Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ | |
| FastAPI Application loading FLAN-T5-Base (approx 780MB) directly from Hugging Face | |
| for low-latency, API-free simplification based purely on prompt engineering. | |
| """ | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import os | |
| # --- Configuration --- | |
| # SWITCHED TO FLAN-T5-Base (approx 780MB) for superior instruction-following accuracy. | |
| BASE_MODEL_ID = "google/flan-t5-base" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- Global Model Variables --- | |
| tokenizer = None | |
| model = None | |
| model_loaded_status = "PENDING" | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="HF FLAN-T5-Base Simplifier", | |
| description="Loads FLAN-T5-Base for low-latency, instruction-based simplification.", | |
| version="1.0.0" | |
| ) | |
| # Pydantic schema for the input request body | |
| class TextRequest(BaseModel): | |
| text: str | |
| # --- Model Loading and Initialization (Startup Event) --- | |
| def load_model_on_startup(): | |
| """Loads the FLAN-T5-Base model directly from Hugging Face.""" | |
| global tokenizer, model, model_loaded_status | |
| try: | |
| print(f"Loading base model {BASE_MODEL_ID} on device: {DEVICE}") | |
| # 1. Load Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) | |
| # 2. Load Model | |
| # CRITICAL SPEED FIX: Force bfloat16 for optimal T4 GPU performance | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16, | |
| ).to(DEVICE).eval() | |
| model_loaded_status = "OK" | |
| print("Model loaded successfully from Hugging Face.") | |
| except Exception as e: | |
| model_loaded_status = f"ERROR: {str(e)}" | |
| print(f"FATAL MODEL LOADING ERROR: {model_loaded_status}") | |
| # --- API Endpoints --- | |
| def health_check(): | |
| """Returns the status of the API and model loading.""" | |
| return {"status": "ok" if model_loaded_status == "OK" else "error", "detail": model_loaded_status} | |
| def simplify_text_api(request: TextRequest): | |
| """Accepts complex text and returns the simplified version.""" | |
| if model_loaded_status != "OK": | |
| return {"error": "Model failed to load during startup. Check logs."} | |
| text = request.text | |
| if not text: | |
| return {"simplified_text": ""} | |
| # FINAL QUALITY FIX: AGGRESSIVE, DETAILED PROMPT for filtering and simplification. | |
| prompt = ( | |
| f"You are a text clarity editor. Preserve all core facts and context. " | |
| f"Remove all filler words (like 'uh', 'um', 'you know'), jargon, and unnecessary complexity. " | |
| f"Output ONLY the simplified text. Simplify: {text}" | |
| ) | |
| try: | |
| # 1. Tokenize Input | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=128, | |
| truncation=True | |
| ).to(DEVICE) | |
| # 2. Generate Output | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| num_beams=4, | |
| length_penalty=0.6, | |
| repetition_penalty=2.0 | |
| ) | |
| # 3. Decode and return the result | |
| simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return {"simplified_text": simplified_text} | |
| except Exception as e: | |
| print(f"Inference error: {e}") | |
| return {"error": "Inference failed due to an internal server error."} | |