Ashok75 commited on
Commit
3db39aa
·
verified ·
1 Parent(s): aed8238

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -38
app.py CHANGED
@@ -6,8 +6,10 @@ Provides streaming chat completion API.
6
  import os
7
  import json
8
  import asyncio
 
9
  from typing import AsyncGenerator, List, Dict, Any, Optional
10
  from contextlib import asynccontextmanager
 
11
 
12
  from fastapi import FastAPI, HTTPException
13
  from fastapi.responses import StreamingResponse
@@ -17,6 +19,10 @@ from pydantic import BaseModel
17
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
18
  from threading import Thread
19
  import torch
 
 
 
 
20
 
21
  # Model configuration
22
  MODEL_NAME = "Nanbeige/Nanbeige4.1-3B"
@@ -36,8 +42,8 @@ class Message(BaseModel):
36
  class ChatRequest(BaseModel):
37
  messages: List[Message]
38
  stream: bool = True
39
- max_tokens: int = 2048
40
- temperature: float = 0.6
41
  tools: Optional[List[Dict]] = None
42
 
43
 
@@ -83,59 +89,121 @@ app.add_middleware(
83
  )
84
 
85
 
86
- def format_messages(messages: List[Message]) -> str:
87
- """Format messages into prompt string."""
88
- formatted = []
89
- for msg in messages:
90
- if msg.role == "system":
91
- formatted.append(f"System: {msg.content}")
92
- elif msg.role == "user":
93
- formatted.append(f"User: {msg.content}")
94
- elif msg.role == "assistant":
95
- formatted.append(f"Assistant: {msg.content}")
96
 
97
- formatted.append("Assistant:")
98
- return "\n\n".join(formatted)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
- async def stream_tokens(prompt: str, max_tokens: int, temperature: float) -> AsyncGenerator[str, None]:
102
- """Stream tokens from the model."""
 
 
 
103
  global model, tokenizer
104
 
105
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
106
  if torch.cuda.is_available():
107
  inputs = inputs.to("cuda")
108
 
 
109
  streamer = TextIteratorStreamer(
110
  tokenizer,
111
  skip_prompt=True,
112
- skip_special_tokens=True
 
113
  )
114
 
115
  generation_kwargs = dict(
116
- inputs,
117
  streamer=streamer,
118
- max_new_tokens=max_tokens,
119
  temperature=temperature,
 
 
120
  do_sample=temperature > 0,
 
121
  pad_token_id=tokenizer.eos_token_id
122
  )
123
 
124
- # Run generation in separate thread
125
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
126
  thread.start()
127
 
128
  generated_text = ""
129
- for new_text in streamer:
130
- generated_text += new_text
131
- # Yield each token
132
- data = json.dumps({"type": "token", "content": new_text})
133
- yield f"data: {data}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # Signal completion
136
- yield f"data: {json.dumps({'type': 'done', 'content': ''})}\n\n"
 
137
 
138
- thread.join()
 
 
 
 
139
 
140
 
141
  @app.get("/")
@@ -172,18 +240,21 @@ async def chat(request: ChatRequest):
172
  if model is None or tokenizer is None:
173
  raise HTTPException(status_code=503, detail="Model not loaded yet")
174
 
175
- # Format messages into prompt
176
- prompt = format_messages(request.messages)
177
 
178
  if request.stream:
179
- # Return streaming response
180
  return StreamingResponse(
181
- stream_tokens(prompt, request.max_tokens, request.temperature),
182
  media_type="text/event-stream",
183
  headers={
184
- "Cache-Control": "no-cache",
 
 
185
  "Connection": "keep-alive",
186
- "X-Accel-Buffering": "no"
 
187
  }
188
  )
189
  else:
@@ -194,9 +265,12 @@ async def chat(request: ChatRequest):
194
 
195
  outputs = model.generate(
196
  **inputs,
197
- max_new_tokens=request.max_tokens,
198
  temperature=request.temperature,
 
 
199
  do_sample=request.temperature > 0,
 
200
  pad_token_id=tokenizer.eos_token_id
201
  )
202
 
@@ -215,4 +289,4 @@ async def chat(request: ChatRequest):
215
 
216
  if __name__ == "__main__":
217
  import uvicorn
218
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  import os
7
  import json
8
  import asyncio
9
+ import time
10
  from typing import AsyncGenerator, List, Dict, Any, Optional
11
  from contextlib import asynccontextmanager
12
+ from datetime import datetime
13
 
14
  from fastapi import FastAPI, HTTPException
15
  from fastapi.responses import StreamingResponse
 
19
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
20
  from threading import Thread
21
  import torch
22
+ import logging
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logging.basicConfig(level=logging.INFO)
26
 
27
  # Model configuration
28
  MODEL_NAME = "Nanbeige/Nanbeige4.1-3B"
 
42
  class ChatRequest(BaseModel):
43
  messages: List[Message]
44
  stream: bool = True
45
+ max_tokens: int = 8192 # Increased from 2048 (supports up to 131072)
46
+ temperature: float = 0.6 # Nanbeige4.1-3B recommended
47
  tools: Optional[List[Dict]] = None
48
 
49
 
 
89
  )
90
 
91
 
92
+ def format_messages_proper(messages: List[Message], tools: Optional[List[Dict]] = None) -> str:
93
+ """Format messages using the model's proper chat template.
94
+
95
+ Nanbeige4.1-3B uses the HF transformers chat template.
96
+ This ensures proper formatting for both regular and tool-aware conversations.
97
+ """
98
+ global tokenizer
 
 
 
99
 
