File size: 3,633 Bytes
d536fa3
 
48da5f8
 
d536fa3
624eb07
 
 
 
d536fa3
624eb07
d536fa3
48da5f8
 
d536fa3
624eb07
48da5f8
d536fa3
 
f47337e
d536fa3
 
 
48da5f8
 
d536fa3
 
 
 
 
 
624eb07
f47337e
 
d536fa3
f47337e
48da5f8
f47337e
d536fa3
48da5f8
d536fa3
f47337e
48da5f8
d536fa3
48da5f8
 
 
d536fa3
48da5f8
 
624eb07
f47337e
48da5f8
624eb07
d536fa3
f47337e
 
624eb07
d536fa3
 
 
 
f47337e
48da5f8
624eb07
d536fa3
 
 
f47337e
 
624eb07
d536fa3
 
 
624eb07
48da5f8
 
 
 
 
 
624eb07
d536fa3
f47337e
d536fa3
 
 
 
 
 
 
f47337e
d536fa3
 
 
 
 
48da5f8
d536fa3
 
624eb07
f47337e
d536fa3
 
624eb07
d536fa3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- coding: utf-8 -*-
"""
FastAPI Application loading FLAN-T5-Base (approx 780MB) directly from Hugging Face
for low-latency, API-free simplification based purely on prompt engineering.
"""
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os

# --- Configuration ---
# SWITCHED TO FLAN-T5-Base (approx 780MB) for superior instruction-following accuracy.
BASE_MODEL_ID = "google/flan-t5-base" 
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Global Model Variables ---
tokenizer = None
model = None
model_loaded_status = "PENDING"

# Initialize FastAPI app
app = FastAPI(
    title="HF FLAN-T5-Base Simplifier",
    description="Loads FLAN-T5-Base for low-latency, instruction-based simplification.",
    version="1.0.0"
)

# Pydantic schema for the input request body
class TextRequest(BaseModel):
    text: str
    
# --- Model Loading and Initialization (Startup Event) ---

@app.on_event("startup")
def load_model_on_startup():
    """Loads the FLAN-T5-Base model directly from Hugging Face."""
    global tokenizer, model, model_loaded_status
    try:
        print(f"Loading base model {BASE_MODEL_ID} on device: {DEVICE}")

        # 1. Load Tokenizer
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
        
        # 2. Load Model
        # CRITICAL SPEED FIX: Force bfloat16 for optimal T4 GPU performance
        model = AutoModelForSeq2SeqLM.from_pretrained(
            BASE_MODEL_ID,
            torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16,
        ).to(DEVICE).eval()
        
        model_loaded_status = "OK"
        print("Model loaded successfully from Hugging Face.")

    except Exception as e:
        model_loaded_status = f"ERROR: {str(e)}"
        print(f"FATAL MODEL LOADING ERROR: {model_loaded_status}")

# --- API Endpoints ---

@app.get("/health")
def health_check():
    """Returns the status of the API and model loading."""
    return {"status": "ok" if model_loaded_status == "OK" else "error", "detail": model_loaded_status}

@app.post("/simplify")
def simplify_text_api(request: TextRequest):
    """Accepts complex text and returns the simplified version."""
    if model_loaded_status != "OK":
        return {"error": "Model failed to load during startup. Check logs."}

    text = request.text
    if not text:
        return {"simplified_text": ""}

    # FINAL QUALITY FIX: AGGRESSIVE, DETAILED PROMPT for filtering and simplification.
    prompt = (
        f"You are a text clarity editor. Preserve all core facts and context. "
        f"Remove all filler words (like 'uh', 'um', 'you know'), jargon, and unnecessary complexity. "
        f"Output ONLY the simplified text. Simplify: {text}"
    )
    
    try:
        # 1. Tokenize Input
        inputs = tokenizer(
            prompt, 
            return_tensors="pt", 
            max_length=128, 
            truncation=True
        ).to(DEVICE)

        # 2. Generate Output
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                num_beams=4,
                length_penalty=0.6,
                repetition_penalty=2.0 
            )

        # 3. Decode and return the result
        simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"simplified_text": simplified_text}
    
    except Exception as e:
        print(f"Inference error: {e}")
        return {"error": "Inference failed due to an internal server error."}