rdsarjito commited on
Commit
16ad677
·
1 Parent(s): dab4200

[UPDATE]Model

Browse files
Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn as nn
7
  from PIL import Image
8
  import requests
9
  import easyocr
10
- from transformers import AutoTokenizer
11
  from torchvision import transforms
12
  from torchvision import models
13
  from torchvision.transforms import functional as F
@@ -57,6 +57,18 @@ print("OCR reader initialized.")
57
 
58
  # --- Model ---
59
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class LateFusionModel(nn.Module):
61
  def __init__(self, image_model, text_model):
62
  super(LateFusionModel, self).__init__()
@@ -75,18 +87,31 @@ class LateFusionModel(nn.Module):
75
 
76
  return fused_logits, image_logits, text_logits, weights
77
 
78
- # Load model
79
- model_path = "models/best_fusion_model.pt"
 
 
 
 
 
 
 
 
 
 
 
80
  if os.path.exists(model_path):
81
- fusion_model = torch.load(model_path, map_location=device, weights_only=False)
 
82
  else:
83
- model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_fusion_model.pt")
84
- fusion_model = torch.load(model_path, map_location=device, weights_only=False)
 
 
85
 
86
- # fusion_model = unwrap_dataparallel(fusion_model)
87
  fusion_model.to(device)
88
  fusion_model.eval()
89
- print("Fusion model loaded successfully!")
90
 
91
  # Load Image-Only Model
92
  # Load image model from state_dict
 
7
  from PIL import Image
8
  import requests
9
  import easyocr
10
+ from transformers import AutoTokenizer, AutoModel
11
  from torchvision import transforms
12
  from torchvision import models
13
  from torchvision.transforms import functional as F
 
57
 
58
  # --- Model ---
59
 
60
+ class TextModelWithClassifier(nn.Module):
61
+ def __init__(self, base_model):
62
+ super(TextModelWithClassifier, self).__init__()
63
+ self.base_model = base_model
64
+ self.classifier = nn.Linear(base_model.config.hidden_size, 1)
65
+
66
+ def forward(self, input_ids, attention_mask):
67
+ outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
68
+ pooled_output = outputs.pooler_output if hasattr(outputs, 'pooler_output') else outputs.last_hidden_state[:, 0]
69
+ logits = self.classifier(pooled_output)
70
+ return type('Output', (), {'logits': logits})()
71
+
72
  class LateFusionModel(nn.Module):
73
  def __init__(self, image_model, text_model):
74
  super(LateFusionModel, self).__init__()
 
87
 
88
  return fused_logits, image_logits, text_logits, weights
89
 
90
+ # Load Fusion Model
91
+ # Create model architecture first
92
+ image_model_for_fusion = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
93
+ num_features = image_model_for_fusion.classifier[1].in_features
94
+ image_model_for_fusion.classifier = nn.Linear(num_features, 1)
95
+
96
+ text_base_model = AutoModel.from_pretrained('indobenchmark/indobert-base-p1')
97
+ text_model = TextModelWithClassifier(text_base_model)
98
+
99
+ fusion_model = LateFusionModel(image_model_for_fusion, text_model)
100
+
101
+ # Load state_dict
102
+ model_path = "models/best_mlp_fusion_model_state_dict.pt"
103
  if os.path.exists(model_path):
104
+ fusion_model.load_state_dict(torch.load(model_path, map_location=device))
105
+ print("Fusion model loaded from local state_dict successfully!")
106
  else:
107
+ print("Fusion model not found locally. Downloading from Hugging Face Hub...")
108
+ model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_mlp_fusion_model_state_dict.pt")
109
+ fusion_model.load_state_dict(torch.load(model_path, map_location=device))
110
+ print("Fusion model downloaded and loaded successfully!")
111
 
 
112
  fusion_model.to(device)
113
  fusion_model.eval()
114
+ print("Fusion model ready!")
115
 
116
  # Load Image-Only Model
117
  # Load image model from state_dict