RAG.axera / llm_api.py
yongqiang
Initialize the repository
1ed9a31
# llm_api.py
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
import uvicorn
import numpy as np
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from config import LLM_HF_MODEL, LLM_AX_MODEL, LLM_API_PORT
app = FastAPI(title="Fast-API", description="本地推理接口")
device = "cuda" if torch.cuda.is_available() else "cpu"
"""
axengine 相关
"""
from ml_dtypes import bfloat16
from utils.infer_func import InferManager
# 定义全局变量,但先不初始化
tokenizer = None
imer = None
embeds = None
def init_model():
global tokenizer, imer, embeds
if tokenizer is None: # 防止重复初始化
cfg = AutoConfig.from_pretrained(LLM_HF_MODEL)
imer = InferManager(cfg, LLM_AX_MODEL, model_type="qwen2")
embeds = np.load(os.path.join(LLM_AX_MODEL, "model.embed_tokens.weight.npy"))
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(LLM_HF_MODEL, trust_remote_code=True)
print("✅ 模型加载完成。")
# 添加 FastAPI 的启动事件
@app.on_event("startup")
async def startup_event():
init_model()
class GenRequest(BaseModel):
prompt: str
max_tokens: Optional[int] = 1024
temperature: Optional[float] = 0.6
top_p: Optional[float] = 0.9
class GenResponse(BaseModel):
text: str
@app.post("/generate", response_model=GenResponse)
def generate_text(req: GenRequest):
try:
# input_ids = tokenizer(req.prompt, return_tensors="pt").input_ids.to(device)
# with torch.no_grad():
# output_ids = model.generate(
# input_ids=input_ids,
# max_new_tokens=req.max_tokens,
# temperature=req.temperature,
# top_p=req.top_p,
# # do_sample=True,
# eos_token_id=tokenizer.eos_token_id
# )
# response_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
messages = [
{"role": "system", "content": "你的名字叫做 [AXERA-RAG 助手]. 你是一个高效、精准的问答助手. 你可以根据上下文内容, 回答用户提出的问题, 回答时不要提及多余的、无用的内容, 且仅输出你的回答."},
{"role": "user", "content": req.prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
"""
axengine 框架模型推理
"""
input_ids = model_inputs['input_ids']
inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0)
prefill_data = inputs_embeds
prefill_data = prefill_data.astype(bfloat16)
token_ids = input_ids[0].cpu().numpy().tolist()
generated_text = ""
def generate_stream():
nonlocal token_ids, generated_text
token_ids = imer.prefill(tokenizer, token_ids, prefill_data[0], slice_len=128)
generated_text += tokenizer.decode(token_ids[-1], skip_special_tokens=True)
# response_text = imer.decode(tokenizer, token_ids, embeds, slice_len=128)
# 去掉 prompt 的前缀, 只保留生成部分
# generated_text = response_text[len(req.prompt):].strip()
# generated_text = response_text
# return GenResponse(text=generated_text)
# 流式输出控制
prefill_word = tokenizer.decode(token_ids[-1], skip_special_tokens=True)
prefill_word = prefill_word.strip().replace("\n", "\\n").replace("\"", "\\\"")
seq_len = len(token_ids) - 1
prefill_len = 128
for step_idx in range(imer.max_seq_len):
if prefill_len > 0 and step_idx < seq_len:
continue
token_ids, next_token_id = imer.decode_next_token(tokenizer, token_ids, embeds, slice_len=128, step_idx=step_idx)
if next_token_id == tokenizer.eos_token_id and next_token_id > seq_len:
break
try:
if next_token_id is not None:
word = tokenizer.decode([next_token_id], skip_special_tokens=True)
generated_text += word
if prefill_word is not None:
word = prefill_word + word
prefill_word = None
# 以适合前端处理的 SSE 格式输出
# 处理特殊字符
word = word.strip().replace("\n", "\\n").replace("\"", "\\\"")
# import pdb; pdb.set_trace()
yield f"data: {{\"token\": \"{word}\"}}\n\n"
except Exception as e:
print(f"Error decoding token {next_token_id}: {e}")
return StreamingResponse(
generate_stream(),
media_type="text/event-stream", # 必须使用SSE格式
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用Nginx缓冲
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=LLM_API_PORT, reload=False)