100
+ # Convert Message objects to dicts for tokenizer
101
+ message_dicts = [{"role": msg.role, "content": msg.content} for msg in messages]
102
+
103
+ # Use tokenizer's built-in chat template for proper formatting
104
+ if tools:
105
+ # Tool-aware formatting (for function calling)
106
+ prompt = tokenizer.apply_chat_template(
107
+ message_dicts,
108
+ tools=tools,
109
+ add_generation_prompt=True,
110
+ tokenize=False
111
+ )
112
+ else:
113
+ # Regular chat formatting
114
+ prompt = tokenizer.apply_chat_template(
115
+ message_dicts,
116
+ add_generation_prompt=True,
117
+ tokenize=False
118
+ )
119
+
120
+ return prompt
121
 
122
 
123
+ async def stream_tokens(prompt: str, max_tokens: int, temperature: float, tools: Optional[List[Dict]] = None) -> AsyncGenerator[str, None]:
124
+ """Stream tokens from the model token-by-token as fast as generated.
125
+
126
+ Uses Nanbeige4.1-3B recommended hyperparameters.
127
+ """
128
  global model, tokenizer
129
 
130
+ start_time = time.time()
131
+ logger.info(f"Starting token generation for prompt length: {len(prompt)}")
132
+
133
+ inputs = tokenizer(
134
+ prompt,
135
+ return_tensors="pt",
136
+ truncation=True,
137
+ max_length=2048
138
+ )
139
+
140
  if torch.cuda.is_available():
141
  inputs = inputs.to("cuda")
142
 
143
+ # Create streamer with timeout to prevent hanging
144
  streamer = TextIteratorStreamer(
145
  tokenizer,
146
  skip_prompt=True,
147
+ skip_special_tokens=True,
148
+ timeout=300.0 # 5 min timeout per token
149
  )
150
 
151
  generation_kwargs = dict(
152
+ **inputs,
153
  streamer=streamer,
154
+ max_new_tokens=min(max_tokens, 131072), # Support up to model's max (131072)
155
  temperature=temperature,
156
+ top_p=0.95, # Nanbeige4.1-3B recommended
157
+ repetition_penalty=1.0, # Nanbeige4.1-3B recommended
158
  do_sample=temperature > 0,
159
+ eos_token_id=166101, # Nanbeige4.1-3B specific EOS token
160
  pad_token_id=tokenizer.eos_token_id
161
  )
162
 
163
+ # Run generation in separate thread (non-blocking)
164
+ thread = Thread(target=model.generate, kwargs=generation_kwargs, daemon=False)
165
  thread.start()
166
 
167
  generated_text = ""
168
+ token_count = 0
169
+ first_token_time = None
170
+
171
+ try:
172
+ for new_text in streamer:
173
+ if new_text: # Skip empty strings
174
+ generated_text += new_text
175
+ token_count += 1
176
+
177
+ # Log first token time (time to first byte)
178
+ if first_token_time is None:
179
+ first_token_time = time.time() - start_time
180
+ logger.info(f"First token generated in {first_token_time:.2f}s")
181
+
182
+ # preview logging to verify streaming works
183
+ logger.info(f"streaming token #{token_count}: {repr(new_text)}")
184
+
185
+ # Yield SSE event immediately (no buffering)
186
+ data = json.dumps({"type": "token", "content": new_text})
187
+ yield f"data: {data}\n\n"
188
+ logger.debug(f"Token {token_count}: {repr(new_text[:20])}...")
189
+
190
+ # Log generation stats
191
+ total_time = time.time() - start_time
192
+ tokens_per_sec = token_count / total_time if total_time > 0 else 0
193
+ logger.info(f"Generation complete: {token_count} tokens in {total_time:.2f}s ({tokens_per_sec:.2f} tok/s)")
194
+
195
+ # Signal completion
196
+ yield f"data: {json.dumps({'type': 'done', 'content': ''})}\n\n"
197
 
198
+ except Exception as e:
199
+ logger.error(f"Token generation error: {e}", exc_info=True)
200
+ yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
201
 
202
+ finally:
203
+ # Wait for thread to finish
204
+ thread.join(timeout=5)
205
+ if thread.is_alive():
206
+ logger.warning("Generation thread did not finish within timeout")
207
 
208
 
209
  @app.get("/")
 
240
  if model is None or tokenizer is None:
241
  raise HTTPException(status_code=503, detail="Model not loaded yet")
242
 
243
+ # Format messages using the model's proper chat template
244
+ prompt = format_messages_proper(request.messages, request.tools)
245
 
246
  if request.stream:
247
+ # Return streaming response with anti-buffering headers
248
  return StreamingResponse(
249
+ stream_tokens(prompt, request.max_tokens, request.temperature, request.tools),
250
  media_type="text/event-stream",
251
  headers={
252
+ "Cache-Control": "no-cache, no-store, must-revalidate",
253
+ "Pragma": "no-cache",
254
+ "Expires": "0",
255
  "Connection": "keep-alive",
256
+ "X-Accel-Buffering": "no",
257
+ "Transfer-Encoding": "chunked"
258
  }
259
  )
260
  else:
 
265
 
266
  outputs = model.generate(
267
  **inputs,
268
+ max_new_tokens=min(request.max_tokens, 131072), # Support up to model's max
269
  temperature=request.temperature,
270
+ top_p=0.95, # Nanbeige4.1-3B recommended
271
+ repetition_penalty=1.0, # Nanbeige4.1-3B recommended
272
  do_sample=request.temperature > 0,
273
+ eos_token_id=166101, # Model-specific EOS token
274
  pad_token_id=tokenizer.eos_token_id
275
  )
276
 
 
289
 
290
  if __name__ == "__main__":
291
  import uvicorn
292
+ uvicorn.run(app, host="0.0.0.0", port=7860)