Spaces:
nagose
/
Runtime error

nagose commited on
Commit
c18b8e3
·
verified ·
1 Parent(s): d4dc38d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -36
app.py CHANGED
@@ -1,19 +1,21 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
- from fastapi.middleware.cors import CORSMiddleware # 新增
4
  from pydantic import BaseModel
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  import torch
7
  import json
8
  import time
9
  import uuid
 
10
  from typing import List, Optional, Dict, Any
11
  from threading import Thread
12
 
13
- # ====================== 你的 7B 模型 ======================
14
  MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
15
- MODEL_ID = "qwen2.5-7b" # 必须与 CoPaw 中填写的模型名称一致
16
 
 
17
  bnb_config = BitsAndBytesConfig(
18
  load_in_4bit=True,
19
  bnb_4bit_use_double_quant=True,
@@ -21,19 +23,24 @@ bnb_config = BitsAndBytesConfig(
21
  bnb_4bit_compute_dtype=torch.float16
22
  )
23
 
24
- print("🔹 加载模型:Qwen2.5-7B-Instruct")
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
  MODEL_NAME,
28
  quantization_config=bnb_config,
29
- device_map="auto",
30
  trust_remote_code=True,
31
  low_cpu_mem_usage=True
32
  )
 
33
 
34
- app = FastAPI(title="Qwen2.5-7B API (CoPaw兼容)")
35
 
