han145 commited on
Commit
26be9f6
·
verified ·
1 Parent(s): 4b6a18b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -33
app.py CHANGED
@@ -2,11 +2,13 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import json
5
- from fastapi import FastAPI, Request
 
6
  from fastapi.responses import JSONResponse
7
  import logging
8
  import time
9
- import re
 
10
 
11
  # 配置日志
12
  logging.basicConfig(level=logging.INFO)
@@ -15,10 +17,18 @@ logger = logging.getLogger(__name__)
15
  # 全局变量
16
  model = None
17
  tokenizer = None
18
- device = "cpu" # 默认使用CPU
 
 
 
 
 
 
 
 
19
 
20
  def load_model():
21
- """加载模型 - 不使用device_map以避免accelerate依赖"""
22
  global model, tokenizer, device
23
 
24
  if model is not None:
@@ -27,7 +37,6 @@ def load_model():
27
  try:
28
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
29
 
30
- # 检查是否有GPU可用
31
  if torch.cuda.is_available():
32
  device = "cuda"
33
  logger.info("检测到GPU可用,将使用GPU加速")
@@ -36,13 +45,11 @@ def load_model():
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
 
39
- # 加载模型
40
  model = AutoModelForCausalLM.from_pretrained(
41
  model_name,
42
- torch_dtype=torch.float32 # 使用float32确保兼容性
43
  )
44
 
45
- # 将模型移动到设备
46
  model = model.to(device)
47
 
48
  if tokenizer.pad_token is None:
@@ -54,8 +61,40 @@ def load_model():
54
  logger.error(f"模型加载失败: {e}")
55
  return False
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def generate_response(message):
58
- """生成模型响应 - 修正版"""
59
  if not load_model():
60
  return "模型加载失败,请稍后重试"
61
 
@@ -65,7 +104,6 @@ def generate_response(message):
65
  {"role": "user", "content": message}
66
  ]
67
 
68
- # 使用tokenizer的apply_chat_template方法
69
  formatted_prompt = tokenizer.apply_chat_template(
70
  prompt,
71
  tokenize=False,
@@ -77,14 +115,14 @@ def generate_response(message):
77
  formatted_prompt,
78
  return_tensors="pt",
79
  truncation=True,
80
- max_length=512 # 减少输入长度
81
  ).to(device)
82
 
83
- # 生成回复 - 减少生成长度
84
  with torch.no_grad():
85
  outputs = model.generate(
86
  **inputs,
87
- max_new_tokens=128, # 减少生成长度
88
  temperature=0.7,
89
  top_p=0.9,
90
  do_sample=True,
@@ -93,40 +131,43 @@ def generate_response(message):
93
  repetition_penalty=1.1
94
  )
95
 
96
- # 解码回复 - 跳过特殊标记
97
  response = tokenizer.decode(
98
  outputs[0][inputs.input_ids.shape[-1]:],
99
  skip_special_tokens=True
100
  )
101
 
102
- # 关键修正:移除模型内部的思考过程
103
- # 只保留最终回复内容
104
- if "</think>" in response:
105
- # 提取最终回复部分
106
- final_response = response.split("</think>")[-1].strip()
107
- # 移除可能的换行符和多余空格
108
- final_response = re.sub(r'\n+', ' ', final_response).strip()
109
- return final_response
110
- else:
111
- return response.strip()
112
 
113
  except Exception as e:
114
  logger.error(f"生成回复时出错: {str(e)}")
115
  return f"生成回复时出错: {str(e)}"
116
 
117
  # 创建FastAPI应用
118
- app = FastAPI()
 
 
 
 
 
119
 
120
- # 添加API端点
 
 
 
 
121
  @app.post("/api/chat")
122
- async def chat_api(request: Request):
123
- """OpenAI兼容的聊天API端点"""
 
 
 
124
  try:
125
  # 解析请求数据
126
  data = await request.json()
127
  messages = data.get("messages", [])
128
  model_name = data.get("model", "deepseek-r1")
129
- max_tokens = data.get("max_tokens", 128) # 默认128
130
  temperature = data.get("temperature", 0.7)
131
 
132
  # 提取最后一条用户消息
@@ -178,12 +219,15 @@ async def chat_api(request: Request):
178
 
179
  # 创建Gradio界面
180
  with gr.Blocks(title="DeepSeek-R1 API服务") as demo:
181
- gr.Markdown("""
182
  # DeepSeek-R1 API 服务
183
  *基于DeepSeek-R1-Distill-Qwen-1.5B模型*
184
 
 
 
185
  ## API端点信息
186
- - **OpenAI兼容端点**: `/api/chat`
 
187
  - **模型名称**: `deepseek-r1`
188
  """)
189
 
@@ -196,11 +240,9 @@ with gr.Blocks(title="DeepSeek-R1 API服务") as demo:
196
 
197
  # 处理函数
198
  def respond(message):
199
- """处理用户输入"""
200
  if not message.strip():
201
  return ""
202
 
203
- # 生成响应
204
  response = generate_response(message)
205
  return response
206
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import json
5
+ from fastapi import FastAPI, Request, HTTPException, Security, Depends
6
+ from fastapi.security import APIKeyHeader
7
  from fastapi.responses import JSONResponse
8
  import logging
9
  import time
10
+ import os
11
+ from typing import Optional
12
 
13
  # 配置日志
14
  logging.basicConfig(level=logging.INFO)
 
17
  # 全局变量
18
  model = None
19
  tokenizer = None
