Ashok75 commited on
Commit
6e9c061
·
verified ·
1 Parent(s): d0b2daf

Upload 5 files

Browse files
Files changed (3) hide show
  1. README.md +137 -96
  2. app.py +24 -281
  3. server_runtime.py +522 -0
README.md CHANGED
@@ -1,96 +1,137 @@
1
- ---
2
- title: React
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.9.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
-
14
-
15
- # Nanbeige4.1-3B Inference Server
16
-
17
- Lightweight remote LLM inference service for Enterprise ReAct Agent systems.
18
-
19
- ## Overview
20
-
21
- This Hugging Face Space hosts the **Nanbeige4.1-3B** model as a remote inference API, designed to work with local agent orchestration systems. The model runs entirely in this Space, while all agent logic, tools, and memory systems run on the user's local machine.
22
-
23
- ## Model Information
24
-
25
- - **Model**: [Nanbeige/Nanbeige4.1-3B](https://huggingface.co/Nanbeige/Nanbeige4.1-3B)
26
- - **Parameters**: 3B
27
- - **Context Window**: 8K tokens
28
- - **Capabilities**: Tool calling, reasoning, 500+ tool invocation rounds
29
- - **License**: Apache 2.0
30
-
31
- ## API Endpoints
32
-
33
- ### POST /chat
34
- Main chat completion endpoint (OpenAI-compatible).
35
-
36
- **Request:**
37
- ```json
38
- {
39
- "messages": [
40
- {"role": "system", "content": "You are a helpful assistant."},
41
- {"role": "user", "content": "Hello!"}
42
- ],
43
- "tools": [...],
44
- "stream": false,
45
- "max_tokens": 2048,
46
- "temperature": 0.6,
47
- "top_p": 0.95
48
- }
49
- ```
50
-
51
- **Response:**
52
- ```json
53
- {
54
- "id": "chatcmpl-...",
55
- "object": "chat.completion",
56
- "created": 1234567890,
57
- "model": "Nanbeige/Nanbeige4.1-3B",
58
- "choices": [...],
59
- "usage": {
60
- "prompt_tokens": 20,
61
- "completion_tokens": 50,
62
- "total_tokens": 70
63
- }
64
- }
65
- ```
66
-
67
- ### GET /chat
68
- Web interface for testing.
69
-
70
- ### GET /health
71
- Health check endpoint.
72
-
73
- ## Usage with Local Agent
74
-
75
- ```python
76
- import requests
77
-
78
- response = requests.post(
79
- "https://your-space.hf.space/chat",
80
- json={
81
- "messages": [{"role": "user", "content": "Hello!"}],
82
- "temperature": 0.6
83
- }
84
- )
85
- result = response.json()
86
- ```
87
-
88
- ## Hardware Requirements
89
-
90
- - **GPU**: Recommended (CUDA-compatible)
91
- - **CPU**: Fallback supported
92
- - **Memory**: ~8GB RAM minimum
93
-
94
- ## Local Agent Repository
95
-
96
- For the complete local agent system that connects to this Space, see the companion repository.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF Space Backend (Streaming LLM Server)
2
+
3
+ This folder contains Hugging Face Space backends for two model deployments that share the same production runtime.
4
+
5
+ ## Files
6
+ - `app.py`: Nanbeige deployment entrypoint (`Nanbeige/Nanbeige4.1-3B`)
7
+ - `main.py`: LiquidAI deployment entrypoint (`LiquidAI/LFM2.5-1.2B-Thinking`)
8
+ - `server_runtime.py`: shared queue + worker + streaming runtime used by both entrypoints
9
+ - `index.html`: lightweight local streaming test UI
10
+ - `requirements.txt`: runtime dependencies
11
+
12
+ ## Runtime Architecture
13
+ Both servers use the same execution flow:
14
+
15
+ Client Request
16
+ -> FastAPI `/chat`
17
+ -> `asyncio.Queue` request buffer
18
+ -> worker pool (`asyncio` tasks)
19
+ -> concurrency gate (`asyncio.Semaphore`)
20
+ -> one generation thread per request (`model.generate`)
21
+ -> per-request `TextIteratorStreamer`
22
+ -> SSE token stream to client
23
+
24
+ ### Why this structure
25
+ - Keeps the event loop responsive.
26
+ - Prevents response mixing across users (isolated request objects).
27
+ - Supports controlled concurrency under CPU/GPU.
28
+ - Queues overflow load instead of hard failing during bursts.
29
+
30
+ ## Concurrency
31
+ Hardware-aware worker count:
32
+ - CPU: `1..4` workers (core-based)
33
+ - GPU: `3..5` workers (VRAM tier-based)
34
+
35
+ Override at runtime:
36
+ - `HF_MAX_WORKERS`
37
+
38
+ Queue settings:
39
+ - `HF_QUEUE_MAX_SIZE` (default: `512`)
40
+
41
+ ## Thread Lifecycle and Safety
42
+ - Each request gets its own generation thread.
43
+ - Each request has a cancellation event.
44
+ - `CancelAwareStoppingCriteria` stops generation when client disconnects/cancels.
45
+ - Streamer is explicitly ended in `finally` block.
46
+ - Generation thread is joined with a long timeout (`HF_GENERATION_JOIN_TIMEOUT_SECONDS`, default `180`) to avoid orphaned work.
47
+
48
+ This fixes the old short-join behavior that produced frequent:
49
+ - `Generation thread did not finish within timeout`
50
+
51
+ ## Metrics and Logging
52
+ Per request logs include:
53
+ - request queued
54
+ - worker start/end
55
+ - first token latency
56
+ - generated token count
57
+ - tokens/sec
58
+ - active workers
59
+ - queue size
60
+
61
+ Debug token-by-token logging is optional:
62
+ - `HF_DEBUG_TOKEN_LOGS=1`
63
+
64
+ ## API
65
+ ### `POST /chat`
66
+ Body:
67
+ - `messages`: chat messages
68
+ - `stream`: `true` for SSE streaming
69
+ - `max_tokens`: max new tokens requested
70
+ - `temperature`: optional; if omitted model default is used
71
+ - `tools`: optional tool schemas for chat template
72
+
73
+ Streaming response format:
74
+ - SSE `data: {"type":"token","content":"..."}` chunks
75
+ - final `{"type":"done","content":""}` event
76
+
77
+ ### `GET /health`
78
+ Returns:
79
+ - `status`
80
+ - `model_loaded`
81
+ - `device`
82
+ - `active_workers`
83
+ - `queue_size`
84
+ - `max_workers`
85
+
86
+ ### `GET /index`
87
+ Serves `index.html` test page.
88
+
89
+ ## Model-Specific Settings
90
+ ### `app.py` (Nanbeige4.1-3B)
91
+ - `max_input_tokens=32768`
92
+ - `eos_token_id=166101`
93
+ - `default_temperature=0.6`
94
+ - `top_p=0.95`
95
+ - `repetition_penalty=1.0`
96
+ - `tokenizer_use_fast=False`
97
+
98
+ ### `main.py` (LFM2.5-1.2B-Thinking)
99
+ - `max_input_tokens=32768`
100
+ - `default_temperature=0.1`
101
+ - `top_p=0.1`
102
+ - `top_k=50`
103
+ - `repetition_penalty=1.05`
104
+ - `eos_token_id` from tokenizer config
105
+
106
+ ## Environment Variables
107
+ - `HF_MAX_WORKERS`
108
+ - `HF_QUEUE_MAX_SIZE`
109
+ - `HF_STREAMER_TIMEOUT_SECONDS`
110
+ - `HF_GENERATION_JOIN_TIMEOUT_SECONDS`
111
+ - `HF_MAX_INPUT_TOKENS`
112
+ - `HF_MAX_NEW_TOKENS`
113
+ - `HF_DEBUG_TOKEN_LOGS`
114
+
115
+ ## Model Documentation References
116
+ ### Nanbeige / `app.py`
117
+ - https://huggingface.co/Nanbeige/Nanbeige4.1-3B
118
+ - https://huggingface.co/Nanbeige/Nanbeige4.1-3B/blob/main/README.md
119
+ - https://huggingface.co/Nanbeige/Nanbeige4.1-3B/blob/main/Nanbeige4.1-3B-Report.pdf
120
+ - https://huggingface.co/Nanbeige/Nanbeige4.1-3B/blob/main/generation_config.json
121
+ - https://huggingface.co/Nanbeige/Nanbeige4.1-3B/blob/main/config.json
122
+
123
+ ### LiquidAI / `main.py`
124
+ - https://huggingface.co/LiquidAI/LFM2.5-1.2B-Thinking
125
+ - https://huggingface.co/LiquidAI/LFM2.5-1.2B-Thinking/blob/main/README.md
126
+ - https://huggingface.co/LiquidAI/LFM2.5-1.2B-Thinking/blob/main/chat_template.jinja
127
+ - https://huggingface.co/LiquidAI/LFM2.5-1.2B-Thinking/blob/main/config.json
128
+ - https://docs.liquid.ai/lfm/key-concepts/chat-template
129
+ - https://docs.liquid.ai/lfm/key-concepts/text-generation-and-prompting
130
+ - https://docs.liquid.ai/lfm/key-concepts/tool-use
131
+ - https://huggingface.co/docs/transformers/en/chat_templating#using-applychattemplate
132
+
133
+ ## Notes
134
+ - Model is loaded once per process during FastAPI lifespan startup.
135
+ - `index.html` is intentionally a simple streaming test page, not the production frontend.
136
+ - Both entrypoints (`app.py`, `main.py`) now behave consistently by design.
137
+
app.py CHANGED
@@ -1,294 +1,37 @@
1
  """
2
- HuggingFace Space application for Nanbeige4.1-3B model inference.
3
- Provides streaming chat completion API.
4
- """
5
 
6
- import os
7
- import json
8
- import asyncio
9
- import time
10
- from typing import AsyncGenerator, List, Dict, Any, Optional
11
- from contextlib import asynccontextmanager
12
- from datetime import datetime
13
 
14
- from fastapi import FastAPI, HTTPException
15
- from fastapi.responses import StreamingResponse
16
- from fastapi.middleware.cors import CORSMiddleware
17
- from fastapi.responses import FileResponse
18
- from pydantic import BaseModel
19
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
20
- from threading import Thread
21
- import torch
22
- import logging
23
 
24
- logger = logging.getLogger(__name__)
25
- logging.basicConfig(level=logging.INFO)
26
 
27
- # Model configuration
28
  MODEL_NAME = "Nanbeige/Nanbeige4.1-3B"
29
- MAX_LENGTH = 32768
30
-
31
- # Global model and tokenizer
32
- model = None
33
- tokenizer = None
34
-
35
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
36
-
37
- class Message(BaseModel):
38
- role: str
39
- content: str
40
-
41
-
42
- class ChatRequest(BaseModel):
43
- messages: List[Message]
44
- stream: bool = True
45
- max_tokens: int = 8192 # Increased from 2048 (supports up to 131072)
46
- temperature: float = 0.6 # Nanbeige4.1-3B recommended
47
- tools: Optional[List[Dict]] = None
48
-
49
 
50
- @asynccontextmanager
51
- async def lifespan(app: FastAPI):
52
- """Application lifespan handler."""
53
- global model, tokenizer
54
-
55
- print("Loading model...")
56
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
57
- model = AutoModelForCausalLM.from_pretrained(
58
- MODEL_NAME,
59
- trust_remote_code=True,
60
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
- device_map="auto" if torch.cuda.is_available() else None
62
  )
63
-
64
- if not torch.cuda.is_available():
65
- model = model.to("cpu")
66
-
67
- print("Model loaded successfully!")
68
- yield
69
-
70
- # Cleanup
71
- del model
72
- del tokenizer
73
- torch.cuda.empty_cache()
74
-
75
-
76
- app = FastAPI(
77
- title="Nanbeige4.1-3B Inference API",
78
- description="Streaming chat completion API for Nanbeige4.1-3B",
79
- version="1.0.0",
80
- lifespan=lifespan
81
- )
82
-
83
- app.add_middleware(
84
- CORSMiddleware,
85
- allow_origins=["*"],
86
- allow_credentials=True,
87
- allow_methods=["*"],
88
- allow_headers=["*"],
89
  )
90
 
91
 
92
- def format_messages_proper(messages: List[Message], tools: Optional[List[Dict]] = None) -> str:
93
- """Format messages using the model's proper chat template.
94
-
95
- Nanbeige4.1-3B uses the HF transformers chat template.
96
- This ensures proper formatting for both regular and tool-aware conversations.
97
- """
98
- global tokenizer
99
-
100
- # Convert Message objects to dicts for tokenizer
101
- message_dicts = [{"role": msg.role, "content": msg.content} for msg in messages]
102
-
103
- # Use tokenizer's built-in chat template for proper formatting
104
- if tools:
105
- # Tool-aware formatting (for function calling)
106
- prompt = tokenizer.apply_chat_template(
107
- message_dicts,
108
- tools=tools,
109
- add_generation_prompt=True,
110
- tokenize=False
111
- )
112
- else:
113
- # Regular chat formatting
114
- prompt = tokenizer.apply_chat_template(
115
- message_dicts,
116
- add_generation_prompt=True,
117
- tokenize=False
118
- )
119
-
120
- return prompt
121
-
122
-
123
- async def stream_tokens(prompt: str, max_tokens: int, temperature: float, tools: Optional[List[Dict]] = None) -> AsyncGenerator[str, None]:
124
- """Stream tokens from the model token-by-token as fast as generated.
125
-
126
- Uses Nanbeige4.1-3B recommended hyperparameters.
127
- """
128
- global model, tokenizer
129
-
130
- start_time = time.time()
131
- logger.info(f"Starting token generation for prompt length: {len(prompt)}")
132
-
133
- inputs = tokenizer(
134
- prompt,
135
- return_tensors="pt",
136
- truncation=True,
137
- max_length=2048
138
- )
139
-
140
- if torch.cuda.is_available():
141
- inputs = inputs.to("cuda")
142
-
143
- # Create streamer with timeout to prevent hanging
144
- streamer = TextIteratorStreamer(
145
- tokenizer,
146
- skip_prompt=True,
147
- skip_special_tokens=True,
148
- timeout=300.0 # 5 min timeout per token
149
- )
150
-
151
- generation_kwargs = dict(
152
- **inputs,
153
- streamer=streamer,
154
- max_new_tokens=min(max_tokens, 131072), # Support up to model's max (131072)
155
- temperature=temperature,
156
- top_p=0.95, # Nanbeige4.1-3B recommended
157
- repetition_penalty=1.0, # Nanbeige4.1-3B recommended
158
- do_sample=temperature > 0,
159
- eos_token_id=166101, # Nanbeige4.1-3B specific EOS token
160
- pad_token_id=tokenizer.eos_token_id
161
- )
162
-
163
- # Run generation in separate thread (non-blocking)
164
- thread = Thread(target=model.generate, kwargs=generation_kwargs, daemon=False)
165
- thread.start()
166
-
167
- generated_text = ""
168
- token_count = 0
169
- first_token_time = None
170
-
171
- try:
172
- for new_text in streamer:
173
- if new_text: # Skip empty strings
174
- generated_text += new_text
175
- token_count += 1
176
-
177
- # Log first token time (time to first byte)
178
- if first_token_time is None:
179
- first_token_time = time.time() - start_time
180
- logger.info(f"First token generated in {first_token_time:.2f}s")
181
-
182
- # preview logging to verify streaming works
183
- logger.info(f"streaming token #{token_count}: {repr(new_text)}")
184
-
185
- # Yield SSE event immediately (no buffering)
186
- data = json.dumps({"type": "token", "content": new_text})
187
- yield f"data: {data}\n\n"
188
- # let the event loop schedule a send/flush so proxies don't buffer
189
- await asyncio.sleep(0)
190
- logger.debug(f"Token {token_count}: {repr(new_text[:20])}...")
191
-
192
- # Log generation stats
193
- total_time = time.time() - start_time
194
- tokens_per_sec = token_count / total_time if total_time > 0 else 0
195
- logger.info(f"Generation complete: {token_count} tokens in {total_time:.2f}s ({tokens_per_sec:.2f} tok/s)")
196
-
197
- # Signal completion
198
- yield f"data: {json.dumps({'type': 'done', 'content': ''})}\n\n"
199
-
200
- except Exception as e:
201
- logger.error(f"Token generation error: {e}", exc_info=True)
202
- yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
203
-
204
- finally:
205
- # Wait for thread to finish
206
- thread.join(timeout=5)
207
- if thread.is_alive():
208
- logger.warning("Generation thread did not finish within timeout")
209
-
210
-
211
- @app.get("/")
212
- async def root():
213
- """Root endpoint."""
214
- return {
215
- "name": "Nanbeige4.1-3B Inference API",
216
- "version": "1.0.0",
217
- "model": MODEL_NAME,
218
- "status": "running"
219
- }
220
-
221
-
222
- @app.get("/index", response_class=FileResponse)
223
- async def serve_chat():
224
- """Serve chat.html as index."""
225
- return FileResponse(os.path.join(BASE_DIR, "index.html"))
226
-
227
-
228
- @app.get("/health")
229
- async def health():
230
- """Health check endpoint."""
231
- return {
232
- "status": "healthy",
233
- "model_loaded": model is not None and tokenizer is not None
234
- }
235
-
236
-
237
- @app.post("/chat")
238
- async def chat(request: ChatRequest):
239
- """
240
- Chat completion endpoint with streaming support.
241
- """
242
- if model is None or tokenizer is None:
243
- raise HTTPException(status_code=503, detail="Model not loaded yet")
244
-
245
- # Format messages using the model's proper chat template
246
- prompt = format_messages_proper(request.messages, request.tools)
247
-
248
- if request.stream:
249
- # Return streaming response with anti-buffering headers
250
- return StreamingResponse(
251
- stream_tokens(prompt, request.max_tokens, request.temperature, request.tools),
252
- media_type="text/event-stream",
253
- headers={
254
- "Cache-Control": "no-cache, no-store, must-revalidate",
255
- "Pragma": "no-cache",
256
- "Expires": "0",
257
- "Connection": "keep-alive",
258
- "X-Accel-Buffering": "no",
259
- "Transfer-Encoding": "chunked"
260
- }
261
- )
262
- else:
263
- # Non-streaming response
264
- inputs = tokenizer(prompt, return_tensors="pt")
265
- if torch.cuda.is_available():
266
- inputs = inputs.to("cuda")
267
-
268
- outputs = model.generate(
269
- **inputs,
270
- max_new_tokens=min(request.max_tokens, 131072), # Support up to model's max
271
- temperature=request.temperature,
272
- top_p=0.95, # Nanbeige4.1-3B recommended
273
- repetition_penalty=1.0, # Nanbeige4.1-3B recommended
274
- do_sample=request.temperature > 0,
275
- eos_token_id=166101, # Model-specific EOS token
276
- pad_token_id=tokenizer.eos_token_id
277
- )
278
-
279
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
280
- # Extract only the assistant's response
281
- response_text = response_text[len(prompt):].strip()
282
-
283
- return {
284
- "content": response_text,
285
- "usage": {
286
- "prompt_tokens": inputs.input_ids.shape[1],
287
- "completion_tokens": outputs.shape[1] - inputs.input_ids.shape[1]
288
- }
289
- }
290
-
291
-
292
  if __name__ == "__main__":
293
  import uvicorn
294
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  """
2
+ Hugging Face Space server for Nanbeige/Nanbeige4.1-3B.
 
 
3
 
4
+ This file uses the shared runtime with:
5
+ - async queue buffering
6
+ - worker pool + semaphore concurrency
7
+ - safe per-request generation thread lifecycle
8
+ """
 
 
9
 
10
+ try:
11
+ from .server_runtime import RuntimeConfig, create_hf_space_app
12
+ except ImportError: # pragma: no cover - direct script execution
13
+ from server_runtime import RuntimeConfig, create_hf_space_app
 
 
 
 
 
14
 
 
 
15
 
 
16
  MODEL_NAME = "Nanbeige/Nanbeige4.1-3B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ app = create_hf_space_app(
19
+ RuntimeConfig(
20
+ model_name=MODEL_NAME,
21
+ title="Nanbeige4.1-3B Inference API",
22
+ description="Streaming chat completion API for Nanbeige4.1-3B",
23
+ max_input_tokens=32768,
24
+ eos_token_id=166101,
25
+ default_temperature=0.6,
26
+ top_p=0.95,
27
+ repetition_penalty=1.0,
28
+ tokenizer_use_fast=False,
29
+ logger_name=__name__,
30
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if __name__ == "__main__":
35
  import uvicorn
36
+
37
+ uvicorn.run(app, host="0.0.0.0", port=7860)
server_runtime.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared Hugging Face Space runtime for streaming chat inference.
3
+
4
+ This module provides:
5
+ - one-time global model loading
6
+ - async request queue
7
+ - worker pool with semaphore-based concurrency limits
8
+ - per-request streamer/thread isolation
9
+ - SSE streaming responses
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ import json
16
+ import logging
17
+ import os
18
+ import time
19
+ import uuid
20
+ from contextlib import asynccontextmanager
21
+ from dataclasses import dataclass, field
22
+ from queue import Empty as QueueEmpty
23
+ from threading import Event as ThreadEvent
24
+ from threading import Thread
25
+ from typing import Any, Dict, List, Optional
26
+
27
+ import torch
28
+ from fastapi import FastAPI, HTTPException
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from fastapi.responses import FileResponse, StreamingResponse
31
+ from pydantic import BaseModel
32
+ from transformers import (
33
+ AutoModelForCausalLM,
34
+ AutoTokenizer,
35
+ StoppingCriteria,
36
+ StoppingCriteriaList,
37
+ TextIteratorStreamer,
38
+ )
39
+
40
+
41
+ class Message(BaseModel):
42
+ role: str
43
+ content: str
44
+
45
+
46
+ class ChatRequest(BaseModel):
47
+ messages: List[Message]
48
+ stream: bool = True
49
+ max_tokens: int = 8192
50
+ temperature: Optional[float] = None
51
+ tools: Optional[List[Dict[str, Any]]] = None
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class RuntimeConfig:
56
+ model_name: str
57
+ title: str
58
+ description: str
59
+ version: str = "1.0.0"
60
+ max_input_tokens: int = 32768
61
+ max_new_tokens: int = 131072
62
+ top_p: float = 0.95
63
+ top_k: Optional[int] = None
64
+ repetition_penalty: float = 1.0
65
+ eos_token_id: Optional[int] = None
66
+ default_temperature: float = 0.6
67
+ tokenizer_use_fast: Optional[bool] = None
68
+ logger_name: str = "hf_space"
69
+
70
+
71
+ @dataclass
72
+ class GenerationTask:
73
+ request_id: str
74
+ prompt: str
75
+ max_tokens: int
76
+ temperature: float
77
+ output_queue: asyncio.Queue[Optional[Dict[str, Any]]]
78
+ created_at: float = field(default_factory=time.time)
79
+ cancel_event: ThreadEvent = field(default_factory=ThreadEvent)
80
+ prompt_tokens: int = 0
81
+ generated_tokens: int = 0
82
+ first_token_latency: Optional[float] = None
83
+ start_time: Optional[float] = None
84
+ end_time: Optional[float] = None
85
+
86
+
87
+ class CancelAwareStoppingCriteria(StoppingCriteria):
88
+ """Stops generation when the request is cancelled/disconnected."""
89
+
90
+ def __init__(self, cancel_event: ThreadEvent):
91
+ self.cancel_event = cancel_event
92
+
93
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
94
+ return self.cancel_event.is_set()
95
+
96
+
97
+ def _is_truthy(value: str) -> bool:
98
+ return value.strip().lower() in {"1", "true", "yes", "on"}
99
+
100
+
101
+ def _format_sse_event(payload: Dict[str, Any]) -> str:
102
+ return f"data: {json.dumps(payload)}\n\n"
103
+
104
+
105
+ def _detect_concurrency(device: str) -> int:
106
+ # Allow environment override if needed for debugging/tuning.
107
+ override = os.getenv("HF_MAX_WORKERS", "").strip()
108
+ if override:
109
+ try:
110
+ parsed = int(override)
111
+ if parsed > 0:
112
+ return parsed
113
+ except ValueError:
114
+ pass
115
+
116
+ if device == "cuda" and torch.cuda.is_available():
117
+ total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
118
+ if total_vram_gb >= 20:
119
+ return 5
120
+ if total_vram_gb >= 10:
121
+ return 4
122
+ return 3
123
+
124
+ cpu_count = os.cpu_count() or 1
125
+ return max(1, min(4, max(1, cpu_count // 2)))
126
+
127
+
128
+ def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
129
+ logger = logging.getLogger(config.logger_name)
130
+ logging.basicConfig(level=logging.INFO)
131
+
132
+ debug_token_logs = _is_truthy(os.getenv("HF_DEBUG_TOKEN_LOGS", "0"))
133
+ queue_max_size = int(os.getenv("HF_QUEUE_MAX_SIZE", "512"))
134
+ streamer_timeout = float(os.getenv("HF_STREAMER_TIMEOUT_SECONDS", "8"))
135
+ join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180"))
136
+ max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens)))
137
+ max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens)))
138
+
139
+ base_dir = os.path.dirname(os.path.abspath(__file__))
140
+
141
+ model = None
142
+ tokenizer = None
143
+ device = "cuda" if torch.cuda.is_available() else "cpu"
144
+ max_workers = _detect_concurrency(device)
145
+
146
+ request_queue: asyncio.Queue[Optional[GenerationTask]] = asyncio.Queue(maxsize=queue_max_size)
147
+ worker_tasks: List[asyncio.Task] = []
148
+ worker_semaphore = asyncio.Semaphore(max_workers)
149
+
150
+ active_workers = 0
151
+ active_workers_lock = asyncio.Lock()
152
+
153
+ async def set_active_workers(delta: int) -> int:
154
+ nonlocal active_workers
155
+ async with active_workers_lock:
156
+ active_workers += delta
157
+ if active_workers < 0:
158
+ active_workers = 0
159
+ return active_workers
160
+
161
+ def format_messages_proper(messages: List[Message], tools: Optional[List[Dict[str, Any]]] = None) -> str:
162
+ message_dicts = [{"role": msg.role, "content": msg.content} for msg in messages]
163
+ if tools:
164
+ return tokenizer.apply_chat_template(
165
+ message_dicts,
166
+ tools=tools,
167
+ add_generation_prompt=True,
168
+ tokenize=False,
169
+ )
170
+ return tokenizer.apply_chat_template(
171
+ message_dicts,
172
+ add_generation_prompt=True,
173
+ tokenize=False,
174
+ )
175
+
176
+ async def run_generation(task: GenerationTask, worker_id: int) -> None:
177
+ request_start = time.time()
178
+ task.start_time = request_start
179
+ await set_active_workers(+1)
180
+
181
+ try:
182
+ logger.info(
183
+ "[%s] worker=%d start queue_size=%d active_workers=%d",
184
+ task.request_id,
185
+ worker_id,
186
+ request_queue.qsize(),
187
+ active_workers,
188
+ )
189
+
190
+ inputs = tokenizer(
191
+ task.prompt,
192
+ return_tensors="pt",
193
+ truncation=True,
194
+ max_length=max_input_tokens,
195
+ add_special_tokens=False,
196
+ )
197
+
198
+ task.prompt_tokens = int(inputs.input_ids.shape[1])
199
+
200
+ if device == "cuda":
201
+ inputs = inputs.to("cuda")
202
+
203
+ streamer = TextIteratorStreamer(
204
+ tokenizer,
205
+ skip_prompt=True,
206
+ skip_special_tokens=True,
207
+ timeout=streamer_timeout,
208
+ )
209
+
210
+ stopping_criteria = StoppingCriteriaList(
211
+ [CancelAwareStoppingCriteria(task.cancel_event)]
212
+ )
213
+
214
+ generation_kwargs: Dict[str, Any] = dict(
215
+ **inputs,
216
+ streamer=streamer,
217
+ max_new_tokens=min(task.max_tokens, max_new_tokens_limit),
218
+ temperature=task.temperature,
219
+ top_p=config.top_p,
220
+ repetition_penalty=config.repetition_penalty,
221
+ do_sample=task.temperature > 0,
222
+ eos_token_id=config.eos_token_id if config.eos_token_id is not None else tokenizer.eos_token_id,
223
+ pad_token_id=tokenizer.eos_token_id,
224
+ stopping_criteria=stopping_criteria,
225
+ )
226
+ if config.top_k is not None:
227
+ generation_kwargs["top_k"] = config.top_k
228
+
229
+ generation_error: Dict[str, Exception] = {}
230
+ generation_done = ThreadEvent()
231
+
232
+ def generate_target() -> None:
233
+ try:
234
+ with torch.inference_mode():
235
+ model.generate(**generation_kwargs)
236
+ except Exception as exc: # pragma: no cover - defensive logging
237
+ generation_error["error"] = exc
238
+ logger.error("[%s] generation thread error: %s", task.request_id, exc, exc_info=True)
239
+ finally:
240
+ generation_done.set()
241
+ try:
242
+ streamer.end()
243
+ except Exception:
244
+ # Best-effort close of streamer queue.
245
+ pass
246
+
247
+ generation_thread = Thread(
248
+ target=generate_target,
249
+ name=f"gen-{task.request_id[:8]}",
250
+ daemon=True,
251
+ )
252
+ generation_thread.start()
253
+
254
+ stream_iter = iter(streamer)
255
+ while True:
256
+ if task.cancel_event.is_set():
257
+ logger.info("[%s] cancellation requested", task.request_id)
258
+ break
259
+
260
+ try:
261
+ new_text = await asyncio.to_thread(next, stream_iter)
262
+ except StopIteration:
263
+ break
264
+ except QueueEmpty:
265
+ if generation_done.is_set():
266
+ break
267
+ continue
268
+ except Exception as exc: # pragma: no cover - defensive logging
269
+ if generation_done.is_set():
270
+ break
271
+ logger.error("[%s] streamer read error: %s", task.request_id, exc, exc_info=True)
272
+ generation_error["error"] = exc
273
+ break
274
+
275
+ if not new_text:
276
+ continue
277
+
278
+ task.generated_tokens += 1
279
+ if task.first_token_latency is None:
280
+ task.first_token_latency = time.time() - request_start
281
+ logger.info(
282
+ "[%s] first_token=%.2fs worker=%d",
283
+ task.request_id,
284
+ task.first_token_latency,
285
+ worker_id,
286
+ )
287
+
288
+ if debug_token_logs:
289
+ logger.info("[%s] token#%d: %r", task.request_id, task.generated_tokens, new_text)
290
+
291
+ await task.output_queue.put({"type": "token", "content": new_text})
292
+ await asyncio.sleep(0)
293
+
294
+ # Ensure generation thread is not left running in background.
295
+ try:
296
+ await asyncio.wait_for(asyncio.to_thread(generation_thread.join), timeout=join_timeout)
297
+ except asyncio.TimeoutError:
298
+ logger.error(
299
+ "[%s] generation thread still alive after %.1fs join timeout",
300
+ task.request_id,
301
+ join_timeout,
302
+ )
303
+
304
+ if task.cancel_event.is_set():
305
+ await task.output_queue.put({"type": "error", "content": "Generation interrupted. You can continue."})
306
+ elif "error" in generation_error:
307
+ await task.output_queue.put({"type": "error", "content": str(generation_error["error"])})
308
+ else:
309
+ await task.output_queue.put({"type": "done", "content": ""})
310
+
311
+ except Exception as exc:
312
+ logger.error("[%s] worker failure: %s", task.request_id, exc, exc_info=True)
313
+ await task.output_queue.put({"type": "error", "content": str(exc)})
314
+ finally:
315
+ task.end_time = time.time()
316
+ duration = max(1e-6, task.end_time - request_start)
317
+ tps = task.generated_tokens / duration
318
+ logger.info(
319
+ "[%s] worker=%d end tokens=%d duration=%.2fs tok_s=%.2f active_workers=%d queue_size=%d",
320
+ task.request_id,
321
+ worker_id,
322
+ task.generated_tokens,
323
+ duration,
324
+ tps,
325
+ active_workers,
326
+ request_queue.qsize(),
327
+ )
328
+
329
+ await task.output_queue.put(None)
330
+ await set_active_workers(-1)
331
+
332
+ async def worker_loop(worker_id: int) -> None:
333
+ logger.info("Worker-%d started", worker_id)
334
+ while True:
335
+ task = await request_queue.get()
336
+ if task is None:
337
+ request_queue.task_done()
338
+ logger.info("Worker-%d received shutdown signal", worker_id)
339
+ break
340
+
341
+ try:
342
+ if task.cancel_event.is_set():
343
+ await task.output_queue.put({"type": "error", "content": "Request cancelled before execution."})
344
+ await task.output_queue.put(None)
345
+ continue
346
+
347
+ async with worker_semaphore:
348
+ await run_generation(task, worker_id)
349
+ finally:
350
+ request_queue.task_done()
351
+
352
+ logger.info("Worker-%d stopped", worker_id)
353
+
354
+ @asynccontextmanager
355
+ async def lifespan(app: FastAPI):
356
+ nonlocal model, tokenizer, worker_tasks, max_workers, device
357
+
358
+ logger.info("Loading model %s on %s", config.model_name, device)
359
+ tokenizer_kwargs: Dict[str, Any] = {"trust_remote_code": True}
360
+ if config.tokenizer_use_fast is not None:
361
+ tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
362
+ tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
363
+ model = AutoModelForCausalLM.from_pretrained(
364
+ config.model_name,
365
+ trust_remote_code=True,
366
+ torch_dtype="auto" if device == "cuda" else torch.float32,
367
+ device_map="auto" if device == "cuda" else None,
368
+ )
369
+
370
+ if device != "cuda":
371
+ model = model.to("cpu")
372
+
373
+ logger.info(
374
+ "Model loaded: %s | device=%s | max_workers=%d | queue_max_size=%d",
375
+ config.model_name,
376
+ device,
377
+ max_workers,
378
+ queue_max_size,
379
+ )
380
+ logger.info(
381
+ "Runtime config: max_input_tokens=%d max_new_tokens_limit=%d top_p=%.3f top_k=%s rep_penalty=%.3f",
382
+ max_input_tokens,
383
+ max_new_tokens_limit,
384
+ config.top_p,
385
+ str(config.top_k),
386
+ config.repetition_penalty,
387
+ )
388
+
389
+ worker_tasks = [
390
+ asyncio.create_task(worker_loop(i + 1), name=f"generation-worker-{i + 1}")
391
+ for i in range(max_workers)
392
+ ]
393
+
394
+ try:
395
+ yield
396
+ finally:
397
+ logger.info("Shutting down workers...")
398
+ for _ in worker_tasks:
399
+ await request_queue.put(None)
400
+ await asyncio.gather(*worker_tasks, return_exceptions=True)
401
+
402
+ logger.info("Releasing model resources...")
403
+ del model
404
+ del tokenizer
405
+ if torch.cuda.is_available():
406
+ torch.cuda.empty_cache()
407
+
408
+ app = FastAPI(
409
+ title=config.title,
410
+ description=config.description,
411
+ version=config.version,
412
+ lifespan=lifespan,
413
+ )
414
+
415
+ app.add_middleware(
416
+ CORSMiddleware,
417
+ allow_origins=["*"],
418
+ allow_credentials=True,
419
+ allow_methods=["*"],
420
+ allow_headers=["*"],
421
+ )
422
+
423
+ @app.get("/")
424
+ async def root():
425
+ return {
426
+ "name": config.title,
427
+ "version": config.version,
428
+ "model": config.model_name,
429
+ "status": "running",
430
+ "device": device,
431
+ "max_workers": max_workers,
432
+ }
433
+
434
+ @app.get("/index", response_class=FileResponse)
435
+ async def serve_chat():
436
+ return FileResponse(os.path.join(base_dir, "index.html"))
437
+
438
+ @app.get("/health")
439
+ async def health():
440
+ return {
441
+ "status": "healthy",
442
+ "model_loaded": model is not None and tokenizer is not None,
443
+ "device": device,
444
+ "active_workers": active_workers,
445
+ "queue_size": request_queue.qsize(),
446
+ "max_workers": max_workers,
447
+ }
448
+
449
+ @app.post("/chat")
450
+ async def chat(request: ChatRequest):
451
+ if model is None or tokenizer is None:
452
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
453
+
454
+ prompt = format_messages_proper(request.messages, request.tools)
455
+ task = GenerationTask(
456
+ request_id=uuid.uuid4().hex,
457
+ prompt=prompt,
458
+ max_tokens=request.max_tokens,
459
+ temperature=request.temperature if request.temperature is not None else config.default_temperature,
460
+ output_queue=asyncio.Queue(maxsize=2048),
461
+ )
462
+
463
+ logger.info(
464
+ "[%s] queued request prompt_len=%d queue_size=%d",
465
+ task.request_id,
466
+ len(prompt),
467
+ request_queue.qsize(),
468
+ )
469
+ await request_queue.put(task)
470
+
471
+ if request.stream:
472
+ async def stream_events():
473
+ try:
474
+ while True:
475
+ event = await task.output_queue.get()
476
+ if event is None:
477
+ break
478
+ yield _format_sse_event(event)
479
+ except asyncio.CancelledError:
480
+ task.cancel_event.set()
481
+ raise
482
+ finally:
483
+ task.cancel_event.set()
484
+
485
+ return StreamingResponse(
486
+ stream_events(),
487
+ media_type="text/event-stream",
488
+ headers={
489
+ "Cache-Control": "no-cache, no-store, must-revalidate",
490
+ "Pragma": "no-cache",
491
+ "Expires": "0",
492
+ "Connection": "keep-alive",
493
+ "X-Accel-Buffering": "no",
494
+ "Transfer-Encoding": "chunked",
495
+ },
496
+ )
497
+
498
+ chunks: List[str] = []
499
+ error_message: Optional[str] = None
500
+ while True:
501
+ event = await task.output_queue.get()
502
+ if event is None:
503
+ break
504
+ event_type = event.get("type")
505
+ if event_type == "token":
506
+ chunks.append(str(event.get("content", "")))
507
+ elif event_type == "error":
508
+ error_message = str(event.get("content", "Generation failed"))
509
+
510
+ if error_message:
511
+ raise HTTPException(status_code=500, detail=error_message)
512
+
513
+ response_text = "".join(chunks).strip()
514
+ return {
515
+ "content": response_text,
516
+ "usage": {
517
+ "prompt_tokens": task.prompt_tokens,
518
+ "completion_tokens": task.generated_tokens,
519
+ },
520
+ }
521
+
522
+ return app