fiewolf1000 commited on
Commit
59a13f0
·
verified ·
1 Parent(s): 2246f9f

Update inference_node.py

Browse files
Files changed (1) hide show
  1. inference_node.py +129 -35
inference_node.py CHANGED
@@ -5,21 +5,23 @@ import os
5
  import logging
6
  import torch
7
  import asyncio
 
8
  from transformers import (
9
  AutoModelForCausalLM, AutoTokenizer,
10
  BitsAndBytesConfig, TextStreamer
11
  )
12
 
13
- # 1. 基础配置
14
- logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s")
 
 
 
15
  logger = logging.getLogger("inference_node_deepseek")
16
  app = FastAPI(title="推理节点服务(DeepSeek-Math-7B-RL)")
17
 
18
  # 2. 模型配置:使用 DeepSeek 官方公开且无访问限制的模型
19
- # 正确 ID:deepseek-ai/deepseek-math-7b-rl(公开无需令牌,支持数学/通用对话)
20
- # 新增 revision="main":明确加载主分支,避免版本解析错误
21
  MODEL_NAME = os.getenv("MODEL_NAME", "deepseek-ai/deepseek-math-7b-rl")
22
- MODEL_REVISION = "main" # 关键:指定模型分支,确保找到文件
23
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") # 公开模型,可留空
24
 
25
  # 3. 4bit量化配置(适配16G内存,DeepSeek 优化)
@@ -30,49 +32,67 @@ bnb_config = BitsAndBytesConfig(
30
  bnb_4bit_compute_dtype=torch.float16 # 降低显存占用,适配 DeepSeek
31
  )
32
 
33
- # 4. 加载 DeepSeek 模型(新增 revision 参数,确保找到文件)
34
  try:
35
  logger.info(f"开始加载模型:{MODEL_NAME}(分支:{MODEL_REVISION},4bit量化)")
36
- # 加载 Tokenizer(新增 revision 参数,匹配模型文件)
37
  tokenizer = AutoTokenizer.from_pretrained(
38
  MODEL_NAME,
39
- revision=MODEL_REVISION, # 关键:指定分支
40
  token=HF_TOKEN,
41
  padding_side="right",
42
- trust_remote_code=True # DeepSeek 必需:加载自定义 Tokenizer 逻辑
43
  )
44
- # 手动设置 pad_token(DeepSeek 默认无,避免生成警告)
45
  if tokenizer.pad_token is None:
46
  tokenizer.pad_token = tokenizer.eos_token
 
47
 
48
- # 加载量化模型(同样指定 revision)
49
  model = AutoModelForCausalLM.from_pretrained(
50
  MODEL_NAME,
51
- revision=MODEL_REVISION, # 关键:与 Tokenizer 分支一致
52
  quantization_config=bnb_config,
53
  device_map="auto", # 自动分配 GPU/CPU
54
  token=HF_TOKEN,
55
- trust_remote_code=True, # DeepSeek 必需:加载自定义模型结构
56
  torch_dtype=torch.float16
57
  )
58
- # 流式生成器(保留特殊标记,确保对话连贯性)
 
 
 
59
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
60
  logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 5-6GB(4bit 量化)")
61
  except Exception as e:
62
  logger.error(f"模型加载失败:{str(e)}", exc_info=True)
63
  raise SystemExit(f"服务终止:{str(e)}")
64
 
65
- # 5. 请求模型(支持数学推理和通用对话,适配场景)
66
  class NodeInferenceRequest(BaseModel):
67
- prompt: str # 输入需求(如“解一元二次方程 x²-5x+6=0”)
68
  max_tokens: int = 1024
69
- is_math: bool = False # 可选:是否为数学任务,优化生成逻辑
 
70
 
71
- # 6. 流式推理接口(适配 DeepSeek 对话格式,支持数学场景)
72
  @app.post("/node/stream-infer")
73
  async def stream_infer(req: NodeInferenceRequest, request: Request):
 
 
 
 
 
 
74
  try:
75
- # 适配 DeepSeek 对话格式(数学任务添加特殊提示,提升准确性)
 
 
 
 
 
 
 
76
  if req.is_math:
