astrosbd commited on
Commit
5c783e4
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # OS
35
+ .DS_Store
36
+ Thumbs.db
37
+
38
+ # Logs
39
+ *.log
40
+
41
+ # Model files
42
+ *.pth
43
+ *.pt
44
+ *.ckpt
45
+ *.bin
46
+
47
+ # Config files
48
+ *.yaml
49
+ *.yml
50
+ !configs/*.yaml
51
+ !configs/*.yml
README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Car Damage Insurance Fraud Detector
3
+ emoji: 🚗
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.50.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # Car Damage Insurance Fraud Detector
14
+
15
+ A sophisticated AI-powered system that detects car damage and potential insurance fraud using deep learning models.
16
+
17
+ ## Features
18
+
19
+ - Damage Detection: Identifies and localizes car damage using Detectron2
20
+ - Deepfake Detection: Analyzes images for potential manipulation
21
+ - User-friendly Interface: Built with Gradio for easy interaction
22
+ - Multi-device Support: Works on CPU, CUDA, and MPS (Apple Silicon)
23
+
24
+ ## Requirements
25
+
26
+ - Python 3.8+
27
+ - PyTorch
28
+ - OpenCV
29
+ - Gradio
30
+ - Detectron2 (optional, not available for macOS)
31
+
32
+ ## Installation
33
+
34
+ 1. Clone the repository
35
+ 2. Install dependencies:
36
+ ```bash
37
+ pip install -r requirements.txt
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ 1. Run the application:
43
+ ```bash
44
+ python app.py
45
+ ```
46
+
47
+ 2. Open your browser and navigate to the provided local URL
48
+
49
+ ## Model Requirements
50
+
51
+ - Damage detection model (Detectron2 format)
52
+ - Deepfake detection model (custom format)
53
+
54
+ ## License
55
+
56
+ Apache 2.0
57
+
58
+ ## Note
59
+
60
+ This application requires pre-trained models for both damage detection and deepfake detection. Make sure to have the appropriate model files in the correct locations before running the application.
app.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ import sys
5
+ import time
6
+ import cv2
7
+ import torch
8
+ import numpy as np
9
+ import gradio as gr
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+
13
+ # Add current directory to path
14
+ if not os.getcwd() in sys.path:
15
+ sys.path.append(os.getcwd())
16
+
17
+ # Detectron2 imports - wrapped in try-except to make them optional
18
+ try:
19
+ from detectron2.engine import DefaultPredictor
20
+ from detectron2.config import get_cfg
21
+ from detectron2.utils.visualizer import Visualizer, ColorMode
22
+ from detectron2 import model_zoo
23
+ DETECTRON2_AVAILABLE = True
24
+ except ImportError:
25
+ print("Warning: Detectron2 is not installed. Damage detection will not be available.")
26
+ DETECTRON2_AVAILABLE = False
27
+
28
+ # Check for custom path for models
29
+ try:
30
+ from configs.get_config import load_config
31
+ from models import *
32
+ MODELS_IMPORTED = True
33
+ except ImportError:
34
+ print("Warning: Custom models couldn't be imported. Only damage detection will work.")
35
+ MODELS_IMPORTED = False
36
+
37
+ def setup_device(device_str):
38
+ """Set up the computation device based on user input and availability"""
39
+ if device_str == 'auto':
40
+ if torch.cuda.is_available():
41
+ return torch.device('cuda:0')
42
+ elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
43
+ return torch.device('mps')
44
+ else:
45
+ return torch.device('cpu')
46
+ elif device_str == 'cuda' and torch.cuda.is_available():
47
+ return torch.device('cuda:0')
48
+ elif device_str == 'mps' and hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
49
+ return torch.device('mps')
50
+ else:
51
+ print(f"Warning: Device {device_str} not available, using CPU instead.")
52
+ return torch.device('cpu')
53
+
54
+ def setup_damage_detector(model_path, threshold=0.7):
55
+ """Set up the damage detection model using Detectron2"""
56
+ if not DETECTRON2_AVAILABLE:
57
+ print("Detectron2 is not installed. Cannot set up damage detector.")
58
+ return None, None
59
+
60
+ if model_path is None or not os.path.exists(model_path):
61
+ print("No damage model specified or file not found. Skipping damage detection.")
62
+ return None, None
63
+
64
+ cfg = get_cfg()
65
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
66
+ cfg.MODEL.WEIGHTS = model_path
67
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # Only one class (damage)
68
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
69
+
70
+ # Explicitly set to use CPU if on Mac (MPS)
71
+ if torch.backends.mps.is_available():
72
+ cfg.MODEL.DEVICE = "cpu"
73
+ print("Mac MPS detected - forcing Detectron2 to use CPU")
74
+
75
+ try:
76
+ predictor = DefaultPredictor(cfg)
77
+ return predictor, cfg
78
+ except Exception as e:
79
+ print(f"Error setting up damage detector: {e}")
80
+ return None, cfg
81
+
82
+ def load_deepfake_model(model_path, cfg_path, device):
83
+ """Load the deepfake detection model"""
84
+ if not MODELS_IMPORTED:
85
+ print("Custom models module not imported. Cannot load deepfake model.")
86
+ return None, None
87
+
88
+ if model_path is None or not os.path.exists(model_path):
89
+ print("No deepfake model specified or file not found. Skipping deepfake detection.")
90
+ return None, None
91
+
92
+ if cfg_path is None or not os.path.exists(cfg_path):
93
+ print("No deepfake config specified or file not found. Skipping deepfake detection.")
94
+ return None, None
95
+
96
+ try:
97
+ # Load config
98
+ cfg = load_config(cfg_path)
99
+
100
+ # Build model
101
+ model = build_model(cfg.MODEL, MODELS)
102
+
103
+ # Load weights
104
+ print(f"Loading deepfake model from: {model_path}")
105
+ checkpoint = torch.load(model_path, map_location='cpu')
106
+
107
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
108
+ model.load_state_dict(checkpoint['state_dict'])
109
+ else:
110
+ model.load_state_dict(checkpoint)
111
+
112
+ # Move model to device and set to evaluation mode
113
+ model = model.to(device)
114
+ if hasattr(cfg.MODEL, 'precision') and cfg.MODEL.precision == 'fp64':
115
+ model = model.to(torch.float64)
116
+ model.eval()
117
+
118
+ return model, cfg
119
+ except Exception as e:
120
+ print(f"Error loading deepfake model: {e}")
121
+ import traceback
122
+ traceback.print_exc()
123
+ return None, None
124
+
125
+ def preprocess_for_deepfake(image, cfg, device):
126
+ """Preprocess an image for deepfake detection"""
127
+ try:
128
+ # Convert to RGB if needed
129
+ if len(image.shape) == 3 and image.shape[2] == 3:
130
+ if image.dtype != np.uint8:
131
+ image = (image * 255).astype(np.uint8)
132
+ rgb_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
133
+ else:
134
+ rgb_img = image
135
+
136
+ # Resize
137
+ img_resized = cv2.resize(rgb_img, (cfg.DATASET.IMAGE_SIZE[0], cfg.DATASET.IMAGE_SIZE[1]))
138
+
139
+ # Convert to PIL and apply transforms
140
+ transform = transforms.Compose([
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(
143
+ mean=cfg.DATASET.TRANSFORM.normalize.mean,
144
+ std=cfg.DATASET.TRANSFORM.normalize.std
145
+ )
146
+ ])
147
+
148
+ img_tensor = transform(Image.fromarray(img_resized)).unsqueeze(0) # Add batch dimension
149
+ img_tensor = img_tensor.to(device)
150
+
151
+ # Convert to correct precision
152
+ if hasattr(cfg.MODEL, 'precision') and cfg.MODEL.precision == 'fp64':
153
+ img_tensor = img_tensor.to(torch.float64)
154
+
155
+ return img_tensor
156
+ except Exception as e:
157
+ print(f"Error preprocessing image for deepfake detection: {e}")
158
+ import traceback
159
+ traceback.print_exc()
160
+ return None
161
+
162
+ def detect_damage(img, damage_detector):
163
+ """Detect damage in an image"""
164
+ try:
165
+ if img is None:
166
+ raise ValueError("Invalid image")
167
+
168
+ # If no damage detector available, return the whole image as region
169
+ if damage_detector is None:
170
+ print("No damage detector available. Using whole image as region.")
171
+ h, w = img.shape[:2]
172
+ damage_regions = [{
173
+ "box": (0, 0, w, h),
174
+ "score": 1.0,
175
+ "mask": None
176
+ }]
177
+ return img, None, damage_regions
178
+
179
+ # Run inference
180
+ outputs = damage_detector(img)
181
+
182
+ # Get damage regions
183
+ instances = outputs["instances"].to("cpu")
184
+ boxes = instances.pred_boxes.tensor.numpy() if instances.has("pred_boxes") else []
185
+ scores = instances.scores.numpy() if instances.has("scores") else []
186
+ masks = instances.pred_masks.numpy() if instances.has("pred_masks") else []
187
+
188
+ damage_regions = []
189
+ for i in range(len(boxes)):
190
+ x1, y1, x2, y2 = map(int, boxes[i])
191
+ damage_regions.append({
192
+ "box": (x1, y1, x2, y2),
193
+ "score": float(scores[i]),
194
+ "mask": masks[i] if len(masks) > i else None
195
+ })
196
+
197
+ if not damage_regions:
198
+ print("No damage detected. Using whole image.")
199
+ h, w = img.shape[:2]
200
+ damage_regions = [{
201
+ "box": (0, 0, w, h),
202
+ "score": 1.0,
203
+ "mask": None
204
+ }]
205
+
206
+ return img, outputs, damage_regions
207
+ except Exception as e:
208
+ print(f"Error detecting damage: {e}")
209
+ # If error occurs, return the whole image as region
210
+ if 'img' in locals() and img is not None:
211
+ h, w = img.shape[:2]
212
+ damage_regions = [{
213
+ "box": (0, 0, w, h),
214
+ "score": 1.0,
215
+ "mask": None
216
+ }]
217
+ return img, None, damage_regions
218
+ return None, None, []
219
+
220
+ def check_deepfake(image, damage_regions, deepfake_model, deepfake_cfg, device, threshold=0.5):
221
+ """Check if damage regions are deepfakes"""
222
+ results = []
223
+
224
+ if deepfake_model is None:
225
+ print("No deepfake model available. Skipping deepfake detection.")
226
+ return []
227
+
228
+ try:
229
+ # If no damage regions, check the entire image
230
+ if not damage_regions:
231
+ img_tensor = preprocess_for_deepfake(image, deepfake_cfg, device)
232
+ if img_tensor is None:
233
+ return []
234
+
235
+ # Run inference
236
+ with torch.no_grad():
237
+ outputs = deepfake_model(img_tensor)
238
+
239
+ # Extract outputs
240
+ if isinstance(outputs, list):
241
+ outputs = outputs[0]
242
+
243
+ if isinstance(outputs, dict) and 'cls' in outputs:
244
+ cls_outputs = outputs['cls']
245
+ cls_prob = cls_outputs.sigmoid().cpu().numpy()
246
+ else:
247
+ # Assuming the output is directly the classification probability
248
+ cls_prob = outputs.sigmoid().cpu().numpy() if hasattr(outputs, 'sigmoid') else outputs.cpu().numpy()
249
+
250
+ if cls_prob.size > 0:
251
+ is_fake = cls_prob[0][0] > threshold if cls_prob.ndim > 1 else cls_prob[0] > threshold
252
+ confidence = cls_prob[0][0] if cls_prob.ndim > 1 else cls_prob[0]
253
+
254
+ results.append({
255
+ "region": "full_image",
256
+ "deepfake_prob": float(confidence),
257
+ "is_fake": bool(is_fake)
258
+ })
259
+
260
+ return results
261
+
262
+ # Process each damage region
263
+ for i, region in enumerate(damage_regions):
264
+ x1, y1, x2, y2 = region["box"]
265
+ # Ensure coordinates are within image bounds
266
+ x1, y1 = max(0, x1), max(0, y1)
267
+ x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
268
+
269
+ # Extract region and check if it's a deepfake
270
+ if x2 > x1 and y2 > y1:
271
+ # Get ROI
272
+ roi = image[y1:y2, x1:x2]
273
+
274
+ # Preprocess
275
+ img_tensor = preprocess_for_deepfake(roi, deepfake_cfg, device)
276
+ if img_tensor is None:
277
+ continue
278
+
279
+ # Run inference
280
+ with torch.no_grad():
281
+ outputs = deepfake_model(img_tensor)
282
+
283
+ # Extract outputs
284
+ if isinstance(outputs, list):
285
+ outputs = outputs[0]
286
+
287
+ if isinstance(outputs, dict) and 'cls' in outputs:
288
+ cls_outputs = outputs['cls']
289
+ cls_prob = cls_outputs.sigmoid().cpu().numpy()
290
+ else:
291
+ # Assuming the output is directly the classification probability
292
+ cls_prob = outputs.sigmoid().cpu().numpy() if hasattr(outputs, 'sigmoid') else outputs.cpu().numpy()
293
+
294
+ if cls_prob.size > 0:
295
+ is_fake = cls_prob[0][0] > threshold if cls_prob.ndim > 1 else cls_prob[0] > threshold
296
+ confidence = cls_prob[0][0] if cls_prob.ndim > 1 else cls_prob[0]
297
+
298
+ results.append({
299
+ "region_id": i,
300
+ "box": (x1, y1, x2, y2),
301
+ "deepfake_prob": float(confidence),
302
+ "is_fake": bool(is_fake)
303
+ })
304
+
305
+ return results
306
+ except Exception as e:
307
+ print(f"Error in deepfake detection: {e}")
308
+ import traceback
309
+ traceback.print_exc()
310
+ return []
311
+
312
+ def visualize_results(image, damage_outputs, deepfake_results, damage_threshold):
313
+ """Create visualization of damage detection and deepfake verification"""
314
+ try:
315
+ # Create a copy for visualization
316
+ img_copy = image.copy()
317
+
318
+ # Draw damage detection results
319
+ if damage_outputs is not None and DETECTRON2_AVAILABLE:
320
+ try:
321
+ v = Visualizer(img_copy[:, :, ::-1], scale=1.0, instance_mode=ColorMode.IMAGE_BW)
322
+ v = v.draw_instance_predictions(damage_outputs["instances"].to("cpu"))
323
+ result_img = v.get_image()[:, :, ::-1]
324
+
325
+ # Convert to a standard numpy array to ensure compatibility with OpenCV
326
+ result_img = np.array(result_img, dtype=np.uint8)
327
+ except Exception as e:
328
+ print(f"Error visualizing damage detection: {e}")
329
+ result_img = img_copy
330
+ else:
331
+ result_img = img_copy
332
+
333
+ # Add deepfake detection results
334
+ for result in deepfake_results:
335
+ try:
336
+ if "box" in result:
337
+ x1, y1, x2, y2 = result["box"]
338
+ fake_prob = result["deepfake_prob"]
339
+ is_fake = result["is_fake"]
340
+ region_id = result.get("region_id", 0)
341
+
342
+ # Text for the region
343
+ text = f"R{region_id}: {'FAKE' if is_fake else 'REAL'} ({fake_prob*100:.1f}%)"
344
+
345
+ # Different colors for fake/real
346
+ color = (0, 0, 255) if is_fake else (0, 255, 0) # Red for fake, green for real
347
+
348
+ # Ensure we have a standard numpy array
349
+ if not isinstance(result_img, np.ndarray):
350
+ result_img = np.array(result_img, dtype=np.uint8)
351
+
352
+ # Draw rectangle and text
353
+ cv2.rectangle(result_img, (x1, y1), (x2, y2), color, 2)
354
+ cv2.putText(result_img, text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
355
+ elif "region" in result and result["region"] == "full_image":
356
+ fake_prob = result["deepfake_prob"]
357
+ is_fake = result["is_fake"]
358
+
359
+ # Text for the whole image
360
+ text = f"Image: {'FAKE' if is_fake else 'REAL'} ({fake_prob*100:.1f}%)"
361
+
362
+ # Different colors for fake/real
363
+ color = (0, 0, 255) if is_fake else (0, 255, 0) # Red for fake, green for real
364
+
365
+ # Ensure we have a standard numpy array
366
+ if not isinstance(result_img, np.ndarray):
367
+ result_img = np.array(result_img, dtype=np.uint8)
368
+
369
+ # Draw text
370
+ cv2.putText(result_img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
371
+ except Exception as e:
372
+ print(f"Error drawing result {result}: {e}")
373
+
374
+ return result_img
375
+ except Exception as e:
376
+ print(f"Error visualizing results: {e}")
377
+ import traceback
378
+ traceback.print_exc()
379
+ return np.array(image, dtype=np.uint8) # Return the original image as a numpy array
380
+
381
+ def process_image(input_image, damage_model_path, deepfake_model_path, deepfake_cfg_path,
382
+ damage_threshold, deepfake_threshold, skip_damage, device_str):
383
+ """Process an image through the car damage and deepfake detection pipeline"""
384
+ progress_info = []
385
+
386
+ # Convert Gradio image to numpy array
387
+ if isinstance(input_image, dict) and "path" in input_image:
388
+ img = cv2.imread(input_image["path"])
389
+ elif isinstance(input_image, str):
390
+ img = cv2.imread(input_image)
391
+ elif isinstance(input_image, np.ndarray):
392
+ # Make a copy to avoid modifying the original
393
+ img = input_image.copy()
394
+ # Convert from RGB to BGR (OpenCV format)
395
+ if len(img.shape) == 3 and img.shape[2] == 3:
396
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
397
+ else:
398
+ return None, "Error: Unsupported image format"
399
+
400
+ if img is None:
401
+ return None, "Error: Could not read the image"
402
+
403
+ # Progress update
404
+ progress_info.append("Image loaded successfully")
405
+
406
+ # Setup device
407
+ device = setup_device(device_str)
408
+ progress_info.append(f"Using device: {device}")
409
+
410
+ # Initialize models
411
+ damage_detector = None
412
+ deepfake_model = None
413
+ deepfake_cfg = None
414
+
415
+ # Setup damage detector if not skipped
416
+ if not skip_damage and damage_model_path:
417
+ progress_info.append("Setting up damage detector...")
418
+ damage_detector, detector_cfg = setup_damage_detector(damage_model_path, float(damage_threshold))
419
+ if damage_detector is None and DETECTRON2_AVAILABLE:
420
+ progress_info.append("Failed to initialize damage detector")
421
+ else:
422
+ progress_info.append("Damage detector initialized successfully")
423
+
424
+ # Setup deepfake detector
425
+ if deepfake_model_path and deepfake_cfg_path:
426
+ progress_info.append("Setting up deepfake detector...")
427
+ deepfake_model, deepfake_cfg = load_deepfake_model(deepfake_model_path, deepfake_cfg_path, device)
428
+ if deepfake_model is None:
429
+ progress_info.append("Failed to initialize deepfake detector")
430
+ else:
431
+ progress_info.append("Deepfake detector initialized successfully")
432
+
433
+ # Ensure at least one detector is working
434
+ if damage_detector is None and deepfake_model is None:
435
+ return None, "Error: Neither damage nor deepfake detector is available"
436
+
437
+ # Step 1: Detect damage or use whole image
438
+ progress_info.append("Detecting damage regions...")
439
+ start_time = time.time()
440
+ img, damage_outputs, damage_regions = detect_damage(img, damage_detector)
441
+ damage_time = time.time() - start_time
442
+
443
+ if img is None:
444
+ return None, "Error: Failed to process image"
445
+
446
+ # Print damage detection results
447
+ if damage_detector is not None and damage_regions:
448
+ progress_info.append(f"Detected {len(damage_regions)} damage regions in {damage_time:.3f} seconds")
449
+ else:
450
+ progress_info.append("Using the whole image for analysis")
451
+
452
+ # Step 2: Check if damage is deepfake
453
+ deepfake_results = []
454
+ if deepfake_model is not None:
455
+ progress_info.append("Performing deepfake detection...")
456
+ start_time = time.time()
457
+ deepfake_results = check_deepfake(
458
+ img, damage_regions, deepfake_model, deepfake_cfg, device, float(deepfake_threshold)
459
+ )
460
+ deepfake_time = time.time() - start_time
461
+
462
+ if deepfake_results:
463
+ progress_info.append(f"Deepfake detection completed in {deepfake_time:.3f} seconds")
464
+
465
+ # Generate report
466
+ for result in deepfake_results:
467
+ if "region_id" in result:
468
+ region_id = result["region_id"]
469
+ fake_prob = result["deepfake_prob"]
470
+ is_fake = result["is_fake"]
471
+ progress_info.append(f"Region {region_id}: {'FAKE' if is_fake else 'REAL'} (Probability: {fake_prob*100:.2f}%)")
472
+ elif "region" in result and result["region"] == "full_image":
473
+ fake_prob = result["deepfake_prob"]
474
+ is_fake = result["is_fake"]
475
+ progress_info.append(f"Whole image: {'FAKE' if is_fake else 'REAL'} (Probability: {fake_prob*100:.2f}%)")
476
+ else:
477
+ progress_info.append("No deepfake detection results")
478
+
479
+ # Step 3: Visualize final results
480
+ progress_info.append("Generating visualization...")
481
+ result_img = visualize_results(img, damage_outputs, deepfake_results, float(damage_threshold))
482
+
483
+ # Convert back to RGB for Gradio
484
+ if len(result_img.shape) == 3 and result_img.shape[2] == 3:
485
+ result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
486
+
487
+ progress_info.append("Processing complete!")
488
+
489
+ return result_img, "\n".join(progress_info)
490
+
491
+ def create_gradio_interface():
492
+ with gr.Blocks(title="Car Damage & Deepfake Detection") as app:
493
+ gr.Markdown("# Car Damage Detection & Deepfake Verification")
494
+ gr.Markdown("Upload an image to detect car damage and check if it's a deepfake")
495
+
496
+ with gr.Tab("Basic Interface"):
497
+ with gr.Row():
498
+ with gr.Column(scale=1):
499
+ input_image = gr.Image(type="numpy", label="Input Image")
500
+
501
+ # Simple controls
502
+ skip_damage = gr.Checkbox(label="Skip Damage Detection", value=False)
503
+ damage_threshold = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05,
504
+ label="Damage Detection Threshold")
505
+ deepfake_threshold = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05,
506
+ label="Deepfake Detection Threshold")
507
+ device = gr.Dropdown(choices=["auto", "cuda", "cpu", "mps"], value="auto",
508
+ label="Computation Device")
509
+
510
+ process_btn = gr.Button("Process Image", variant="primary")
511
+
512
+ with gr.Column(scale=1):
513
+ output_image = gr.Image(type="numpy", label="Result")
514
+ output_text = gr.Textbox(label="Detection Results", lines=10)
515
+
516
+ with gr.Tab("Advanced Settings"):
517
+ with gr.Row():
518
+ with gr.Column():
519
+ damage_model_path = gr.Textbox(label="Damage Model Path",
520
+ placeholder="Path to damage detection model (.pth)")
521
+ deepfake_model_path = gr.Textbox(label="Deepfake Model Path",
522
+ placeholder="Path to deepfake detection model (.pth)")
523
+ deepfake_cfg_path = gr.Textbox(label="Deepfake Config Path",
524
+ placeholder="Path to deepfake model config (.yaml)")
525
+
526
+ # Connect the process function
527
+ process_btn.click(
528
+ fn=process_image,
529
+ inputs=[
530
+ input_image,
531
+ damage_model_path,
532
+ deepfake_model_path,
533
+ deepfake_cfg_path,
534
+ damage_threshold,
535
+ deepfake_threshold,
536
+ skip_damage,
537
+ device
538
+ ],
539
+ outputs=[output_image, output_text]
540
+ )
541
+
542
+ # Examples
543
+ gr.Markdown("## Examples")
544
+ gr.Markdown("Note: Examples will only work if you have the appropriate models installed.")
545
+
546
+ return app
547
+
548
+ if __name__ == "__main__":
549
+ # Create and launch the Gradio interface
550
+ app = create_gradio_interface()
551
+ app.launch(share=True) # Set share=False in production
configs/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
configs/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
configs/get_config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ import os
3
+
4
+ from yaml import load, dump
5
+ try:
6
+ from yaml import CLoader as Loader, CDumper as Dumper
7
+ except ImportError:
8
+ from yaml import Loader, Dumper
9
+ from box import Box as edict
10
+
11
+
12
+ def load_config(cfg):
13
+ with open(cfg) as f:
14
+ config = load(f, Loader=Loader)
15
+
16
+ return edict(config)
configs/test_config.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mode: test
2
+ lmdb: False
3
+ rgb_dir: '/ssd_scratch/deep_fake_dataset/'
4
+ lmdb_dir: '/ssd_scratch/deep_fake_dataset/datasets_lmdbs/'
5
+ dataset_json_folder: './preprocessing/dataset_json_v6/'
6
+ label_dict:
7
+ # DFD
8
+ DFD_fake: 1
9
+ DFD_real: 0
10
+ # FF++ + FaceShifter(FF-real+FF-FH)
11
+ FF-SH: 1
12
+ FF-F2F: 1
13
+ FF-DF: 1
14
+ FF-FS: 1
15
+ FF-NT: 1
16
+ FF-FH: 1
17
+ FF-real: 0
18
+ # CelebDF
19
+ CelebDFv1_real: 0
20
+ CelebDFv1_fake: 1
21
+ CelebDFv2_real: 0
22
+ CelebDFv2_fake: 1
23
+ # DFDCP
24
+ DFDCP_Real: 0
25
+ DFDCP_FakeA: 1
26
+ DFDCP_FakeB: 1
27
+ # DFDC
28
+ DFDC_Fake: 1
29
+ DFDC_Real: 0
30
+ # DeeperForensics-1.0
31
+ DF_fake: 1
32
+ DF_real: 0
33
+ # UADFV
34
+ UADFV_Fake: 1
35
+ UADFV_Real: 0
36
+ # Roop
37
+ roop_Real: 0
38
+ roop_Fake: 1
configs/train_config copie.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mode: train
2
+ lmdb: False
3
+ dry_run: False
4
+ rgb_dir: '/ssd_scratch/deep_fake_dataset/'
5
+ lmdb_dir: '/ssd_scratch/deep_fake_dataset/datasets_lmdbs/'
6
+ dataset_json_folder: './preprocessing/dataset_json_v6/'
7
+ SWA: False
8
+ save_avg: True
9
+ log_dir: ./logs/training/
10
+ # label settings
11
+ label_dict:
12
+ # DFD
13
+ DFD_fake: 1
14
+ DFD_real: 0
15
+ # FF++ + FaceShifter(FF-real+FF-FH)
16
+ FF-SH: 1
17
+ FF-F2F: 1
18
+ FF-DF: 1
19
+ FF-FS: 1
20
+ FF-NT: 1
21
+ FF-FH: 1
22
+ FF-real: 0
23
+ # CelebDF
24
+ CelebDFv1_real: 0
25
+ CelebDFv1_fake: 1
26
+ CelebDFv2_real: 0
27
+ CelebDFv2_fake: 1
28
+ # DFDCP
29
+ DFDCP_Real: 0
30
+ DFDCP_FakeA: 1
31
+ DFDCP_FakeB: 1
32
+ # DFDC
33
+ DFDC_Fake: 1
34
+ DFDC_Real: 0
35
+ # DeeperForensics-1.0
36
+ DF_fake: 1
37
+ DF_real: 0
38
+ # UADFV
39
+ UADFV_Fake: 1
40
+ UADFV_Real: 0
41
+ # Roop
42
+ roop_Real: 0
43
+ roop_Fake: 1
configs/train_config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mode: train
2
+ lmdb: False
3
+ dry_run: False
4
+ rgb_dir: '/ssd_scratch/deep_fake_dataset/'
5
+ lmdb_dir: '/ssd_scratch/deep_fake_dataset/datasets_lmdbs/'
6
+ dataset_json_folder: './preprocessing/dataset_json_v6/'
7
+ SWA: False
8
+ save_avg: True
9
+ log_dir: ./logs/training/
10
+ # label settings
11
+ label_dict:
12
+ # iFakeFaceDB labels
13
+ real: 0
14
+ fake: 1
15
+ # DFD
16
+ DFD_fake: 1
17
+ DFD_real: 0
18
+ # FF++ + FaceShifter(FF-real+FF-FH)
19
+ FF-SH: 1
20
+ FF-F2F: 1
21
+ FF-DF: 1
22
+ FF-FS: 1
23
+ FF-NT: 1
24
+ FF-FH: 1
25
+ FF-real: 0
26
+ # CelebDF
27
+ CelebDFv1_real: 0
28
+ CelebDFv1_fake: 1
29
+ CelebDFv2_real: 0
30
+ CelebDFv2_fake: 1
31
+ # DFDCP
32
+ DFDCP_Real: 0
33
+ DFDCP_FakeA: 1
34
+ DFDCP_FakeB: 1
35
+ # DFDC
36
+ DFDC_Fake: 1
37
+ DFDC_Real: 0
38
+ # DeeperForensics-1.0
39
+ DF_fake: 1
40
+ DF_real: 0
41
+ # UADFV
42
+ UADFV_Fake: 1
43
+ UADFV_Real: 0
44
+ # Roop
45
+ roop_Real: 0
46
+ roop_Fake: 1
loss/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
8
+
9
+ from metrics.registry import LOSSFUNC
10
+
11
+ from .cross_entropy_loss import CrossEntropyLoss
loss/abstract_loss_func.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class AbstractLossClass(nn.Module):
4
+ """Abstract class for loss functions."""
5
+ def __init__(self):
6
+ super(AbstractLossClass, self).__init__()
7
+
8
+ def forward(self, pred, label):
9
+ """
10
+ Args:
11
+ pred: prediction of the model
12
+ label: ground truth label
13
+
14
+ Return:
15
+ loss: loss value
16
+ """
17
+ raise NotImplementedError('Each subclass should implement the forward method.')
loss/cross_entropy_loss.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .abstract_loss_func import AbstractLossClass
3
+ from metrics.registry import LOSSFUNC
4
+
5
+
6
+ @LOSSFUNC.register_module(module_name="cross_entropy")
7
+ class CrossEntropyLoss(AbstractLossClass):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.loss_fn = nn.CrossEntropyLoss()
11
+
12
+ def forward(self, inputs, targets):
13
+ """
14
+ Computes the cross-entropy loss.
15
+
16
+ Args:
17
+ inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
18
+ targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
19
+
20
+ Returns:
21
+ A scalar tensor representing the cross-entropy loss.
22
+ """
23
+ # Compute the cross-entropy loss
24
+ loss = self.loss_fn(inputs, targets)
25
+
26
+ return loss
metrics/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ current_file_path = os.path.abspath(__file__)
4
+ parent_dir = os.path.dirname(os.path.dirname(current_file_path))
5
+ project_root_dir = os.path.dirname(parent_dir)
6
+ sys.path.append(parent_dir)
7
+ sys.path.append(project_root_dir)
metrics/base_metrics_class.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn import metrics
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def get_accracy(output, label):
8
+ _, prediction = torch.max(output, 1) # argmax
9
+ correct = (prediction == label).sum().item()
10
+ accuracy = correct / prediction.size(0)
11
+ return accuracy
12
+
13
+
14
+ def get_prediction(output, label):
15
+ prob = nn.functional.softmax(output, dim=1)[:, 1]
16
+ prob = prob.view(prob.size(0), 1)
17
+ label = label.view(label.size(0), 1)
18
+ #print(prob.size(), label.size())
19
+ datas = torch.cat((prob, label.float()), dim=1)
20
+ return datas
21
+
22
+
23
+ def calculate_metrics_for_train(label, output):
24
+ if output.size(1) == 2:
25
+ prob = torch.softmax(output, dim=1)[:, 1]
26
+ else:
27
+ prob = output
28
+
29
+ # Accuracy
30
+ _, prediction = torch.max(output, 1)
31
+ correct = (prediction == label).sum().item()
32
+ accuracy = correct / prediction.size(0)
33
+
34
+ # Average Precision
35
+ y_true = label.cpu().detach().numpy()
36
+ y_pred = prob.cpu().detach().numpy()
37
+ ap = metrics.average_precision_score(y_true, y_pred)
38
+
39
+ # AUC and EER
40
+ try:
41
+ fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(),
42
+ prob.squeeze().cpu().numpy(),
43
+ pos_label=1)
44
+ except:
45
+ # for the case when we only have one sample
46
+ return None, None, accuracy, ap
47
+
48
+ if np.isnan(fpr[0]) or np.isnan(tpr[0]):
49
+ # for the case when all the samples within a batch is fake/real
50
+ auc, eer = None, None
51
+ else:
52
+ auc = metrics.auc(fpr, tpr)
53
+ fnr = 1 - tpr
54
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
55
+
56
+ return auc, eer, accuracy, ap
57
+
58
+
59
+ # ------------ compute average metrics of batches---------------------
60
+ class Metrics_batch():
61
+ def __init__(self):
62
+ self.tprs = []
63
+ self.mean_fpr = np.linspace(0, 1, 100)
64
+ self.aucs = []
65
+ self.eers = []
66
+ self.aps = []
67
+
68
+ self.correct = 0
69
+ self.total = 0
70
+ self.losses = []
71
+
72
+ def update(self, label, output):
73
+ acc = self._update_acc(label, output)
74
+ if output.size(1) == 2:
75
+ prob = torch.softmax(output, dim=1)[:, 1]
76
+ else:
77
+ prob = output
78
+ #label = 1-label
79
+ #prob = torch.softmax(output, dim=1)[:, 1]
80
+ auc, eer = self._update_auc(label, prob)
81
+ ap = self._update_ap(label, prob)
82
+
83
+ return acc, auc, eer, ap
84
+
85
+ def _update_auc(self, lab, prob):
86
+ fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(),
87
+ prob.squeeze().cpu().numpy(),
88
+ pos_label=1)
89
+ if np.isnan(fpr[0]) or np.isnan(tpr[0]):
90
+ return -1, -1
91
+
92
+ auc = metrics.auc(fpr, tpr)
93
+ interp_tpr = np.interp(self.mean_fpr, fpr, tpr)
94
+ interp_tpr[0] = 0.0
95
+ self.tprs.append(interp_tpr)
96
+ self.aucs.append(auc)
97
+
98
+ # return auc
99
+
100
+ # EER
101
+ fnr = 1 - tpr
102
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
103
+ self.eers.append(eer)
104
+
105
+ return auc, eer
106
+
107
+ def _update_acc(self, lab, output):
108
+ _, prediction = torch.max(output, 1) # argmax
109
+ correct = (prediction == lab).sum().item()
110
+ accuracy = correct / prediction.size(0)
111
+ # self.accs.append(accuracy)
112
+ self.correct = self.correct+correct
113
+ self.total = self.total+lab.size(0)
114
+ return accuracy
115
+
116
+ def _update_ap(self, label, prob):
117
+ y_true = label.cpu().detach().numpy()
118
+ y_pred = prob.cpu().detach().numpy()
119
+ ap = metrics.average_precision_score(y_true,y_pred)
120
+ self.aps.append(ap)
121
+
122
+ return np.mean(ap)
123
+
124
+ def get_mean_metrics(self):
125
+ mean_acc, std_acc = self.correct/self.total, 0
126
+ mean_auc, std_auc = self._mean_auc()
127
+ mean_err, std_err = np.mean(self.eers), np.std(self.eers)
128
+ mean_ap, std_ap = np.mean(self.aps), np.std(self.aps)
129
+
130
+ return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap}
131
+
132
+ def _mean_auc(self):
133
+ mean_tpr = np.mean(self.tprs, axis=0)
134
+ mean_tpr[-1] = 1.0
135
+ mean_auc = metrics.auc(self.mean_fpr, mean_tpr)
136
+ std_auc = np.std(self.aucs)
137
+ return mean_auc, std_auc
138
+
139
+ def clear(self):
140
+ self.tprs.clear()
141
+ self.aucs.clear()
142
+ # self.accs.clear()
143
+ self.correct=0
144
+ self.total=0
145
+ self.eers.clear()
146
+ self.aps.clear()
147
+ self.losses.clear()
148
+
149
+
150
+ # ------------ compute average metrics of all data ---------------------
151
+ class Metrics_all():
152
+ def __init__(self):
153
+ self.probs = []
154
+ self.labels = []
155
+ self.correct = 0
156
+ self.total = 0
157
+
158
+ def store(self, label, output):
159
+ prob = torch.softmax(output, dim=1)[:, 1]
160
+ _, prediction = torch.max(output, 1) # argmax
161
+ correct = (prediction == label).sum().item()
162
+ self.correct += correct
163
+ self.total += label.size(0)
164
+ self.labels.append(label.squeeze().cpu().numpy())
165
+ self.probs.append(prob.squeeze().cpu().numpy())
166
+
167
+ def get_metrics(self):
168
+ y_pred = np.concatenate(self.probs)
169
+ y_true = np.concatenate(self.labels)
170
+ # auc
171
+ fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1)
172
+ auc = metrics.auc(fpr, tpr)
173
+ # eer
174
+ fnr = 1 - tpr
175
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
176
+ # ap
177
+ ap = metrics.average_precision_score(y_true,y_pred)
178
+ # acc
179
+ acc = self.correct / self.total
180
+ return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap}
181
+
182
+ def clear(self):
183
+ self.probs.clear()
184
+ self.labels.clear()
185
+ self.correct = 0
186
+ self.total = 0
187
+
188
+
189
+ # only used to record a series of scalar value
190
+ class Recorder:
191
+ def __init__(self):
192
+ self.sum = 0
193
+ self.num = 0
194
+ def update(self, item, num=1):
195
+ if item is not None:
196
+ self.sum += item * num
197
+ self.num += num
198
+ def average(self):
199
+ if self.num == 0:
200
+ return None
201
+ return self.sum/self.num
202
+ def clear(self):
203
+ self.sum = 0
204
+ self.num = 0
metrics/registry.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Registry(object):
2
+ def __init__(self):
3
+ self.data = {}
4
+
5
+ def register_module(self, module_name=None):
6
+ def _register(cls):
7
+ name = module_name
8
+ if module_name is None:
9
+ name = cls.__name__
10
+ self.data[name] = cls
11
+ return cls
12
+ return _register
13
+
14
+ def __getitem__(self, key):
15
+ return self.data[key]
16
+
17
+ DETECTOR = Registry()
18
+ TRAINER = Registry()
19
+ LOSSFUNC = Registry()
metrics/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn import metrics
2
+ import numpy as np
3
+
4
+ def parse_metric_for_print(metric_dict):
5
+ if metric_dict is None:
6
+ return "\n"
7
+ str = "\n"
8
+ str += "================================ Each dataset best metric ================================ \n"
9
+ for key, value in metric_dict.items():
10
+ if key != 'avg':
11
+ str= str+ f"| {key}: "
12
+ for k,v in value.items():
13
+ str = str + f" {k}={v} "
14
+ str= str+ "| \n"
15
+ else:
16
+ str += "============================================================================================= \n"
17
+ str += "================================== Average best metric ====================================== \n"
18
+ avg_dict = value
19
+ for avg_key, avg_value in avg_dict.items():
20
+ if avg_key == 'dataset_dict':
21
+ for key,value in avg_value.items():
22
+ str = str + f"| {key}: {value} | \n"
23
+ else:
24
+ str = str + f"| avg {avg_key}: {avg_value} | \n"
25
+ str += "============================================================================================="
26
+ return str
27
+
28
+
29
+ def get_test_metrics(y_pred, y_true, img_names):
30
+ def get_video_metrics(image, pred, label):
31
+ result_dict = {}
32
+ new_label = []
33
+ new_pred = []
34
+ # print(image[0])
35
+ # print(pred.shape)
36
+ # print(label.shape)
37
+ for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
38
+
39
+ s = item[0]
40
+ if '\\' in s:
41
+ parts = s.split('\\')
42
+ else:
43
+ parts = s.split('/')
44
+ a = parts[-2]
45
+ b = parts[-1]
46
+
47
+ if a not in result_dict:
48
+ result_dict[a] = []
49
+
50
+ result_dict[a].append(item)
51
+ image_arr = list(result_dict.values())
52
+
53
+ for video in image_arr:
54
+ pred_sum = 0
55
+ label_sum = 0
56
+ leng = 0
57
+ for frame in video:
58
+ pred_sum += float(frame[1])
59
+ label_sum += int(frame[2])
60
+ leng += 1
61
+ new_pred.append(pred_sum / leng)
62
+ new_label.append(int(label_sum / leng))
63
+ fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred)
64
+ v_auc = metrics.auc(fpr, tpr)
65
+ fnr = 1 - tpr
66
+ v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
67
+ return v_auc, v_eer
68
+
69
+
70
+ y_pred = y_pred.squeeze()
71
+ # For UCF, where labels for different manipulations are not consistent.
72
+ y_true[y_true >= 1] = 1
73
+ # auc
74
+ fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
75
+ auc = metrics.auc(fpr, tpr)
76
+ # eer
77
+ fnr = 1 - tpr
78
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
79
+ # ap
80
+ ap = metrics.average_precision_score(y_true, y_pred)
81
+ # acc
82
+ prediction_class = (y_pred > 0.5).astype(int)
83
+ correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
84
+ acc = correct / len(prediction_class)
85
+ if type(img_names[0]) is not list:
86
+ # calculate video-level auc for the frame-level methods.
87
+ v_auc, _ = get_video_metrics(img_names, y_pred, y_true)
88
+ else:
89
+ # video-level methods
90
+ v_auc=auc
91
+
92
+ return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'video_auc': v_auc, 'label': y_true}
models/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ from .builder import MODELS, build_model
3
+ from .networks.arcface import (
4
+ SimpleClassificationDF,
5
+ )
6
+ from .networks.mrsa_resnet import (
7
+ PoseResNet, resnet_spec, Bottleneck
8
+ )
9
+ from .networks.pose_hrnet import (
10
+ PoseHighResolutionNet
11
+ )
12
+ from .networks.xception import (
13
+ Xception
14
+ )
15
+ from.networks.pose_efficientNet import (
16
+ PoseEfficientNet
17
+ )
18
+ from .networks.common import *
19
+ from .utils import (
20
+ load_pretrained, freeze_backbone,
21
+ load_model, save_model, unfreeze_backbone,
22
+ preset_model,
23
+ )
24
+
25
+
26
+ __all__=['SimpleClassificationDF', 'PoseResNet', 'MODELS', 'build_model',
27
+ 'load_pretrained', 'freeze_backbone', 'resnet_spec',
28
+ 'load_model', 'save_model', 'unfreeze_backbone', 'Bottleneck',
29
+ 'preset_model', 'PoseHighResolutionNet', 'Xception', 'PoseEfficientNet']
models/builder.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ from typing import Dict, Any, Optional
3
+
4
+ import os
5
+ import sys
6
+ if not os.getcwd() in sys.path:
7
+ sys.path.append(os.getcwd())
8
+
9
+ from torch.nn import Sequential
10
+
11
+ from register.register import Registry, build_from_cfg
12
+
13
+
14
+ def build_model_from_cfg(cfg, registry, default_args=None):
15
+ """Build a PyTorch model from config dict(s). Different from
16
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
17
+ Args:
18
+ cfg (dict, list[dict]): The config of modules, is is either a config
19
+ dict or a list of config dicts. If cfg is a list, a
20
+ the built modules will be wrapped with ``nn.Sequential``.
21
+ registry (:obj:`Registry`): A registry the module belongs to.
22
+ default_args (dict, optional): Default arguments to build the module.
23
+ Defaults to None.
24
+ Returns:
25
+ nn.Module: A built nn module.
26
+ """
27
+ if isinstance(cfg, list):
28
+ modules = [
29
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
30
+ ]
31
+ return Sequential(*modules)
32
+ else:
33
+ return build_from_cfg(cfg, registry, default_args)
34
+
35
+
36
+ MODELS = Registry('model', build_func=build_model_from_cfg)
37
+ HEADS = MODELS
38
+ BACKBONES = MODELS
39
+
40
+
41
+ def build_model(cfg: Dict,
42
+ model: Registry,
43
+ build_func=build_model_from_cfg,
44
+ default_args: Optional[Dict] = None) -> Any:
45
+ return build_func(cfg, model, default_args)
models/networks/arcface.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ from collections import namedtuple
5
+
6
+ from torch.nn import (Linear, Conv2d, BatchNorm1d, Softmax,
7
+ BatchNorm2d, PReLU, ReLU, Sigmoid,
8
+ Dropout2d, Dropout, AvgPool2d, MaxPool2d,
9
+ AdaptiveAvgPool2d, Sequential, Module, Parameter)
10
+ import torch.nn.functional as F
11
+ import torch
12
+
13
+ from ..builder import (
14
+ MODELS, HEADS, BACKBONES,
15
+ build_model,
16
+ )
17
+
18
+
19
+ ################################## Original Arcface Model #############################################################
20
+
21
+
22
+ class Flatten(Module):
23
+ def forward(self, input):
24
+ return input.view(input.size(0), -1)
25
+
26
+
27
+ def l2_norm(input,axis=1):
28
+ norm = torch.norm(input, 2, axis, True)
29
+ output = torch.div(input, norm)
30
+ return output
31
+
32
+
33
+ class SEModule(Module):
34
+ def __init__(self, channels, reduction):
35
+ super(SEModule, self).__init__()
36
+ self.avg_pool = AdaptiveAvgPool2d(1)
37
+ self.fc1 = Conv2d(
38
+ channels, channels // reduction, kernel_size=1, padding=0 ,bias=False)
39
+ self.relu = ReLU(inplace=True)
40
+ self.fc2 = Conv2d(
41
+ channels // reduction, channels, kernel_size=1, padding=0 ,bias=False)
42
+ self.sigmoid = Sigmoid()
43
+
44
+ def forward(self, x):
45
+ module_input = x
46
+ x = self.avg_pool(x)
47
+ x = self.fc1(x)
48
+ x = self.relu(x)
49
+ x = self.fc2(x)
50
+ x = self.sigmoid(x)
51
+ return module_input * x
52
+
53
+
54
+ class bottleneck_IR(Module):
55
+ def __init__(self, in_channel, depth, stride):
56
+ super(bottleneck_IR, self).__init__()
57
+ if in_channel == depth:
58
+ self.shortcut_layer = MaxPool2d(1, stride)
59
+ else:
60
+ self.shortcut_layer = Sequential(
61
+ Conv2d(in_channel, depth, (1, 1), stride ,bias=False), BatchNorm2d(depth))
62
+ self.res_layer = Sequential(
63
+ BatchNorm2d(in_channel),
64
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1 ,bias=False), PReLU(depth),
65
+ Conv2d(depth, depth, (3, 3), stride, 1 ,bias=False), BatchNorm2d(depth))
66
+
67
+ def forward(self, x):
68
+ shortcut = self.shortcut_layer(x)
69
+ res = self.res_layer(x)
70
+ return res + shortcut
71
+
72
+
73
+ class bottleneck_IR_SE(Module):
74
+ def __init__(self, in_channel, depth, stride):
75
+ super(bottleneck_IR_SE, self).__init__()
76
+ if in_channel == depth:
77
+ self.shortcut_layer = MaxPool2d(1, stride)
78
+ else:
79
+ self.shortcut_layer = Sequential(
80
+ Conv2d(in_channel, depth, (1, 1), stride ,bias=False),
81
+ BatchNorm2d(depth))
82
+ self.res_layer = Sequential(
83
+ BatchNorm2d(in_channel),
84
+ Conv2d(in_channel, depth, (3,3), (1,1),1 ,bias=False),
85
+ PReLU(depth),
86
+ Conv2d(depth, depth, (3,3), stride, 1 ,bias=False),
87
+ BatchNorm2d(depth),
88
+ SEModule(depth,16)
89
+ )
90
+
91
+ def forward(self,x):
92
+ shortcut = self.shortcut_layer(x)
93
+ res = self.res_layer(x)
94
+ return res + shortcut
95
+
96
+
97
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
98
+ '''A named tuple describing a ResNet block.'''
99
+
100
+
101
+ def get_block(in_channel, depth, num_units, stride = 2):
102
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units-1)]
103
+
104
+
105
+ def get_blocks(num_layers):
106
+ if num_layers == 50:
107
+ blocks = [
108
+ get_block(in_channel=64, depth=64, num_units = 3),
109
+ get_block(in_channel=64, depth=128, num_units=4),
110
+ get_block(in_channel=128, depth=256, num_units=14),
111
+ get_block(in_channel=256, depth=512, num_units=3)
112
+ ]
113
+ elif num_layers == 100:
114
+ blocks = [
115
+ get_block(in_channel=64, depth=64, num_units=3),
116
+ get_block(in_channel=64, depth=128, num_units=13),
117
+ get_block(in_channel=128, depth=256, num_units=30),
118
+ get_block(in_channel=256, depth=512, num_units=3)
119
+ ]
120
+ elif num_layers == 152:
121
+ blocks = [
122
+ get_block(in_channel=64, depth=64, num_units=3),
123
+ get_block(in_channel=64, depth=128, num_units=8),
124
+ get_block(in_channel=128, depth=256, num_units=36),
125
+ get_block(in_channel=256, depth=512, num_units=3)
126
+ ]
127
+ return blocks
128
+
129
+
130
+ @BACKBONES.register_module()
131
+ class ResNet(Module):
132
+ def __init__(self, num_layers=50, drop_ratio=0.6, mode='ir', **kwargs):
133
+ """
134
+ Implementation for ResNet 50, 101, 152 with/out SE module
135
+ """
136
+ super(ResNet, self).__init__()
137
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
138
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
139
+ blocks = get_blocks(num_layers)
140
+ if mode == 'ir':
141
+ unit_module = bottleneck_IR
142
+ elif mode == 'ir_se':
143
+ unit_module = bottleneck_IR_SE
144
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False),
145
+ BatchNorm2d(64),
146
+ PReLU(64))
147
+ self.output_layer = Sequential(BatchNorm2d(512),
148
+ Dropout(drop_ratio),
149
+ Flatten(),
150
+ Linear(512 * 7 * 7, 512),
151
+ BatchNorm1d(512))
152
+ modules = []
153
+ for block in blocks:
154
+ for bottleneck in block:
155
+ modules.append(
156
+ unit_module(bottleneck.in_channel,
157
+ bottleneck.depth,
158
+ bottleneck.stride))
159
+ self.body = Sequential(*modules)
160
+
161
+ def forward(self,x):
162
+ x = self.input_layer(x)
163
+ x = self.body(x)
164
+ x = self.output_layer(x)
165
+ x = l2_norm(x)
166
+ return x
167
+
168
+
169
+ @HEADS.register_module()
170
+ class SimpleClassificationHead(Module):
171
+ def __init__(self, drop_ratio=0.6, in_planes=512, **kwargs):
172
+ super(SimpleClassificationHead, self).__init__()
173
+ self.classification_head = Sequential(Dropout(drop_ratio),
174
+ Linear(in_planes, 256),
175
+ BatchNorm1d(256),
176
+ Dropout(drop_ratio),
177
+ Linear(256, 128),
178
+ BatchNorm1d(128),
179
+ Dropout(drop_ratio),
180
+ Linear(128, 64),
181
+ BatchNorm1d(64),
182
+ Dropout(drop_ratio),
183
+ Linear(64, 32),
184
+ BatchNorm1d(32),
185
+ # Dropout(drop_ratio),
186
+ Linear(32, 1),
187
+ Sigmoid())
188
+
189
+ def forward(self, x):
190
+ x = self.classification_head(x)
191
+ return x
192
+
193
+
194
+ @MODELS.register_module()
195
+ class SimpleClassificationDF(Module):
196
+ def __init__(self, cfg: dict, **kwargs):
197
+ super(SimpleClassificationDF, self).__init__()
198
+ assert 'backbone' in cfg, 'Config for Backbones is mandatory!'
199
+ assert 'head' in cfg, 'Config for Heads is mandatory!'
200
+
201
+ self.backbone = BACKBONES.get(cfg.backbone.type)(**cfg.backbone)
202
+ self.head = HEADS.get(cfg.head.type)(**cfg.head)
203
+ self.model = Sequential(*[self.backbone,
204
+ self.head])
205
+
206
+ def forward(self, x):
207
+ x = self.model(x)
208
+ return x
209
+
210
+
211
+ ################################## MobileFaceNet #############################################################
212
+
213
+
214
+ class Conv_block(Module):
215
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
216
+ super(Conv_block, self).__init__()
217
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
218
+ self.bn = BatchNorm2d(out_c)
219
+ self.prelu = PReLU(out_c)
220
+
221
+ def forward(self, x):
222
+ x = self.conv(x)
223
+ x = self.bn(x)
224
+ x = self.prelu(x)
225
+ return x
226
+
227
+
228
+ class Linear_block(Module):
229
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
230
+ super(Linear_block, self).__init__()
231
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
232
+ self.bn = BatchNorm2d(out_c)
233
+
234
+ def forward(self, x):
235
+ x = self.conv(x)
236
+ x = self.bn(x)
237
+ return x
238
+
239
+
240
+ class Depth_Wise(Module):
241
+ def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
242
+ super(Depth_Wise, self).__init__()
243
+ self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
244
+ self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
245
+ self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
246
+ self.residual = residual
247
+
248
+ def forward(self, x):
249
+ if self.residual:
250
+ short_cut = x
251
+ x = self.conv(x)
252
+ x = self.conv_dw(x)
253
+ x = self.project(x)
254
+ if self.residual:
255
+ output = short_cut + x
256
+ else:
257
+ output = x
258
+ return output
259
+
260
+
261
+ class Residual(Module):
262
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
263
+ super(Residual, self).__init__()
264
+ modules = []
265
+ for _ in range(num_block):
266
+ modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
267
+ self.model = Sequential(*modules)
268
+
269
+ def forward(self, x):
270
+ return self.model(x)
271
+
272
+
273
+ class MobileFaceNet(Module):
274
+ def __init__(self, embedding_size):
275
+ super(MobileFaceNet, self).__init__()
276
+ self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
277
+ self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
278
+ self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
279
+ self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
280
+ self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
281
+ self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
282
+ self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
283
+ self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
284
+ self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
285
+ self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
286
+ self.conv_6_flatten = Flatten()
287
+ self.linear = Linear(512, embedding_size, bias=False)
288
+ self.bn = BatchNorm1d(embedding_size)
289
+
290
+ def forward(self, x):
291
+ out = self.conv1(x)
292
+ out = self.conv2_dw(out)
293
+ out = self.conv_23(out)
294
+ out = self.conv_3(out)
295
+ out = self.conv_34(out)
296
+ out = self.conv_4(out)
297
+ out = self.conv_45(out)
298
+ out = self.conv_5(out)
299
+ out = self.conv_6_sep(out)
300
+ out = self.conv_6_dw(out)
301
+ out = self.conv_6_flatten(out)
302
+ out = self.linear(out)
303
+ out = self.bn(out)
304
+
305
+ return l2_norm(out)
306
+
307
+
308
+ ################################## Arcface head #############################################################
309
+
310
+
311
+ class Arcface(Module):
312
+ # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
313
+ def __init__(self, embedding_size=512, classnum=51332, s=64., m=0.5):
314
+ super(Arcface, self).__init__()
315
+ self.classnum = classnum
316
+ self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
317
+ # initial kernel
318
+ self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
319
+ self.m = m # the margin value, default is 0.5
320
+ self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
321
+ self.cos_m = math.cos(m)
322
+ self.sin_m = math.sin(m)
323
+ self.mm = self.sin_m * m # issue 1
324
+ self.threshold = math.cos(math.pi - m)
325
+
326
+ def forward(self, embbedings, label):
327
+ # weights norm
328
+ nB = len(embbedings)
329
+ kernel_norm = l2_norm(self.kernel,axis=0)
330
+ # cos(theta+m)
331
+ cos_theta = torch.mm(embbedings,kernel_norm)
332
+ # output = torch.mm(embbedings,kernel_norm)
333
+ cos_theta = cos_theta.clamp(-1,1) # for numerical stability
334
+ cos_theta_2 = torch.pow(cos_theta, 2)
335
+ sin_theta_2 = 1 - cos_theta_2
336
+ sin_theta = torch.sqrt(sin_theta_2)
337
+ cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
338
+ # this condition controls the theta+m should in range [0, pi]
339
+ # 0<=theta+m<=pi
340
+ # -m<=theta<=pi-m
341
+ cond_v = cos_theta - self.threshold
342
+ cond_mask = cond_v <= 0
343
+ keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
344
+ cos_theta_m[cond_mask] = keep_val[cond_mask]
345
+ output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
346
+ idx_ = torch.arange(0, nB, dtype=torch.long)
347
+ output[idx_, label] = cos_theta_m[idx_, label]
348
+ output *= self.s # scale up in order to make softmax work, first introduced in normface
349
+ return output
350
+
351
+
352
+ ################################## Cosface head #############################################################
353
+
354
+
355
+ class Am_softmax(Module):
356
+ # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
357
+ def __init__(self,embedding_size=512,classnum=51332):
358
+ super(Am_softmax, self).__init__()
359
+ self.classnum = classnum
360
+ self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
361
+ # initial kernel
362
+ self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
363
+ self.m = 0.35 # additive margin recommended by the paper
364
+ self.s = 30. # see normface https://arxiv.org/abs/1704.06369
365
+
366
+ def forward(self,embbedings,label):
367
+ kernel_norm = l2_norm(self.kernel,axis=0)
368
+ cos_theta = torch.mm(embbedings,kernel_norm)
369
+ cos_theta = cos_theta.clamp(-1,1) # for numerical stability
370
+ phi = cos_theta - self.m
371
+ label = label.view(-1,1) #size=(B,1)
372
+ index = cos_theta.data * 0.0 #size=(B,Classnum)
373
+ index.scatter_(1,label.data.view(-1,1),1)
374
+ index = index.byte()
375
+ output = cos_theta * 1.0
376
+ output[index] = phi[index] #only change the correct predicted output
377
+ output *= self.s # scale up in order to make softmax work, first introduced in normface
378
+ return output
379
+
380
+
381
+ if __name__ == "__main__":
382
+ cfg = dict(num_layers=50, drop_ratio=0.6, mode='ir', type='Backbone')
383
+ backbone = MODELS.build(cfg)
384
+ print(backbone)
models/networks/common.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ BN_MOMENTUM = 0.1
7
+
8
+
9
+ def point_wise_block(inplanes, outplanes):
10
+ return nn.Sequential(
11
+ nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=1, padding=0, stride=1, bias=False),
12
+ nn.BatchNorm2d(outplanes, momentum=BN_MOMENTUM),
13
+ nn.ReLU(inplace=True),
14
+ )
15
+
16
+
17
+ def conv_block(inplanes, outplanes, kernel_size, stride=1, padding=0):
18
+ return nn.Sequential(
19
+ nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=kernel_size, padding=padding, stride=stride, bias=False),
20
+ nn.BatchNorm2d(outplanes, momentum=BN_MOMENTUM),
21
+ nn.ReLU(inplace=True)
22
+ )
23
+
24
+
25
+ def conv3x3(in_planes, out_planes, stride=1):
26
+ """3x3 convolution with padding"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=1, bias=False)
29
+
30
+
31
+ class InceptionBlock(nn.Module):
32
+ def __init__(self, inplanes, outplanes, stride=1, pool_size=3):
33
+ self.inplanes = inplanes
34
+ self.outplanes = outplanes
35
+ self.stride = stride
36
+ self.pool_size = pool_size
37
+ super(InceptionBlock, self).__init__()
38
+
39
+ self.pw_block = point_wise_block(self.inplanes, self.outplanes//4)
40
+ self.mp_layer = nn.MaxPool2d(kernel_size=self.pool_size, stride=stride, padding=1)
41
+ self.conv3_block = conv_block(self.outplanes//4, self.outplanes//4, kernel_size=3, stride=1, padding=1)
42
+ self.conv5_block = conv_block(self.outplanes//4, self.outplanes//4, kernel_size=5, stride=1, padding=2)
43
+
44
+ def forward(self, x):
45
+ x1 = self.pw_block(x)
46
+
47
+ x2 = self.pw_block(x)
48
+ x2 = self.conv3_block(x2)
49
+
50
+ x3 = self.pw_block(x)
51
+ x3 = self.conv5_block(x3)
52
+
53
+ x4 = self.mp_layer(x)
54
+ x4 = self.pw_block(x4)
55
+
56
+ x = torch.cat((x1, x2, x3, x4), dim=1)
57
+ return x
58
+
59
+
60
+ class SELayer(nn.Module):
61
+ def __init__(self, channel, reduction=16):
62
+ super(SELayer, self).__init__()
63
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
64
+ self.fc = nn.Sequential(
65
+ nn.Linear(channel, channel // reduction, bias=False),
66
+ nn.ReLU(inplace=True),
67
+ nn.Linear(channel // reduction, channel, bias=False),
68
+ nn.Sigmoid()
69
+ )
70
+
71
+ def forward(self, x):
72
+ b, c, _, _ = x.size()
73
+ y = self.avg_pool(x).view(b, c)
74
+ y = self.fc(y).view(b, c, 1, 1)
75
+ return x * y.expand_as(x)
models/networks/efficientNet.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ import math
3
+ import re
4
+ import collections
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.utils import model_zoo
11
+
12
+
13
+ # Parameters for the entire model (stem, all blocks, and head)
14
+ GlobalParams = collections.namedtuple('GlobalParams', [
15
+ 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
16
+ 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
17
+ 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top',
18
+ 'include_hm_decoder', 'head_conv', 'heads', 'num_layers', 'INIT_WEIGHTS',
19
+ 'use_c2', 'use_c3', 'use_c4', 'use_c51', 'efpn', 'se_layer', 'tfpn'])
20
+
21
+ # Parameters for an individual model block
22
+ BlockArgs = collections.namedtuple('BlockArgs', [
23
+ 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
24
+ 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
25
+
26
+ # Set GlobalParams and BlockArgs's defaults
27
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
28
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
29
+
30
+
31
+ # Swish activation function
32
+ if hasattr(nn, 'SiLU'):
33
+ Swish = nn.SiLU
34
+ else:
35
+ # For compatibility with old PyTorch versions
36
+ class Swish(nn.Module):
37
+ def forward(self, x):
38
+ return x * torch.sigmoid(x)
39
+
40
+
41
+ def round_filters(filters, global_params):
42
+ """Calculate and round number of filters based on width multiplier.
43
+ Use width_coefficient, depth_divisor and min_depth of global_params.
44
+ Args:
45
+ filters (int): Filters number to be calculated.
46
+ global_params (namedtuple): Global params of the model.
47
+ Returns:
48
+ new_filters: New filters number after calculating.
49
+ """
50
+ multiplier = global_params.width_coefficient
51
+ if not multiplier:
52
+ return filters
53
+ # TODO: modify the params names.
54
+ # maybe the names (width_divisor,min_width)
55
+ # are more suitable than (depth_divisor,min_depth).
56
+ divisor = global_params.depth_divisor
57
+ min_depth = global_params.min_depth
58
+ filters *= multiplier
59
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
60
+ # follow the formula transferred from official TensorFlow implementation
61
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
62
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
63
+ new_filters += divisor
64
+ return int(new_filters)
65
+
66
+
67
+ def round_repeats(repeats, global_params):
68
+ """Calculate module's repeat number of a block based on depth multiplier.
69
+ Use depth_coefficient of global_params.
70
+ Args:
71
+ repeats (int): num_repeat to be calculated.
72
+ global_params (namedtuple): Global params of the model.
73
+ Returns:
74
+ new repeat: New repeat number after calculating.
75
+ """
76
+ multiplier = global_params.depth_coefficient
77
+ if not multiplier:
78
+ return repeats
79
+ # follow the formula transferred from official TensorFlow implementation
80
+ return int(math.ceil(multiplier * repeats))
81
+
82
+
83
+ def drop_connect(inputs, p, training):
84
+ """Drop connect.
85
+ Args:
86
+ input (tensor: BCWH): Input of this structure.
87
+ p (float: 0.0~1.0): Probability of drop connection.
88
+ training (bool): The running mode.
89
+ Returns:
90
+ output: Output after drop connection.
91
+ """
92
+ assert 0 <= p <= 1, 'p must be in range of [0,1]'
93
+
94
+ if not training:
95
+ return inputs
96
+
97
+ batch_size = inputs.shape[0]
98
+ keep_prob = 1 - p
99
+
100
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
101
+ random_tensor = keep_prob
102
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
103
+ binary_tensor = torch.floor(random_tensor)
104
+
105
+ output = inputs / keep_prob * binary_tensor
106
+ return output
107
+
108
+
109
+ def get_same_padding_conv2d(image_size=None):
110
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
111
+ Static padding is necessary for ONNX exporting of models.
112
+ Args:
113
+ image_size (int or tuple): Size of the image.
114
+ Returns:
115
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
116
+ """
117
+ if image_size is None:
118
+ return Conv2dDynamicSamePadding
119
+ else:
120
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
121
+
122
+
123
+ class Conv2dDynamicSamePadding(nn.Conv2d):
124
+ """2D Convolutions like TensorFlow, for a dynamic image size.
125
+ The padding is operated in forward function by calculating dynamically.
126
+ """
127
+
128
+ # Tips for 'SAME' mode padding.
129
+ # Given the following:
130
+ # i: width or height
131
+ # s: stride
132
+ # k: kernel size
133
+ # d: dilation
134
+ # p: padding
135
+ # Output after Conv2d:
136
+ # o = floor((i+p-((k-1)*d+1))/s+1)
137
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
138
+ # => p = (i-1)*s+((k-1)*d+1)-i
139
+
140
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
141
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
142
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
143
+
144
+ def forward(self, x):
145
+ ih, iw = x.size()[-2:]
146
+ kh, kw = self.weight.size()[-2:]
147
+ sh, sw = self.stride
148
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
149
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
150
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
151
+ if pad_h > 0 or pad_w > 0:
152
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
153
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
154
+
155
+
156
+ class Conv2dStaticSamePadding(nn.Conv2d):
157
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
158
+ The padding mudule is calculated in construction function, then used in forward.
159
+ """
160
+
161
+ # With the same calculation as Conv2dDynamicSamePadding
162
+
163
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
164
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
165
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
166
+
167
+ # Calculate padding based on image size and save it
168
+ assert image_size is not None
169
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
170
+ kh, kw = self.weight.size()[-2:]
171
+ sh, sw = self.stride
172
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
173
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
174
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
175
+ if pad_h > 0 or pad_w > 0:
176
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
177
+ pad_h // 2, pad_h - pad_h // 2))
178
+ else:
179
+ self.static_padding = nn.Identity()
180
+
181
+ def forward(self, x):
182
+ x = self.static_padding(x)
183
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
184
+ return x
185
+
186
+
187
+ def get_model_params(model_name, override_params):
188
+ """Get the block args and global params for a given model name.
189
+ Args:
190
+ model_name (str): Model's name.
191
+ override_params (dict): A dict to modify global_params.
192
+ Returns:
193
+ blocks_args, global_params
194
+ """
195
+ if model_name.startswith('efficientnet'):
196
+ w, d, s, p = efficientnet_params(model_name)
197
+ # note: all models have drop connect rate = 0.2
198
+ blocks_args, global_params = efficientnet(
199
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
200
+ else:
201
+ raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
202
+ if override_params:
203
+ # ValueError will be raised here if override_params has fields not included in global_params.
204
+ global_params = global_params._replace(**override_params)
205
+ return blocks_args, global_params
206
+
207
+
208
+ def efficientnet_params(model_name):
209
+ """Map EfficientNet model name to parameter coefficients.
210
+ Args:
211
+ model_name (str): Model name to be queried.
212
+ Returns:
213
+ params_dict[model_name]: A (width,depth,res,dropout) tuple.
214
+ """
215
+ params_dict = {
216
+ # Coefficients: width,depth,res,dropout
217
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
218
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
219
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
220
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
221
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
222
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
223
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
224
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
225
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
226
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
227
+ }
228
+ return params_dict[model_name]
229
+
230
+
231
+ def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
232
+ dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000,
233
+ include_top=True, include_hm_decoder=False, head_conv=None,
234
+ heads=None, use_c2=False, use_c3=False, use_c4=False, use_c51=False,
235
+ num_layers=None, INIT_WEIGHTS=None, efpn=False, se_layer=False, tfpn=False):
236
+ """Create BlockArgs and GlobalParams for efficientnet model.
237
+ Args:
238
+ width_coefficient (float)
239
+ depth_coefficient (float)
240
+ image_size (int)
241
+ dropout_rate (float)
242
+ drop_connect_rate (float)
243
+ num_classes (int)
244
+ Meaning as the name suggests.
245
+ Returns:
246
+ blocks_args, global_params.
247
+ """
248
+
249
+ # Blocks args for the whole model(efficientnet-b0 by default)
250
+ # It will be modified in the construction of EfficientNet Class according to model
251
+ blocks_args = [
252
+ 'r1_k3_s11_e1_i32_o16_se0.25',
253
+ 'r2_k3_s22_e6_i16_o24_se0.25',
254
+ 'r2_k5_s22_e6_i24_o40_se0.25',
255
+ 'r3_k3_s22_e6_i40_o80_se0.25',
256
+ 'r3_k5_s11_e6_i80_o112_se0.25',
257
+ 'r4_k5_s22_e6_i112_o192_se0.25',
258
+ 'r1_k3_s11_e6_i192_o320_se0.25',
259
+ ]
260
+ blocks_args = BlockDecoder.decode(blocks_args)
261
+
262
+ global_params = GlobalParams(
263
+ width_coefficient=width_coefficient,
264
+ depth_coefficient=depth_coefficient,
265
+ image_size=image_size,
266
+ dropout_rate=dropout_rate,
267
+
268
+ num_classes=num_classes,
269
+ batch_norm_momentum=0.99,
270
+ batch_norm_epsilon=1e-3,
271
+ drop_connect_rate=drop_connect_rate,
272
+ depth_divisor=8,
273
+ min_depth=None,
274
+ include_top=include_top,
275
+ include_hm_decoder=include_hm_decoder,
276
+ head_conv=head_conv,
277
+ heads=heads,
278
+ use_c2=use_c2,
279
+ use_c3=use_c3,
280
+ use_c4=use_c4,
281
+ use_c51=use_c51,
282
+ efpn=efpn,
283
+ tfpn=tfpn,
284
+ se_layer=se_layer,
285
+ num_layers=num_layers,
286
+ INIT_WEIGHTS=INIT_WEIGHTS
287
+ )
288
+
289
+ return blocks_args, global_params
290
+
291
+
292
+ class BlockDecoder(object):
293
+ """Block Decoder for readability,
294
+ straight from the official TensorFlow repository.
295
+ """
296
+
297
+ @staticmethod
298
+ def _decode_block_string(block_string):
299
+ """Get a block through a string notation of arguments.
300
+ Args:
301
+ block_string (str): A string notation of arguments.
302
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
303
+ Returns:
304
+ BlockArgs: The namedtuple defined at the top of this file.
305
+ """
306
+ assert isinstance(block_string, str)
307
+
308
+ ops = block_string.split('_')
309
+ options = {}
310
+ for op in ops:
311
+ splits = re.split(r'(\d.*)', op)
312
+ if len(splits) >= 2:
313
+ key, value = splits[:2]
314
+ options[key] = value
315
+
316
+ # Check stride
317
+ assert (('s' in options and len(options['s']) == 1) or
318
+ (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
319
+
320
+ return BlockArgs(
321
+ num_repeat=int(options['r']),
322
+ kernel_size=int(options['k']),
323
+ stride=[int(options['s'][0])],
324
+ expand_ratio=int(options['e']),
325
+ input_filters=int(options['i']),
326
+ output_filters=int(options['o']),
327
+ se_ratio=float(options['se']) if 'se' in options else None,
328
+ id_skip=('noskip' not in block_string))
329
+
330
+ @staticmethod
331
+ def _encode_block_string(block):
332
+ """Encode a block to a string.
333
+ Args:
334
+ block (namedtuple): A BlockArgs type argument.
335
+ Returns:
336
+ block_string: A String form of BlockArgs.
337
+ """
338
+ args = [
339
+ 'r%d' % block.num_repeat,
340
+ 'k%d' % block.kernel_size,
341
+ 's%d%d' % (block.strides[0], block.strides[1]),
342
+ 'e%s' % block.expand_ratio,
343
+ 'i%d' % block.input_filters,
344
+ 'o%d' % block.output_filters
345
+ ]
346
+ if 0 < block.se_ratio <= 1:
347
+ args.append('se%s' % block.se_ratio)
348
+ if block.id_skip is False:
349
+ args.append('noskip')
350
+ return '_'.join(args)
351
+
352
+ @staticmethod
353
+ def decode(string_list):
354
+ """Decode a list of string notations to specify blocks inside the network.
355
+ Args:
356
+ string_list (list[str]): A list of strings, each string is a notation of block.
357
+ Returns:
358
+ blocks_args: A list of BlockArgs namedtuples of block args.
359
+ """
360
+ assert isinstance(string_list, list)
361
+ blocks_args = []
362
+ for block_string in string_list:
363
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
364
+ return blocks_args
365
+
366
+ @staticmethod
367
+ def encode(blocks_args):
368
+ """Encode a list of BlockArgs to a list of strings.
369
+ Args:
370
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
371
+ Returns:
372
+ block_strings: A list of strings, each string is a notation of block.
373
+ """
374
+ block_strings = []
375
+ for block in blocks_args:
376
+ block_strings.append(BlockDecoder._encode_block_string(block))
377
+ return block_strings
378
+
379
+
380
+ class SwishImplementation(torch.autograd.Function):
381
+ @staticmethod
382
+ def forward(ctx, i):
383
+ result = i * torch.sigmoid(i)
384
+ ctx.save_for_backward(i)
385
+ return result
386
+
387
+ @staticmethod
388
+ def backward(ctx, grad_output):
389
+ i = ctx.saved_tensors[0]
390
+ sigmoid_i = torch.sigmoid(i)
391
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
392
+
393
+
394
+ def get_width_and_height_from_size(x):
395
+ """Obtain height and width from x.
396
+ Args:
397
+ x (int, tuple or list): Data size.
398
+ Returns:
399
+ size: A tuple or list (H,W).
400
+ """
401
+ if isinstance(x, int):
402
+ return x, x
403
+ if isinstance(x, list) or isinstance(x, tuple):
404
+ return x
405
+ else:
406
+ raise TypeError()
407
+
408
+
409
+ def calculate_output_image_size(input_image_size, stride):
410
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
411
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
412
+ Args:
413
+ input_image_size (int, tuple or list): Size of input image.
414
+ stride (int, tuple or list): Conv2d operation's stride.
415
+ Returns:
416
+ output_image_size: A list [H,W].
417
+ """
418
+ if input_image_size is None:
419
+ return None
420
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
421
+ stride = stride if isinstance(stride, int) else stride[0]
422
+ image_height = int(math.ceil(image_height / stride))
423
+ image_width = int(math.ceil(image_width / stride))
424
+ return [image_height, image_width]
425
+
426
+
427
+ class MemoryEfficientSwish(nn.Module):
428
+ def forward(self, x):
429
+ return SwishImplementation.apply(x)
430
+
431
+
432
+ url_map_advprop = {
433
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
434
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
435
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
436
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
437
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
438
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
439
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
440
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
441
+ 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
442
+ }
443
+
444
+
445
+ url_map = {
446
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
447
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
448
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
449
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
450
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
451
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
452
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
453
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
454
+ }
455
+
456
+
457
+ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
458
+ """Loads pretrained weights from weights path or download using url.
459
+ Args:
460
+ model (Module): The whole model of efficientnet.
461
+ model_name (str): Model name of efficientnet.
462
+ weights_path (None or str):
463
+ str: path to pretrained weights file on the local disk.
464
+ None: use pretrained weights downloaded from the Internet.
465
+ load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
466
+ advprop (bool): Whether to load pretrained weights
467
+ trained with advprop (valid when weights_path is None).
468
+ """
469
+ if isinstance(weights_path, str):
470
+ state_dict = torch.load(weights_path)
471
+ else:
472
+ # AutoAugment or Advprop (different preprocessing)
473
+ url_map_ = url_map_advprop if advprop else url_map
474
+ state_dict = model_zoo.load_url(url_map_[model_name])
475
+
476
+ if load_fc:
477
+ ret = model.load_state_dict(state_dict, strict=False)
478
+ assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
479
+ else:
480
+ state_dict.pop('_fc.weight')
481
+ state_dict.pop('_fc.bias')
482
+ ret = model.load_state_dict(state_dict, strict=False)
483
+
484
+ # if len(ret.missing_keys):
485
+ # assert set(ret.missing_keys) == set(
486
+ # ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
487
+ assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
488
+
489
+ if verbose:
490
+ print('Loaded pretrained weights for {}'.format(model_name))
models/networks/mrsa_resnet.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn.modules.activation import ReLU
12
+ from torch.nn.modules.batchnorm import BatchNorm2d
13
+ from torch.nn.modules.pooling import MaxPool2d
14
+ import torch.utils.model_zoo as model_zoo
15
+
16
+ from ..builder import MODELS, build_model
17
+ from .common import (
18
+ BN_MOMENTUM,
19
+ conv_block,
20
+ point_wise_block,
21
+ InceptionBlock,
22
+ )
23
+
24
+
25
+ model_urls = {
26
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
27
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
28
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
29
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
30
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
31
+ }
32
+
33
+
34
+ def conv3x3(in_planes, out_planes, stride=1):
35
+ """3x3 convolution with padding"""
36
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
37
+ padding=1, bias=False)
38
+
39
+
40
+ class BasicBlock(nn.Module):
41
+ expansion = 1
42
+
43
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
44
+ super(BasicBlock, self).__init__()
45
+ self.conv1 = conv3x3(inplanes, planes, stride)
46
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
47
+ self.relu = nn.ReLU(inplace=True)
48
+ self.conv2 = conv3x3(planes, planes)
49
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
50
+ self.downsample = downsample
51
+ self.stride = stride
52
+
53
+ def forward(self, x):
54
+ residual = x
55
+
56
+ out = self.conv1(x)
57
+ out = self.bn1(out)
58
+ out = self.relu(out)
59
+
60
+ out = self.conv2(out)
61
+ out = self.bn2(out)
62
+
63
+ if self.downsample is not None:
64
+ residual = self.downsample(x)
65
+
66
+ out += residual
67
+ out = self.relu(out)
68
+
69
+ return out
70
+
71
+ @staticmethod
72
+ def __repr__():
73
+ return 'BasicBlock'
74
+
75
+
76
+ class Bottleneck(nn.Module):
77
+ expansion = 4
78
+
79
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
80
+ super(Bottleneck, self).__init__()
81
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
83
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
84
+ padding=1, bias=False)
85
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
86
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
87
+ bias=False)
88
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
89
+ momentum=BN_MOMENTUM)
90
+ self.relu = nn.ReLU(inplace=True)
91
+ self.downsample = downsample
92
+ self.stride = stride
93
+
94
+ def forward(self, x):
95
+ residual = x
96
+
97
+ out = self.conv1(x)
98
+ out = self.bn1(out)
99
+ out = self.relu(out)
100
+
101
+ out = self.conv2(out)
102
+ out = self.bn2(out)
103
+ out = self.relu(out)
104
+
105
+ out = self.conv3(out)
106
+ out = self.bn3(out)
107
+
108
+ if self.downsample is not None:
109
+ residual = self.downsample(x)
110
+
111
+ out += residual
112
+ out = self.relu(out)
113
+
114
+ return out
115
+
116
+ @staticmethod
117
+ def __repr__():
118
+ return 'Bottleneck'
119
+
120
+
121
+ @MODELS.register_module()
122
+ class PoseResNet(nn.Module):
123
+ def __init__(self,
124
+ block,
125
+ layers,
126
+ heads,
127
+ head_conv,
128
+ dropout_prob,
129
+ fpn=False,
130
+ cls_based_hm=True,
131
+ use_c2=False,
132
+ **kwargs):
133
+ self.inplanes = 64
134
+ self.deconv_with_bias = False
135
+ self.heads = heads
136
+ self.fpn = fpn
137
+ self.cls_based_hm= cls_based_hm
138
+ self.use_c2 = use_c2
139
+
140
+ #Convert Cls name into Cls Object
141
+ if isinstance(block, str):
142
+ for bl in [BasicBlock, Bottleneck]:
143
+ if block == bl.__repr__():
144
+ block = bl
145
+
146
+ for k, v in kwargs.items():
147
+ if v is None:
148
+ raise ValueError(f'The {k} argument receive a None value, Please check!')
149
+ self.__setattr__(k, v)
150
+
151
+ super(PoseResNet, self).__init__()
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
153
+ bias=False)
154
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
155
+ self.relu = nn.ReLU(inplace=True)
156
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
157
+ self.layer1 = self._make_layer(block, 64, layers[0])
158
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
159
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
160
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
161
+
162
+ # Custom dropout layer
163
+ self.dropout_layer = nn.Dropout(dropout_prob)
164
+
165
+ if self.fpn:
166
+ # Adding sidmoid layer
167
+ self.sigmoid_layer = nn.Sigmoid()
168
+
169
+ # Adding pointwise block
170
+ self.pw_block_1 = self._point_wise_block(2048, 1024)
171
+
172
+ # used for deconv layers
173
+ deconv_filters = [256, 128, 256] if self.fpn else [256, 256, 256]
174
+ self.deconv_layers = self._make_deconv_layer(
175
+ 3,
176
+ deconv_filters,
177
+ [4, 4, 4],
178
+ )
179
+
180
+ # Adding inception block
181
+ if self.fpn:
182
+ for idx, deconv_layer in enumerate(self.deconv_layers):
183
+ self.__setattr__(f'deconv_layer_{idx}', nn.Sequential(deconv_layer))
184
+ self.pw_block_2 = self._point_wise_block(512, 512)
185
+ if self.use_c2:
186
+ self.pw_block_3 = self._point_wise_block(512, 256)
187
+ self.pw_block_c3 = self._point_wise_block(1024, 256)
188
+ self.pw_block_c2 = self._point_wise_block(512, 128)
189
+ self.inception_block = InceptionBlock(256, 256, stride=1, pool_size=3)
190
+
191
+ for head in sorted(self.heads):
192
+ num_output = self.heads[head]
193
+ if head_conv > 0:
194
+ if head != 'cls':
195
+ fc = nn.Sequential(
196
+ nn.Conv2d(256, head_conv,
197
+ kernel_size=3, padding=1, bias=True),
198
+ nn.BatchNorm2d(head_conv),
199
+ nn.ReLU(inplace=True),
200
+ nn.Conv2d(head_conv, num_output,
201
+ kernel_size=1, stride=1, padding=0)
202
+ )
203
+ else:
204
+ if self.cls_based_hm:
205
+ fc = nn.Sequential(
206
+ nn.AdaptiveMaxPool2d(head_conv//4),
207
+ nn.Flatten(),
208
+ nn.Linear(num_output*((head_conv//4)**2), head_conv, bias=True),
209
+ nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
210
+ nn.ReLU(inplace=True),
211
+ nn.Linear(head_conv, 1, bias=True),
212
+ nn.Sigmoid()
213
+ )
214
+ else:
215
+ fc = nn.Sequential(
216
+ nn.Conv2d(256, head_conv, kernel_size=3,
217
+ padding=1, bias=True),
218
+ nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
219
+ nn.ReLU(inplace=True),
220
+ # nn.Conv2d(head_conv, num_output, kernel_size=1,
221
+ # stride=1, padding=0, bias=True),
222
+ # nn.BatchNorm2d(num_output),
223
+ # nn.ReLU(inplace=True),
224
+ # nn.AdaptiveMaxPool2d(head_conv//4),
225
+ nn.AdaptiveAvgPool2d(1),
226
+ nn.Flatten(),
227
+ # nn.Linear((head_conv//4)**2, head_conv, bias=True),
228
+ # nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
229
+ # nn.ReLU(inplace=True),
230
+ nn.Linear(head_conv, 1, bias=True),
231
+ # nn.Sigmoid()
232
+ )
233
+ else:
234
+ fc = nn.Conv2d(
235
+ in_channels=256,
236
+ out_channels=num_output,
237
+ kernel_size=1,
238
+ stride=1,
239
+ padding=0
240
+ )
241
+ self.__setattr__(head, fc)
242
+
243
+ def _point_wise_block(self, inplanes, outplanes):
244
+ self.inplanes = outplanes
245
+ module = point_wise_block(inplanes, outplanes)
246
+ return module
247
+
248
+ def _conv_block(self, inplanes, outplanes, kernel_size, stride=1):
249
+ self.inplanes = outplanes
250
+ module = conv_block(inplanes, outplanes, kernel_size=kernel_size, stride=stride)
251
+ return module
252
+
253
+ def _make_layer(self, block, planes, blocks, stride=1):
254
+ downsample = None
255
+ if stride != 1 or self.inplanes != planes * block.expansion:
256
+ downsample = nn.Sequential(
257
+ nn.Conv2d(self.inplanes, planes * block.expansion,
258
+ kernel_size=1, stride=stride, bias=False),
259
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
260
+ )
261
+
262
+ layers = []
263
+ layers.append(block(self.inplanes, planes, stride, downsample))
264
+ self.inplanes = planes * block.expansion
265
+ for i in range(1, blocks):
266
+ layers.append(block(self.inplanes, planes))
267
+
268
+ return nn.Sequential(*layers)
269
+
270
+ def _get_deconv_cfg(self, deconv_kernel, index):
271
+ if deconv_kernel == 4:
272
+ padding = 1
273
+ output_padding = 0
274
+ elif deconv_kernel == 3:
275
+ padding = 1
276
+ output_padding = 1
277
+ elif deconv_kernel == 2:
278
+ padding = 0
279
+ output_padding = 0
280
+
281
+ return deconv_kernel, padding, output_padding
282
+
283
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
284
+ assert num_layers == len(num_filters), \
285
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
286
+ assert num_layers == len(num_kernels), \
287
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
288
+
289
+ layers = []
290
+ for i in range(num_layers):
291
+ kernel, padding, output_padding = \
292
+ self._get_deconv_cfg(num_kernels[i], i)
293
+
294
+ planes = num_filters[i]
295
+ layers.append(nn.Sequential(
296
+ nn.ConvTranspose2d(
297
+ in_channels=self.inplanes,
298
+ out_channels=planes,
299
+ kernel_size=kernel,
300
+ stride=2,
301
+ padding=padding,
302
+ output_padding=output_padding,
303
+ bias=self.deconv_with_bias),
304
+ nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
305
+ )
306
+ if (not self.fpn):
307
+ layers.append(nn.ReLU(inplace=True))
308
+
309
+ self.inplanes = planes if not self.fpn else planes * 2
310
+
311
+ if self.fpn:
312
+ return layers
313
+ else:
314
+ return nn.Sequential(*layers)
315
+
316
+ def forward(self, x):
317
+ x = self.conv1(x)
318
+ x = self.bn1(x)
319
+ x = self.relu(x)
320
+ x = self.maxpool(x)
321
+
322
+ x1 = self.layer1(x) #256 x 64 x 64
323
+ x2 = self.layer2(x1) #512 x 32 x 32
324
+ x3 = self.layer3(x2) #1024 x 16 x 16
325
+ x4 = self.layer4(x3) #2048 x 8 x 8
326
+
327
+ # Custom dropout layer
328
+ x = self.dropout_layer(x4) #B x 8 x 8 x 2048
329
+ x3 = self.dropout_layer(x3)
330
+ x2 = self.dropout_layer(x2)
331
+ x1 = self.dropout_layer(x1)
332
+
333
+ # Custom FPN
334
+ if self.fpn:
335
+ assert isinstance(self.deconv_layers, list), "To custom FPN, decompose deconv layers as a list!"
336
+ x = self.pw_block_1(x) # B x 1024 x 8 x 8
337
+ x = self.deconv_layer_0(x) # B x 256 x 16 x 16
338
+ # x = self.relu(x) # B x 256 x 16 x 16
339
+
340
+ x_weighted = self.sigmoid_layer(x) # B x 256 x 16 x 16
341
+ x_inverse = torch.sub(1, x_weighted, alpha=1) # B x 256 x 16 x 16
342
+ x3 = self.pw_block_c3(x3) #B x 256 x 16 x 16
343
+ x3_ = torch.multiply(x3, x_inverse) #B x 256 x 16 x 16
344
+ x = torch.cat((x, x3_), dim=1) #B x 512 x 16 x 16
345
+
346
+ x = self.pw_block_2(x) #B x 512 x 16 x 16
347
+ x = self.deconv_layer_1(x) #B x 128 x 32 x 32
348
+ # x = self.relu(x) #B x 128 x 32 x 32
349
+
350
+ x_weighted = self.sigmoid_layer(x) #B x 128 x 32 x 32
351
+ x_inverse = torch.sub(1, x_weighted, alpha=1) #B x 128 x 32 x 32
352
+ x2 = self.pw_block_c2(x2)
353
+ x2_ = torch.multiply(x2, x_inverse) #B x 128 x 32 x 32
354
+ x = torch.cat((x, x2_), dim=1) #B x 256 x 32 x 32
355
+
356
+ x = self.inception_block(x) #B x 256 x 64 x 64
357
+ x = self.deconv_layer_2(x) #B x 256 x 64 x 64
358
+
359
+ if self.use_c2:
360
+ x_weighted = self.sigmoid_layer(x)
361
+ x_inverse = torch.sub(1, x_weighted, alpha=1)
362
+ x1_ = torch.multiply(x1, x_inverse)
363
+ x = torch.cat((x, x1_), dim=1)
364
+ x = self.pw_block_3(x)
365
+ else:
366
+ x = self.relu(x) #B x 256 x 64 x 64
367
+ else:
368
+ assert isinstance(self.deconv_layers, nn.Module), "Deconv Layer must be nn Module to compute!"
369
+ x = self.deconv_layers(x)
370
+
371
+ ret = {}
372
+ x1_hm = None
373
+ for head in self.heads:
374
+ if self.cls_based_hm and head == 'cls' and x1_hm is not None:
375
+ x = x1_hm
376
+ elif head == 'hm':
377
+ x1_hm = x
378
+
379
+ ret[head] = self.__getattr__(head)(x)
380
+
381
+ return [ret]
382
+
383
+ def init_weights(self, pretrained=True, **kwargs):
384
+ num_layers = kwargs.get('num_layers')
385
+ if pretrained:
386
+ if self.fpn:
387
+ for bl in [self.pw_block_1, self.pw_block_2]:
388
+ for _, l in bl.named_parameters():
389
+ if isinstance(l, nn.Conv2d):
390
+ nn.init.normal_(l.weight, std=0.001)
391
+ nn.init.constant_(l.bias, 0)
392
+
393
+ for _, l in self.inception_block.named_parameters():
394
+ if isinstance(l, nn.Conv2d):
395
+ nn.init.normal_(l.weight, std=0.001)
396
+ nn.init.constant_(l.bias, 0)
397
+
398
+ # print('=> init resnet deconv weights from normal distribution')
399
+ if isinstance(self.deconv_layers, nn.Module):
400
+ for _, m in self.deconv_layers.named_modules():
401
+ if isinstance(m, nn.ConvTranspose2d):
402
+ # print('=> init {}.weight as normal(0, 0.001)'.format(name))
403
+ # print('=> init {}.bias as 0'.format(name))
404
+ nn.init.normal_(m.weight, std=0.001)
405
+ if self.deconv_with_bias:
406
+ nn.init.constant_(m.bias, 0)
407
+ elif isinstance(m, nn.BatchNorm2d):
408
+ # print('=> init {}.weight as 1'.format(name))
409
+ # print('=> init {}.bias as 0'.format(name))
410
+ nn.init.constant_(m.weight, 1)
411
+ nn.init.constant_(m.bias, 0)
412
+ else:
413
+ for layer in [self.deconv_layer_0, self.deconv_layer_1, self.deconv_layer_2]:
414
+ for _, m in layer.named_modules():
415
+ if isinstance(m, nn.ConvTranspose2d):
416
+ # print('=> init {}.weight as normal(0, 0.001)'.format(name))
417
+ # print('=> init {}.bias as 0'.format(name))
418
+ nn.init.normal_(m.weight, std=0.001)
419
+ if self.deconv_with_bias:
420
+ nn.init.constant_(m.bias, 0)
421
+ elif isinstance(m, nn.BatchNorm2d):
422
+ # print('=> init {}.weight as 1'.format(name))
423
+ # print('=> init {}.bias as 0'.format(name))
424
+ nn.init.constant_(m.weight, 1)
425
+ nn.init.constant_(m.bias, 0)
426
+
427
+ # print('=> init final conv weights from normal distribution')
428
+ for head in self.heads:
429
+ final_layer = self.__getattr__(head)
430
+ for i, m in enumerate(final_layer.modules()):
431
+ if isinstance(m, nn.Conv2d):
432
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
433
+ # print('=> init {}.weight as normal(0, 0.001)'.format(name))
434
+ # print('=> init {}.bias as 0'.format(name))
435
+ if m.weight.shape[0] == self.heads[head]:
436
+ if 'hm' in head:
437
+ nn.init.constant_(m.bias, -2.19)
438
+ else:
439
+ nn.init.normal_(m.weight, std=0.001)
440
+ nn.init.constant_(m.bias, 0)
441
+ # if isinstance(m, nn.Linear):
442
+ # if m.weight.shape[0] == self.heads[head]:
443
+ # prior = 1/71
444
+ # nn.init.constant_(m.bias, -math.log((1-prior)/prior))
445
+ # else:
446
+ # nn.init.normal_(m.weight, std=0.001)
447
+ # nn.init.constant_(m.bias, 0)
448
+
449
+ #pretrained_state_dict = torch.load(pretrained)
450
+ url = model_urls['resnet{}'.format(num_layers)]
451
+ pretrained_state_dict = model_zoo.load_url(url)
452
+ print('=> loading pretrained model {}'.format(url))
453
+ self.load_state_dict(pretrained_state_dict, strict=False)
454
+ else:
455
+ print('=> imagenet pretrained model dose not exist')
456
+ print('=> please download it first')
457
+ raise ValueError('imagenet pretrained model does not exist')
458
+
459
+
460
+ resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
461
+ 34: (BasicBlock, [3, 4, 6, 3]),
462
+ 50: (Bottleneck, [3, 4, 6, 3]),
463
+ 101: (Bottleneck, [3, 4, 23, 3]),
464
+ 152: (Bottleneck, [3, 8, 36, 3])}
models/networks/pose_efficientNet.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ import math
3
+ import sys
4
+ import os
5
+ if not os.getcwd() in sys.path:
6
+ sys.path.append(os.getcwd())
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch.utils import model_zoo
12
+
13
+ from ..builder import MODELS, build_model
14
+ from .efficientNet import (
15
+ round_filters,
16
+ round_repeats,
17
+ drop_connect,
18
+ get_same_padding_conv2d,
19
+ get_model_params,
20
+ efficientnet_params,
21
+ load_pretrained_weights,
22
+ Swish,
23
+ MemoryEfficientSwish,
24
+ calculate_output_image_size,
25
+ url_map_advprop,
26
+ url_map
27
+ )
28
+ from .common import (
29
+ InceptionBlock,
30
+ conv_block,
31
+ BN_MOMENTUM,
32
+ SELayer
33
+ )
34
+
35
+
36
+ VALID_MODELS = (
37
+ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
38
+ 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
39
+ 'efficientnet-b8',
40
+
41
+ # Support the construction of 'efficientnet-l2' without pretrained weights
42
+ 'efficientnet-l2'
43
+ )
44
+
45
+
46
+ class MBConvBlock(nn.Module):
47
+ """Mobile Inverted Residual Bottleneck Block.
48
+ Args:
49
+ block_args (namedtuple): BlockArgs, defined in utils.py.
50
+ global_params (namedtuple): GlobalParam, defined in utils.py.
51
+ image_size (tuple or list): [image_height, image_width].
52
+ References:
53
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
54
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
55
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
56
+ """
57
+
58
+ def __init__(self, block_args, global_params, image_size=None):
59
+ super().__init__()
60
+ self._block_args = block_args
61
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
62
+ self._bn_eps = global_params.batch_norm_epsilon
63
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
64
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
65
+
66
+ # Expansion phase (Inverted Bottleneck)
67
+ inp = self._block_args.input_filters # number of input channels
68
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
69
+ if self._block_args.expand_ratio != 1:
70
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
71
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
72
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
73
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
74
+
75
+ # Depthwise convolution phase
76
+ k = self._block_args.kernel_size
77
+ s = self._block_args.stride
78
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
79
+ self._depthwise_conv = Conv2d(
80
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
81
+ kernel_size=k, stride=s, bias=False)
82
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
83
+ image_size = calculate_output_image_size(image_size, s)
84
+
85
+ # Squeeze and Excitation layer, if desired
86
+ if self.has_se:
87
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
88
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
89
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
90
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
91
+
92
+ # Pointwise convolution phase
93
+ final_oup = self._block_args.output_filters
94
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
95
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
96
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
97
+ self._swish = MemoryEfficientSwish()
98
+
99
+ def forward(self, inputs, drop_connect_rate=None):
100
+ """MBConvBlock's forward function.
101
+ Args:
102
+ inputs (tensor): Input tensor.
103
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
104
+ Returns:
105
+ Output of this block after processing.
106
+ """
107
+
108
+ # Expansion and Depthwise Convolution
109
+ x = inputs
110
+ if self._block_args.expand_ratio != 1:
111
+ x = self._expand_conv(inputs)
112
+ x = self._bn0(x)
113
+ x = self._swish(x)
114
+
115
+ x = self._depthwise_conv(x)
116
+ x = self._bn1(x)
117
+ x = self._swish(x)
118
+
119
+ # Squeeze and Excitation
120
+ if self.has_se:
121
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
122
+ x_squeezed = self._se_reduce(x_squeezed)
123
+ x_squeezed = self._swish(x_squeezed)
124
+ x_squeezed = self._se_expand(x_squeezed)
125
+ x = torch.sigmoid(x_squeezed) * x
126
+
127
+ # Pointwise Convolution
128
+ x = self._project_conv(x)
129
+ x = self._bn2(x)
130
+
131
+ # Skip connection and drop connect
132
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
133
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
134
+ # The combination of skip connection and drop connect brings about stochastic depth.
135
+ if drop_connect_rate:
136
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
137
+ x = x + inputs # skip connection
138
+ return x
139
+
140
+ def set_swish(self, memory_efficient=True):
141
+ """Sets swish function as memory efficient (for training) or standard (for export).
142
+ Args:
143
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
144
+ """
145
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
146
+
147
+
148
+ @MODELS.register_module()
149
+ class EfficientNet(nn.Module):
150
+ """EfficientNet model.
151
+ Most easily loaded with the .from_name or .from_pretrained methods.
152
+ Args:
153
+ blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
154
+ global_params (namedtuple): A set of GlobalParams shared between blocks.
155
+ References:
156
+ [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
157
+ Example:
158
+ >>> import torch
159
+ >>> from efficientnet.model import EfficientNet
160
+ >>> inputs = torch.rand(1, 3, 224, 224)
161
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
162
+ >>> model.eval()
163
+ >>> outputs = model(inputs)
164
+ """
165
+
166
+ def __init__(self, blocks_args=None, global_params=None):
167
+ super().__init__()
168
+ assert isinstance(blocks_args, list), 'blocks_args should be a list'
169
+ assert len(blocks_args) > 0, 'block args must be greater than 0'
170
+ self._global_params = global_params
171
+ self._blocks_args = blocks_args
172
+
173
+ # Batch norm parameters
174
+ bn_mom = 1 - self._global_params.batch_norm_momentum
175
+ bn_eps = self._global_params.batch_norm_epsilon
176
+
177
+ # Get stem static or dynamic convolution depending on image size
178
+ image_size = global_params.image_size
179
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
180
+
181
+ # Stem
182
+ in_channels = 3 # rgb
183
+ out_channels = round_filters(32, self._global_params) # number of output channels
184
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
185
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
186
+ image_size = calculate_output_image_size(image_size, 2)
187
+
188
+ # Build blocks
189
+ self._blocks = nn.ModuleList([])
190
+ for block_args in self._blocks_args:
191
+
192
+ # Update block input and output filters based on depth multiplier.
193
+ block_args = block_args._replace(
194
+ input_filters=round_filters(block_args.input_filters, self._global_params),
195
+ output_filters=round_filters(block_args.output_filters, self._global_params),
196
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params)
197
+ )
198
+
199
+ # The first block needs to take care of stride and filter size increase.
200
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
201
+ image_size = calculate_output_image_size(image_size, block_args.stride)
202
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
203
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
204
+ for _ in range(block_args.num_repeat - 1):
205
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
206
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
207
+
208
+ # Head
209
+ in_channels = block_args.output_filters # output of final block
210
+ out_channels = round_filters(1280, self._global_params)
211
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
212
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
213
+ self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
214
+
215
+ # Final linear layer
216
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
217
+ if self._global_params.include_top:
218
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
219
+ self._fc = nn.Linear(out_channels, self._global_params.num_classes)
220
+
221
+ # Heatmap Decoder Construction
222
+ if self._global_params.include_hm_decoder:
223
+ print("Constructing the heatmap Decoder!")
224
+ self.efpn = self._global_params.efpn
225
+ self.tfpn = self._global_params.tfpn
226
+
227
+ assert not (self.efpn and self.tfpn), "Only one of E-FPN or FPN is intergrated!"
228
+
229
+ self.se_layer = self._global_params.se_layer
230
+ # self.hm_decoder_filters = [1792, 448, 160, 56] if self.fpn else [1792, 256, 256, 128]
231
+ self.hm_decoder_filters = [1792, 448, 160, 56]
232
+ num_kernels = [4, 4, 4, 4] if (self.efpn or self.tfpn) else [4, 4, 4]
233
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
234
+ self._sigmoid = nn.Sigmoid()
235
+ self._relu = nn.ReLU(inplace=True)
236
+ self._relu1 = nn.ReLU(inplace=False)
237
+ self.deconv_with_bias = False
238
+ if self._global_params.use_c3:
239
+ self.inception_block = InceptionBlock(112, 112, stride=1, pool_size=3)
240
+ else:
241
+ self.inception_block = InceptionBlock(56, 56, stride=1, pool_size=3)
242
+ self.heads = self._global_params.heads
243
+ n_deconv = len(self.hm_decoder_filters)
244
+ self.fpn_layers = [self._global_params.use_c51, self._global_params.use_c4, self._global_params.use_c3]
245
+
246
+ if self.efpn or self.tfpn:
247
+ for idx in range(n_deconv):
248
+ in_decod_filters = self.hm_decoder_filters[idx]
249
+
250
+ if idx == 0:
251
+ out_decod_filters = self.hm_decoder_filters[idx+1]
252
+ deconv = nn.Sequential(
253
+ conv_block(in_decod_filters, out_decod_filters, (3,3), stride=1, padding=1),
254
+ )
255
+ else:
256
+ in_decod_filters = in_decod_filters*2 if self.fpn_layers[idx-1] else in_decod_filters
257
+ kernel, padding, output_padding = self._get_deconv_cfg(num_kernels[idx])
258
+
259
+ if idx+1 < n_deconv:
260
+ out_decod_filters = self.hm_decoder_filters[idx+1]
261
+ deconv = nn.Sequential(
262
+ conv_block(in_decod_filters, out_decod_filters, (3,3), stride=1, padding=1),
263
+ nn.ConvTranspose2d(
264
+ in_channels=out_decod_filters,
265
+ out_channels=out_decod_filters,
266
+ kernel_size=kernel,
267
+ stride=2,
268
+ padding=padding,
269
+ output_padding=output_padding,
270
+ bias=self.deconv_with_bias),
271
+ nn.BatchNorm2d(out_decod_filters, momentum=BN_MOMENTUM),
272
+ )
273
+ else:
274
+ out_decod_filters = in_decod_filters
275
+ deconv = nn.Sequential(
276
+ self.inception_block,
277
+ nn.ConvTranspose2d(
278
+ in_channels=out_decod_filters,
279
+ out_channels=out_decod_filters,
280
+ kernel_size=kernel,
281
+ stride=2,
282
+ padding=padding,
283
+ output_padding=output_padding,
284
+ bias=self.deconv_with_bias),
285
+ nn.BatchNorm2d(out_decod_filters, momentum=BN_MOMENTUM),
286
+ )
287
+
288
+ # In case of using C2, this conv to apply to C2 features to get the same filters of the last deconv
289
+ if self._global_params.use_c2:
290
+ self.conv_c2 = conv_block(32, out_decod_filters, (3,3), stride=1, padding=1)
291
+ if self.se_layer:
292
+ se = SELayer(channel=out_decod_filters*2)
293
+ self.__setattr__(f'se_layer_{idx+1}', se)
294
+
295
+ self.__setattr__(f'deconv_{idx+1}', deconv)
296
+ else:
297
+ self.deconv_layers = self._make_deconv_layer(
298
+ len(num_kernels),
299
+ self.hm_decoder_filters,
300
+ num_kernels,
301
+ )
302
+
303
+ for head, num_output in self.heads.items():
304
+ head_conv = int(self._global_params.head_conv)
305
+ num_output = int(num_output)
306
+ if self._global_params.use_c2:
307
+ assert self._global_params.efpn or self._global_params.tfpn, "FPN Design must be set active!"
308
+ assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
309
+ in_head_filters = self.hm_decoder_filters[-1]*4
310
+ elif self._global_params.use_c3:
311
+ in_head_filters = self.hm_decoder_filters[-1]*2
312
+ else:
313
+ in_head_filters = self.hm_decoder_filters[-1]
314
+
315
+ if head_conv > 0:
316
+ if head != 'cls':
317
+ fc = nn.Sequential(
318
+ nn.Conv2d(in_head_filters, head_conv,
319
+ kernel_size=3, padding=1, bias=True),
320
+ nn.BatchNorm2d(head_conv),
321
+ nn.ReLU(inplace=True),
322
+ nn.Conv2d(head_conv, num_output,
323
+ kernel_size=1, stride=1, padding=0)
324
+ )
325
+ else:
326
+ fc = nn.Sequential(
327
+ nn.Conv2d(in_head_filters, head_conv, kernel_size=3,
328
+ padding=1, bias=True),
329
+ nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
330
+ nn.ReLU(inplace=True),
331
+ # nn.Conv2d(head_conv, num_output, kernel_size=1,
332
+ # stride=1, padding=0, bias=True),
333
+ # nn.BatchNorm2d(num_output),
334
+ # nn.ReLU(inplace=True),
335
+ # nn.AdaptiveMaxPool2d(head_conv//4),
336
+ nn.AdaptiveAvgPool2d(1),
337
+ nn.Flatten(),
338
+ # nn.Linear((head_conv//4)**2, head_conv, bias=True),
339
+ # nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
340
+ # nn.ReLU(inplace=True),
341
+ nn.Linear(head_conv, num_output, bias=True),
342
+ # nn.Sigmoid(),
343
+ # nn.Softmax(dim=-1)
344
+ )
345
+ else:
346
+ fc = nn.Conv2d(
347
+ in_channels=in_head_filters,
348
+ out_channels=num_output,
349
+ kernel_size=1,
350
+ stride=1,
351
+ padding=0
352
+ )
353
+ self.__setattr__(head, fc)
354
+
355
+ # set activation to memory efficient swish by default
356
+ self._swish = MemoryEfficientSwish()
357
+
358
+ def _get_deconv_cfg(self, deconv_kernel):
359
+ if deconv_kernel == 4:
360
+ padding = 1
361
+ output_padding = 0
362
+ elif deconv_kernel == 3:
363
+ padding = 1
364
+ output_padding = 1
365
+ elif deconv_kernel == 2:
366
+ padding = 0
367
+ output_padding = 0
368
+
369
+ return deconv_kernel, padding, output_padding
370
+
371
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
372
+ assert num_layers == (len(num_filters) - 1), \
373
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
374
+ assert num_layers == len(num_kernels), \
375
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
376
+
377
+ layers = []
378
+ for i in range(num_layers):
379
+ kernel, padding, output_padding = \
380
+ self._get_deconv_cfg(num_kernels[i])
381
+
382
+ in_planes = num_filters[i]
383
+ out_planes = num_filters[i+1]
384
+
385
+ layers.append(nn.Sequential(
386
+ nn.ConvTranspose2d(
387
+ in_channels=in_planes,
388
+ out_channels=out_planes,
389
+ kernel_size=kernel,
390
+ stride=2,
391
+ padding=padding,
392
+ output_padding=output_padding,
393
+ bias=self.deconv_with_bias),
394
+ nn.BatchNorm2d(out_planes, momentum=BN_MOMENTUM),
395
+ nn.ReLU(inplace=True))
396
+ )
397
+
398
+ return nn.Sequential(*layers)
399
+
400
+ def set_swish(self, memory_efficient=True):
401
+ """Sets swish function as memory efficient (for training) or standard (for export).
402
+ Args:
403
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
404
+ """
405
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
406
+ for block in self._blocks:
407
+ block.set_swish(memory_efficient)
408
+
409
+ def extract_endpoints(self, inputs):
410
+ """Use convolution layer to extract features
411
+ from reduction levels i in [1, 2, 3, 4, 5].
412
+ Args:
413
+ inputs (tensor): Input tensor.
414
+ Returns:
415
+ Dictionary of last intermediate features
416
+ with reduction levels i in [1, 2, 3, 4, 5].
417
+ Example:
418
+ >>> import torch
419
+ >>> from efficientnet.model import EfficientNet
420
+ >>> inputs = torch.rand(1, 3, 224, 224)
421
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
422
+ >>> endpoints = model.extract_endpoints(inputs)
423
+ >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
424
+ >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
425
+ >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
426
+ >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
427
+ >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
428
+ >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
429
+ """
430
+ endpoints = dict()
431
+
432
+ # Stem
433
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
434
+ prev_x = x
435
+
436
+ # Blocks
437
+ for idx, block in enumerate(self._blocks):
438
+ drop_connect_rate = self._global_params.drop_connect_rate
439
+ if drop_connect_rate:
440
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
441
+ x = block(x, drop_connect_rate=drop_connect_rate)
442
+ # print('Prev', prev_x.size())
443
+ # print('X', x.size())
444
+ if prev_x.size(2) > x.size(2):
445
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
446
+ elif idx == len(self._blocks) - 1:
447
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
448
+ prev_x = x
449
+
450
+ # Head
451
+ x = self._swish(self._bn1(self._conv_head(x)))
452
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
453
+
454
+ return endpoints
455
+
456
+ def extract_features(self, inputs):
457
+ """use convolution layer to extract feature .
458
+ Args:
459
+ inputs (tensor): Input tensor.
460
+ Returns:
461
+ Output of the final convolution
462
+ layer in the efficientnet model.
463
+ """
464
+ # Stem
465
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
466
+
467
+ # Blocks
468
+ for idx, block in enumerate(self._blocks):
469
+ drop_connect_rate = self._global_params.drop_connect_rate
470
+ if drop_connect_rate:
471
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
472
+ x = block(x, drop_connect_rate=drop_connect_rate)
473
+
474
+ # Head
475
+ x = self._swish(self._bn1(self._conv_head(x)))
476
+
477
+ return x
478
+
479
+ def forward(self, inputs):
480
+ """EfficientNet's forward function.
481
+ Calls extract_features to extract features, applies final linear layer, and returns logits.
482
+ Args:
483
+ inputs (tensor): Input tensor.
484
+ Returns:
485
+ Output of this model after processing.
486
+ """
487
+ # Convolution layers
488
+ # x = self.extract_features(inputs)
489
+ endpoints = self.extract_endpoints(inputs)
490
+ x1 = endpoints['reduction_6']
491
+ x2 = endpoints['reduction_5']
492
+ x3 = endpoints['reduction_4']
493
+ x4 = endpoints['reduction_3']
494
+ x5 = endpoints['reduction_2']
495
+ x = x1
496
+
497
+ if self._global_params.include_top:
498
+ # Pooling and final linear layer
499
+ x = self._avg_pooling(x)
500
+
501
+ x = x.flatten(start_dim=1)
502
+ x = self._dropout(x)
503
+ x = self._fc(x)
504
+ return x
505
+
506
+ if self._global_params.include_hm_decoder:
507
+ x1 = self._dropout(x1)
508
+ x2 = self._dropout(x2)
509
+ x3 = self._dropout(x3)
510
+ x4 = self._dropout(x4)
511
+
512
+ if self.efpn:
513
+ assert self._global_params.use_c51, "C51 must be utilized for FPN intergration"
514
+
515
+ x = self.__getattr__('deconv_1')(x1)
516
+
517
+ if self._global_params.use_c51:
518
+ x_weighted = self._sigmoid(x)
519
+ x_inv = torch.sub(1, x_weighted, alpha=1)
520
+ x2_ = torch.multiply(x_inv, x2)
521
+ x = torch.cat([x, x2_], dim=1)
522
+
523
+ if self.se_layer:
524
+ x = self.__getattr__('se_layer_1')(x)
525
+ else:
526
+ x = self._relu(x)
527
+
528
+ x = self.__getattr__('deconv_2')(x)
529
+
530
+ if self._global_params.use_c4:
531
+ x_weighted = self._sigmoid(x)
532
+ x_inv = torch.sub(1, x_weighted, alpha=1)
533
+ x3_ = torch.multiply(x_inv, x3)
534
+ x = torch.cat([x, x3_], dim=1)
535
+
536
+ if self.se_layer:
537
+ x = self.__getattr__('se_layer_2')(x)
538
+ else:
539
+ x = self._relu(x)
540
+
541
+ x = self.__getattr__('deconv_3')(x)
542
+
543
+ if self._global_params.use_c3:
544
+ assert self._global_params.use_c4, "C4 must be utilized for FPN intergration of C3"
545
+
546
+ x_weighted = self._sigmoid(x)
547
+ x_inv = torch.sub(1, x_weighted, alpha=1)
548
+ x4_ = torch.multiply(x_inv, x4)
549
+ x = torch.cat([x, x4_], dim=1)
550
+
551
+ if self.se_layer:
552
+ x = self.__getattr__('se_layer_3')(x)
553
+ else:
554
+ x = self._relu(x)
555
+
556
+ x = self.__getattr__('deconv_4')(x)
557
+
558
+ if not self._global_params.use_c2:
559
+ x = self._relu(x)
560
+ else:
561
+ assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
562
+
563
+ x5 = self._dropout(x5)
564
+ x5_ = self.conv_c2(x5)
565
+ x_weighted = self._sigmoid(x)
566
+ x_inv = torch.sub(1, x_weighted, alpha=1)
567
+ x5_ = torch.multiply(x_inv, x5_)
568
+ x = torch.cat([x, x5_], dim=1)
569
+
570
+ if self.se_layer:
571
+ x = self.__getattr__('se_layer_4')(x)
572
+ elif self.tfpn:
573
+ assert self._global_params.use_c51, "C51 must be utilized for FPN intergration"
574
+ x = self.__getattr__('deconv_1')(x1)
575
+ x = self._relu1(x)
576
+ x = torch.cat([x, x2], dim=1)
577
+
578
+ x = self.__getattr__('deconv_2')(x)
579
+ if not self._global_params.use_c4:
580
+ x = self._relu1(x)
581
+ else:
582
+ x = torch.cat([x, x3], dim=1)
583
+
584
+ x = self.__getattr__('deconv_3')(x)
585
+ if not self._global_params.use_c3:
586
+ x = self._relu1(x)
587
+ else:
588
+ assert self._global_params.use_c4, "C4 must be utilized for FPN intergration of C3"
589
+ x = torch.cat([x, x4], dim=1)
590
+
591
+ x = self.__getattr__('deconv_4')(x)
592
+ if not self._global_params.use_c2:
593
+ x = self._relu(x)
594
+ else:
595
+ assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
596
+ x5 = self._dropout(x5)
597
+ x5 = self.conv_c2(x5)
598
+ x = self._relu1(x)
599
+ x = torch.cat([x, x5], dim=1)
600
+ else:
601
+ x = self.deconv_layers(x1)
602
+
603
+ ret = {}
604
+ for head in self.heads:
605
+ ret[head] = self.__getattr__(head)(x)
606
+
607
+ return [ret]
608
+
609
+ @classmethod
610
+ def from_name(cls, model_name, in_channels=3, **override_params):
611
+ """Create an efficientnet model according to name.
612
+ Args:
613
+ model_name (str): Name for efficientnet.
614
+ in_channels (int): Input data's channel number.
615
+ override_params (other key word params):
616
+ Params to override model's global_params.
617
+ Optional key:
618
+ 'width_coefficient', 'depth_coefficient',
619
+ 'image_size', 'dropout_rate',
620
+ 'num_classes', 'batch_norm_momentum',
621
+ 'batch_norm_epsilon', 'drop_connect_rate',
622
+ 'depth_divisor', 'min_depth'
623
+ Returns:
624
+ An efficientnet model.
625
+ """
626
+ cls._check_model_name_is_valid(model_name)
627
+ blocks_args, global_params = get_model_params(model_name, override_params)
628
+ model = cls(blocks_args, global_params)
629
+ model._change_in_channels(in_channels)
630
+ return model
631
+
632
+ @classmethod
633
+ def from_pretrained(cls, model_name, weights_path=None, advprop=False,
634
+ in_channels=3, num_classes=1000, **override_params):
635
+ """Create an efficientnet model according to name.
636
+ Args:
637
+ model_name (str): Name for efficientnet.
638
+ weights_path (None or str):
639
+ str: path to pretrained weights file on the local disk.
640
+ None: use pretrained weights downloaded from the Internet.
641
+ advprop (bool):
642
+ Whether to load pretrained weights
643
+ trained with advprop (valid when weights_path is None).
644
+ in_channels (int): Input data's channel number.
645
+ num_classes (int):
646
+ Number of categories for classification.
647
+ It controls the output size for final linear layer.
648
+ override_params (other key word params):
649
+ Params to override model's global_params.
650
+ Optional key:
651
+ 'width_coefficient', 'depth_coefficient',
652
+ 'image_size', 'dropout_rate',
653
+ 'batch_norm_momentum',
654
+ 'batch_norm_epsilon', 'drop_connect_rate',
655
+ 'depth_divisor', 'min_depth'
656
+ Returns:
657
+ A pretrained efficientnet model.
658
+ """
659
+ model = cls.from_name(model_name, num_classes=num_classes, **override_params)
660
+ load_pretrained_weights(model, model_name, weights_path=weights_path,
661
+ load_fc=((num_classes == 1000) and (model._global_params.include_top)), advprop=advprop)
662
+ model._change_in_channels(in_channels)
663
+ return model
664
+
665
+ @classmethod
666
+ def get_image_size(cls, model_name):
667
+ """Get the input image size for a given efficientnet model.
668
+ Args:
669
+ model_name (str): Name for efficientnet.
670
+ Returns:
671
+ Input image size (resolution).
672
+ """
673
+ cls._check_model_name_is_valid(model_name)
674
+ _, _, res, _ = efficientnet_params(model_name)
675
+ return res
676
+
677
+ @classmethod
678
+ def _check_model_name_is_valid(cls, model_name):
679
+ """Validates model name.
680
+ Args:
681
+ model_name (str): Name for efficientnet.
682
+ Returns:
683
+ bool: Is a valid name or not.
684
+ """
685
+ if model_name not in VALID_MODELS:
686
+ raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
687
+
688
+ def _change_in_channels(self, in_channels):
689
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
690
+ Args:
691
+ in_channels (int): Input data's channel number.
692
+ """
693
+ if in_channels != 3:
694
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
695
+ out_channels = round_filters(32, self._global_params)
696
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
697
+
698
+
699
+ @MODELS.register_module()
700
+ class PoseEfficientNet(EfficientNet):
701
+ def __init__(self, model_name, in_channels=3, **override_params):
702
+ self.model_name = model_name
703
+ self.in_channels = in_channels
704
+
705
+ # Initialize Parent Class
706
+ super()._check_model_name_is_valid(model_name)
707
+ blocks_args, global_params = get_model_params(model_name, override_params)
708
+ super().__init__(blocks_args, global_params)
709
+
710
+ @classmethod
711
+ def from_name(cls, model_name, in_channels, **override_params):
712
+ return NotImplemented
713
+
714
+ @classmethod
715
+ def from_pretrained(cls, model_name, weights_path, advprop, in_channels, num_classes, **override_params):
716
+ return NotImplemented
717
+
718
+ def _change_in_channels(self, in_channels):
719
+ return NotImplemented
720
+
721
+ def init_weights(self, pretrained=False, advprop=False, verbose=True):
722
+ if pretrained:
723
+ url_map_ = url_map_advprop if advprop else url_map
724
+ state_dict = model_zoo.load_url(url_map_[self.model_name])
725
+ self.load_state_dict(state_dict, strict=False)
726
+
727
+ # Initialize weights for Deconvolution Layer
728
+ if self._global_params.include_hm_decoder:
729
+ if self.efpn or self.tfpn:
730
+ deconv_layers = [self.deconv_1, self.deconv_2, self.deconv_3, self.deconv_4]
731
+ else:
732
+ deconv_layers = self.deconv_layers
733
+
734
+ for layer in deconv_layers:
735
+ for _, m in layer.named_modules():
736
+ if isinstance(m, nn.ConvTranspose2d):
737
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
738
+ m.weight.data.normal_(0, math.sqrt(2. / n))
739
+ if self.deconv_with_bias:
740
+ nn.init.constant_(m.bias, 0)
741
+ elif isinstance(m, nn.BatchNorm2d):
742
+ nn.init.constant_(m.weight, 1)
743
+ nn.init.constant_(m.bias, 0)
744
+
745
+ # Init head parameters
746
+ for head in self.heads:
747
+ final_layer = self.__getattr__(head)
748
+ for i, m in enumerate(final_layer.modules()):
749
+ if isinstance(m, nn.Conv2d):
750
+ if m.weight.shape[0] == self.heads[head]:
751
+ if 'hm' in head:
752
+ nn.init.constant_(m.bias, -2.19)
753
+ else:
754
+ # nn.init.normal_(m.weight, std=0.001)
755
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
756
+ m.weight.data.normal_(0, math.sqrt(2. / n))
757
+ nn.init.constant_(m.bias, 0)
758
+
759
+ self._change_in_channels(in_channels=self.in_channels)
760
+ if verbose:
761
+ print('Loaded pretrained weights for {}'.format(self.model_name))
762
+
763
+
764
+ if __name__ == '__main__':
765
+ cfg = dict(type='PoseEfficientNet',
766
+ model_name='efficientnet-b4',
767
+ include_top=False,
768
+ include_hm_decoder=True,
769
+ head_conv=64,
770
+ heads={'hm':1, 'cls':1, 'cstency':256},
771
+ use_c2=True)
772
+ model = build_model(cfg, MODELS)
773
+ model.init_weights(pretrained=True)
774
+ model.eval()
775
+ inputs = torch.rand((1, 3, 384, 384))
776
+
777
+ for i, (n, p) in enumerate(model.named_parameters()):
778
+ print(i, n)
779
+
780
+ # To show the whole pose EFN model outputs shape
781
+ x = model(inputs)[0]
782
+ for head in x.keys():
783
+ print(f'{head} shape is --- {x[head].shape}')
784
+
785
+ # To show the endpoints features shape
786
+ # endpoints = model.extract_endpoints(inputs)
787
+ # for k in endpoints.keys():
788
+ # print(endpoints[k].shape)
models/networks/pose_hrnet.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import logging
8
+ import re
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from ..builder import MODELS
13
+
14
+ from .common import conv3x3, BN_MOMENTUM
15
+
16
+
17
+ class BasicBlock(nn.Module):
18
+ expansion = 1
19
+
20
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
21
+ super(BasicBlock, self).__init__()
22
+ self.conv1 = conv3x3(inplanes, planes, stride)
23
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
24
+ self.relu = nn.ReLU(inplace=True)
25
+ self.conv2 = conv3x3(planes, planes)
26
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
27
+ self.downsample = downsample
28
+ self.stride = stride
29
+
30
+ def forward(self, x):
31
+ residual = x
32
+
33
+ out = self.conv1(x)
34
+ out = self.bn1(out)
35
+ out = self.relu(out)
36
+
37
+ out = self.conv2(out)
38
+ out = self.bn2(out)
39
+
40
+ if self.downsample is not None:
41
+ residual = self.downsample(x)
42
+
43
+ out += residual
44
+ out = self.relu(out)
45
+
46
+ return out
47
+
48
+
49
+ class Bottleneck(nn.Module):
50
+ expansion = 4
51
+
52
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
53
+ super(Bottleneck, self).__init__()
54
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
55
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
56
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
57
+ padding=1, bias=False)
58
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
59
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
60
+ bias=False)
61
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
62
+ momentum=BN_MOMENTUM)
63
+ self.relu = nn.ReLU(inplace=True)
64
+ self.downsample = downsample
65
+ self.stride = stride
66
+
67
+ def forward(self, x):
68
+ residual = x
69
+
70
+ out = self.conv1(x)
71
+ out = self.bn1(out)
72
+ out = self.relu(out)
73
+
74
+ out = self.conv2(out)
75
+ out = self.bn2(out)
76
+ out = self.relu(out)
77
+
78
+ out = self.conv3(out)
79
+ out = self.bn3(out)
80
+
81
+ if self.downsample is not None:
82
+ residual = self.downsample(x)
83
+
84
+ out += residual
85
+ out = self.relu(out)
86
+
87
+ return out
88
+
89
+
90
+ class HighResolutionModule(nn.Module):
91
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
92
+ num_channels, fuse_method, multi_scale_output=True):
93
+ super(HighResolutionModule, self).__init__()
94
+ self._check_branches(
95
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
96
+
97
+ self.num_inchannels = num_inchannels
98
+ self.fuse_method = fuse_method
99
+ self.num_branches = num_branches
100
+
101
+ self.multi_scale_output = multi_scale_output
102
+
103
+ self.branches = self._make_branches(
104
+ num_branches, blocks, num_blocks, num_channels)
105
+ self.fuse_layers = self._make_fuse_layers()
106
+ self.relu = nn.ReLU(True)
107
+
108
+ def _check_branches(self, num_branches, blocks, num_blocks,
109
+ num_inchannels, num_channels):
110
+ if num_branches != len(num_blocks):
111
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
112
+ num_branches, len(num_blocks))
113
+ # logger.error(error_msg)
114
+ raise ValueError(error_msg)
115
+
116
+ if num_branches != len(num_channels):
117
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
118
+ num_branches, len(num_channels))
119
+ # logger.error(error_msg)
120
+ raise ValueError(error_msg)
121
+
122
+ if num_branches != len(num_inchannels):
123
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
124
+ num_branches, len(num_inchannels))
125
+ # logger.error(error_msg)
126
+ raise ValueError(error_msg)
127
+
128
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
129
+ stride=1):
130
+ downsample = None
131
+ if stride != 1 or \
132
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
133
+ downsample = nn.Sequential(
134
+ nn.Conv2d(
135
+ self.num_inchannels[branch_index],
136
+ num_channels[branch_index] * block.expansion,
137
+ kernel_size=1, stride=stride, bias=False
138
+ ),
139
+ nn.BatchNorm2d(
140
+ num_channels[branch_index] * block.expansion,
141
+ momentum=BN_MOMENTUM
142
+ ),
143
+ )
144
+
145
+ layers = []
146
+ layers.append(
147
+ block(
148
+ self.num_inchannels[branch_index],
149
+ num_channels[branch_index],
150
+ stride,
151
+ downsample
152
+ )
153
+ )
154
+ self.num_inchannels[branch_index] = \
155
+ num_channels[branch_index] * block.expansion
156
+ for i in range(1, num_blocks[branch_index]):
157
+ layers.append(
158
+ block(
159
+ self.num_inchannels[branch_index],
160
+ num_channels[branch_index]
161
+ )
162
+ )
163
+
164
+ return nn.Sequential(*layers)
165
+
166
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
167
+ branches = []
168
+
169
+ for i in range(num_branches):
170
+ branches.append(
171
+ self._make_one_branch(i, block, num_blocks, num_channels)
172
+ )
173
+
174
+ return nn.ModuleList(branches)
175
+
176
+ def _make_fuse_layers(self):
177
+ if self.num_branches == 1:
178
+ return None
179
+
180
+ num_branches = self.num_branches
181
+ num_inchannels = self.num_inchannels
182
+ fuse_layers = []
183
+ for i in range(num_branches if self.multi_scale_output else 1):
184
+ fuse_layer = []
185
+ for j in range(num_branches):
186
+ if j > i:
187
+ fuse_layer.append(
188
+ nn.Sequential(
189
+ nn.Conv2d(
190
+ num_inchannels[j],
191
+ num_inchannels[i],
192
+ 1, 1, 0, bias=False
193
+ ),
194
+ nn.BatchNorm2d(num_inchannels[i]),
195
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')
196
+ )
197
+ )
198
+ elif j == i:
199
+ fuse_layer.append(None)
200
+ else:
201
+ conv3x3s = []
202
+ for k in range(i-j):
203
+ if k == i - j - 1:
204
+ num_outchannels_conv3x3 = num_inchannels[i]
205
+ conv3x3s.append(
206
+ nn.Sequential(
207
+ nn.Conv2d(
208
+ num_inchannels[j],
209
+ num_outchannels_conv3x3,
210
+ 3, 2, 1, bias=False
211
+ ),
212
+ nn.BatchNorm2d(num_outchannels_conv3x3)
213
+ )
214
+ )
215
+ else:
216
+ num_outchannels_conv3x3 = num_inchannels[j]
217
+ conv3x3s.append(
218
+ nn.Sequential(
219
+ nn.Conv2d(
220
+ num_inchannels[j],
221
+ num_outchannels_conv3x3,
222
+ 3, 2, 1, bias=False
223
+ ),
224
+ nn.BatchNorm2d(num_outchannels_conv3x3),
225
+ nn.ReLU(True)
226
+ )
227
+ )
228
+ fuse_layer.append(nn.Sequential(*conv3x3s))
229
+ fuse_layers.append(nn.ModuleList(fuse_layer))
230
+
231
+ return nn.ModuleList(fuse_layers)
232
+
233
+ def get_num_inchannels(self):
234
+ return self.num_inchannels
235
+
236
+ def forward(self, x):
237
+ if self.num_branches == 1:
238
+ return [self.branches[0](x[0])]
239
+
240
+ for i in range(self.num_branches):
241
+ x[i] = self.branches[i](x[i])
242
+
243
+ x_fuse = []
244
+
245
+ for i in range(len(self.fuse_layers)):
246
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
247
+ for j in range(1, self.num_branches):
248
+ if i == j:
249
+ y = y + x[j]
250
+ else:
251
+ y = y + self.fuse_layers[i][j](x[j])
252
+ x_fuse.append(self.relu(y))
253
+
254
+ return x_fuse
255
+
256
+
257
+ blocks_dict = {
258
+ 'BASIC': BasicBlock,
259
+ 'BOTTLENECK': Bottleneck
260
+ }
261
+
262
+
263
+ @MODELS.register_module()
264
+ class PoseHighResolutionNet(nn.Module):
265
+ def __init__(self,
266
+ cfg,
267
+ **kwargs):
268
+ self.inplanes = 64
269
+ extra = cfg.MODEL.EXTRA
270
+ self.cls_based_hm = cfg.MODEL.cls_based_hm
271
+ self.heads = cfg.MODEL.heads
272
+ super(PoseHighResolutionNet, self).__init__()
273
+
274
+ # stem net
275
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
276
+ bias=False)
277
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
278
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
279
+ bias=False)
280
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
281
+ self.relu = nn.ReLU(inplace=True)
282
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
283
+
284
+ self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
285
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
286
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
287
+ num_channels = [
288
+ num_channels[i] * block.expansion for i in range(len(num_channels))
289
+ ]
290
+ self.transition1 = self._make_transition_layer([256], num_channels)
291
+ self.stage2, pre_stage_channels = self._make_stage(
292
+ self.stage2_cfg, num_channels)
293
+
294
+ self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
295
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
296
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
297
+ num_channels = [
298
+ num_channels[i] * block.expansion for i in range(len(num_channels))
299
+ ]
300
+ self.transition2 = self._make_transition_layer(
301
+ pre_stage_channels, num_channels)
302
+ self.stage3, pre_stage_channels = self._make_stage(
303
+ self.stage3_cfg, num_channels)
304
+
305
+ self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
306
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
307
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
308
+ num_channels = [
309
+ num_channels[i] * block.expansion for i in range(len(num_channels))
310
+ ]
311
+ self.transition3 = self._make_transition_layer(
312
+ pre_stage_channels, num_channels)
313
+ self.stage4, pre_stage_channels = self._make_stage(
314
+ self.stage4_cfg, num_channels, multi_scale_output=False)
315
+
316
+ self.final_layer = nn.Conv2d(
317
+ in_channels=pre_stage_channels[0],
318
+ out_channels=cfg.MODEL.NUM_JOINTS,
319
+ kernel_size=extra.FINAL_CONV_KERNEL,
320
+ stride=1,
321
+ padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
322
+ )
323
+
324
+ self.final_layer_cls = nn.Sequential(
325
+ nn.BatchNorm2d(cfg.MODEL.NUM_JOINTS, momentum=BN_MOMENTUM),
326
+ nn.AdaptiveMaxPool2d(cfg.MODEL.HEATMAP_SIZE[0]//4),
327
+ nn.Flatten(),
328
+ nn.Linear((cfg.MODEL.HEATMAP_SIZE[0]//4)**2, cfg.MODEL.NUM_JOINTS, bias=True),
329
+ nn.Sigmoid()
330
+ )
331
+
332
+ self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS']
333
+
334
+ def _make_transition_layer(
335
+ self, num_channels_pre_layer, num_channels_cur_layer):
336
+ num_branches_cur = len(num_channels_cur_layer)
337
+ num_branches_pre = len(num_channels_pre_layer)
338
+
339
+ transition_layers = []
340
+ for i in range(num_branches_cur):
341
+ if i < num_branches_pre:
342
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
343
+ transition_layers.append(
344
+ nn.Sequential(
345
+ nn.Conv2d(
346
+ num_channels_pre_layer[i],
347
+ num_channels_cur_layer[i],
348
+ 3, 1, 1, bias=False
349
+ ),
350
+ nn.BatchNorm2d(num_channels_cur_layer[i]),
351
+ nn.ReLU(inplace=True)
352
+ )
353
+ )
354
+ else:
355
+ transition_layers.append(None)
356
+ else:
357
+ conv3x3s = []
358
+ for j in range(i+1-num_branches_pre):
359
+ inchannels = num_channels_pre_layer[-1]
360
+ outchannels = num_channels_cur_layer[i] \
361
+ if j == i-num_branches_pre else inchannels
362
+ conv3x3s.append(
363
+ nn.Sequential(
364
+ nn.Conv2d(
365
+ inchannels, outchannels, 3, 2, 1, bias=False
366
+ ),
367
+ nn.BatchNorm2d(outchannels),
368
+ nn.ReLU(inplace=True)
369
+ )
370
+ )
371
+ transition_layers.append(nn.Sequential(*conv3x3s))
372
+
373
+ return nn.ModuleList(transition_layers)
374
+
375
+ def _make_layer(self, block, planes, blocks, stride=1):
376
+ downsample = None
377
+ if stride != 1 or self.inplanes != planes * block.expansion:
378
+ downsample = nn.Sequential(
379
+ nn.Conv2d(
380
+ self.inplanes, planes * block.expansion,
381
+ kernel_size=1, stride=stride, bias=False
382
+ ),
383
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
384
+ )
385
+
386
+ layers = []
387
+ layers.append(block(self.inplanes, planes, stride, downsample))
388
+ self.inplanes = planes * block.expansion
389
+ for i in range(1, blocks):
390
+ layers.append(block(self.inplanes, planes))
391
+
392
+ return nn.Sequential(*layers)
393
+
394
+ def _make_stage(self, layer_config, num_inchannels,
395
+ multi_scale_output=True):
396
+ num_modules = layer_config['NUM_MODULES']
397
+ num_branches = layer_config['NUM_BRANCHES']
398
+ num_blocks = layer_config['NUM_BLOCKS']
399
+ num_channels = layer_config['NUM_CHANNELS']
400
+ block = blocks_dict[layer_config['BLOCK']]
401
+ fuse_method = layer_config['FUSE_METHOD']
402
+
403
+ modules = []
404
+ for i in range(num_modules):
405
+ # multi_scale_output is only used last module
406
+ if not multi_scale_output and i == num_modules - 1:
407
+ reset_multi_scale_output = False
408
+ else:
409
+ reset_multi_scale_output = True
410
+
411
+ modules.append(
412
+ HighResolutionModule(
413
+ num_branches,
414
+ block,
415
+ num_blocks,
416
+ num_inchannels,
417
+ num_channels,
418
+ fuse_method,
419
+ reset_multi_scale_output
420
+ )
421
+ )
422
+ num_inchannels = modules[-1].get_num_inchannels()
423
+
424
+ return nn.Sequential(*modules), num_inchannels
425
+
426
+ def forward(self, x):
427
+ x = self.conv1(x)
428
+ x = self.bn1(x)
429
+ x = self.relu(x)
430
+ x = self.conv2(x)
431
+ x = self.bn2(x)
432
+ x = self.relu(x)
433
+ x = self.layer1(x)
434
+
435
+ x_list = []
436
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
437
+ if self.transition1[i] is not None:
438
+ x_list.append(self.transition1[i](x))
439
+ else:
440
+ x_list.append(x)
441
+ y_list = self.stage2(x_list)
442
+
443
+ x_list = []
444
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
445
+ if self.transition2[i] is not None:
446
+ x_list.append(self.transition2[i](y_list[-1]))
447
+ else:
448
+ x_list.append(y_list[i])
449
+ y_list = self.stage3(x_list)
450
+
451
+ x_list = []
452
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
453
+ if self.transition3[i] is not None:
454
+ x_list.append(self.transition3[i](y_list[-1]))
455
+ else:
456
+ x_list.append(y_list[i])
457
+ y_list = self.stage4(x_list)
458
+
459
+ x = self.final_layer(y_list[0])
460
+
461
+ ret = {}
462
+ for head in self.heads.keys():
463
+ if head == 'hm':
464
+ ret[head] = x
465
+ else:
466
+ x1 = self.final_layer_cls(x)
467
+ ret[head] = x1
468
+ return [ret]
469
+
470
+ def init_weights(self, pretrained='', **kwargs):
471
+ for m in self.modules():
472
+ if isinstance(m, nn.Conv2d):
473
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
474
+ nn.init.normal_(m.weight, std=0.001)
475
+ for name, _ in m.named_parameters():
476
+ if name in ['bias']:
477
+ nn.init.constant_(m.bias, 0)
478
+ elif isinstance(m, nn.BatchNorm2d):
479
+ nn.init.constant_(m.weight, 1)
480
+ nn.init.constant_(m.bias, 0)
481
+ elif isinstance(m, nn.ConvTranspose2d):
482
+ nn.init.normal_(m.weight, std=0.001)
483
+ for name, _ in m.named_parameters():
484
+ if name in ['bias']:
485
+ nn.init.constant_(m.bias, 0)
486
+
487
+ if os.path.isfile(pretrained):
488
+ pretrained_state_dict = torch.load(pretrained)
489
+
490
+ need_init_state_dict = {}
491
+ for name, m in pretrained_state_dict.items():
492
+ if name.split('.')[0] in self.pretrained_layers \
493
+ or self.pretrained_layers[0] == '*':
494
+ need_init_state_dict[name] = m
495
+ self.load_state_dict(need_init_state_dict, strict=False)
496
+ elif pretrained:
497
+ raise ValueError('{} is not exist!'.format(pretrained))
498
+
499
+
500
+ def get_pose_net(cfg, is_train, **kwargs):
501
+ model = PoseHighResolutionNet(cfg, **kwargs)
502
+
503
+ if is_train and cfg.MODEL.INIT_WEIGHTS:
504
+ model.init_weights(cfg.MODEL.PRETRAINED)
505
+
506
+ return model
507
+
508
+
509
+ if __name__ == "__main__":
510
+ from configs.get_config import load_config
511
+ from builder import build_model
512
+ cfg = load_config("configs/hrnet_sbi.yaml")
513
+
514
+ hrnet = build_model(cfg.MODEL, MODELS, default_args=dict(cfg=cfg))
515
+ print(hrnet)
models/networks/xception.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Creates an Xception Model as defined in:
3
+
4
+ Francois Chollet
5
+ Xception: Deep Learning with Depthwise Separable Convolutions
6
+ https://arxiv.org/pdf/1610.02357.pdf
7
+
8
+ This weights ported from the Keras implementation. Achieves the following performance on the validation set:
9
+
10
+ Loss:0.9173 Prec@1:78.892 Prec@5:94.292
11
+
12
+ REMEMBER to set your image size to 3x299x299 for both test and validation
13
+
14
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
15
+ std=[0.5, 0.5, 0.5])
16
+
17
+ The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
18
+ """
19
+ import math
20
+
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.utils.model_zoo as model_zoo
24
+ from torch.nn import init
25
+ import torch
26
+
27
+ from ..builder import MODELS
28
+ from .common import conv_block, BN_MOMENTUM
29
+
30
+
31
+ model_urls = {
32
+ 'xception':'https://www.dropbox.com/s/1hplpzet9d7dv29/xception-c0a72b38.pth.tar?dl=1'
33
+ }
34
+
35
+
36
+ class SeparableConv2d(nn.Module):
37
+ def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
38
+ super(SeparableConv2d,self).__init__()
39
+
40
+ self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
41
+ self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
42
+
43
+ def forward(self,x):
44
+ x = self.conv1(x)
45
+ x = self.pointwise(x)
46
+ return x
47
+
48
+
49
+ class Block(nn.Module):
50
+ def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
51
+ super(Block, self).__init__()
52
+
53
+ if out_filters != in_filters or strides!=1:
54
+ self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
55
+ self.skipbn = nn.BatchNorm2d(out_filters)
56
+ else:
57
+ self.skip=None
58
+
59
+ self.relu = nn.ReLU(inplace=True)
60
+ rep=[]
61
+
62
+ filters=in_filters
63
+ if grow_first:
64
+ rep.append(self.relu)
65
+ rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
66
+ rep.append(nn.BatchNorm2d(out_filters))
67
+ filters = out_filters
68
+
69
+ for i in range(reps-1):
70
+ rep.append(self.relu)
71
+ rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
72
+ rep.append(nn.BatchNorm2d(filters))
73
+
74
+ if not grow_first:
75
+ rep.append(self.relu)
76
+ rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
77
+ rep.append(nn.BatchNorm2d(out_filters))
78
+
79
+ if not start_with_relu:
80
+ rep = rep[1:]
81
+ else:
82
+ rep[0] = nn.ReLU(inplace=False)
83
+
84
+ if strides != 1:
85
+ rep.append(nn.MaxPool2d(3,strides,1))
86
+ self.rep = nn.Sequential(*rep)
87
+
88
+ def forward(self,inp):
89
+ x = self.rep(inp)
90
+
91
+ if self.skip is not None:
92
+ skip = self.skip(inp)
93
+ skip = self.skipbn(skip)
94
+ else:
95
+ skip = inp
96
+
97
+ x+=skip
98
+ return x
99
+
100
+
101
+ @MODELS.register_module()
102
+ class Xception(nn.Module):
103
+ """
104
+ Xception optimized for the ImageNet dataset, as specified in
105
+ https://arxiv.org/pdf/1610.02357.pdf
106
+ """
107
+ def __init__(self,
108
+ heads,
109
+ head_conv=64,
110
+ cls_based_hm=True,
111
+ dropout_prob=0.5,
112
+ **kwargs):
113
+ """ Constructor
114
+ Args:
115
+ num_classes: number of classes
116
+ """
117
+ self.heads = heads
118
+ self.head_conv = head_conv
119
+ self.cls_based_hm = cls_based_hm
120
+ self.dropout_prob = dropout_prob
121
+ super(Xception, self).__init__()
122
+
123
+ self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
124
+ self.bn1 = nn.BatchNorm2d(32)
125
+ self.relu = nn.ReLU(inplace=True)
126
+
127
+ self.conv2 = nn.Conv2d(32,64,3,bias=False)
128
+ self.bn2 = nn.BatchNorm2d(64)
129
+ #do relu here
130
+
131
+ self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
132
+ self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
133
+ self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)
134
+
135
+ self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
136
+ self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
137
+ self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
138
+ self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)
139
+
140
+ self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
141
+ self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
142
+ self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
143
+ self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)
144
+
145
+ self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
146
+
147
+ self.conv3 = SeparableConv2d(1024,1536,3,1,1)
148
+ self.bn3 = nn.BatchNorm2d(1536)
149
+
150
+ #do relu here
151
+ self.conv4 = SeparableConv2d(1536,2048,3,1,1)
152
+ self.bn4 = nn.BatchNorm2d(2048)
153
+
154
+ self.dropout = nn.Dropout2d(p=self.dropout_prob)
155
+
156
+ self.conv_block_1 = conv_block(2048, 256, (3,3), padding=1)
157
+ self.deconv_1 = nn.Sequential(
158
+ nn.ConvTranspose2d(
159
+ in_channels=256,
160
+ out_channels=256,
161
+ kernel_size=(4,4),
162
+ stride=2,
163
+ padding=1,
164
+ output_padding=0,
165
+ bias=False),
166
+ nn.BatchNorm2d(256, momentum=BN_MOMENTUM),
167
+ nn.ReLU(inplace=True)
168
+ )
169
+
170
+ self.conv_block_2 = conv_block(256, 256, (3,3), padding=1)
171
+ self.deconv_2 = nn.Sequential(
172
+ nn.ConvTranspose2d(
173
+ in_channels=256,
174
+ out_channels=128,
175
+ kernel_size=(4,4),
176
+ stride=2,
177
+ padding=1,
178
+ output_padding=0,
179
+ bias=False),
180
+ nn.BatchNorm2d(128, momentum=BN_MOMENTUM),
181
+ nn.ReLU(inplace=True)
182
+ )
183
+
184
+ self.conv_block_3 = conv_block(128, 128, (3,3), padding=1)
185
+ self.deconv_3 = nn.Sequential(
186
+ nn.ConvTranspose2d(
187
+ in_channels=128,
188
+ out_channels=64,
189
+ kernel_size=(4,4),
190
+ stride=2,
191
+ padding=1,
192
+ output_padding=0,
193
+ bias=False),
194
+ nn.BatchNorm2d(64, momentum=BN_MOMENTUM),
195
+ nn.ReLU(inplace=True)
196
+ )
197
+
198
+ for head in sorted(self.heads):
199
+ num_output = self.heads[head]
200
+ if self.head_conv > 0:
201
+ if head != 'cls':
202
+ fc = nn.Sequential(
203
+ nn.Conv2d(64, self.head_conv,
204
+ kernel_size=3, padding=1, bias=False),
205
+ nn.BatchNorm2d(self.head_conv),
206
+ nn.ReLU(inplace=True),
207
+ nn.Conv2d(self.head_conv, num_output,
208
+ kernel_size=1, stride=1, padding=0)
209
+ )
210
+ else:
211
+ if self.cls_based_hm:
212
+ fc = nn.Sequential(
213
+ nn.AdaptiveAvgPool2d(head_conv//4),
214
+ nn.Flatten(),
215
+ nn.Linear((head_conv//4)**2, head_conv, bias=False),
216
+ nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
217
+ nn.ReLU(inplace=True),
218
+ nn.Linear(head_conv, num_output, bias=True),
219
+ nn.Sigmoid()
220
+ )
221
+ else:
222
+ fc = nn.Sequential(
223
+ nn.Conv2d(64, head_conv, kernel_size=3,
224
+ padding=1, bias=False),
225
+ nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
226
+ nn.ReLU(inplace=True),
227
+ nn.Conv2d(head_conv, num_output, kernel_size=1,
228
+ stride=1, padding=0, bias=False),
229
+ nn.BatchNorm2d(num_output),
230
+ # nn.ReLU(inplace=True),
231
+ nn.AdaptiveAvgPool2d(head_conv//4),
232
+ nn.Flatten(),
233
+ nn.Linear((head_conv//4)**2, head_conv, bias=False),
234
+ nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
235
+ nn.ReLU(inplace=True),
236
+ nn.Linear(head_conv, num_output, bias=True),
237
+ nn.Sigmoid()
238
+ )
239
+ else:
240
+ fc = nn.Conv2d(
241
+ in_channels=64,
242
+ out_channels=num_output,
243
+ kernel_size=1,
244
+ stride=1,
245
+ padding=0
246
+ )
247
+ self.__setattr__(head, fc)
248
+
249
+ def forward(self, x):
250
+ x = self.conv1(x)
251
+ x = self.bn1(x)
252
+ x = self.relu(x)
253
+
254
+ x = self.conv2(x)
255
+ x = self.bn2(x)
256
+ x = self.relu(x)
257
+
258
+ x = self.block1(x)
259
+ x = self.block2(x)
260
+ x = self.block3(x)
261
+ x = self.block4(x)
262
+ x = self.block5(x)
263
+ x = self.block6(x)
264
+ x = self.block7(x)
265
+ x = self.block8(x)
266
+ x = self.block9(x)
267
+ x = self.block10(x)
268
+ x = self.block11(x)
269
+ x = self.block12(x)
270
+
271
+ x = self.conv3(x)
272
+ x = self.bn3(x)
273
+ x = self.relu(x)
274
+
275
+ x = self.conv4(x)
276
+ x = self.bn4(x)
277
+ x = self.relu(x)
278
+
279
+ x = self.dropout(x)
280
+
281
+ x = self.conv_block_1(x)
282
+ x = self.deconv_1(x)
283
+
284
+ x = self.conv_block_2(x)
285
+ x = self.deconv_2(x)
286
+
287
+ x = self.conv_block_3(x)
288
+ x = self.deconv_3(x)
289
+
290
+ ret = {}
291
+ x1_hm = None
292
+ for head in self.heads:
293
+ if not self.cls_based_hm or head != 'cls':
294
+ ret[head] = self.__getattr__(head)(x)
295
+ if head == 'hm':
296
+ x1_hm = ret[head]
297
+ else:
298
+ assert 'hm' in ret.keys(), "Other heads need features from heatmap, please check it!"
299
+ ret[head] = self.__getattr__(head)(x1_hm)
300
+ return [ret]
301
+
302
+ def init_weights(self, pretrained=False):
303
+ if not pretrained:
304
+ for m in self.modules():
305
+ if isinstance(m, nn.Conv2d):
306
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
307
+ m.weight.data.normal_(0, math.sqrt(2. / n))
308
+ elif isinstance(m, nn.BatchNorm2d):
309
+ m.weight.data.fill_(1)
310
+ m.bias.data.zero_()
311
+ elif isinstance(m, nn.ConvTranspose2d):
312
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
313
+ m.weight.data.normal_(0, math.sqrt(2. / n))
314
+ if self.deconv_with_bias:
315
+ nn.init.constant_(m.bias, 0)
316
+ else:
317
+ self.load_state_dict(model_zoo.load_url(model_urls['xception']), strict=False)
318
+
319
+ # Init head parameters
320
+ for head in self.heads:
321
+ final_layer = self.__getattr__(head)
322
+ for i, m in enumerate(final_layer.modules()):
323
+ prior = 1/71
324
+ # if isinstance(m, nn.Conv2d):
325
+ # if m.weight.shape[0] == self.heads[head]:
326
+ # if 'hm' in head:
327
+ # # nn.init.constant_(m.bias, -2.19)
328
+ # nn.init.constant_(m.bias, -math.log((1-prior)/prior))
329
+ # else:
330
+ # nn.init.normal_(m.weight, std=0.001)
331
+ # # nn.init.constant_(m.bias, 0)
332
+ if isinstance(m, nn.Linear):
333
+ if m.weight.shape[0] == self.heads[head]:
334
+ nn.init.constant_(m.bias, -math.log((1-prior)/prior))
335
+ # else:
336
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
337
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
338
+ # # nn.init.constant_(m.bias, 0)
models/utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- coding: utf-8 -*-
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import os
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ layers_position = {
13
+ 'PoseResNet_50': 158,
14
+ 'PoseResNet_101': 311,
15
+ 'PoseEfficientNet_B4': 415,
16
+ }
17
+
18
+
19
+ def preset_model(cfg, model, optimizer=None):
20
+ #Loading models from config, make sure the pretrained path correct to the model name
21
+ start_epoch = 0
22
+ if 'pretrained' in cfg.TRAIN and os.path.isfile(cfg.TRAIN.pretrained):
23
+ model, optimizer, start_epoch = load_model(model,
24
+ cfg.TRAIN.pretrained,
25
+ optimizer=optimizer,
26
+ resume=cfg.TRAIN.resume,
27
+ lr=cfg.TRAIN.lr,
28
+ lr_step=cfg.TRAIN.lr_scheduler.milestones,
29
+ gamma=cfg.TRAIN.lr_scheduler.gamma)
30
+ else:
31
+ model.init_weights(**cfg.MODEL.INIT_WEIGHTS)
32
+ print('Loading model successfully -- {}'.format(cfg.MODEL.type))
33
+
34
+ #Freeze backbone if begin_epoch < warm up
35
+ if cfg.TRAIN.freeze_backbone and start_epoch < cfg.TRAIN.warm_up:
36
+ freeze_backbone(cfg.MODEL, model)
37
+
38
+ print('Number of parameters', sum(p.numel() for p in model.parameters()))
39
+ print('Number of trainable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))
40
+ return model, optimizer, start_epoch
41
+
42
+
43
+ def load_pretrained(model, weight_path):
44
+ '''
45
+ This function only care about state dict of model
46
+ For other modules such as optimizer, resume learning, please refer @load_model
47
+ '''
48
+ state_dict = torch.load(weight_path)['state_dict']
49
+ model.load_state_dict(state_dict, strict=True)
50
+ return model
51
+
52
+
53
+ def freeze_backbone(cfg, model):
54
+ '''
55
+ This func to freeze some specific layers to warm up the models
56
+ '''
57
+ if hasattr(model, 'backbone'):
58
+ backbone = model.backbone
59
+ for param in backbone.parameters():
60
+ param.requires_grad = False
61
+ else:
62
+ for i, (n, p) in enumerate(model.named_parameters()):
63
+ if (i <= layers_position[f'{cfg.type}_{cfg.num_layers}']):
64
+ p.requires_grad = False
65
+
66
+
67
+ def unfreeze_backbone(model):
68
+ '''
69
+ This func to unfreeze all model layers
70
+ '''
71
+ for param in model.parameters():
72
+ if not param.requires_grad:
73
+ param.requires_grad = True
74
+
75
+
76
+ def load_model(model, model_path, optimizer=None, resume=False,
77
+ lr=None, lr_step=None, gamma=None):
78
+ start_epoch = 0
79
+ checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
80
+ print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
81
+ state_dict_ = checkpoint['state_dict']
82
+ state_dict = {}
83
+
84
+ # convert data_parallal to model
85
+ for k in state_dict_:
86
+ if k.startswith('module') and not k.startswith('module_list'):
87
+ state_dict[k[7:]] = state_dict_[k]
88
+ else:
89
+ state_dict[k] = state_dict_[k]
90
+ model_state_dict = model.state_dict()
91
+
92
+ # check loaded parameters and created model parameters
93
+ msg = 'If you see this, your model does not fully load the ' + \
94
+ 'pre-trained weight. Please make sure ' + \
95
+ 'you have correctly specified --arch xxx ' + \
96
+ 'or set the correct --num_classes for your own dataset.'
97
+ for k in state_dict:
98
+ if k in model_state_dict:
99
+ if state_dict[k].shape != model_state_dict[k].shape:
100
+ print('Skip loading parameter {}, required shape{}, '\
101
+ 'loaded shape{}. {}'.format(
102
+ k, model_state_dict[k].shape, state_dict[k].shape, msg))
103
+ state_dict[k] = model_state_dict[k]
104
+ else:
105
+ print('Drop parameter {}.'.format(k) + msg)
106
+ for k in model_state_dict:
107
+ if not (k in state_dict):
108
+ print('No param {}.'.format(k) + msg)
109
+ state_dict[k] = model_state_dict[k]
110
+ model.load_state_dict(state_dict, strict=False)
111
+
112
+ # resume optimizer parameters
113
+ if optimizer is not None and resume:
114
+ if 'optimizer' in checkpoint:
115
+ optimizer.load_state_dict(checkpoint['optimizer'])
116
+ start_epoch = checkpoint['epoch'] + 1
117
+ start_lr = lr
118
+ for step in lr_step:
119
+ if start_epoch >= step:
120
+ start_lr *= gamma
121
+ for param_group in optimizer.param_groups:
122
+ param_group['lr'] = start_lr
123
+ print('Resumed optimizer with start lr', start_lr)
124
+ else:
125
+ print('No optimizer parameters in checkpoint.')
126
+ return model, optimizer, start_epoch
127
+
128
+
129
+ def save_model(path, epoch, model, optimizer=None):
130
+ if isinstance(model, torch.nn.DataParallel):
131
+ state_dict = model.module.state_dict()
132
+ else:
133
+ state_dict = model.state_dict()
134
+ data = {'epoch': epoch,
135
+ 'state_dict': state_dict}
136
+ if not (optimizer is None):
137
+ data['optimizer'] = optimizer.state_dict()
138
+ torch.save(data, path)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ opencv-python>=4.8.0
4
+ numpy>=1.24.0
5
+ Pillow>=10.0.0
6
+ gradio>=3.50.0
7
+ detectron2>=0.6.0; platform_system!="Darwin" # Detectron2 not available for macOS
8
+ fvcore>=0.1.5.post20221221; platform_system!="Darwin" # Required for detectron2
9
+ iopath>=0.1.9; platform_system!="Darwin" # Required for detectron2
10
+ pycocotools>=2.0.6; platform_system!="Darwin" # Required for detectron2