Zeldeo's picture
Update app.py
3a68ffb verified
# app.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Modèle HF Flan-T5
MODEL_NAME = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval()
app = FastAPI(title="Flan-T5 Service")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class ImproveRequest(BaseModel):
text: str
@app.get("/")
def health():
return {"status": "ok", "model": MODEL_NAME}
@app.post("/improve")
def improve_text(req: ImproveRequest):
try:
inputs = tokenizer(req.text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=512)
improved = tokenizer.decode(outputs[0], skip_special_tokens=True)
return JSONResponse({"success": True, "improved_text": improved})
except Exception as e:
return JSONResponse({"success": False, "error": str(e)}, status_code=500)