from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer import torch import json import time import uuid import re import logging from typing import List, Optional, Dict, Any, Union from threading import Thread # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # ====================== 模型配置 ====================== MODEL_NAME = "Qwen/Qwen3.5-4B" # 0.8B 参数,速度更快 MODEL_ID = "qwen3.5-4b" # CoPaw 中配置的模型名称 bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) print("🔹 加载模型:", MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True ) except Exception as e: print(f"量化加载失败,尝试普通加载: {e}") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16 ) print("✅ 模型加载完成") app = FastAPI(title="Qwen3.5-4B API") # ====================== CORS 中间件 ====================== app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ====================== CoPaw 所需端点 ====================== @app.get("/health") async def health(): return {"status": "healthy"} @app.get("/v1/me") async def get_me(): return { "id": "local-user", "name": "Local User", "email": "user@localhost", "is_admin": True } @app.get("/v1/dashboard/bots") async def get_bots(): return {"objects": []} @app.get("/v1/models") async def list_models(): return { "object": "list", "data": [ { "id": MODEL_ID, "object": "model", "created": 1773000000, "owned_by": "qwen" } ] } # ====================== 请求/响应模型 ====================== class Message(BaseModel): role: str content: Optional[Union[str, List[Dict[str, Any]]]] = None class ChatRequest(BaseModel): messages: List[Message] temperature: Optional[float] = 0.7 max_tokens: Optional[int] = 1024 model: Optional[str] = MODEL_ID stream: Optional[bool] = False tools: Optional[List[Dict[str, Any]]] = None tool_choice: Optional[str] = None def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]]) -> str: if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): texts = [] for part in content: if isinstance(part, dict) and part.get("type") == "text": texts.append(part.get("text", "")) return "\n".join(texts) return str(content) # ====================== 核心修复:确保 system 消息在开头且唯一 ====================== def normalize_messages(messages): """将 system 消息合并并移到开头,过滤掉空的 system 消息""" system_msgs = [m for m in messages if m["role"] == "system" and m.get("content")] non_system = [m for m in messages if m["role"] != "system"] if system_msgs: # 合并所有 system 消息内容 combined_content = "\n\n".join(m["content"] for m in system_msgs) return [{"role": "system", "content": combined_content}] + non_system return messages # ====================== 流式生成(原始流式,使用 TextIteratorStreamer)====================== def stream_generate(messages, temperature=0.7, max_new_tokens=1024): chunk_id = f"chatcmpl-{uuid.uuid4().hex}" try: # 确保 system 消息顺序正确 messages = normalize_messages(messages) logger.info(f"Starting stream generation with {len(messages)} messages") text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device) input_len = len(inputs.input_ids[0]) logger.info(f"Input token length: {input_len}") streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0 ) gen_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "do_sample": temperature > 0, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id } thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() # 发送角色块 yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" # 发送内容块 token_count = 0 for new_text in streamer: token_count += 1 if new_text: chunk = { "id": chunk_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{ "index": 0, "delta": {"content": new_text}, "finish_reason": None }] } yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" logger.info(f"Generated {token_count} tokens") # 发送结束块 yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" yield "data: [DONE]\n\n" except (BrokenPipeError, ConnectionError, RuntimeError) as client_err: logger.warning(f"Client disconnected during streaming: {client_err}") return except Exception as e: logger.exception("Unexpected streaming error") try: error_chunk = { "id": chunk_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{ "index": 0, "delta": {}, "finish_reason": "error" }] } yield f"data: {json.dumps(error_chunk)}\n\n" yield "data: [DONE]\n\n" except: pass # ====================== 聊天接口 ======================= @app.post("/v1/chat/completions") async def chat_completions(req: ChatRequest): # 1. 转换消息 content 为字符串 converted_messages = [] for m in req.messages: converted_messages.append({ "role": m.role, "content": convert_content_to_str(m.content) }) # 2. 处理 tools,合并 system 消息 if req.tools: tools_json = json.dumps(req.tools, ensure_ascii=False) tool_prompt = f"""你是一个助手,可以使用以下工具: {tools_json} 当用户的问题需要调用工具时,请输出 ... 标签,内部是一个 JSON 对象,必须包含 "name" 和 "arguments" 字段。arguments 是一个对象,包含工具所需的参数。 例如:{{"name": "get_weather", "arguments": {{"location": "Beijing"}}}} 如果不需要调用工具,则正常回答。""" # 查找现有 system 消息 system_index = None for i, msg in enumerate(converted_messages): if msg["role"] == "system": system_index = i break if system_index is not None: converted_messages[system_index]["content"] += "\n\n" + tool_prompt messages = converted_messages else: messages = [{"role": "system", "content": tool_prompt}] + converted_messages else: messages = converted_messages # 3. 最终规范化:合并 system 消息并移到开头 messages = normalize_messages(messages) # 4. 流式处理 if req.stream: return StreamingResponse( stream_generate(messages, req.temperature, req.max_tokens), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream" } ) # 5. 非流式生成 text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=req.max_tokens, temperature=req.temperature, do_sample=req.temperature > 0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) # 6. 解析工具调用 tool_calls = None clean_response = response tool_call_matches = re.findall(r'(.*?)(?:|$)', response, re.DOTALL) if tool_call_matches: tool_calls = [] for match in tool_call_matches: try: tool_call_data = json.loads(match.strip()) tool_calls.append({ "id": f"call_{uuid.uuid4().hex[:8]}", "type": "function", "function": { "name": tool_call_data.get("name"), "arguments": json.dumps(tool_call_data.get("arguments", {}), ensure_ascii=False) } }) except Exception as e: logger.warning(f"工具调用解析失败: {e}") clean_response = re.sub(r'.*?', '', response, flags=re.DOTALL).strip() prompt_tokens = len(inputs.input_ids[0]) completion_tokens = len(outputs[0]) - prompt_tokens return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": req.model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": clean_response if not tool_calls else None, "tool_calls": tool_calls }, "finish_reason": "tool_calls" if tool_calls else "stop" }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens } } @app.get("/") async def root(): return {"status": "running", "model": MODEL_NAME} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, timeout_keep_alive=1200)