| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import StreamingResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer |
| import torch |
| import json |
| import time |
| import uuid |
| import re |
| import logging |
| from typing import List, Optional, Dict, Any, Union |
| from threading import Thread |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| MODEL_NAME = "Qwen/Qwen3.5-4B" |
| MODEL_ID = "qwen3.5-4b" |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.float16 |
| ) |
|
|
| print("🔹 加载模型:", MODEL_NAME) |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| low_cpu_mem_usage=True |
| ) |
| except Exception as e: |
| print(f"量化加载失败,尝试普通加载: {e}") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| device_map="auto", |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float16 |
| ) |
|
|
| print("✅ 模型加载完成") |
|
|
| app = FastAPI(title="Qwen3.5-4B API") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| @app.get("/health") |
| async def health(): |
| return {"status": "healthy"} |
|
|
| @app.get("/v1/me") |
| async def get_me(): |
| return { |
| "id": "local-user", |
| "name": "Local User", |
| "email": "user@localhost", |
| "is_admin": True |
| } |
|
|
| @app.get("/v1/dashboard/bots") |
| async def get_bots(): |
| return {"objects": []} |
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| return { |
| "object": "list", |
| "data": [ |
| { |
| "id": MODEL_ID, |
| "object": "model", |
| "created": 1773000000, |
| "owned_by": "qwen" |
| } |
| ] |
| } |
|
|
| |
| class Message(BaseModel): |
| role: str |
| content: Optional[Union[str, List[Dict[str, Any]]]] = None |
|
|
| class ChatRequest(BaseModel): |
| messages: List[Message] |
| temperature: Optional[float] = 0.7 |
| max_tokens: Optional[int] = 1024 |
| model: Optional[str] = MODEL_ID |
| stream: Optional[bool] = False |
| tools: Optional[List[Dict[str, Any]]] = None |
| tool_choice: Optional[str] = None |
|
|
| def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]]) -> str: |
| if content is None: |
| return "" |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| texts = [] |
| for part in content: |
| if isinstance(part, dict) and part.get("type") == "text": |
| texts.append(part.get("text", "")) |
| return "\n".join(texts) |
| return str(content) |
|
|
| |
| def normalize_messages(messages): |
| """将 system 消息合并并移到开头,过滤掉空的 system 消息""" |
| system_msgs = [m for m in messages if m["role"] == "system" and m.get("content")] |
| non_system = [m for m in messages if m["role"] != "system"] |
| if system_msgs: |
| |
| combined_content = "\n\n".join(m["content"] for m in system_msgs) |
| return [{"role": "system", "content": combined_content}] + non_system |
| return messages |
|
|
| |
| def stream_generate(messages, temperature=0.7, max_new_tokens=1024): |
| chunk_id = f"chatcmpl-{uuid.uuid4().hex}" |
| try: |
| |
| messages = normalize_messages(messages) |
| logger.info(f"Starting stream generation with {len(messages)} messages") |
| text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device) |
| input_len = len(inputs.input_ids[0]) |
| logger.info(f"Input token length: {input_len}") |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| timeout=300.0 |
| ) |
|
|
| gen_kwargs = { |
| **inputs, |
| "streamer": streamer, |
| "max_new_tokens": max_new_tokens, |
| "temperature": temperature, |
| "do_sample": temperature > 0, |
| "pad_token_id": tokenizer.pad_token_id, |
| "eos_token_id": tokenizer.eos_token_id |
| } |
|
|
| thread = Thread(target=model.generate, kwargs=gen_kwargs) |
| thread.start() |
|
|
| |
| yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" |
|
|
| |
| token_count = 0 |
| for new_text in streamer: |
| token_count += 1 |
| if new_text: |
| chunk = { |
| "id": chunk_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": MODEL_ID, |
| "choices": [{ |
| "index": 0, |
| "delta": {"content": new_text}, |
| "finish_reason": None |
| }] |
| } |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
| logger.info(f"Generated {token_count} tokens") |
|
|
| |
| yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" |
| yield "data: [DONE]\n\n" |
|
|
| except (BrokenPipeError, ConnectionError, RuntimeError) as client_err: |
| logger.warning(f"Client disconnected during streaming: {client_err}") |
| return |
| except Exception as e: |
| logger.exception("Unexpected streaming error") |
| try: |
| error_chunk = { |
| "id": chunk_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": MODEL_ID, |
| "choices": [{ |
| "index": 0, |
| "delta": {}, |
| "finish_reason": "error" |
| }] |
| } |
| yield f"data: {json.dumps(error_chunk)}\n\n" |
| yield "data: [DONE]\n\n" |
| except: |
| pass |
|
|
| |
| @app.post("/v1/chat/completions") |
| async def chat_completions(req: ChatRequest): |
| |
| converted_messages = [] |
| for m in req.messages: |
| converted_messages.append({ |
| "role": m.role, |
| "content": convert_content_to_str(m.content) |
| }) |
|
|
| |
| if req.tools: |
| tools_json = json.dumps(req.tools, ensure_ascii=False) |
| tool_prompt = f"""你是一个助手,可以使用以下工具: |
| {tools_json} |
| 当用户的问题需要调用工具时,请输出 <tool_call>...</tool_call> 标签,内部是一个 JSON 对象,必须包含 "name" 和 "arguments" 字段。arguments 是一个对象,包含工具所需的参数。 |
| 例如:<tool_call>{{"name": "get_weather", "arguments": {{"location": "Beijing"}}}}</tool_call> |
| 如果不需要调用工具,则正常回答。""" |
|
|
| |
| system_index = None |
| for i, msg in enumerate(converted_messages): |
| if msg["role"] == "system": |
| system_index = i |
| break |
|
|
| if system_index is not None: |
| converted_messages[system_index]["content"] += "\n\n" + tool_prompt |
| messages = converted_messages |
| else: |
| messages = [{"role": "system", "content": tool_prompt}] + converted_messages |
| else: |
| messages = converted_messages |
|
|
| |
| messages = normalize_messages(messages) |
|
|
| |
| if req.stream: |
| return StreamingResponse( |
| stream_generate(messages, req.temperature, req.max_tokens), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Content-Type": "text/event-stream" |
| } |
| ) |
|
|
| |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=req.max_tokens, |
| temperature=req.temperature, |
| do_sample=req.temperature > 0, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
|
|
| response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) |
|
|
| |
| tool_calls = None |
| clean_response = response |
| tool_call_matches = re.findall(r'<tool_call>(.*?)(?:</tool_call>|$)', response, re.DOTALL) |
|
|
| if tool_call_matches: |
| tool_calls = [] |
| for match in tool_call_matches: |
| try: |
| tool_call_data = json.loads(match.strip()) |
| tool_calls.append({ |
| "id": f"call_{uuid.uuid4().hex[:8]}", |
| "type": "function", |
| "function": { |
| "name": tool_call_data.get("name"), |
| "arguments": json.dumps(tool_call_data.get("arguments", {}), ensure_ascii=False) |
| } |
| }) |
| except Exception as e: |
| logger.warning(f"工具调用解析失败: {e}") |
| clean_response = re.sub(r'<tool_call>.*?</tool_call>', '', response, flags=re.DOTALL).strip() |
|
|
| prompt_tokens = len(inputs.input_ids[0]) |
| completion_tokens = len(outputs[0]) - prompt_tokens |
|
|
| return { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": req.model, |
| "choices": [{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": clean_response if not tool_calls else None, |
| "tool_calls": tool_calls |
| }, |
| "finish_reason": "tool_calls" if tool_calls else "stop" |
| }], |
| "usage": { |
| "prompt_tokens": prompt_tokens, |
| "completion_tokens": completion_tokens, |
| "total_tokens": prompt_tokens + completion_tokens |
| } |
| } |
|
|
| @app.get("/") |
| async def root(): |
| return {"status": "running", "model": MODEL_NAME} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860, timeout_keep_alive=1200) |