Ashok75 commited on
Commit
90937b1
·
verified ·
1 Parent(s): 90e69ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +607 -38
app.py CHANGED
@@ -1,53 +1,622 @@
 
 
 
 
 
 
 
1
  import torch
2
- from flask import Flask, request, Response, render_template
 
 
 
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- app = Flask(__name__)
 
 
 
 
 
 
7
 
8
- # Load Nanbeige 4.1 3B
9
- model_id = "Nanbeige/Nanbeige4.1-3B"
10
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_id,
13
- torch_dtype=torch.bfloat16,
14
- device_map="auto",
15
- trust_remote_code=True
 
 
 
 
 
 
 
16
  )
17
 
18
- @app.route('/chat', methods=['POST'])
19
- def chat():
20
- user_msg = request.json.get("message")
21
-
22
- # System Prompt Construction [14, 32]
23
- prompt = f"<|system|>\nYou are an GAKR AI ASSISTANT. Always think before answering.dont think heavely and answer directly if you know the answer and if you want any latest content or anything call the web_search tool to get the content like latest data and web data and all\n<|user|>\n{user_msg}\n<|assistant|>\n<thought>"
24
-
25
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
26
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
27
-
28
- generation_kwargs = dict(
29
- **inputs,
30
- streamer=streamer,
31
- max_new_tokens=1024,
32
- do_sample=True,
33
- temperature=0.7,
34
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
38
  thread.start()
39
 
40
- def stream():
41
- # Start with the tag we forced in the prompt
42
- yield "<thought>"
43
- for new_text in streamer:
44
- yield new_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- return Response(stream(), mimetype='text/plain')
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- @app.route('/')
49
- def index():
50
- return render_template('index.html')
51
 
52
- if __name__ == '__main__':
53
- app.run(host='0.0.0.0', port=7860)
 
 
 
1
+ """
2
+ Nanbeige4.1-3B Inference Server for Hugging Face Space
3
+ Lightweight API server exposing /chat endpoint for remote agent communication
4
+ """
5
+
6
+ import os
7
+ import json
8
  import torch
9
+ from typing import AsyncGenerator, Dict, List, Optional
10
+ from contextlib import asynccontextmanager
11
+ from fastapi import FastAPI, Request, HTTPException
12
+ from fastapi.responses import StreamingResponse, HTMLResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel, Field
15
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
16
  from threading import Thread
17
+ import asyncio
18
+
19
+ # Global model instances
20
+ model = None
21
+ tokenizer = None
22
+
23
+ # Model configuration
24
+ MODEL_ID = "Nanbeige/Nanbeige4.1-3B"
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ DEFAULT_MAX_TOKENS = 2048
27
+ DEFAULT_TEMPERATURE = 0.6
28
+ DEFAULT_TOP_P = 0.95
29
+
30
+
31
+ class ChatMessage(BaseModel):
32
+ role: str = Field(..., description="Message role: system, user, assistant, or tool")
33
+ content: str = Field(..., description="Message content")
34
+ tool_calls: Optional[List[Dict]] = Field(None, description="Tool calls from assistant")
35
+ tool_call_id: Optional[str] = Field(None, description="Tool call ID for tool responses")
36
+
37
+
38
+ class ChatRequest(BaseModel):
39
+ messages: List[ChatMessage] = Field(..., description="Conversation history")
40
+ tools: Optional[List[Dict]] = Field(None, description="Available tools for function calling")
41
+ stream: bool = Field(default=False, description="Enable streaming response")
42
+ max_tokens: int = Field(default=DEFAULT_MAX_TOKENS, ge=1, le=8192)
43
+ temperature: float = Field(default=DEFAULT_TEMPERATURE, ge=0.0, le=2.0)
44
+ top_p: float = Field(default=DEFAULT_TOP_P, ge=0.0, le=1.0)
45
+ stop: Optional[List[str]] = Field(default=None, description="Stop sequences")
46
+
47
+
48
+ class ChatResponse(BaseModel):
49
+ id: str
50
+ object: str = "chat.completion"
51
+ created: int
52
+ model: str
53
+ choices: List[Dict]
54
+ usage: Optional[Dict] = None
55
+
56
+
57
+ def load_model():
58
+ """Load Nanbeige4.1-3B model and tokenizer."""
59
+ global model, tokenizer
60
+
61
+ print(f"Loading {MODEL_ID} on {DEVICE}...")
62
+
63
+ tokenizer = AutoTokenizer.from_pretrained(
64
+ MODEL_ID,
65
+ trust_remote_code=True,
66
+ padding_side="left"
67
+ )
68
+
69
+ # Set pad token if not present
70
+ if tokenizer.pad_token is None:
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ MODEL_ID,
75
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
76
+ device_map="auto" if DEVICE == "cuda" else None,
77
+ trust_remote_code=True,
78
+ low_cpu_mem_usage=True
79
+ )
80
+
81
+ if DEVICE == "cpu":
82
+ model = model.to(DEVICE)
83
+
84
+ model.eval()
85
+ print(f"Model loaded successfully on {DEVICE}")
86
+
87
 
