fiewolf1000 commited on
Commit
a31d0b8
·
verified ·
1 Parent(s): 1426e3d

Update inference_node.py

Browse files
Files changed (1) hide show
  1. inference_node.py +28 -54
inference_node.py CHANGED
@@ -13,15 +13,13 @@ from transformers import (
13
  # 1. 基础配置
14
  logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s")
15
  logger = logging.getLogger("inference_node")
16
- app = FastAPI(title="推理节点服务(单一模型)")
17
 
18
- # 2. 模型配置(修复:使用正确的模型名,支持通过环境变量覆盖)
19
- # 正确模型名:Qwen/Qwen-0.5B-Instruct(Hugging Face 官方存在)
20
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-0.5B-Instruct")
21
- # 从环境变量获取 Hugging Face 令牌(必填,部分模型需登录)
22
- HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
23
 
24
- # 3. 4bit量化配置(适配16G内存,降低显存占用)
25
  bnb_config = BitsAndBytesConfig(
26
  load_in_4bit=True,
27
  bnb_4bit_use_double_quant=True,
@@ -29,104 +27,80 @@ bnb_config = BitsAndBytesConfig(
29
  bnb_4bit_compute_dtype=torch.bfloat16
30
  )
31
 
32
- # 4. 加载模型(修复:用 token 参数替代 use_auth_token,增加错误捕获)
33
  try:
34
  logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)")
35
- # 加载 Tokenizer(修复参数:用 token 替代 use_auth_token)
36
  tokenizer = AutoTokenizer.from_pretrained(
37
  MODEL_NAME,
38
- token=HF_TOKEN, # 新参数:传递 Hugging Face 令牌
39
- padding_side="right", # 避免生成时的警告
40
- trust_remote_code=True # 加载 Qwen 模型需开启(支持自定义代码)
41
  )
42
- # 加载量化模型
43
  model = AutoModelForCausalLM.from_pretrained(
44
  MODEL_NAME,
45
  quantization_config=bnb_config,
46
- device_map="auto", # 自动分配 GPU/CPU(优先用 GPU)
47
- token=HF_TOKEN, # 传递令牌(部分模型需授权)
48
- trust_remote_code=True # Qwen 模型必需
49
  )
50
- # 流式生成器(逐段输出结果)
51
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
52
- logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 2-3GB")
53
  except Exception as e:
54
  logger.error(f"模型加载失败:{str(e)}", exc_info=True)
55
- # 启动时加载失败直接退出(避免服务异常运行)
56
- raise SystemExit(f"模型加载失败,服务终止:{str(e)}")
57
 
58
- # 5. 请求模型(与总控约定的格式,无修改)
59
  class NodeInferenceRequest(BaseModel):
60
- prompt: str # 总控拼接好的完整 Prompt(含用户上下文)
61
- max_tokens: int = 1024 # 最大生成长度
62
 
63
- # 6. 流式推理接口(核心逻辑无修改,确保异步兼容)
64
  @app.post("/node/stream-infer")
65
  async def stream_infer(req: NodeInferenceRequest, request: Request):
66
  try:
67
- # 预处理 Prompt(Qwen 模型需用专用方法构建输入)
68
  inputs = tokenizer.build_chat_input(
69
- [{"role": "user", "content": req.prompt}], # 适配 Qwen 对话格式
70
- add_generation_prompt=True # 自动添加“助手回复”的提示
71
  ).to(model.device)
72
 
73
- # 异步生成器:避免阻塞 FastAPI 事件循环
74
  async def generate_chunks():
75
- generated_text = ""
76
- # 用线程池执行同步的模型生成(避免阻塞异步接口)
77
  loop = asyncio.get_running_loop()
78
  outputs = await loop.run_in_executor(
79
- None, # 使用默认线程池
80
  lambda: model.generate(
81
  **inputs,
82
  streamer=streamer,
83
  max_new_tokens=req.max_tokens,
84
  do_sample=True,
85
- temperature=0.7, # 随机性(0~1,越小越确定)
86
- pad_token_id=tokenizer.eos_token_id # 避免警告
87
  )
88
  )
89
 
90
- # 逐段解码并返回结果
91
  for token in outputs[0][len(inputs["input_ids"][0]):]:
92
- # 检查客户端是否断开连接(提前终止,节省资源)
93
  if await request.is_disconnected():
94
- logger.info("客户端断开连接,停止生成")
95
  break
96
- # 解码 Token(跳过特殊字符)
97
  token_text = tokenizer.decode(token, skip_special_tokens=True)
98
- generated_text += token_text
99
- # 处理双引号转义(确保 JSON 格式合法)
100
  escaped_text = token_text.replace('"', '\\"')
101
- # 用 format 拼接 JSON,避免引号冲突
102
  yield '{{"chunk":"{}","finish":false}}\n'.format(escaped_text)
103
-
104
- # 生成结束标识
105
  yield '{"chunk":"","finish":true}\n'
106
 
107
- # 返回流式响应(指定媒体类型为 JSON 流)
108
- return StreamingResponse(
109
- generate_chunks(),
110
- media_type="application/x-ndjson",
111
- headers={"Cache-Control": "no-cache"}
112
- )
113
 
114
  except Exception as e:
115
- error_msg = f"推理失败:{str(e)}"
116
- logger.error(error_msg, exc_info=True)
117
- raise HTTPException(status_code=500, detail=error_msg)
118
 
119
- # 7. 健康检查接口(总控用于检测节点状态)
120
  @app.get("/node/health")
121
  async def node_health():
122
  return {
123
  "status": "healthy",
124
  "model": MODEL_NAME,
125
  "support_stream": True,
126
- "note": "Qwen-0.5B-Instruct 4bit量化,显存占用~2GB"
127
  }
128
 
129
  if __name__ == "__main__":
130
  import uvicorn
131
- # 启动服务(Hugging Face Space 默认端口 7860)
132
  uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
13
  # 1. 基础配置
14
  logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s")
15
  logger = logging.getLogger("inference_node")
16
+ app = FastAPI(title="推理节点服务(Qwen-7B)")
17
 
18
+ # 2. 模型配置(使用真实存在的 Qwen-7B)
19
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-7B")
20
+ HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") # Qwen-7B 公开,可留空
 
 
21
 
22
+ # 3. 4bit量化配置(适配16G内存)
23
  bnb_config = BitsAndBytesConfig(
24
  load_in_4bit=True,
25
  bnb_4bit_use_double_quant=True,
 
27
  bnb_4bit_compute_dtype=torch.bfloat16
28
  )
29
 
30
+ # 4. 加载模型
31
  try:
32
  logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)")
 
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  MODEL_NAME,
35
+ token=HF_TOKEN, # 公开模型可留空
36
+ padding_side="right",
37
+ trust_remote_code=True # Qwen 模型必需
38
  )
 
