qw2b / app.py
nagose's picture
Update app.py
ce316b2 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
import torch
import json
import time
import uuid
import re
import logging
from typing import List, Optional, Dict, Any, Union
from threading import Thread
# 配置日志(可选,便于调试)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ====================== 模型配置 ======================
MODEL_NAME = "Qwen/Qwen3.5-2B" # 可根据需要替换
MODEL_ID = "qwen3.5-2b" # 与 CoPaw 配置一致
# 4-bit 量化配置(若内存不足可尝试移除量化)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
print("🔹 加载模型:", MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 优先尝试量化加载,失败则回退到普通加载
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
except Exception as e:
print(f"量化加载失败,尝试普通加载: {e}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
print("✅ 模型加载完成")
app = FastAPI(title="Qwen3.5-2B API (OpenAI 兼容)")
# ====================== CORS 中间件(CoPaw 必须) ======================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ====================== CoPaw 所需额外端点 ======================
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/v1/me")
async def get_me():
return {
"id": "local-user",
"name": "Local User",
"email": "user@localhost",
"is_admin": True
}
@app.get("/v1/dashboard/bots")
async def get_bots():
return {"objects": []}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": MODEL_ID,
"object": "model",
"created": 1773000000,
"owned_by": "qwen"
}
]
}
# ====================== 请求/响应数据模型 ======================
class Message(BaseModel):
role: str
# content 可以是字符串或 OpenAI 结构化数组
content: Optional[Union[str, List[Dict[str, Any]]]] = None
class ChatRequest(BaseModel):
messages: List[Message]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 1024
model: Optional[str] = MODEL_ID
stream: Optional[bool] = False
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[str] = None
def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]]) -> str:
"""将 OpenAI 结构化 content 转换为纯文本字符串"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
texts.append(part.get("text", ""))
return "\n".join(texts)
return str(content)
# ====================== 流式生成(同步生成器,避免 async 阻塞)======================
def stream_generate(messages, temperature=0.7, max_new_tokens=1024):
chunk_id = f"chatcmpl-{uuid.uuid4().hex}"
try:
# 构建提示词
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=60.0
)
gen_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": temperature > 0,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id
}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
# 发送角色块(可选,但 OpenAI 通常发送)
yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
# 发送内容块
for new_text in streamer:
if new_text:
chunk = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {"content": new_text},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 发送结束块
yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {e}")
error_chunk = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "error"
}]
}
yield f"data: {json.dumps(error_chunk)}\n\n"
yield "data: [DONE]\n\n"
# ====================== 聊天接口(支持工具调用)======================
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
# 1. 转换消息 content 为字符串
converted_messages = []
for m in req.messages:
converted_messages.append({
"role": m.role,
"content": convert_content_to_str(m.content)
})
# 2. 处理 tools,合并 system 消息(保证只有一条 system 且在开头)
if req.tools:
tools_json = json.dumps(req.tools, ensure_ascii=False)
tool_prompt = f"""你是一个助手,可以使用以下工具:
{tools_json}
当用户的问题需要调用工具时,请输出 <tool_call>...</tool_call> 标签,内部是一个 JSON 对象,必须包含 "name" 和 "arguments" 字段。arguments 是一个对象,包含工具所需的参数。
例如:<tool_call>{{"name": "get_weather", "arguments": {{"location": "Beijing"}}}}</tool_call>
如果不需要调用工具,则正常回答。"""
system_index = None
for i, msg in enumerate(converted_messages):
if msg["role"] == "system":
system_index = i
break
if system_index is not None:
# 将工具提示附加到现有 system 消息后
converted_messages[system_index]["content"] += "\n\n" + tool_prompt
messages = converted_messages
else:
# 没有 system 消息,则在开头插入新的 system 消息
messages = [{"role": "system", "content": tool_prompt}] + converted_messages
else:
messages = converted_messages
# 3. 流式处理
if req.stream:
return StreamingResponse(
stream_generate(messages, req.temperature, req.max_tokens),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream"
}
)
# 4. 非流式生成
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=req.max_tokens,
temperature=req.temperature,
do_sample=req.temperature > 0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
# 5. 解析工具调用
tool_calls = None
clean_response = response
tool_call_matches = re.findall(r'<tool_call>(.*?)(?:</tool_call>|$)', response, re.DOTALL)
if tool_call_matches:
tool_calls = []
for match in tool_call_matches:
try:
tool_call_data = json.loads(match.strip())
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": tool_call_data.get("name"),
"arguments": json.dumps(tool_call_data.get("arguments", {}), ensure_ascii=False)
}
})
except Exception as e:
print(f"工具调用解析失败: {e}")
# 移除所有 tool_call 标签,保留剩余文本
clean_response = re.sub(r'<tool_call>.*?</tool_call>', '', response, flags=re.DOTALL).strip()
# 6. 构建响应
prompt_tokens = len(inputs.input_ids[0])
completion_tokens = len(outputs[0]) - prompt_tokens
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": clean_response if not tool_calls else None,
"tool_calls": tool_calls
},
"finish_reason": "tool_calls" if tool_calls else "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
@app.get("/")
async def root():
return {"status": "running", "model": MODEL_NAME}
if __name__ == "__main__":
import uvicorn
# 可适当增加 keep-alive 超时,避免连接意外关闭
uvicorn.run(app, host="0.0.0.0", port=7860, timeout_keep_alive=120)