arthu1 commited on
Commit
27f81ac
·
1 Parent(s): 264ac43

Revert to Docker/FastAPI + add dynamic INT8 quantization for faster CPU inference

Browse files
Files changed (4) hide show
  1. Dockerfile +15 -0
  2. README.md +15 -18
  3. app.py +259 -167
  4. requirements.txt +3 -0
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PORT=7860
7
+
8
+ COPY requirements.txt /app/requirements.txt
9
+ RUN pip install --no-cache-dir -r /app/requirements.txt
10
+
11
+ COPY . /app
12
+
13
+ EXPOSE 7860
14
+
15
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -3,33 +3,30 @@ title: North Air API
3
  emoji: 🌬️
4
  colorFrom: blue
5
  colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.12.0
8
- app_file: app.py
9
- hardware: zero-a10g
10
  ---
11
 
12
- # North Air 1 — ZeroGPU API
13
 
14
- GPU-accelerated inference via HuggingFace ZeroGPU (free A100 time-slices).
15
 
16
- ## API Endpoints
17
-
18
- - `POST /api/chat` — non-streaming chat
19
- - `POST /api/chat_stream` — streaming chat (newline-delimited JSON events)
20
-
21
- ### Request format
22
 
 
23
  ```json
24
  {
25
- "data": ["{\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}"]
 
 
 
26
  }
27
  ```
28
 
29
- ### Response format
30
-
31
  ```json
32
- {
33
- "data": ["{\"output\":\"Hey! I'm North Air 1...\",\"model\":\"north-air-1\",\"tokens_generated\":42,\"latency_ms\":1200}"]
34
- }
35
  ```
 
3
  emoji: 🌬️
4
  colorFrom: blue
5
  colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
 
 
8
  ---
9
 
10
+ # North Air 1 API
11
 
12
+ Optimized CPU inference with dynamic INT8 quantization.
13
 
14
+ Endpoints:
15
+ - `GET /health`
16
+ - `POST /chat`
17
+ - `POST /chat/stream`
 
 
18
 
19
+ Request shape (`/chat`):
20
  ```json
21
  {
22
+ "model": "north-air-1",
23
+ "messages": [
24
+ {"role": "user", "content": "Hello"}
25
+ ]
26
  }
27
  ```
28
 
29
+ Response shape:
 
30
  ```json
31
+ {"output": "...", "inference": "pytorch-int8"}
 
 
32
  ```
app.py CHANGED
@@ -1,14 +1,16 @@
1
  import os
2
  import re
3
- import json
4
  import time
 
 
5
  from threading import Thread
6
 
7
  import torch
8
- import gradio as gr
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
10
 
11
- # ─── Config ───
12
  MODEL_DIR = os.getenv("MODEL_DIR", "./final_model")
13
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
14
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.6"))
@@ -18,39 +20,144 @@ SYSTEM_PROMPT = """You are North Air 1, built by North Air. 0.6B params, a custo
18
  Be direct, helpful, concise. Use markdown. Write clean code. Never fabricate facts.
19
  If asked who you are: "I'm North Air 1, built by North Air." You are NOT ChatGPT/GPT-4/Claude/etc."""
20
 
21
- # ─── Load model on CPU first, ZeroGPU moves it to GPU per-call ───
22
- def _load_model():
23
- adapter_cfg = os.path.join(MODEL_DIR, "adapter_config.json")
24
- if os.path.exists(adapter_cfg):
25
- from peft import AutoPeftModelForCausalLM
26
- model = AutoPeftModelForCausalLM.from_pretrained(
27
- MODEL_DIR, torch_dtype=torch.float16, device_map={"": "cpu"},
28
- )
29
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
30
- else:
31
- model = AutoModelForCausalLM.from_pretrained(
32
- MODEL_DIR, torch_dtype=torch.float16, device_map={"": "cpu"},
33
- trust_remote_code=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, trust_remote_code=True)
 
 
 
 
 
36
 
37
- if tokenizer.pad_token is None:
38
- tokenizer.pad_token = tokenizer.eos_token
39
- model.eval()
40
- return model, tokenizer
41
 
