Briefly / app.py
vishnuraggav's picture
Initial Commit 2
36bda2a
raw
history blame contribute delete
854 Bytes
''' Import Modules '''
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI
from pydantic import BaseModel
''' Model '''
save_path = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(save_path)
model = T5ForConditionalGeneration.from_pretrained(save_path)
''' Server '''
class Request(BaseModel):
text: str
maxlen: int
app = FastAPI()
@app.post("/")
def main(request: Request):
task_prefix = "summarize: "
input_ids = tokenizer.encode(task_prefix + request.text, return_tensors="pt", max_length=512, truncation=True)
output_ids = model.generate(
input_ids,
max_length=request.maxlen,
num_beams=3,
early_stopping=True,
no_repeat_ngram_size=5
)
result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"output": result}