simcourt / api_pool /local_model_api_vllm_base.py
GakkiLi's picture
Upload folder using huggingface_hub
94bdfd0 verified
import fastapi
from fastapi import HTTPException
from pydantic import BaseModel
from vllm import LLM, SamplingParams
import os
# 假设你已经定义了一个函数来加载模型路径
def get_model_path(model_name):
if model_name == "Llama-3.1-8B-Instruct":
return "/liuzyai04/thuir/LLM/Meta-Llama-3.1-8B-Instruct"
elif model_name == "qwen2.5-7b-instruct":
return "/liuzyai04/thuir/LLM/Qwen2.5-7B-Instruct"
elif model_name == "qwen2.5-32b-instruct":
return "/liuzyai04/thuir/LLM/Qwen2.5-32B-Instruct"
elif model_name == "QwQ-32B":
return "/liuzyai04/thuir/LLM/QwQ-32B"
elif model_name == "glm-4-9b-chat":
return "/liuzyai04/thuir/LLM/glm-4-9b-chat"
else:
return model_name
# 定义模型加载的API
class ModelRequest(BaseModel):
model_name: str
messages: list
app = fastapi.FastAPI()
# 修改全局变量定义
loaded_models = {}
# 重写加载模型函数
def load_model(model_name: str):
print("loading model: ", model_name)
if model_name in loaded_models:
return loaded_models[model_name]
print("model not found, loading...")
model_path = get_model_path(model_name)
if not os.path.exists(model_path):
raise HTTPException(status_code=404, detail="Model path not found")
print("model_path ok: ", model_path)
# 使用vLLM加载模型
tp_size = 4
model = LLM(
model_path,
trust_remote_code=True,
gpu_memory_utilization=0.9,
tensor_parallel_size=tp_size # 张量并行
)
print(f"Model loaded from {model_path} using {tp_size} GPUs")
loaded_models[model_name] = model
return loaded_models[model_name]
# 修改API路由
@app.post("/predict/")
async def predict(request: ModelRequest):
try:
# 加载模型
print("IN PREDICT~")
model = load_model(request.model_name)
# 准备输入
sampling_params = SamplingParams(
temperature=0.7,
max_tokens=1024,
stop=None,
seed=42
)
# print(f'last message: {request.messages[-1]["content"]}')
# 构建提示词
prompt = model.get_tokenizer().apply_chat_template(
request.messages,
tokenize=False,
add_generation_prompt=True
)
# 使用vLLM生成回复
outputs = model.generate(prompt, sampling_params)
response_text = outputs[0].outputs[0].text.strip()
# 计算token数量
input_tokens = len(model.get_tokenizer().encode(prompt))
output_tokens = len(model.get_tokenizer().encode(response_text))
return {
"choices": [{"message": {"content": response_text}}],
"usage": {
"completion_tokens": output_tokens,
"prompt_tokens": input_tokens
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 启动 FastAPI 服务器
if __name__ == "__main__":
import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app)