| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import uvicorn |
|
|
| |
| MODEL_NAME = "16pramodh/t2s_model" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
| |
| app = FastAPI() |
|
|
| |
| class QueryRequest(BaseModel): |
| text: str |
|
|
| @app.post("/predict") |
| def predict(request: QueryRequest): |
| try: |
| inputs = tokenizer(request.text, return_tensors="pt") |
| outputs = model.generate(**inputs, max_length=256) |
| sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return {"sql": sql_query} |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| |
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|