Summariser / app /main.py
govardhan-06
b
6836deb
raw
history blame contribute delete
942 Bytes
from fastapi import FastAPI
from transformers import BartForConditionalGeneration, BartTokenizer
import pickle
from pydantic import BaseModel
app = FastAPI()
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
# Load the Bart model
model = BartForConditionalGeneration.from_pretrained(model_name)
async def mymodel(text):
inputs = tokenizer(text, return_tensors="pt")
summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=100, min_length=30, length_penalty=2.0)
summary_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary_text
class Request(BaseModel):
input_text: str
class Response(BaseModel):
summ_text: str
@app.post("/summarise", response_model=Response)
async def predict(request: Request) -> Response:
x = request.input_text
summary_text = await mymodel(x)
return Response(summ_text=summary_text)