rdsarjito commited on
Commit
3e9fd4e
ยท
1 Parent(s): c48b7e8
Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -74,16 +74,27 @@ class LateFusionModel(nn.Module):
74
  super(LateFusionModel, self).__init__()
75
  self.image_model = image_model
76
  self.text_model = text_model
77
- self.image_weight = nn.Parameter(torch.tensor(0.5))
78
- self.text_weight = nn.Parameter(torch.tensor(0.5))
 
 
 
 
 
 
 
79
 
80
  def forward(self, images, input_ids, attention_mask):
81
  with torch.no_grad():
82
  image_logits = self.image_model(images).squeeze(1)
83
  text_logits = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits.squeeze(1)
84
 
85
- weights = torch.softmax(torch.stack([self.image_weight, self.text_weight]), dim=0)
86
- fused_logits = weights[0] * image_logits + weights[1] * text_logits
 
 
 
 
87
 
88
  return fused_logits, image_logits, text_logits, weights
89
 
@@ -91,7 +102,11 @@ class LateFusionModel(nn.Module):
91
  # Create model architecture first
92
  image_model_for_fusion = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
93
  num_features = image_model_for_fusion.classifier[1].in_features
94
- image_model_for_fusion.classifier = nn.Linear(num_features, 1)
 
 
 
 
95
 
96
  text_base_model = AutoModel.from_pretrained('indobenchmark/indobert-base-p1')
97
  text_model = TextModelWithClassifier(text_base_model)
@@ -404,14 +419,23 @@ def predict_single_url(url):
404
  }
405
 
406
  confidence = gambling_prob if is_gambling else non_gambling_prob
407
- image_weight = weights[0].item()
408
- text_weight = weights[1].item()
 
 
 
 
 
 
 
 
 
409
 
410
  confidence_md = f"**Confidence:** {confidence:.1%}\n\n**Model Used:** Fusion Model (Image + Text)\n\n**Prediction:** {'๐ŸŸฅ Gambling' if is_gambling else '๐ŸŸฉ Non-Gambling'}"
411
 
412
- model_info = f"""**Model Type:** Fusion Model
413
- **Image Model:** EfficientNet-B3 (Weight: {image_weight:.1%})
414
- **Text Model:** IndoBERT (Weight: {text_weight:.1%})
415
 
416
  **Individual Predictions:**
417
  - ๐Ÿ–ผ๏ธ Image Model: {image_probs[0].item():.1%}
@@ -610,9 +634,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="Gambling Website D
610
  with gr.Column():
611
  file_input = gr.File(
612
  label="Upload URL File (.txt)",
613
- file_types=[".txt"],
614
- info="Upload a .txt file with one URL per line"
615
  )
 
616
  batch_predict_button = gr.Button(
617
  "๐Ÿš€ Process Batch",
618
  variant="primary",
 
74
  super(LateFusionModel, self).__init__()
75
  self.image_model = image_model
76
  self.text_model = text_model
77
+ # MLP fusion layer (matching saved model structure)
78
+ # Structure: Linear(2, hidden) -> ReLU -> Dropout -> Linear(hidden, 1)
79
+ hidden_dim = 64 # Adjust if needed based on saved model
80
+ self.fusion_mlp = nn.Sequential(
81
+ nn.Linear(2, hidden_dim), # layer 0
82
+ nn.ReLU(), # layer 1 (no params)
83
+ nn.Dropout(0.1), # layer 2 (no params)
84
+ nn.Linear(hidden_dim, 1) # layer 3
85
+ )
86
 
87
  def forward(self, images, input_ids, attention_mask):
88
  with torch.no_grad():
89
  image_logits = self.image_model(images).squeeze(1)
90
  text_logits = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits.squeeze(1)
91
 
92
+ # Stack logits and pass through MLP
93
+ stacked_logits = torch.stack([image_logits, text_logits], dim=1)
94
+ fused_logits = self.fusion_mlp(stacked_logits).squeeze(1)
95
+
96
+ # For compatibility, create dummy weights
97
+ weights = torch.tensor([0.5, 0.5], device=fused_logits.device)
98
 
99
  return fused_logits, image_logits, text_logits, weights
100
 
 
102
  # Create model architecture first
103
  image_model_for_fusion = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
104
  num_features = image_model_for_fusion.classifier[1].in_features
105
+ # Match saved model structure: classifier.1 instead of classifier
106
+ image_model_for_fusion.classifier = nn.Sequential(
107
+ nn.Dropout(p=0.3, inplace=True),
108
+ nn.Linear(num_features, 1)
109
+ )
110
 
111
  text_base_model = AutoModel.from_pretrained('indobenchmark/indobert-base-p1')
112
  text_model = TextModelWithClassifier(text_base_model)
 
419
  }
420
 
421
  confidence = gambling_prob if is_gambling else non_gambling_prob
422
+
423
+ # Calculate relative contribution (approximation for MLP fusion)
424
+ image_contrib = abs(image_probs[0].item() - 0.5)
425
+ text_contrib = abs(text_probs[0].item() - 0.5)
426
+ total_contrib = image_contrib + text_contrib
427
+ if total_contrib > 0:
428
+ image_weight = image_contrib / total_contrib
429
+ text_weight = text_contrib / total_contrib
430
+ else:
431
+ image_weight = 0.5
432
+ text_weight = 0.5
433
 
434
  confidence_md = f"**Confidence:** {confidence:.1%}\n\n**Model Used:** Fusion Model (Image + Text)\n\n**Prediction:** {'๐ŸŸฅ Gambling' if is_gambling else '๐ŸŸฉ Non-Gambling'}"
435
 
436
+ model_info = f"""**Model Type:** Fusion Model (MLP)
437
+ **Image Model:** EfficientNet-B3
438
+ **Text Model:** IndoBERT
439
 
440
  **Individual Predictions:**
441
  - ๐Ÿ–ผ๏ธ Image Model: {image_probs[0].item():.1%}
 
634
  with gr.Column():
635
  file_input = gr.File(
636
  label="Upload URL File (.txt)",
637
+ file_types=[".txt"]
 
638
  )
639
+ gr.Markdown("๐Ÿ’ก **Tip:** Upload a .txt file with one URL per line")
640
  batch_predict_button = gr.Button(
641
  "๐Ÿš€ Process Batch",
642
  variant="primary",