mc4b / app.py
nagose's picture
Update app.py
24bf9c2 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
import torch
import json
import time
import uuid
import re
import logging
from typing import List, Optional, Dict, Any, Union
from threading import Thread
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# ====================== 模型配置 ======================
MODEL_NAME = "Qwen/Qwen3.5-4B" # 0.8B 参数,速度更快
MODEL_ID = "qwen3.5-4b" # CoPaw 中配置的模型名称
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
print("🔹 加载模型:", MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
except Exception as e:
print(f"量化加载失败,尝试普通加载: {e}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
print("✅ 模型加载完成")
app = FastAPI(title="Qwen3.5-4B API")
# ====================== CORS 中间件 ======================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ====================== CoPaw 所需端点 ======================
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/v1/me")
async def get_me():
return {
"id": "local-user",
"name": "Local User",
"email": "user@localhost",
"is_admin": True
}
@app.get("/v1/dashboard/bots")
async def get_bots():
return {"objects": []}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": MODEL_ID,
"object": "model",
"created": 1773000000,
"owned_by": "qwen"
}
]
}
# ====================== 请求/响应模型 ======================
class Message(BaseModel):
role: str
content: Optional[Union[str, List[Dict[str, Any]]]] = None
class ChatRequest(BaseModel):
messages: List[Message]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 1024
model: Optional[str] = MODEL_ID
stream: Optional[bool] = False
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[str] = None
def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]]) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
texts.append(part.get("text", ""))
return "\n".join(texts)
return str(content)
# ====================== 核心修复:确保 system 消息在开头且唯一 ======================
def normalize_messages(messages):
"""将 system 消息合并并移到开头,过滤掉空的 system 消息"""
system_msgs = [m for m in messages if m["role"] == "system" and m.get("content")]
non_system = [m for m in messages if m["role"] != "system"]
if system_msgs:
# 合并所有 system 消息内容
combined_content = "\n\n".join(m["content"] for m in system_msgs)
return [{"role": "system", "content": combined_content}] + non_system
return messages
# ====================== 流式生成(原始流式,使用 TextIteratorStreamer)======================
def stream_generate(messages, temperature=0.7, max_new_tokens=1024):
chunk_id = f"chatcmpl-{uuid.uuid4().hex}"
try:
# 确保 system 消息顺序正确
messages = normalize_messages(messages)
logger.info(f"Starting stream generation with {len(messages)} messages")
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device)
input_len = len(inputs.input_ids[0])
logger.info(f"Input token length: {input_len}")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=300.0
)
gen_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": temperature > 0,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id
}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
# 发送角色块
yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
# 发送内容块
token_count = 0
for new_text in streamer:
token_count += 1
if new_text:
chunk = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {"content": new_text},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
logger.info(f"Generated {token_count} tokens")
# 发送结束块
yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
except (BrokenPipeError, ConnectionError, RuntimeError) as client_err:
logger.warning(f"Client disconnected during streaming: {client_err}")
return
except Exception as e:
logger.exception("Unexpected streaming error")
try:
error_chunk = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "error"
}]
}
yield f"data: {json.dumps(error_chunk)}\n\n"
yield "data: [DONE]\n\n"
except:
pass
# ====================== 聊天接口 =======================
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
# 1. 转换消息 content 为字符串
converted_messages = []
for m in req.messages:
converted_messages.append({
"role": m.role,
"content": convert_content_to_str(m.content)
})
# 2. 处理 tools,合并 system 消息
if req.tools:
tools_json = json.dumps(req.tools, ensure_ascii=False)
tool_prompt = f"""你是一个助手,可以使用以下工具:
{tools_json}
当用户的问题需要调用工具时,请输出 <tool_call>...</tool_call> 标签,内部是一个 JSON 对象,必须包含 "name" 和 "arguments" 字段。arguments 是一个对象,包含工具所需的参数。
例如:<tool_call>{{"name": "get_weather", "arguments": {{"location": "Beijing"}}}}</tool_call>
如果不需要调用工具,则正常回答。"""
# 查找现有 system 消息
system_index = None
for i, msg in enumerate(converted_messages):
if msg["role"] == "system":
system_index = i
break
if system_index is not None:
converted_messages[system_index]["content"] += "\n\n" + tool_prompt
messages = converted_messages
else:
messages = [{"role": "system", "content": tool_prompt}] + converted_messages
else:
messages = converted_messages
# 3. 最终规范化:合并 system 消息并移到开头
messages = normalize_messages(messages)
# 4. 流式处理
if req.stream:
return StreamingResponse(
stream_generate(messages, req.temperature, req.max_tokens),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream"
}
)
# 5. 非流式生成
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=req.max_tokens,
temperature=req.temperature,
do_sample=req.temperature > 0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
# 6. 解析工具调用
tool_calls = None
clean_response = response
tool_call_matches = re.findall(r'<tool_call>(.*?)(?:</tool_call>|$)', response, re.DOTALL)
if tool_call_matches:
tool_calls = []
for match in tool_call_matches:
try:
tool_call_data = json.loads(match.strip())
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": tool_call_data.get("name"),
"arguments": json.dumps(tool_call_data.get("arguments", {}), ensure_ascii=False)
}
})
except Exception as e:
logger.warning(f"工具调用解析失败: {e}")
clean_response = re.sub(r'<tool_call>.*?</tool_call>', '', response, flags=re.DOTALL).strip()
prompt_tokens = len(inputs.input_ids[0])
completion_tokens = len(outputs[0]) - prompt_tokens
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": clean_response if not tool_calls else None,
"tool_calls": tool_calls
},
"finish_reason": "tool_calls" if tool_calls else "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
@app.get("/")
async def root():
return {"status": "running", "model": MODEL_NAME}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, timeout_keep_alive=1200)