Devang1290
feat: deploy News Whisper on-demand search API (FastAPI + Docker)
2cb327c
"""
Summarization Model Configuration
===================================
Manages AI model selection and inference for article summarization.
English Pipeline:
Uses configurable models (default: t5-small) via MODEL_TYPE env var.
Options: t5-small, distilbart, bart, t5, pegasus, led
Hindi Pipeline (via english_summary.py fallback only):
Uses L3Cube-Pune/Hindi-BART-Summary.
NOTE: The recommended Hindi path is hindi_summary.py (mT5 ONNX + Groq).
This model is only used if summarize is called directly on Hindi articles.
Architecture:
SummarizationModel is a thread-safe singleton. Each language's model is
loaded lazily on first use and cached in memory for the batch run.
Usage:
from backend.summarization.model import get_summarizer
summarizer = get_summarizer()
summary = summarizer.summarize(article_text, max_words=150, language="english")
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import threading
from dotenv import load_dotenv
load_dotenv()
# ─────────────────────────────────────────────
# Device Configuration
# ─────────────────────────────────────────────
ENV_DEVICE = os.getenv('DEVICE', 'cpu').lower()
if ENV_DEVICE == 'gpu':
USE_GPU = True
DEVICE = 0 if torch.cuda.is_available() else -1
if DEVICE == -1:
print("Warning: GPU requested but CUDA not available, falling back to CPU")
elif ENV_DEVICE == 'cpu':
USE_GPU = False
DEVICE = -1
else:
USE_GPU = True
DEVICE = 0 if torch.cuda.is_available() else -1
# CPU Optimization
if DEVICE == -1:
CPU_THREADS = int(os.getenv('MAX_WORKERS', '4'))
torch.set_num_threads(CPU_THREADS)
torch.set_num_interop_threads(CPU_THREADS)
# ─────────────────────────────────────────────
# Model Selection
# ─────────────────────────────────────────────
MODEL_TYPE = os.getenv('MODEL_TYPE', 't5-small').lower()
HINDI_MODEL_NAME = os.getenv('HINDI_MODEL_NAME', 'L3Cube-Pune/Hindi-BART-Summary')
MODELS = {
"t5-small": {
"name": "t5-small",
"max_length": 300,
"min_length": 80,
"max_input_length": 1024,
"description": "T5 Small - Fast CPU inference, ~240MB (BEST FOR GITHUB ACTIONS)"
},
"distilbart": {
"name": "sshleifer/distilbart-cnn-12-6",
"max_length": 130,
"min_length": 30,
"max_input_length": 1024,
"description": "DistilBART - Faster than BART, ~600MB (GOOD FOR GITHUB ACTIONS)"
},
"bart": {
"name": "facebook/bart-large-cnn",
"max_length": 130,
"min_length": 30,
"max_input_length": 1024,
"description": "BART - Good balance of speed and quality, ~1.6GB"
},
"t5": {
"name": "t5-base",
"max_length": 150,
"min_length": 30,
"max_input_length": 512,
"description": "T5 Base - Versatile text-to-text model, ~850MB"
},
"pegasus": {
"name": "google/pegasus-xsum",
"max_length": 128,
"min_length": 32,
"max_input_length": 512,
"description": "Pegasus - Optimized for news summarization, ~2.2GB"
},
"led": {
"name": "allenai/led-base-16384",
"max_length": 150,
"min_length": 30,
"max_input_length": 4096,
"description": "LED - Best for long documents, ~500MB"
}
}
if MODEL_TYPE not in MODELS:
valid_models = ", ".join(MODELS.keys())
print(f"Warning: Invalid MODEL_TYPE '{MODEL_TYPE}' in .env")
print(f"Valid options: {valid_models}")
print(f"Falling back to default: t5-small")
MODEL_TYPE = "t5-small"
LANGUAGE_MODEL_OVERRIDES = {
"hindi": {
"name": HINDI_MODEL_NAME,
"max_length": 220,
"min_length": 70,
"max_input_length": 1024,
"description": "Hindi-BART-Summary by L3Cube Pune",
"is_t5": False
}
}
def _fallback_summary(words, max_words: int) -> str:
return " ".join(words[:max_words]).strip()
def _normalize_summary_length(summary: str, original_words, max_words: int) -> str:
if not summary:
return _fallback_summary(original_words, max_words)
summary_words = summary.split()
if len(summary_words) > max_words:
summary = " ".join(summary_words[:max_words]).strip()
summary_words = summary.split()
min_words = max(35, int(max_words * 0.55))
if len(summary_words) < min_words:
return _fallback_summary(original_words, max_words)
return summary
def _language_model_config(language: str):
lang = (language or "english").strip().lower()
if lang in LANGUAGE_MODEL_OVERRIDES:
return LANGUAGE_MODEL_OVERRIDES[lang], lang
return MODELS[MODEL_TYPE], "english"
def _is_t5_model(language: str) -> bool:
lang = (language or "english").strip().lower()
if lang in LANGUAGE_MODEL_OVERRIDES:
return LANGUAGE_MODEL_OVERRIDES[lang].get("is_t5", False)
return MODEL_TYPE.startswith("t5")
class SummarizationModel:
"""Thread-safe singleton for loading and running HuggingFace seq2seq models.
Lazy-loads the model on first summarize() call for each language.
Protects tokenization/generation with a threading lock.
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._lock = threading.Lock()
cls._instance._models = {}
return cls._instance
def __init__(self):
if not hasattr(self, "_models"):
self._models = {}
def _load_model(self, language: str):
model_config, model_key = _language_model_config(language)
device_name = "GPU (CUDA)" if DEVICE == 0 else "CPU"
if DEVICE == 0 and torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
print(f"Using device: {device_name} ({gpu_name})")
else:
print(f"Using device: {device_name}")
if USE_GPU and not torch.cuda.is_available():
print("Warning: GPU requested but CUDA not available, falling back to CPU")
print(f"Loading model: {model_config['name']}")
print(f"Description: {model_config['description']}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_config["name"])
model = AutoModelForSeq2SeqLM.from_pretrained(model_config["name"])
if DEVICE == 0:
model = model.to("cuda")
self._models[model_key] = {
"tokenizer": tokenizer,
"model": model,
"config": model_config,
"device": "cuda" if DEVICE == 0 else "cpu"
}
print("Model loaded successfully!\n")
except Exception as e:
print(f"Error loading model: {e}")
raise
def summarize(self, text: str, max_words: int = 80, language: str = "english") -> str:
"""Generate a summary of the input text.
Args:
text: The article body text to summarize.
max_words: Maximum number of words in the output summary.
language: "english" or "hindi" — determines which model to use.
Returns:
Summary string. Falls back to truncated original if model fails.
"""
if not text or not text.strip():
return text
words = text.split()
max_input_words = 600
if len(words) > max_input_words:
text = " ".join(words[:max_input_words])
if len(words) < 40:
return text
model_config, model_key = _language_model_config(language)
if model_key not in self._models:
with self._lock:
if model_key not in self._models:
self._load_model(model_key)
model_bundle = self._models[model_key]
tokenizer = model_bundle["tokenizer"]
model = model_bundle["model"]
device = model_bundle["device"]
if _is_t5_model(model_key):
text = "summarize: " + text
max_length = min(int(max_words * 2.0), model_config["max_length"])
min_length = min(max(model_config["min_length"], int(max_words * 0.5)), max_length - 20)
min_length = max(20, min_length)
try:
with self._lock:
inputs = tokenizer(
text,
max_length=model_config["max_input_length"],
truncation=True,
return_tensors="pt"
)
if device == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
if _is_t5_model(model_key):
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4,
length_penalty=2.5,
early_stopping=True,
no_repeat_ngram_size=3
)
else:
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True
)
summary = tokenizer.decode(
summary_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
if not summary or len(summary.strip()) < 20:
return _fallback_summary(words, max_words)
return _normalize_summary_length(summary.strip(), words, max_words)
except Exception as e:
print(f"Summarization error: {e}")
return _fallback_summary(words, max_words)
def get_summarizer():
"""Returns the singleton SummarizationModel instance."""
return SummarizationModel()