fiewolf1000 commited on
Commit
ca2bc68
·
verified ·
1 Parent(s): 59a13f0

Update inference_node.py

Browse files
Files changed (1) hide show
  1. inference_node.py +172 -126
inference_node.py CHANGED
@@ -6,236 +6,282 @@ import logging
6
  import torch
7
  import asyncio
8
  import time
 
9
  from transformers import (
10
  AutoModelForCausalLM, AutoTokenizer,
11
  BitsAndBytesConfig, TextStreamer
12
  )
13
 
14
- # 1. 基础配置 - 调整日志格式,增加更多细节
 
 
 
 
 
 
 
 
 
 
15
  logging.basicConfig(
16
- level=logging.INFO,
17
  format="%(asctime)s-%(name)s-%(levelname)s-%(module)s:%(lineno)d-%(message)s"
18
  )
19
- logger = logging.getLogger("inference_node_deepseek")
20
- app = FastAPI(title="推理节点服务(DeepSeek-Math-7B-RL)")
21
 
22
- # 2. 模型配置:使用 DeepSeek 官方公开且无访问限制的模型
23
- MODEL_NAME = os.getenv("MODEL_NAME", "deepseek-ai/deepseek-math-7b-rl")
24
- MODEL_REVISION = "main" # 明确加载主分支,避免版本解析错误
25
- HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") # 公开模型,可留空
 
 
26
 
27
- # 3. 4bit量化配置(适配16G内存,DeepSeek 优化
28
  bnb_config = BitsAndBytesConfig(
29
  load_in_4bit=True,
30
  bnb_4bit_use_double_quant=True,
31
- bnb_4bit_quant_type="nf4",
32
- bnb_4bit_compute_dtype=torch.float16 # 降低显存占用,适配 DeepSeek
 
33
  )
34
 
35
- # 4. 加载 DeepSeek 模型
 
 
36
  try:
37
- logger.info(f"开始加载模型:{MODEL_NAME}(分支:{MODEL_REVISION},4bit量化)")
38
- # 加载 Tokenizer
 
39
  tokenizer = AutoTokenizer.from_pretrained(
40
  MODEL_NAME,
41
  revision=MODEL_REVISION,
42
  token=HF_TOKEN,
43
  padding_side="right",
44
- trust_remote_code=True
 
45
  )
46
- # 手动设置 pad_token
47
  if tokenizer.pad_token is None:
48
  tokenizer.pad_token = tokenizer.eos_token
49
- logger.info(f"已将pad_token设置为eos_token: {tokenizer.eos_token}")
50
 
51
- # 加载量化模型
52
  model = AutoModelForCausalLM.from_pretrained(
53
  MODEL_NAME,
54
  revision=MODEL_REVISION,
55
  quantization_config=bnb_config,
56
- device_map="auto", # 自动分配 GPU/CPU
57
  token=HF_TOKEN,
58
  trust_remote_code=True,
59
- torch_dtype=torch.float16
 
60
  )
61
- # 打印模型设备分配情况,方便调试
62
- logger.info(f"模型设备分配: {model.hf_device_map}")
63
 
64
- # 流式生成器
65
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
66
- logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 5-6GB(4bit 量化)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
- logger.error(f"模型加载失败:{str(e)}", exc_info=True)
69
- raise SystemExit(f"服务终止{str(e)}")
70
 
71
- # 5. 请求模型
 
 
72
  class NodeInferenceRequest(BaseModel):
73
- prompt: str # 输入需求
74
- max_tokens: int = 1024
75
- is_math: bool = False # 是否为数学任务
76
- request_id: str = None # 新增:请求唯一标识,方便追踪
77
 
78
- # 6. 流式推理接口 - 增加详细日志
 
 
79
  @app.post("/node/stream-infer")
80
  async def stream_infer(req: NodeInferenceRequest, request: Request):
81
- # 生成唯一请求ID(如果未提供)
82
- request_id = req.request_id or f"req_{int(time.time() * 1000)}"
83
  start_time = time.time()
84
  total_tokens = 0
85
  first_token_time = None
 
86
 
87
  try:
88
- # 记录请求参数
89
  logger.info(
90
- f"收到推理请求 | request_id={request_id} | "
91
- f"is_math={req.is_math} | max_tokens={req.max_tokens} | "
92
- f"prompt_length={len(req.prompt)}"
93
  )
94
 