36
- # ====================== 新增:CORS 中间件(CoPaw 必须) ======================
37
  app.add_middleware(
38
  CORSMiddleware,
39
  allow_origins=["*"],
@@ -42,7 +49,7 @@ app.add_middleware(
42
  allow_headers=["*"],
43
  )
44
 
45
- # ====================== 新增:CoPaw 测试连接所需端点 ======================
46
  @app.get("/health")
47
  async def health():
48
  return {"status": "healthy"}
@@ -60,17 +67,22 @@ async def get_me():
60
  async def get_bots():
61
  return {"objects": []}
62
 
63
- # ====================== 原有的 /v1/models(已存在,无需修改) ======================
64
  @app.get("/v1/models")
65
  async def list_models():
 
66
  return {
 
67
  "data": [
68
- {"id": MODEL_ID, "object": "model", "created": 1773000000, "owned_by": "qwen"}
69
- ],
70
- "object": "list"
 
 
 
 
71
  }
72
 
73
- # ====================== 请求结构 ======================
74
  class Message(BaseModel):
75
  role: str
76
  content: Optional[str] = None
@@ -81,60 +93,158 @@ class ChatRequest(BaseModel):
81
  max_tokens: Optional[int] = 1024
82
  model: Optional[str] = MODEL_ID
83
  stream: Optional[bool] = False
84
- tools: Optional[List[Dict]] = None
85
  tool_choice: Optional[str] = None
86
 
87
  # ====================== 流式生成 ======================
88
  def stream_generate(messages, temperature=0.7, max_new_tokens=1024):
89
  try:
 
90
  text = tokenizer.apply_chat_template(
91
  messages, tokenize=False, add_generation_prompt=True
92
  )
93
- inputs = tokenizer([text], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
94
- from transformers import TextIteratorStreamer
95
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
96
 
97
  gen_kwargs = {
98
  **inputs,
99
  "streamer": streamer,
100
  "max_new_tokens": max_new_tokens,
101
  "temperature": temperature,
102
- "do_sample": True
 
 
103
  }
 
104
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
105
  thread.start()
106
 
 
 
 
 
107
  for new_text in streamer:
108
- chunk = {
109
- "id": f"chatcmpl-{uuid.uuid4()}",
110
- "object": "chat.completion.chunk",
111
- "created": int(time.time()),
112
- "model": MODEL_ID,
113
- "choices": [{"index": 0, "delta": {"content": new_text}, "finish_reason": None}]
114
- }
115
- yield f"data: {json.dumps(chunk)}\n\n"
 
 
 
 
 
 
 
 
116
  yield "data: [DONE]\n\n"
117
  except Exception as e:
118
- yield f"data: {json.dumps({'error': str(e)})}\n\n"
119
 
120
- # ====================== 聊天接口 ======================
121
  @app.post("/v1/chat/completions")
122
  async def chat_completions(req: ChatRequest):
123
- messages = [{"role": m.role, "content": m.content} for m in req.messages]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if req.stream:
125
- return StreamingResponse(stream_generate(messages, req.temperature, req.max_tokens), media_type="text/event-stream")
 
 
 
 
 
 
 
 
126
 
 
127
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
- inputs = tokenizer([text], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
129
- outputs = model.generate(**inputs, max_new_tokens=req.max_tokens, temperature=req.temperature, do_sample=True)
 
 
 
 
 
 
 
 
 
 
130
  response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return {
133
- "id": f"chatcmpl-{uuid.uuid4()}",
134
  "object": "chat.completion",
135
  "created": int(time.time()),
136
  "model": req.model,
137
- "choices": [{"index": 0, "message": {"role": "assistant", "content": response}, "finish_reason": "stop"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  }
139
 
140
  @app.get("/")
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
6
  import torch
7
  import json
8
  import time
9
  import uuid
10
+ import re
11
  from typing import List, Optional, Dict, Any
12
  from threading import Thread
13
 
14
+ # ====================== 模型配置 ======================
15
  MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
16
+ MODEL_ID = "qwen2.5-7b" # 自定义模型标识符,前端需与此一致
17
 
18
+ # 4-bit 量化配置(适用于 CPU/GPU)
19
  bnb_config = BitsAndBytesConfig(
20
  load_in_4bit=True,
21
  bnb_4bit_use_double_quant=True,
 
23
  bnb_4bit_compute_dtype=torch.float16
24
  )
25
 
26
+ print("🔹 加载模型:Qwen2.5-7B-Instruct (4-bit 量化)")
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
28
+ # 确保 tokenizer 有 pad_token
29
+ if tokenizer.pad_token is None:
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+
32
  model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_NAME,
34
  quantization_config=bnb_config,
35
+ device_map="auto", # 自动选择设备(CPU/GPU)
36
  trust_remote_code=True,
37
  low_cpu_mem_usage=True
38
  )
39
+ print("✅ 模型加载完成")
40
 
41
+ app = FastAPI(title="Qwen2.5-7B API (OpenAI 兼容)")
42
 
43
+ # ====================== CORS 中间件 ======================
44
  app.add_middleware(
45
  CORSMiddleware,
46
  allow_origins=["*"],
 
49
  allow_headers=["*"],
50
  )
51
 
52
+ # ====================== CoPaw 所需额外端点 ======================
53
  @app.get("/health")
54
  async def health():
55
  return {"status": "healthy"}
 
67
  async def get_bots():
68
  return {"objects": []}
69
 
 
70
  @app.get("/v1/models")
71
  async def list_models():
72
+ """返回 OpenAI 格式的模型列表"""
73
  return {
74
+ "object": "list",
75
  "data": [
76
+ {
77
+ "id": MODEL_ID,
78
+ "object": "model",
79
+ "created": 1773000000,
80
+ "owned_by": "qwen"
81
+ }
82
+ ]
83
  }
84
 
85
+ # ====================== 请求/响应数据模型 ======================
86
  class Message(BaseModel):
87
  role: str
88
  content: Optional[str] = None
 
93
  max_tokens: Optional[int] = 1024
94
  model: Optional[str] = MODEL_ID
95
  stream: Optional[bool] = False
96
+ tools: Optional[List[Dict[str, Any]]] = None
97
  tool_choice: Optional[str] = None
98
 
99
  # ====================== 流式生成 ======================
100
  def stream_generate(messages, temperature=0.7, max_new_tokens=1024):
101
  try:
102
+ # 使用 chat template 构建提示词
103
  text = tokenizer.apply_chat_template(
104
  messages, tokenize=False, add_generation_prompt=True
105
  )
106
+ inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device)
107
+
108
+ streamer = TextIteratorStreamer(
109
+ tokenizer,
110
+ skip_prompt=True,
111
+ skip_special_tokens=True,
112
+ timeout=60.0
113
+ )
114
 
115
  gen_kwargs = {
116
  **inputs,
117
  "streamer": streamer,
118
  "max_new_tokens": max_new_tokens,
119
  "temperature": temperature,
120
+ "do_sample": temperature > 0,
121
+ "pad_token_id": tokenizer.pad_token_id,
122
+ "eos_token_id": tokenizer.eos_token_id
123
  }
124
+
125
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
126
  thread.start()
127
 
128
+ # 首先发送角色信息(OpenAI 格式要求)
129
+ chunk_id = f"chatcmpl-{uuid.uuid4().hex}"
130
+ yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
131
+
132
  for new_text in streamer:
133
+ if new_text:
134
+ chunk = {
135
+ "id": chunk_id,
136
+ "object": "chat.completion.chunk",
137
+ "created": int(time.time()),
138
+ "model": MODEL_ID,
139
+ "choices": [{
140
+ "index": 0,
141
+ "delta": {"content": new_text},
142
+ "finish_reason": None
143
+ }]
144
+ }
145
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
146
+
147
+ # 发送结束 chunk
148
+ yield f"data: {json.dumps({'id': chunk_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': MODEL_ID, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
149
  yield "data: [DONE]\n\n"
150
  except Exception as e:
151
+ yield f"data: {json.dumps({'error': {'message': str(e)}})}\n\n"
152
 
153
+ # ====================== 非流式生成(支持工具调用)======================
154
  @app.post("/v1/chat/completions")
155
  async def chat_completions(req: ChatRequest):
156
+ # 构建基础消息列表
157
+ base_messages = [{"role": m.role, "content": m.content} for m in req.messages]
158
+
159
+ # 如果提供了 tools,将其转换为系统提示(Qwen 2.5 推荐方式)
160
+ if req.tools:
161
+ tools_json = json.dumps(req.tools, ensure_ascii=False)
162
+ # 构造工具调用提示,要求输出特定格式
163
+ tool_prompt = f"""你是一个助手,可以使用以下工具:
164
+ {tools_json}
165
+ 当用户的问题需要调用工具时,请输出 <tool_call>...</tool_call> 标签,内部是一个 JSON 对象,必须包含 "name" 和 "arguments" 字段。arguments 是一个对象,包含工具所需的参数。
166
+ 例如:<tool_call>{{"name": "get_weather", "arguments": {{"location": "Beijing"}}}}</tool_call>
167
+ 如果不需要调用工具,则正常回答。"""
168
+ messages = [{"role": "system", "content": tool_prompt}] + base_messages
169
+ else:
170
+ messages = base_messages
171
+
172
+ # 流式处理
173
  if req.stream:
174
+ return StreamingResponse(
175
+ stream_generate(messages, req.temperature, req.max_tokens),
176
+ media_type="text/event-stream",
177
+ headers={
178
+ "Cache-Control": "no-cache",
179
+ "Connection": "keep-alive",
180
+ "Content-Type": "text/event-stream"
181
+ }
182
+ )
183
 
184
+ # 非流式生成
185
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
186
+ inputs = tokenizer([text], return_tensors="pt", padding=True).to(model.device)
187
+
188
+ with torch.no_grad():
189
+ outputs = model.generate(
190
+ **inputs,
191
+ max_new_tokens=req.max_tokens,
192
+ temperature=req.temperature,
193
+ do_sample=req.temperature > 0,
194
+ pad_token_id=tokenizer.pad_token_id,
195
+ eos_token_id=tokenizer.eos_token_id
196
+ )
197
+
198
  response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
199
 
200
+ # 解析工具调用(Qwen 返回 <tool_call>...</tool_call> 标签)
201
+ tool_calls = None
202
+ clean_response = response
203
+ tool_call_matches = re.findall(r'<tool_call>(.*?)</tool_call>', response, re.DOTALL)
204
+
205
+ if tool_call_matches:
206
+ tool_calls = []
207
+ for match in tool_call_matches:
208
+ try:
209
+ tool_call_data = json.loads(match)
210
+ # 转换为 OpenAI 工具调用格式
211
+ tool_calls.append({
212
+ "id": f"call_{uuid.uuid4().hex[:8]}",
213
+ "type": "function",
214
+ "function": {
215
+ "name": tool_call_data.get("name"),
216
+ "arguments": json.dumps(tool_call_data.get("arguments", {}), ensure_ascii=False)
217
+ }
218
+ })
219
+ except Exception as e:
220
+ print(f"工具调用解析失败: {e}")
221
+ # 移除所有 tool_call 标签,保留剩余文本(如果有��
222
+ clean_response = re.sub(r'<tool_call>.*?</tool_call>', '', response, flags=re.DOTALL).strip()
223
+
224
+ # 计算 token 用量
225
+ prompt_tokens = len(inputs.input_ids[0])
226
+ completion_tokens = len(outputs[0]) - prompt_tokens
227
+
228
+ # 构建 OpenAI 格式响应
229
  return {
230
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
231
  "object": "chat.completion",
232
  "created": int(time.time()),
233
  "model": req.model,
234
+ "choices": [{
235
+ "index": 0,
236
+ "message": {
237
+ "role": "assistant",
238
+ "content": clean_response if not tool_calls else None,
239
+ "tool_calls": tool_calls
240
+ },
241
+ "finish_reason": "tool_calls" if tool_calls else "stop"
242
+ }],
243
+ "usage": {
244
+ "prompt_tokens": prompt_tokens,
245
+ "completion_tokens": completion_tokens,
246
+ "total_tokens": prompt_tokens + completion_tokens
247
+ }
248
  }
249
 
250
  @app.get("/")