77
  prompt = f"""你是专业的数学助手,需详细步骤解答数学问题。
78
  问题:{req.prompt}
@@ -82,18 +102,28 @@ async def stream_infer(req: NodeInferenceRequest, request: Request):
82
  问题:{req.prompt}
83
  回答:"""
84
 
85
- # 构建输入(用标准 tokenize 方法,避免兼容问题)
86
  inputs = tokenizer(
87
  prompt,
88
  return_tensors="pt",
89
  truncation=True,
90
- max_length=2048 # 限制输入长度,预留生成空间
91
  ).to(model.device)
 
 
 
 
 
 
92
 
93
- # 异步生成器(确保流式输出)
94
  async def generate_chunks():
 
 
95
  loop = asyncio.get_running_loop()
96
- # 调用 DeepSeek 生成(数学任务用低温度,确保步骤正确)
 
 
97
  outputs = await loop.run_in_executor(
98
  None,
99
  lambda: model.generate(
@@ -101,47 +131,111 @@ async def stream_infer(req: NodeInferenceRequest, request: Request):
101
  streamer=streamer,
102
  max_new_tokens=req.max_tokens,
103
  do_sample=True,
104
- temperature=0.3 if req.is_math else 0.7, # 数学任务低温度(0.3)
105
  top_p=0.95,
106
  pad_token_id=tokenizer.pad_token_id,
107
  eos_token_id=tokenizer.eos_token_id
108
  )
109
  )
 
 
 
 
 
 
110
 
111
- # 逐段解码(仅取生成部分,排除输入 Prompt)
112
  generated_tokens = outputs[0][len(inputs["input_ids"][0]):]
113
- for token in generated_tokens:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if await request.is_disconnected():
115
- logger.info("客户端断开,停止生成")
116
  break
117
- # 解码 Token(跳过结束符,保留纯文本)
 
118
  token_text = tokenizer.decode(token, skip_special_tokens=True)
119
  if token_text.endswith(tokenizer.eos_token):
 
120
  break
121
- # 处理 JSON 转义(确保总控能解析)
 
122
  escaped_text = token_text.replace('"', '\\"').replace('\n', '\\n')
123
- yield '{{"chunk":"{}","finish":false}}\n'.format(escaped_text)
 
 
 
 
 
 
 
 
 
124
  # 生成结束标识
125
- yield '{"chunk":"","finish":true}\n'
126
 
127
  return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
128
 
129
  except Exception as e:
130
  error_msg = f"推理失败:{str(e)}"
131
- logger.error(error_msg, exc_info=True)
 
 
 
 
132
  raise HTTPException(status_code=500, detail=error_msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # 7. 健康检查(确认模型正确加载)
135
  @app.get("/node/health")
136
  async def node_health():
 
 
 
 
 
 
 
137
  return {
138
- "status": "healthy",
139
  "model": MODEL_NAME,
140
  "model_revision": MODEL_REVISION,
 
 
 
141
  "support_stream": True,
 
142
  "note": "DeepSeek-Math-7B-RL 4bit量化,适配16G内存,支持数学推理和通用对话"
143
  }
144
 
145
  if __name__ == "__main__":
146
  import uvicorn
147
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
 
5
  import logging
6
  import torch
7
  import asyncio
8
+ import time
9
  from transformers import (
10
  AutoModelForCausalLM, AutoTokenizer,
11
  BitsAndBytesConfig, TextStreamer
12
  )
13
 
14
+ # 1. 基础配置 - 调整日志格式,增加更多细节
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format="%(asctime)s-%(name)s-%(levelname)s-%(module)s:%(lineno)d-%(message)s"
18
+ )
19
  logger = logging.getLogger("inference_node_deepseek")
20
  app = FastAPI(title="推理节点服务(DeepSeek-Math-7B-RL)")
21
 
22
  # 2. 模型配置:使用 DeepSeek 官方公开且无访问限制的模型
 
 
23
  MODEL_NAME = os.getenv("MODEL_NAME", "deepseek-ai/deepseek-math-7b-rl")
24
+ MODEL_REVISION = "main" # 明确加载主分支,避免版本解析错误
25
  HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") # 公开模型,可留空
26
 
27
  # 3. 4bit量化配置(适配16G内存,DeepSeek 优化)
 
