AnatoliiG commited on
Commit
010db11
·
1 Parent(s): a8a31d7

cancel process

Browse files
Files changed (2) hide show
  1. src/api/routes.py +51 -53
  2. src/core/engine.py +29 -5
src/api/routes.py CHANGED
@@ -1,6 +1,7 @@
1
  import asyncio
2
  import json
3
  import logging
 
4
 
5
  from fastapi import APIRouter, HTTPException, Request
6
  from fastapi.responses import StreamingResponse
@@ -25,10 +26,6 @@ async def chat_completions(request: Request):
25
  except Exception:
26
  raise HTTPException(status_code=400, detail="Invalid JSON")
27
 
28
- logger.info(
29
- f"API Request received. Model: {data.get('model')}, Stream: {data.get('stream', True)}"
30
- )
31
-
32
  messages = [
33
  {"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))}
34
  for m in data.get("messages", [])
@@ -37,19 +34,17 @@ async def chat_completions(request: Request):
37
  max_tokens = data.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
38
  temperature = data.get("temperature", settings.DEFAULT_TEMP)
39
  top_p = data.get("top_p", 0.95)
40
-
41
  stop = data.get("stop", [])
42
  if isinstance(stop, str):
43
  stop = [stop]
44
 
45
- default_stops = ["<|im_end|>", "<|endoftext|>"]
46
  for s in default_stops:
47
  if s not in stop:
48
  stop.append(s)
49
 
50
- stream_req = data.get("stream", True)
51
 
52
- # --- Логика Streaming ---
53
  async def stream_generator():
54
  queue = asyncio.Queue()
55
  loop = asyncio.get_running_loop()
@@ -61,57 +56,60 @@ async def chat_completions(request: Request):
61
  "temperature": float(temperature),
62
  "top_p": float(top_p),
63
  "stop": stop,
 
64
  }
65
 
 
66
  for chunk in engine.generate_stream(messages, **gen_kwargs):
67
  loop.call_soon_threadsafe(queue.put_nowait, chunk)
68
- loop.call_soon_threadsafe(queue.put_nowait, None) # Конец
 
69
  except Exception as e:
70
- logger.error(f"Generation error: {e}")
 
71
  loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)})
72
 
73
  loop.run_in_executor(None, worker)
74
 
75
- while True:
76
- chunk = await queue.get()
77
-
78
- if chunk is None:
79
- yield "data: [DONE]\n\n"
80
- break
81
-
82
- if isinstance(chunk, dict) and "error" in chunk:
83
- err_json = json.dumps(
84
- {"error": {"message": chunk["error"], "type": "internal_error"}}
85
- )
86
- yield f"data: {err_json}\n\n"
87
- break
88
-
89
- # Стандартный чанк
90
- yield f"data: {json.dumps(chunk)}\n\n"
91
-
92
- if stream_req:
93
- headers = {
94
- "X-Accel-Buffering": "no",
95
- "Cache-Control": "no-cache",
96
- "Connection": "keep-alive",
97
- "Content-Type": "text/event-stream",
98
- }
99
- return StreamingResponse(
100
- stream_generator(), media_type="text/event-stream", headers=headers
101
- )
102
-
103
- else:
104
-
105
- def run_sync():
106
- with engine.lock:
107
- return engine.llm.create_chat_completion(
108
- messages=messages,
109
- max_tokens=int(max_tokens),
110
- temperature=float(temperature),
111
- top_p=float(top_p),
112
- stop=stop,
113
- stream=False,
114
- )
115
-
116
- response = await asyncio.to_thread(run_sync)
117
- return response
 
1
  import asyncio
2
  import json
3
  import logging
4
+ import threading
5
 
6
  from fastapi import APIRouter, HTTPException, Request
7
  from fastapi.responses import StreamingResponse
 
26
  except Exception:
27
  raise HTTPException(status_code=400, detail="Invalid JSON")
28
 
 
 
 
 
29
  messages = [
30
  {"role": m.get("role", "user"), "content": get_clean_text(m.get("content"))}
31
  for m in data.get("messages", [])
 
34
  max_tokens = data.get("max_tokens", settings.DEFAULT_MAX_TOKENS)
35
  temperature = data.get("temperature", settings.DEFAULT_TEMP)
36
  top_p = data.get("top_p", 0.95)
 
37
  stop = data.get("stop", [])
38
  if isinstance(stop, str):
39
  stop = [stop]
40
 
41
+ default_stops = ["<|im_end|>", "<|endoftext|>", "<|file_sep|>"]
42
  for s in default_stops:
43
  if s not in stop:
44
  stop.append(s)
45
 
46
+ abort_event = threading.Event()
47
 
 
48
  async def stream_generator():
49
  queue = asyncio.Queue()
50
  loop = asyncio.get_running_loop()
 
56
  "temperature": float(temperature),
57
  "top_p": float(top_p),
58
  "stop": stop,
59
+ "abort_event": abort_event,
60
  }
61
 
62
+ # Запускаем генерацию
63
  for chunk in engine.generate_stream(messages, **gen_kwargs):
64
  loop.call_soon_threadsafe(queue.put_nowait, chunk)
65
+
66
+ loop.call_soon_threadsafe(queue.put_nowait, None)
67
  except Exception as e:
