# -*- coding: utf-8 -*- """ FastAPI Application loading FLAN-T5-Base (approx 780MB) directly from Hugging Face for low-latency, API-free simplification based purely on prompt engineering. """ from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import os # --- Configuration --- # SWITCHED TO FLAN-T5-Base (approx 780MB) for superior instruction-following accuracy. BASE_MODEL_ID = "google/flan-t5-base" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # --- Global Model Variables --- tokenizer = None model = None model_loaded_status = "PENDING" # Initialize FastAPI app app = FastAPI( title="HF FLAN-T5-Base Simplifier", description="Loads FLAN-T5-Base for low-latency, instruction-based simplification.", version="1.0.0" ) # Pydantic schema for the input request body class TextRequest(BaseModel): text: str # --- Model Loading and Initialization (Startup Event) --- @app.on_event("startup") def load_model_on_startup(): """Loads the FLAN-T5-Base model directly from Hugging Face.""" global tokenizer, model, model_loaded_status try: print(f"Loading base model {BASE_MODEL_ID} on device: {DEVICE}") # 1. Load Tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) # 2. Load Model # CRITICAL SPEED FIX: Force bfloat16 for optimal T4 GPU performance model = AutoModelForSeq2SeqLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16, ).to(DEVICE).eval() model_loaded_status = "OK" print("Model loaded successfully from Hugging Face.") except Exception as e: model_loaded_status = f"ERROR: {str(e)}" print(f"FATAL MODEL LOADING ERROR: {model_loaded_status}") # --- API Endpoints --- @app.get("/health") def health_check(): """Returns the status of the API and model loading.""" return {"status": "ok" if model_loaded_status == "OK" else "error", "detail": model_loaded_status} @app.post("/simplify") def simplify_text_api(request: TextRequest): """Accepts complex text and returns the simplified version.""" if model_loaded_status != "OK": return {"error": "Model failed to load during startup. Check logs."} text = request.text if not text: return {"simplified_text": ""} # FINAL QUALITY FIX: AGGRESSIVE, DETAILED PROMPT for filtering and simplification. prompt = ( f"You are a text clarity editor. Preserve all core facts and context. " f"Remove all filler words (like 'uh', 'um', 'you know'), jargon, and unnecessary complexity. " f"Output ONLY the simplified text. Simplify: {text}" ) try: # 1. Tokenize Input inputs = tokenizer( prompt, return_tensors="pt", max_length=128, truncation=True ).to(DEVICE) # 2. Generate Output with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=128, num_beams=4, length_penalty=0.6, repetition_penalty=2.0 ) # 3. Decode and return the result simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"simplified_text": simplified_text} except Exception as e: print(f"Inference error: {e}") return {"error": "Inference failed due to an internal server error."}