simplifier / app.py
Shivangguptasih's picture
Update app.py
48da5f8 verified
# -*- coding: utf-8 -*-
"""
FastAPI Application loading FLAN-T5-Base (approx 780MB) directly from Hugging Face
for low-latency, API-free simplification based purely on prompt engineering.
"""
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
# --- Configuration ---
# SWITCHED TO FLAN-T5-Base (approx 780MB) for superior instruction-following accuracy.
BASE_MODEL_ID = "google/flan-t5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Global Model Variables ---
tokenizer = None
model = None
model_loaded_status = "PENDING"
# Initialize FastAPI app
app = FastAPI(
title="HF FLAN-T5-Base Simplifier",
description="Loads FLAN-T5-Base for low-latency, instruction-based simplification.",
version="1.0.0"
)
# Pydantic schema for the input request body
class TextRequest(BaseModel):
text: str
# --- Model Loading and Initialization (Startup Event) ---
@app.on_event("startup")
def load_model_on_startup():
"""Loads the FLAN-T5-Base model directly from Hugging Face."""
global tokenizer, model, model_loaded_status
try:
print(f"Loading base model {BASE_MODEL_ID} on device: {DEVICE}")
# 1. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
# 2. Load Model
# CRITICAL SPEED FIX: Force bfloat16 for optimal T4 GPU performance
model = AutoModelForSeq2SeqLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16,
).to(DEVICE).eval()
model_loaded_status = "OK"
print("Model loaded successfully from Hugging Face.")
except Exception as e:
model_loaded_status = f"ERROR: {str(e)}"
print(f"FATAL MODEL LOADING ERROR: {model_loaded_status}")
# --- API Endpoints ---
@app.get("/health")
def health_check():
"""Returns the status of the API and model loading."""
return {"status": "ok" if model_loaded_status == "OK" else "error", "detail": model_loaded_status}
@app.post("/simplify")
def simplify_text_api(request: TextRequest):
"""Accepts complex text and returns the simplified version."""
if model_loaded_status != "OK":
return {"error": "Model failed to load during startup. Check logs."}
text = request.text
if not text:
return {"simplified_text": ""}
# FINAL QUALITY FIX: AGGRESSIVE, DETAILED PROMPT for filtering and simplification.
prompt = (
f"You are a text clarity editor. Preserve all core facts and context. "
f"Remove all filler words (like 'uh', 'um', 'you know'), jargon, and unnecessary complexity. "
f"Output ONLY the simplified text. Simplify: {text}"
)
try:
# 1. Tokenize Input
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=128,
truncation=True
).to(DEVICE)
# 2. Generate Output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
num_beams=4,
length_penalty=0.6,
repetition_penalty=2.0
)
# 3. Decode and return the result
simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"simplified_text": simplified_text}
except Exception as e:
print(f"Inference error: {e}")
return {"error": "Inference failed due to an internal server error."}