68
+ if not abort_event.is_set():
69
+ logger.error(f"Generation error: {e}")
70
  loop.call_soon_threadsafe(queue.put_nowait, {"error": str(e)})
71
 
72
  loop.run_in_executor(None, worker)
73
 
74
+ try:
75
+ while True:
76
+ if await request.is_disconnected():
77
+ logger.info("Client disconnected! Aborting generation...")
78
+ abort_event.set()
79
+ break
80
+
81
+ try:
82
+ chunk = await asyncio.wait_for(queue.get(), timeout=0.1)
83
+ except asyncio.TimeoutError:
84
+ continue
85
+
86
+ if chunk is None:
87
+ yield "data: [DONE]\n\n"
88
+ break
89
+
90
+ if isinstance(chunk, dict) and "error" in chunk:
91
+ if abort_event.is_set():
92
+ break
93
+ err_json = json.dumps(
94
+ {"error": {"message": chunk["error"], "type": "internal_error"}}
95
+ )
96
+ yield f"data: {err_json}\n\n"
97
+ break
98
+
99
+ yield f"data: {json.dumps(chunk)}\n\n"
100
+
101
+ except asyncio.CancelledError:
102
+ logger.info("Task cancelled. Stopping worker.")
103
+ abort_event.set()
104
+ raise
105
+
106
+ # Возвращаем стрим
107
+ headers = {
108
+ "X-Accel-Buffering": "no",
109
+ "Cache-Control": "no-cache",
110
+ "Connection": "keep-alive",
111
+ "Content-Type": "text/event-stream",
112
+ }
113
+ return StreamingResponse(
114
+ stream_generator(), media_type="text/event-stream", headers=headers
115
+ )
 
src/core/engine.py CHANGED
@@ -1,5 +1,6 @@
1
  import threading
2
- from typing import Any, Dict, Generator, List
 
3
 
4
  from huggingface_hub import hf_hub_download
5
  from llama_cpp import Llama
@@ -24,14 +25,18 @@ class ModelEngine:
24
  n_ctx=settings.CONTEXT_SIZE,
25
  n_threads=settings.N_THREADS,
26
  n_gpu_layers=settings.N_GPU_LAYERS,
27
- verbose=False,
28
  )
29
  print("Model loaded successfully!")
30
  except Exception as e:
31
  print(f"CRITICAL ERROR loading model: {e}")
32
 
33
- # Изменили сигнатуру: теперь принимает **kwargs
34
- def generate_stream(self, messages: List[Dict[str, str]], **kwargs) -> Generator:
 
 
 
 
35
  if not self.llm:
36
  raise RuntimeError("Model not loaded")
37
 
@@ -39,7 +44,18 @@ class ModelEngine:
39
  temperature = kwargs.get("temperature", settings.DEFAULT_TEMP)
40
  stop = kwargs.get("stop", [])
41
 
42
- with self.lock:
 
 
 
 
 
 
 
 
 
 
 
43
  stream = self.llm.create_chat_completion(
44
  messages=messages,
45
  max_tokens=int(max_tokens),
@@ -48,8 +64,16 @@ class ModelEngine:
48
  stream=True,
49
  top_p=kwargs.get("top_p", 0.95),
50
  )
 
51
  for chunk in stream:
 
 
 
 
52
  yield chunk
53
 
 
 
 
54
 
55
  engine = ModelEngine()
 
1
  import threading
2
+ import time
3
+ from typing import Any, Dict, Generator, List, Optional
4
 
5
  from huggingface_hub import hf_hub_download
6
  from llama_cpp import Llama
 
25
  n_ctx=settings.CONTEXT_SIZE,
26
  n_threads=settings.N_THREADS,
27
  n_gpu_layers=settings.N_GPU_LAYERS,
28
+ verbose=True,
29
  )
30
  print("Model loaded successfully!")
31
  except Exception as e:
32
  print(f"CRITICAL ERROR loading model: {e}")
33
 
34
+ def generate_stream(
35
+ self,
36
+ messages: List[Dict[str, str]],
37
+ abort_event: Optional[threading.Event] = None, # Новый аргумент
38
+ **kwargs,
39
+ ) -> Generator:
40
  if not self.llm:
41
  raise RuntimeError("Model not loaded")
42
 
 
44
  temperature = kwargs.get("temperature", settings.DEFAULT_TEMP)
45
  stop = kwargs.get("stop", [])
46
 
47
+ acquired = False
48
+ while not acquired:
49
+ if abort_event and abort_event.is_set():
50
+ print("Request aborted while waiting in queue.")
51
+ return
52
+
53
+ acquired = self.lock.acquire(timeout=0.5)
54
+
55
+ try:
56
+ if abort_event and abort_event.is_set():
57
+ return
58
+
59
  stream = self.llm.create_chat_completion(
60
  messages=messages,
61
  max_tokens=int(max_tokens),
 
64
  stream=True,
65
  top_p=kwargs.get("top_p", 0.95),
66
  )
67
+
68
  for chunk in stream:
69
+ if abort_event and abort_event.is_set():
70
+ print("Request aborted during generation.")
71
+ break
72
+
73
  yield chunk
74
 
75
+ finally:
76
+ self.lock.release()
77
+
78
 
79
  engine = ModelEngine()