AnatoliiG commited on
Commit
a8a31d7
·
1 Parent(s): 862ded5

fix work with agent api mode

Browse files
src/api/routes.py CHANGED
@@ -1,6 +1,6 @@
1
  import asyncio
2
  import json
3
- import threading
4
 
5
  from fastapi import APIRouter, HTTPException, Request
6
  from fastapi.responses import StreamingResponse
@@ -9,6 +9,9 @@ from src.core.config import settings
9
  from src.core.engine import engine
10
  from src.utils.helpers import get_clean_text
11
 
 
 
 
12
  router = APIRouter()
13
 
14
 
@@ -17,13 +20,33 @@ async def chat_completions(request: Request):
17
  if not engine.llm:
18
  raise HTTPException(status_code=500, detail="Model not loaded")
19
 
20
- data = await request.json()
 
 
 
 
 
 
 
 
21
  messages = [
22
  {"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))}
23
  for m in data.get("messages", [])
24
  ]
 
25
  max_tokens = data.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
26
  temperature = data.get("temperature", settings.DEFAULT_TEMP)
 
 
 
 
 
 
 
 
 
 
 
27
  stream_req = data.get("stream", True)
28
 
29
  # --- Логика Streaming ---
@@ -33,28 +56,49 @@ async def chat_completions(request: Request):
33
 
34
  def worker():
35
  try:
36
- for chunk in engine.generate_stream(messages, max_tokens, temperature):
 
 
 
 
 
 
 
37
  loop.call_soon_threadsafe(queue.put_nowait, chunk)
38
- loop.call_soon_threadsafe(queue.put_nowait, None)
39
  except Exception as e:
 
40
  loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)})
41
 
42
  loop.run_in_executor(None, worker)
43
 
44
  while True:
45
  chunk = await queue.get()
 
46
  if chunk is None:
47
  yield "data: [DONE]\n\n"
48
  break
49
 
50
  if isinstance(chunk, dict) and "error" in chunk:
51
- yield f"data: {json.dumps({'error': chunk['error']})}\n\n"
 
 
 
52
  break
53
 
 
54
  yield f"data: {json.dumps(chunk)}\n\n"
55
 
56
  if stream_req:
57
- return StreamingResponse(stream_generator(), media_type="text/event-stream")
 
 
 
 
 
 
 
 
58
 
59
  else:
60
 
@@ -64,6 +108,8 @@ async def chat_completions(request: Request):
64
  messages=messages,
65
  max_tokens=int(max_tokens),
66
  temperature=float(temperature),
 
 
67
  stream=False,
68
  )
69
 
 
1
  import asyncio
2
  import json
3
+ import logging
4
 
5
  from fastapi import APIRouter, HTTPException, Request
6
  from fastapi.responses import StreamingResponse
 
9
  from src.core.engine import engine
10
  from src.utils.helpers import get_clean_text
11
 
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
  router = APIRouter()
16
 
17
 
 
20
  if not engine.llm:
21
  raise HTTPException(status_code=500, detail="Model not loaded")
22
 
23
+ try:
24
+ data = await request.json()
25
+ except Exception:
26
+ raise HTTPException(status_code=400, detail="Invalid JSON")
27
+
28
+ logger.info(
29
+ f"API Request received. Model: {data.get('model')}, Stream: {data.get('stream', True)}"
30
+ )
31
+
32
  messages = [
33
  {"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))}
34
  for m in data.get("messages", [])
35
  ]
36
+
37
  max_tokens = data.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
38
  temperature = data.get("temperature", settings.DEFAULT_TEMP)
39
+ top_p = data.get("top_p", 0.95)
40
+
41
+ stop = data.get("stop", [])
42
+ if isinstance(stop, str):
43
+ stop = [stop]
44
+
45
+ default_stops = ["<|im_end|>", "<|endoftext|>"]
46
+ for s in default_stops:
47
+ if s not in stop:
48
+ stop.append(s)
49
+
50
  stream_req = data.get("stream", True)
51
 
52
  # --- Логика Streaming ---
 
56
 
57
  def worker():
58
  try:
59
+ gen_kwargs = {
60
+ "max_tokens": int(max_tokens),
61
+ "temperature": float(temperature),
62
+ "top_p": float(top_p),
63
+ "stop": stop,
64
+ }
65
+
66
+ for chunk in engine.generate_stream(messages, **gen_kwargs):
67
  loop.call_soon_threadsafe(queue.put_nowait, chunk)
68
+ loop.call_soon_threadsafe(queue.put_nowait, None) # Конец
69
  except Exception as e:
70
+ logger.error(f"Generation error: {e}")
71
  loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)})
72
 
73
  loop.run_in_executor(None, worker)
74
 
75
  while True:
76
  chunk = await queue.get()
