Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import os | |
| import logging | |
| import torch | |
| import asyncio | |
| from transformers import ( | |
| AutoModelForCausalLM, AutoTokenizer, | |
| BitsAndBytesConfig, TextStreamer | |
| ) | |
| # 1. 基础配置 | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s") | |
| logger = logging.getLogger("inference_node") | |
| app = FastAPI(title="推理节点服务(Qwen-7B)") | |
| # 2. 模型配置(Qwen-7B 公开模型,无需HF Token) | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-7B") | |
| HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN", "") # 留空即可 | |
| # 3. 4bit量化配置(适配16G内存,显存占用约4-5GB) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| # 4. 加载模型(关键:显式处理tokenizer缺失的配置) | |
| try: | |
| logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| token=HF_TOKEN, | |
| padding_side="right", # 右侧padding,避免生成时截断 | |
| trust_remote_code=True, # Qwen模型必需(加载自定义tokenizer) | |
| eos_token="<|endoftext|>", # 显式指定结束符(兼容旧版本) | |
| pad_token="<|endoftext|>" # 显式指定padding符(避免生成警告) | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map="auto", # 自动分配设备(优先GPU) | |
| token=HF_TOKEN, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 # 匹配量化计算精度 | |
| ) | |
| # 流式输出配置(跳过提示词,只返回生成内容) | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 4-5GB(4bit 量化)") | |
| except Exception as e: | |
| logger.error(f"模型加载失败:{str(e)}", exc_info=True) | |
| raise SystemExit(f"服务终止:{str(e)}") | |
| # 5. 请求体定义(用户输入prompt和生成参数) | |
| class NodeInferenceRequest(BaseModel): | |
| prompt: str # 用户提问内容 | |
| max_tokens: int = 1024 # 最大生成长度(默认1024) | |
| temperature: float = 0.7 # 随机性(0-1,越大越多样) | |
| # 6. 流式推理接口(核心修复:绕开chat_template,直接构建输入) | |
| async def stream_infer(req: NodeInferenceRequest, request: Request): | |
| try: | |
| # -------------------------- | |
| # 关键修复:手动构建Qwen原生对话格式 | |
| # Qwen要求格式:<|user|>用户输入<|end|><|assistant|> | |
| # -------------------------- | |
| user_prompt = req.prompt.strip() | |
| # 构建模型能理解的输入文本(无需依赖chat_template) | |
| input_text = f"<|user|>{user_prompt}<|end|><|assistant|>" | |
| # 编码输入(转换为模型可处理的张量,并移动到GPU) | |
| inputs = tokenizer( | |
| input_text, | |
| return_tensors="pt", # 返回PyTorch张量 | |
| truncation=True, # 截断过长输入(避免OOM) | |
| max_length=2048 # 输入最大长度(根据模型能力调整) | |
| ).to(model.device) | |
| # 异步生成流式内容(避免阻塞FastAPI主线程) | |
| async def generate_chunks(): | |
| loop = asyncio.get_running_loop() | |
| # 在线程池中运行同步的模型生成(不阻塞事件循环) | |
| outputs = await loop.run_in_executor( | |
| None, # 使用默认线程池 | |
| lambda: model.generate( | |
| **inputs, | |
| streamer=streamer, # 流式输出支持 | |
| max_new_tokens=req.max_tokens, # 最大生成长度 | |
| do_sample=True, # 启用采样(生成多样内容) | |
| temperature=req.temperature, # 随机性控制 | |
| pad_token_id=tokenizer.pad_token_id, # padding符ID | |
| eos_token_id=tokenizer.eos_token_id # 结束符ID(生成停止标志) | |
| ) | |
| ) | |
| # 提取生成的内容(排除输入部分,只取新生成的token) | |
| input_token_len = inputs["input_ids"].shape[1] # 输入token长度 | |
| generated_tokens = outputs[0][input_token_len:] # 仅保留新生成的token | |
| # 逐token解码并返回(流式输出核心) | |
| for token in generated_tokens: | |
| # 检查客户端是否断开连接(避免无效生成) | |
| if await request.is_disconnected(): | |
| logger.info("客户端已断开连接,停止生成") | |
| break | |
| # 解码单个token(跳过特殊符号,如<|end|>) | |
| token_text = tokenizer.decode( | |
| token, | |
| skip_special_tokens=True, # 跳过特殊token(如结束符、分隔符) | |
| clean_up_tokenization_spaces=True # 清理多余空格 | |
| ) | |
| # 转义双引号(避免JSON格式错误) | |
| escaped_text = token_text.replace('"', '\\"') | |
| # 按NDJSON格式返回(每行一个JSON对象,兼容流式解析) | |
| yield f'{{"chunk":"{escaped_text}","finish":false}}\n' | |
| # 生成结束标志(告知客户端生成完成) | |
| yield '{"chunk":"","finish":true}\n' | |
| # 返回流式响应(媒体类型为application/x-ndjson,支持逐行解析) | |
| return StreamingResponse( | |
| generate_chunks(), | |
| media_type="application/x-ndjson" | |
| ) | |
| except Exception as e: | |
| logger.error(f"推理失败:{str(e)}", exc_info=True) # 记录详细错误堆栈 | |
| raise HTTPException(status_code=500, detail=f"推理服务异常:{str(e)}") | |
| # 7. 健康检查接口(用于监控服务状态) | |
| async def node_health(): | |
| # 检查模型和tokenizer是否正常加载 | |
| is_model_ready = model is not None and tokenizer is not None | |
| return { | |
| "status": "healthy" if is_model_ready else "unhealthy", | |
| "model": MODEL_NAME, | |
| "support_stream": True, | |
| "note": "Qwen-7B 4bit量化(适配16G内存),绕开chat_template兼容旧版本", | |
| "timestamp": str(asyncio.get_event_loop().time()) | |
| } | |
| # 8. 启动服务(仅在直接运行脚本时执行) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # 启动UVicorn服务(host=0.0.0.0允许外部访问,port=7860为默认端口) | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info", | |
| workers=1 # 单进程(模型不支持多进程共享,避免重复加载) | |
| ) |