Spaces:
Sleeping
Sleeping
Update inference_node.py
Browse files- inference_node.py +19 -16
inference_node.py
CHANGED
|
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|
| 4 |
import os
|
| 5 |
import logging
|
| 6 |
import torch
|
|
|
|
| 7 |
from transformers import (
|
| 8 |
AutoModelForCausalLM, AutoTokenizer,
|
| 9 |
BitsAndBytesConfig, TextStreamer
|
|
@@ -47,17 +48,16 @@ class NodeInferenceRequest(BaseModel):
|
|
| 47 |
@app.post("/node/stream-infer")
|
| 48 |
async def stream_infer(req: NodeInferenceRequest, request: Request):
|
| 49 |
try:
|
| 50 |
-
#
|
| 51 |
inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
|
| 52 |
|
| 53 |
-
#
|
| 54 |
async def generate_chunks():
|
| 55 |
generated_text = ""
|
| 56 |
-
#
|
| 57 |
-
loop = asyncio.
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
None,
|
| 61 |
lambda: model.generate(
|
| 62 |
**inputs,
|
| 63 |
streamer=streamer,
|
|
@@ -68,22 +68,26 @@ async def stream_infer(req: NodeInferenceRequest, request: Request):
|
|
| 68 |
)
|
| 69 |
)
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
for token in outputs[0][len(inputs["input_ids"][0]):]:
|
| 74 |
-
#
|
| 75 |
if await request.is_disconnected():
|
| 76 |
logger.info("客户端断开连接,停止生成")
|
| 77 |
break
|
|
|
|
|
|
|
| 78 |
token_text = tokenizer.decode(token, skip_special_tokens=True)
|
| 79 |
generated_text += token_text
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
#
|
| 85 |
yield '{"chunk":"","finish":true}\n'
|
| 86 |
|
|
|
|
| 87 |
return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
|
| 88 |
|
| 89 |
except Exception as e:
|
|
@@ -97,5 +101,4 @@ async def node_health():
|
|
| 97 |
|
| 98 |
if __name__ == "__main__":
|
| 99 |
import uvicorn
|
| 100 |
-
|
| 101 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 4 |
import os
|
| 5 |
import logging
|
| 6 |
import torch
|
| 7 |
+
import asyncio # 新增异步依赖
|
| 8 |
from transformers import (
|
| 9 |
AutoModelForCausalLM, AutoTokenizer,
|
| 10 |
BitsAndBytesConfig, TextStreamer
|
|
|
|
| 48 |
@app.post("/node/stream-infer")
|
| 49 |
async def stream_infer(req: NodeInferenceRequest, request: Request):
|
| 50 |
try:
|
| 51 |
+
# 模型生成(流式):用异步线程避免阻塞事件循环
|
| 52 |
inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
|
| 53 |
|
| 54 |
+
# 异步生成器:必须用 async def
|
| 55 |
async def generate_chunks():
|
| 56 |
generated_text = ""
|
| 57 |
+
# 模型生成是同步操作,用线程池异步执行(避免阻塞FastAPI)
|
| 58 |
+
loop = asyncio.get_running_loop()
|
| 59 |
+
outputs = await loop.run_in_executor(
|
| 60 |
+
None, # 使用默认线程池
|
|
|
|
| 61 |
lambda: model.generate(
|
| 62 |
**inputs,
|
| 63 |
streamer=streamer,
|
|
|
|
| 68 |
)
|
| 69 |
)
|
| 70 |
|
| 71 |
+
# 逐段处理生成结果
|
|
|
|
| 72 |
for token in outputs[0][len(inputs["input_ids"][0]):]:
|
| 73 |
+
# 检查客户端是否断开连接(提前终止,避免无效计算)
|
| 74 |
if await request.is_disconnected():
|
| 75 |
logger.info("客户端断开连接,停止生成")
|
| 76 |
break
|
| 77 |
+
|
| 78 |
+
# 解码token并处理双引号转义(避免JSON格式错误)
|
| 79 |
token_text = tokenizer.decode(token, skip_special_tokens=True)
|
| 80 |
generated_text += token_text
|
| 81 |
+
escaped_text = token_text.replace('"', '\\"') # 提前处理双引号转义
|
| 82 |
+
|
| 83 |
+
# 用 str.format() 拼接JSON,彻底避免f-string引号冲突
|
| 84 |
+
json_chunk = '{{"chunk":"{}","finish":false}}\n'.format(escaped_text)
|
| 85 |
+
yield json_chunk
|
| 86 |
|
| 87 |
+
# 生成结束标识(固定字符串,无变量,直接返回)
|
| 88 |
yield '{"chunk":"","finish":true}\n'
|
| 89 |
|
| 90 |
+
# 返回流式响应(指定媒体类型为JSON流)
|
| 91 |
return StreamingResponse(generate_chunks(), media_type="application/x-ndjson")
|
| 92 |
|
| 93 |
except Exception as e:
|
|
|
|
| 101 |
|
| 102 |
if __name__ == "__main__":
|
| 103 |
import uvicorn
|
| 104 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|