32
  bnb_4bit_compute_dtype=torch.float16 # 降低显存占用,适配 DeepSeek
33
  )
34
 
35
+ # 4. 加载 DeepSeek 模型
36
  try:
37
  logger.info(f"开始加载模型:{MODEL_NAME}(分支:{MODEL_REVISION},4bit量化)")
38
+ # 加载 Tokenizer
39
  tokenizer = AutoTokenizer.from_pretrained(
40
  MODEL_NAME,
41
+ revision=MODEL_REVISION,
42
  token=HF_TOKEN,
43
  padding_side="right",
44
+ trust_remote_code=True
45
  )
46
+ # 手动设置 pad_token
47
  if tokenizer.pad_token is None:
48
  tokenizer.pad_token = tokenizer.eos_token
49
+ logger.info(f"已将pad_token设置为eos_token: {tokenizer.eos_token}")
50
 
51
+ # 加载量化模型
52
  model = AutoModelForCausalLM.from_pretrained(
53
  MODEL_NAME,
54
+ revision=MODEL_REVISION,
55
  quantization_config=bnb_config,
56
  device_map="auto", # 自动分配 GPU/CPU
57
  token=HF_TOKEN,
58
+ trust_remote_code=True,
59
  torch_dtype=torch.float16
60
  )
61
+ # 打印模型设备分配情况,方便调试
62
+ logger.info(f"模型设备分配: {model.hf_device_map}")
63
+
64
+ # 流式生成器
65
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
66
  logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 5-6GB(4bit 量化)")
67
  except Exception as e:
68
  logger.error(f"模型加载失败:{str(e)}", exc_info=True)
69
  raise SystemExit(f"服务终止:{str(e)}")
70
 
71
+ # 5. 请求模型
72
  class NodeInferenceRequest(BaseModel):
73
+ prompt: str # 输入需求
74
  max_tokens: int = 1024
75
+ is_math: bool = False # 是否为数学任务
76
+ request_id: str = None # 新增:请求唯一标识,方便追踪
77
 
78
+ # 6. 流式推理接口 - 增加详细日志
79
  @app.post("/node/stream-infer")
80
  async def stream_infer(req: NodeInferenceRequest, request: Request):
81
+ # 生成唯一请求ID(如果未提供)
82
+ request_id = req.request_id or f"req_{int(time.time() * 1000)}"
83
+ start_time = time.time()
84
+ total_tokens = 0
85
+ first_token_time = None
86
+
87
  try:
88
+ # 记录请求参数
89
+ logger.info(
90
+ f"收到推理请求 | request_id={request_id} | "
91
+ f"is_math={req.is_math} | max_tokens={req.max_tokens} | "
92
+ f"prompt_length={len(req.prompt)}"
93
+ )
94
+
95
+ # 构建提示词
96
  if req.is_math:
97
  prompt = f"""你是专业的数学助手,需详细步骤解答数学问题。
98
  问题:{req.prompt}
 
102
  问题:{req.prompt}
103
  回答:"""
104
 
105
+ # 构建输入
106
  inputs = tokenizer(
107
  prompt,
108
  return_tensors="pt",
109
  truncation=True,
110
+ max_length=2048
111
  ).to(model.device)
112
+
113
+ input_tokens = len(inputs["input_ids"][0])
114
+ logger.info(
115
+ f"请求预处理完成 | request_id={request_id} | "
116
+ f"input_tokens={input_tokens} | device={model.device}"
117
+ )
118
 
119
+ # 异步生成器
120
  async def generate_chunks():
121
+ nonlocal total_tokens, first_token_time
122
+
123
  loop = asyncio.get_running_loop()
124
+ generate_start = time.time()
125
+
126
+ # 调用模型生成
127
  outputs = await loop.run_in_executor(
128
  None,
129
  lambda: model.generate(
 
131
  streamer=streamer,
132
  max_new_tokens=req.max_tokens,
133
  do_sample=True,
134
+ temperature=0.3 if req.is_math else 0.7,
135
  top_p=0.95,
136
  pad_token_id=tokenizer.pad_token_id,
137
  eos_token_id=tokenizer.eos_token_id
138
  )
139
  )
