import torch from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = "saadkhi/SQL_Chat_finetuned_model" app = FastAPI(title="SQL Chatbot API") # Load model once (on startup) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map="auto" ) class QueryRequest(BaseModel): prompt: str max_new_tokens: int = 256 class QueryResponse(BaseModel): response: str @app.post("/generate", response_model=QueryResponse) def generate_answer(request: QueryRequest): inputs = tokenizer( request.prompt, return_tensors="pt" ).to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=request.max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9 ) output_text = tokenizer.decode( output_ids[0], skip_special_tokens=True ) return {"response": output_text}