39
  model = AutoModelForCausalLM.from_pretrained(
40
  MODEL_NAME,
41
  quantization_config=bnb_config,
42
+ device_map="auto",
43
+ token=HF_TOKEN,
44
+ trust_remote_code=True
45
  )
 
46
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
47
+ logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 4-5GB(4bit 量化)")
48
  except Exception as e:
49
  logger.error(f"模型加载失败:{str(e)}", exc_info=True)
50
+ raise SystemExit(f"服务终止:{str(e)}")
 
51
 
52
+ # 5. 请求模型
53
  class NodeInferenceRequest(BaseModel):
54
+ prompt: str
55
+ max_tokens: int = 1024
56
 
57
+ # 6. 流式推理接口
58
  @app.post("/node/stream-infer")
59
  async def stream_infer(req: NodeInferenceRequest, request: Request):
60
  try:
 
61
  inputs = tokenizer.build_chat_input(
62
+ [{"role": "user", "content": req.prompt}],
63
+ add_generation_prompt=True
64
  ).to(model.device)
65
 
 
66
  async def generate_chunks():
 
 
67
  loop = asyncio.get_running_loop()
68
  outputs = await loop.run_in_executor(
69
+ None,
70
  lambda: model.generate(
71
  **inputs,
72
  streamer=streamer,
73
  max_new_tokens=req.max_tokens,
74
  do_sample=True,
75
+ temperature=0.7,
76
+ pad_token_id=tokenizer.eos_token_id
77
  )
78
  )
79
 
 
80
  for token in outputs[0][len(inputs["input_ids"][0]):]:
 
81
  if await request.is_disconnected():
 
82
  break
 
83
  token_text = tokenizer.decode(token, skip_special_tokens=True)
 
 
84
  escaped_text = token_text.replace('"', '\\"')
 
85
  yield '{{"chunk":"{}","finish":false}}\n'.format(escaped_text)
 
 
86
  yield '{"chunk":"","finish":true}\n'
87
 
88
+ return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
 
 
 
 
 
89
 
90
  except Exception as e:
91
+ logger.error(f"推理失败:{str(e)}")
92
+ raise HTTPException(status_code=500, detail=str(e))
 
93
 
94
+ # 7. 健康检查
95
  @app.get("/node/health")
96
  async def node_health():
97
  return {
98
  "status": "healthy",
99
  "model": MODEL_NAME,
100
  "support_stream": True,
101
+ "note": "Qwen-7B 4bit量化,适配16G内存"
102
  }
103
 
104
  if __name__ == "__main__":
105
  import uvicorn
 
106
  uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")