fiewolf1000 commited on
Commit
aa269af
·
verified ·
1 Parent(s): 6b2cd32

Update inference_node.py

Browse files
Files changed (1) hide show
  1. inference_node.py +74 -46
inference_node.py CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
4
  import os
5
  import logging
6
  import torch
7
- import asyncio # 新增异步依赖
8
  from transformers import (
9
  AutoModelForCausalLM, AutoTokenizer,
10
  BitsAndBytesConfig, TextStreamer
@@ -15,11 +15,13 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname
15
  logger = logging.getLogger("inference_node")
16
  app = FastAPI(title="推理节点服务(单一模型)")
17
 
18
- # 2. 模型配置(每个节点仅加载一个模型,通过环境变量指定
19
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-2-0.5B-Instruct") # 节点启动时指定模型
20
- hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
 
 
21
 
22
- # 3. 4bit量化(适配16G内存)
23
  bnb_config = BitsAndBytesConfig(
24
  load_in_4bit=True,
25
  bnb_4bit_use_double_quant=True,
@@ -27,34 +29,51 @@ bnb_config = BitsAndBytesConfig(
27
  bnb_4bit_compute_dtype=torch.bfloat16
28
  )
29
 
30
- # 4. 加载模型(启动时加载单一模型
31
- logger.info(f"加载模型:{MODEL_NAME}(4bit量化)")
32
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=hf_token, padding_side="right")
33
- model = AutoModelForCausalLM.from_pretrained(
34
- MODEL_NAME,
35
- quantization_config=bnb_config,
36
- device_map="auto",
37
- use_auth_token=hf_token
38
- )
39
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
40
- logger.info(f"模型加载完成:{MODEL_NAME}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # 5. 请求模型(与总控约定的格式)
43
  class NodeInferenceRequest(BaseModel):
44
- prompt: str # 总控拼接好的完整Prompt(含用户上下文)
45
- max_tokens: int = 1024
46
 
47
- # 6. 流式推理接口(仅处理推理不存上下文
48
  @app.post("/node/stream-infer")
49
  async def stream_infer(req: NodeInferenceRequest, request: Request):
50
  try:
51
- # 模型生成(流式):异步线程避免阻塞事件循环
52
- inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
53
-
54
- # 异步生成器:必须用 async def
 
 
 
55
  async def generate_chunks():
56
  generated_text = ""
57
- # 模型生成是同步操作,用线程池异步执行(避免阻塞FastAPI
58
  loop = asyncio.get_running_loop()
59
  outputs = await loop.run_in_executor(
60
  None, # 使用默认线程池
@@ -63,42 +82,51 @@ async def stream_infer(req: NodeInferenceRequest, request: Request):
63
  streamer=streamer,
64
  max_new_tokens=req.max_tokens,
65
  do_sample=True,
66
- temperature=0.7,
67
- pad_token_id=tokenizer.eos_token_id
68
  )
69
  )
70
-
71
- # 逐段处理生成结果
72
  for token in outputs[0][len(inputs["input_ids"][0]):]:
73
- # 检查客户端是否断开连接(提前终止,避免无效计算
74
  if await request.is_disconnected():
75
  logger.info("客户端断开连接,停止生成")
76
  break
77
-
78
- # 解码token并处理双引号转义(避免JSON格式错误)
79
  token_text = tokenizer.decode(token, skip_special_tokens=True)
80
  generated_text += token_text
81
- escaped_text = token_text.replace('"', '\\"') # 提前处理双引号转义
82
-
83
- # 用 str.format() 拼接JSON,彻底避免f-string引号冲突
84
- json_chunk = '{{"chunk":"{}","finish":false}}\n'.format(escaped_text)
85
- yield json_chunk
86
-
87
- # 生成结束标识(固定字符串,无变量,直接返回)
88
  yield '{"chunk":"","finish":true}\n'
89
 
90
- # 返回流式响应(指定媒体类型为JSON流)
91
- return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
 
 
 
 
92
 
93
  except Exception as e:
94
- logger.error(f"推理失败:{str(e)}")
95
- raise HTTPException(status_code=500, detail=f"节点推理失败:{str(e)}")
 
96
 
97
- # 7. 健康检查接口(总控用于节点状态检测
98
  @app.get("/node/health")
99
  async def node_health():
100
- return {"status": "healthy", "model": MODEL_NAME}
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
  import uvicorn
104
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
4
  import os
5
  import logging
6
  import torch
7
+ import asyncio
8
  from transformers import (
9
  AutoModelForCausalLM, AutoTokenizer,
10
  BitsAndBytesConfig, TextStreamer
 
15
  logger = logging.getLogger("inference_node")
16
  app = FastAPI(title="推理节点服务(单一模型)")
17
 
18
+ # 2. 模型配置(修复:使用正确的模型支持通过环境变量覆盖
19
+ # 正确模型名:Qwen/Qwen-0.5B-Instruct(Hugging Face 官方存在)
20
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-0.5B-Instruct")
21
+ # 从环境变量获取 Hugging Face 令牌(必填,部分模型需登录)
22
+ HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
23
 
24
+ # 3. 4bit量化配置(适配16G内存,降低显存占用
25
  bnb_config = BitsAndBytesConfig(
26
  load_in_4bit=True,
27
  bnb_4bit_use_double_quant=True,
 
29
  bnb_4bit_compute_dtype=torch.bfloat16
30
  )
31
 
32
+ # 4. 加载模型(修复:用 token 参数替代 use_auth_token增加错误捕获
33
+ try:
34
+ logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)")
35
+ # 加载 Tokenizer(修复参数:用 token 替代 use_auth_token)
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ MODEL_NAME,
38
+ token=HF_TOKEN, # 新参数:传递 Hugging Face 令牌
39
+ padding_side="right", # 避免生成时的警告
40
+ trust_remote_code=True # 加载 Qwen 模型需开启(支持自定义代码)
41
+ )
42
+ # 加载量化模型
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ MODEL_NAME,
45
+ quantization_config=bnb_config,
46
+ device_map="auto", # 自动分配 GPU/CPU(优先用 GPU)
47
+ token=HF_TOKEN, # 传递令牌(部分模型需授权)
48
+ trust_remote_code=True # Qwen 模型必需
49
+ )
50
+ # 流式生成器(逐段输出结果)
51
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
52
+ logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 2-3GB")
53
+ except Exception as e:
54
+ logger.error(f"模型加载失败:{str(e)}", exc_info=True)
55
+ # 启动时加载失败直接退出(避免服务异常运行)
56
+ raise SystemExit(f"模型加载失败,服务终止:{str(e)}")
57
 
