simcourt / api_pool /local_model_api.py
GakkiLi's picture
Upload folder using huggingface_hub
94bdfd0 verified
import fastapi
from fastapi import HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
# 假设你已经定义了一个函数来加载模型路径
def get_model_path(model_name):
if model_name == "llama3.1-8b-instruct":
return "/home/lijiaqi/ParaAgent/model/meta-llama/Llama-3.1-8B-Instruct"
elif model_name == "qwen2.5-7b-instruct":
return "/home/lijiaqi/ParaAgent/model/Qwen/Qwen2.5-7B-Instruct"
else:
return model_name
# 定义模型加载的API
class ModelRequest(BaseModel):
model_name: str
messages: list
app = fastapi.FastAPI()
# 创建一个全局模型字典,防止重复加载
loaded_models = {}
# 加载模型
def load_model(model_name: str):
print("model_name to load: ", model_name)
if model_name in loaded_models:
return loaded_models[model_name]
model_path = get_model_path(model_name)
if not os.path.exists(model_path):
raise HTTPException(status_code=404, detail="Model path not found")
print("model_path ok: ", model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(f"Model and tokenizer loaded from {model_path}")
loaded_models[model_name] = {"model": model, "tokenizer": tokenizer}
print("model loaded: ", model_name)
return loaded_models[model_name]
# API 路由,处理请求
@app.post("/predict/")
async def predict(request: ModelRequest):
try:
# 加载模型
model_data = load_model(request.model_name)
model = model_data["model"]
tokenizer = model_data["tokenizer"]
# 手动补全
# print("model loaded : ", model)
# # 将消息列表转换为 Qwen 的对话格式
# prompt = ""
# for msg in request.messages:
# role = msg['role']
# content = msg['content']
# if role == "system":
# prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
# elif role == "user":
# prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
# elif role == "assistant":
# prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
# prompt += "<|im_start|>assistant\n" # 添加助手角色标记以开始生成
# # 编码输入文本
# inputs = tokenizer(prompt, return_tensors="pt")
# 使用chat函数补全
inputs = tokenizer(
tokenizer.apply_chat_template(request.messages,
add_generation_prompt=True,
tokenize=False,
pad_token_id=tokenizer.pad_token_id),
return_tensors = "pt")
print(f"inputs: {inputs}")
# 获取模型输出
outputs = model.generate(inputs['input_ids'], max_length=8000)
print(f"outputs: {outputs}")
# 解码输出
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"result: {result}")
# 解析最终回复
if "assistant\n" in result:
response_text = result.split("assistant\n", 1)[-1].strip()
else:
response_text = result.strip() # 兜底处理,防止解析失败
print(f"response_text: {response_text}")
# 生成返回格式
return {
"choices": [{"message": {"content": response_text}}],
"usage": {
"completion_tokens": len(outputs[0]),
"prompt_tokens": len(inputs["input_ids"][0])
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 启动 FastAPI 服务器
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)