Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +18 -10
inference.py
CHANGED
|
@@ -29,6 +29,14 @@ MODEL_URLS = {
|
|
| 29 |
'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
|
| 30 |
}
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 33 |
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 34 |
|
|
@@ -282,7 +290,8 @@ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50
|
|
| 282 |
class GenerativeInferenceModel:
|
| 283 |
def __init__(self):
|
| 284 |
self.models = {}
|
| 285 |
-
self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
|
|
|
|
| 286 |
self.labels = get_imagenet_labels()
|
| 287 |
|
| 288 |
def verify_model_integrity(self, model, model_type):
|
|
@@ -337,21 +346,20 @@ class GenerativeInferenceModel:
|
|
| 337 |
return True
|
| 338 |
|
| 339 |
def load_model(self, model_type):
|
| 340 |
-
"""Load model from checkpoint or use pretrained model."""
|
| 341 |
if model_type in self.models:
|
| 342 |
print(f"Using cached {model_type} model")
|
| 343 |
return self.models[model_type]
|
| 344 |
-
|
| 345 |
-
# Record loading time for performance analysis
|
| 346 |
start_time = time.time()
|
| 347 |
model_path = download_model(model_type)
|
| 348 |
-
|
| 349 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
resnet = models.resnet50()
|
| 351 |
-
model = nn.Sequential(
|
| 352 |
-
self.normalizer, # Normalizer is part of the model sequence
|
| 353 |
-
resnet
|
| 354 |
-
)
|
| 355 |
|
| 356 |
# Load the model checkpoint
|
| 357 |
if model_path:
|
|
|
|
| 29 |
'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
|
| 30 |
}
|
| 31 |
|
| 32 |
+
# Per-model input size and normalization
|
| 33 |
+
MODEL_PREPROC = {
|
| 34 |
+
"resnet50_robust": {"size": 224, "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
|
| 35 |
+
"resnet50_standard": {"size": 224, "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
|
| 36 |
+
# Typical for face models trained ArcFace/InsightFace-style
|
| 37 |
+
"resnet50_robust_face": {"size": 112, "mean": [0.5, 0.5, 0.5], "std": [0.5, 0.5, 0.5]},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 41 |
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 42 |
|
|
|
|
| 290 |
class GenerativeInferenceModel:
|
| 291 |
def __init__(self):
|
| 292 |
self.models = {}
|
| 293 |
+
#self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
|
| 294 |
+
self.model_preproc = {}
|
| 295 |
self.labels = get_imagenet_labels()
|
| 296 |
|
| 297 |
def verify_model_integrity(self, model, model_type):
|
|
|
|
| 346 |
return True
|
| 347 |
|
| 348 |
def load_model(self, model_type):
|
|
|
|
| 349 |
if model_type in self.models:
|
| 350 |
print(f"Using cached {model_type} model")
|
| 351 |
return self.models[model_type]
|
| 352 |
+
|
|
|
|
| 353 |
start_time = time.time()
|
| 354 |
model_path = download_model(model_type)
|
| 355 |
+
|
| 356 |
+
# pick preproc for this model
|
| 357 |
+
pre = MODEL_PREPROC.get(model_type, {"size": 224, "mean": IMAGENET_MEAN, "std": IMAGENET_STD})
|
| 358 |
+
normalizer = NormalizeByChannelMeanStd(pre["mean"], pre["std"]).to(device)
|
| 359 |
+
self.model_preproc[model_type] = pre
|
| 360 |
+
|
| 361 |
resnet = models.resnet50()
|
| 362 |
+
model = nn.Sequential(normalizer, resnet)
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
# Load the model checkpoint
|
| 365 |
if model_path:
|