File size: 5,540 Bytes
1ed9a31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# llm_api.py
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
import uvicorn
import numpy as np
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from config import LLM_HF_MODEL, LLM_AX_MODEL, LLM_API_PORT


app = FastAPI(title="Fast-API", description="本地推理接口")
device = "cuda" if torch.cuda.is_available() else "cpu"

"""
axengine 相关
"""
from ml_dtypes import bfloat16
from utils.infer_func import InferManager

# 定义全局变量,但先不初始化
tokenizer = None
imer = None
embeds = None

def init_model():
    global tokenizer, imer, embeds
    if tokenizer is None:  # 防止重复初始化
        cfg = AutoConfig.from_pretrained(LLM_HF_MODEL) 
        imer = InferManager(cfg, LLM_AX_MODEL, model_type="qwen2")
        embeds = np.load(os.path.join(LLM_AX_MODEL, "model.embed_tokens.weight.npy"))
        # 加载 tokenizer
        tokenizer = AutoTokenizer.from_pretrained(LLM_HF_MODEL, trust_remote_code=True)
        print("✅ 模型加载完成。")

# 添加 FastAPI 的启动事件
@app.on_event("startup")
async def startup_event():
    init_model()

class GenRequest(BaseModel):
    prompt: str
    max_tokens: Optional[int] = 1024
    temperature: Optional[float] = 0.6
    top_p: Optional[float] = 0.9

class GenResponse(BaseModel):
    text: str

@app.post("/generate", response_model=GenResponse)
def generate_text(req: GenRequest):
    try:
        # input_ids = tokenizer(req.prompt, return_tensors="pt").input_ids.to(device)

        # with torch.no_grad():
        #     output_ids = model.generate(
        #         input_ids=input_ids,
        #         max_new_tokens=req.max_tokens,
        #         temperature=req.temperature,
        #         top_p=req.top_p,
        #         # do_sample=True,
        #         eos_token_id=tokenizer.eos_token_id
        #     )

        # response_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        messages = [
            {"role": "system", "content": "你的名字叫做 [AXERA-RAG 助手]. 你是一个高效、精准的问答助手. 你可以根据上下文内容, 回答用户提出的问题, 回答时不要提及多余的、无用的内容, 且仅输出你的回答."},
            {"role": "user", "content": req.prompt}
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        model_inputs = tokenizer([text], return_tensors="pt").to(device)

        """
        axengine 框架模型推理
        """
        input_ids = model_inputs['input_ids']
        inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0)
        prefill_data = inputs_embeds
        prefill_data = prefill_data.astype(bfloat16)
        token_ids = input_ids[0].cpu().numpy().tolist()
        generated_text = ""

        def generate_stream():
            nonlocal token_ids, generated_text
            token_ids = imer.prefill(tokenizer, token_ids, prefill_data[0], slice_len=128)
            generated_text += tokenizer.decode(token_ids[-1], skip_special_tokens=True)

            # response_text = imer.decode(tokenizer, token_ids, embeds, slice_len=128)
            # 去掉 prompt 的前缀, 只保留生成部分
            # generated_text = response_text[len(req.prompt):].strip()
            # generated_text = response_text
            # return GenResponse(text=generated_text)

            # 流式输出控制
            prefill_word = tokenizer.decode(token_ids[-1], skip_special_tokens=True)
            prefill_word = prefill_word.strip().replace("\n", "\\n").replace("\"", "\\\"")

            seq_len = len(token_ids) - 1
            prefill_len = 128
            for step_idx in range(imer.max_seq_len):
                if prefill_len > 0 and step_idx < seq_len:
                    continue
                token_ids, next_token_id = imer.decode_next_token(tokenizer, token_ids, embeds, slice_len=128, step_idx=step_idx)
                if next_token_id == tokenizer.eos_token_id and next_token_id > seq_len:
                    break
                try:
                    if next_token_id is not None:
                        word = tokenizer.decode([next_token_id], skip_special_tokens=True)
                        generated_text += word
                        if prefill_word is not None:
                            word = prefill_word + word
                            prefill_word = None
                        # 以适合前端处理的 SSE 格式输出
                        # 处理特殊字符
                        word = word.strip().replace("\n", "\\n").replace("\"", "\\\"")
                        # import pdb; pdb.set_trace()
                        yield f"data: {{\"token\": \"{word}\"}}\n\n"
                except Exception as e:
                    print(f"Error decoding token {next_token_id}: {e}")

        return StreamingResponse(
            generate_stream(),
            media_type="text/event-stream",  # 必须使用SSE格式
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "X-Accel-Buffering": "no"  # 禁用Nginx缓冲
            }
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=LLM_API_PORT, reload=False)