dthh commited on
Commit
dd9d3f9
·
verified ·
1 Parent(s): 335feed

Upload Models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. Models.py +457 -0
Models.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # import sys
3
+ from functools import partial
4
+ from pathlib import Path
5
+ import logging
6
+
7
+ import torch
8
+ from huggingface_hub import hf_hub_download
9
+ from torch import Tensor, nn
10
+ from torchvision import models, transforms
11
+ import pandas as pd
12
+ from collections import defaultdict
13
+
14
+
15
+ class ModelInterface:
16
+ """
17
+ Interface for managing image classification and regression tasks.
18
+
19
+ """
20
+ def __init__(self, config):
21
+ """
22
+ Initialize the ModelInterface.
23
+
24
+ Parameters:
25
+ config (dict): Configuration dictionary containing the following keys:
26
+ - gpu_kernel (int): GPU index to use for computations. Defaults to the first available GPU if available, otherwise CPU.
27
+ - transform_surface (dict): Parameters for surface type and quality image transformations, including resize, crop, and normalization settings.
28
+ - transform_road_type (dict): Parameters for road type image transformations, similar to surface transformations.
29
+ - model_root (str): Directory path where model files are stored locally. Defaults to folder name 'models'.
30
+ - models (dict): Dictionary mapping prediction levels (e.g., 'road_type', 'surface_type') to model file names.
31
+ - hf_model_repo (str): Hugging Face repository ID for downloading models if not found locally.
32
+ """
33
+ self.device = self._validate_device(config.get('gpu_kernel', ''))
34
+ self.model_root = Path(config.get("model_root", "models"))
35
+ self.models = config.get("models")
36
+ self.hf_model_repo = config.get("hf_model_repo", "")
37
+ self._validate_models()
38
+ self._default_normalization = (NORM_MEAN, NORM_SD)
39
+ self.transform_surface = self._validate_transform(config.get("transform_surface", None), "surface_type")
40
+ self.transform_road_type = self._validate_transform(config.get("transform_road_type", None), "road_type")
41
+
42
+ def _validate_device(self, gpu_kernel):
43
+ try:
44
+ cuda = "cuda" if gpu_kernel == '' else f"cuda:{gpu_kernel}"
45
+ return torch.device(
46
+ cuda if torch.cuda.is_available() else "cpu"
47
+ )
48
+ except Exception as e:
49
+ logging.warning(f"An unexpected error occurred while selecting GPU: {e}\n"
50
+ + "Falling back to CPU.")
51
+ return torch.device("cpu")
52
+
53
+ def _validate_models(self):
54
+ """
55
+ Check if model files exist and download from hugging face if not.
56
+ """
57
+ if self.models is None:
58
+ raise TypeError("No models are defined.")
59
+
60
+ log_model_not_defined = "No model for '{level_string}' is defined. Prediction is skipped."
61
+
62
+ # check surface type model
63
+ level = "surface_type"
64
+ model_file = self.models.get(level)
65
+ if model_file is None:
66
+ logging.warning(log_model_not_defined.format(level_string=model_to_info_string[level]))
67
+ else:
68
+ self.download_model(model_file)
69
+ _, surface_class_to_idx, _ = self.load_model(model=model_file)
70
+
71
+ # check quality models
72
+ level = "surface_quality"
73
+ sub_models = self.models.get(level)
74
+ if model_file is None:
75
+ logging.warning(log_model_not_defined.format(level_string=model_to_info_string[level]))
76
+ else:
77
+ for surface_type in surface_class_to_idx:
78
+ model_file = sub_models.get(surface_type)
79
+ if model_file is None:
80
+ logging.warning(log_model_not_defined.format(level_string=surface_type))
81
+ else:
82
+ self.download_model(model_file)
83
+ self.load_model(model=model_file)
84
+
85
+ # check road type model
86
+ level = "road_type"
87
+ model_file = self.models.get(level)
88
+ if model_file is None:
89
+ logging.warning(log_model_not_defined.format(level_string=model_to_info_string[level]))
90
+ else:
91
+ self.download_model(model_file)
92
+ self.load_model(model=model_file)
93
+
94
+ def _validate_transform(self, transform, level):
95
+ """
96
+ Validate the transformation for a given model type if the model exists.
97
+
98
+ Parameters:
99
+ - transform (dict): transformation.
100
+ - level (str): model level.
101
+
102
+ Returns:
103
+ dict: transformation.
104
+ """
105
+ if (level in self.models) and (transform is None):
106
+ logging.warning(f"No transformation for {model_to_info_string[level]} prediction defined.")
107
+ transform = {}
108
+
109
+ if "normalize" not in transform:
110
+ logging.info(f"No normalization parameters for {model_to_info_string[level]} prediction provided. Using default values.")
111
+ transform["normalize"] = self._default_normalization
112
+
113
+ return transform
114
+
115
+
116
+ def download_model(self, model):
117
+ """
118
+ Download a model from Hugging Face repository.
119
+
120
+ Parameters:
121
+ - model (str): Model file name.
122
+
123
+ Returns:
124
+ None
125
+ """
126
+ model_path = self.model_root / model
127
+ # load model data from hugging face if not locally available
128
+ if not os.path.exists(model_path):
129
+ logging.info(
130
+ f"Model file not found at {model_path}. Downloading from Hugging Face..."
131
+ )
132
+ try:
133
+ os.makedirs(self.model_root, exist_ok=True)
134
+ model_path = hf_hub_download(
135
+ repo_id=self.hf_model_repo, filename=model, local_dir=self.model_root
136
+ )
137
+ logging.info(f"Model file downloaded successfully to {model_path}.")
138
+ except Exception as e:
139
+ logging.error(f"An unexpected error occurred while downloading the model {model}: {e}")
140
+ raise e
141
+
142
+
143
+ @staticmethod
144
+ def custom_crop(img, crop_style=None):
145
+ """
146
+ Crop an image according to the specified style.
147
+
148
+ Parameters:
149
+ - img (PIL.Image): Input image to be cropped.
150
+ - crop_style (str, optional): Style of cropping (e.g., 'lower_middle_half').
151
+
152
+ Returns:
153
+ PIL.Image: Cropped image.
154
+ """
155
+ im_width, im_height = img.size
156
+ if crop_style == CROP_LOWER_MIDDLE_HALF:
157
+ top = im_height / 2
158
+ left = im_width / 4
159
+ height = im_height / 2
160
+ width = im_width / 2
161
+ elif crop_style == CROP_LOWER_HALF:
162
+ top = im_height / 2
163
+ left = 0
164
+ height = im_height / 2
165
+ width = im_width
166
+ else: # None, or not valid
167
+ logging.warning(f"Cropping method {crop_style} is not defined. Image is not cropped.")
168
+ return img
169
+
170
+ cropped_img = transforms.functional.crop(img, top, left, height, width)
171
+ return cropped_img
172
+
173
+ def transform(
174
+ self,
175
+ resize=None,
176
+ crop=None,
177
+ to_tensor=True,
178
+ normalize=None,
179
+ ):
180
+ """
181
+ Create a PyTorch image transformation function based on specified parameters.
182
+
183
+ Parameters:
184
+ - resize ((int, int) or int, optional): Target size for resizing, e.g. (height, width). If int, then used for both height and width.
185
+ - crop (str, optional): crop style e.g. 'lower_middle_third'
186
+ - to_tensor (bool, optional): Converts the PIL Image (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
187
+ - normalize (tuple of lists [r, g, b], optional): Mean and standard deviation for normalization.
188
+
189
+ Returns:
190
+ PyTorch image transformation function.
191
+ """
192
+ transform_list = []
193
+
194
+ if crop is not None:
195
+ transform_list.append(
196
+ transforms.Lambda(partial(self.custom_crop, crop_style=crop))
197
+ )
198
+
199
+ if resize is not None:
200
+ if isinstance(resize, int):
201
+ resize = (resize, resize)
202
+ transform_list.append(transforms.Resize(resize))
203
+
204
+ if to_tensor:
205
+ transform_list.append(transforms.ToTensor())
206
+
207
+ if normalize is not None:
208
+ transform_list.append(transforms.Normalize(*normalize))
209
+
210
+ composed_transform = transforms.Compose(transform_list)
211
+ return composed_transform
212
+
213
+ def preprocessing(self, img_data_raw, transform):
214
+ """
215
+ Preprocess raw image data using a specified transformation.
216
+
217
+ Parameters:
218
+ - img_data_raw (list): List of raw images to preprocess.
219
+ - transform (dict): Dictionary of transformation parameters.
220
+
221
+ Returns:
222
+ torch.Tensor: Preprocessed image tensor.
223
+ """
224
+ if not img_data_raw:
225
+ raise ValueError("Image data is empty.")
226
+
227
+ transform = self.transform(**transform)
228
+ img_data = torch.stack([transform(img) for img in img_data_raw])
229
+ return img_data
230
+
231
+ def load_model(self, model):
232
+ """
233
+ Load a model from local storage.
234
+
235
+ Parameters:
236
+ - model (str): Model file name.
237
+
238
+ Returns:
239
+ nn.Module: Loaded model.
240
+ dict: Mapping of classes to indices.
241
+ bool: Whether the model is for regression.
242
+ """
243
+ model_path = self.model_root / model
244
+ try:
245
+ model_state = torch.load(model_path, map_location=self.device)
246
+ model_name = model_state["model_name"]
247
+ is_regression = model_state["is_regression"]
248
+ class_to_idx = model_state["class_to_idx"]
249
+ num_classes = 1 if is_regression else len(class_to_idx.items())
250
+ model_state_dict = model_state["model_state_dict"]
251
+ model_cls = model_mapping[model_name]
252
+ model = model_cls(num_classes=num_classes)
253
+ model.load_state_dict(model_state_dict)
254
+ except Exception as e:
255
+ logging.error(f"An unexpected error occurred while loading the model {model_path}: {e}")
256
+ raise e
257
+
258
+ return model, class_to_idx, is_regression
259
+
260
+ def predict(self, model, data):
261
+ """
262
+ Perform predictions using the specified model and input data.
263
+
264
+ Parameters:
265
+ - model (nn.Module): The model to use for predictions.
266
+ - data (torch.Tensor): Batch of input data.
267
+
268
+ Returns:
269
+ torch.Tensor: Predicted values or class probabilities.
270
+ """
271
+ model.to(self.device)
272
+ model.eval()
273
+
274
+ image_batch = data.to(self.device)
275
+
276
+ with torch.no_grad():
277
+ batch_outputs = model(image_batch)
278
+ # batch_classes, batch_values = model.get_class_and_value(batch_outputs)
279
+ batch_values = model.get_class_probabilities(batch_outputs)
280
+
281
+ return batch_values
282
+
283
+ @staticmethod
284
+ def predict_to_classes(batch_values, class_to_idx):
285
+ """
286
+ Map predicted values to classes.
287
+
288
+ Parameters:
289
+ - batch_values (torch.Tensor): Batch of prediction values.
290
+ - class_to_idx (dict): Mapping from class names to indices.
291
+
292
+ Returns:
293
+ list: List of predicted values.
294
+ list: List of predicted classes.
295
+ """
296
+ idx_to_class = {i: cls for cls, i in class_to_idx.items()}
297
+
298
+ if len(list(batch_values.shape)) < 2:
299
+ classes = [
300
+ idx_to_class[
301
+ min(
302
+ max(idx.item(), min(list(class_to_idx.values()))),
303
+ max(list(class_to_idx.values())),
304
+ )
305
+ ]
306
+ for idx in batch_values.round().int()
307
+ ]
308
+ values = batch_values.tolist()
309
+ else:
310
+ classes = [idx_to_class[idx.item()] for idx in torch.argmax(batch_values, dim=1)]
311
+ values = batch_values.tolist()
312
+
313
+ return values, classes
314
+
315
+ def batch_classifications(self, img_data_raw, img_ids=None):
316
+ """
317
+ Perform batch classification for multiple prediction levels (road type, surface type, surface quality).
318
+
319
+ Parameters:
320
+ - img_data_raw (list): List of raw images to classify.
321
+ - img_ids (list, optional): List of IDs corresponding to the images. Defaults to indices.
322
+
323
+ Returns:
324
+ list: Combined list of image ids and predictions across levels.
325
+ """
326
+ if not img_data_raw:
327
+ logging.info("Input data is empty. No predictions performed.")
328
+ return []
329
+
330
+ # default image ids
331
+ if img_ids is None:
332
+ img_ids = range(len(img_data_raw))
333
+
334
+ # road type
335
+ level = "road_type"
336
+ model_file = self.models.get(level)
337
+ if model_file is not None:
338
+ model, class_to_idx, _ = self.load_model(model=model_file)
339
+ data = self.preprocessing(img_data_raw, self.transform_road_type)
340
+ values = self.predict(model, data)
341
+ road_values, road_classes = self.predict_to_classes(values, class_to_idx)
342
+
343
+ # surface type
344
+ level = "surface_type"
345
+ model_file = self.models.get(level)
346
+ if model_file is not None:
347
+ model, class_to_idx, _ = self.load_model(model=model_file)
348
+ data = self.preprocessing(img_data_raw, self.transform_surface)
349
+ values = self.predict(model, data)
350
+ surface_values, surface_classes = self.predict_to_classes(values, class_to_idx)
351
+
352
+ # surface quality
353
+ level = "surface_quality"
354
+ sub_models = self.models.get(level)
355
+ if sub_models is not None:
356
+ surface_indices = defaultdict(list)
357
+ for i, surface_type in enumerate(surface_classes):
358
+ surface_indices[surface_type].append(i)
359
+
360
+ quality_values = [None] * len(img_data_raw)
361
+ quality_classes = [None] * len(img_data_raw)
362
+ for surface_type, indices in surface_indices.items():
363
+ model_file = sub_models.get(surface_type)
364
+ if model_file is not None:
365
+ model, class_to_idx, _ = self.load_model(model=model_file)
366
+ values = self.predict(model, data[indices])
367
+ values, classes = self.predict_to_classes(values, class_to_idx)
368
+ for idx, vl, cls in zip(indices, values, classes):
369
+ quality_values[idx] = vl
370
+ quality_classes[idx] = cls
371
+
372
+ # final results combination
373
+ final_results = [
374
+ [
375
+ img_ids[i],
376
+ road_classes[i],
377
+ road_values[i],
378
+ surface_classes[i],
379
+ surface_values[i],
380
+ quality_classes[i],
381
+ quality_values[i],
382
+ ]
383
+ for i in range(len(img_data_raw))
384
+ ]
385
+
386
+ return final_results
387
+
388
+
389
+ class CustomEfficientNetV2SLinear(nn.Module):
390
+ """
391
+ Custom implementation of EfficientNetV2-S with a linear classifier for classification or regression tasks.
392
+
393
+ Attributes:
394
+ features (nn.Sequential): Feature extractor from EfficientNetV2-S.
395
+ avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer.
396
+ classifier (nn.Sequential): Fully connected layers for classification.
397
+ is_regression (bool): Whether the model is configured for regression tasks.
398
+ criterion (callable): Loss function used for training the model.
399
+ """
400
+
401
+ def __init__(self, num_classes, avg_pool=1):
402
+ super(CustomEfficientNetV2SLinear, self).__init__()
403
+
404
+ model = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
405
+ # adapt output layer
406
+ in_features = model.classifier[-1].in_features * (avg_pool * avg_pool)
407
+ fc = nn.Linear(in_features, num_classes, bias=True)
408
+ model.classifier[-1] = fc
409
+
410
+ self.features = model.features
411
+ self.avgpool = nn.AdaptiveAvgPool2d(avg_pool)
412
+ self.classifier = model.classifier
413
+ if num_classes == 1:
414
+ self.criterion = nn.MSELoss
415
+ self.is_regression = True
416
+ else:
417
+ self.criterion = nn.CrossEntropyLoss
418
+ self.is_regression = False
419
+
420
+ def get_class_probabilities(self, x):
421
+ if self.is_regression:
422
+ x = x.flatten()
423
+ else:
424
+ x = nn.functional.softmax(x, dim=1)
425
+ return x
426
+
427
+ def forward(self, x: Tensor) -> Tensor:
428
+ x = self.features(x)
429
+
430
+ x = self.avgpool(x)
431
+ x = torch.flatten(x, 1)
432
+
433
+ x = self.classifier(x)
434
+
435
+ return x
436
+
437
+ # def get_optimizer_layers(self):
438
+ # return self.classifier
439
+
440
+
441
+ # Constants
442
+ EFFNET_LINEAR = "efficientNetV2SLinear"
443
+ CROP_LOWER_MIDDLE_HALF = "lower_middle_half"
444
+ CROP_LOWER_HALF = "lower_half"
445
+ NORM_MEAN = [0.42834484577178955, 0.4461250305175781, 0.4350937306880951]
446
+ NORM_SD = [0.22991590201854706, 0.23555299639701843, 0.26348039507865906]
447
+
448
+
449
+ model_mapping = {
450
+ EFFNET_LINEAR: CustomEfficientNetV2SLinear,
451
+ }
452
+
453
+ model_to_info_string = {
454
+ "surface_type": "surface type",
455
+ "road_type": "road type",
456
+ "surface_quality": "quality",
457
+ }