Spaces:
Runtime error
Runtime error
File size: 3,633 Bytes
d536fa3 48da5f8 d536fa3 624eb07 d536fa3 624eb07 d536fa3 48da5f8 d536fa3 624eb07 48da5f8 d536fa3 f47337e d536fa3 48da5f8 d536fa3 624eb07 f47337e d536fa3 f47337e 48da5f8 f47337e d536fa3 48da5f8 d536fa3 f47337e 48da5f8 d536fa3 48da5f8 d536fa3 48da5f8 624eb07 f47337e 48da5f8 624eb07 d536fa3 f47337e 624eb07 d536fa3 f47337e 48da5f8 624eb07 d536fa3 f47337e 624eb07 d536fa3 624eb07 48da5f8 624eb07 d536fa3 f47337e d536fa3 f47337e d536fa3 48da5f8 d536fa3 624eb07 f47337e d536fa3 624eb07 d536fa3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | # -*- coding: utf-8 -*-
"""
FastAPI Application loading FLAN-T5-Base (approx 780MB) directly from Hugging Face
for low-latency, API-free simplification based purely on prompt engineering.
"""
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
# --- Configuration ---
# SWITCHED TO FLAN-T5-Base (approx 780MB) for superior instruction-following accuracy.
BASE_MODEL_ID = "google/flan-t5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Global Model Variables ---
tokenizer = None
model = None
model_loaded_status = "PENDING"
# Initialize FastAPI app
app = FastAPI(
title="HF FLAN-T5-Base Simplifier",
description="Loads FLAN-T5-Base for low-latency, instruction-based simplification.",
version="1.0.0"
)
# Pydantic schema for the input request body
class TextRequest(BaseModel):
text: str
# --- Model Loading and Initialization (Startup Event) ---
@app.on_event("startup")
def load_model_on_startup():
"""Loads the FLAN-T5-Base model directly from Hugging Face."""
global tokenizer, model, model_loaded_status
try:
print(f"Loading base model {BASE_MODEL_ID} on device: {DEVICE}")
# 1. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
# 2. Load Model
# CRITICAL SPEED FIX: Force bfloat16 for optimal T4 GPU performance
model = AutoModelForSeq2SeqLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16,
).to(DEVICE).eval()
model_loaded_status = "OK"
print("Model loaded successfully from Hugging Face.")
except Exception as e:
model_loaded_status = f"ERROR: {str(e)}"
print(f"FATAL MODEL LOADING ERROR: {model_loaded_status}")
# --- API Endpoints ---
@app.get("/health")
def health_check():
"""Returns the status of the API and model loading."""
return {"status": "ok" if model_loaded_status == "OK" else "error", "detail": model_loaded_status}
@app.post("/simplify")
def simplify_text_api(request: TextRequest):
"""Accepts complex text and returns the simplified version."""
if model_loaded_status != "OK":
return {"error": "Model failed to load during startup. Check logs."}
text = request.text
if not text:
return {"simplified_text": ""}
# FINAL QUALITY FIX: AGGRESSIVE, DETAILED PROMPT for filtering and simplification.
prompt = (
f"You are a text clarity editor. Preserve all core facts and context. "
f"Remove all filler words (like 'uh', 'um', 'you know'), jargon, and unnecessary complexity. "
f"Output ONLY the simplified text. Simplify: {text}"
)
try:
# 1. Tokenize Input
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=128,
truncation=True
).to(DEVICE)
# 2. Generate Output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
num_beams=4,
length_penalty=0.6,
repetition_penalty=2.0
)
# 3. Decode and return the result
simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"simplified_text": simplified_text}
except Exception as e:
print(f"Inference error: {e}")
return {"error": "Inference failed due to an internal server error."}
|