Rajak13's picture
Upload folder using huggingface_hub (#1)
634567d verified
"""
BART (Bidirectional and Auto-Regressive Transformers) Abstractive Summarization
State-of-the-art sequence-to-sequence model for text generation
Professional implementation with comprehensive features
"""
# Handle imports when running directly (python models/bart.py)
# For proper package usage, run as: python -m models.bart
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from transformers import BartForConditionalGeneration, BartTokenizer
import torch
import logging
from typing import Dict, List, Optional, Union
from models.base_summarizer import BaseSummarizer
logger = logging.getLogger(__name__)
class BARTSummarizer(BaseSummarizer):
"""
BART implementation for abstractive text summarization.
Model Architecture:
- Encoder: Bidirectional transformer (like BERT)
- Decoder: Auto-regressive transformer (like GPT)
- Pre-trained on denoising tasks
Key Features:
- Generates human-like, fluent summaries
- Can paraphrase and compress information
- Handles long documents effectively
- State-of-the-art performance on CNN/DailyMail
Training Objective:
Trained to reconstruct original text from corrupted versions:
- Token masking
- Token deletion
- Sentence permutation
- Document rotation
Mathematical Foundation:
Self-Attention: Attention(Q,K,V) = softmax(QK^T/√d_k)V
Where Q=Query, K=Key, V=Value, d_k=dimension of keys
"""
def __init__(self,
model_name: str = "facebook/bart-large-cnn",
device: Optional[str] = None,
use_fp16: bool = False):
"""
Initialize BART Summarizer
Args:
model_name: HuggingFace model identifier
device: Computing device ('cuda', 'cpu', or None for auto-detect)
use_fp16: Use 16-bit floating point for faster inference (requires GPU)
"""
super().__init__(model_name="BART", model_type="Abstractive")
logger.info(f"Loading BART model: {model_name}")
logger.info("Initial model loading may take 2-3 minutes...")
# Determine device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"Using device: {self.device}")
if self.device == "cuda":
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
# Load tokenizer and model
try:
self.tokenizer = BartTokenizer.from_pretrained(model_name)
self.model = BartForConditionalGeneration.from_pretrained(model_name)
# Move model to device
self.model.to(self.device)
# Enable FP16 if requested and GPU available
if use_fp16 and self.device == "cuda":
self.model.half()
logger.info("Using FP16 precision for faster inference")
# Set to evaluation mode
self.model.eval()
self.model_name_full = model_name
self.is_initialized = True
logger.info("BART model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load BART model: {e}")
raise
def summarize(self,
text: str,
max_length: int = 150,
min_length: int = 50,
num_beams: int = 4,
length_penalty: float = 2.0,
no_repeat_ngram_size: int = 3,
early_stopping: bool = True,
do_sample: bool = False,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95) -> str:
"""
Generate abstractive summary using BART
Beam Search: Maintains top-k hypotheses at each step
Length Penalty: Exponential penalty applied to sequence length
Args:
text: Input text to summarize
max_length: Maximum summary length in tokens
min_length: Minimum summary length in tokens
num_beams: Number of beams for beam search (higher = better quality, slower)
length_penalty: >1.0 favors longer sequences, <1.0 favors shorter
no_repeat_ngram_size: Prevent repetition of n-grams
early_stopping: Stop when num_beams hypotheses are complete
do_sample: Use sampling instead of greedy decoding
temperature: Sampling temperature (higher = more random)
top_k: Keep only top k tokens for sampling
top_p: Nucleus sampling threshold
Returns:
Generated summary string
"""
# Validate input
self.validate_input(text)
# Tokenize input
inputs = self.tokenizer(
text,
max_length=1024, # BART max input length
truncation=True,
padding="max_length",
return_tensors="pt"
)
# Move to device
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# Generate summary
with torch.no_grad():
if do_sample:
# Sampling-based generation (more diverse)
summary_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
min_length=min_length,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping
)
else:
# Beam search generation (more deterministic, higher quality)
summary_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping
)
# Decode summary
summary = self.tokenizer.decode(
summary_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return summary
def batch_summarize(self,
texts: List[str],
batch_size: int = 4,
max_length: int = 150,
min_length: int = 50,
**kwargs) -> List[str]:
"""
Efficiently summarize multiple texts in batches
Args:
texts: List of texts to summarize
batch_size: Number of texts to process simultaneously
max_length: Maximum summary length
min_length: Minimum summary length
**kwargs: Additional generation parameters
Returns:
List of generated summaries
"""
logger.info(f"Batch summarizing {len(texts)} texts (batch_size={batch_size})")
summaries = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
# Tokenize batch
inputs = self.tokenizer(
batch,
max_length=1024,
truncation=True,
padding=True,
return_tensors="pt"
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# Generate summaries for batch
with torch.no_grad():
summary_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
min_length=min_length,
num_beams=kwargs.get('num_beams', 4),
early_stopping=True
)
# Decode summaries
batch_summaries = [
self.tokenizer.decode(ids, skip_special_tokens=True)
for ids in summary_ids
]
summaries.extend(batch_summaries)
logger.info(f"Processed batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
return summaries
def get_model_info(self) -> Dict:
"""Return comprehensive model information"""
info = super().get_model_info()
info.update({
'algorithm': 'Transformer Encoder-Decoder',
'architecture': {
'encoder': 'Bidirectional (BERT-like)',
'decoder': 'Auto-regressive (GPT-like)',
'layers': '12 encoder + 12 decoder',
'attention_heads': 16,
'hidden_size': 1024,
'parameters': '406M'
},
'training': {
'objective': 'Denoising autoencoder',
'noise_functions': [
'Token masking',
'Token deletion',
'Sentence permutation',
'Document rotation'
],
'dataset': 'Large-scale web text + CNN/DailyMail fine-tuning'
},
'performance': {
'rouge_1': '44.16',
'rouge_2': '21.28',
'rouge_l': '40.90',
'benchmark': 'CNN/DailyMail test set'
},
'advantages': [
'Generates fluent, human-like summaries',
'Can paraphrase and compress effectively',
'Handles long documents well',
'State-of-the-art performance'
],
'limitations': [
'May introduce factual errors',
'Computationally intensive',
'Requires GPU for fast inference',
'Black-box nature (less interpretable)'
]
})
return info
def __del__(self):
"""Cleanup GPU memory when object is destroyed"""
if hasattr(self, 'device') and self.device == 'cuda':
torch.cuda.empty_cache()
# Test the implementation
if __name__ == "__main__":
sample_text = """
Machine learning has revolutionized artificial intelligence in recent years.
Deep learning neural networks can now perform tasks that were impossible just
a decade ago. Computer vision systems can recognize objects in images with
superhuman accuracy. Natural language processing models can generate human-like
text and translate between languages. Reinforcement learning has enabled AI
to master complex games like Go and StarCraft. These advances have been driven
by increases in computing power, availability of large datasets, and algorithmic
innovations. However, challenges remain in areas like explainability, fairness,
and robustness. The field continues to evolve rapidly with new breakthroughs
occurring regularly.
"""
print("=" * 70)
print("BART SUMMARIZER - PROFESSIONAL TEST")
print("=" * 70)
# Initialize summarizer
summarizer = BARTSummarizer()
# Generate summary with metrics
result = summarizer.summarize_with_metrics(
sample_text,
max_length=100,
min_length=30,
num_beams=4
)
print(f"\nModel: {result['metadata']['model_name']}")
print(f"Type: {result['metadata']['model_type']}")
print(f"Device: {summarizer.device}")
print(f"Input Length: {result['metadata']['input_length']} words")
print(f"Summary Length: {result['metadata']['summary_length']} words")
print(f"Compression Ratio: {result['metadata']['compression_ratio']:.2%}")
print(f"Processing Time: {result['metadata']['processing_time']:.4f} seconds")
print(f"\n{'Generated Summary:':-^70}")
print(result['summary'])
print("\n" + "=" * 70)
model_info = summarizer.get_model_info()
print(f"Architecture: {model_info['architecture']}")
print(f"Performance: {model_info['performance']}")