simcourt / api_pool /local_api.py
GakkiLi's picture
Upload folder using huggingface_hub
94bdfd0 verified
import fastapi
from fastapi import HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
# 假设你已经定义了一个函数来加载模型路径
def get_model_path(model_name):
if model_name == "llama3.1-8b-instruct":
return "/home/lijiaqi/ParaAgent/model/meta-llama/Llama-3.1-8B-Instruct"
elif model_name == "qwen2.5-7b-instruct":
return "/home/lijiaqi/ParaAgent/model/Qwen/Qwen2.5-7B-Instruct"
else:
return model_name
# 创建一个client类来模拟调用模型
class LocalModelClient:
def __init__(self, model_name: str):
self.model_name = model_name
self.model_data = self.load_model(model_name)
def load_model(self, model_name: str):
model_path = get_model_path(model_name)
if not os.path.exists(model_path):
raise HTTPException(status_code=404, detail="Model path not found")
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return {"model": model, "tokenizer": tokenizer}
def chat(self, messages: list, stream: bool = False):
tokenizer = self.model_data["tokenizer"]
model = self.model_data["model"]
# 假设messages是一个包含dict的列表,每个dict包含"role"和"content"
input_text = " ".join([msg["content"] for msg in messages])
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(inputs["input_ids"], max_length=100)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 模拟返回的格式
return {
"choices": [{"message": {"content": response_text}}],
"usage": {"completion_tokens": len(outputs[0]), "prompt_tokens": len(inputs["input_ids"][0])}
}
# FastAPI部分
app = fastapi.FastAPI()
# 定义请求模型
class ModelRequest(BaseModel):
model_name: str
messages: list # 保证接收的字段名为 messages
@app.post("/predict/")
async def predict(request: ModelRequest):
try:
# 根据请求的模型名称创建client
client = LocalModelClient(request.model_name)
# 调用chat方法获取模型的响应
response = client.chat(messages=request.messages, stream=False)
# 返回内容和token使用信息
content = response["choices"][0]["message"]["content"]
usage_info = {
"completion_tokens": response["usage"]["completion_tokens"],
"prompt_tokens": response["usage"]["prompt_tokens"],
}
return {"content": content, "usage": usage_info}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 启动 FastAPI 服务器
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)