Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| from pathlib import Path | |
| from functools import lru_cache | |
| class EnhancedBertForSequenceClassification(nn.Module): | |
| def __init__(self, model_name='bert-base-uncased', num_classes=2, dropout=0.3): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.bert = BertModel.from_pretrained(model_name) | |
| self.dropout = nn.Dropout(dropout) | |
| # Additional layers for better performance | |
| self.lstm = nn.LSTM( | |
| input_size=self.bert.config.hidden_size, | |
| hidden_size=256, | |
| num_layers=2, | |
| batch_first=True, | |
| dropout=0.2, | |
| bidirectional=True | |
| ) | |
| # Attention mechanism | |
| self.attention = nn.MultiheadAttention( | |
| embed_dim=512, # bidirectional LSTM output | |
| num_heads=8, | |
| dropout=0.1 | |
| ) | |
| # Classification layers | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, num_classes) | |
| ) | |
| # Layer normalization | |
| self.layer_norm = nn.LayerNorm(512) | |
| def forward(self, input_ids, attention_mask): | |
| # BERT encoding | |
| bert_output = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| # Get sequence output (all tokens) | |
| sequence_output = bert_output.last_hidden_state | |
| sequence_output = self.dropout(sequence_output) | |
| # LSTM layer | |
| lstm_output, _ = self.lstm(sequence_output) | |
| lstm_output = self.layer_norm(lstm_output) | |
| # Self-attention | |
| lstm_output_transposed = lstm_output.transpose(0, 1) | |
| attn_output, _ = self.attention( | |
| lstm_output_transposed, | |
| lstm_output_transposed, | |
| lstm_output_transposed | |
| ) | |
| attn_output = attn_output.transpose(0, 1) | |
| # Global max pooling | |
| pooled_output = torch.max(attn_output, dim=1)[0] | |
| # Classification | |
| logits = self.classifier(pooled_output) | |
| return logits | |
| def get_model(): | |
| """ | |
| Load the fine-tuned BERT model and tokenizer. | |
| Uses caching to load only once. | |
| Returns: | |
| tuple: (model, tokenizer, checkpoint_info) | |
| """ | |
| model_path = Path(__file__).parent.parent.parent / "enhanced_bert_welfake_model" | |
| # Load tokenizer | |
| tokenizer = BertTokenizer.from_pretrained(str(model_path)) | |
| # Load checkpoint | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| checkpoint = torch.load( | |
| model_path / "model.pth", | |
| map_location=device | |
| ) | |
| # Get model configuration from checkpoint | |
| num_classes = checkpoint.get('num_classes', 2) | |
| classification_type = checkpoint.get('classification_type', 'binary') | |
| model_config = checkpoint.get('config', {}) | |
| dropout = model_config.get('dropout', 0.3) | |
| model_name = model_config.get('model_name', 'bert-base-uncased') | |
| # Create model with correct architecture | |
| model = EnhancedBertForSequenceClassification( | |
| model_name=model_name, | |
| num_classes=num_classes, | |
| dropout=dropout | |
| ) | |
| # Load state dict | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer, checkpoint | |
| def predict_fake_news(text: str, model=None, tokenizer=None, checkpoint=None): | |
| """ | |
| Predict whether a news article is fake or real. | |
| Args: | |
| text: News article text (can be title only, or title [SEP] text format) | |
| model: Pre-loaded model (optional) | |
| tokenizer: Pre-loaded tokenizer (optional) | |
| checkpoint: Model checkpoint with metadata (optional) | |
| Returns: | |
| dict: Prediction results with label, confidence, and probabilities | |
| """ | |
| if model is None or tokenizer is None: | |
| model, tokenizer, checkpoint = get_model() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Format input to match training format: title [SEP] text | |
| # If input doesn't have [SEP], treat the whole input as title + duplicate as text | |
| if '[SEP]' not in text: | |
| # User passed only headline/claim - format it like training data | |
| # Use the text as both title and content for better model understanding | |
| formatted_text = f"{text} [SEP] {text}" | |
| else: | |
| formatted_text = text | |
| # Determine classification type from checkpoint | |
| num_classes = checkpoint.get('num_classes', 2) if checkpoint else 2 | |
| classification_type = checkpoint.get('classification_type', 'binary') if checkpoint else 'binary' | |
| # Label mapping based on classification type | |
| # NOTE: WELFake dataset uses: | |
| # 0 = real (legitimate news) | |
| # 1 = fake (fake/misleading news) | |
| if classification_type == 'binary' and num_classes == 2: | |
| labels = { | |
| 0: "real", | |
| 1: "fake" | |
| } | |
| elif num_classes == 6: | |
| labels = { | |
| 0: "pants-fire", | |
| 1: "false", | |
| 2: "barely-true", | |
| 3: "half-true", | |
| 4: "mostly-true", | |
| 5: "true" | |
| } | |
| else: | |
| labels = {i: f"class_{i}" for i in range(num_classes)} | |
| # Tokenize input (use formatted text) | |
| encoding = tokenizer( | |
| formatted_text, | |
| add_special_tokens=True, | |
| max_length=512, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| logits = model(input_ids, attention_mask) | |
| probabilities = torch.softmax(logits, dim=1) | |
| predicted_class = torch.argmax(probabilities, dim=1).item() | |
| confidence = probabilities[0][predicted_class].item() | |
| # Convert probabilities to dict | |
| prob_dict = {labels[i]: float(probabilities[0][i].item()) for i in range(num_classes)} | |
| # Determine if fake based on classification type | |
| if classification_type == 'binary': | |
| is_fake = predicted_class == 1 # class 1 is "fake" in WELFake dataset | |
| else: | |
| is_fake = predicted_class < 3 # pants-fire, false, barely-true are considered fake | |
| return { | |
| "text": text, # Return original text, not formatted | |
| "prediction": labels[predicted_class], | |
| "confidence": float(confidence), | |
| "probabilities": prob_dict, | |
| "is_fake": is_fake, | |
| "classification_type": classification_type | |
| } | |