# llm_api.py from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import Optional import uvicorn import numpy as np import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from config import LLM_HF_MODEL, LLM_AX_MODEL, LLM_API_PORT app = FastAPI(title="Fast-API", description="本地推理接口") device = "cuda" if torch.cuda.is_available() else "cpu" """ axengine 相关 """ from ml_dtypes import bfloat16 from utils.infer_func import InferManager # 定义全局变量,但先不初始化 tokenizer = None imer = None embeds = None def init_model(): global tokenizer, imer, embeds if tokenizer is None: # 防止重复初始化 cfg = AutoConfig.from_pretrained(LLM_HF_MODEL) imer = InferManager(cfg, LLM_AX_MODEL, model_type="qwen2") embeds = np.load(os.path.join(LLM_AX_MODEL, "model.embed_tokens.weight.npy")) # 加载 tokenizer tokenizer = AutoTokenizer.from_pretrained(LLM_HF_MODEL, trust_remote_code=True) print("✅ 模型加载完成。") # 添加 FastAPI 的启动事件 @app.on_event("startup") async def startup_event(): init_model() class GenRequest(BaseModel): prompt: str max_tokens: Optional[int] = 1024 temperature: Optional[float] = 0.6 top_p: Optional[float] = 0.9 class GenResponse(BaseModel): text: str @app.post("/generate", response_model=GenResponse) def generate_text(req: GenRequest): try: # input_ids = tokenizer(req.prompt, return_tensors="pt").input_ids.to(device) # with torch.no_grad(): # output_ids = model.generate( # input_ids=input_ids, # max_new_tokens=req.max_tokens, # temperature=req.temperature, # top_p=req.top_p, # # do_sample=True, # eos_token_id=tokenizer.eos_token_id # ) # response_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) messages = [ {"role": "system", "content": "你的名字叫做 [AXERA-RAG 助手]. 你是一个高效、精准的问答助手. 你可以根据上下文内容, 回答用户提出的问题, 回答时不要提及多余的、无用的内容, 且仅输出你的回答."}, {"role": "user", "content": req.prompt} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], return_tensors="pt").to(device) """ axengine 框架模型推理 """ input_ids = model_inputs['input_ids'] inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0) prefill_data = inputs_embeds prefill_data = prefill_data.astype(bfloat16) token_ids = input_ids[0].cpu().numpy().tolist() generated_text = "" def generate_stream(): nonlocal token_ids, generated_text token_ids = imer.prefill(tokenizer, token_ids, prefill_data[0], slice_len=128) generated_text += tokenizer.decode(token_ids[-1], skip_special_tokens=True) # response_text = imer.decode(tokenizer, token_ids, embeds, slice_len=128) # 去掉 prompt 的前缀, 只保留生成部分 # generated_text = response_text[len(req.prompt):].strip() # generated_text = response_text # return GenResponse(text=generated_text) # 流式输出控制 prefill_word = tokenizer.decode(token_ids[-1], skip_special_tokens=True) prefill_word = prefill_word.strip().replace("\n", "\\n").replace("\"", "\\\"") seq_len = len(token_ids) - 1 prefill_len = 128 for step_idx in range(imer.max_seq_len): if prefill_len > 0 and step_idx < seq_len: continue token_ids, next_token_id = imer.decode_next_token(tokenizer, token_ids, embeds, slice_len=128, step_idx=step_idx) if next_token_id == tokenizer.eos_token_id and next_token_id > seq_len: break try: if next_token_id is not None: word = tokenizer.decode([next_token_id], skip_special_tokens=True) generated_text += word if prefill_word is not None: word = prefill_word + word prefill_word = None # 以适合前端处理的 SSE 格式输出 # 处理特殊字符 word = word.strip().replace("\n", "\\n").replace("\"", "\\\"") # import pdb; pdb.set_trace() yield f"data: {{\"token\": \"{word}\"}}\n\n" except Exception as e: print(f"Error decoding token {next_token_id}: {e}") return StreamingResponse( generate_stream(), media_type="text/event-stream", # 必须使用SSE格式 headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" # 禁用Nginx缓冲 } ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=LLM_API_PORT, reload=False)