nagose commited on
Commit
48a1762
·
verified ·
1 Parent(s): b6585a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -62
app.py CHANGED
@@ -4,36 +4,22 @@ import time
4
  import uuid
5
  from typing import List, Optional, Dict, Any, Union
6
 
7
- from fastapi import FastAPI
 
8
  from fastapi.responses import StreamingResponse
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
- from llama_cpp import Llama
12
 
13
- # 配置日志
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
- # ====================== 模型配置 ======================
18
- # 使用 Hugging Face GGUF 模型(4B Q4_K_M 版本)
19
- REPO_ID = "lmstudio-community/Qwen3.5-4B-GGUF"
20
- FILENAME = "Qwen3.5-4B-Q4_K_M.gguf"
21
- MODEL_ID = "qwen3.5-4b" # CoPaw 中配置的模型名称
22
-
23
- # 加载模型(自动从 HF 下载并缓存)
24
- logger.info(f"正在从 {REPO_ID} 加载模型 {FILENAME}...")
25
- llm = Llama.from_pretrained(
26
- repo_id=REPO_ID,
27
- filename=FILENAME,
28
- n_ctx=4096, # 上下文窗口,可根据需求调整
29
- n_threads=None, # 自动使用所有 CPU 线程
30
- verbose=False,
31
- )
32
- logger.info("模型加载完成!")
33
 
34
- app = FastAPI(title="Qwen3.5-4B GGUF API (CoPaw兼容)")
35
 
36
- # ====================== CORS 中间件 ======================
37
  app.add_middleware(
38
  CORSMiddleware,
39
  allow_origins=["*"],
@@ -42,7 +28,7 @@ app.add_middleware(
42
  allow_headers=["*"],
43
  )
44
 
45
- # ====================== CoPaw 所需端点 ======================
46
  @app.get("/health")
47
  async def health():
48
  return {"status": "healthy"}
@@ -74,7 +60,7 @@ async def list_models():
74
  ]
75
  }
76
 
77
- # ====================== 请求/响应数据模型 ======================
78
  class Message(BaseModel):
79
  role: str
80
  content: Optional[Union[str, List[Dict[str, Any]]]] = None
@@ -88,8 +74,8 @@ class ChatRequest(BaseModel):
88
  tools: Optional[List[Dict[str, Any]]] = None
89
  tool_choice: Optional[str] = None
90
 
91
- # ====================== 辅助函数 ======================
92
  def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]]) -> str:
 
93
  if content is None:
94
  return ""
95
  if isinstance(content, str):
@@ -105,10 +91,10 @@ def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]])
105
  # ====================== 聊天接口 ======================
106
  @app.post("/v1/chat/completions")
107
  async def chat_completions(req: ChatRequest):
108
- # 转换消息格式
109
  messages = [{"role": m.role, "content": convert_content_to_str(m.content)} for m in req.messages]
110
 
111
- # 处理 tools:将具描述合并到 system 消息中
112
  if req.tools:
113
  tools_json = json.dumps(req.tools, ensure_ascii=False)