140
+
141
+ generate_end = time.time()
142
+ logger.info(
143
+ f"模型生成完成 | request_id={request_id} | "
144
+ f"generate_time={generate_end - generate_start:.2f}s"
145
+ )
146
 
147
+ # 处理生成结果
148
  generated_tokens = outputs[0][len(inputs["input_ids"][0]):]
149
+ total_tokens = len(generated_tokens)
150
+ logger.info(
151
+ f"开始处理生成结果 | request_id={request_id} | "
152
+ f"generated_tokens={total_tokens}"
153
+ )
154
+
155
+ for i, token in enumerate(generated_tokens):
156
+ # 记录首字符生成时间
157
+ if i == 0:
158
+ first_token_time = time.time()
159
+ logger.info(
160
+ f"首字符生成 | request_id={request_id} | "
161
+ f"first_token_latency={first_token_time - start_time:.2f}s"
162
+ )
163
+
164
  if await request.is_disconnected():
165
+ logger.warning(f"客户端断开连接 | request_id={request_id} | generated_tokens={i+1}")
166
  break
167
+
168
+ # 解码Token
169
  token_text = tokenizer.decode(token, skip_special_tokens=True)
170
  if token_text.endswith(tokenizer.eos_token):
171
+ logger.info(f"遇到结束符 | request_id={request_id} | position={i+1}")
172
  break
173
+
174
+ # 处理JSON转义
175
  escaped_text = token_text.replace('"', '\\"').replace('\n', '\\n')
176
+ yield '{{"chunk":"{}","finish":false,"request_id":"{}"}}\n'.format(escaped_text, request_id)
177
+
178
+ # 每生成50个token记录一次进度
179
+ if (i + 1) % 50 == 0:
180
+ logger.info(
181
+ f"生成进度 | request_id={request_id} | "
182
+ f"completed_tokens={i+1}/{total_tokens} | "
183
+ f"speed={(i+1)/(time.time() - generate_start):.2f}tokens/s"
184
+ )
185
+
186
  # 生成结束标识
187
+ yield '{"chunk":"","finish":true,"request_id":"{}"}\n'.format(request_id)
188
 
189
  return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
190
 
191
  except Exception as e:
192
  error_msg = f"推理失败:{str(e)}"
193
+ logger.error(
194
+ f"推理过程出错 | request_id={request_id} | "
195
+ f"error={error_msg} | elapsed_time={time.time() - start_time:.2f}s",
196
+ exc_info=True
197
+ )
198
  raise HTTPException(status_code=500, detail=error_msg)
199
+ finally:
200
+ # 记录请求完成信息
201
+ elapsed_time = time.time() - start_time
202
+ if total_tokens > 0 and elapsed_time > 0:
203
+ speed = total_tokens / elapsed_time
204
+ logger.info(
205
+ f"请求处理完成 | request_id={request_id} | "
206
+ f"total_tokens={total_tokens} | "
207
+ f"total_time={elapsed_time:.2f}s | "
208
+ f"average_speed={speed:.2f}tokens/s"
209
+ )
210
+ else:
211
+ logger.info(
212
+ f"请求处理完成 | request_id={request_id} | "
213
+ f"total_time={elapsed_time:.2f}s | 未生成有效内容"
214
+ )
215
 
216
+ # 7. 健康检查接口 - 增加更多信息
217
  @app.get("/node/health")
218
  async def node_health():
219
+ # 检查模型是否可用
220
+ model_available = isinstance(model, AutoModelForCausalLM)
221
+ tokenizer_available = isinstance(tokenizer, AutoTokenizer)
222
+
223
+ # 获取设备信息
224
+ device_info = str(model.device) if model_available else "unknown"
225
+
226
  return {
227
+ "status": "healthy" if model_available and tokenizer_available else "unhealthy",
228
  "model": MODEL_NAME,
229
  "model_revision": MODEL_REVISION,
230
+ "model_available": model_available,
231
+ "tokenizer_available": tokenizer_available,
232
+ "device": device_info,
233
  "support_stream": True,
234
+ "timestamp": time.time(),
235
  "note": "DeepSeek-Math-7B-RL 4bit量化,适配16G内存,支持数学推理和通用对话"
236
  }
237
 
238
  if __name__ == "__main__":
239
  import uvicorn
240
+ logger.info("启动推理服务...")
241
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")