#!/usr/bin/env python3 """ Enhanced IndicBERT Processor with Fine-tuning Capabilities Supports both inference and fine-tuning for Indian language misinformation detection. """ import os import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, AutoModel, AdamW, get_linear_schedule_with_warmup ) from typing import Dict, List, Tuple, Optional import numpy as np import logging from tqdm import tqdm import pickle from functools import lru_cache logger = logging.getLogger(__name__) class MisinformationDataset(Dataset): """PyTorch Dataset for misinformation detection""" def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length: int = 512): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = str(self.texts[idx]) label = int(self.labels[idx]) encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label, dtype=torch.long) } class EnhancedIndicBERTProcessor: """Enhanced IndicBERT processor with fine-tuning and caching""" def __init__(self, model_name: str = "ai4bharat/indic-bert", num_labels: int = 2): self.model_name = model_name self.num_labels = num_labels self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.tokenizer = None self.model = None self.classification_model = None self.is_finetuned = False # Cache for embeddings (LRU cache for efficiency) self._embedding_cache = {} self.cache_size = 1000 logger.info(f"🧠 Initializing Enhanced IndicBERT on {self.device}") self._initialize_model() def _initialize_model(self, for_classification: bool = False): """Initialize IndicBERT model""" try: # SAFEGUARD: Limit PyTorch CPU threads to 1 to prevent massive MKL memory pool bloat in threaded environments torch.set_num_threads(1) logger.info(f"Loading tokenizer from {self.model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if for_classification: logger.info("Loading IndicBERT for sequence classification...") self.classification_model = AutoModelForSequenceClassification.from_pretrained( self.model_name, num_labels=self.num_labels ) self.classification_model.to(self.device) else: logger.info("Loading base IndicBERT model...") self.model = AutoModel.from_pretrained(self.model_name) self.model.to(self.device) logger.info("āœ… IndicBERT loaded successfully") except Exception as e: logger.error(f"āŒ Failed to load IndicBERT: {e}") raise @lru_cache(maxsize=1000) def get_embeddings(self, text: str) -> np.ndarray: """Get IndicBERT embeddings with LRU caching""" if not self.model: self._initialize_model(for_classification=False) try: # Tokenize inputs = self.tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get embeddings self.model.eval() with torch.no_grad(): outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() return embeddings.flatten() except Exception as e: logger.error(f"Embedding generation failed: {e}") return np.random.rand(768) def get_embeddings_batch(self, texts: List[str], batch_size: int = 16) -> np.ndarray: """Get embeddings for multiple texts efficiently""" if not self.model: self._initialize_model(for_classification=False) all_embeddings = [] self.model.eval() with torch.no_grad(): for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] inputs = self.tokenizer( batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() all_embeddings.append(embeddings) return np.vstack(all_embeddings) def fine_tune(self, train_texts: List[str], train_labels: List[int], val_texts: Optional[List[str]] = None, val_labels: Optional[List[int]] = None, epochs: int = 3, batch_size: int = 16, learning_rate: float = 2e-5, output_dir: str = "models/finetuned_indicbert", save_steps: int = 500): """Fine-tune IndicBERT for misinformation classification""" logger.info("šŸš€ Starting IndicBERT fine-tuning...") logger.info(f"Training samples: {len(train_texts)}") if val_texts: logger.info(f"Validation samples: {len(val_texts)}") # Initialize classification model self._initialize_model(for_classification=True) # Create datasets train_dataset = MisinformationDataset(train_texts, train_labels, self.tokenizer) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = None if val_texts and val_labels: val_dataset = MisinformationDataset(val_texts, val_labels, self.tokenizer) val_loader = DataLoader(val_dataset, batch_size=batch_size) # Optimizer and scheduler optimizer = AdamW(self.classification_model.parameters(), lr=learning_rate) total_steps = len(train_loader) * epochs scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps ) # Training loop best_val_accuracy = 0.0 for epoch in range(epochs): logger.info(f"\nšŸ“š Epoch {epoch + 1}/{epochs}") # Training self.classification_model.train() train_loss = 0 train_correct = 0 train_total = 0 progress_bar = tqdm(train_loader, desc="Training") for batch in progress_bar: input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['label'].to(self.device) optimizer.zero_grad() outputs = self.classification_model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss logits = outputs.logits loss.backward() torch.nn.utils.clip_grad_norm_(self.classification_model.parameters(), 1.0) optimizer.step() scheduler.step() train_loss += loss.item() predictions = torch.argmax(logits, dim=1) train_correct += (predictions == labels).sum().item() train_total += labels.size(0) progress_bar.set_postfix({ 'loss': f'{loss.item():.4f}', 'acc': f'{train_correct/train_total:.4f}' }) avg_train_loss = train_loss / len(train_loader) train_accuracy = train_correct / train_total logger.info(f"Training Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.4f}") # Validation if val_loader: val_accuracy, val_loss = self._evaluate(val_loader) logger.info(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}") # Save best model if val_accuracy > best_val_accuracy: best_val_accuracy = val_accuracy self.save_model(output_dir) logger.info(f"šŸ’¾ Saved best model (accuracy: {val_accuracy:.4f})") self.is_finetuned = True logger.info("āœ… Fine-tuning completed!") return { 'final_train_accuracy': train_accuracy, 'best_val_accuracy': best_val_accuracy if val_loader else None } def _evaluate(self, data_loader) -> Tuple[float, float]: """Evaluate model on validation/test set""" self.classification_model.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch in data_loader: input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['label'].to(self.device) outputs = self.classification_model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) total_loss += outputs.loss.item() predictions = torch.argmax(outputs.logits, dim=1) correct += (predictions == labels).sum().item() total += labels.size(0) accuracy = correct / total avg_loss = total_loss / len(data_loader) return accuracy, avg_loss def predict(self, text: str) -> Dict: """Predict misinformation for a single text""" if not self.classification_model: raise ValueError("Model not trained or loaded. Call fine_tune() or load_model() first.") self.classification_model.eval() inputs = self.tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.classification_model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) prediction = torch.argmax(logits, dim=1).item() return { 'prediction': 'fake' if prediction == 1 else 'real', 'confidence': probabilities[0][prediction].item(), 'probabilities': { 'real': probabilities[0][0].item(), 'fake': probabilities[0][1].item() } } def predict_batch(self, texts: List[str], batch_size: int = 16) -> List[Dict]: """Predict misinformation for multiple texts""" if not self.classification_model: raise ValueError("Model not trained or loaded. Call fine_tune() or load_model() first.") results = [] self.classification_model.eval() with torch.no_grad(): for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] inputs = self.tokenizer( batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.classification_model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) predictions = torch.argmax(logits, dim=1) for j, pred in enumerate(predictions): results.append({ 'prediction': 'fake' if pred.item() == 1 else 'real', 'confidence': probabilities[j][pred].item(), 'probabilities': { 'real': probabilities[j][0].item(), 'fake': probabilities[j][1].item() } }) return results def save_model(self, output_dir: str): """Save fine-tuned model""" os.makedirs(output_dir, exist_ok=True) if self.classification_model: self.classification_model.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir) logger.info(f"Model saved to {output_dir}") else: logger.warning("No classification model to save") def load_model(self, model_dir: str): """Load fine-tuned model""" try: logger.info(f"Loading fine-tuned model from {model_dir}...") self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.classification_model = AutoModelForSequenceClassification.from_pretrained(model_dir) self.classification_model.to(self.device) self.is_finetuned = True logger.info("āœ… Fine-tuned model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise def quantize_model(self): """Apply dynamic quantization for faster inference""" if not self.classification_model: logger.warning("No classification model to quantize") return logger.info("šŸ”§ Applying dynamic quantization...") self.classification_model = torch.quantization.quantize_dynamic( self.classification_model, {torch.nn.Linear}, dtype=torch.qint8 ) logger.info("āœ… Model quantized successfully (expected 2-3x speedup)") def analyze_indian_context(self, text: str) -> Dict: """Analyze Indian context and cultural references""" text_lower = text.lower() # Indian political terms political_terms = [ 'modi', 'rahul gandhi', 'bjp', 'congress', 'aap', 'parliament', 'lok sabha', 'rajya sabha', 'chief minister', 'governor', 'president', 'prime minister' ] # Indian cultural terms cultural_terms = [ 'bollywood', 'cricket', 'ipl', 'festival', 'diwali', 'holi', 'eid', 'temple', 'mosque', 'gurudwara', 'church', 'hindu', 'muslim', 'sikh', 'christian' ] # Indian economic terms economic_terms = [ 'rupee', 'rbi', 'gst', 'demonetization', 'digital india', 'make in india', 'startup india', 'skill india', 'jan dhan', 'aadhaar' ] # Indian geographic terms geographic_terms = [ 'kashmir', 'punjab', 'kerala', 'tamil nadu', 'maharashtra', 'gujarat', 'bengal', 'assam', 'bihar', 'uttar pradesh', 'rajasthan', 'karnataka' ] analysis = { 'political_context': sum(1 for term in political_terms if term in text_lower), 'cultural_context': sum(1 for term in cultural_terms if term in text_lower), 'economic_context': sum(1 for term in economic_terms if term in text_lower), 'geographic_context': sum(1 for term in geographic_terms if term in text_lower), 'indian_relevance_score': 0 } # Calculate Indian relevance score total_context = sum(analysis.values()) - analysis['indian_relevance_score'] analysis['indian_relevance_score'] = min(total_context / 10, 1.0) return analysis