114
  tool_prompt = (
@@ -122,47 +108,42 @@ async def chat_completions(req: ChatRequest):
122
  else:
123
  messages.insert(0, {"role": "system", "content": tool_prompt})
124
 
125
- # 流式处理
126
- if req.stream:
127
- stream = llm.create_chat_completion_openai_v1(
128
- messages=messages,
129
- temperature=req.temperature,
130
- max_tokens=req.max_tokens,
131
- stream=True,
132
- )
133
 
 
 
134
  async def generate():
135
- chunk_id = f"chatcmpl-{uuid.uuid4().hex}"
136
- for chunk in stream:
137
- if chunk.choices:
138
- delta = chunk.choices[0].delta
139
- finish_reason = chunk.choices[0].finish_reason
140
- response_chunk = {
141
- "id": chunk_id,
142
- "object": "chat.completion.chunk",
143
- "created": int(time.time()),
144
- "model": req.model,
145
- "choices": [{
146
- "index": 0,
147
- "delta": delta.model_dump(exclude_none=True),
148
- "finish_reason": finish_reason
149
- }]
150
- }
151
- yield f"data: {json.dumps(response_chunk)}\n\n"
152
- if finish_reason:
153
- yield "data: [DONE]\n\n"
154
  return StreamingResponse(generate(), media_type="text/event-stream")
155
 
156
- # 非流式处理
157
  else:
158
- response = llm.create_chat_completion_openai_v1(
159
- messages=messages,
160
- temperature=req.temperature,
161
- max_tokens=req.max_tokens,
162
- stream=False,
163
- )
164
- return response
 
 
165
 
166
  @app.get("/")
167
  async def root():
168
- return {"status": "running", "model": f"{REPO_ID}/{FILENAME}"}
 
4
  import uuid
5
  from typing import List, Optional, Dict, Any, Union
6
 
7
+ import httpx
8
+ from fastapi import FastAPI, HTTPException
9
  from fastapi.responses import StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
 
12
 
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # ====================== 配置 ======================
17
+ MODEL_ID = "qwen3.5-4b" # CoPaw 中填写的模型名称
18
+ LLAMA_SERVER_URL = "http://127.0.0.1:8080" # 本地 llama-server 地址
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ app = FastAPI(title="Qwen3.5-4B Proxy for CoPaw")
21
 
22
+ # CORS 中间件(CoPaw 必须)
23
  app.add_middleware(
24
  CORSMiddleware,
25
  allow_origins=["*"],
 
28
  allow_headers=["*"],
29
  )
30
 
31
+ # ====================== CoPaw 所需额外端点 ======================
32
  @app.get("/health")
33
  async def health():
34
  return {"status": "healthy"}
 
60
  ]
61
  }
62
 
63
+ # ====================== 请求/响应模型 ======================
64
  class Message(BaseModel):
65
  role: str
66
  content: Optional[Union[str, List[Dict[str, Any]]]] = None
 
74
  tools: Optional[List[Dict[str, Any]]] = None
75
  tool_choice: Optional[str] = None
76
 
 
77
  def convert_content_to_str(content: Optional[Union[str, List[Dict[str, Any]]]]) -> str:
78
+ """将 OpenAI 结构化 content 转换为纯文本"""
79
  if content is None:
80
  return ""
81
  if isinstance(content, str):
 
91
  # ====================== 聊天接口 ======================
92
  @app.post("/v1/chat/completions")
93
  async def chat_completions(req: ChatRequest):
94
+ # 1. 转换消息格式
95
  messages = [{"role": m.role, "content": convert_content_to_str(m.content)} for m in req.messages]
96
 
97
+ # 2. 处理 tools(简单提示程)
98
  if req.tools:
99
  tools_json = json.dumps(req.tools, ensure_ascii=False)
100
  tool_prompt = (
 
108
  else:
109
  messages.insert(0, {"role": "system", "content": tool_prompt})
110
 
111
+ # 3. 构造转发给 llama-server 的请求体
112
+ payload = {
113
+ "messages": messages,
114
+ "temperature": req.temperature,
115
+ "max_tokens": req.max_tokens,
116
+ "stream": req.stream,
117
+ "model": "local" # llama-server 可能忽略此字段
118
+ }
119
 
120
+ # 4. 流式处理
121
+ if req.stream:
122
  async def generate():
123
+ async with httpx.AsyncClient(timeout=None) as client:
124
+ async with client.stream(
125
+ "POST",
126
+ f"{LLAMA_SERVER_URL}/v1/chat/completions",
127
+ json=payload,
128
+ headers={"Content-Type": "application/json"}
129
+ ) as response:
130
+ async for line in response.aiter_lines():
131
+ if line.startswith("data: "):
132
+ yield line + "\n\n"
 
 
 
 
 
 
 
 
 
133
  return StreamingResponse(generate(), media_type="text/event-stream")
134
 
135
+ # 5. 非流式处理
136
  else:
137
+ async with httpx.AsyncClient(timeout=300.0) as client:
138
+ resp = await client.post(
139
+ f"{LLAMA_SERVER_URL}/v1/chat/completions",
140
+ json=payload,
141
+ headers={"Content-Type": "application/json"}
142
+ )
143
+ if resp.status_code != 200:
144
+ raise HTTPException(status_code=resp.status_code, detail=resp.text)
145
+ return resp.json()
146
 
147
  @app.get("/")
148
  async def root():
149
+ return {"status": "running", "model": "Qwen3.5-4B via llama-server"}