95
- # 构建提示词
96
- if req.is_math:
97
- prompt = f"""你是专业的数学助手,需详细步骤解答数学问题。
98
- 问题:{req.prompt}
99
- 解答(含步骤):"""
100
- else:
101
- prompt = f"""你是通用对话助手,需清晰、准确地回答问题。
102
- 问题:{req.prompt}
103
- 回答:"""
104
 
105
- # 构建输入
106
  inputs = tokenizer(
107
  prompt,
108
  return_tensors="pt",
109
  truncation=True,
110
- max_length=2048
111
- ).to(model.device)
112
-
113
- input_tokens = len(inputs["input_ids"][0])
114
- logger.info(
115
- f"请求预处理完成 | request_id={request_id} | "
116
- f"input_tokens={input_tokens} | device={model.device}"
117
  )
 
 
118
 
119
- # 异步生成
120
  async def generate_chunks():
121
  nonlocal total_tokens, first_token_time
122
 
123
  loop = asyncio.get_running_loop()
124
- generate_start = time.time()
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- # 调用模型生成
127
- outputs = await loop.run_in_executor(
128
- None,
129
- lambda: model.generate(
130
- **inputs,
131
- streamer=streamer,
132
- max_new_tokens=req.max_tokens,
133
- do_sample=True,
134
- temperature=0.3 if req.is_math else 0.7,
135
- top_p=0.95,
136
- pad_token_id=tokenizer.pad_token_id,
137
- eos_token_id=tokenizer.eos_token_id
138
- )
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- generate_end = time.time()
142
- logger.info(
143
- f"模型生成完成 | request_id={request_id} | "
144
- f"generate_time={generate_end - generate_start:.2f}s"
145
- )
146
-
147
  # 处理生成结果
148
- generated_tokens = outputs[0][len(inputs["input_ids"][0]):]
149
  total_tokens = len(generated_tokens)
150
  logger.info(
151
- f"开始处理生成结果 | request_id={request_id} | "
152
- f"generated_tokens={total_tokens}"
 
153
  )
154
 
 
155
  for i, token in enumerate(generated_tokens):
156
- # 记录首字符生成时间
157
  if i == 0:
158
  first_token_time = time.time()
159
  logger.info(
160
- f"首字符生成 | request_id={request_id} | "
161
- f"first_token_latency={first_token_time - start_time:.2f}s"
162
  )
163
 
 
164
  if await request.is_disconnected():
165
- logger.warning(f"客户端断开连接 | request_id={request_id} | generated_tokens={i+1}")
166
  break
167
 
168
- # 解码Token
169
  token_text = tokenizer.decode(token, skip_special_tokens=True)
170
  if token_text.endswith(tokenizer.eos_token):
171
- logger.info(f"遇到结束符 | request_id={request_id} | position={i+1}")
172
  break
173
 
174
- # 处理JSON转义
175
  escaped_text = token_text.replace('"', '\\"').replace('\n', '\\n')
176
  yield '{{"chunk":"{}","finish":false,"request_id":"{}"}}\n'.format(escaped_text, request_id)
177
-
178
- # 每生成50个token记录一次进度
179
- if (i + 1) % 50 == 0:
180
- logger.info(
181
- f"生成进度 | request_id={request_id} | "
182
- f"completed_tokens={i+1}/{total_tokens} | "
183
- f"speed={(i+1)/(time.time() - generate_start):.2f}tokens/s"
184
- )
185
 
186
- # 生成结束标识
187
  yield '{"chunk":"","finish":true,"request_id":"{}"}\n'.format(request_id)
188
 
189
  return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
190
 
191
  except Exception as e:
192
- error_msg = f"推理失败{str(e)}"
193
  logger.error(
194
- f"推理过程出错 | request_id={request_id} | "
195
- f"error={error_msg} | elapsed_time={time.time() - start_time:.2f}s",
196
  exc_info=True
197
  )
198
  raise HTTPException(status_code=500, detail=error_msg)
199
  finally:
200
- # 记录请求完成信息
201
  elapsed_time = time.time() - start_time
202
  if total_tokens > 0 and elapsed_time > 0:
203
  speed = total_tokens / elapsed_time
204
  logger.info(
205
- f"请求处理完成 | request_id={request_id} | "
206
- f"total_tokens={total_tokens} | "
207
- f"total_time={elapsed_time:.2f}s | "
208
- f"average_speed={speed:.2f}tokens/s"
209
- )
210
- else:
211
- logger.info(
212
- f"请求处理完成 | request_id={request_id} | "
213
- f"total_time={elapsed_time:.2f}s | 未生成有效内容"
214
  )
215
 
216
- # 7. 健康检查接口 - 增加更多信息
 
 
217
  @app.get("/node/health")
218
  async def node_health():
219
- # 检查模型是否可用
 
 
220
  model_available = isinstance(model, AutoModelForCausalLM)
221
- tokenizer_available = isinstance(tokenizer, AutoTokenizer)
222
-
223
- # 获取设备信息
224
- device_info = str(model.device) if model_available else "unknown"
225
 
226
  return {
227
- "status": "healthy" if model_available and tokenizer_available else "unhealthy",
228
  "model": MODEL_NAME,
229
- "model_revision": MODEL_REVISION,
230
- "model_available": model_available,
231
- "tokenizer_available": tokenizer_available,
232
- "device": device_info,
233
- "support_stream": True,
234
- "timestamp": time.time(),
235
- "note": "DeepSeek-Math-7B-RL 4bit量化,适配16G内存,支持数学推理和通用对话"
 
 
 
 
236
  }
237
 
238
  if __name__ == "__main__":
239
  import uvicorn
240
- logger.info("启动推理服务...")
241
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
 
 
 
 
 
 
6
  import torch
7
  import asyncio
8
  import time
9
+ import psutil # 新增:用于CPU监控
10
  from transformers import (
11
  AutoModelForCausalLM, AutoTokenizer,
12
  BitsAndBytesConfig, TextStreamer
13
  )
14
 
15
+ # --------------------------
16
+ # 1. 环境与性能优化配置(核心)
17
+ # --------------------------
18
+ # 绑定CPU线程(2核专用配置,避免线程切换开销)
19
+ os.environ["OMP_NUM_THREADS"] = "2"
20
+ os.environ["MKL_NUM_THREADS"] = "2"
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # 禁用tokenizer并行(2核效率低)
22
+
23
+ # --------------------------
24
+ # 2. 日志配置(增强监控粒度)
25
+ # --------------------------
26
  logging.basicConfig(
27
+ level=logging.INFO,
28
  format="%(asctime)s-%(name)s-%(levelname)s-%(module)s:%(lineno)d-%(message)s"
29
  )
30
+ logger = logging.getLogger("optimized_deepseek_math")
31
+ app = FastAPI(title="优化版DeepSeek-Math推理服务(2核CPU适配)")
32
 
33
+ # --------------------------
34
+ # 3. 模型配置(量化与加载优化)
35
+ # --------------------------
36
+ MODEL_NAME = os.getenv("MODEL_NAME", "deepseek-ai/deepseek-math-7b-rl")
37
+ MODEL_REVISION = "main"
38
+ HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
39
 
40
+ # 4bit量化参数调优(适配2核CPU计算特性
41
  bnb_config = BitsAndBytesConfig(
42
  load_in_4bit=True,
43
  bnb_4bit_use_double_quant=True,
44
+ bnb_4bit_quant_type="nf4", # 数学模型推荐nf4量化,精度损失小
45
+ bnb_4bit_compute_dtype=torch.float16,
46
+ bnb_4bit_quant_storage_dtype=torch.uint8 # 存储类型降级,减少内存访问耗时
47
  )
48
 
49
+ # --------------------------
50
+ # 4. 模型加载(添加硬件适配逻辑)
51
+ # --------------------------
52
  try:
53
+ logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化,2核CPU优化)")
54
+
55
+ # 加载Tokenizer(禁用快速tokenizer,减少内存波动)
56
  tokenizer = AutoTokenizer.from_pretrained(
57
  MODEL_NAME,
58
  revision=MODEL_REVISION,
59
  token=HF_TOKEN,
60
  padding_side="right",
61
+ trust_remote_code=True,
62
+ use_fast=False # 2核CPU下,慢速tokenizer更稳定
63
  )
 
