Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,33 +22,24 @@ else:
|
|
| 22 |
logging.error(f"Model file not found: {model_path}")
|
| 23 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 24 |
|
| 25 |
-
# Create
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
config.id2label = {str(i): label for i, label in enumerate(labels)}
|
| 29 |
-
config.label2id = {label: str(i) for i, label in enumerate(labels)}
|
| 30 |
-
logging.info(f"Custom config created with {len(labels)} labels")
|
| 31 |
|
| 32 |
-
# Load the model with
|
| 33 |
-
logging.info("Loading the model with custom
|
| 34 |
-
model = ViTForImageClassification(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
try:
|
| 37 |
# Load the state dict
|
| 38 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
| 39 |
-
|
| 40 |
-
# Check if the state dict keys match the model's keys
|
| 41 |
-
model_keys = set(model.state_dict().keys())
|
| 42 |
-
loaded_keys = set(state_dict.keys())
|
| 43 |
-
|
| 44 |
-
if model_keys != loaded_keys:
|
| 45 |
-
logging.warning("Mismatch in state dict keys. Attempting to adjust...")
|
| 46 |
-
# Adjust keys if necessary (e.g., remove 'module.' prefix if it exists)
|
| 47 |
-
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
| 48 |
-
model.load_state_dict(new_state_dict)
|
| 49 |
-
else:
|
| 50 |
-
model.load_state_dict(state_dict)
|
| 51 |
-
|
| 52 |
logging.info("Model loaded successfully")
|
| 53 |
except Exception as e:
|
| 54 |
logging.error(f"Error loading model: {str(e)}")
|
|
@@ -57,10 +48,11 @@ except Exception as e:
|
|
| 57 |
model.eval()
|
| 58 |
logging.info("Model set to evaluation mode")
|
| 59 |
|
| 60 |
-
# Load
|
| 61 |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 62 |
logging.info("Feature extractor loaded")
|
| 63 |
|
|
|
|
| 64 |
logging.info("Model and feature extractor loaded successfully")
|
| 65 |
|
| 66 |
# Define the prediction function
|
|
|
|
| 22 |
logging.error(f"Model file not found: {model_path}")
|
| 23 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 24 |
|
| 25 |
+
# Create label mappings consistent with training
|
| 26 |
+
id2label = {str(i): label for i, label in enumerate(labels)}
|
| 27 |
+
label2id = {label: str(i) for i, label in enumerate(labels)}
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# Load the model with custom label mapping
|
| 30 |
+
logging.info("Loading the model with custom label mapping")
|
| 31 |
+
model = ViTForImageClassification.from_pretrained(
|
| 32 |
+
"google/vit-base-patch16-224-in21k",
|
| 33 |
+
num_labels=len(labels),
|
| 34 |
+
id2label=id2label,
|
| 35 |
+
label2id=label2id,
|
| 36 |
+
ignore_mismatched_sizes=True
|
| 37 |
+
)
|
| 38 |
|
| 39 |
try:
|
| 40 |
# Load the state dict
|
| 41 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
| 42 |
+
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
logging.info("Model loaded successfully")
|
| 44 |
except Exception as e:
|
| 45 |
logging.error(f"Error loading model: {str(e)}")
|
|
|
|
| 48 |
model.eval()
|
| 49 |
logging.info("Model set to evaluation mode")
|
| 50 |
|
| 51 |
+
# Load feature extractor
|
| 52 |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 53 |
logging.info("Feature extractor loaded")
|
| 54 |
|
| 55 |
+
|
| 56 |
logging.info("Model and feature extractor loaded successfully")
|
| 57 |
|
| 58 |
# Define the prediction function
|