20
+ device = "cpu"
21
+
22
+ # 安全配置
23
+ # 从环境变量读取配置,默认启用安全认证
24
+ TEST_MODE: bool = os.getenv("TEST_MODE", "false").lower() == "true"
25
+ API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
26
+
27
+ # 初始化API密钥头认证
28
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
29
 
30
  def load_model():
31
+ """加载模型"""
32
  global model, tokenizer, device
33
 
34
  if model is not None:
 
37
  try:
38
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
39
 
 
40
  if torch.cuda.is_available():
41
  device = "cuda"
42
  logger.info("检测到GPU可用,将使用GPU加速")
 
45
 
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
 
 
48
  model = AutoModelForCausalLM.from_pretrained(
49
  model_name,
50
+ torch_dtype=torch.float32
51
  )
52
 
 
53
  model = model.to(device)
54
 
55
  if tokenizer.pad_token is None:
 
61
  logger.error(f"模型加载失败: {e}")
62
  return False
63
 
64
+ def verify_api_key(
65
+ request_key_header: Optional[str] = Security(api_key_header) if not TEST_MODE else None,
66
+ ) -> str:
67
+ """
68
+ API密钥验证依赖函数
69
+ 支持测试模式和生产模式切换[1,3](@ref)
70
+ """
71
+ logger.info(f"当前安全模式: {'测试模式' if TEST_MODE else '生产模式'}")
72
+
73
+ # 测试模式:跳过认证
74
+ if TEST_MODE:
75
+ logger.info("测试模式下跳过API密钥验证")
76
+ return "test_mode_bypass"
77
+
78
+ # 生产模式:严格验证
79
+ if request_key_header is None:
80
+ logger.warning("请求头中缺少API密钥")
81
+ raise HTTPException(
82
+ status_code=401,
83
+ detail="缺少API密钥,请在请求头中添加 X-API-Key"
84
+ )
85
+
86
+ if request_key_header not in API_KEYS:
87
+ logger.warning(f"无效的API密钥尝试: {request_key_header}")
88
+ raise HTTPException(
89
+ status_code=401,
90
+ detail="无效的API密钥"
91
+ )
92
+
93
+ logger.info("API密钥验证通过")
94
+ return request_key_header
95
+
96
  def generate_response(message):
97
+ """生成模型响应"""
98
  if not load_model():
99
  return "模型加载失败,请稍后重试"
100
 
 
104
  {"role": "user", "content": message}
105
  ]
106
 
 
107
  formatted_prompt = tokenizer.apply_chat_template(
108
  prompt,
109
  tokenize=False,
 
115
  formatted_prompt,
116
  return_tensors="pt",
117
  truncation=True,
118
+ max_length=512
119
  ).to(device)
120
 
121
+ # 生成回复
122
  with torch.no_grad():
123
  outputs = model.generate(
124
  **inputs,
125
+ max_new_tokens=128,
126
  temperature=0.7,
127
  top_p=0.9,
128
  do_sample=True,
 
131
  repetition_penalty=1.1
132
  )
133
 
134
+ # 解码回复
135
  response = tokenizer.decode(
136
  outputs[0][inputs.input_ids.shape[-1]:],
137
  skip_special_tokens=True
138
  )
139
 
140
+ return response.strip()
 
 
 
 
 
 
 
 
 
141
 
142
  except Exception as e:
143
  logger.error(f"生成回复时出错: {str(e)}")
144
  return f"生成回复时出错: {str(e)}"
145
 
146
  # 创建FastAPI应用
147
+ app = FastAPI(title="DeepSeek-R1 API服务", description="带API密钥认证的大模型服务")
148
+
149
+ # API健康检查端点(无需认证)
150
+ @app.get("/")
151
+ async def root():
152
+ return {"message": "DeepSeek-R1 API服务运行中", "timestamp": int(time.time())}
153
 
154
+ @app.get("/health")
155
+ async def health_check():
156
+ return {"status": "healthy", "model_loaded": model is not None}
157
+
158
+ # 受保护的聊天API端点
159
  @app.post("/api/chat")
160
+ async def chat_api(
161
+ request: Request,
162
+ api_key: str = Depends(verify_api_key) # 添加API密钥依赖
163
+ ):
164
+ """OpenAI兼容的聊天API端点(需要API密钥认证)"""
165
  try:
166
  # 解析请求数据
167
  data = await request.json()
168
  messages = data.get("messages", [])
169
  model_name = data.get("model", "deepseek-r1")
170
+ max_tokens = data.get("max_tokens", 128)
171
  temperature = data.get("temperature", 0.7)
172
 
173
  # 提取最后一条用户消息
 
219
 
220
  # 创建Gradio界面
221
  with gr.Blocks(title="DeepSeek-R1 API服务") as demo:
222
+ gr.Markdown(f"""
223
  # DeepSeek-R1 API 服务
224
  *基于DeepSeek-R1-Distill-Qwen-1.5B模型*
225
 
226
+ ## 安全状态: {'🔓 测试模式(认证已禁用)' if TEST_MODE else '🔒 生产模式(认证已启用)'}
227
+
228
  ## API端点信息
229
+ - **聊天端点**: `/api/chat` (需要API密钥认证)
230
+ - **健康检查**: `/health` (公开)
231
  - **模型名称**: `deepseek-r1`
232
  """)
233
 
 
240
 
241
  # 处理函数
242
  def respond(message):
 
243
  if not message.strip():
244
  return ""
245
 
 
246
  response = generate_response(message)
247
  return response
248