58
+ # 5. 请求模型(与总控约定的格式,无修改
59
  class NodeInferenceRequest(BaseModel):
60
+ prompt: str # 总控拼接好的完整 Prompt(含用户上下文)
61
+ max_tokens: int = 1024 # 最大生成长度
62
 
63
+ # 6. 流式推理接口(核心逻辑无修改确保异步兼容
64
  @app.post("/node/stream-infer")
65
  async def stream_infer(req: NodeInferenceRequest, request: Request):
66
  try:
67
+ # 预处理 Prompt(Qwen 模型专用方法构建输入)
68
+ inputs = tokenizer.build_chat_input(
69
+ [{"role": "user", "content": req.prompt}], # 适配 Qwen 对话格式
70
+ add_generation_prompt=True # 自动添加“助手回复”的提示
71
+ ).to(model.device)
72
+
73
+ # 异步生成器:避免阻塞 FastAPI 事件循环
74
  async def generate_chunks():
75
  generated_text = ""
76
+ # 用线程池执行同步的模型生成(避免阻塞异步接口
77
  loop = asyncio.get_running_loop()
78
  outputs = await loop.run_in_executor(
79
  None, # 使用默认线程池
 
82
  streamer=streamer,
83
  max_new_tokens=req.max_tokens,
84
  do_sample=True,
85
+ temperature=0.7, # 随机性(0~1,越小越确定)
86
+ pad_token_id=tokenizer.eos_token_id # 避免警告
87
  )
88
  )
89
+
90
+ # 逐段解码并返回结果
91
  for token in outputs[0][len(inputs["input_ids"][0]):]:
92
+ # 检查客户端是否断开连接(提前终止,节省资源
93
  if await request.is_disconnected():
94
  logger.info("客户端断开连接,停止生成")
95
  break
96
+ # 解码 Token(跳过特殊字符)
 
97
  token_text = tokenizer.decode(token, skip_special_tokens=True)
98
  generated_text += token_text
99
+ # 处理双引号转义(确保 JSON 格式合法)
100
+ escaped_text = token_text.replace('"', '\\"')
101
+ # 用 format 拼接 JSON,避免引号冲突
102
+ yield '{{"chunk":"{}","finish":false}}\n'.format(escaped_text)
103
+
104
+ # 生成结束标识
 
105
  yield '{"chunk":"","finish":true}\n'
106
 
107
+ # 返回流式响应(指定媒体类型为 JSON 流)
108
+ return StreamingResponse(
109
+ generate_chunks(),
110
+ media_type="application/x-ndjson",
111
+ headers={"Cache-Control": "no-cache"}
112
+ )
113
 
114
  except Exception as e:
115
+ error_msg = f"推理失败:{str(e)}"
116
+ logger.error(error_msg, exc_info=True)
117
+ raise HTTPException(status_code=500, detail=error_msg)
118
 
119
+ # 7. 健康检查接口(总控用于检测节点状态)
120
  @app.get("/node/health")
121
  async def node_health():
122
+ return {
123
+ "status": "healthy",
124
+ "model": MODEL_NAME,
125
+ "support_stream": True,
126
+ "note": "Qwen-0.5B-Instruct 4bit量化,显存占用~2GB"
127
+ }
128
 
129
  if __name__ == "__main__":
130
  import uvicorn
131
+ # 启动服务(Hugging Face Space 默认端口 7860
132
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")