| | |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.responses import StreamingResponse |
| | from pydantic import BaseModel |
| | from typing import Optional |
| | import uvicorn |
| | import numpy as np |
| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| | from config import LLM_HF_MODEL, LLM_AX_MODEL, LLM_API_PORT |
| |
|
| |
|
| | app = FastAPI(title="Fast-API", description="本地推理接口") |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | """ |
| | axengine 相关 |
| | """ |
| | from ml_dtypes import bfloat16 |
| | from utils.infer_func import InferManager |
| |
|
| | |
| | tokenizer = None |
| | imer = None |
| | embeds = None |
| |
|
| | def init_model(): |
| | global tokenizer, imer, embeds |
| | if tokenizer is None: |
| | cfg = AutoConfig.from_pretrained(LLM_HF_MODEL) |
| | imer = InferManager(cfg, LLM_AX_MODEL, model_type="qwen2") |
| | embeds = np.load(os.path.join(LLM_AX_MODEL, "model.embed_tokens.weight.npy")) |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(LLM_HF_MODEL, trust_remote_code=True) |
| | print("✅ 模型加载完成。") |
| |
|
| | |
| | @app.on_event("startup") |
| | async def startup_event(): |
| | init_model() |
| |
|
| | class GenRequest(BaseModel): |
| | prompt: str |
| | max_tokens: Optional[int] = 1024 |
| | temperature: Optional[float] = 0.6 |
| | top_p: Optional[float] = 0.9 |
| |
|
| | class GenResponse(BaseModel): |
| | text: str |
| |
|
| | @app.post("/generate", response_model=GenResponse) |
| | def generate_text(req: GenRequest): |
| | try: |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | messages = [ |
| | {"role": "system", "content": "你的名字叫做 [AXERA-RAG 助手]. 你是一个高效、精准的问答助手. 你可以根据上下文内容, 回答用户提出的问题, 回答时不要提及多余的、无用的内容, 且仅输出你的回答."}, |
| | {"role": "user", "content": req.prompt} |
| | ] |
| | text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | model_inputs = tokenizer([text], return_tensors="pt").to(device) |
| |
|
| | """ |
| | axengine 框架模型推理 |
| | """ |
| | input_ids = model_inputs['input_ids'] |
| | inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0) |
| | prefill_data = inputs_embeds |
| | prefill_data = prefill_data.astype(bfloat16) |
| | token_ids = input_ids[0].cpu().numpy().tolist() |
| | generated_text = "" |
| |
|
| | def generate_stream(): |
| | nonlocal token_ids, generated_text |
| | token_ids = imer.prefill(tokenizer, token_ids, prefill_data[0], slice_len=128) |
| | generated_text += tokenizer.decode(token_ids[-1], skip_special_tokens=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | prefill_word = tokenizer.decode(token_ids[-1], skip_special_tokens=True) |
| | prefill_word = prefill_word.strip().replace("\n", "\\n").replace("\"", "\\\"") |
| |
|
| | seq_len = len(token_ids) - 1 |
| | prefill_len = 128 |
| | for step_idx in range(imer.max_seq_len): |
| | if prefill_len > 0 and step_idx < seq_len: |
| | continue |
| | token_ids, next_token_id = imer.decode_next_token(tokenizer, token_ids, embeds, slice_len=128, step_idx=step_idx) |
| | if next_token_id == tokenizer.eos_token_id and next_token_id > seq_len: |
| | break |
| | try: |
| | if next_token_id is not None: |
| | word = tokenizer.decode([next_token_id], skip_special_tokens=True) |
| | generated_text += word |
| | if prefill_word is not None: |
| | word = prefill_word + word |
| | prefill_word = None |
| | |
| | |
| | word = word.strip().replace("\n", "\\n").replace("\"", "\\\"") |
| | |
| | yield f"data: {{\"token\": \"{word}\"}}\n\n" |
| | except Exception as e: |
| | print(f"Error decoding token {next_token_id}: {e}") |
| |
|
| | return StreamingResponse( |
| | generate_stream(), |
| | media_type="text/event-stream", |
| | headers={ |
| | "Cache-Control": "no-cache", |
| | "Connection": "keep-alive", |
| | "X-Accel-Buffering": "no" |
| | } |
| | ) |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run(app, host="0.0.0.0", port=LLM_API_PORT, reload=False) |
| |
|