Devang1290
feat: deploy News Whisper on-demand search API (FastAPI + Docker)
2cb327c
"""
Model Configuration
Switch between different summarization models easily by changing MODEL_TYPE.
English uses MODEL_TYPE from .env. Hindi uses Hindi-BART-Summary by L3Cube Pune.
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import threading
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Device Configuration - Read from .env or auto-detect
ENV_DEVICE = os.getenv('DEVICE', 'cpu').lower()
if ENV_DEVICE == 'gpu':
USE_GPU = True
DEVICE = 0 if torch.cuda.is_available() else -1
if DEVICE == -1:
print("Warning: GPU requested but CUDA not available, falling back to CPU")
elif ENV_DEVICE == 'cpu':
USE_GPU = False
DEVICE = -1
else:
# Auto-detect if invalid value
USE_GPU = True
DEVICE = 0 if torch.cuda.is_available() else -1
# CPU Optimization Settings
if DEVICE == -1:
# Set number of threads for CPU inference
CPU_THREADS = int(os.getenv('MAX_WORKERS', '4'))
torch.set_num_threads(CPU_THREADS)
# Enable CPU optimizations
torch.set_num_interop_threads(CPU_THREADS)
# Model Selection - Read from environment or use default
MODEL_TYPE = os.getenv('MODEL_TYPE', 't5-small').lower()
HINDI_MODEL_NAME = os.getenv('HINDI_MODEL_NAME', 'L3Cube-Pune/Hindi-BART-Summary')
# Valid options for English MODEL_TYPE: "distilbart", "t5-small", "bart", "t5", "pegasus", "led"
# Model Configurations
MODELS = {
# CPU-Optimized Models (Recommended for GitHub Actions)
"t5-small": {
"name": "t5-small",
"max_length": 300,
"min_length": 80,
"max_input_length": 1024,
"description": "T5 Small - Fast CPU inference, ~240MB (BEST FOR GITHUB ACTIONS)"
},
"distilbart": {
"name": "sshleifer/distilbart-cnn-12-6",
"max_length": 130,
"min_length": 30,
"max_input_length": 1024,
"description": "DistilBART - Faster than BART, ~600MB (GOOD FOR GITHUB ACTIONS)"
},
# Standard Models (Better quality, slower on CPU)
"bart": {
"name": "facebook/bart-large-cnn",
"max_length": 130,
"min_length": 30,
"max_input_length": 1024,
"description": "BART - Good balance of speed and quality, ~1.6GB"
},
"t5": {
"name": "t5-base",
"max_length": 150,
"min_length": 30,
"max_input_length": 512,
"description": "T5 Base - Versatile text-to-text model, ~850MB"
},
"pegasus": {
"name": "google/pegasus-xsum",
"max_length": 128,
"min_length": 32,
"max_input_length": 512,
"description": "Pegasus - Optimized for news summarization, ~2.2GB"
},
"led": {
"name": "allenai/led-base-16384",
"max_length": 150,
"min_length": 30,
"max_input_length": 4096,
"description": "LED - Best for long documents, ~500MB"
}
}
# Validate MODEL_TYPE
if MODEL_TYPE not in MODELS:
valid_models = ", ".join(MODELS.keys())
print(f"Warning: Invalid MODEL_TYPE '{MODEL_TYPE}' in .env")
print(f"Valid options: {valid_models}")
print(f"Falling back to default: t5-small")
MODEL_TYPE = "t5-small"
LANGUAGE_MODEL_OVERRIDES = {
"hindi": {
"name": HINDI_MODEL_NAME,
"max_length": 220,
"min_length": 70,
"max_input_length": 1024,
"description": "Hindi-BART-Summary by L3Cube Pune",
"is_t5": False
}
}
def _fallback_summary(words, max_words: int) -> str:
return " ".join(words[:max_words]).strip()
def _normalize_summary_length(summary: str, original_words, max_words: int) -> str:
if not summary:
return _fallback_summary(original_words, max_words)
summary_words = summary.split()
if len(summary_words) > max_words:
summary = " ".join(summary_words[:max_words]).strip()
summary_words = summary.split()
min_words = max(35, int(max_words * 0.55))
if len(summary_words) < min_words:
return _fallback_summary(original_words, max_words)
return summary
def _language_model_config(language: str):
lang = (language or "english").strip().lower()
if lang in LANGUAGE_MODEL_OVERRIDES:
return LANGUAGE_MODEL_OVERRIDES[lang], lang
return MODELS[MODEL_TYPE], "english"
def _is_t5_model(language: str) -> bool:
lang = (language or "english").strip().lower()
if lang in LANGUAGE_MODEL_OVERRIDES:
return LANGUAGE_MODEL_OVERRIDES[lang].get("is_t5", False)
return MODEL_TYPE.startswith("t5")
class SummarizationModel:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._lock = threading.Lock()
cls._instance._models = {}
return cls._instance
def __init__(self):
if not hasattr(self, "_models"):
self._models = {}
def _load_model(self, language: str):
# Load the selected model with proper device configuration
model_config, model_key = _language_model_config(language)
# Display device info
device_name = "GPU (CUDA)" if DEVICE == 0 else "CPU"
if DEVICE == 0 and torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
print(f"Using device: {device_name} ({gpu_name})")
else:
print(f"Using device: {device_name}")
if USE_GPU and not torch.cuda.is_available():
print("Warning: GPU requested but CUDA not available, falling back to CPU")
print(f"Loading model: {model_config['name']}")
print(f"Description: {model_config['description']}")
try:
# Load model and tokenizer directly for better compatibility
tokenizer = AutoTokenizer.from_pretrained(model_config["name"])
model = AutoModelForSeq2SeqLM.from_pretrained(model_config["name"])
# Move model to device
if DEVICE == 0:
model = model.to("cuda")
self._models[model_key] = {
"tokenizer": tokenizer,
"model": model,
"config": model_config,
"device": "cuda" if DEVICE == 0 else "cpu"
}
print("Model loaded successfully!\n")
except Exception as e:
print(f"Error loading model: {e}")
raise
def summarize(self, text: str, max_words: int = 80, language: str = "english") -> str:
# Generate summary using the loaded model with proper truncation
if not text or not text.strip():
return text
# Split text into words and check length
words = text.split()
# Conservative word limit since 1 word is often > 1 token
# T5-small max is 512 tokens. 600 words is a good upper bound to capture more context.
max_input_words = 600
# Truncate input if too long
if len(words) > max_input_words:
text = " ".join(words[:max_input_words])
# Skip very short texts
if len(words) < 40:
return text
model_config, model_key = _language_model_config(language)
if model_key not in self._models:
with self._lock:
if model_key not in self._models:
self._load_model(model_key)
model_bundle = self._models[model_key]
tokenizer = model_bundle["tokenizer"]
model = model_bundle["model"]
device = model_bundle["device"]
# Add task prefix for T5 models
if _is_t5_model(model_key):
text = "summarize: " + text
# Calculate output lengths
# Increase buffer for tokens vs words diff
max_length = min(int(max_words * 2.0), model_config["max_length"])
min_length = min(max(model_config["min_length"], int(max_words * 0.5)), max_length - 20)
min_length = max(20, min_length)
try:
with self._lock:
# Tokenize input
inputs = tokenizer(
text,
max_length=model_config["max_input_length"],
truncation=True,
return_tensors="pt"
)
# Move to device
if device == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate summary with model-specific parameters
if _is_t5_model(model_key):
# T5 works better with these parameters
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4,
length_penalty=2.5,
early_stopping=True,
no_repeat_ngram_size=3
)
else:
# BART, Pegasus, LED, DistilBART
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True
)
# Decode summary
summary = tokenizer.decode(
summary_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
# Fallback if summary is empty or too short
if not summary or len(summary.strip()) < 20:
return _fallback_summary(words, max_words)
return _normalize_summary_length(summary.strip(), words, max_words)
except Exception as e:
print(f"Summarization error: {e}")
# Return truncated original text as fallback
return _fallback_summary(words, max_words)
# Global model instance
def get_summarizer():
# Returns singleton instance of summarization model
return SummarizationModel()