64
  if tokenizer.pad_token is None:
65
  tokenizer.pad_token = tokenizer.eos_token
66
+ logger.info(f"已设置pad_token: {tokenizer.eos_token}")
67
 
68
+ # 加载模型(强制CPU运行,禁用GPU检测)
69
  model = AutoModelForCausalLM.from_pretrained(
70
  MODEL_NAME,
71
  revision=MODEL_REVISION,
72
  quantization_config=bnb_config,
73
+ device_map="cpu", # 2核环境强制CPU,避免自动分配逻辑消耗资源
74
  token=HF_TOKEN,
75
  trust_remote_code=True,
76
+ torch_dtype=torch.float16,
77
+ low_cpu_mem_usage=True # 启用低内存模式,减少加载时峰值占用
78
  )
 
 
79
 
80
+ # 验证CPU指令集支持(AVX2对数学计算加速明显)
81
+ try:
82
+ import subprocess
83
+ avx2_support = subprocess.check_output(
84
+ "grep -c avx2 /proc/cpuinfo", shell=True
85
+ ).decode().strip()
86
+ logger.info(f"CPU AVX2支持: {'是' if int(avx2_support) > 0 else '否'}")
87
+ except Exception as e:
88
+ logger.warning(f"AVX2检测失败: {str(e)}")
89
+
90
+ # 流式生成器配置(减少中间缓存)
91
+ streamer = TextStreamer(
92
+ tokenizer,
93
+ skip_prompt=True,
94
+ skip_special_tokens=True,
95
+ timeout=30.0 # 适配2核生成速度,避免超时
96
+ )
97
+
98
+ logger.info(f"模型加载完成!内存占用: {psutil.virtual_memory().used / 1024**3:.2f}GB")
99
  except Exception as e:
