test-api / main.py
Mr-Help's picture
Update main.py
6867f65 verified
raw
history blame
2.16 kB
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct")
app = FastAPI(title="Qwen FastAPI")
tokenizer = None
model = None
class GenerateRequest(BaseModel):
system_prompt: str
user_prompt: str
max_new_tokens: int = 400
temperature: float = 0.7
top_p: float = 0.9
do_sample: bool = True
@app.on_event("startup")
def startup_event():
global tokenizer, model
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# dtype: bfloat16 on CUDA, float32 on CPU
has_cuda = torch.cuda.is_available()
dtype = torch.bfloat16 if has_cuda else torch.float32
# Load model (auto device placement)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto"
)
print("Model ready") # βœ… Ω…Ψ·Ω„ΩˆΨ¨ Ω…Ω†Ωƒ
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME}
@app.post("/generate")
def generate(req: GenerateRequest):
global tokenizer, model
messages = [
{"role": "system", "content": req.system_prompt},
{"role": "user", "content": req.user_prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
print("\n=== Incoming Request ===")
print("SYSTEM:", req.system_prompt)
print("USER:", req.user_prompt)
with torch.no_grad():
generated_ids = model.generate(
**model_inputs,
max_new_tokens=req.max_new_tokens,
do_sample=req.do_sample,
temperature=req.temperature,
top_p=req.top_p,
)
new_tokens = generated_ids[0, model_inputs["input_ids"].shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
print("\n=== Model Response ===")
print(response)
print("======================\n")
return {"response": response}