| import torch
|
| import torch.nn as nn
|
| from transformers import BertTokenizer, BertModel
|
| from pathlib import Path
|
| from functools import lru_cache
|
|
|
| class EnhancedBertForSequenceClassification(nn.Module):
|
| def __init__(self, model_name='bert-base-uncased', num_classes=2, dropout=0.3):
|
| super().__init__()
|
| self.num_classes = num_classes
|
| self.bert = BertModel.from_pretrained(model_name)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
|
|
| self.lstm = nn.LSTM(
|
| input_size=self.bert.config.hidden_size,
|
| hidden_size=256,
|
| num_layers=2,
|
| batch_first=True,
|
| dropout=0.2,
|
| bidirectional=True
|
| )
|
|
|
|
|
| self.attention = nn.MultiheadAttention(
|
| embed_dim=512,
|
| num_heads=8,
|
| dropout=0.1
|
| )
|
|
|
|
|
| self.classifier = nn.Sequential(
|
| nn.Linear(512, 256),
|
| nn.ReLU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(256, 128),
|
| nn.ReLU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(128, num_classes)
|
| )
|
|
|
|
|
| self.layer_norm = nn.LayerNorm(512)
|
|
|
| def forward(self, input_ids, attention_mask):
|
|
|
| bert_output = self.bert(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask
|
| )
|
|
|
|
|
| sequence_output = bert_output.last_hidden_state
|
| sequence_output = self.dropout(sequence_output)
|
|
|
|
|
| lstm_output, _ = self.lstm(sequence_output)
|
| lstm_output = self.layer_norm(lstm_output)
|
|
|
|
|
| lstm_output_transposed = lstm_output.transpose(0, 1)
|
| attn_output, _ = self.attention(
|
| lstm_output_transposed,
|
| lstm_output_transposed,
|
| lstm_output_transposed
|
| )
|
| attn_output = attn_output.transpose(0, 1)
|
|
|
|
|
| pooled_output = torch.max(attn_output, dim=1)[0]
|
|
|
|
|
| logits = self.classifier(pooled_output)
|
|
|
| return logits
|
|
|
| @lru_cache(maxsize=1)
|
| def get_model():
|
| """
|
| Load the fine-tuned BERT model and tokenizer.
|
| Uses caching to load only once.
|
|
|
| Returns:
|
| tuple: (model, tokenizer, checkpoint_info)
|
| """
|
| model_path = Path(__file__).parent.parent.parent / "enhanced_bert_welfake_model"
|
|
|
|
|
| tokenizer = BertTokenizer.from_pretrained(str(model_path))
|
|
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| checkpoint = torch.load(
|
| model_path / "model.pth",
|
| map_location=device
|
| )
|
|
|
|
|
| num_classes = checkpoint.get('num_classes', 2)
|
| classification_type = checkpoint.get('classification_type', 'binary')
|
| model_config = checkpoint.get('config', {})
|
| dropout = model_config.get('dropout', 0.3)
|
| model_name = model_config.get('model_name', 'bert-base-uncased')
|
|
|
|
|
| model = EnhancedBertForSequenceClassification(
|
| model_name=model_name,
|
| num_classes=num_classes,
|
| dropout=dropout
|
| )
|
|
|
|
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| model.to(device)
|
| model.eval()
|
|
|
| return model, tokenizer, checkpoint
|
|
|
| def predict_fake_news(text: str, model=None, tokenizer=None, checkpoint=None):
|
| """
|
| Predict whether a news article is fake or real.
|
|
|
| Args:
|
| text: News article text (can be title only, or title [SEP] text format)
|
| model: Pre-loaded model (optional)
|
| tokenizer: Pre-loaded tokenizer (optional)
|
| checkpoint: Model checkpoint with metadata (optional)
|
|
|
| Returns:
|
| dict: Prediction results with label, confidence, and probabilities
|
| """
|
| if model is None or tokenizer is None:
|
| model, tokenizer, checkpoint = get_model()
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
| if '[SEP]' not in text:
|
|
|
|
|
| formatted_text = f"{text} [SEP] {text}"
|
| else:
|
| formatted_text = text
|
|
|
|
|
| num_classes = checkpoint.get('num_classes', 2) if checkpoint else 2
|
| classification_type = checkpoint.get('classification_type', 'binary') if checkpoint else 'binary'
|
|
|
|
|
|
|
|
|
|
|
| if classification_type == 'binary' and num_classes == 2:
|
| labels = {
|
| 0: "real",
|
| 1: "fake"
|
| }
|
| elif num_classes == 6:
|
| labels = {
|
| 0: "pants-fire",
|
| 1: "false",
|
| 2: "barely-true",
|
| 3: "half-true",
|
| 4: "mostly-true",
|
| 5: "true"
|
| }
|
| else:
|
| labels = {i: f"class_{i}" for i in range(num_classes)}
|
|
|
|
|
| encoding = tokenizer(
|
| formatted_text,
|
| add_special_tokens=True,
|
| max_length=512,
|
| padding='max_length',
|
| truncation=True,
|
| return_tensors='pt'
|
| )
|
|
|
| input_ids = encoding['input_ids'].to(device)
|
| attention_mask = encoding['attention_mask'].to(device)
|
|
|
|
|
| with torch.no_grad():
|
| logits = model(input_ids, attention_mask)
|
| probabilities = torch.softmax(logits, dim=1)
|
| predicted_class = torch.argmax(probabilities, dim=1).item()
|
| confidence = probabilities[0][predicted_class].item()
|
|
|
|
|
| prob_dict = {labels[i]: float(probabilities[0][i].item()) for i in range(num_classes)}
|
|
|
|
|
| if classification_type == 'binary':
|
| is_fake = predicted_class == 1
|
| else:
|
| is_fake = predicted_class < 3
|
|
|
| return {
|
| "text": text,
|
| "prediction": labels[predicted_class],
|
| "confidence": float(confidence),
|
| "probabilities": prob_dict,
|
| "is_fake": is_fake,
|
| "classification_type": classification_type
|
| }
|
|
|