from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel import os import logging import torch import asyncio from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer ) # 1. 基础配置 logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s") logger = logging.getLogger("inference_node") app = FastAPI(title="推理节点服务(Qwen-7B)") # 2. 模型配置(Qwen-7B 公开模型,无需HF Token) MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-7B") HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN", "") # 留空即可 # 3. 4bit量化配置(适配16G内存,显存占用约4-5GB) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # 4. 加载模型(关键:显式处理tokenizer缺失的配置) try: logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)") tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, token=HF_TOKEN, padding_side="right", # 右侧padding,避免生成时截断 trust_remote_code=True, # Qwen模型必需(加载自定义tokenizer) eos_token="<|endoftext|>", # 显式指定结束符(兼容旧版本) pad_token="<|endoftext|>" # 显式指定padding符(避免生成警告) ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", # 自动分配设备(优先GPU) token=HF_TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16 # 匹配量化计算精度 ) # 流式输出配置(跳过提示词,只返回生成内容) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 4-5GB(4bit 量化)") except Exception as e: logger.error(f"模型加载失败:{str(e)}", exc_info=True) raise SystemExit(f"服务终止:{str(e)}") # 5. 请求体定义(用户输入prompt和生成参数) class NodeInferenceRequest(BaseModel): prompt: str # 用户提问内容 max_tokens: int = 1024 # 最大生成长度(默认1024) temperature: float = 0.7 # 随机性(0-1,越大越多样) # 6. 流式推理接口(核心修复:绕开chat_template,直接构建输入) @app.post("/node/stream-infer") async def stream_infer(req: NodeInferenceRequest, request: Request): try: # -------------------------- # 关键修复:手动构建Qwen原生对话格式 # Qwen要求格式:<|user|>用户输入<|end|><|assistant|> # -------------------------- user_prompt = req.prompt.strip() # 构建模型能理解的输入文本(无需依赖chat_template) input_text = f"<|user|>{user_prompt}<|end|><|assistant|>" # 编码输入(转换为模型可处理的张量,并移动到GPU) inputs = tokenizer( input_text, return_tensors="pt", # 返回PyTorch张量 truncation=True, # 截断过长输入(避免OOM) max_length=2048 # 输入最大长度(根据模型能力调整) ).to(model.device) # 异步生成流式内容(避免阻塞FastAPI主线程) async def generate_chunks(): loop = asyncio.get_running_loop() # 在线程池中运行同步的模型生成(不阻塞事件循环) outputs = await loop.run_in_executor( None, # 使用默认线程池 lambda: model.generate( **inputs, streamer=streamer, # 流式输出支持 max_new_tokens=req.max_tokens, # 最大生成长度 do_sample=True, # 启用采样(生成多样内容) temperature=req.temperature, # 随机性控制 pad_token_id=tokenizer.pad_token_id, # padding符ID eos_token_id=tokenizer.eos_token_id # 结束符ID(生成停止标志) ) ) # 提取生成的内容(排除输入部分,只取新生成的token) input_token_len = inputs["input_ids"].shape[1] # 输入token长度 generated_tokens = outputs[0][input_token_len:] # 仅保留新生成的token # 逐token解码并返回(流式输出核心) for token in generated_tokens: # 检查客户端是否断开连接(避免无效生成) if await request.is_disconnected(): logger.info("客户端已断开连接,停止生成") break # 解码单个token(跳过特殊符号,如<|end|>) token_text = tokenizer.decode( token, skip_special_tokens=True, # 跳过特殊token(如结束符、分隔符) clean_up_tokenization_spaces=True # 清理多余空格 ) # 转义双引号(避免JSON格式错误) escaped_text = token_text.replace('"', '\\"') # 按NDJSON格式返回(每行一个JSON对象,兼容流式解析) yield f'{{"chunk":"{escaped_text}","finish":false}}\n' # 生成结束标志(告知客户端生成完成) yield '{"chunk":"","finish":true}\n' # 返回流式响应(媒体类型为application/x-ndjson,支持逐行解析) return StreamingResponse( generate_chunks(), media_type="application/x-ndjson" ) except Exception as e: logger.error(f"推理失败:{str(e)}", exc_info=True) # 记录详细错误堆栈 raise HTTPException(status_code=500, detail=f"推理服务异常:{str(e)}") # 7. 健康检查接口(用于监控服务状态) @app.get("/node/health") async def node_health(): # 检查模型和tokenizer是否正常加载 is_model_ready = model is not None and tokenizer is not None return { "status": "healthy" if is_model_ready else "unhealthy", "model": MODEL_NAME, "support_stream": True, "note": "Qwen-7B 4bit量化(适配16G内存),绕开chat_template兼容旧版本", "timestamp": str(asyncio.get_event_loop().time()) } # 8. 启动服务(仅在直接运行脚本时执行) if __name__ == "__main__": import uvicorn # 启动UVicorn服务(host=0.0.0.0允许外部访问,port=7860为默认端口) uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info", workers=1 # 单进程(模型不支持多进程共享,避免重复加载) )