100
+ logger.error(f"模型加载失��: {str(e)}", exc_info=True)
101
+ raise SystemExit(f"服务终止: {str(e)}")
102
 
103
+ # --------------------------
104
+ # 5. 请求模型(精简参数)
105
+ # --------------------------
106
  class NodeInferenceRequest(BaseModel):
107
+ prompt: str
108
+ max_tokens: int = 512 # 2核环境缩短默认长度,控制总耗时
109
+ is_math: bool = False
110
+ request_id: str = None
111
 
112
+ # --------------------------
113
+ # 6. 流式推理接口(核心优化)
114
+ # --------------------------
115
  @app.post("/node/stream-infer")
116
  async def stream_infer(req: NodeInferenceRequest, request: Request):
117
+ request_id = req.request_id or f"req_{int(time.time()*1000)}"
 
118
  start_time = time.time()
119
  total_tokens = 0
120
  first_token_time = None
121
+ cpu_monitor_interval = 10 # 每生成10个token监控一次CPU
122
 
123
  try:
124
+ # 记录请求基础信息
125
  logger.info(
126
+ f"请求开始 | request_id={request_id} | "
127
+ f"prompt_len={len(req.prompt)} | max_tokens={req.max_tokens}"
 
128
  )
129
 
130
+ # 构建提示词(精简模板,减少无效计算)
131
+ prompt = f"问题:{req.prompt}\n{'解答(含步骤)' if req.is_math else '回答'}:"
 
 
 
 
 
 
 
132
 
133
+ # 输入处理(严格控制长度,避免2核CPU过载)
134
  inputs = tokenizer(
135
  prompt,
136
  return_tensors="pt",
137
  truncation=True,
138
+ max_length=1536 # 预留512token给生成结果
 
 
 
 
 
 
139
  )
140
+ input_tokens = len(inputs["input_ids"][0])
141
+ logger.info(f"输入处理完成 | input_tokens={input_tokens}")
142
 
143
+ # 异步生成逻辑
144
  async def generate_chunks():
145
  nonlocal total_tokens, first_token_time
146
 
147
  loop = asyncio.get_running_loop()
148
+ # 预计算生成参数(减少生成过程中的条件判断)
149
+ gen_kwargs = {
150
+ **inputs,
151
+ streamer=streamer,
152
+ max_new_tokens=req.max_tokens,
153
+ do_sample=True,
154
+ temperature=0.2 if req.is_math else 0.6, # 降低随机性加速生成
155
+ top_p=0.9 if req.is_math else 0.95,
156
+ pad_token_id=tokenizer.pad_token_id,
157
+ eos_token_id=tokenizer.eos_token_id,
158
+ repetition_penalty=1.05 # 轻微抑制重复,不增加太多计算量
159
+ }
160
 
161
+ # 启动生成并监控CPU
162
+ def generate_and_monitor():
163
+ # 生成过程中每1秒记录一次CPU(独立线程)
164
+ cpu_logger = None
165
+ def log_cpu_usage():
166
+ while True:
167
+ cpu_percent = psutil.cpu_percent(interval=1)
168
+ per_core = psutil.cpu_percent(percpu=True)
169
+ logger.info(
170
+ f"CPU实时监控 | request_id={request_id} | "
171
+ f"整体使用率={cpu_percent}% | 核心使用率={per_core}"
172
+ )
173
+ time.sleep(1)
174
+
175
+ # 启动CPU监控线程
176
+ import threading
177
+ cpu_logger = threading.Thread(target=log_cpu_usage, daemon=True)
178
+ cpu_logger.start()
179
+
180
+ # 执行生成
181
+ try:
182
+ return model.generate(** gen_kwargs)
183
+ finally:
184
+ # 生成结束后终止监控线程
185
+ if cpu_logger and cpu_logger.is_alive():
186
+ # 温和终止线程(避免资源泄漏)
187
+ import ctypes
188
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(
189
+ ctypes.c_long(cpu_logger.ident),
190
+ ctypes.py_object(SystemExit)
191
+ )
192
+
193
+ # 在 executor 中运行生成逻辑(带CPU监控)
194
+ outputs = await loop.run_in_executor(None, generate_and_monitor)
195
 
 
 
 
 
 
 
