File size: 7,001 Bytes
bf5067d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from pathlib import Path
from functools import lru_cache
class EnhancedBertForSequenceClassification(nn.Module):
def __init__(self, model_name='bert-base-uncased', num_classes=2, dropout=0.3):
super().__init__()
self.num_classes = num_classes
self.bert = BertModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
# Additional layers for better performance
self.lstm = nn.LSTM(
input_size=self.bert.config.hidden_size,
hidden_size=256,
num_layers=2,
batch_first=True,
dropout=0.2,
bidirectional=True
)
# Attention mechanism
self.attention = nn.MultiheadAttention(
embed_dim=512, # bidirectional LSTM output
num_heads=8,
dropout=0.1
)
# Classification layers
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes)
)
# Layer normalization
self.layer_norm = nn.LayerNorm(512)
def forward(self, input_ids, attention_mask):
# BERT encoding
bert_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
# Get sequence output (all tokens)
sequence_output = bert_output.last_hidden_state
sequence_output = self.dropout(sequence_output)
# LSTM layer
lstm_output, _ = self.lstm(sequence_output)
lstm_output = self.layer_norm(lstm_output)
# Self-attention
lstm_output_transposed = lstm_output.transpose(0, 1)
attn_output, _ = self.attention(
lstm_output_transposed,
lstm_output_transposed,
lstm_output_transposed
)
attn_output = attn_output.transpose(0, 1)
# Global max pooling
pooled_output = torch.max(attn_output, dim=1)[0]
# Classification
logits = self.classifier(pooled_output)
return logits
@lru_cache(maxsize=1)
def get_model():
"""
Load the fine-tuned BERT model and tokenizer.
Uses caching to load only once.
Returns:
tuple: (model, tokenizer, checkpoint_info)
"""
model_path = Path(__file__).parent.parent.parent / "enhanced_bert_welfake_model"
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(str(model_path))
# Load checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(
model_path / "model.pth",
map_location=device
)
# Get model configuration from checkpoint
num_classes = checkpoint.get('num_classes', 2)
classification_type = checkpoint.get('classification_type', 'binary')
model_config = checkpoint.get('config', {})
dropout = model_config.get('dropout', 0.3)
model_name = model_config.get('model_name', 'bert-base-uncased')
# Create model with correct architecture
model = EnhancedBertForSequenceClassification(
model_name=model_name,
num_classes=num_classes,
dropout=dropout
)
# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, tokenizer, checkpoint
def predict_fake_news(text: str, model=None, tokenizer=None, checkpoint=None):
"""
Predict whether a news article is fake or real.
Args:
text: News article text (can be title only, or title [SEP] text format)
model: Pre-loaded model (optional)
tokenizer: Pre-loaded tokenizer (optional)
checkpoint: Model checkpoint with metadata (optional)
Returns:
dict: Prediction results with label, confidence, and probabilities
"""
if model is None or tokenizer is None:
model, tokenizer, checkpoint = get_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Format input to match training format: title [SEP] text
# If input doesn't have [SEP], treat the whole input as title + duplicate as text
if '[SEP]' not in text:
# User passed only headline/claim - format it like training data
# Use the text as both title and content for better model understanding
formatted_text = f"{text} [SEP] {text}"
else:
formatted_text = text
# Determine classification type from checkpoint
num_classes = checkpoint.get('num_classes', 2) if checkpoint else 2
classification_type = checkpoint.get('classification_type', 'binary') if checkpoint else 'binary'
# Label mapping based on classification type
# NOTE: WELFake dataset uses:
# 0 = real (legitimate news)
# 1 = fake (fake/misleading news)
if classification_type == 'binary' and num_classes == 2:
labels = {
0: "real",
1: "fake"
}
elif num_classes == 6:
labels = {
0: "pants-fire",
1: "false",
2: "barely-true",
3: "half-true",
4: "mostly-true",
5: "true"
}
else:
labels = {i: f"class_{i}" for i in range(num_classes)}
# Tokenize input (use formatted text)
encoding = tokenizer(
formatted_text,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Make prediction
with torch.no_grad():
logits = model(input_ids, attention_mask)
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Convert probabilities to dict
prob_dict = {labels[i]: float(probabilities[0][i].item()) for i in range(num_classes)}
# Determine if fake based on classification type
if classification_type == 'binary':
is_fake = predicted_class == 1 # class 1 is "fake" in WELFake dataset
else:
is_fake = predicted_class < 3 # pants-fire, false, barely-true are considered fake
return {
"text": text, # Return original text, not formatted
"prediction": labels[predicted_class],
"confidence": float(confidence),
"probabilities": prob_dict,
"is_fake": is_fake,
"classification_type": classification_type
}
|