rdsarjito commited on
Commit
83d56a4
·
1 Parent(s): 16ad677

[UPDATE]Model

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -60,11 +60,11 @@ print("OCR reader initialized.")
60
  class TextModelWithClassifier(nn.Module):
61
  def __init__(self, base_model):
62
  super(TextModelWithClassifier, self).__init__()
63
- self.base_model = base_model
64
  self.classifier = nn.Linear(base_model.config.hidden_size, 1)
65
 
66
  def forward(self, input_ids, attention_mask):
67
- outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
68
  pooled_output = outputs.pooler_output if hasattr(outputs, 'pooler_output') else outputs.last_hidden_state[:, 0]
69
  logits = self.classifier(pooled_output)
70
  return type('Output', (), {'logits': logits})()
@@ -101,13 +101,27 @@ fusion_model = LateFusionModel(image_model_for_fusion, text_model)
101
  # Load state_dict
102
  model_path = "models/best_mlp_fusion_model_state_dict.pt"
103
  if os.path.exists(model_path):
104
- fusion_model.load_state_dict(torch.load(model_path, map_location=device))
105
- print("Fusion model loaded from local state_dict successfully!")
 
 
 
 
 
 
 
106
  else:
107
  print("Fusion model not found locally. Downloading from Hugging Face Hub...")
108
  model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_mlp_fusion_model_state_dict.pt")
109
- fusion_model.load_state_dict(torch.load(model_path, map_location=device))
110
- print("Fusion model downloaded and loaded successfully!")
 
 
 
 
 
 
 
111
 
112
  fusion_model.to(device)
113
  fusion_model.eval()
 
60
  class TextModelWithClassifier(nn.Module):
61
  def __init__(self, base_model):
62
  super(TextModelWithClassifier, self).__init__()
63
+ self.bert = base_model # Use 'bert' to match saved state_dict keys
64
  self.classifier = nn.Linear(base_model.config.hidden_size, 1)
65
 
66
  def forward(self, input_ids, attention_mask):
67
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
68
  pooled_output = outputs.pooler_output if hasattr(outputs, 'pooler_output') else outputs.last_hidden_state[:, 0]
69
  logits = self.classifier(pooled_output)
70
  return type('Output', (), {'logits': logits})()
 
101
  # Load state_dict
102
  model_path = "models/best_mlp_fusion_model_state_dict.pt"
103
  if os.path.exists(model_path):
104
+ state_dict = torch.load(model_path, map_location=device)
105
+ try:
106
+ fusion_model.load_state_dict(state_dict, strict=True)
107
+ print("Fusion model loaded from local state_dict successfully!")
108
+ except RuntimeError as e:
109
+ print(f"Warning: Some keys didn't match. Trying with strict=False...")
110
+ print(f"Error details: {str(e)[:500]}")
111
+ fusion_model.load_state_dict(state_dict, strict=False)
112
+ print("Fusion model loaded with strict=False (some keys may be missing)")
113
  else:
114
  print("Fusion model not found locally. Downloading from Hugging Face Hub...")
115
  model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_mlp_fusion_model_state_dict.pt")
116
+ state_dict = torch.load(model_path, map_location=device)
117
+ try:
118
+ fusion_model.load_state_dict(state_dict, strict=True)
119
+ print("Fusion model downloaded and loaded successfully!")
120
+ except RuntimeError as e:
121
+ print(f"Warning: Some keys didn't match. Trying with strict=False...")
122
+ print(f"Error details: {str(e)[:500]}")
123
+ fusion_model.load_state_dict(state_dict, strict=False)
124
+ print("Fusion model loaded with strict=False (some keys may be missing)")
125
 
126
  fusion_model.to(device)
127
  fusion_model.eval()