Fake / Backend /app /models /bert_model.py
Ravi1212's picture
uploaded the all the dependencies
bf5067d verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from pathlib import Path
from functools import lru_cache
class EnhancedBertForSequenceClassification(nn.Module):
def __init__(self, model_name='bert-base-uncased', num_classes=2, dropout=0.3):
super().__init__()
self.num_classes = num_classes
self.bert = BertModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
# Additional layers for better performance
self.lstm = nn.LSTM(
input_size=self.bert.config.hidden_size,
hidden_size=256,
num_layers=2,
batch_first=True,
dropout=0.2,
bidirectional=True
)
# Attention mechanism
self.attention = nn.MultiheadAttention(
embed_dim=512, # bidirectional LSTM output
num_heads=8,
dropout=0.1
)
# Classification layers
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes)
)
# Layer normalization
self.layer_norm = nn.LayerNorm(512)
def forward(self, input_ids, attention_mask):
# BERT encoding
bert_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
# Get sequence output (all tokens)
sequence_output = bert_output.last_hidden_state
sequence_output = self.dropout(sequence_output)
# LSTM layer
lstm_output, _ = self.lstm(sequence_output)
lstm_output = self.layer_norm(lstm_output)
# Self-attention
lstm_output_transposed = lstm_output.transpose(0, 1)
attn_output, _ = self.attention(
lstm_output_transposed,
lstm_output_transposed,
lstm_output_transposed
)
attn_output = attn_output.transpose(0, 1)
# Global max pooling
pooled_output = torch.max(attn_output, dim=1)[0]
# Classification
logits = self.classifier(pooled_output)
return logits
@lru_cache(maxsize=1)
def get_model():
"""
Load the fine-tuned BERT model and tokenizer.
Uses caching to load only once.
Returns:
tuple: (model, tokenizer, checkpoint_info)
"""
model_path = Path(__file__).parent.parent.parent / "enhanced_bert_welfake_model"
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(str(model_path))
# Load checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(
model_path / "model.pth",
map_location=device
)
# Get model configuration from checkpoint
num_classes = checkpoint.get('num_classes', 2)
classification_type = checkpoint.get('classification_type', 'binary')
model_config = checkpoint.get('config', {})
dropout = model_config.get('dropout', 0.3)
model_name = model_config.get('model_name', 'bert-base-uncased')
# Create model with correct architecture
model = EnhancedBertForSequenceClassification(
model_name=model_name,
num_classes=num_classes,
dropout=dropout
)
# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, tokenizer, checkpoint
def predict_fake_news(text: str, model=None, tokenizer=None, checkpoint=None):
"""
Predict whether a news article is fake or real.
Args:
text: News article text (can be title only, or title [SEP] text format)
model: Pre-loaded model (optional)
tokenizer: Pre-loaded tokenizer (optional)
checkpoint: Model checkpoint with metadata (optional)
Returns:
dict: Prediction results with label, confidence, and probabilities
"""
if model is None or tokenizer is None:
model, tokenizer, checkpoint = get_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Format input to match training format: title [SEP] text
# If input doesn't have [SEP], treat the whole input as title + duplicate as text
if '[SEP]' not in text:
# User passed only headline/claim - format it like training data
# Use the text as both title and content for better model understanding
formatted_text = f"{text} [SEP] {text}"
else:
formatted_text = text
# Determine classification type from checkpoint
num_classes = checkpoint.get('num_classes', 2) if checkpoint else 2
classification_type = checkpoint.get('classification_type', 'binary') if checkpoint else 'binary'
# Label mapping based on classification type
# NOTE: WELFake dataset uses:
# 0 = real (legitimate news)
# 1 = fake (fake/misleading news)
if classification_type == 'binary' and num_classes == 2:
labels = {
0: "real",
1: "fake"
}
elif num_classes == 6:
labels = {
0: "pants-fire",
1: "false",
2: "barely-true",
3: "half-true",
4: "mostly-true",
5: "true"
}
else:
labels = {i: f"class_{i}" for i in range(num_classes)}
# Tokenize input (use formatted text)
encoding = tokenizer(
formatted_text,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Make prediction
with torch.no_grad():
logits = model(input_ids, attention_mask)
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Convert probabilities to dict
prob_dict = {labels[i]: float(probabilities[0][i].item()) for i in range(num_classes)}
# Determine if fake based on classification type
if classification_type == 'binary':
is_fake = predicted_class == 1 # class 1 is "fake" in WELFake dataset
else:
is_fake = predicted_class < 3 # pants-fire, false, barely-true are considered fake
return {
"text": text, # Return original text, not formatted
"prediction": labels[predicted_class],
"confidence": float(confidence),
"probabilities": prob_dict,
"is_fake": is_fake,
"classification_type": classification_type
}