42
- try:
43
- MODEL, TOKENIZER = _load_model()
44
- LOAD_ERROR = None
45
- except Exception as exc:
46
- MODEL, TOKENIZER = None, None
47
- LOAD_ERROR = str(exc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
- def _build_prompt(messages, system, enable_thinking=False):
51
  has_system = any(m["role"] == "system" for m in messages)
52
  if not has_system:
53
  messages = [{"role": "system", "content": system}] + messages
 
54
  if hasattr(TOKENIZER, "apply_chat_template"):
55
  return TOKENIZER.apply_chat_template(
56
  messages, tokenize=False, add_generation_prompt=True,
@@ -59,178 +166,163 @@ def _build_prompt(messages, system, enable_thinking=False):
59
  return "\n".join(f"{m['role']}: {m['content']}" for m in messages) + "\nassistant:"
60
 
61
 
62
- # ─── GPU inference via ZeroGPU ───
63
- @gr.api(api_name="chat")
64
- @gr.GPU
65
- def chat_api(request_json: str) -> str:
66
- """Non-streaming chat. Called via Gradio API."""
67
- if MODEL is None:
68
- return json.dumps({"error": f"Model failed to load: {LOAD_ERROR}"})
69
 
70
- body = json.loads(request_json)
71
- messages = body.get("messages", [])
72
- if not messages:
73
- return json.dumps({"error": "messages are required"})
74
 
75
- system = body.get("system_prompt", SYSTEM_PROMPT)
76
- max_tokens = body.get("max_new_tokens", MAX_NEW_TOKENS)
77
- temperature = body.get("temperature", TEMPERATURE)
78
- top_p = body.get("top_p", TOP_P)
79
- enable_thinking = body.get("enable_thinking", False)
 
 
 
 
 
 
 
 
 
 
80
 
81
- msg_dicts = [{"role": m["role"], "content": m["content"]} for m in messages]
82
- prompt = _build_prompt(msg_dicts, system, enable_thinking)
 
 
 
 
 
 
 
 
 
 
 
83
  batch = TOKENIZER(prompt, return_tensors="pt", add_special_tokens=False)
84
 
85
- # Move to GPU (ZeroGPU provides it)
86
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
- MODEL.to(device)
88
- input_ids = batch["input_ids"].to(device)
89
- attention_mask = batch["attention_mask"].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  t0 = time.time()
 
92
  with torch.no_grad():
93
  out = MODEL.generate(
94
- input_ids=input_ids,
95
- attention_mask=attention_mask,
96
- max_new_tokens=max_tokens,
97
- temperature=max(temperature, 0.01),
98
- top_p=top_p,
99
- top_k=40,
100
- do_sample=True,
101
- repetition_penalty=1.2,
102
- pad_token_id=TOKENIZER.pad_token_id,
103
- eos_token_id=TOKENIZER.eos_token_id,
104
  )
105
- elapsed = time.time() - t0
106
 
 
107
  generated_ids = out[0][input_ids.shape[1]:]
108
  completion = TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
 
109
 
110
- # Parse thinking tags
111
- think_match = re.search(r"<think>(.*?)</think>", completion, re.DOTALL)
112
- thinking = ""
113
- answer = completion
114
- if think_match:
115
- thinking = think_match.group(1).strip()
116
- answer = re.sub(r"<think>.*?</think>", "", completion, flags=re.DOTALL).strip()
117
-
118
- return json.dumps({
119
  "output": answer,
120
  "thinking": thinking if thinking else None,
121
  "model": "north-air-1",
 
122
  "tokens_generated": len(generated_ids),
123
  "latency_ms": round(elapsed * 1000),
124
- })
125
-
126
-
127
- @gr.api(api_name="chat_stream")
128
- @gr.GPU
129
- def chat_stream_api(request_json: str) -> str:
130
- """Streaming chat. Returns all tokens as newline-delimited JSON events."""
131
- if MODEL is None:
132
- return json.dumps({"error": f"Model failed to load: {LOAD_ERROR}"})
133
 
134
- body = json.loads(request_json)
135
- messages = body.get("messages", [])
136
- if not messages:
137
- return json.dumps({"error": "messages are required"})
138
 
139
- system = body.get("system_prompt", SYSTEM_PROMPT)
140
- max_tokens = body.get("max_new_tokens", MAX_NEW_TOKENS)
141
- temperature = body.get("temperature", TEMPERATURE)
142
- top_p = body.get("top_p", TOP_P)
143
- enable_thinking = body.get("enable_thinking", False)
144
 
145
- msg_dicts = [{"role": m["role"], "content": m["content"]} for m in messages]
146
- prompt = _build_prompt(msg_dicts, system, enable_thinking)
147
- batch = TOKENIZER(prompt, return_tensors="pt", add_special_tokens=False)
148
 
149
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150
- MODEL.to(device)
151
- input_ids = batch["input_ids"].to(device)
152
- attention_mask = batch["attention_mask"].to(device)
153
 
154
  streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
155
 
156
- gen_kwargs = {
157
- "input_ids": input_ids,
158
- "attention_mask": attention_mask,
159
- "max_new_tokens": max_tokens,
160
- "temperature": max(temperature, 0.01),
161
- "top_p": top_p,
162
- "top_k": 40,
163
- "do_sample": True,
164
- "repetition_penalty": 1.2,
165
- "pad_token_id": TOKENIZER.pad_token_id,
166
- "eos_token_id": TOKENIZER.eos_token_id,
167
- "streamer": streamer,
168
- }
169
 
170
  t0 = time.time()
171
- thread = Thread(target=lambda: MODEL.generate(**gen_kwargs))
172
  thread.start()
173
 
174
- # Collect all tokens (ZeroGPU doesn't support true SSE, so we batch)
175
- events = []
176
- token_count = 0
177
- in_thinking = False
178
- buf = ""
179
-
180
- for token_text in streamer:
181
- buf += token_text
182
- token_count += 1
183
-
184
- if "<think>" in buf and not in_thinking:
185
- in_thinking = True
186
- events.append(json.dumps({"type": "thinking_start"}))
187
- after = buf.split("<think>", 1)[1]
188
- buf = after if after else ""
189
-
190
- if "</think>" in buf and in_thinking:
191
- before = buf.split("</think>", 1)[0]
192
- if before:
193
- events.append(json.dumps({"type": "thinking", "text": before}))
194
- in_thinking = False
195
- events.append(json.dumps({"type": "thinking_end"}))
196
- after = buf.split("</think>", 1)[1].lstrip()
197
- buf = ""
198
- if after:
199
- events.append(json.dumps({"type": "text", "text": after}))
200
- continue
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  if buf:
203
  evt_type = "thinking" if in_thinking else "text"
204
- events.append(json.dumps({"type": evt_type, "text": buf}))
205
- buf = ""
 
206
 
207
- if buf:
208
- evt_type = "thinking" if in_thinking else "text"
209
- events.append(json.dumps({"type": evt_type, "text": buf}))
210
-
211
- thread.join()
212
- elapsed = time.time() - t0
213
- events.append(json.dumps({
214
- "type": "done",
215
- "tokens_generated": token_count,
216
- "latency_ms": round(elapsed * 1000),
217
- }))
218
-
219
- return "\n".join(events)
220
 
 
221
 
222
- # ─── Gradio UI (required for ZeroGPU Spaces) ───
223
- def gradio_chat(message, history):
224
- """Simple chat interface for the Gradio UI."""
225
- result_json = chat_api(json.dumps({
226
- "messages": [{"role": "user", "content": message}],
227
- }))
228
- result = json.loads(result_json)
229
- return result.get("output", result.get("error", "Error"))
230
 
231
-
232
- with gr.Blocks(title="North Air 1 API", theme=gr.themes.Base()) as demo:
233
- gr.Markdown("# North Air 1 API\n0.6B parameter AI by North Air. Use the chat below or call the API endpoints.")
234
- gr.ChatInterface(gradio_chat, type="messages")
235
-
236
- demo.launch()
 
1
  import os
2
  import re
 
3
  import time
4
+ import json
5
+ from typing import List, Optional
6
  from threading import Thread
7
 
8
  import torch
9
+ from fastapi import FastAPI, HTTPException
10
+ from fastapi.responses import StreamingResponse
11
+ from pydantic import BaseModel
12
+ from transformers import AutoTokenizer, TextIteratorStreamer
13
 
 
14
  MODEL_DIR = os.getenv("MODEL_DIR", "./final_model")
15
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
16
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.6"))
 
20
  Be direct, helpful, concise. Use markdown. Write clean code. Never fabricate facts.
21
  If asked who you are: "I'm North Air 1, built by North Air." You are NOT ChatGPT/GPT-4/Claude/etc."""
22
 
23
+
24
+ class Message(BaseModel):
25
+ role: str
26
+ content: str
27
+
28
+
29
+ class ChatRequest(BaseModel):
30
+ messages: List[Message]
31
+ model: Optional[str] = "north-air-1"
32
+ max_new_tokens: Optional[int] = None
33
+ temperature: Optional[float] = None
34
+ top_p: Optional[float] = None
35
+ system_prompt: Optional[str] = None
36
+ stream: Optional[bool] = False
37
+ enable_thinking: Optional[bool] = False
38
+
39
+
40
+ app = FastAPI(title="North Air 1 API", version="4.0.0")
41
+
42
+ # ─── Model Loading: try ONNX first (fast), fallback to PyTorch ───
43
+ ONNX_SESSION = None
44
+ MODEL = None
45
+ TOKENIZER = None
46
+ LOAD_ERROR = None
47
+ INFERENCE_MODE = "pytorch" # or "onnx"
48
+
49
+
50
+ def _try_load_onnx():
51
+ """Try to load ONNX Runtime quantized model for 2-4x faster CPU inference."""
52
+ global ONNX_SESSION, INFERENCE_MODE
53
+ onnx_path = os.path.join(MODEL_DIR, "model_quantized.onnx")
54
+ if not os.path.exists(onnx_path):
55
+ onnx_path = os.path.join(MODEL_DIR, "model.onnx")
56
+ if not os.path.exists(onnx_path):
57
+ return False
58
+
59
+ try:
60
+ import onnxruntime as ort
61
+ sess_options = ort.SessionOptions()
62
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
63
+ sess_options.intra_op_num_threads = 4
64
+ sess_options.inter_op_num_threads = 2
65
+ sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
66
+
67
+ ONNX_SESSION = ort.InferenceSession(
68
+ onnx_path, sess_options,
69
+ providers=["CPUExecutionProvider"],
70
  )
71
+ INFERENCE_MODE = "onnx"
72
+ print(f"ONNX Runtime loaded: {onnx_path}")
73
+ return True
74
+ except Exception as e:
75
+ print(f"ONNX load failed: {e}")
76
+ return False
77
 
 
 
 
 
78
 
79
+ def _load_model():
80
+ """Load model ONNX quantized if available, else PyTorch."""
81
+ global MODEL, TOKENIZER, LOAD_ERROR, INFERENCE_MODE
82
+
83
+ try:
84
+ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, trust_remote_code=True)
85
+ if TOKENIZER.pad_token is None:
86
+ TOKENIZER.pad_token = TOKENIZER.eos_token
87
+ except Exception as e:
88
+ LOAD_ERROR = f"Tokenizer load failed: {e}"
89
+ return
90
+
91
+ # Try ONNX first
92
+ if _try_load_onnx():
93
+ print(f"Using ONNX Runtime ({INFERENCE_MODE})")
94
+ return
95
+
96
+ # Fallback: PyTorch with optimizations
97
+ try:
98
+ from transformers import AutoModelForCausalLM
99
+ adapter_cfg = os.path.join(MODEL_DIR, "adapter_config.json")
100
+
101
+ if os.path.exists(adapter_cfg):
102
+ from peft import AutoPeftModelForCausalLM
103
+ MODEL = AutoPeftModelForCausalLM.from_pretrained(
104
+ MODEL_DIR, torch_dtype=torch.float32, device_map={"": "cpu"},
105
+ )
106
+ else:
107
+ MODEL = AutoModelForCausalLM.from_pretrained(
108
+ MODEL_DIR, torch_dtype=torch.float32, device_map={"": "cpu"},
109
+ trust_remote_code=True,
110
+ )
111
+
112
+ MODEL.eval()
113
+
114
+ # Apply PyTorch dynamic quantization (INT8) for ~1.5-2x speedup on CPU
115
+ try:
116
+ MODEL = torch.quantization.quantize_dynamic(
117
+ MODEL, {torch.nn.Linear}, dtype=torch.qint8,
118
+ )
119
+ INFERENCE_MODE = "pytorch-int8"
120
+ print("PyTorch dynamic INT8 quantization applied")
121
+ except Exception as e:
122
+ INFERENCE_MODE = "pytorch"
123
+ print(f"Quantization skipped: {e}")
124
+
125
+ # Enable torch.compile if available (PyTorch 2.x)
126
+ try:
127
+ MODEL = torch.compile(MODEL, mode="reduce-overhead")
128
+ print("torch.compile applied")
129
+ except Exception:
130
+ pass
131
+
132
+ print(f"Model loaded: {INFERENCE_MODE}")
133
+
134
+ except Exception as e:
135
+ LOAD_ERROR = str(e)
136
+
137
+
138
+ _load_model()
139
+
140
+
141
+ @app.get("/health")
142
+ def health():
143
+ ok = (MODEL is not None) or (ONNX_SESSION is not None)
144
+ return {
145
+ "ok": ok,
146
+ "model": "north-air-1",
147
+ "version": "4.0.0",
148
+ "architecture": "Qwen3-0.6B + LoRA r=64",
149
+ "inference": INFERENCE_MODE,
150
+ "features": ["streaming", "thinking", "quantized"],
151
+ "model_dir": MODEL_DIR,
152
+ "error": LOAD_ERROR,
153
+ }
154
 
155
 
156
+ def _build_prompt(messages: list, system: str, enable_thinking: bool) -> str:
157
  has_system = any(m["role"] == "system" for m in messages)
158
  if not has_system:
159
  messages = [{"role": "system", "content": system}] + messages
160
+
161
  if hasattr(TOKENIZER, "apply_chat_template"):
162
  return TOKENIZER.apply_chat_template(
163
  messages, tokenize=False, add_generation_prompt=True,
 
166
  return "\n".join(f"{m['role']}: {m['content']}" for m in messages) + "\nassistant:"
167
 
168
 
169
+ def _parse_thinking(text: str) -> tuple:
170
+ think_match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
171
+ if think_match:
172
+ thinking = think_match.group(1).strip()
173
+ answer = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
174
+ return thinking, answer
175
+ return "", text
176
 
 
 
 
 
177
 
178
+ def _generation_kwargs(input_ids, attention_mask, max_new_tokens, temperature, top_p, **extra):
179
+ return {
180
+ "input_ids": input_ids,
181
+ "attention_mask": attention_mask,
182
+ "max_new_tokens": max_new_tokens,
183
+ "temperature": max(temperature, 0.01),
184
+ "top_p": top_p,
185
+ "top_k": 40,
186
+ "do_sample": True,
187
+ "repetition_penalty": 1.2,
188
+ "pad_token_id": TOKENIZER.pad_token_id,
189
+ "eos_token_id": TOKENIZER.eos_token_id,
190
+ **extra,
191
+ }
192
+
193
 
194
+ def _check_model():
195
+ if MODEL is None and ONNX_SESSION is None:
196
+ raise HTTPException(status_code=500, detail=f"Model failed to load: {LOAD_ERROR}")
197
+ if TOKENIZER is None:
198
+ raise HTTPException(status_code=500, detail=f"Tokenizer failed to load: {LOAD_ERROR}")
199
+
200
+
201
+ def _prepare_request(req: ChatRequest):
202
+ system = req.system_prompt or SYSTEM_PROMPT
203
+ messages = [{"role": m.role, "content": m.content} for m in req.messages]
204
+ enable_thinking = req.enable_thinking if req.enable_thinking is not None else False
205
+
206
+ prompt = _build_prompt(messages, system, enable_thinking)
207
  batch = TOKENIZER(prompt, return_tensors="pt", add_special_tokens=False)
208
 
209
+ max_new_tokens = req.max_new_tokens or MAX_NEW_TOKENS
210
+ temperature = req.temperature if req.temperature is not None else TEMPERATURE
211
+ top_p = req.top_p if req.top_p is not None else TOP_P
212
+
213
+ return batch, max_new_tokens, temperature, top_p
214
+
215
+
216
+ @app.post("/chat")
217
+ def chat(req: ChatRequest):
218
+ _check_model()
219
+
220
+ if not req.messages:
221
+ raise HTTPException(status_code=400, detail="messages are required")
222
+
223
+ if req.stream:
224
+ return chat_stream(req)
225
+
226
+ batch, max_new_tokens, temperature, top_p = _prepare_request(req)
227
+ input_ids = batch["input_ids"]
228
+ attention_mask = batch["attention_mask"]
229
 
230
  t0 = time.time()
231
+
232
  with torch.no_grad():
233
  out = MODEL.generate(
234
+ **_generation_kwargs(input_ids, attention_mask, max_new_tokens, temperature, top_p)
 
 
 
 
 
 
 
 
 
235
  )
 
236
 
237
+ elapsed = time.time() - t0
238
  generated_ids = out[0][input_ids.shape[1]:]
239
  completion = TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
240
+ thinking, answer = _parse_thinking(completion)
241
 
242
+ return {
 
 
 
 
 
 
 
 
243
  "output": answer,
244
  "thinking": thinking if thinking else None,
245
  "model": "north-air-1",
246
+ "inference": INFERENCE_MODE,
247
  "tokens_generated": len(generated_ids),
248
  "latency_ms": round(elapsed * 1000),
249
+ }
 
 
 
 
 
 
 
 
250
 
 
 
 
 
251
 
252
+ @app.post("/chat/stream")
253
+ def chat_stream(req: ChatRequest):
254
+ _check_model()
 
 
255
 
256
+ if not req.messages:
257
+ raise HTTPException(status_code=400, detail="messages are required")
 
258
 
259
+ batch, max_new_tokens, temperature, top_p = _prepare_request(req)
260
+ input_ids = batch["input_ids"]
261
+ attention_mask = batch["attention_mask"]
 
262
 
263
  streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
264
 
265
+ gen_kwargs = _generation_kwargs(
266
+ input_ids, attention_mask, max_new_tokens, temperature, top_p,
267
+ streamer=streamer,
268
+ )
 
 
 
 
 
 
 
 
 
269
 
270
  t0 = time.time()
271
+ thread = Thread(target=_generate_in_thread, args=(gen_kwargs,))
272
  thread.start()
273
 
274
+ def event_stream():
275
+ token_count = 0
276
+ in_thinking = False
277
+ buf = ""
278
+
279
+ for token_text in streamer:
280
+ buf += token_text
281
+ token_count += 1
282
+
283
+ if "<think>" in buf and not in_thinking:
284
+ in_thinking = True
285
+ yield f"data: {json.dumps({'type': 'thinking_start'})}\n\n"
286
+ after = buf.split("<think>", 1)[1]
287
+ buf = after if after else ""
288
+
289
+ if "</think>" in buf and in_thinking:
290
+ before = buf.split("</think>", 1)[0]
291
+ if before:
292
+ yield f"data: {json.dumps({'type': 'thinking', 'text': before})}\n\n"
293
+ in_thinking = False
294
+ yield f"data: {json.dumps({'type': 'thinking_end'})}\n\n"
295
+ after = buf.split("</think>", 1)[1].lstrip()
296
+ buf = ""
297
+ if after:
298
+ yield f"data: {json.dumps({'type': 'text', 'text': after})}\n\n"
299
+ continue
300
+
301
+ partial_open = "<think"
302
+ partial_close = "</think"
303
+ if not in_thinking and buf.endswith(tuple(partial_open[:i] for i in range(1, len(partial_open) + 1))):
304
+ continue
305
+ if in_thinking and buf.endswith(tuple(partial_close[:i] for i in range(1, len(partial_close) + 1))):
306
+ continue
307
+
308
+ if buf:
309
+ evt_type = "thinking" if in_thinking else "text"
310
+ yield f"data: {json.dumps({'type': evt_type, 'text': buf})}\n\n"
311
+ buf = ""
312
 
313
  if buf:
314
  evt_type = "thinking" if in_thinking else "text"
315
+ yield f"data: {json.dumps({'type': evt_type, 'text': buf})}\n\n"
316
+ if in_thinking:
317
+ yield f"data: {json.dumps({'type': 'thinking_end'})}\n\n"
318
 
319
+ thread.join()
320
+ elapsed = time.time() - t0
321
+ yield f"data: {json.dumps({'type': 'done', 'tokens_generated': token_count, 'latency_ms': round(elapsed * 1000), 'inference': INFERENCE_MODE})}\n\n"
 
 
 
 
 
 
 
 
 
 
322
 
323
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
324
 
 
 
 
 
 
 
 
 
325
 
326
+ def _generate_in_thread(kwargs):
327
+ with torch.no_grad():
328
+ MODEL.generate(**kwargs)
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  torch>=2.2.0
2
  transformers>=4.45.0
3
  peft>=0.12.0
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6
3
+ pydantic==2.9.2
4
  torch>=2.2.0
5
  transformers>=4.45.0
6
  peft>=0.12.0