|
|
import fastapi |
|
|
from fastapi import HTTPException |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
def get_model_path(model_name): |
|
|
if model_name == "llama3.1-8b-instruct": |
|
|
return "/home/lijiaqi/ParaAgent/model/meta-llama/Llama-3.1-8B-Instruct" |
|
|
elif model_name == "qwen2.5-7b-instruct": |
|
|
return "/home/lijiaqi/ParaAgent/model/Qwen/Qwen2.5-7B-Instruct" |
|
|
else: |
|
|
return model_name |
|
|
|
|
|
|
|
|
class LocalModelClient: |
|
|
def __init__(self, model_name: str): |
|
|
self.model_name = model_name |
|
|
self.model_data = self.load_model(model_name) |
|
|
|
|
|
def load_model(self, model_name: str): |
|
|
model_path = get_model_path(model_name) |
|
|
if not os.path.exists(model_path): |
|
|
raise HTTPException(status_code=404, detail="Model path not found") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
return {"model": model, "tokenizer": tokenizer} |
|
|
|
|
|
def chat(self, messages: list, stream: bool = False): |
|
|
tokenizer = self.model_data["tokenizer"] |
|
|
model = self.model_data["model"] |
|
|
|
|
|
|
|
|
input_text = " ".join([msg["content"] for msg in messages]) |
|
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
|
outputs = model.generate(inputs["input_ids"], max_length=100) |
|
|
|
|
|
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
return { |
|
|
"choices": [{"message": {"content": response_text}}], |
|
|
"usage": {"completion_tokens": len(outputs[0]), "prompt_tokens": len(inputs["input_ids"][0])} |
|
|
} |
|
|
|
|
|
|
|
|
app = fastapi.FastAPI() |
|
|
|
|
|
|
|
|
class ModelRequest(BaseModel): |
|
|
model_name: str |
|
|
messages: list |
|
|
|
|
|
@app.post("/predict/") |
|
|
async def predict(request: ModelRequest): |
|
|
try: |
|
|
|
|
|
client = LocalModelClient(request.model_name) |
|
|
|
|
|
|
|
|
response = client.chat(messages=request.messages, stream=False) |
|
|
|
|
|
|
|
|
content = response["choices"][0]["message"]["content"] |
|
|
usage_info = { |
|
|
"completion_tokens": response["usage"]["completion_tokens"], |
|
|
"prompt_tokens": response["usage"]["prompt_tokens"], |
|
|
} |
|
|
return {"content": content, "usage": usage_info} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|