ttoosi commited on
Commit
3529ca1
·
verified ·
1 Parent(s): 45d4825

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +4 -4
inference.py CHANGED
@@ -591,7 +591,7 @@ class GenerativeInferenceModel:
591
 
592
  # Create a new ResNet model with pretrained weights
593
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
594
- model = nn.Sequential(self.normalizer, resnet)
595
  loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
596
 
597
  # Show missing keys by layer type
@@ -642,7 +642,7 @@ class GenerativeInferenceModel:
642
 
643
  if norm_state_dict:
644
  try:
645
- self.normalizer.load_state_dict(norm_state_dict, strict=False)
646
  print("Successfully loaded normalizer parameters")
647
  except Exception as e:
648
  print(f"Warning: Could not load normalizer parameters: {e}")
@@ -655,12 +655,12 @@ class GenerativeInferenceModel:
655
  # Fallback to PyTorch's pretrained model
656
  print("Falling back to PyTorch's pretrained model")
657
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
658
- model = nn.Sequential(self.normalizer, resnet)
659
  else:
660
  # Fallback to PyTorch's pretrained model
661
  print("No checkpoint available, using PyTorch's pretrained model")
662
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
663
- model = nn.Sequential(self.normalizer, resnet)
664
 
665
  model = model.to(device)
666
  model.eval() # Set to evaluation mode
 
591
 
592
  # Create a new ResNet model with pretrained weights
593
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
594
+ model = nn.Sequential(normalizer, resnet)
595
  loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
596
 
597
  # Show missing keys by layer type
 
642
 
643
  if norm_state_dict:
644
  try:
645
+ normalizer.load_state_dict(norm_state_dict, strict=False)
646
  print("Successfully loaded normalizer parameters")
647
  except Exception as e:
648
  print(f"Warning: Could not load normalizer parameters: {e}")
 
655
  # Fallback to PyTorch's pretrained model
656
  print("Falling back to PyTorch's pretrained model")
657
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
658
+ model = nn.Sequential(normalizer, resnet)
659
  else:
660
  # Fallback to PyTorch's pretrained model
661
  print("No checkpoint available, using PyTorch's pretrained model")
662
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
663
+ model = nn.Sequential(normalizer, resnet)
664
 
665
  model = model.to(device)
666
  model.eval() # Set to evaluation mode