shahzaib201's picture
Update main.py
d223493 verified
raw
history blame contribute delete
998 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the model and tokenizer
model_name = "shahzaib201/AI_OEL"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Pydantic model for input validation
class TextInput(BaseModel):
text: str
max_length: int = 150
# Initialize FastAPI app
app = FastAPI()
# Endpoint for text summarization
@app.post("/summarize_text")
async def summarize_text_endpoint(item: TextInput):
# Tokenize the input text
inputs = tokenizer(item.text, return_tensors="pt", max_length=1024, truncation=True)
# Generate the summary
summary_ids = model.generate(inputs.input_ids, max_length=item.max_length, num_beams=4, length_penalty=2.0, early_stopping=True)
# Decode the generated summary
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}