88
+ @asynccontextmanager
89
+ async def lifespan(app: FastAPI):
90
+ """Application lifespan manager."""
91
+ # Startup
92
+ load_model()
93
+ yield
94
+ # Shutdown - cleanup happens automatically
95
 
96
+
97
+ app = FastAPI(
98
+ title="Nanbeige4.1-3B Inference API",
99
+ description="Remote LLM inference service for Enterprise ReAct Agent",
100
+ version="1.0.0",
101
+ lifespan=lifespan
102
+ )
103
+
104
+ # CORS for local agent communication
105
+ app.add_middleware(
106
+ CORSMiddleware,
107
+ allow_origins=["*"], # Configure for production
108
+ allow_credentials=True,
109
+ allow_methods=["*"],
110
+ allow_headers=["*"],
111
  )
112
 
113
+
114
+ def format_messages_for_model(messages: List[ChatMessage], tools: Optional[List[Dict]] = None) -> str:
115
+ """Format messages using Nanbeige chat template."""
116
+ formatted_messages = []
117
+
118
+ for msg in messages:
119
+ if msg.role == "system":
120
+ formatted_messages.append({"role": "system", "content": msg.content})
121
+ elif msg.role == "user":
122
+ formatted_messages.append({"role": "user", "content": msg.content})
123
+ elif msg.role == "assistant":
124
+ content = msg.content
125
+ if msg.tool_calls:
126
+ # Append tool calls to content
127
+ tool_calls_str = json.dumps(msg.tool_calls)
128
+ content = f"{content}\n<tool_calls>{tool_calls_str}</tool_calls>"
129
+ formatted_messages.append({"role": "assistant", "content": content})
130
+ elif msg.role == "tool":
131
+ formatted_messages.append({
132
+ "role": "tool",
133
+ "content": msg.content,
134
+ "tool_call_id": msg.tool_call_id
135
+ })
136
+
137
+ # Add tools to system message if provided
138
+ if tools:
139
+ tools_description = "\n\nAvailable tools:\n" + json.dumps(tools, indent=2)
140
+ if formatted_messages and formatted_messages[0]["role"] == "system":
141
+ formatted_messages[0]["content"] += tools_description
142
+ else:
143
+ formatted_messages.insert(0, {"role": "system", "content": tools_description})
144
+
145
+ # Apply chat template
146
+ prompt = tokenizer.apply_chat_template(
147
+ formatted_messages,
148
+ tokenize=False,
149
+ add_generation_prompt=True
150
  )
151
 
