Spaces:
Sleeping
Sleeping
| """ | |
| Summarization Model Configuration | |
| =================================== | |
| Manages AI model selection and inference for article summarization. | |
| English Pipeline: | |
| Uses configurable models (default: t5-small) via MODEL_TYPE env var. | |
| Options: t5-small, distilbart, bart, t5, pegasus, led | |
| Hindi Pipeline (via english_summary.py fallback only): | |
| Uses L3Cube-Pune/Hindi-BART-Summary. | |
| NOTE: The recommended Hindi path is hindi_summary.py (mT5 ONNX + Groq). | |
| This model is only used if summarize is called directly on Hindi articles. | |
| Architecture: | |
| SummarizationModel is a thread-safe singleton. Each language's model is | |
| loaded lazily on first use and cached in memory for the batch run. | |
| Usage: | |
| from backend.summarization.model import get_summarizer | |
| summarizer = get_summarizer() | |
| summary = summarizer.summarize(article_text, max_words=150, language="english") | |
| """ | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import os | |
| import threading | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # ───────────────────────────────────────────── | |
| # Device Configuration | |
| # ───────────────────────────────────────────── | |
| ENV_DEVICE = os.getenv('DEVICE', 'cpu').lower() | |
| if ENV_DEVICE == 'gpu': | |
| USE_GPU = True | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| if DEVICE == -1: | |
| print("Warning: GPU requested but CUDA not available, falling back to CPU") | |
| elif ENV_DEVICE == 'cpu': | |
| USE_GPU = False | |
| DEVICE = -1 | |
| else: | |
| USE_GPU = True | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| # CPU Optimization | |
| if DEVICE == -1: | |
| CPU_THREADS = int(os.getenv('MAX_WORKERS', '4')) | |
| torch.set_num_threads(CPU_THREADS) | |
| torch.set_num_interop_threads(CPU_THREADS) | |
| # ───────────────────────────────────────────── | |
| # Model Selection | |
| # ───────────────────────────────────────────── | |
| MODEL_TYPE = os.getenv('MODEL_TYPE', 't5-small').lower() | |
| HINDI_MODEL_NAME = os.getenv('HINDI_MODEL_NAME', 'L3Cube-Pune/Hindi-BART-Summary') | |
| MODELS = { | |
| "t5-small": { | |
| "name": "t5-small", | |
| "max_length": 300, | |
| "min_length": 80, | |
| "max_input_length": 1024, | |
| "description": "T5 Small - Fast CPU inference, ~240MB (BEST FOR GITHUB ACTIONS)" | |
| }, | |
| "distilbart": { | |
| "name": "sshleifer/distilbart-cnn-12-6", | |
| "max_length": 130, | |
| "min_length": 30, | |
| "max_input_length": 1024, | |
| "description": "DistilBART - Faster than BART, ~600MB (GOOD FOR GITHUB ACTIONS)" | |
| }, | |
| "bart": { | |
| "name": "facebook/bart-large-cnn", | |
| "max_length": 130, | |
| "min_length": 30, | |
| "max_input_length": 1024, | |
| "description": "BART - Good balance of speed and quality, ~1.6GB" | |
| }, | |
| "t5": { | |
| "name": "t5-base", | |
| "max_length": 150, | |
| "min_length": 30, | |
| "max_input_length": 512, | |
| "description": "T5 Base - Versatile text-to-text model, ~850MB" | |
| }, | |
| "pegasus": { | |
| "name": "google/pegasus-xsum", | |
| "max_length": 128, | |
| "min_length": 32, | |
| "max_input_length": 512, | |
| "description": "Pegasus - Optimized for news summarization, ~2.2GB" | |
| }, | |
| "led": { | |
| "name": "allenai/led-base-16384", | |
| "max_length": 150, | |
| "min_length": 30, | |
| "max_input_length": 4096, | |
| "description": "LED - Best for long documents, ~500MB" | |
| } | |
| } | |
| if MODEL_TYPE not in MODELS: | |
| valid_models = ", ".join(MODELS.keys()) | |
| print(f"Warning: Invalid MODEL_TYPE '{MODEL_TYPE}' in .env") | |
| print(f"Valid options: {valid_models}") | |
| print(f"Falling back to default: t5-small") | |
| MODEL_TYPE = "t5-small" | |
| LANGUAGE_MODEL_OVERRIDES = { | |
| "hindi": { | |
| "name": HINDI_MODEL_NAME, | |
| "max_length": 220, | |
| "min_length": 70, | |
| "max_input_length": 1024, | |
| "description": "Hindi-BART-Summary by L3Cube Pune", | |
| "is_t5": False | |
| } | |
| } | |
| def _fallback_summary(words, max_words: int) -> str: | |
| return " ".join(words[:max_words]).strip() | |
| def _normalize_summary_length(summary: str, original_words, max_words: int) -> str: | |
| if not summary: | |
| return _fallback_summary(original_words, max_words) | |
| summary_words = summary.split() | |
| if len(summary_words) > max_words: | |
| summary = " ".join(summary_words[:max_words]).strip() | |
| summary_words = summary.split() | |
| min_words = max(35, int(max_words * 0.55)) | |
| if len(summary_words) < min_words: | |
| return _fallback_summary(original_words, max_words) | |
| return summary | |
| def _language_model_config(language: str): | |
| lang = (language or "english").strip().lower() | |
| if lang in LANGUAGE_MODEL_OVERRIDES: | |
| return LANGUAGE_MODEL_OVERRIDES[lang], lang | |
| return MODELS[MODEL_TYPE], "english" | |
| def _is_t5_model(language: str) -> bool: | |
| lang = (language or "english").strip().lower() | |
| if lang in LANGUAGE_MODEL_OVERRIDES: | |
| return LANGUAGE_MODEL_OVERRIDES[lang].get("is_t5", False) | |
| return MODEL_TYPE.startswith("t5") | |
| class SummarizationModel: | |
| """Thread-safe singleton for loading and running HuggingFace seq2seq models. | |
| Lazy-loads the model on first summarize() call for each language. | |
| Protects tokenization/generation with a threading lock. | |
| """ | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._lock = threading.Lock() | |
| cls._instance._models = {} | |
| return cls._instance | |
| def __init__(self): | |
| if not hasattr(self, "_models"): | |
| self._models = {} | |
| def _load_model(self, language: str): | |
| model_config, model_key = _language_model_config(language) | |
| device_name = "GPU (CUDA)" if DEVICE == 0 else "CPU" | |
| if DEVICE == 0 and torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| print(f"Using device: {device_name} ({gpu_name})") | |
| else: | |
| print(f"Using device: {device_name}") | |
| if USE_GPU and not torch.cuda.is_available(): | |
| print("Warning: GPU requested but CUDA not available, falling back to CPU") | |
| print(f"Loading model: {model_config['name']}") | |
| print(f"Description: {model_config['description']}") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_config["name"]) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_config["name"]) | |
| if DEVICE == 0: | |
| model = model.to("cuda") | |
| self._models[model_key] = { | |
| "tokenizer": tokenizer, | |
| "model": model, | |
| "config": model_config, | |
| "device": "cuda" if DEVICE == 0 else "cpu" | |
| } | |
| print("Model loaded successfully!\n") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| def summarize(self, text: str, max_words: int = 80, language: str = "english") -> str: | |
| """Generate a summary of the input text. | |
| Args: | |
| text: The article body text to summarize. | |
| max_words: Maximum number of words in the output summary. | |
| language: "english" or "hindi" — determines which model to use. | |
| Returns: | |
| Summary string. Falls back to truncated original if model fails. | |
| """ | |
| if not text or not text.strip(): | |
| return text | |
| words = text.split() | |
| max_input_words = 600 | |
| if len(words) > max_input_words: | |
| text = " ".join(words[:max_input_words]) | |
| if len(words) < 40: | |
| return text | |
| model_config, model_key = _language_model_config(language) | |
| if model_key not in self._models: | |
| with self._lock: | |
| if model_key not in self._models: | |
| self._load_model(model_key) | |
| model_bundle = self._models[model_key] | |
| tokenizer = model_bundle["tokenizer"] | |
| model = model_bundle["model"] | |
| device = model_bundle["device"] | |
| if _is_t5_model(model_key): | |
| text = "summarize: " + text | |
| max_length = min(int(max_words * 2.0), model_config["max_length"]) | |
| min_length = min(max(model_config["min_length"], int(max_words * 0.5)), max_length - 20) | |
| min_length = max(20, min_length) | |
| try: | |
| with self._lock: | |
| inputs = tokenizer( | |
| text, | |
| max_length=model_config["max_input_length"], | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| if device == "cuda": | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| if _is_t5_model(model_key): | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=4, | |
| length_penalty=2.5, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3 | |
| ) | |
| else: | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=4, | |
| length_penalty=2.0, | |
| early_stopping=True | |
| ) | |
| summary = tokenizer.decode( | |
| summary_ids[0], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| if not summary or len(summary.strip()) < 20: | |
| return _fallback_summary(words, max_words) | |
| return _normalize_summary_length(summary.strip(), words, max_words) | |
| except Exception as e: | |
| print(f"Summarization error: {e}") | |
| return _fallback_summary(words, max_words) | |
| def get_summarizer(): | |
| """Returns the singleton SummarizationModel instance.""" | |
| return SummarizationModel() | |