Spaces:
Sleeping
Sleeping
rdsarjito commited on
Commit ·
16ad677
1
Parent(s): dab4200
[UPDATE]Model
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import torch.nn as nn
|
|
| 7 |
from PIL import Image
|
| 8 |
import requests
|
| 9 |
import easyocr
|
| 10 |
-
from transformers import AutoTokenizer
|
| 11 |
from torchvision import transforms
|
| 12 |
from torchvision import models
|
| 13 |
from torchvision.transforms import functional as F
|
|
@@ -57,6 +57,18 @@ print("OCR reader initialized.")
|
|
| 57 |
|
| 58 |
# --- Model ---
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
class LateFusionModel(nn.Module):
|
| 61 |
def __init__(self, image_model, text_model):
|
| 62 |
super(LateFusionModel, self).__init__()
|
|
@@ -75,18 +87,31 @@ class LateFusionModel(nn.Module):
|
|
| 75 |
|
| 76 |
return fused_logits, image_logits, text_logits, weights
|
| 77 |
|
| 78 |
-
# Load
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
if os.path.exists(model_path):
|
| 81 |
-
fusion_model
|
|
|
|
| 82 |
else:
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
# fusion_model = unwrap_dataparallel(fusion_model)
|
| 87 |
fusion_model.to(device)
|
| 88 |
fusion_model.eval()
|
| 89 |
-
print("Fusion model
|
| 90 |
|
| 91 |
# Load Image-Only Model
|
| 92 |
# Load image model from state_dict
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
import requests
|
| 9 |
import easyocr
|
| 10 |
+
from transformers import AutoTokenizer, AutoModel
|
| 11 |
from torchvision import transforms
|
| 12 |
from torchvision import models
|
| 13 |
from torchvision.transforms import functional as F
|
|
|
|
| 57 |
|
| 58 |
# --- Model ---
|
| 59 |
|
| 60 |
+
class TextModelWithClassifier(nn.Module):
|
| 61 |
+
def __init__(self, base_model):
|
| 62 |
+
super(TextModelWithClassifier, self).__init__()
|
| 63 |
+
self.base_model = base_model
|
| 64 |
+
self.classifier = nn.Linear(base_model.config.hidden_size, 1)
|
| 65 |
+
|
| 66 |
+
def forward(self, input_ids, attention_mask):
|
| 67 |
+
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 68 |
+
pooled_output = outputs.pooler_output if hasattr(outputs, 'pooler_output') else outputs.last_hidden_state[:, 0]
|
| 69 |
+
logits = self.classifier(pooled_output)
|
| 70 |
+
return type('Output', (), {'logits': logits})()
|
| 71 |
+
|
| 72 |
class LateFusionModel(nn.Module):
|
| 73 |
def __init__(self, image_model, text_model):
|
| 74 |
super(LateFusionModel, self).__init__()
|
|
|
|
| 87 |
|
| 88 |
return fused_logits, image_logits, text_logits, weights
|
| 89 |
|
| 90 |
+
# Load Fusion Model
|
| 91 |
+
# Create model architecture first
|
| 92 |
+
image_model_for_fusion = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
|
| 93 |
+
num_features = image_model_for_fusion.classifier[1].in_features
|
| 94 |
+
image_model_for_fusion.classifier = nn.Linear(num_features, 1)
|
| 95 |
+
|
| 96 |
+
text_base_model = AutoModel.from_pretrained('indobenchmark/indobert-base-p1')
|
| 97 |
+
text_model = TextModelWithClassifier(text_base_model)
|
| 98 |
+
|
| 99 |
+
fusion_model = LateFusionModel(image_model_for_fusion, text_model)
|
| 100 |
+
|
| 101 |
+
# Load state_dict
|
| 102 |
+
model_path = "models/best_mlp_fusion_model_state_dict.pt"
|
| 103 |
if os.path.exists(model_path):
|
| 104 |
+
fusion_model.load_state_dict(torch.load(model_path, map_location=device))
|
| 105 |
+
print("Fusion model loaded from local state_dict successfully!")
|
| 106 |
else:
|
| 107 |
+
print("Fusion model not found locally. Downloading from Hugging Face Hub...")
|
| 108 |
+
model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_mlp_fusion_model_state_dict.pt")
|
| 109 |
+
fusion_model.load_state_dict(torch.load(model_path, map_location=device))
|
| 110 |
+
print("Fusion model downloaded and loaded successfully!")
|
| 111 |
|
|
|
|
| 112 |
fusion_model.to(device)
|
| 113 |
fusion_model.eval()
|
| 114 |
+
print("Fusion model ready!")
|
| 115 |
|
| 116 |
# Load Image-Only Model
|
| 117 |
# Load image model from state_dict
|