ll4b / app.py
nagose's picture
Update app.py
48a1762 verified
import logging
import json
import time
import uuid
from typing import List, Optional, Dict, Any, Union
import httpx
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ====================== 配置 ======================
MODEL_ID = "qwen3.5-4b" # CoPaw 中填写的模型名称
LLAMA_SERVER_URL = "http://127.0.0.1:8080" # 本地 llama-server 地址
app = FastAPI(title="Qwen3.5-4B Proxy for CoPaw")
# CORS 中间件(CoPaw 必须)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ====================== CoPaw 所需额外端点 ======================
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/v1/me")
async def get_me():
return {
"id": "local-user",
"name": "Local User",
"email": "user@localhost",
"is_admin": True
}
@app.get("/v1/dashboard/bots")
async def get_bots():
return {"objects": []}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": MODEL_ID,
"object": "model",
"created": 1773000000,
"owned_by": "user"
}
]
}
# ====================== 请求/响应模型 ======================
class Message(BaseModel):
role: str
content: Optional[Union[str, List[Dict[str, Any]]]] = None
class ChatRequest(BaseModel):
messages: List[Message]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 1024
model: Optional[str] = MODEL_ID
stream: Optional[bool] = False
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[str] = None
def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]]) -> str:
"""将 OpenAI 结构化 content 转换为纯文本"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
texts.append(part.get("text", ""))
return "\n".join(texts)
return str(content)
# ====================== 聊天接口 ======================
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
# 1. 转换消息格式
messages = [{"role": m.role, "content": convert_content_to_str(m.content)} for m in req.messages]
# 2. 处理 tools(简单提示工程)
if req.tools:
tools_json = json.dumps(req.tools, ensure_ascii=False)
tool_prompt = (
f"你是一个助手,可以使用以下工具:\n{tools_json}\n"
f"当用户的问题需要调用工具时,请输出 <tool_call>{{...}}</tool_call> 格式的 JSON。"
)
# 查找现有 system 消息,有则合并,否则创建
system_index = next((i for i, m in enumerate(messages) if m["role"] == "system"), None)
if system_index is not None:
messages[system_index]["content"] += "\n\n" + tool_prompt
else:
messages.insert(0, {"role": "system", "content": tool_prompt})
# 3. 构造转发给 llama-server 的请求体
payload = {
"messages": messages,
"temperature": req.temperature,
"max_tokens": req.max_tokens,
"stream": req.stream,
"model": "local" # llama-server 可能忽略此字段
}
# 4. 流式处理
if req.stream:
async def generate():
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
f"{LLAMA_SERVER_URL}/v1/chat/completions",
json=payload,
headers={"Content-Type": "application/json"}
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
yield line + "\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# 5. 非流式处理
else:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(
f"{LLAMA_SERVER_URL}/v1/chat/completions",
json=payload,
headers={"Content-Type": "application/json"}
)
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.text)
return resp.json()
@app.get("/")
async def root():
return {"status": "running", "model": "Qwen3.5-4B via llama-server"}