fiewolf1000 commited on
Commit
6b2cd32
·
verified ·
1 Parent(s): d5049a2

Update inference_node.py

Browse files
Files changed (1) hide show
  1. inference_node.py +19 -16
inference_node.py CHANGED
@@ -4,6 +4,7 @@ from pydantic import BaseModel
4
  import os
5
  import logging
6
  import torch
 
7
  from transformers import (
8
  AutoModelForCausalLM, AutoTokenizer,
9
  BitsAndBytesConfig, TextStreamer
@@ -47,17 +48,16 @@ class NodeInferenceRequest(BaseModel):
47
  @app.post("/node/stream-infer")
48
  async def stream_infer(req: NodeInferenceRequest, request: Request):
49
  try:
50
- # 模型生成(流式)
51
  inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
52
 
53
- # 异步生成器需要用async def
54
  async def generate_chunks():
55
  generated_text = ""
56
- # 使用model.generate的异步版本或创建任务
57
- loop = asyncio.get_event_loop()
58
- # 在单独的线程中运行生成过程,避免阻塞事件循环
59
- future = loop.run_in_executor(
60
- None,
61
  lambda: model.generate(
62
  **inputs,
63
  streamer=streamer,
@@ -68,22 +68,26 @@ async def stream_infer(req: NodeInferenceRequest, request: Request):
68
  )
69
  )
70
 
71
- outputs = await future
72
-
73
  for token in outputs[0][len(inputs["input_ids"][0]):]:
74
- # 检查客户端是否断开连接(避免无效生成)
75
  if await request.is_disconnected():
76
  logger.info("客户端断开连接,停止生成")
77
  break
 
 
78
  token_text = tokenizer.decode(token, skip_special_tokens=True)
79
  generated_text += token_text
80
- # 按总控约定的JSON格式返回(便于总控透传)
81
- # 修复引号转义问题
82
- yield f'{{"chunk":"{token_text.replace(\'"', '\\"')}","finish":false}}\n'
 
 
83
 
84
- # 生成结束标识
85
  yield '{"chunk":"","finish":true}\n'
86
 
 
87
  return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
88
 
89
  except Exception as e:
@@ -97,5 +101,4 @@ async def node_health():
97
 
98
  if __name__ == "__main__":
99
  import uvicorn
100
- import asyncio # 新增导入
101
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
4
  import os
5
  import logging
6
  import torch
7
+ import asyncio # 新增异步依赖
8
  from transformers import (
9
  AutoModelForCausalLM, AutoTokenizer,
10
  BitsAndBytesConfig, TextStreamer
 
48
  @app.post("/node/stream-infer")
49
  async def stream_infer(req: NodeInferenceRequest, request: Request):
50
  try:
51
+ # 模型生成(流式):用异步线程避免阻塞事件循环
52
  inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
53
 
54
+ # 异步生成器:必须用 async def
55
  async def generate_chunks():
56
  generated_text = ""
57
+ # 模型生成是同步操作,用线程池异步执行(避免阻塞FastAPI)
58
+ loop = asyncio.get_running_loop()
59
+ outputs = await loop.run_in_executor(
60
+ None, # 使用默认线程池
 
61
  lambda: model.generate(
62
  **inputs,
63
  streamer=streamer,
 
68
  )
69
  )
70
 
71
+ # 逐段处理生成结果
 
72
  for token in outputs[0][len(inputs["input_ids"][0]):]:
73
+ # 检查客户端是否断开连接(提前终止,避免无效计算)
74
  if await request.is_disconnected():
75
  logger.info("客户端断开连接,停止生成")
76
  break
77
+
78
+ # 解码token并处理双引号转义(避免JSON格式错误)
79
  token_text = tokenizer.decode(token, skip_special_tokens=True)
80
  generated_text += token_text
81
+ escaped_text = token_text.replace('"', '\\"') # 提前处理双引号转义
82
+
83
+ # 用 str.format() 拼接JSON,彻底避免f-string引号冲突
84
+ json_chunk = '{{"chunk":"{}","finish":false}}\n'.format(escaped_text)
85
+ yield json_chunk
86
 
87
+ # 生成结束标识(固定字符串,无变量,直接返回)
88
  yield '{"chunk":"","finish":true}\n'
89
 
90
+ # 返回流式响应(指定媒体类型为JSON流)
91
  return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
92
 
93
  except Exception as e:
 
101
 
102
  if __name__ == "__main__":
103
  import uvicorn
104
+ uvicorn.run(app, host="0.0.0.0", port=7860)