152
+ return prompt
153
+
154
+
155
+ def parse_tool_calls(response_text: str) -> tuple[str, Optional[List[Dict]]]:
156
+ """Parse tool calls from model response."""
157
+ tool_calls = None
158
+ content = response_text
159
+
160
+ # Look for tool_calls in the response
161
+ if "<tool_calls>" in response_text and "</tool_calls>" in response_text:
162
+ try:
163
+ start = response_text.find("<tool_calls>") + len("<tool_calls>")
164
+ end = response_text.find("</tool_calls>")
165
+ tool_calls_json = response_text[start:end]
166
+ tool_calls = json.loads(tool_calls_json)
167
+ content = response_text[:response_text.find("<tool_calls>")].strip()
168
+ except (json.JSONDecodeError, ValueError):
169
+ pass
170
+
171
+ return content, tool_calls
172
+
173
+
174
+ def generate_stream(
175
+ prompt: str,
176
+ max_tokens: int,
177
+ temperature: float,
178
+ top_p: float,
179
+ stop: Optional[List[str]]
180
+ ) -> AsyncGenerator[str, None]:
181
+ """Generate streaming response."""
182
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
183
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
184
+
185
+ streamer = TextIteratorStreamer(
186
+ tokenizer,
187
+ skip_prompt=True,
188
+ skip_special_tokens=True
189
+ )
190
+
191
+ generation_kwargs = {
192
+ "input_ids": inputs["input_ids"],
193
+ "attention_mask": inputs["attention_mask"],
194
+ "max_new_tokens": max_tokens,
195
+ "temperature": temperature,
196
+ "top_p": top_p,
197
+ "do_sample": temperature > 0,
198
+ "streamer": streamer,
199
+ "pad_token_id": tokenizer.pad_token_id,
200
+ "eos_token_id": tokenizer.eos_token_id,
201
+ }
202
+
203
+ if stop:
204
+ generation_kwargs["stopping_criteria"] = create_stopping_criteria(stop)
205
+
206
+ # Run generation in separate thread
207
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
208
  thread.start()
209
 
