Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Form, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import re | |
| from src.inference_lstm import inference_lstm | |
| from src.inference_t5 import inference_t5 | |
| def summarize(text: str): | |
| """ | |
| Returns the summary of an input text. | |
| Parameter | |
| --------- | |
| text : str | |
| A text to summarize. | |
| Returns | |
| ------- | |
| :str | |
| The summary of the input text. | |
| """ | |
| if global_choose_model.var == "lstm": | |
| text = " ".join(inference_lstm(text)) | |
| return re.sub("^1|1$|<start>|<end>", "", text) | |
| elif global_choose_model.var == "fineTunedT5": | |
| text = inference_t5(text) | |
| return re.sub("<extra_id_0> ", "", text) | |
| elif global_choose_model.var == "": | |
| return "You have not chosen a model." | |
| def global_choose_model(model_choice): | |
| """This function allows to connect the choice of the | |
| model and the summary function by defining global variables. | |
| The aime is to access a variable outside of a function.""" | |
| if model_choice == "lstm": | |
| global_choose_model.var = "lstm" | |
| elif model_choice == "fineTunedT5": | |
| global_choose_model.var = "fineTunedT5" | |
| elif model_choice == " --- ": | |
| global_choose_model.var = "" | |
| # definition of the main elements used in the script | |
| model_list = [ | |
| {"model": " --- ", "name": " --- "}, | |
| {"model": "lstm", "name": "LSTM"}, | |
| {"model": "fineTunedT5", "name": "Fine-tuned T5"}, | |
| ] | |
| selected_model = " --- " | |
| model_choice = "" | |
| # -------- API --------------------------------------------------------------- | |
| app = FastAPI() | |
| # static files to send the css | |
| templates = Jinja2Templates(directory="templates") | |
| app.mount("/templates", StaticFiles(directory="templates"), name="templates") | |
| async def index(request: Request): | |
| """This function is used to create an endpoint for the | |
| index page of the app.""" | |
| return templates.TemplateResponse( | |
| "index.html.jinja", | |
| { | |
| "request": request, | |
| "current_route": "/", | |
| "model_list": model_list, | |
| "selected_model": selected_model, | |
| }, | |
| ) | |
| async def get_model(request: Request): | |
| """This function is used to create an endpoint for | |
| the model page of the app.""" | |
| return templates.TemplateResponse( | |
| "index.html.jinja", | |
| { | |
| "request": request, | |
| "current_route": "/model", | |
| "model_list": model_list, | |
| "selected_model": selected_model, | |
| }, | |
| ) | |
| async def get_prediction(request: Request): | |
| """This function is used to create an endpoint for | |
| the predict page of the app.""" | |
| return templates.TemplateResponse( | |
| "index.html.jinja", {"request": request, "current_route": "/predict"} | |
| ) | |
| async def choose_model(request: Request, model_choice: str = Form(None)): | |
| """This functions allows to retrieve the model chosen by the user. Then, it | |
| can end to an error message if it not defined or it is sent to the | |
| global_choose_model function which connects the user choice to the | |
| use of a model.""" | |
| selected_model = model_choice | |
| # print(selected_model) | |
| if not model_choice: | |
| model_error = "Please select a model." | |
| return templates.TemplateResponse( | |
| "index.html.jinja", | |
| { | |
| "request": request, | |
| "text": model_error, | |
| "model_list": model_list, | |
| "selected_model": selected_model, | |
| }, | |
| ) | |
| else: | |
| global_choose_model(model_choice) | |
| return templates.TemplateResponse( | |
| "index.html.jinja", | |
| { | |
| "request": request, | |
| "model_list": model_list, | |
| "selected_model": selected_model, | |
| }, | |
| ) | |
| async def prediction(request: Request, text: str = Form(None)): | |
| """This function allows to retrieve the input text of the user. | |
| Then, it can end to an error message or it can be sent to | |
| the summarize function.""" | |
| if not text: | |
| text_error = "Please enter your text." | |
| return templates.TemplateResponse( | |
| "index.html.jinja", | |
| { | |
| "request": request, | |
| "text": text_error, | |
| "model_list": model_list, | |
| "selected_model": selected_model, | |
| }, | |
| ) | |
| else: | |
| summary = summarize(text) | |
| return templates.TemplateResponse( | |
| "index.html.jinja", | |
| { | |
| "request": request, | |
| "text": text, | |
| "summary": summary, | |
| "model_list": model_list, | |
| "selected_model": selected_model, | |
| }, | |
| ) | |
| # ------------------------------------------------------------------------------------ | |
| # lancer le serveur et le recharger a chaque modification sauvegardee | |
| # if __name__ == "__main__": | |
| # uvicorn.run("api:app", port=8000, reload=True) | |