Text_Summarizer / app.py
Amanprime's picture
Update app.py
ede19cb verified
Raw
History Blame Contribute Delete
2.26 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel # validate text input
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import re
from fastapi.templating import Jinja2Templates # UI
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
app = FastAPI(title="Text Summarizer App", description="Text Summarization using T5", version="1.0")
model = T5ForConditionalGeneration.from_pretrained("Amanprime/T5_Text_Summarization")
tokenizer = T5Tokenizer.from_pretrained("Amanprime/T5_Text_Summarization")
# device
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model.to(device)
templates = Jinja2Templates(directory=".")
class DialogueInput(BaseModel): # input scheama for dialogue , list exists.
dialogue: str
def clean_data(text):
text = re.sub(r"\r\n", " ", text) # lines
text = re.sub(r"\s+", " ", text) # spaces
text = re.sub(r"<.*?>", " ", text) # html tags <p> <h1>
text = text.strip().lower()
return text
def summarize_dialogue(dialogue : str) -> str:
dialogue = clean_data(dialogue) # clean
# tokenize
inputs = tokenizer(
dialogue,
padding="max_length",
max_length=512,
truncation=True,
return_tensors="pt"
).to(device)
# generate the summary => token ids
model.to(device)
targets = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=150,
num_beams=4,
early_stopping=True
)
# decoded our output
summary = tokenizer.decode(targets[0], skip_special_tokens=True) # EOS, SEP
return summary
# API endpoints
@app.post("/summarize/")
async def summarize(dialogue_input: DialogueInput): # dialogue_input type - DialogueInput
summary = summarize_dialogue(dialogue_input.dialogue)
return {"summary": summary}
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse(request=request, name="index.html")