Spaces:
Runtime error
Runtime error
| import uvicorn | |
| from fastapi import FastAPI, Form, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from inference import inferenceAPI | |
| # from transformers import RobertaTokenizerFast, EncoderDecoderModel | |
| # ------- MODELE HUGGING FACE QUI MARCHE BIEN ------------------------------------ | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # ckpt = 'mrm8488/camembert2camembert_shared-finetuned-french-summarization' | |
| # tokenizer = RobertaTokenizerFast.from_pretrained(ckpt) | |
| # model = EncoderDecoderModel.from_pretrained(ckpt).to(device) | |
| # def generate_summary(text): | |
| # inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
| # input_ids = inputs.input_ids.to(device) | |
| # attention_mask = inputs.attention_mask.to(device) | |
| # output = model.generate(input_ids, attention_mask=attention_mask) | |
| # return tokenizer.decode(output[0], skip_special_tokens=True) | |
| # ---------------------------------------------------------------------------------- | |
| # ------ NOTRE MODELE -------------------------------------------------------------- | |
| # appel de la fonction inférence, adaptée pour une entrée txt | |
| def summarize(text: str): | |
| return " ".join(inferenceAPI(text)) | |
| # ---------------------------------------------------------------------------------- | |
| # -------- API --------------------------------------------------------------------- | |
| app = FastAPI() | |
| # static pour tout ce qui est css | |
| templates = Jinja2Templates(directory="templates") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| app.mount("/templates", StaticFiles(directory="templates"), name="templates") | |
| async def index(request: Request): | |
| return templates.TemplateResponse("index.html.jinja", {"request": request}) | |
| # pour donner les predictions | |
| async def prediction(request: Request, text: str = Form(...)): | |
| summary = summarize(text) | |
| return templates.TemplateResponse( | |
| "index.html.jinja", {"request": request, "text": text, "summary": summary} | |
| ) | |
| # ------------------------------------------------------------------------------------ | |
| # pour lancer le serveur et le reload à chaque changement sauvegardé dans le repo | |
| if __name__ == "__main__": | |
| uvicorn.run("api:app", port=8000, reload=True) | |