SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
c6fae16 verified
raw
history blame
1.11 kB
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "saadkhi/SQL_Chat_finetuned_model"
app = FastAPI(title="SQL Chatbot API")
# Load model once (on startup)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto"
)
class QueryRequest(BaseModel):
prompt: str
max_new_tokens: int = 256
class QueryResponse(BaseModel):
response: str
@app.post("/generate", response_model=QueryResponse)
def generate_answer(request: QueryRequest):
inputs = tokenizer(
request.prompt,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9
)
output_text = tokenizer.decode(
output_ids[0],
skip_special_tokens=True
)
return {"response": output_text}