ttoosi commited on
Commit
45d4825
·
verified ·
1 Parent(s): 15e838d

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +18 -10
inference.py CHANGED
@@ -29,6 +29,14 @@ MODEL_URLS = {
29
  'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
30
  }
31
 
 
 
 
 
 
 
 
 
32
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
33
  IMAGENET_STD = [0.229, 0.224, 0.225]
34
 
@@ -282,7 +290,8 @@ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50
282
  class GenerativeInferenceModel:
283
  def __init__(self):
284
  self.models = {}
285
- self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
 
286
  self.labels = get_imagenet_labels()
287
 
288
  def verify_model_integrity(self, model, model_type):
@@ -337,21 +346,20 @@ class GenerativeInferenceModel:
337
  return True
338
 
339
  def load_model(self, model_type):
340
- """Load model from checkpoint or use pretrained model."""
341
  if model_type in self.models:
342
  print(f"Using cached {model_type} model")
343
  return self.models[model_type]
344
-
345
- # Record loading time for performance analysis
346
  start_time = time.time()
347
  model_path = download_model(model_type)
348
-
349
- # Create a sequential model with normalizer and ResNet50
 
 
 
 
350
  resnet = models.resnet50()
351
- model = nn.Sequential(
352
- self.normalizer, # Normalizer is part of the model sequence
353
- resnet
354
- )
355
 
356
  # Load the model checkpoint
357
  if model_path:
 
29
  'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
30
  }
31
 
32
+ # Per-model input size and normalization
33
+ MODEL_PREPROC = {
34
+ "resnet50_robust": {"size": 224, "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
35
+ "resnet50_standard": {"size": 224, "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
36
+ # Typical for face models trained ArcFace/InsightFace-style
37
+ "resnet50_robust_face": {"size": 112, "mean": [0.5, 0.5, 0.5], "std": [0.5, 0.5, 0.5]},
38
+ }
39
+
40
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
41
  IMAGENET_STD = [0.229, 0.224, 0.225]
42
 
 
290
  class GenerativeInferenceModel:
291
  def __init__(self):
292
  self.models = {}
293
+ #self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
294
+ self.model_preproc = {}
295
  self.labels = get_imagenet_labels()
296
 
297
  def verify_model_integrity(self, model, model_type):
 
346
  return True
347
 
348
  def load_model(self, model_type):
 
349
  if model_type in self.models:
350
  print(f"Using cached {model_type} model")
351
  return self.models[model_type]
352
+
 
353
  start_time = time.time()
354
  model_path = download_model(model_type)
355
+
356
+ # pick preproc for this model
357
+ pre = MODEL_PREPROC.get(model_type, {"size": 224, "mean": IMAGENET_MEAN, "std": IMAGENET_STD})
358
+ normalizer = NormalizeByChannelMeanStd(pre["mean"], pre["std"]).to(device)
359
+ self.model_preproc[model_type] = pre
360
+
361
  resnet = models.resnet50()
362
+ model = nn.Sequential(normalizer, resnet)
 
 
 
363
 
364
  # Load the model checkpoint
365
  if model_path: