| import os |
| import torch |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct") |
|
|
| app = FastAPI(title="Qwen FastAPI") |
|
|
| tokenizer = None |
| model = None |
|
|
|
|
| class GenerateRequest(BaseModel): |
| system_prompt: str |
| user_prompt: str |
| max_new_tokens: int = 400 |
| temperature: float = 0.7 |
| top_p: float = 0.9 |
| do_sample: bool = True |
|
|
|
|
| @app.on_event("startup") |
| def startup_event(): |
| global tokenizer, model |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
| |
| has_cuda = torch.cuda.is_available() |
| dtype = torch.bfloat16 if has_cuda else torch.float32 |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=dtype, |
| device_map="auto" |
| ) |
|
|
| print("Model ready") |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "model": MODEL_NAME} |
|
|
|
|
| @app.post("/generate") |
| def generate(req: GenerateRequest): |
| global tokenizer, model |
|
|
| messages = [ |
| {"role": "system", "content": req.system_prompt}, |
| {"role": "user", "content": req.user_prompt} |
| ] |
|
|
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
| print("\n=== Incoming Request ===") |
| print("SYSTEM:", req.system_prompt) |
| print("USER:", req.user_prompt) |
|
|
| with torch.no_grad(): |
| generated_ids = model.generate( |
| **model_inputs, |
| max_new_tokens=req.max_new_tokens, |
| do_sample=req.do_sample, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| ) |
|
|
| new_tokens = generated_ids[0, model_inputs["input_ids"].shape[-1]:] |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
| print("\n=== Model Response ===") |
| print(response) |
| print("======================\n") |
|
|
| return {"response": response} |
|
|