|
|
import fastapi |
|
|
from fastapi import HTTPException |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import os |
|
|
|
|
|
|
|
|
def get_model_path(model_name): |
|
|
if model_name == "llama3.1-8b-instruct": |
|
|
return "/home/lijiaqi/ParaAgent/model/meta-llama/Llama-3.1-8B-Instruct" |
|
|
elif model_name == "qwen2.5-7b-instruct": |
|
|
return "/home/lijiaqi/ParaAgent/model/Qwen/Qwen2.5-7B-Instruct" |
|
|
else: |
|
|
return model_name |
|
|
|
|
|
|
|
|
class ModelRequest(BaseModel): |
|
|
model_name: str |
|
|
messages: list |
|
|
|
|
|
app = fastapi.FastAPI() |
|
|
|
|
|
|
|
|
loaded_models = {} |
|
|
|
|
|
|
|
|
def load_model(model_name: str): |
|
|
print("model_name to load: ", model_name) |
|
|
if model_name in loaded_models: |
|
|
return loaded_models[model_name] |
|
|
|
|
|
model_path = get_model_path(model_name) |
|
|
if not os.path.exists(model_path): |
|
|
raise HTTPException(status_code=404, detail="Model path not found") |
|
|
print("model_path ok: ", model_path) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
print(f"Model and tokenizer loaded from {model_path}") |
|
|
|
|
|
loaded_models[model_name] = {"model": model, "tokenizer": tokenizer} |
|
|
|
|
|
print("model loaded: ", model_name) |
|
|
return loaded_models[model_name] |
|
|
|
|
|
|
|
|
@app.post("/predict/") |
|
|
async def predict(request: ModelRequest): |
|
|
try: |
|
|
|
|
|
model_data = load_model(request.model_name) |
|
|
model = model_data["model"] |
|
|
tokenizer = model_data["tokenizer"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
tokenizer.apply_chat_template(request.messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False, |
|
|
pad_token_id=tokenizer.pad_token_id), |
|
|
return_tensors = "pt") |
|
|
|
|
|
print(f"inputs: {inputs}") |
|
|
|
|
|
outputs = model.generate(inputs['input_ids'], max_length=8000) |
|
|
print(f"outputs: {outputs}") |
|
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
print(f"result: {result}") |
|
|
|
|
|
if "assistant\n" in result: |
|
|
response_text = result.split("assistant\n", 1)[-1].strip() |
|
|
else: |
|
|
response_text = result.strip() |
|
|
|
|
|
print(f"response_text: {response_text}") |
|
|
|
|
|
return { |
|
|
"choices": [{"message": {"content": response_text}}], |
|
|
"usage": { |
|
|
"completion_tokens": len(outputs[0]), |
|
|
"prompt_tokens": len(inputs["input_ids"][0]) |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|