Spaces:
Sleeping
Sleeping
rdsarjito
commited on
Commit
·
83d56a4
1
Parent(s):
16ad677
[UPDATE]Model
Browse files
app.py
CHANGED
|
@@ -60,11 +60,11 @@ print("OCR reader initialized.")
|
|
| 60 |
class TextModelWithClassifier(nn.Module):
|
| 61 |
def __init__(self, base_model):
|
| 62 |
super(TextModelWithClassifier, self).__init__()
|
| 63 |
-
self.
|
| 64 |
self.classifier = nn.Linear(base_model.config.hidden_size, 1)
|
| 65 |
|
| 66 |
def forward(self, input_ids, attention_mask):
|
| 67 |
-
outputs = self.
|
| 68 |
pooled_output = outputs.pooler_output if hasattr(outputs, 'pooler_output') else outputs.last_hidden_state[:, 0]
|
| 69 |
logits = self.classifier(pooled_output)
|
| 70 |
return type('Output', (), {'logits': logits})()
|
|
@@ -101,13 +101,27 @@ fusion_model = LateFusionModel(image_model_for_fusion, text_model)
|
|
| 101 |
# Load state_dict
|
| 102 |
model_path = "models/best_mlp_fusion_model_state_dict.pt"
|
| 103 |
if os.path.exists(model_path):
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
else:
|
| 107 |
print("Fusion model not found locally. Downloading from Hugging Face Hub...")
|
| 108 |
model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_mlp_fusion_model_state_dict.pt")
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
fusion_model.to(device)
|
| 113 |
fusion_model.eval()
|
|
|
|
| 60 |
class TextModelWithClassifier(nn.Module):
|
| 61 |
def __init__(self, base_model):
|
| 62 |
super(TextModelWithClassifier, self).__init__()
|
| 63 |
+
self.bert = base_model # Use 'bert' to match saved state_dict keys
|
| 64 |
self.classifier = nn.Linear(base_model.config.hidden_size, 1)
|
| 65 |
|
| 66 |
def forward(self, input_ids, attention_mask):
|
| 67 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 68 |
pooled_output = outputs.pooler_output if hasattr(outputs, 'pooler_output') else outputs.last_hidden_state[:, 0]
|
| 69 |
logits = self.classifier(pooled_output)
|
| 70 |
return type('Output', (), {'logits': logits})()
|
|
|
|
| 101 |
# Load state_dict
|
| 102 |
model_path = "models/best_mlp_fusion_model_state_dict.pt"
|
| 103 |
if os.path.exists(model_path):
|
| 104 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 105 |
+
try:
|
| 106 |
+
fusion_model.load_state_dict(state_dict, strict=True)
|
| 107 |
+
print("Fusion model loaded from local state_dict successfully!")
|
| 108 |
+
except RuntimeError as e:
|
| 109 |
+
print(f"Warning: Some keys didn't match. Trying with strict=False...")
|
| 110 |
+
print(f"Error details: {str(e)[:500]}")
|
| 111 |
+
fusion_model.load_state_dict(state_dict, strict=False)
|
| 112 |
+
print("Fusion model loaded with strict=False (some keys may be missing)")
|
| 113 |
else:
|
| 114 |
print("Fusion model not found locally. Downloading from Hugging Face Hub...")
|
| 115 |
model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_mlp_fusion_model_state_dict.pt")
|
| 116 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 117 |
+
try:
|
| 118 |
+
fusion_model.load_state_dict(state_dict, strict=True)
|
| 119 |
+
print("Fusion model downloaded and loaded successfully!")
|
| 120 |
+
except RuntimeError as e:
|
| 121 |
+
print(f"Warning: Some keys didn't match. Trying with strict=False...")
|
| 122 |
+
print(f"Error details: {str(e)[:500]}")
|
| 123 |
+
fusion_model.load_state_dict(state_dict, strict=False)
|
| 124 |
+
print("Fusion model loaded with strict=False (some keys may be missing)")
|
| 125 |
|
| 126 |
fusion_model.to(device)
|
| 127 |
fusion_model.eval()
|