ray-lei commited on
Commit
d802fc4
·
verified ·
1 Parent(s): 82628df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -41
app.py CHANGED
@@ -1,63 +1,217 @@
1
  import os
2
- from fastapi import FastAPI, Request
3
- from fastapi.responses import JSONResponse
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import torch
6
-
7
  # 设置缓存目录,避免 /.cache 权限问题
8
  os.environ["HF_HOME"] = "/tmp"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
10
  os.environ["HF_HUB_CACHE"] = "/tmp"
11
 
12
- # 初始化 FastAPI
13
- app = FastAPI()
14
-
15
- # 模型 ID
16
- MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"
 
 
 
 
17
 
18
- print("Loading model... (this may take a while the first time)")
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir="/tmp")
20
 
21
- # 加载模型到 GPU (T4 支持 bfloat16,显存不够可换成 float16)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- MODEL_ID,
24
- device_map="auto",
25
- torch_dtype=torch.bfloat16,
26
- trust_remote_code=True,
27
- cache_dir="/tmp"
28
  )
29
- model.eval()
30
- print("Model loaded.")
31
 
32
- # 生成接口 (兼容 OpenAI /v1/completions 简单版)
33
- @app.post("/v1/completions")
34
- async def completions(request: Request):
35
- data = await request.json()
36
- prompt = data.get("prompt") or ""
37
- max_tokens = data.get("max_tokens", 512)
38
 
39
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  with torch.no_grad():
42
  outputs = model.generate(
43
  **inputs,
44
  max_new_tokens=max_tokens,
 
 
45
  do_sample=True,
46
- temperature=0.7,
47
- top_p=0.9,
48
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
51
 
52
- # OpenAI API 格式返回
53
- return JSONResponse({
54
- "id": "cmpl-1",
55
- "object": "text_completion",
56
- "choices": [
57
- {"index": 0, "text": text, "finish_reason": "stop"}
 
 
 
 
58
  ]
59
- })
60
 
61
- @app.get("/")
62
- def root():
63
- return {"status": "ok", "model": MODEL_ID}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
2
  # 设置缓存目录,避免 /.cache 权限问题
3
  os.environ["HF_HOME"] = "/tmp"
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
5
  os.environ["HF_HUB_CACHE"] = "/tmp"
6
 
7
+ import time
8
+ import uuid
9
+ from typing import List, Optional, Union, Dict, Any
10
+ from fastapi import FastAPI, HTTPException
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from pydantic import BaseModel
13
+ import torch
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ import json
16
 
17
+ # 初始化FastAPI应用
18
+ app = FastAPI(title="Qwen Coder API", version="1.0.0")
19
 
20
+ # CORS中间件
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
  )
 
 
28
 
29
+ # 全局变量
30
+ model = None
31
+ tokenizer = None
 
 
 
32
 
33
+ # Pydantic模型定义
34
+ class Message(BaseModel):
35
+ role: str
36
+ content: str
37
+
38
+ class ChatCompletionRequest(BaseModel):
39
+ model: str
40
+ messages: List[Message]
41
+ temperature: Optional[float] = 0.7
42
+ max_tokens: Optional[int] = 2048
43
+ stream: Optional[bool] = False
44
+ top_p: Optional[float] = 0.9
45
+
46
+ class ChatCompletionChoice(BaseModel):
47
+ index: int
48
+ message: Message
49
+ finish_reason: str
50
+
51
+ class Usage(BaseModel):
52
+ prompt_tokens: int
53
+ completion_tokens: int
54
+ total_tokens: int
55
+
56
+ class ChatCompletionResponse(BaseModel):
57
+ id: str
58
+ object: str = "chat.completion"
59
+ created: int
60
+ model: str
61
+ choices: List[ChatCompletionChoice]
62
+ usage: Usage
63
+
64
+ class Model(BaseModel):
65
+ id: str
66
+ object: str = "model"
67
+ created: int
68
+ owned_by: str = "qwen"
69
 
