File size: 6,970 Bytes
91bd68b
 
 
 
 
 
aa269af
91bd68b
 
ea17503
91bd68b
 
 
 
 
a31d0b8
91bd68b
9908754
a31d0b8
9908754
91bd68b
9908754
91bd68b
 
 
 
 
 
 
ea17503
aa269af
 
 
 
e95d674
9908754
 
ea17503
9908754
aa269af
 
 
 
9908754
a31d0b8
9908754
 
aa269af
ea17503
 
a31d0b8
aa269af
 
a31d0b8
91bd68b
9908754
91bd68b
9908754
 
 
91bd68b
ea17503
91bd68b
 
 
9908754
ea17503
 
9908754
 
ea17503
9908754
e95d674
ea17503
e95d674
 
9908754
 
ea17503
aa269af
 
ea17503
d5049a2
6b2cd32
ea17503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9908754
91bd68b
9908754
ea17503
 
 
 
 
 
 
 
 
 
 
 
 
91bd68b
 
9908754
 
 
 
 
91bd68b
 
ea17503
 
91bd68b
9908754
91bd68b
 
9908754
 
aa269af
9908754
aa269af
 
ea17503
9908754
aa269af
91bd68b
9908754
91bd68b
 
ea17503
9908754
 
 
 
 
ea17503
9908754
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import os
import logging
import torch
import asyncio
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    BitsAndBytesConfig, TextStreamer
)

# 1. 基础配置
logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s")
logger = logging.getLogger("inference_node")
app = FastAPI(title="推理节点服务(Qwen-7B)")

# 2. 模型配置(Qwen-7B 公开模型,无需HF Token)
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-7B")  
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN", "")  # 留空即可

# 3. 4bit量化配置(适配16G内存,显存占用约4-5GB)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# 4. 加载模型(关键:显式处理tokenizer缺失的配置)
try:
    logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        token=HF_TOKEN,
        padding_side="right",  # 右侧padding,避免生成时截断
        trust_remote_code=True,  # Qwen模型必需(加载自定义tokenizer)
        eos_token="<|endoftext|>",  # 显式指定结束符(兼容旧版本)
        pad_token="<|endoftext|>"   # 显式指定padding符(避免生成警告)
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",  # 自动分配设备(优先GPU)
        token=HF_TOKEN,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16  # 匹配量化计算精度
    )
    # 流式输出配置(跳过提示词,只返回生成内容)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 4-5GB(4bit 量化)")
except Exception as e:
    logger.error(f"模型加载失败:{str(e)}", exc_info=True)
    raise SystemExit(f"服务终止:{str(e)}")

# 5. 请求体定义(用户输入prompt和生成参数)
class NodeInferenceRequest(BaseModel):
    prompt: str  # 用户提问内容
    max_tokens: int = 1024  # 最大生成长度(默认1024)
    temperature: float = 0.7  # 随机性(0-1,越大越多样)

# 6. 流式推理接口(核心修复:绕开chat_template,直接构建输入)
@app.post("/node/stream-infer")
async def stream_infer(req: NodeInferenceRequest, request: Request):
    try:
        # --------------------------
        # 关键修复:手动构建Qwen原生对话格式
        # Qwen要求格式:<|user|>用户输入<|end|><|assistant|>
        # --------------------------
        user_prompt = req.prompt.strip()
        # 构建模型能理解的输入文本(无需依赖chat_template)
        input_text = f"<|user|>{user_prompt}<|end|><|assistant|>"
        
        # 编码输入(转换为模型可处理的张量,并移动到GPU)
        inputs = tokenizer(
            input_text,
            return_tensors="pt",  # 返回PyTorch张量
            truncation=True,      # 截断过长输入(避免OOM)
            max_length=2048       # 输入最大长度(根据模型能力调整)
        ).to(model.device)

        # 异步生成流式内容(避免阻塞FastAPI主线程)
        async def generate_chunks():
            loop = asyncio.get_running_loop()
            # 在线程池中运行同步的模型生成(不阻塞事件循环)
            outputs = await loop.run_in_executor(
                None,  # 使用默认线程池
                lambda: model.generate(
                    **inputs,
                    streamer=streamer,          # 流式输出支持
                    max_new_tokens=req.max_tokens,  # 最大生成长度
                    do_sample=True,             # 启用采样(生成多样内容)
                    temperature=req.temperature,  # 随机性控制
                    pad_token_id=tokenizer.pad_token_id,  # padding符ID
                    eos_token_id=tokenizer.eos_token_id   # 结束符ID(生成停止标志)
                )
            )

            # 提取生成的内容(排除输入部分,只取新生成的token)
            input_token_len = inputs["input_ids"].shape[1]  # 输入token长度
            generated_tokens = outputs[0][input_token_len:]  # 仅保留新生成的token

            # 逐token解码并返回(流式输出核心)
            for token in generated_tokens:
                # 检查客户端是否断开连接(避免无效生成)
                if await request.is_disconnected():
                    logger.info("客户端已断开连接,停止生成")
                    break
                # 解码单个token(跳过特殊符号,如<|end|>)
                token_text = tokenizer.decode(
                    token,
                    skip_special_tokens=True,  # 跳过特殊token(如结束符、分隔符)
                    clean_up_tokenization_spaces=True  # 清理多余空格
                )
                # 转义双引号(避免JSON格式错误)
                escaped_text = token_text.replace('"', '\\"')
                # 按NDJSON格式返回(每行一个JSON对象,兼容流式解析)
                yield f'{{"chunk":"{escaped_text}","finish":false}}\n'
            
            # 生成结束标志(告知客户端生成完成)
            yield '{"chunk":"","finish":true}\n'

        # 返回流式响应(媒体类型为application/x-ndjson,支持逐行解析)
        return StreamingResponse(
            generate_chunks(),
            media_type="application/x-ndjson"
        )

    except Exception as e:
        logger.error(f"推理失败:{str(e)}", exc_info=True)  # 记录详细错误堆栈
        raise HTTPException(status_code=500, detail=f"推理服务异常:{str(e)}")

# 7. 健康检查接口(用于监控服务状态)
@app.get("/node/health")
async def node_health():
    # 检查模型和tokenizer是否正常加载
    is_model_ready = model is not None and tokenizer is not None
    return {
        "status": "healthy" if is_model_ready else "unhealthy",
        "model": MODEL_NAME,
        "support_stream": True,
        "note": "Qwen-7B 4bit量化(适配16G内存),绕开chat_template兼容旧版本",
        "timestamp": str(asyncio.get_event_loop().time())
    }

# 8. 启动服务(仅在直接运行脚本时执行)
if __name__ == "__main__":
    import uvicorn
    # 启动UVicorn服务(host=0.0.0.0允许外部访问,port=7860为默认端口)
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        log_level="info",
        workers=1  # 单进程(模型不支持多进程共享,避免重复加载)
    )