|
|
import fastapi |
|
|
from fastapi import HTTPException |
|
|
from pydantic import BaseModel |
|
|
from vllm import LLM, SamplingParams |
|
|
import os |
|
|
|
|
|
|
|
|
def get_model_path(model_name): |
|
|
if model_name == "Llama-3.1-8B-Instruct": |
|
|
return "/liuzyai04/thuir/LLM/Meta-Llama-3.1-8B-Instruct" |
|
|
elif model_name == "qwen2.5-7b-instruct": |
|
|
return "/liuzyai04/thuir/LLM/Qwen2.5-7B-Instruct" |
|
|
elif model_name == "qwen2.5-32b-instruct": |
|
|
return "/liuzyai04/thuir/LLM/Qwen2.5-32B-Instruct" |
|
|
elif model_name == "QwQ-32B": |
|
|
return "/liuzyai04/thuir/LLM/QwQ-32B" |
|
|
elif model_name == "glm-4-9b-chat": |
|
|
return "/liuzyai04/thuir/LLM/glm-4-9b-chat" |
|
|
else: |
|
|
return model_name |
|
|
|
|
|
|
|
|
class ModelRequest(BaseModel): |
|
|
model_name: str |
|
|
messages: list |
|
|
|
|
|
app = fastapi.FastAPI() |
|
|
|
|
|
|
|
|
loaded_models = {} |
|
|
|
|
|
|
|
|
def load_model(model_name: str): |
|
|
print("loading model: ", model_name) |
|
|
if model_name in loaded_models: |
|
|
return loaded_models[model_name] |
|
|
print("model not found, loading...") |
|
|
|
|
|
model_path = get_model_path(model_name) |
|
|
if not os.path.exists(model_path): |
|
|
raise HTTPException(status_code=404, detail="Model path not found") |
|
|
print("model_path ok: ", model_path) |
|
|
|
|
|
|
|
|
tp_size = 4 |
|
|
model = LLM( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
gpu_memory_utilization=0.9, |
|
|
tensor_parallel_size=tp_size |
|
|
) |
|
|
print(f"Model loaded from {model_path} using {tp_size} GPUs") |
|
|
|
|
|
loaded_models[model_name] = model |
|
|
return loaded_models[model_name] |
|
|
|
|
|
|
|
|
@app.post("/predict/") |
|
|
async def predict(request: ModelRequest): |
|
|
try: |
|
|
|
|
|
print("IN PREDICT~") |
|
|
model = load_model(request.model_name) |
|
|
|
|
|
|
|
|
sampling_params = SamplingParams( |
|
|
temperature=0.7, |
|
|
max_tokens=1024, |
|
|
stop=None, |
|
|
seed=42 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = model.get_tokenizer().apply_chat_template( |
|
|
request.messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
outputs = model.generate(prompt, sampling_params) |
|
|
response_text = outputs[0].outputs[0].text.strip() |
|
|
|
|
|
|
|
|
input_tokens = len(model.get_tokenizer().encode(prompt)) |
|
|
output_tokens = len(model.get_tokenizer().encode(response_text)) |
|
|
|
|
|
return { |
|
|
"choices": [{"message": {"content": response_text}}], |
|
|
"usage": { |
|
|
"completion_tokens": output_tokens, |
|
|
"prompt_tokens": input_tokens |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run(app) |
|
|
|