70
+ class ModelListResponse(BaseModel):
71
+ object: str = "list"
72
+ data: List[Model]
73
+
74
+ def load_model():
75
+ """加载Qwen Coder模型"""
76
+ global model, tokenizer
77
+
78
+ model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
79
+
80
+ print("Loading tokenizer...")
81
+ tokenizer = AutoTokenizer.from_pretrained(
82
+ model_name,
83
+ trust_remote_code=True
84
+ )
85
+
86
+ print("Loading model...")
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ model_name,
89
+ torch_dtype=torch.float16,
90
+ device_map="auto",
91
+ trust_remote_code=True,
92
+ low_cpu_mem_usage=True
93
+ )
94
+
95
+ print("Model loaded successfully!")
96
+
97
+ def format_messages(messages: List[Message]) -> str:
98
+ """将消息格式化为Qwen格式"""
99
+ formatted_messages = []
100
+ for msg in messages:
101
+ formatted_messages.append({
102
+ "role": msg.role,
103
+ "content": msg.content
104
+ })
105
+
106
+ # 使用tokenizer的chat template
107
+ text = tokenizer.apply_chat_template(
108
+ formatted_messages,
109
+ tokenize=False,
110
+ add_generation_prompt=True
111
+ )
112
+ return text
113
+
114
+ def generate_response(prompt: str, temperature: float, max_tokens: int, top_p: float) -> str:
115
+ """生成模型响应"""
116
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
117
+
118
  with torch.no_grad():
119
  outputs = model.generate(
120
  **inputs,
121
  max_new_tokens=max_tokens,
122
+ temperature=temperature,
123
+ top_p=top_p,
124
  do_sample=True,
125
+ pad_token_id=tokenizer.eos_token_id
 
126
  )
127
+
128
+ # 只返回新生成的部分
129
+ response = tokenizer.decode(
130
+ outputs[0][inputs['input_ids'].shape[1]:],
131
+ skip_special_tokens=True
132
+ )
133
+
134
+ return response.strip()
135
+
136
+ @app.on_event("startup")
137
+ async def startup_event():
138
+ """应用启动时加载模型"""
139
+ load_model()
140
 
141
+ @app.get("/")
142
+ async def root():
143
+ return {"message": "Qwen Coder API Server is running!"}
144
 
145
+ @app.get("/v1/models", response_model=ModelListResponse)
146
+ async def list_models():
147
+ """列出��用模型"""
148
+ return ModelListResponse(
149
+ data=[
150
+ Model(
151
+ id="qwen2.5-coder-7b-instruct",
152
+ created=int(time.time()),
153
+ owned_by="qwen"
154
+ )
155
  ]
156
+ )
157
 
158
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
159
+ async def chat_completions(request: ChatCompletionRequest):
160
+ """处理聊天补全请求"""
161
+ try:
162
+ if model is None or tokenizer is None:
163
+ raise HTTPException(status_code=503, detail="Model not loaded")
164
+
165
+ # 格式化消息
166
+ prompt = format_messages(request.messages)
167
+
168
+ # 生成响应
169
+ response_text = generate_response(
170
+ prompt,
171
+ request.temperature,
172
+ request.max_tokens,
173
+ request.top_p
174
+ )
175
+
176
+ # 构造响应
177
+ completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
178
+
179
+ # 计算token使用量(简化版本)
180
+ prompt_tokens = len(tokenizer.encode(prompt))
181
+ completion_tokens = len(tokenizer.encode(response_text))
182
+
183
+ response = ChatCompletionResponse(
184
+ id=completion_id,
185
+ created=int(time.time()),
186
+ model=request.model,
187
+ choices=[
188
+ ChatCompletionChoice(
189
+ index=0,
190
+ message=Message(role="assistant", content=response_text),
191
+ finish_reason="stop"
192
+ )
193
+ ],
194
+ usage=Usage(
195
+ prompt_tokens=prompt_tokens,
196
+ completion_tokens=completion_tokens,
197
+ total_tokens=prompt_tokens + completion_tokens
198
+ )
199
+ )
200
+
201
+ return response
202
+
203
+ except Exception as e:
204
+ print(f"Error processing request: {str(e)}")
205
+ raise HTTPException(status_code=500, detail=str(e))
206
+
207
+ @app.get("/health")
208
+ async def health_check():
209
+ """健康检查端点"""
210
+ return {
211
+ "status": "healthy",
212
+ "model_loaded": model is not None and tokenizer is not None
213
+ }
214
+
215
+ if __name__ == "__main__":
216
+ import uvicorn
217
+ uvicorn.run(app, host="0.0.0.0", port=7860)