210
+ generated_text = ""
211
+ for new_text in streamer:
212
+ generated_text += new_text
213
+ # Check for stop sequences
214
+ if stop:
215
+ for s in stop:
216
+ if s in generated_text:
217
+ generated_text = generated_text[:generated_text.find(s)]
218
+ break
219
+
220
+ yield new_text
221
+
222
+ thread.join()
223
+
224
+
225
+ def create_stopping_criteria(stop_sequences: List[str]):
226
+ """Create stopping criteria for generation."""
227
+ from transformers import StoppingCriteria, StoppingCriteriaList
228
+
229
+ class StopSequenceCriteria(StoppingCriteria):
230
+ def __init__(self, stops, tokenizer):
231
+ self.stops = stops
232
+ self.tokenizer = tokenizer
233
+
234
+ def __call__(self, input_ids, scores, **kwargs):
235
+ generated = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
236
+ for stop in self.stops:
237
+ if stop in generated:
238
+ return True
239
+ return False
240
+
241
+ return StoppingCriteriaList([StopSequenceCriteria(stop_sequences, tokenizer)])
242
+
243
+
244
+ def generate_non_stream(
245
+ prompt: str,
246
+ max_tokens: int,
247
+ temperature: float,
248
+ top_p: float,
249
+ stop: Optional[List[str]]
250
+ ) -> str:
251
+ """Generate non-streaming response."""
252
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
253
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
254
+
255
+ with torch.no_grad():
256
+ outputs = model.generate(
257
+ input_ids=inputs["input_ids"],
258
+ attention_mask=inputs["attention_mask"],
259
+ max_new_tokens=max_tokens,
260
+ temperature=temperature,
261
+ top_p=top_p,
262
+ do_sample=temperature > 0,
263
+ pad_token_id=tokenizer.pad_token_id,
264
+ eos_token_id=tokenizer.eos_token_id,
265
+ )
266
+
267
+ generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
268
+
269
+ # Apply stop sequences
270
+ if stop:
271
+ for s in stop:
272
+ if s in generated:
273
+ generated = generated[:generated.find(s)]
274
+ break
275
+
276
+ return generated
277
+
278
+
279
+ @app.post("/chat", response_model=ChatResponse)
280
+ async def chat_completion(request: ChatRequest):
281
+ """
282
+ Main chat completion endpoint.
283
+ Compatible with OpenAI-style API for easy integration.
284
+ """
285
+ import time
286
+
287
+ prompt = format_messages_for_model(request.messages, request.tools)
288
+
289
+ if request.stream:
290
+ async def stream_response():
291
+ generated = ""
292
+ async for chunk in generate_stream(
293
+ prompt,
294
+ request.max_tokens,
295
+ request.temperature,
296
+ request.top_p,
297
+ request.stop
298
+ ):
299
+ generated += chunk
300
+ data = {
301
+ "id": f"chatcmpl-{int(time.time())}",
302
+ "object": "chat.completion.chunk",
303
+ "created": int(time.time()),
304
+ "model": MODEL_ID,
305
+ "choices": [{
306
+ "index": 0,
307
+ "delta": {"content": chunk},
308
+ "finish_reason": None
309
+ }]
310
+ }
311
+ yield f"data: {json.dumps(data)}\n\n"
312
+
313
+ # Final chunk
314
+ content, tool_calls = parse_tool_calls(generated)
315
+ final_data = {
316
+ "id": f"chatcmpl-{int(time.time())}",
317
+ "object": "chat.completion.chunk",
318
+ "created": int(time.time()),
319
+ "model": MODEL_ID,
320
+ "choices": [{
321
+ "index": 0,
322
+ "delta": {},
323
+ "finish_reason": "stop"
324
+ }]
325
+ }
326
+ yield f"data: {json.dumps(final_data)}\n\n"
327
+ yield "data: [DONE]\n\n"
328
+
329
+ return StreamingResponse(
330
+ stream_response(),
331
+ media_type="text/event-stream",
332
+ headers={
333
+ "Cache-Control": "no-cache",
334
+ "Connection": "keep-alive",
335
+ "X-Accel-Buffering": "no"
336
+ }
337
+ )
338
+
339
+ else:
340
+ generated = generate_non_stream(
341
+ prompt,
342
+ request.max_tokens,
343
+ request.temperature,
344
+ request.top_p,
345
+ request.stop
346
+ )
347
+
348
+ content, tool_calls = parse_tool_calls(generated)
349
+
350
+ # Calculate token usage
351
+ input_tokens = len(tokenizer.encode(prompt))
352
+ output_tokens = len(tokenizer.encode(generated))
353
+
354
+ response = ChatResponse(
355
+ id=f"chatcmpl-{int(time.time())}",
356
+ object="chat.completion",
357
+ created=int(time.time()),
358
+ model=MODEL_ID,
359
+ choices=[{
360
+ "index": 0,
361
+ "message": {
362
+ "role": "assistant",
363
+ "content": content,
364
+ "tool_calls": tool_calls
365
+ },
366
+ "finish_reason": "stop"
367
+ }],
368
+ usage={
369
+ "prompt_tokens": input_tokens,
370
+ "completion_tokens": output_tokens,
371
+ "total_tokens": input_tokens + output_tokens
372
+ }
373
+ )
374
+
375
+ return response
376
+
377
+
378
+ @app.get("/chat", response_class=HTMLResponse)
379
+ async def chat_interface():
380
+ """Simple web interface for testing."""
381
+ return """
382
+ <!DOCTYPE html>
383
+ <html lang="en">
384
+ <head>
385
+ <meta charset="UTF-8">
386
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
387
+ <title>Nanbeige4.1-3B Chat</title>
388
+ <style>
389
+ * { margin: 0; padding: 0; box-sizing: border-box; }
390
+ body {
391
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
392
+ background: #1a1a2e;
393
+ color: #eee;
394
+ min-height: 100vh;
395
+ display: flex;
396
+ flex-direction: column;
397
+ }
398
+ header {
399
+ background: #16213e;
400
+ padding: 1rem 2rem;
401
+ border-bottom: 1px solid #0f3460;
402
+ }
403
+ header h1 { font-size: 1.25rem; color: #e94560; }
404
+ header p { font-size: 0.875rem; color: #888; margin-top: 0.25rem; }
405
+ .chat-container {
406
+ flex: 1;
407
+ display: flex;
408
+ flex-direction: column;
409
+ max-width: 900px;
410
+ width: 100%;
411
+ margin: 0 auto;
412
+ padding: 1rem;
413
+ }
414
+ .messages {
415
+ flex: 1;
416
+ overflow-y: auto;
417
+ padding: 1rem;
418
+ display: flex;
419
+ flex-direction: column;
420
+ gap: 1rem;
421
+ }
422
+ .message {
423
+ max-width: 80%;
424
+ padding: 1rem;
425
+ border-radius: 12px;
426
+ line-height: 1.6;
427
+ }
428
+ .message.user {
429
+ align-self: flex-end;
430
+ background: #e94560;
431
+ color: white;
432
+ }
433
+ .message.assistant {
434
+ align-self: flex-start;
435
+ background: #16213e;
436
+ border: 1px solid #0f3460;
437
+ }
438
+ .message.system {
439
+ align-self: center;
440
+ background: #0f3460;
441
+ font-size: 0.875rem;
442
+ color: #888;
443
+ }
444
+ .input-area {
445
+ display: flex;
446
+ gap: 0.5rem;
447
+ padding: 1rem;
448
+ background: #16213e;
449
+ border-top: 1px solid #0f3460;
450
+ }
451
+ textarea {
452
+ flex: 1;
453
+ padding: 0.75rem 1rem;
454
+ border: 1px solid #0f3460;
455
+ border-radius: 8px;
456
+ background: #1a1a2e;
457
+ color: #eee;
458
+ font-size: 1rem;
459
+ resize: none;
460
+ min-height: 50px;
461
+ max-height: 150px;
462
+ }
463
+ textarea:focus {
464
+ outline: none;
465
+ border-color: #e94560;
466
+ }
467
+ button {
468
+ padding: 0.75rem 1.5rem;
469
+ background: #e94560;
470
+ color: white;
471
+ border: none;
472
+ border-radius: 8px;
473
+ cursor: pointer;
474
+ font-size: 1rem;
475
+ transition: background 0.2s;
476
+ }
477
+ button:hover { background: #d63d56; }
478
+ button:disabled { background: #666; cursor: not-allowed; }
479
+ .loading {
480
+ display: inline-block;
481
+ width: 20px;
482
+ height: 20px;
483
+ border: 2px solid #0f3460;
484
+ border-top-color: #e94560;
485
+ border-radius: 50%;
486
+ animation: spin 1s linear infinite;
487
+ }
488
+ @keyframes spin { to { transform: rotate(360deg); } }
489
+ .tool-calls {
490
+ margin-top: 0.5rem;
491
+ padding: 0.5rem;
492
+ background: #0f3460;
493
+ border-radius: 6px;
494
+ font-size: 0.8rem;
495
+ font-family: monospace;
496
+ }
497
+ </style>
498
+ </head>
499
+ <body>
500
+ <header>
501
+ <h1>Nanbeige4.1-3B Inference Server</h1>
502
+ <p>Remote LLM service for Enterprise ReAct Agent</p>
503
+ </header>
504
+ <div class="chat-container">
505
+ <div class="messages" id="messages"></div>
506
+ <div class="input-area">
507
+ <textarea id="input" placeholder="Type your message..." rows="1"></textarea>
508
+ <button id="send" onclick="sendMessage()">Send</button>
509
+ </div>
510
+ </div>
511
+
512
+ <script>
513
+ const messages = document.getElementById('messages');
514
+ const input = document.getElementById('input');
515
+ const sendBtn = document.getElementById('send');
516
+ let conversation = [];
517
+
518
+ // Auto-resize textarea
519
+ input.addEventListener('input', () => {
520
+ input.style.height = 'auto';
521
+ input.style.height = Math.min(input.scrollHeight, 150) + 'px';
522
+ });
523
+
524
+ // Enter to send, Shift+Enter for new line
525
+ input.addEventListener('keydown', (e) => {
526
+ if (e.key === 'Enter' && !e.shiftKey) {
527
+ e.preventDefault();
528
+ sendMessage();
529
+ }
530
+ });
531
+
532
+ function addMessage(role, content, toolCalls = null) {
533
+ const div = document.createElement('div');
534
+ div.className = `message ${role}`;
535
+ div.textContent = content;
536
+ if (toolCalls) {
537
+ const toolDiv = document.createElement('div');
538
+ toolDiv.className = 'tool-calls';
539
+ toolDiv.textContent = 'Tool calls: ' + JSON.stringify(toolCalls, null, 2);
540
+ div.appendChild(toolDiv);
541
+ }
542
+ messages.appendChild(div);
543
+ messages.scrollTop = messages.scrollHeight;
544
+ }
545
+
546
+ async function sendMessage() {
547
+ const text = input.value.trim();
548
+ if (!text) return;
549
+
550
+ addMessage('user', text);
551
+ conversation.push({ role: 'user', content: text });
552
+ input.value = '';
553
+ input.style.height = 'auto';
554
+ sendBtn.disabled = true;
555
+ sendBtn.innerHTML = '<span class="loading"></span>';
556
+
557
+ try {
558
+ const response = await fetch('/chat', {
559
+ method: 'POST',
560
+ headers: { 'Content-Type': 'application/json' },
561
+ body: JSON.stringify({
562
+ messages: conversation,
563
+ stream: false,
564
+ max_tokens: 2048,
565
+ temperature: 0.6
566
+ })
567
+ });
568
+
569
+ const data = await response.json();
570
+ const assistantMsg = data.choices[0].message;
571
+
572
+ addMessage('assistant', assistantMsg.content, assistantMsg.tool_calls);
573
+ conversation.push({
574
+ role: 'assistant',
575
+ content: assistantMsg.content,
576
+ tool_calls: assistantMsg.tool_calls
577
+ });
578
+ } catch (error) {
579
+ addMessage('system', 'Error: ' + error.message);
580
+ } finally {
581
+ sendBtn.disabled = false;
582
+ sendBtn.textContent = 'Send';
583
+ }
584
+ }
585
+
586
+ // Initial system message
587
+ addMessage('system', 'Welcome! The model is ready for inference.');
588
+ </script>
589
+ </body>
590
+ </html>
591
+ """
592
+
593
+
594
+ @app.get("/health")
595
+ async def health_check():
596
+ """Health check endpoint."""
597
+ return {
598
+ "status": "healthy",
599
+ "model": MODEL_ID,
600
+ "device": DEVICE,
601
+ "model_loaded": model is not None and tokenizer is not None
602
+ }
603
+
604
 
605
+ @app.get("/")
606
+ async def root():
607
+ """Root endpoint - redirect to chat interface."""
608
+ return {
609
+ "message": "Nanbeige4.1-3B Inference Server",
610
+ "endpoints": {
611
+ "chat": "/chat (POST for API, GET for web interface)",
612
+ "health": "/health"
613
+ },
614
+ "model": MODEL_ID,
615
+ "device": DEVICE
616
+ }
617
 
 
 
 
618
 
619
+ if __name__ == "__main__":
620
+ import uvicorn
621
+ port = int(os.environ.get("PORT", 7860))
622
+ uvicorn.run(app, host="0.0.0.0", port=port)