77
+
78
  if chunk is None:
79
  yield "data: [DONE]\n\n"
80
  break
81
 
82
  if isinstance(chunk, dict) and "error" in chunk:
83
+ err_json = json.dumps(
84
+ {"error": {"message": chunk["error"], "type": "internal_error"}}
85
+ )
86
+ yield f"data: {err_json}\n\n"
87
  break
88
 
89
+ # Стандартный чанк
90
  yield f"data: {json.dumps(chunk)}\n\n"
91
 
92
  if stream_req:
93
+ headers = {
94
+ "X-Accel-Buffering": "no",
95
+ "Cache-Control": "no-cache",
96
+ "Connection": "keep-alive",
97
+ "Content-Type": "text/event-stream",
98
+ }
99
+ return StreamingResponse(
100
+ stream_generator(), media_type="text/event-stream", headers=headers
101
+ )
102
 
103
  else:
104
 
 
108
  messages=messages,
109
  max_tokens=int(max_tokens),
110
  temperature=float(temperature),
111
+ top_p=float(top_p),
112
+ stop=stop,
113
  stream=False,
114
  )
115
 
src/core/engine.py CHANGED
@@ -24,24 +24,29 @@ class ModelEngine:
24
  n_ctx=settings.CONTEXT_SIZE,
25
  n_threads=settings.N_THREADS,
26
  n_gpu_layers=settings.N_GPU_LAYERS,
27
- verbose=True,
28
  )
29
  print("Model loaded successfully!")
30
  except Exception as e:
31
  print(f"CRITICAL ERROR loading model: {e}")
32
 
33
- def generate_stream(
34
- self, messages: List[Dict[str, str]], max_tokens: int, temperature: float
35
- ) -> Generator:
36
  if not self.llm:
37
  raise RuntimeError("Model not loaded")
38
 
 
 
 
 
39
  with self.lock:
40
  stream = self.llm.create_chat_completion(
41
  messages=messages,
42
  max_tokens=int(max_tokens),
43
  temperature=float(temperature),
 
44
  stream=True,
 
45
  )
46
  for chunk in stream:
47
  yield chunk
 
24
  n_ctx=settings.CONTEXT_SIZE,
25
  n_threads=settings.N_THREADS,
26
  n_gpu_layers=settings.N_GPU_LAYERS,
27
+ verbose=False,
28
  )
29
  print("Model loaded successfully!")
30
  except Exception as e:
31
  print(f"CRITICAL ERROR loading model: {e}")
32
 
33
+ # Изменили сигнатуру: теперь принимает **kwargs
34
+ def generate_stream(self, messages: List[Dict[str, str]], **kwargs) -> Generator:
 
35
  if not self.llm:
36
  raise RuntimeError("Model not loaded")
37
 
38
+ max_tokens = kwargs.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
39
+ temperature = kwargs.get("temperature", settings.DEFAULT_TEMP)
40
+ stop = kwargs.get("stop", [])
41
+
42
  with self.lock:
43
  stream = self.llm.create_chat_completion(
44
  messages=messages,
45
  max_tokens=int(max_tokens),
46
  temperature=float(temperature),
47
+ stop=stop,
48
  stream=True,
49
+ top_p=kwargs.get("top_p", 0.95),
50
  )
51
  for chunk in stream:
52
  yield chunk
src/ui/callbacks.py CHANGED
@@ -40,7 +40,9 @@ def bot_response(history, system_prompt, temperature, max_tokens):
40
  history.append({"role": "assistant", "content": ""})
41
 
42
  try:
43
- stream = engine.generate_stream(messages, max_tokens, temperature)
 
 
44
 
45
  partial_text = ""
46
  for chunk in stream:
 
40
  history.append({"role": "assistant", "content": ""})
41
 
42
  try:
43
+ stream = engine.generate_stream(
44
+ messages=messages, max_tokens=max_tokens, temperature=temperature
45
+ )
46
 
47
  partial_text = ""
48
  for chunk in stream:
src/ui/components.py CHANGED
@@ -24,7 +24,7 @@ def create_ui():
24
  temp = gr.Slider(0, 1, value=settings.DEFAULT_TEMP, label="Temperature")
25
  tokens = gr.Slider(
26
  512,
27
- settings.DEFAULT_MAX_TOKENS,
28
  value=settings.DEFAULT_MAX_TOKENS,
29
  label="Max New Tokens",
30
  step=128,
 
24
  temp = gr.Slider(0, 1, value=settings.DEFAULT_TEMP, label="Temperature")
25
  tokens = gr.Slider(
26
  512,
27
+ settings.CONTEXT_SIZE,
28
  value=settings.DEFAULT_MAX_TOKENS,
29
  label="Max New Tokens",
30
  step=128,