Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +4 -4
inference.py
CHANGED
|
@@ -591,7 +591,7 @@ class GenerativeInferenceModel:
|
|
| 591 |
|
| 592 |
# Create a new ResNet model with pretrained weights
|
| 593 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 594 |
-
model = nn.Sequential(
|
| 595 |
loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
|
| 596 |
|
| 597 |
# Show missing keys by layer type
|
|
@@ -642,7 +642,7 @@ class GenerativeInferenceModel:
|
|
| 642 |
|
| 643 |
if norm_state_dict:
|
| 644 |
try:
|
| 645 |
-
|
| 646 |
print("Successfully loaded normalizer parameters")
|
| 647 |
except Exception as e:
|
| 648 |
print(f"Warning: Could not load normalizer parameters: {e}")
|
|
@@ -655,12 +655,12 @@ class GenerativeInferenceModel:
|
|
| 655 |
# Fallback to PyTorch's pretrained model
|
| 656 |
print("Falling back to PyTorch's pretrained model")
|
| 657 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 658 |
-
model = nn.Sequential(
|
| 659 |
else:
|
| 660 |
# Fallback to PyTorch's pretrained model
|
| 661 |
print("No checkpoint available, using PyTorch's pretrained model")
|
| 662 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 663 |
-
model = nn.Sequential(
|
| 664 |
|
| 665 |
model = model.to(device)
|
| 666 |
model.eval() # Set to evaluation mode
|
|
|
|
| 591 |
|
| 592 |
# Create a new ResNet model with pretrained weights
|
| 593 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 594 |
+
model = nn.Sequential(normalizer, resnet)
|
| 595 |
loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
|
| 596 |
|
| 597 |
# Show missing keys by layer type
|
|
|
|
| 642 |
|
| 643 |
if norm_state_dict:
|
| 644 |
try:
|
| 645 |
+
normalizer.load_state_dict(norm_state_dict, strict=False)
|
| 646 |
print("Successfully loaded normalizer parameters")
|
| 647 |
except Exception as e:
|
| 648 |
print(f"Warning: Could not load normalizer parameters: {e}")
|
|
|
|
| 655 |
# Fallback to PyTorch's pretrained model
|
| 656 |
print("Falling back to PyTorch's pretrained model")
|
| 657 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 658 |
+
model = nn.Sequential(normalizer, resnet)
|
| 659 |
else:
|
| 660 |
# Fallback to PyTorch's pretrained model
|
| 661 |
print("No checkpoint available, using PyTorch's pretrained model")
|
| 662 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 663 |
+
model = nn.Sequential(normalizer, resnet)
|
| 664 |
|
| 665 |
model = model.to(device)
|
| 666 |
model.eval() # Set to evaluation mode
|