196
  # 处理生成结果
197
+ generated_tokens = outputs[0][input_tokens:]
198
  total_tokens = len(generated_tokens)
199
  logger.info(
200
+ f"生成完成 | request_id={request_id} | "
201
+ f"generated_tokens={total_tokens} | "
202
+ f"耗时={(time.time()-start_time):.2f}s"
203
  )
204
 
205
+ # 流式返回处理
206
  for i, token in enumerate(generated_tokens):
 
207
  if i == 0:
208
  first_token_time = time.time()
209
  logger.info(
210
+ f"首token生成 | request_id={request_id} | "
211
+ f"延迟={(first_token_time - start_time):.2f}s"
212
  )
213
 
214
+ # 客户端断开连接检测
215
  if await request.is_disconnected():
216
+ logger.warning(f"客户端断开 | request_id={request_id} | 已生成{i+1}token")
217
  break
218
 
219
+ # 解码与转义
220
  token_text = tokenizer.decode(token, skip_special_tokens=True)
221
  if token_text.endswith(tokenizer.eos_token):
 
222
  break
223
 
 
224
  escaped_text = token_text.replace('"', '\\"').replace('\n', '\\n')
225
  yield '{{"chunk":"{}","finish":false,"request_id":"{}"}}\n'.format(escaped_text, request_id)
 
 
 
 
 
 
 
 
226
 
227
+ # 结束标识
228
  yield '{"chunk":"","finish":true,"request_id":"{}"}\n'.format(request_id)
229
 
230
  return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
231
 
232
  except Exception as e:
233
+ error_msg = f"推理失败: {str(e)}"
234
  logger.error(
235
+ f"请求出错 | request_id={request_id} | "
236
+ f"error={error_msg} | 耗时={(time.time()-start_time):.2f}s",
237
  exc_info=True
238
  )
239
  raise HTTPException(status_code=500, detail=error_msg)
240
  finally:
241
+ # 输出性能总结
242
  elapsed_time = time.time() - start_time
243
  if total_tokens > 0 and elapsed_time > 0:
244
  speed = total_tokens / elapsed_time
245
  logger.info(
246
+ f"请求总结 | request_id={request_id} | "
247
+ f"总token={total_tokens} | 总耗时={elapsed_time:.2f}s | "
248
+ f"平均速率={speed:.2f}token/s | "
249
+ f"内存占用={psutil.virtual_memory().used / 1024**3:.2f}GB"
 
 
 
 
 
250
  )
251
 
252
+ # --------------------------
253
+ # 7. 增强版健康检查接口
254
+ # --------------------------
255
  @app.get("/node/health")
256
  async def node_health():
257
+ # 实时硬件状态
258
+ cpu_percent = psutil.cpu_percent(interval=0.5)
259
+ mem_usage = psutil.virtual_memory().percent
260
  model_available = isinstance(model, AutoModelForCausalLM)
 
 
 
 
261
 
262
  return {
263
+ "status": "healthy" if model_available else "unhealthy",
264
  "model": MODEL_NAME,
265
+ "hardware": {
266
+ "cpu_cores": psutil.cpu_count(logical=False),
267
+ "logical_cores": psutil.cpu_count(logical=True),
268
+ "cpu_usage": f"{cpu_percent}%",
269
+ "memory_usage": f"{mem_usage}%"
270
+ },
271
+ "performance": {
272
+ "target_speed": "1.5-2 token/s (2核CPU)",
273
+ "quantization": "4bit NF4"
274
+ },
275
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
276
  }
277
 
278
  if __name__ == "__main__":
279
  import uvicorn
280
+ # 启动参数优化(2核专用)
281
+ uvicorn.run(
282
+ app,
283
+ host="0.0.0.0",
284
+ port=7860,
285
+ log_level="info",
286
+ workers=1 # 2核环境禁用多worker,避免资源竞争
287
+ )