sqlspace / app.py
16pramodh's picture
adding files
9be26a7
raw
history blame
887 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import uvicorn
# Load model and tokenizer once at startup
MODEL_NAME = "16pramodh/t2s_model"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# Create FastAPI app
app = FastAPI()
# Request body format
class QueryRequest(BaseModel):
text: str
@app.post("/predict")
def predict(request: QueryRequest):
try:
inputs = tokenizer(request.text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=256)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"sql": sql_query}
except Exception as e:
return {"error": str(e)}
# For local testing
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)