Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,8 +27,8 @@ class MultiTaskModel(nn.Module):
|
|
| 27 |
########################################
|
| 28 |
# 2. Reconstruct the Model and Load Weights
|
| 29 |
########################################
|
| 30 |
-
#
|
| 31 |
-
num_obj_classes =
|
| 32 |
|
| 33 |
device = torch.device("cpu")
|
| 34 |
|
|
@@ -43,7 +43,7 @@ model.to(device)
|
|
| 43 |
|
| 44 |
# Download the state dict from HF Hub.
|
| 45 |
repo_id = "Abdu07/multitask-model" # Your repo name
|
| 46 |
-
filename = "Yolloplusclassproject_weights.pth" #
|
| 47 |
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 48 |
|
| 49 |
# Load the state dict and update the model.
|
|
@@ -55,12 +55,10 @@ model.eval()
|
|
| 55 |
# 3. Define Label Mappings and Transforms
|
| 56 |
########################################
|
| 57 |
# Update these with your actual label mappings.
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# ... add the rest of your object classes ...
|
| 63 |
-
}
|
| 64 |
bin_label_names = ["AI-Generated", "Real"]
|
| 65 |
|
| 66 |
# Define the validation transforms (must match those used during training)
|
|
|
|
| 27 |
########################################
|
| 28 |
# 2. Reconstruct the Model and Load Weights
|
| 29 |
########################################
|
| 30 |
+
# IMPORTANT: Use the same number of object classes as in training.
|
| 31 |
+
num_obj_classes = 494 # Updated to match the state dict from training
|
| 32 |
|
| 33 |
device = torch.device("cpu")
|
| 34 |
|
|
|
|
| 43 |
|
| 44 |
# Download the state dict from HF Hub.
|
| 45 |
repo_id = "Abdu07/multitask-model" # Your repo name
|
| 46 |
+
filename = "Yolloplusclassproject_weights.pth" # The state dict file you uploaded
|
| 47 |
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 48 |
|
| 49 |
# Load the state dict and update the model.
|
|
|
|
| 55 |
# 3. Define Label Mappings and Transforms
|
| 56 |
########################################
|
| 57 |
# Update these with your actual label mappings.
|
| 58 |
+
# They should reflect the 494 unique pseudo-labels produced during training.
|
| 59 |
+
# For this example, we assume that the mapping is stored somewhere.
|
| 60 |
+
# Here we provide a dummy mapping for illustration. Replace it with your real mapping.
|
| 61 |
+
idx_to_obj_label = {i: f"label_{i}" for i in range(num_obj_classes)}
|
|
|
|
|
|
|
| 62 |
bin_label_names = ["AI-Generated", "Real"]
|
| 63 |
|
| 64 |
# Define the validation transforms (must match those used during training)
|