Spaces:
Runtime error
Runtime error
astrosbd
commited on
Commit
·
5c783e4
0
Parent(s):
Initial commit
Browse files- .gitattributes +36 -0
- .gitignore +51 -0
- README.md +60 -0
- app.py +551 -0
- configs/.ipynb_checkpoints/__init__-checkpoint.py +7 -0
- configs/__init__.py +7 -0
- configs/get_config.py +16 -0
- configs/test_config.yaml +38 -0
- configs/train_config copie.yaml +43 -0
- configs/train_config.yaml +46 -0
- loss/__init__.py +11 -0
- loss/abstract_loss_func.py +17 -0
- loss/cross_entropy_loss.py +26 -0
- metrics/__init__.py +7 -0
- metrics/base_metrics_class.py +204 -0
- metrics/registry.py +19 -0
- metrics/utils.py +92 -0
- models/__init__.py +29 -0
- models/builder.py +45 -0
- models/networks/arcface.py +384 -0
- models/networks/common.py +75 -0
- models/networks/efficientNet.py +490 -0
- models/networks/mrsa_resnet.py +464 -0
- models/networks/pose_efficientNet.py +788 -0
- models/networks/pose_hrnet.py +515 -0
- models/networks/xception.py +338 -0
- models/utils.py +138 -0
- requirements.txt +10 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual Environment
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
ENV/
|
| 27 |
+
|
| 28 |
+
# IDE
|
| 29 |
+
.idea/
|
| 30 |
+
.vscode/
|
| 31 |
+
*.swp
|
| 32 |
+
*.swo
|
| 33 |
+
|
| 34 |
+
# OS
|
| 35 |
+
.DS_Store
|
| 36 |
+
Thumbs.db
|
| 37 |
+
|
| 38 |
+
# Logs
|
| 39 |
+
*.log
|
| 40 |
+
|
| 41 |
+
# Model files
|
| 42 |
+
*.pth
|
| 43 |
+
*.pt
|
| 44 |
+
*.ckpt
|
| 45 |
+
*.bin
|
| 46 |
+
|
| 47 |
+
# Config files
|
| 48 |
+
*.yaml
|
| 49 |
+
*.yml
|
| 50 |
+
!configs/*.yaml
|
| 51 |
+
!configs/*.yml
|
README.md
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Car Damage Insurance Fraud Detector
|
| 3 |
+
emoji: 🚗
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.50.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Car Damage Insurance Fraud Detector
|
| 14 |
+
|
| 15 |
+
A sophisticated AI-powered system that detects car damage and potential insurance fraud using deep learning models.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- Damage Detection: Identifies and localizes car damage using Detectron2
|
| 20 |
+
- Deepfake Detection: Analyzes images for potential manipulation
|
| 21 |
+
- User-friendly Interface: Built with Gradio for easy interaction
|
| 22 |
+
- Multi-device Support: Works on CPU, CUDA, and MPS (Apple Silicon)
|
| 23 |
+
|
| 24 |
+
## Requirements
|
| 25 |
+
|
| 26 |
+
- Python 3.8+
|
| 27 |
+
- PyTorch
|
| 28 |
+
- OpenCV
|
| 29 |
+
- Gradio
|
| 30 |
+
- Detectron2 (optional, not available for macOS)
|
| 31 |
+
|
| 32 |
+
## Installation
|
| 33 |
+
|
| 34 |
+
1. Clone the repository
|
| 35 |
+
2. Install dependencies:
|
| 36 |
+
```bash
|
| 37 |
+
pip install -r requirements.txt
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Usage
|
| 41 |
+
|
| 42 |
+
1. Run the application:
|
| 43 |
+
```bash
|
| 44 |
+
python app.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
2. Open your browser and navigate to the provided local URL
|
| 48 |
+
|
| 49 |
+
## Model Requirements
|
| 50 |
+
|
| 51 |
+
- Damage detection model (Detectron2 format)
|
| 52 |
+
- Deepfake detection model (custom format)
|
| 53 |
+
|
| 54 |
+
## License
|
| 55 |
+
|
| 56 |
+
Apache 2.0
|
| 57 |
+
|
| 58 |
+
## Note
|
| 59 |
+
|
| 60 |
+
This application requires pre-trained models for both damage detection and deepfake detection. Make sure to have the appropriate model files in the correct locations before running the application.
|
app.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
|
| 13 |
+
# Add current directory to path
|
| 14 |
+
if not os.getcwd() in sys.path:
|
| 15 |
+
sys.path.append(os.getcwd())
|
| 16 |
+
|
| 17 |
+
# Detectron2 imports - wrapped in try-except to make them optional
|
| 18 |
+
try:
|
| 19 |
+
from detectron2.engine import DefaultPredictor
|
| 20 |
+
from detectron2.config import get_cfg
|
| 21 |
+
from detectron2.utils.visualizer import Visualizer, ColorMode
|
| 22 |
+
from detectron2 import model_zoo
|
| 23 |
+
DETECTRON2_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
print("Warning: Detectron2 is not installed. Damage detection will not be available.")
|
| 26 |
+
DETECTRON2_AVAILABLE = False
|
| 27 |
+
|
| 28 |
+
# Check for custom path for models
|
| 29 |
+
try:
|
| 30 |
+
from configs.get_config import load_config
|
| 31 |
+
from models import *
|
| 32 |
+
MODELS_IMPORTED = True
|
| 33 |
+
except ImportError:
|
| 34 |
+
print("Warning: Custom models couldn't be imported. Only damage detection will work.")
|
| 35 |
+
MODELS_IMPORTED = False
|
| 36 |
+
|
| 37 |
+
def setup_device(device_str):
|
| 38 |
+
"""Set up the computation device based on user input and availability"""
|
| 39 |
+
if device_str == 'auto':
|
| 40 |
+
if torch.cuda.is_available():
|
| 41 |
+
return torch.device('cuda:0')
|
| 42 |
+
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 43 |
+
return torch.device('mps')
|
| 44 |
+
else:
|
| 45 |
+
return torch.device('cpu')
|
| 46 |
+
elif device_str == 'cuda' and torch.cuda.is_available():
|
| 47 |
+
return torch.device('cuda:0')
|
| 48 |
+
elif device_str == 'mps' and hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 49 |
+
return torch.device('mps')
|
| 50 |
+
else:
|
| 51 |
+
print(f"Warning: Device {device_str} not available, using CPU instead.")
|
| 52 |
+
return torch.device('cpu')
|
| 53 |
+
|
| 54 |
+
def setup_damage_detector(model_path, threshold=0.7):
|
| 55 |
+
"""Set up the damage detection model using Detectron2"""
|
| 56 |
+
if not DETECTRON2_AVAILABLE:
|
| 57 |
+
print("Detectron2 is not installed. Cannot set up damage detector.")
|
| 58 |
+
return None, None
|
| 59 |
+
|
| 60 |
+
if model_path is None or not os.path.exists(model_path):
|
| 61 |
+
print("No damage model specified or file not found. Skipping damage detection.")
|
| 62 |
+
return None, None
|
| 63 |
+
|
| 64 |
+
cfg = get_cfg()
|
| 65 |
+
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
|
| 66 |
+
cfg.MODEL.WEIGHTS = model_path
|
| 67 |
+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # Only one class (damage)
|
| 68 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
|
| 69 |
+
|
| 70 |
+
# Explicitly set to use CPU if on Mac (MPS)
|
| 71 |
+
if torch.backends.mps.is_available():
|
| 72 |
+
cfg.MODEL.DEVICE = "cpu"
|
| 73 |
+
print("Mac MPS detected - forcing Detectron2 to use CPU")
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
predictor = DefaultPredictor(cfg)
|
| 77 |
+
return predictor, cfg
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"Error setting up damage detector: {e}")
|
| 80 |
+
return None, cfg
|
| 81 |
+
|
| 82 |
+
def load_deepfake_model(model_path, cfg_path, device):
|
| 83 |
+
"""Load the deepfake detection model"""
|
| 84 |
+
if not MODELS_IMPORTED:
|
| 85 |
+
print("Custom models module not imported. Cannot load deepfake model.")
|
| 86 |
+
return None, None
|
| 87 |
+
|
| 88 |
+
if model_path is None or not os.path.exists(model_path):
|
| 89 |
+
print("No deepfake model specified or file not found. Skipping deepfake detection.")
|
| 90 |
+
return None, None
|
| 91 |
+
|
| 92 |
+
if cfg_path is None or not os.path.exists(cfg_path):
|
| 93 |
+
print("No deepfake config specified or file not found. Skipping deepfake detection.")
|
| 94 |
+
return None, None
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
# Load config
|
| 98 |
+
cfg = load_config(cfg_path)
|
| 99 |
+
|
| 100 |
+
# Build model
|
| 101 |
+
model = build_model(cfg.MODEL, MODELS)
|
| 102 |
+
|
| 103 |
+
# Load weights
|
| 104 |
+
print(f"Loading deepfake model from: {model_path}")
|
| 105 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
| 106 |
+
|
| 107 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 108 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 109 |
+
else:
|
| 110 |
+
model.load_state_dict(checkpoint)
|
| 111 |
+
|
| 112 |
+
# Move model to device and set to evaluation mode
|
| 113 |
+
model = model.to(device)
|
| 114 |
+
if hasattr(cfg.MODEL, 'precision') and cfg.MODEL.precision == 'fp64':
|
| 115 |
+
model = model.to(torch.float64)
|
| 116 |
+
model.eval()
|
| 117 |
+
|
| 118 |
+
return model, cfg
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error loading deepfake model: {e}")
|
| 121 |
+
import traceback
|
| 122 |
+
traceback.print_exc()
|
| 123 |
+
return None, None
|
| 124 |
+
|
| 125 |
+
def preprocess_for_deepfake(image, cfg, device):
|
| 126 |
+
"""Preprocess an image for deepfake detection"""
|
| 127 |
+
try:
|
| 128 |
+
# Convert to RGB if needed
|
| 129 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 130 |
+
if image.dtype != np.uint8:
|
| 131 |
+
image = (image * 255).astype(np.uint8)
|
| 132 |
+
rgb_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 133 |
+
else:
|
| 134 |
+
rgb_img = image
|
| 135 |
+
|
| 136 |
+
# Resize
|
| 137 |
+
img_resized = cv2.resize(rgb_img, (cfg.DATASET.IMAGE_SIZE[0], cfg.DATASET.IMAGE_SIZE[1]))
|
| 138 |
+
|
| 139 |
+
# Convert to PIL and apply transforms
|
| 140 |
+
transform = transforms.Compose([
|
| 141 |
+
transforms.ToTensor(),
|
| 142 |
+
transforms.Normalize(
|
| 143 |
+
mean=cfg.DATASET.TRANSFORM.normalize.mean,
|
| 144 |
+
std=cfg.DATASET.TRANSFORM.normalize.std
|
| 145 |
+
)
|
| 146 |
+
])
|
| 147 |
+
|
| 148 |
+
img_tensor = transform(Image.fromarray(img_resized)).unsqueeze(0) # Add batch dimension
|
| 149 |
+
img_tensor = img_tensor.to(device)
|
| 150 |
+
|
| 151 |
+
# Convert to correct precision
|
| 152 |
+
if hasattr(cfg.MODEL, 'precision') and cfg.MODEL.precision == 'fp64':
|
| 153 |
+
img_tensor = img_tensor.to(torch.float64)
|
| 154 |
+
|
| 155 |
+
return img_tensor
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Error preprocessing image for deepfake detection: {e}")
|
| 158 |
+
import traceback
|
| 159 |
+
traceback.print_exc()
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
def detect_damage(img, damage_detector):
|
| 163 |
+
"""Detect damage in an image"""
|
| 164 |
+
try:
|
| 165 |
+
if img is None:
|
| 166 |
+
raise ValueError("Invalid image")
|
| 167 |
+
|
| 168 |
+
# If no damage detector available, return the whole image as region
|
| 169 |
+
if damage_detector is None:
|
| 170 |
+
print("No damage detector available. Using whole image as region.")
|
| 171 |
+
h, w = img.shape[:2]
|
| 172 |
+
damage_regions = [{
|
| 173 |
+
"box": (0, 0, w, h),
|
| 174 |
+
"score": 1.0,
|
| 175 |
+
"mask": None
|
| 176 |
+
}]
|
| 177 |
+
return img, None, damage_regions
|
| 178 |
+
|
| 179 |
+
# Run inference
|
| 180 |
+
outputs = damage_detector(img)
|
| 181 |
+
|
| 182 |
+
# Get damage regions
|
| 183 |
+
instances = outputs["instances"].to("cpu")
|
| 184 |
+
boxes = instances.pred_boxes.tensor.numpy() if instances.has("pred_boxes") else []
|
| 185 |
+
scores = instances.scores.numpy() if instances.has("scores") else []
|
| 186 |
+
masks = instances.pred_masks.numpy() if instances.has("pred_masks") else []
|
| 187 |
+
|
| 188 |
+
damage_regions = []
|
| 189 |
+
for i in range(len(boxes)):
|
| 190 |
+
x1, y1, x2, y2 = map(int, boxes[i])
|
| 191 |
+
damage_regions.append({
|
| 192 |
+
"box": (x1, y1, x2, y2),
|
| 193 |
+
"score": float(scores[i]),
|
| 194 |
+
"mask": masks[i] if len(masks) > i else None
|
| 195 |
+
})
|
| 196 |
+
|
| 197 |
+
if not damage_regions:
|
| 198 |
+
print("No damage detected. Using whole image.")
|
| 199 |
+
h, w = img.shape[:2]
|
| 200 |
+
damage_regions = [{
|
| 201 |
+
"box": (0, 0, w, h),
|
| 202 |
+
"score": 1.0,
|
| 203 |
+
"mask": None
|
| 204 |
+
}]
|
| 205 |
+
|
| 206 |
+
return img, outputs, damage_regions
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"Error detecting damage: {e}")
|
| 209 |
+
# If error occurs, return the whole image as region
|
| 210 |
+
if 'img' in locals() and img is not None:
|
| 211 |
+
h, w = img.shape[:2]
|
| 212 |
+
damage_regions = [{
|
| 213 |
+
"box": (0, 0, w, h),
|
| 214 |
+
"score": 1.0,
|
| 215 |
+
"mask": None
|
| 216 |
+
}]
|
| 217 |
+
return img, None, damage_regions
|
| 218 |
+
return None, None, []
|
| 219 |
+
|
| 220 |
+
def check_deepfake(image, damage_regions, deepfake_model, deepfake_cfg, device, threshold=0.5):
|
| 221 |
+
"""Check if damage regions are deepfakes"""
|
| 222 |
+
results = []
|
| 223 |
+
|
| 224 |
+
if deepfake_model is None:
|
| 225 |
+
print("No deepfake model available. Skipping deepfake detection.")
|
| 226 |
+
return []
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
# If no damage regions, check the entire image
|
| 230 |
+
if not damage_regions:
|
| 231 |
+
img_tensor = preprocess_for_deepfake(image, deepfake_cfg, device)
|
| 232 |
+
if img_tensor is None:
|
| 233 |
+
return []
|
| 234 |
+
|
| 235 |
+
# Run inference
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
outputs = deepfake_model(img_tensor)
|
| 238 |
+
|
| 239 |
+
# Extract outputs
|
| 240 |
+
if isinstance(outputs, list):
|
| 241 |
+
outputs = outputs[0]
|
| 242 |
+
|
| 243 |
+
if isinstance(outputs, dict) and 'cls' in outputs:
|
| 244 |
+
cls_outputs = outputs['cls']
|
| 245 |
+
cls_prob = cls_outputs.sigmoid().cpu().numpy()
|
| 246 |
+
else:
|
| 247 |
+
# Assuming the output is directly the classification probability
|
| 248 |
+
cls_prob = outputs.sigmoid().cpu().numpy() if hasattr(outputs, 'sigmoid') else outputs.cpu().numpy()
|
| 249 |
+
|
| 250 |
+
if cls_prob.size > 0:
|
| 251 |
+
is_fake = cls_prob[0][0] > threshold if cls_prob.ndim > 1 else cls_prob[0] > threshold
|
| 252 |
+
confidence = cls_prob[0][0] if cls_prob.ndim > 1 else cls_prob[0]
|
| 253 |
+
|
| 254 |
+
results.append({
|
| 255 |
+
"region": "full_image",
|
| 256 |
+
"deepfake_prob": float(confidence),
|
| 257 |
+
"is_fake": bool(is_fake)
|
| 258 |
+
})
|
| 259 |
+
|
| 260 |
+
return results
|
| 261 |
+
|
| 262 |
+
# Process each damage region
|
| 263 |
+
for i, region in enumerate(damage_regions):
|
| 264 |
+
x1, y1, x2, y2 = region["box"]
|
| 265 |
+
# Ensure coordinates are within image bounds
|
| 266 |
+
x1, y1 = max(0, x1), max(0, y1)
|
| 267 |
+
x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
|
| 268 |
+
|
| 269 |
+
# Extract region and check if it's a deepfake
|
| 270 |
+
if x2 > x1 and y2 > y1:
|
| 271 |
+
# Get ROI
|
| 272 |
+
roi = image[y1:y2, x1:x2]
|
| 273 |
+
|
| 274 |
+
# Preprocess
|
| 275 |
+
img_tensor = preprocess_for_deepfake(roi, deepfake_cfg, device)
|
| 276 |
+
if img_tensor is None:
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
# Run inference
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
outputs = deepfake_model(img_tensor)
|
| 282 |
+
|
| 283 |
+
# Extract outputs
|
| 284 |
+
if isinstance(outputs, list):
|
| 285 |
+
outputs = outputs[0]
|
| 286 |
+
|
| 287 |
+
if isinstance(outputs, dict) and 'cls' in outputs:
|
| 288 |
+
cls_outputs = outputs['cls']
|
| 289 |
+
cls_prob = cls_outputs.sigmoid().cpu().numpy()
|
| 290 |
+
else:
|
| 291 |
+
# Assuming the output is directly the classification probability
|
| 292 |
+
cls_prob = outputs.sigmoid().cpu().numpy() if hasattr(outputs, 'sigmoid') else outputs.cpu().numpy()
|
| 293 |
+
|
| 294 |
+
if cls_prob.size > 0:
|
| 295 |
+
is_fake = cls_prob[0][0] > threshold if cls_prob.ndim > 1 else cls_prob[0] > threshold
|
| 296 |
+
confidence = cls_prob[0][0] if cls_prob.ndim > 1 else cls_prob[0]
|
| 297 |
+
|
| 298 |
+
results.append({
|
| 299 |
+
"region_id": i,
|
| 300 |
+
"box": (x1, y1, x2, y2),
|
| 301 |
+
"deepfake_prob": float(confidence),
|
| 302 |
+
"is_fake": bool(is_fake)
|
| 303 |
+
})
|
| 304 |
+
|
| 305 |
+
return results
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"Error in deepfake detection: {e}")
|
| 308 |
+
import traceback
|
| 309 |
+
traceback.print_exc()
|
| 310 |
+
return []
|
| 311 |
+
|
| 312 |
+
def visualize_results(image, damage_outputs, deepfake_results, damage_threshold):
|
| 313 |
+
"""Create visualization of damage detection and deepfake verification"""
|
| 314 |
+
try:
|
| 315 |
+
# Create a copy for visualization
|
| 316 |
+
img_copy = image.copy()
|
| 317 |
+
|
| 318 |
+
# Draw damage detection results
|
| 319 |
+
if damage_outputs is not None and DETECTRON2_AVAILABLE:
|
| 320 |
+
try:
|
| 321 |
+
v = Visualizer(img_copy[:, :, ::-1], scale=1.0, instance_mode=ColorMode.IMAGE_BW)
|
| 322 |
+
v = v.draw_instance_predictions(damage_outputs["instances"].to("cpu"))
|
| 323 |
+
result_img = v.get_image()[:, :, ::-1]
|
| 324 |
+
|
| 325 |
+
# Convert to a standard numpy array to ensure compatibility with OpenCV
|
| 326 |
+
result_img = np.array(result_img, dtype=np.uint8)
|
| 327 |
+
except Exception as e:
|
| 328 |
+
print(f"Error visualizing damage detection: {e}")
|
| 329 |
+
result_img = img_copy
|
| 330 |
+
else:
|
| 331 |
+
result_img = img_copy
|
| 332 |
+
|
| 333 |
+
# Add deepfake detection results
|
| 334 |
+
for result in deepfake_results:
|
| 335 |
+
try:
|
| 336 |
+
if "box" in result:
|
| 337 |
+
x1, y1, x2, y2 = result["box"]
|
| 338 |
+
fake_prob = result["deepfake_prob"]
|
| 339 |
+
is_fake = result["is_fake"]
|
| 340 |
+
region_id = result.get("region_id", 0)
|
| 341 |
+
|
| 342 |
+
# Text for the region
|
| 343 |
+
text = f"R{region_id}: {'FAKE' if is_fake else 'REAL'} ({fake_prob*100:.1f}%)"
|
| 344 |
+
|
| 345 |
+
# Different colors for fake/real
|
| 346 |
+
color = (0, 0, 255) if is_fake else (0, 255, 0) # Red for fake, green for real
|
| 347 |
+
|
| 348 |
+
# Ensure we have a standard numpy array
|
| 349 |
+
if not isinstance(result_img, np.ndarray):
|
| 350 |
+
result_img = np.array(result_img, dtype=np.uint8)
|
| 351 |
+
|
| 352 |
+
# Draw rectangle and text
|
| 353 |
+
cv2.rectangle(result_img, (x1, y1), (x2, y2), color, 2)
|
| 354 |
+
cv2.putText(result_img, text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
|
| 355 |
+
elif "region" in result and result["region"] == "full_image":
|
| 356 |
+
fake_prob = result["deepfake_prob"]
|
| 357 |
+
is_fake = result["is_fake"]
|
| 358 |
+
|
| 359 |
+
# Text for the whole image
|
| 360 |
+
text = f"Image: {'FAKE' if is_fake else 'REAL'} ({fake_prob*100:.1f}%)"
|
| 361 |
+
|
| 362 |
+
# Different colors for fake/real
|
| 363 |
+
color = (0, 0, 255) if is_fake else (0, 255, 0) # Red for fake, green for real
|
| 364 |
+
|
| 365 |
+
# Ensure we have a standard numpy array
|
| 366 |
+
if not isinstance(result_img, np.ndarray):
|
| 367 |
+
result_img = np.array(result_img, dtype=np.uint8)
|
| 368 |
+
|
| 369 |
+
# Draw text
|
| 370 |
+
cv2.putText(result_img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
|
| 371 |
+
except Exception as e:
|
| 372 |
+
print(f"Error drawing result {result}: {e}")
|
| 373 |
+
|
| 374 |
+
return result_img
|
| 375 |
+
except Exception as e:
|
| 376 |
+
print(f"Error visualizing results: {e}")
|
| 377 |
+
import traceback
|
| 378 |
+
traceback.print_exc()
|
| 379 |
+
return np.array(image, dtype=np.uint8) # Return the original image as a numpy array
|
| 380 |
+
|
| 381 |
+
def process_image(input_image, damage_model_path, deepfake_model_path, deepfake_cfg_path,
|
| 382 |
+
damage_threshold, deepfake_threshold, skip_damage, device_str):
|
| 383 |
+
"""Process an image through the car damage and deepfake detection pipeline"""
|
| 384 |
+
progress_info = []
|
| 385 |
+
|
| 386 |
+
# Convert Gradio image to numpy array
|
| 387 |
+
if isinstance(input_image, dict) and "path" in input_image:
|
| 388 |
+
img = cv2.imread(input_image["path"])
|
| 389 |
+
elif isinstance(input_image, str):
|
| 390 |
+
img = cv2.imread(input_image)
|
| 391 |
+
elif isinstance(input_image, np.ndarray):
|
| 392 |
+
# Make a copy to avoid modifying the original
|
| 393 |
+
img = input_image.copy()
|
| 394 |
+
# Convert from RGB to BGR (OpenCV format)
|
| 395 |
+
if len(img.shape) == 3 and img.shape[2] == 3:
|
| 396 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 397 |
+
else:
|
| 398 |
+
return None, "Error: Unsupported image format"
|
| 399 |
+
|
| 400 |
+
if img is None:
|
| 401 |
+
return None, "Error: Could not read the image"
|
| 402 |
+
|
| 403 |
+
# Progress update
|
| 404 |
+
progress_info.append("Image loaded successfully")
|
| 405 |
+
|
| 406 |
+
# Setup device
|
| 407 |
+
device = setup_device(device_str)
|
| 408 |
+
progress_info.append(f"Using device: {device}")
|
| 409 |
+
|
| 410 |
+
# Initialize models
|
| 411 |
+
damage_detector = None
|
| 412 |
+
deepfake_model = None
|
| 413 |
+
deepfake_cfg = None
|
| 414 |
+
|
| 415 |
+
# Setup damage detector if not skipped
|
| 416 |
+
if not skip_damage and damage_model_path:
|
| 417 |
+
progress_info.append("Setting up damage detector...")
|
| 418 |
+
damage_detector, detector_cfg = setup_damage_detector(damage_model_path, float(damage_threshold))
|
| 419 |
+
if damage_detector is None and DETECTRON2_AVAILABLE:
|
| 420 |
+
progress_info.append("Failed to initialize damage detector")
|
| 421 |
+
else:
|
| 422 |
+
progress_info.append("Damage detector initialized successfully")
|
| 423 |
+
|
| 424 |
+
# Setup deepfake detector
|
| 425 |
+
if deepfake_model_path and deepfake_cfg_path:
|
| 426 |
+
progress_info.append("Setting up deepfake detector...")
|
| 427 |
+
deepfake_model, deepfake_cfg = load_deepfake_model(deepfake_model_path, deepfake_cfg_path, device)
|
| 428 |
+
if deepfake_model is None:
|
| 429 |
+
progress_info.append("Failed to initialize deepfake detector")
|
| 430 |
+
else:
|
| 431 |
+
progress_info.append("Deepfake detector initialized successfully")
|
| 432 |
+
|
| 433 |
+
# Ensure at least one detector is working
|
| 434 |
+
if damage_detector is None and deepfake_model is None:
|
| 435 |
+
return None, "Error: Neither damage nor deepfake detector is available"
|
| 436 |
+
|
| 437 |
+
# Step 1: Detect damage or use whole image
|
| 438 |
+
progress_info.append("Detecting damage regions...")
|
| 439 |
+
start_time = time.time()
|
| 440 |
+
img, damage_outputs, damage_regions = detect_damage(img, damage_detector)
|
| 441 |
+
damage_time = time.time() - start_time
|
| 442 |
+
|
| 443 |
+
if img is None:
|
| 444 |
+
return None, "Error: Failed to process image"
|
| 445 |
+
|
| 446 |
+
# Print damage detection results
|
| 447 |
+
if damage_detector is not None and damage_regions:
|
| 448 |
+
progress_info.append(f"Detected {len(damage_regions)} damage regions in {damage_time:.3f} seconds")
|
| 449 |
+
else:
|
| 450 |
+
progress_info.append("Using the whole image for analysis")
|
| 451 |
+
|
| 452 |
+
# Step 2: Check if damage is deepfake
|
| 453 |
+
deepfake_results = []
|
| 454 |
+
if deepfake_model is not None:
|
| 455 |
+
progress_info.append("Performing deepfake detection...")
|
| 456 |
+
start_time = time.time()
|
| 457 |
+
deepfake_results = check_deepfake(
|
| 458 |
+
img, damage_regions, deepfake_model, deepfake_cfg, device, float(deepfake_threshold)
|
| 459 |
+
)
|
| 460 |
+
deepfake_time = time.time() - start_time
|
| 461 |
+
|
| 462 |
+
if deepfake_results:
|
| 463 |
+
progress_info.append(f"Deepfake detection completed in {deepfake_time:.3f} seconds")
|
| 464 |
+
|
| 465 |
+
# Generate report
|
| 466 |
+
for result in deepfake_results:
|
| 467 |
+
if "region_id" in result:
|
| 468 |
+
region_id = result["region_id"]
|
| 469 |
+
fake_prob = result["deepfake_prob"]
|
| 470 |
+
is_fake = result["is_fake"]
|
| 471 |
+
progress_info.append(f"Region {region_id}: {'FAKE' if is_fake else 'REAL'} (Probability: {fake_prob*100:.2f}%)")
|
| 472 |
+
elif "region" in result and result["region"] == "full_image":
|
| 473 |
+
fake_prob = result["deepfake_prob"]
|
| 474 |
+
is_fake = result["is_fake"]
|
| 475 |
+
progress_info.append(f"Whole image: {'FAKE' if is_fake else 'REAL'} (Probability: {fake_prob*100:.2f}%)")
|
| 476 |
+
else:
|
| 477 |
+
progress_info.append("No deepfake detection results")
|
| 478 |
+
|
| 479 |
+
# Step 3: Visualize final results
|
| 480 |
+
progress_info.append("Generating visualization...")
|
| 481 |
+
result_img = visualize_results(img, damage_outputs, deepfake_results, float(damage_threshold))
|
| 482 |
+
|
| 483 |
+
# Convert back to RGB for Gradio
|
| 484 |
+
if len(result_img.shape) == 3 and result_img.shape[2] == 3:
|
| 485 |
+
result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
|
| 486 |
+
|
| 487 |
+
progress_info.append("Processing complete!")
|
| 488 |
+
|
| 489 |
+
return result_img, "\n".join(progress_info)
|
| 490 |
+
|
| 491 |
+
def create_gradio_interface():
|
| 492 |
+
with gr.Blocks(title="Car Damage & Deepfake Detection") as app:
|
| 493 |
+
gr.Markdown("# Car Damage Detection & Deepfake Verification")
|
| 494 |
+
gr.Markdown("Upload an image to detect car damage and check if it's a deepfake")
|
| 495 |
+
|
| 496 |
+
with gr.Tab("Basic Interface"):
|
| 497 |
+
with gr.Row():
|
| 498 |
+
with gr.Column(scale=1):
|
| 499 |
+
input_image = gr.Image(type="numpy", label="Input Image")
|
| 500 |
+
|
| 501 |
+
# Simple controls
|
| 502 |
+
skip_damage = gr.Checkbox(label="Skip Damage Detection", value=False)
|
| 503 |
+
damage_threshold = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05,
|
| 504 |
+
label="Damage Detection Threshold")
|
| 505 |
+
deepfake_threshold = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05,
|
| 506 |
+
label="Deepfake Detection Threshold")
|
| 507 |
+
device = gr.Dropdown(choices=["auto", "cuda", "cpu", "mps"], value="auto",
|
| 508 |
+
label="Computation Device")
|
| 509 |
+
|
| 510 |
+
process_btn = gr.Button("Process Image", variant="primary")
|
| 511 |
+
|
| 512 |
+
with gr.Column(scale=1):
|
| 513 |
+
output_image = gr.Image(type="numpy", label="Result")
|
| 514 |
+
output_text = gr.Textbox(label="Detection Results", lines=10)
|
| 515 |
+
|
| 516 |
+
with gr.Tab("Advanced Settings"):
|
| 517 |
+
with gr.Row():
|
| 518 |
+
with gr.Column():
|
| 519 |
+
damage_model_path = gr.Textbox(label="Damage Model Path",
|
| 520 |
+
placeholder="Path to damage detection model (.pth)")
|
| 521 |
+
deepfake_model_path = gr.Textbox(label="Deepfake Model Path",
|
| 522 |
+
placeholder="Path to deepfake detection model (.pth)")
|
| 523 |
+
deepfake_cfg_path = gr.Textbox(label="Deepfake Config Path",
|
| 524 |
+
placeholder="Path to deepfake model config (.yaml)")
|
| 525 |
+
|
| 526 |
+
# Connect the process function
|
| 527 |
+
process_btn.click(
|
| 528 |
+
fn=process_image,
|
| 529 |
+
inputs=[
|
| 530 |
+
input_image,
|
| 531 |
+
damage_model_path,
|
| 532 |
+
deepfake_model_path,
|
| 533 |
+
deepfake_cfg_path,
|
| 534 |
+
damage_threshold,
|
| 535 |
+
deepfake_threshold,
|
| 536 |
+
skip_damage,
|
| 537 |
+
device
|
| 538 |
+
],
|
| 539 |
+
outputs=[output_image, output_text]
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Examples
|
| 543 |
+
gr.Markdown("## Examples")
|
| 544 |
+
gr.Markdown("Note: Examples will only work if you have the appropriate models installed.")
|
| 545 |
+
|
| 546 |
+
return app
|
| 547 |
+
|
| 548 |
+
if __name__ == "__main__":
|
| 549 |
+
# Create and launch the Gradio interface
|
| 550 |
+
app = create_gradio_interface()
|
| 551 |
+
app.launch(share=True) # Set share=False in production
|
configs/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
configs/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
configs/get_config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from yaml import load, dump
|
| 5 |
+
try:
|
| 6 |
+
from yaml import CLoader as Loader, CDumper as Dumper
|
| 7 |
+
except ImportError:
|
| 8 |
+
from yaml import Loader, Dumper
|
| 9 |
+
from box import Box as edict
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_config(cfg):
|
| 13 |
+
with open(cfg) as f:
|
| 14 |
+
config = load(f, Loader=Loader)
|
| 15 |
+
|
| 16 |
+
return edict(config)
|
configs/test_config.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: test
|
| 2 |
+
lmdb: False
|
| 3 |
+
rgb_dir: '/ssd_scratch/deep_fake_dataset/'
|
| 4 |
+
lmdb_dir: '/ssd_scratch/deep_fake_dataset/datasets_lmdbs/'
|
| 5 |
+
dataset_json_folder: './preprocessing/dataset_json_v6/'
|
| 6 |
+
label_dict:
|
| 7 |
+
# DFD
|
| 8 |
+
DFD_fake: 1
|
| 9 |
+
DFD_real: 0
|
| 10 |
+
# FF++ + FaceShifter(FF-real+FF-FH)
|
| 11 |
+
FF-SH: 1
|
| 12 |
+
FF-F2F: 1
|
| 13 |
+
FF-DF: 1
|
| 14 |
+
FF-FS: 1
|
| 15 |
+
FF-NT: 1
|
| 16 |
+
FF-FH: 1
|
| 17 |
+
FF-real: 0
|
| 18 |
+
# CelebDF
|
| 19 |
+
CelebDFv1_real: 0
|
| 20 |
+
CelebDFv1_fake: 1
|
| 21 |
+
CelebDFv2_real: 0
|
| 22 |
+
CelebDFv2_fake: 1
|
| 23 |
+
# DFDCP
|
| 24 |
+
DFDCP_Real: 0
|
| 25 |
+
DFDCP_FakeA: 1
|
| 26 |
+
DFDCP_FakeB: 1
|
| 27 |
+
# DFDC
|
| 28 |
+
DFDC_Fake: 1
|
| 29 |
+
DFDC_Real: 0
|
| 30 |
+
# DeeperForensics-1.0
|
| 31 |
+
DF_fake: 1
|
| 32 |
+
DF_real: 0
|
| 33 |
+
# UADFV
|
| 34 |
+
UADFV_Fake: 1
|
| 35 |
+
UADFV_Real: 0
|
| 36 |
+
# Roop
|
| 37 |
+
roop_Real: 0
|
| 38 |
+
roop_Fake: 1
|
configs/train_config copie.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: train
|
| 2 |
+
lmdb: False
|
| 3 |
+
dry_run: False
|
| 4 |
+
rgb_dir: '/ssd_scratch/deep_fake_dataset/'
|
| 5 |
+
lmdb_dir: '/ssd_scratch/deep_fake_dataset/datasets_lmdbs/'
|
| 6 |
+
dataset_json_folder: './preprocessing/dataset_json_v6/'
|
| 7 |
+
SWA: False
|
| 8 |
+
save_avg: True
|
| 9 |
+
log_dir: ./logs/training/
|
| 10 |
+
# label settings
|
| 11 |
+
label_dict:
|
| 12 |
+
# DFD
|
| 13 |
+
DFD_fake: 1
|
| 14 |
+
DFD_real: 0
|
| 15 |
+
# FF++ + FaceShifter(FF-real+FF-FH)
|
| 16 |
+
FF-SH: 1
|
| 17 |
+
FF-F2F: 1
|
| 18 |
+
FF-DF: 1
|
| 19 |
+
FF-FS: 1
|
| 20 |
+
FF-NT: 1
|
| 21 |
+
FF-FH: 1
|
| 22 |
+
FF-real: 0
|
| 23 |
+
# CelebDF
|
| 24 |
+
CelebDFv1_real: 0
|
| 25 |
+
CelebDFv1_fake: 1
|
| 26 |
+
CelebDFv2_real: 0
|
| 27 |
+
CelebDFv2_fake: 1
|
| 28 |
+
# DFDCP
|
| 29 |
+
DFDCP_Real: 0
|
| 30 |
+
DFDCP_FakeA: 1
|
| 31 |
+
DFDCP_FakeB: 1
|
| 32 |
+
# DFDC
|
| 33 |
+
DFDC_Fake: 1
|
| 34 |
+
DFDC_Real: 0
|
| 35 |
+
# DeeperForensics-1.0
|
| 36 |
+
DF_fake: 1
|
| 37 |
+
DF_real: 0
|
| 38 |
+
# UADFV
|
| 39 |
+
UADFV_Fake: 1
|
| 40 |
+
UADFV_Real: 0
|
| 41 |
+
# Roop
|
| 42 |
+
roop_Real: 0
|
| 43 |
+
roop_Fake: 1
|
configs/train_config.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mode: train
|
| 2 |
+
lmdb: False
|
| 3 |
+
dry_run: False
|
| 4 |
+
rgb_dir: '/ssd_scratch/deep_fake_dataset/'
|
| 5 |
+
lmdb_dir: '/ssd_scratch/deep_fake_dataset/datasets_lmdbs/'
|
| 6 |
+
dataset_json_folder: './preprocessing/dataset_json_v6/'
|
| 7 |
+
SWA: False
|
| 8 |
+
save_avg: True
|
| 9 |
+
log_dir: ./logs/training/
|
| 10 |
+
# label settings
|
| 11 |
+
label_dict:
|
| 12 |
+
# iFakeFaceDB labels
|
| 13 |
+
real: 0
|
| 14 |
+
fake: 1
|
| 15 |
+
# DFD
|
| 16 |
+
DFD_fake: 1
|
| 17 |
+
DFD_real: 0
|
| 18 |
+
# FF++ + FaceShifter(FF-real+FF-FH)
|
| 19 |
+
FF-SH: 1
|
| 20 |
+
FF-F2F: 1
|
| 21 |
+
FF-DF: 1
|
| 22 |
+
FF-FS: 1
|
| 23 |
+
FF-NT: 1
|
| 24 |
+
FF-FH: 1
|
| 25 |
+
FF-real: 0
|
| 26 |
+
# CelebDF
|
| 27 |
+
CelebDFv1_real: 0
|
| 28 |
+
CelebDFv1_fake: 1
|
| 29 |
+
CelebDFv2_real: 0
|
| 30 |
+
CelebDFv2_fake: 1
|
| 31 |
+
# DFDCP
|
| 32 |
+
DFDCP_Real: 0
|
| 33 |
+
DFDCP_FakeA: 1
|
| 34 |
+
DFDCP_FakeB: 1
|
| 35 |
+
# DFDC
|
| 36 |
+
DFDC_Fake: 1
|
| 37 |
+
DFDC_Real: 0
|
| 38 |
+
# DeeperForensics-1.0
|
| 39 |
+
DF_fake: 1
|
| 40 |
+
DF_real: 0
|
| 41 |
+
# UADFV
|
| 42 |
+
UADFV_Fake: 1
|
| 43 |
+
UADFV_Real: 0
|
| 44 |
+
# Roop
|
| 45 |
+
roop_Real: 0
|
| 46 |
+
roop_Fake: 1
|
loss/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
| 8 |
+
|
| 9 |
+
from metrics.registry import LOSSFUNC
|
| 10 |
+
|
| 11 |
+
from .cross_entropy_loss import CrossEntropyLoss
|
loss/abstract_loss_func.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class AbstractLossClass(nn.Module):
|
| 4 |
+
"""Abstract class for loss functions."""
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super(AbstractLossClass, self).__init__()
|
| 7 |
+
|
| 8 |
+
def forward(self, pred, label):
|
| 9 |
+
"""
|
| 10 |
+
Args:
|
| 11 |
+
pred: prediction of the model
|
| 12 |
+
label: ground truth label
|
| 13 |
+
|
| 14 |
+
Return:
|
| 15 |
+
loss: loss value
|
| 16 |
+
"""
|
| 17 |
+
raise NotImplementedError('Each subclass should implement the forward method.')
|
loss/cross_entropy_loss.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .abstract_loss_func import AbstractLossClass
|
| 3 |
+
from metrics.registry import LOSSFUNC
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@LOSSFUNC.register_module(module_name="cross_entropy")
|
| 7 |
+
class CrossEntropyLoss(AbstractLossClass):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.loss_fn = nn.CrossEntropyLoss()
|
| 11 |
+
|
| 12 |
+
def forward(self, inputs, targets):
|
| 13 |
+
"""
|
| 14 |
+
Computes the cross-entropy loss.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
|
| 18 |
+
targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
A scalar tensor representing the cross-entropy loss.
|
| 22 |
+
"""
|
| 23 |
+
# Compute the cross-entropy loss
|
| 24 |
+
loss = self.loss_fn(inputs, targets)
|
| 25 |
+
|
| 26 |
+
return loss
|
metrics/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
current_file_path = os.path.abspath(__file__)
|
| 4 |
+
parent_dir = os.path.dirname(os.path.dirname(current_file_path))
|
| 5 |
+
project_root_dir = os.path.dirname(parent_dir)
|
| 6 |
+
sys.path.append(parent_dir)
|
| 7 |
+
sys.path.append(project_root_dir)
|
metrics/base_metrics_class.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn import metrics
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_accracy(output, label):
|
| 8 |
+
_, prediction = torch.max(output, 1) # argmax
|
| 9 |
+
correct = (prediction == label).sum().item()
|
| 10 |
+
accuracy = correct / prediction.size(0)
|
| 11 |
+
return accuracy
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_prediction(output, label):
|
| 15 |
+
prob = nn.functional.softmax(output, dim=1)[:, 1]
|
| 16 |
+
prob = prob.view(prob.size(0), 1)
|
| 17 |
+
label = label.view(label.size(0), 1)
|
| 18 |
+
#print(prob.size(), label.size())
|
| 19 |
+
datas = torch.cat((prob, label.float()), dim=1)
|
| 20 |
+
return datas
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def calculate_metrics_for_train(label, output):
|
| 24 |
+
if output.size(1) == 2:
|
| 25 |
+
prob = torch.softmax(output, dim=1)[:, 1]
|
| 26 |
+
else:
|
| 27 |
+
prob = output
|
| 28 |
+
|
| 29 |
+
# Accuracy
|
| 30 |
+
_, prediction = torch.max(output, 1)
|
| 31 |
+
correct = (prediction == label).sum().item()
|
| 32 |
+
accuracy = correct / prediction.size(0)
|
| 33 |
+
|
| 34 |
+
# Average Precision
|
| 35 |
+
y_true = label.cpu().detach().numpy()
|
| 36 |
+
y_pred = prob.cpu().detach().numpy()
|
| 37 |
+
ap = metrics.average_precision_score(y_true, y_pred)
|
| 38 |
+
|
| 39 |
+
# AUC and EER
|
| 40 |
+
try:
|
| 41 |
+
fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(),
|
| 42 |
+
prob.squeeze().cpu().numpy(),
|
| 43 |
+
pos_label=1)
|
| 44 |
+
except:
|
| 45 |
+
# for the case when we only have one sample
|
| 46 |
+
return None, None, accuracy, ap
|
| 47 |
+
|
| 48 |
+
if np.isnan(fpr[0]) or np.isnan(tpr[0]):
|
| 49 |
+
# for the case when all the samples within a batch is fake/real
|
| 50 |
+
auc, eer = None, None
|
| 51 |
+
else:
|
| 52 |
+
auc = metrics.auc(fpr, tpr)
|
| 53 |
+
fnr = 1 - tpr
|
| 54 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 55 |
+
|
| 56 |
+
return auc, eer, accuracy, ap
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ------------ compute average metrics of batches---------------------
|
| 60 |
+
class Metrics_batch():
|
| 61 |
+
def __init__(self):
|
| 62 |
+
self.tprs = []
|
| 63 |
+
self.mean_fpr = np.linspace(0, 1, 100)
|
| 64 |
+
self.aucs = []
|
| 65 |
+
self.eers = []
|
| 66 |
+
self.aps = []
|
| 67 |
+
|
| 68 |
+
self.correct = 0
|
| 69 |
+
self.total = 0
|
| 70 |
+
self.losses = []
|
| 71 |
+
|
| 72 |
+
def update(self, label, output):
|
| 73 |
+
acc = self._update_acc(label, output)
|
| 74 |
+
if output.size(1) == 2:
|
| 75 |
+
prob = torch.softmax(output, dim=1)[:, 1]
|
| 76 |
+
else:
|
| 77 |
+
prob = output
|
| 78 |
+
#label = 1-label
|
| 79 |
+
#prob = torch.softmax(output, dim=1)[:, 1]
|
| 80 |
+
auc, eer = self._update_auc(label, prob)
|
| 81 |
+
ap = self._update_ap(label, prob)
|
| 82 |
+
|
| 83 |
+
return acc, auc, eer, ap
|
| 84 |
+
|
| 85 |
+
def _update_auc(self, lab, prob):
|
| 86 |
+
fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(),
|
| 87 |
+
prob.squeeze().cpu().numpy(),
|
| 88 |
+
pos_label=1)
|
| 89 |
+
if np.isnan(fpr[0]) or np.isnan(tpr[0]):
|
| 90 |
+
return -1, -1
|
| 91 |
+
|
| 92 |
+
auc = metrics.auc(fpr, tpr)
|
| 93 |
+
interp_tpr = np.interp(self.mean_fpr, fpr, tpr)
|
| 94 |
+
interp_tpr[0] = 0.0
|
| 95 |
+
self.tprs.append(interp_tpr)
|
| 96 |
+
self.aucs.append(auc)
|
| 97 |
+
|
| 98 |
+
# return auc
|
| 99 |
+
|
| 100 |
+
# EER
|
| 101 |
+
fnr = 1 - tpr
|
| 102 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 103 |
+
self.eers.append(eer)
|
| 104 |
+
|
| 105 |
+
return auc, eer
|
| 106 |
+
|
| 107 |
+
def _update_acc(self, lab, output):
|
| 108 |
+
_, prediction = torch.max(output, 1) # argmax
|
| 109 |
+
correct = (prediction == lab).sum().item()
|
| 110 |
+
accuracy = correct / prediction.size(0)
|
| 111 |
+
# self.accs.append(accuracy)
|
| 112 |
+
self.correct = self.correct+correct
|
| 113 |
+
self.total = self.total+lab.size(0)
|
| 114 |
+
return accuracy
|
| 115 |
+
|
| 116 |
+
def _update_ap(self, label, prob):
|
| 117 |
+
y_true = label.cpu().detach().numpy()
|
| 118 |
+
y_pred = prob.cpu().detach().numpy()
|
| 119 |
+
ap = metrics.average_precision_score(y_true,y_pred)
|
| 120 |
+
self.aps.append(ap)
|
| 121 |
+
|
| 122 |
+
return np.mean(ap)
|
| 123 |
+
|
| 124 |
+
def get_mean_metrics(self):
|
| 125 |
+
mean_acc, std_acc = self.correct/self.total, 0
|
| 126 |
+
mean_auc, std_auc = self._mean_auc()
|
| 127 |
+
mean_err, std_err = np.mean(self.eers), np.std(self.eers)
|
| 128 |
+
mean_ap, std_ap = np.mean(self.aps), np.std(self.aps)
|
| 129 |
+
|
| 130 |
+
return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap}
|
| 131 |
+
|
| 132 |
+
def _mean_auc(self):
|
| 133 |
+
mean_tpr = np.mean(self.tprs, axis=0)
|
| 134 |
+
mean_tpr[-1] = 1.0
|
| 135 |
+
mean_auc = metrics.auc(self.mean_fpr, mean_tpr)
|
| 136 |
+
std_auc = np.std(self.aucs)
|
| 137 |
+
return mean_auc, std_auc
|
| 138 |
+
|
| 139 |
+
def clear(self):
|
| 140 |
+
self.tprs.clear()
|
| 141 |
+
self.aucs.clear()
|
| 142 |
+
# self.accs.clear()
|
| 143 |
+
self.correct=0
|
| 144 |
+
self.total=0
|
| 145 |
+
self.eers.clear()
|
| 146 |
+
self.aps.clear()
|
| 147 |
+
self.losses.clear()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ------------ compute average metrics of all data ---------------------
|
| 151 |
+
class Metrics_all():
|
| 152 |
+
def __init__(self):
|
| 153 |
+
self.probs = []
|
| 154 |
+
self.labels = []
|
| 155 |
+
self.correct = 0
|
| 156 |
+
self.total = 0
|
| 157 |
+
|
| 158 |
+
def store(self, label, output):
|
| 159 |
+
prob = torch.softmax(output, dim=1)[:, 1]
|
| 160 |
+
_, prediction = torch.max(output, 1) # argmax
|
| 161 |
+
correct = (prediction == label).sum().item()
|
| 162 |
+
self.correct += correct
|
| 163 |
+
self.total += label.size(0)
|
| 164 |
+
self.labels.append(label.squeeze().cpu().numpy())
|
| 165 |
+
self.probs.append(prob.squeeze().cpu().numpy())
|
| 166 |
+
|
| 167 |
+
def get_metrics(self):
|
| 168 |
+
y_pred = np.concatenate(self.probs)
|
| 169 |
+
y_true = np.concatenate(self.labels)
|
| 170 |
+
# auc
|
| 171 |
+
fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1)
|
| 172 |
+
auc = metrics.auc(fpr, tpr)
|
| 173 |
+
# eer
|
| 174 |
+
fnr = 1 - tpr
|
| 175 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 176 |
+
# ap
|
| 177 |
+
ap = metrics.average_precision_score(y_true,y_pred)
|
| 178 |
+
# acc
|
| 179 |
+
acc = self.correct / self.total
|
| 180 |
+
return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap}
|
| 181 |
+
|
| 182 |
+
def clear(self):
|
| 183 |
+
self.probs.clear()
|
| 184 |
+
self.labels.clear()
|
| 185 |
+
self.correct = 0
|
| 186 |
+
self.total = 0
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# only used to record a series of scalar value
|
| 190 |
+
class Recorder:
|
| 191 |
+
def __init__(self):
|
| 192 |
+
self.sum = 0
|
| 193 |
+
self.num = 0
|
| 194 |
+
def update(self, item, num=1):
|
| 195 |
+
if item is not None:
|
| 196 |
+
self.sum += item * num
|
| 197 |
+
self.num += num
|
| 198 |
+
def average(self):
|
| 199 |
+
if self.num == 0:
|
| 200 |
+
return None
|
| 201 |
+
return self.sum/self.num
|
| 202 |
+
def clear(self):
|
| 203 |
+
self.sum = 0
|
| 204 |
+
self.num = 0
|
metrics/registry.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Registry(object):
|
| 2 |
+
def __init__(self):
|
| 3 |
+
self.data = {}
|
| 4 |
+
|
| 5 |
+
def register_module(self, module_name=None):
|
| 6 |
+
def _register(cls):
|
| 7 |
+
name = module_name
|
| 8 |
+
if module_name is None:
|
| 9 |
+
name = cls.__name__
|
| 10 |
+
self.data[name] = cls
|
| 11 |
+
return cls
|
| 12 |
+
return _register
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, key):
|
| 15 |
+
return self.data[key]
|
| 16 |
+
|
| 17 |
+
DETECTOR = Registry()
|
| 18 |
+
TRAINER = Registry()
|
| 19 |
+
LOSSFUNC = Registry()
|
metrics/utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn import metrics
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def parse_metric_for_print(metric_dict):
|
| 5 |
+
if metric_dict is None:
|
| 6 |
+
return "\n"
|
| 7 |
+
str = "\n"
|
| 8 |
+
str += "================================ Each dataset best metric ================================ \n"
|
| 9 |
+
for key, value in metric_dict.items():
|
| 10 |
+
if key != 'avg':
|
| 11 |
+
str= str+ f"| {key}: "
|
| 12 |
+
for k,v in value.items():
|
| 13 |
+
str = str + f" {k}={v} "
|
| 14 |
+
str= str+ "| \n"
|
| 15 |
+
else:
|
| 16 |
+
str += "============================================================================================= \n"
|
| 17 |
+
str += "================================== Average best metric ====================================== \n"
|
| 18 |
+
avg_dict = value
|
| 19 |
+
for avg_key, avg_value in avg_dict.items():
|
| 20 |
+
if avg_key == 'dataset_dict':
|
| 21 |
+
for key,value in avg_value.items():
|
| 22 |
+
str = str + f"| {key}: {value} | \n"
|
| 23 |
+
else:
|
| 24 |
+
str = str + f"| avg {avg_key}: {avg_value} | \n"
|
| 25 |
+
str += "============================================================================================="
|
| 26 |
+
return str
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_test_metrics(y_pred, y_true, img_names):
|
| 30 |
+
def get_video_metrics(image, pred, label):
|
| 31 |
+
result_dict = {}
|
| 32 |
+
new_label = []
|
| 33 |
+
new_pred = []
|
| 34 |
+
# print(image[0])
|
| 35 |
+
# print(pred.shape)
|
| 36 |
+
# print(label.shape)
|
| 37 |
+
for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
|
| 38 |
+
|
| 39 |
+
s = item[0]
|
| 40 |
+
if '\\' in s:
|
| 41 |
+
parts = s.split('\\')
|
| 42 |
+
else:
|
| 43 |
+
parts = s.split('/')
|
| 44 |
+
a = parts[-2]
|
| 45 |
+
b = parts[-1]
|
| 46 |
+
|
| 47 |
+
if a not in result_dict:
|
| 48 |
+
result_dict[a] = []
|
| 49 |
+
|
| 50 |
+
result_dict[a].append(item)
|
| 51 |
+
image_arr = list(result_dict.values())
|
| 52 |
+
|
| 53 |
+
for video in image_arr:
|
| 54 |
+
pred_sum = 0
|
| 55 |
+
label_sum = 0
|
| 56 |
+
leng = 0
|
| 57 |
+
for frame in video:
|
| 58 |
+
pred_sum += float(frame[1])
|
| 59 |
+
label_sum += int(frame[2])
|
| 60 |
+
leng += 1
|
| 61 |
+
new_pred.append(pred_sum / leng)
|
| 62 |
+
new_label.append(int(label_sum / leng))
|
| 63 |
+
fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred)
|
| 64 |
+
v_auc = metrics.auc(fpr, tpr)
|
| 65 |
+
fnr = 1 - tpr
|
| 66 |
+
v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 67 |
+
return v_auc, v_eer
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
y_pred = y_pred.squeeze()
|
| 71 |
+
# For UCF, where labels for different manipulations are not consistent.
|
| 72 |
+
y_true[y_true >= 1] = 1
|
| 73 |
+
# auc
|
| 74 |
+
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
|
| 75 |
+
auc = metrics.auc(fpr, tpr)
|
| 76 |
+
# eer
|
| 77 |
+
fnr = 1 - tpr
|
| 78 |
+
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 79 |
+
# ap
|
| 80 |
+
ap = metrics.average_precision_score(y_true, y_pred)
|
| 81 |
+
# acc
|
| 82 |
+
prediction_class = (y_pred > 0.5).astype(int)
|
| 83 |
+
correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
|
| 84 |
+
acc = correct / len(prediction_class)
|
| 85 |
+
if type(img_names[0]) is not list:
|
| 86 |
+
# calculate video-level auc for the frame-level methods.
|
| 87 |
+
v_auc, _ = get_video_metrics(img_names, y_pred, y_true)
|
| 88 |
+
else:
|
| 89 |
+
# video-level methods
|
| 90 |
+
v_auc=auc
|
| 91 |
+
|
| 92 |
+
return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'video_auc': v_auc, 'label': y_true}
|
models/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
from .builder import MODELS, build_model
|
| 3 |
+
from .networks.arcface import (
|
| 4 |
+
SimpleClassificationDF,
|
| 5 |
+
)
|
| 6 |
+
from .networks.mrsa_resnet import (
|
| 7 |
+
PoseResNet, resnet_spec, Bottleneck
|
| 8 |
+
)
|
| 9 |
+
from .networks.pose_hrnet import (
|
| 10 |
+
PoseHighResolutionNet
|
| 11 |
+
)
|
| 12 |
+
from .networks.xception import (
|
| 13 |
+
Xception
|
| 14 |
+
)
|
| 15 |
+
from.networks.pose_efficientNet import (
|
| 16 |
+
PoseEfficientNet
|
| 17 |
+
)
|
| 18 |
+
from .networks.common import *
|
| 19 |
+
from .utils import (
|
| 20 |
+
load_pretrained, freeze_backbone,
|
| 21 |
+
load_model, save_model, unfreeze_backbone,
|
| 22 |
+
preset_model,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__=['SimpleClassificationDF', 'PoseResNet', 'MODELS', 'build_model',
|
| 27 |
+
'load_pretrained', 'freeze_backbone', 'resnet_spec',
|
| 28 |
+
'load_model', 'save_model', 'unfreeze_backbone', 'Bottleneck',
|
| 29 |
+
'preset_model', 'PoseHighResolutionNet', 'Xception', 'PoseEfficientNet']
|
models/builder.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
from typing import Dict, Any, Optional
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
if not os.getcwd() in sys.path:
|
| 7 |
+
sys.path.append(os.getcwd())
|
| 8 |
+
|
| 9 |
+
from torch.nn import Sequential
|
| 10 |
+
|
| 11 |
+
from register.register import Registry, build_from_cfg
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_model_from_cfg(cfg, registry, default_args=None):
|
| 15 |
+
"""Build a PyTorch model from config dict(s). Different from
|
| 16 |
+
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
|
| 17 |
+
Args:
|
| 18 |
+
cfg (dict, list[dict]): The config of modules, is is either a config
|
| 19 |
+
dict or a list of config dicts. If cfg is a list, a
|
| 20 |
+
the built modules will be wrapped with ``nn.Sequential``.
|
| 21 |
+
registry (:obj:`Registry`): A registry the module belongs to.
|
| 22 |
+
default_args (dict, optional): Default arguments to build the module.
|
| 23 |
+
Defaults to None.
|
| 24 |
+
Returns:
|
| 25 |
+
nn.Module: A built nn module.
|
| 26 |
+
"""
|
| 27 |
+
if isinstance(cfg, list):
|
| 28 |
+
modules = [
|
| 29 |
+
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
| 30 |
+
]
|
| 31 |
+
return Sequential(*modules)
|
| 32 |
+
else:
|
| 33 |
+
return build_from_cfg(cfg, registry, default_args)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
MODELS = Registry('model', build_func=build_model_from_cfg)
|
| 37 |
+
HEADS = MODELS
|
| 38 |
+
BACKBONES = MODELS
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def build_model(cfg: Dict,
|
| 42 |
+
model: Registry,
|
| 43 |
+
build_func=build_model_from_cfg,
|
| 44 |
+
default_args: Optional[Dict] = None) -> Any:
|
| 45 |
+
return build_func(cfg, model, default_args)
|
models/networks/arcface.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
|
| 6 |
+
from torch.nn import (Linear, Conv2d, BatchNorm1d, Softmax,
|
| 7 |
+
BatchNorm2d, PReLU, ReLU, Sigmoid,
|
| 8 |
+
Dropout2d, Dropout, AvgPool2d, MaxPool2d,
|
| 9 |
+
AdaptiveAvgPool2d, Sequential, Module, Parameter)
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from ..builder import (
|
| 14 |
+
MODELS, HEADS, BACKBONES,
|
| 15 |
+
build_model,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
################################## Original Arcface Model #############################################################
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Flatten(Module):
|
| 23 |
+
def forward(self, input):
|
| 24 |
+
return input.view(input.size(0), -1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def l2_norm(input,axis=1):
|
| 28 |
+
norm = torch.norm(input, 2, axis, True)
|
| 29 |
+
output = torch.div(input, norm)
|
| 30 |
+
return output
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SEModule(Module):
|
| 34 |
+
def __init__(self, channels, reduction):
|
| 35 |
+
super(SEModule, self).__init__()
|
| 36 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
| 37 |
+
self.fc1 = Conv2d(
|
| 38 |
+
channels, channels // reduction, kernel_size=1, padding=0 ,bias=False)
|
| 39 |
+
self.relu = ReLU(inplace=True)
|
| 40 |
+
self.fc2 = Conv2d(
|
| 41 |
+
channels // reduction, channels, kernel_size=1, padding=0 ,bias=False)
|
| 42 |
+
self.sigmoid = Sigmoid()
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
module_input = x
|
| 46 |
+
x = self.avg_pool(x)
|
| 47 |
+
x = self.fc1(x)
|
| 48 |
+
x = self.relu(x)
|
| 49 |
+
x = self.fc2(x)
|
| 50 |
+
x = self.sigmoid(x)
|
| 51 |
+
return module_input * x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class bottleneck_IR(Module):
|
| 55 |
+
def __init__(self, in_channel, depth, stride):
|
| 56 |
+
super(bottleneck_IR, self).__init__()
|
| 57 |
+
if in_channel == depth:
|
| 58 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
| 59 |
+
else:
|
| 60 |
+
self.shortcut_layer = Sequential(
|
| 61 |
+
Conv2d(in_channel, depth, (1, 1), stride ,bias=False), BatchNorm2d(depth))
|
| 62 |
+
self.res_layer = Sequential(
|
| 63 |
+
BatchNorm2d(in_channel),
|
| 64 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1 ,bias=False), PReLU(depth),
|
| 65 |
+
Conv2d(depth, depth, (3, 3), stride, 1 ,bias=False), BatchNorm2d(depth))
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
shortcut = self.shortcut_layer(x)
|
| 69 |
+
res = self.res_layer(x)
|
| 70 |
+
return res + shortcut
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class bottleneck_IR_SE(Module):
|
| 74 |
+
def __init__(self, in_channel, depth, stride):
|
| 75 |
+
super(bottleneck_IR_SE, self).__init__()
|
| 76 |
+
if in_channel == depth:
|
| 77 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
| 78 |
+
else:
|
| 79 |
+
self.shortcut_layer = Sequential(
|
| 80 |
+
Conv2d(in_channel, depth, (1, 1), stride ,bias=False),
|
| 81 |
+
BatchNorm2d(depth))
|
| 82 |
+
self.res_layer = Sequential(
|
| 83 |
+
BatchNorm2d(in_channel),
|
| 84 |
+
Conv2d(in_channel, depth, (3,3), (1,1),1 ,bias=False),
|
| 85 |
+
PReLU(depth),
|
| 86 |
+
Conv2d(depth, depth, (3,3), stride, 1 ,bias=False),
|
| 87 |
+
BatchNorm2d(depth),
|
| 88 |
+
SEModule(depth,16)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self,x):
|
| 92 |
+
shortcut = self.shortcut_layer(x)
|
| 93 |
+
res = self.res_layer(x)
|
| 94 |
+
return res + shortcut
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
| 98 |
+
'''A named tuple describing a ResNet block.'''
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_block(in_channel, depth, num_units, stride = 2):
|
| 102 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units-1)]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_blocks(num_layers):
|
| 106 |
+
if num_layers == 50:
|
| 107 |
+
blocks = [
|
| 108 |
+
get_block(in_channel=64, depth=64, num_units = 3),
|
| 109 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
| 110 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
| 111 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 112 |
+
]
|
| 113 |
+
elif num_layers == 100:
|
| 114 |
+
blocks = [
|
| 115 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 116 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
| 117 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
| 118 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 119 |
+
]
|
| 120 |
+
elif num_layers == 152:
|
| 121 |
+
blocks = [
|
| 122 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 123 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
| 124 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
| 125 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 126 |
+
]
|
| 127 |
+
return blocks
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@BACKBONES.register_module()
|
| 131 |
+
class ResNet(Module):
|
| 132 |
+
def __init__(self, num_layers=50, drop_ratio=0.6, mode='ir', **kwargs):
|
| 133 |
+
"""
|
| 134 |
+
Implementation for ResNet 50, 101, 152 with/out SE module
|
| 135 |
+
"""
|
| 136 |
+
super(ResNet, self).__init__()
|
| 137 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
| 138 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
| 139 |
+
blocks = get_blocks(num_layers)
|
| 140 |
+
if mode == 'ir':
|
| 141 |
+
unit_module = bottleneck_IR
|
| 142 |
+
elif mode == 'ir_se':
|
| 143 |
+
unit_module = bottleneck_IR_SE
|
| 144 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False),
|
| 145 |
+
BatchNorm2d(64),
|
| 146 |
+
PReLU(64))
|
| 147 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
| 148 |
+
Dropout(drop_ratio),
|
| 149 |
+
Flatten(),
|
| 150 |
+
Linear(512 * 7 * 7, 512),
|
| 151 |
+
BatchNorm1d(512))
|
| 152 |
+
modules = []
|
| 153 |
+
for block in blocks:
|
| 154 |
+
for bottleneck in block:
|
| 155 |
+
modules.append(
|
| 156 |
+
unit_module(bottleneck.in_channel,
|
| 157 |
+
bottleneck.depth,
|
| 158 |
+
bottleneck.stride))
|
| 159 |
+
self.body = Sequential(*modules)
|
| 160 |
+
|
| 161 |
+
def forward(self,x):
|
| 162 |
+
x = self.input_layer(x)
|
| 163 |
+
x = self.body(x)
|
| 164 |
+
x = self.output_layer(x)
|
| 165 |
+
x = l2_norm(x)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@HEADS.register_module()
|
| 170 |
+
class SimpleClassificationHead(Module):
|
| 171 |
+
def __init__(self, drop_ratio=0.6, in_planes=512, **kwargs):
|
| 172 |
+
super(SimpleClassificationHead, self).__init__()
|
| 173 |
+
self.classification_head = Sequential(Dropout(drop_ratio),
|
| 174 |
+
Linear(in_planes, 256),
|
| 175 |
+
BatchNorm1d(256),
|
| 176 |
+
Dropout(drop_ratio),
|
| 177 |
+
Linear(256, 128),
|
| 178 |
+
BatchNorm1d(128),
|
| 179 |
+
Dropout(drop_ratio),
|
| 180 |
+
Linear(128, 64),
|
| 181 |
+
BatchNorm1d(64),
|
| 182 |
+
Dropout(drop_ratio),
|
| 183 |
+
Linear(64, 32),
|
| 184 |
+
BatchNorm1d(32),
|
| 185 |
+
# Dropout(drop_ratio),
|
| 186 |
+
Linear(32, 1),
|
| 187 |
+
Sigmoid())
|
| 188 |
+
|
| 189 |
+
def forward(self, x):
|
| 190 |
+
x = self.classification_head(x)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@MODELS.register_module()
|
| 195 |
+
class SimpleClassificationDF(Module):
|
| 196 |
+
def __init__(self, cfg: dict, **kwargs):
|
| 197 |
+
super(SimpleClassificationDF, self).__init__()
|
| 198 |
+
assert 'backbone' in cfg, 'Config for Backbones is mandatory!'
|
| 199 |
+
assert 'head' in cfg, 'Config for Heads is mandatory!'
|
| 200 |
+
|
| 201 |
+
self.backbone = BACKBONES.get(cfg.backbone.type)(**cfg.backbone)
|
| 202 |
+
self.head = HEADS.get(cfg.head.type)(**cfg.head)
|
| 203 |
+
self.model = Sequential(*[self.backbone,
|
| 204 |
+
self.head])
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
x = self.model(x)
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
################################## MobileFaceNet #############################################################
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class Conv_block(Module):
|
| 215 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
| 216 |
+
super(Conv_block, self).__init__()
|
| 217 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
|
| 218 |
+
self.bn = BatchNorm2d(out_c)
|
| 219 |
+
self.prelu = PReLU(out_c)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
x = self.conv(x)
|
| 223 |
+
x = self.bn(x)
|
| 224 |
+
x = self.prelu(x)
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class Linear_block(Module):
|
| 229 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
| 230 |
+
super(Linear_block, self).__init__()
|
| 231 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
|
| 232 |
+
self.bn = BatchNorm2d(out_c)
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
x = self.conv(x)
|
| 236 |
+
x = self.bn(x)
|
| 237 |
+
return x
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class Depth_Wise(Module):
|
| 241 |
+
def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
| 242 |
+
super(Depth_Wise, self).__init__()
|
| 243 |
+
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
| 244 |
+
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
|
| 245 |
+
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
| 246 |
+
self.residual = residual
|
| 247 |
+
|
| 248 |
+
def forward(self, x):
|
| 249 |
+
if self.residual:
|
| 250 |
+
short_cut = x
|
| 251 |
+
x = self.conv(x)
|
| 252 |
+
x = self.conv_dw(x)
|
| 253 |
+
x = self.project(x)
|
| 254 |
+
if self.residual:
|
| 255 |
+
output = short_cut + x
|
| 256 |
+
else:
|
| 257 |
+
output = x
|
| 258 |
+
return output
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class Residual(Module):
|
| 262 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
| 263 |
+
super(Residual, self).__init__()
|
| 264 |
+
modules = []
|
| 265 |
+
for _ in range(num_block):
|
| 266 |
+
modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
|
| 267 |
+
self.model = Sequential(*modules)
|
| 268 |
+
|
| 269 |
+
def forward(self, x):
|
| 270 |
+
return self.model(x)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class MobileFaceNet(Module):
|
| 274 |
+
def __init__(self, embedding_size):
|
| 275 |
+
super(MobileFaceNet, self).__init__()
|
| 276 |
+
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
|
| 277 |
+
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
|
| 278 |
+
self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
|
| 279 |
+
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 280 |
+
self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
|
| 281 |
+
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 282 |
+
self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
|
| 283 |
+
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 284 |
+
self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
| 285 |
+
self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
|
| 286 |
+
self.conv_6_flatten = Flatten()
|
| 287 |
+
self.linear = Linear(512, embedding_size, bias=False)
|
| 288 |
+
self.bn = BatchNorm1d(embedding_size)
|
| 289 |
+
|
| 290 |
+
def forward(self, x):
|
| 291 |
+
out = self.conv1(x)
|
| 292 |
+
out = self.conv2_dw(out)
|
| 293 |
+
out = self.conv_23(out)
|
| 294 |
+
out = self.conv_3(out)
|
| 295 |
+
out = self.conv_34(out)
|
| 296 |
+
out = self.conv_4(out)
|
| 297 |
+
out = self.conv_45(out)
|
| 298 |
+
out = self.conv_5(out)
|
| 299 |
+
out = self.conv_6_sep(out)
|
| 300 |
+
out = self.conv_6_dw(out)
|
| 301 |
+
out = self.conv_6_flatten(out)
|
| 302 |
+
out = self.linear(out)
|
| 303 |
+
out = self.bn(out)
|
| 304 |
+
|
| 305 |
+
return l2_norm(out)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
################################## Arcface head #############################################################
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class Arcface(Module):
|
| 312 |
+
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
|
| 313 |
+
def __init__(self, embedding_size=512, classnum=51332, s=64., m=0.5):
|
| 314 |
+
super(Arcface, self).__init__()
|
| 315 |
+
self.classnum = classnum
|
| 316 |
+
self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
|
| 317 |
+
# initial kernel
|
| 318 |
+
self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
|
| 319 |
+
self.m = m # the margin value, default is 0.5
|
| 320 |
+
self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
|
| 321 |
+
self.cos_m = math.cos(m)
|
| 322 |
+
self.sin_m = math.sin(m)
|
| 323 |
+
self.mm = self.sin_m * m # issue 1
|
| 324 |
+
self.threshold = math.cos(math.pi - m)
|
| 325 |
+
|
| 326 |
+
def forward(self, embbedings, label):
|
| 327 |
+
# weights norm
|
| 328 |
+
nB = len(embbedings)
|
| 329 |
+
kernel_norm = l2_norm(self.kernel,axis=0)
|
| 330 |
+
# cos(theta+m)
|
| 331 |
+
cos_theta = torch.mm(embbedings,kernel_norm)
|
| 332 |
+
# output = torch.mm(embbedings,kernel_norm)
|
| 333 |
+
cos_theta = cos_theta.clamp(-1,1) # for numerical stability
|
| 334 |
+
cos_theta_2 = torch.pow(cos_theta, 2)
|
| 335 |
+
sin_theta_2 = 1 - cos_theta_2
|
| 336 |
+
sin_theta = torch.sqrt(sin_theta_2)
|
| 337 |
+
cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
|
| 338 |
+
# this condition controls the theta+m should in range [0, pi]
|
| 339 |
+
# 0<=theta+m<=pi
|
| 340 |
+
# -m<=theta<=pi-m
|
| 341 |
+
cond_v = cos_theta - self.threshold
|
| 342 |
+
cond_mask = cond_v <= 0
|
| 343 |
+
keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
|
| 344 |
+
cos_theta_m[cond_mask] = keep_val[cond_mask]
|
| 345 |
+
output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
|
| 346 |
+
idx_ = torch.arange(0, nB, dtype=torch.long)
|
| 347 |
+
output[idx_, label] = cos_theta_m[idx_, label]
|
| 348 |
+
output *= self.s # scale up in order to make softmax work, first introduced in normface
|
| 349 |
+
return output
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
################################## Cosface head #############################################################
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class Am_softmax(Module):
|
| 356 |
+
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
|
| 357 |
+
def __init__(self,embedding_size=512,classnum=51332):
|
| 358 |
+
super(Am_softmax, self).__init__()
|
| 359 |
+
self.classnum = classnum
|
| 360 |
+
self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
|
| 361 |
+
# initial kernel
|
| 362 |
+
self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
|
| 363 |
+
self.m = 0.35 # additive margin recommended by the paper
|
| 364 |
+
self.s = 30. # see normface https://arxiv.org/abs/1704.06369
|
| 365 |
+
|
| 366 |
+
def forward(self,embbedings,label):
|
| 367 |
+
kernel_norm = l2_norm(self.kernel,axis=0)
|
| 368 |
+
cos_theta = torch.mm(embbedings,kernel_norm)
|
| 369 |
+
cos_theta = cos_theta.clamp(-1,1) # for numerical stability
|
| 370 |
+
phi = cos_theta - self.m
|
| 371 |
+
label = label.view(-1,1) #size=(B,1)
|
| 372 |
+
index = cos_theta.data * 0.0 #size=(B,Classnum)
|
| 373 |
+
index.scatter_(1,label.data.view(-1,1),1)
|
| 374 |
+
index = index.byte()
|
| 375 |
+
output = cos_theta * 1.0
|
| 376 |
+
output[index] = phi[index] #only change the correct predicted output
|
| 377 |
+
output *= self.s # scale up in order to make softmax work, first introduced in normface
|
| 378 |
+
return output
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
cfg = dict(num_layers=50, drop_ratio=0.6, mode='ir', type='Backbone')
|
| 383 |
+
backbone = MODELS.build(cfg)
|
| 384 |
+
print(backbone)
|
models/networks/common.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
BN_MOMENTUM = 0.1
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def point_wise_block(inplanes, outplanes):
|
| 10 |
+
return nn.Sequential(
|
| 11 |
+
nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=1, padding=0, stride=1, bias=False),
|
| 12 |
+
nn.BatchNorm2d(outplanes, momentum=BN_MOMENTUM),
|
| 13 |
+
nn.ReLU(inplace=True),
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def conv_block(inplanes, outplanes, kernel_size, stride=1, padding=0):
|
| 18 |
+
return nn.Sequential(
|
| 19 |
+
nn.Conv2d(in_channels=inplanes, out_channels=outplanes, kernel_size=kernel_size, padding=padding, stride=stride, bias=False),
|
| 20 |
+
nn.BatchNorm2d(outplanes, momentum=BN_MOMENTUM),
|
| 21 |
+
nn.ReLU(inplace=True)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 26 |
+
"""3x3 convolution with padding"""
|
| 27 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 28 |
+
padding=1, bias=False)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class InceptionBlock(nn.Module):
|
| 32 |
+
def __init__(self, inplanes, outplanes, stride=1, pool_size=3):
|
| 33 |
+
self.inplanes = inplanes
|
| 34 |
+
self.outplanes = outplanes
|
| 35 |
+
self.stride = stride
|
| 36 |
+
self.pool_size = pool_size
|
| 37 |
+
super(InceptionBlock, self).__init__()
|
| 38 |
+
|
| 39 |
+
self.pw_block = point_wise_block(self.inplanes, self.outplanes//4)
|
| 40 |
+
self.mp_layer = nn.MaxPool2d(kernel_size=self.pool_size, stride=stride, padding=1)
|
| 41 |
+
self.conv3_block = conv_block(self.outplanes//4, self.outplanes//4, kernel_size=3, stride=1, padding=1)
|
| 42 |
+
self.conv5_block = conv_block(self.outplanes//4, self.outplanes//4, kernel_size=5, stride=1, padding=2)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x1 = self.pw_block(x)
|
| 46 |
+
|
| 47 |
+
x2 = self.pw_block(x)
|
| 48 |
+
x2 = self.conv3_block(x2)
|
| 49 |
+
|
| 50 |
+
x3 = self.pw_block(x)
|
| 51 |
+
x3 = self.conv5_block(x3)
|
| 52 |
+
|
| 53 |
+
x4 = self.mp_layer(x)
|
| 54 |
+
x4 = self.pw_block(x4)
|
| 55 |
+
|
| 56 |
+
x = torch.cat((x1, x2, x3, x4), dim=1)
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SELayer(nn.Module):
|
| 61 |
+
def __init__(self, channel, reduction=16):
|
| 62 |
+
super(SELayer, self).__init__()
|
| 63 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 64 |
+
self.fc = nn.Sequential(
|
| 65 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
| 66 |
+
nn.ReLU(inplace=True),
|
| 67 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
| 68 |
+
nn.Sigmoid()
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
b, c, _, _ = x.size()
|
| 73 |
+
y = self.avg_pool(x).view(b, c)
|
| 74 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 75 |
+
return x * y.expand_as(x)
|
models/networks/efficientNet.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
import math
|
| 3 |
+
import re
|
| 4 |
+
import collections
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.utils import model_zoo
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Parameters for the entire model (stem, all blocks, and head)
|
| 14 |
+
GlobalParams = collections.namedtuple('GlobalParams', [
|
| 15 |
+
'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
|
| 16 |
+
'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
|
| 17 |
+
'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top',
|
| 18 |
+
'include_hm_decoder', 'head_conv', 'heads', 'num_layers', 'INIT_WEIGHTS',
|
| 19 |
+
'use_c2', 'use_c3', 'use_c4', 'use_c51', 'efpn', 'se_layer', 'tfpn'])
|
| 20 |
+
|
| 21 |
+
# Parameters for an individual model block
|
| 22 |
+
BlockArgs = collections.namedtuple('BlockArgs', [
|
| 23 |
+
'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
|
| 24 |
+
'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
|
| 25 |
+
|
| 26 |
+
# Set GlobalParams and BlockArgs's defaults
|
| 27 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
| 28 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Swish activation function
|
| 32 |
+
if hasattr(nn, 'SiLU'):
|
| 33 |
+
Swish = nn.SiLU
|
| 34 |
+
else:
|
| 35 |
+
# For compatibility with old PyTorch versions
|
| 36 |
+
class Swish(nn.Module):
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return x * torch.sigmoid(x)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def round_filters(filters, global_params):
|
| 42 |
+
"""Calculate and round number of filters based on width multiplier.
|
| 43 |
+
Use width_coefficient, depth_divisor and min_depth of global_params.
|
| 44 |
+
Args:
|
| 45 |
+
filters (int): Filters number to be calculated.
|
| 46 |
+
global_params (namedtuple): Global params of the model.
|
| 47 |
+
Returns:
|
| 48 |
+
new_filters: New filters number after calculating.
|
| 49 |
+
"""
|
| 50 |
+
multiplier = global_params.width_coefficient
|
| 51 |
+
if not multiplier:
|
| 52 |
+
return filters
|
| 53 |
+
# TODO: modify the params names.
|
| 54 |
+
# maybe the names (width_divisor,min_width)
|
| 55 |
+
# are more suitable than (depth_divisor,min_depth).
|
| 56 |
+
divisor = global_params.depth_divisor
|
| 57 |
+
min_depth = global_params.min_depth
|
| 58 |
+
filters *= multiplier
|
| 59 |
+
min_depth = min_depth or divisor # pay attention to this line when using min_depth
|
| 60 |
+
# follow the formula transferred from official TensorFlow implementation
|
| 61 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
| 62 |
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
| 63 |
+
new_filters += divisor
|
| 64 |
+
return int(new_filters)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def round_repeats(repeats, global_params):
|
| 68 |
+
"""Calculate module's repeat number of a block based on depth multiplier.
|
| 69 |
+
Use depth_coefficient of global_params.
|
| 70 |
+
Args:
|
| 71 |
+
repeats (int): num_repeat to be calculated.
|
| 72 |
+
global_params (namedtuple): Global params of the model.
|
| 73 |
+
Returns:
|
| 74 |
+
new repeat: New repeat number after calculating.
|
| 75 |
+
"""
|
| 76 |
+
multiplier = global_params.depth_coefficient
|
| 77 |
+
if not multiplier:
|
| 78 |
+
return repeats
|
| 79 |
+
# follow the formula transferred from official TensorFlow implementation
|
| 80 |
+
return int(math.ceil(multiplier * repeats))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def drop_connect(inputs, p, training):
|
| 84 |
+
"""Drop connect.
|
| 85 |
+
Args:
|
| 86 |
+
input (tensor: BCWH): Input of this structure.
|
| 87 |
+
p (float: 0.0~1.0): Probability of drop connection.
|
| 88 |
+
training (bool): The running mode.
|
| 89 |
+
Returns:
|
| 90 |
+
output: Output after drop connection.
|
| 91 |
+
"""
|
| 92 |
+
assert 0 <= p <= 1, 'p must be in range of [0,1]'
|
| 93 |
+
|
| 94 |
+
if not training:
|
| 95 |
+
return inputs
|
| 96 |
+
|
| 97 |
+
batch_size = inputs.shape[0]
|
| 98 |
+
keep_prob = 1 - p
|
| 99 |
+
|
| 100 |
+
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
| 101 |
+
random_tensor = keep_prob
|
| 102 |
+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
| 103 |
+
binary_tensor = torch.floor(random_tensor)
|
| 104 |
+
|
| 105 |
+
output = inputs / keep_prob * binary_tensor
|
| 106 |
+
return output
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_same_padding_conv2d(image_size=None):
|
| 110 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
| 111 |
+
Static padding is necessary for ONNX exporting of models.
|
| 112 |
+
Args:
|
| 113 |
+
image_size (int or tuple): Size of the image.
|
| 114 |
+
Returns:
|
| 115 |
+
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
| 116 |
+
"""
|
| 117 |
+
if image_size is None:
|
| 118 |
+
return Conv2dDynamicSamePadding
|
| 119 |
+
else:
|
| 120 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
| 124 |
+
"""2D Convolutions like TensorFlow, for a dynamic image size.
|
| 125 |
+
The padding is operated in forward function by calculating dynamically.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
# Tips for 'SAME' mode padding.
|
| 129 |
+
# Given the following:
|
| 130 |
+
# i: width or height
|
| 131 |
+
# s: stride
|
| 132 |
+
# k: kernel size
|
| 133 |
+
# d: dilation
|
| 134 |
+
# p: padding
|
| 135 |
+
# Output after Conv2d:
|
| 136 |
+
# o = floor((i+p-((k-1)*d+1))/s+1)
|
| 137 |
+
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
|
| 138 |
+
# => p = (i-1)*s+((k-1)*d+1)-i
|
| 139 |
+
|
| 140 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
|
| 141 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
| 142 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
ih, iw = x.size()[-2:]
|
| 146 |
+
kh, kw = self.weight.size()[-2:]
|
| 147 |
+
sh, sw = self.stride
|
| 148 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
|
| 149 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 150 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 151 |
+
if pad_h > 0 or pad_w > 0:
|
| 152 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 153 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
| 157 |
+
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
|
| 158 |
+
The padding mudule is calculated in construction function, then used in forward.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
# With the same calculation as Conv2dDynamicSamePadding
|
| 162 |
+
|
| 163 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
|
| 164 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
|
| 165 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 166 |
+
|
| 167 |
+
# Calculate padding based on image size and save it
|
| 168 |
+
assert image_size is not None
|
| 169 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
| 170 |
+
kh, kw = self.weight.size()[-2:]
|
| 171 |
+
sh, sw = self.stride
|
| 172 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 173 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 174 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 175 |
+
if pad_h > 0 or pad_w > 0:
|
| 176 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
|
| 177 |
+
pad_h // 2, pad_h - pad_h // 2))
|
| 178 |
+
else:
|
| 179 |
+
self.static_padding = nn.Identity()
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
x = self.static_padding(x)
|
| 183 |
+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_model_params(model_name, override_params):
|
| 188 |
+
"""Get the block args and global params for a given model name.
|
| 189 |
+
Args:
|
| 190 |
+
model_name (str): Model's name.
|
| 191 |
+
override_params (dict): A dict to modify global_params.
|
| 192 |
+
Returns:
|
| 193 |
+
blocks_args, global_params
|
| 194 |
+
"""
|
| 195 |
+
if model_name.startswith('efficientnet'):
|
| 196 |
+
w, d, s, p = efficientnet_params(model_name)
|
| 197 |
+
# note: all models have drop connect rate = 0.2
|
| 198 |
+
blocks_args, global_params = efficientnet(
|
| 199 |
+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
|
| 200 |
+
else:
|
| 201 |
+
raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
|
| 202 |
+
if override_params:
|
| 203 |
+
# ValueError will be raised here if override_params has fields not included in global_params.
|
| 204 |
+
global_params = global_params._replace(**override_params)
|
| 205 |
+
return blocks_args, global_params
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def efficientnet_params(model_name):
|
| 209 |
+
"""Map EfficientNet model name to parameter coefficients.
|
| 210 |
+
Args:
|
| 211 |
+
model_name (str): Model name to be queried.
|
| 212 |
+
Returns:
|
| 213 |
+
params_dict[model_name]: A (width,depth,res,dropout) tuple.
|
| 214 |
+
"""
|
| 215 |
+
params_dict = {
|
| 216 |
+
# Coefficients: width,depth,res,dropout
|
| 217 |
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
| 218 |
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
| 219 |
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
| 220 |
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
| 221 |
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
| 222 |
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
| 223 |
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
| 224 |
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
| 225 |
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
| 226 |
+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
| 227 |
+
}
|
| 228 |
+
return params_dict[model_name]
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
|
| 232 |
+
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000,
|
| 233 |
+
include_top=True, include_hm_decoder=False, head_conv=None,
|
| 234 |
+
heads=None, use_c2=False, use_c3=False, use_c4=False, use_c51=False,
|
| 235 |
+
num_layers=None, INIT_WEIGHTS=None, efpn=False, se_layer=False, tfpn=False):
|
| 236 |
+
"""Create BlockArgs and GlobalParams for efficientnet model.
|
| 237 |
+
Args:
|
| 238 |
+
width_coefficient (float)
|
| 239 |
+
depth_coefficient (float)
|
| 240 |
+
image_size (int)
|
| 241 |
+
dropout_rate (float)
|
| 242 |
+
drop_connect_rate (float)
|
| 243 |
+
num_classes (int)
|
| 244 |
+
Meaning as the name suggests.
|
| 245 |
+
Returns:
|
| 246 |
+
blocks_args, global_params.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
# Blocks args for the whole model(efficientnet-b0 by default)
|
| 250 |
+
# It will be modified in the construction of EfficientNet Class according to model
|
| 251 |
+
blocks_args = [
|
| 252 |
+
'r1_k3_s11_e1_i32_o16_se0.25',
|
| 253 |
+
'r2_k3_s22_e6_i16_o24_se0.25',
|
| 254 |
+
'r2_k5_s22_e6_i24_o40_se0.25',
|
| 255 |
+
'r3_k3_s22_e6_i40_o80_se0.25',
|
| 256 |
+
'r3_k5_s11_e6_i80_o112_se0.25',
|
| 257 |
+
'r4_k5_s22_e6_i112_o192_se0.25',
|
| 258 |
+
'r1_k3_s11_e6_i192_o320_se0.25',
|
| 259 |
+
]
|
| 260 |
+
blocks_args = BlockDecoder.decode(blocks_args)
|
| 261 |
+
|
| 262 |
+
global_params = GlobalParams(
|
| 263 |
+
width_coefficient=width_coefficient,
|
| 264 |
+
depth_coefficient=depth_coefficient,
|
| 265 |
+
image_size=image_size,
|
| 266 |
+
dropout_rate=dropout_rate,
|
| 267 |
+
|
| 268 |
+
num_classes=num_classes,
|
| 269 |
+
batch_norm_momentum=0.99,
|
| 270 |
+
batch_norm_epsilon=1e-3,
|
| 271 |
+
drop_connect_rate=drop_connect_rate,
|
| 272 |
+
depth_divisor=8,
|
| 273 |
+
min_depth=None,
|
| 274 |
+
include_top=include_top,
|
| 275 |
+
include_hm_decoder=include_hm_decoder,
|
| 276 |
+
head_conv=head_conv,
|
| 277 |
+
heads=heads,
|
| 278 |
+
use_c2=use_c2,
|
| 279 |
+
use_c3=use_c3,
|
| 280 |
+
use_c4=use_c4,
|
| 281 |
+
use_c51=use_c51,
|
| 282 |
+
efpn=efpn,
|
| 283 |
+
tfpn=tfpn,
|
| 284 |
+
se_layer=se_layer,
|
| 285 |
+
num_layers=num_layers,
|
| 286 |
+
INIT_WEIGHTS=INIT_WEIGHTS
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
return blocks_args, global_params
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class BlockDecoder(object):
|
| 293 |
+
"""Block Decoder for readability,
|
| 294 |
+
straight from the official TensorFlow repository.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def _decode_block_string(block_string):
|
| 299 |
+
"""Get a block through a string notation of arguments.
|
| 300 |
+
Args:
|
| 301 |
+
block_string (str): A string notation of arguments.
|
| 302 |
+
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
|
| 303 |
+
Returns:
|
| 304 |
+
BlockArgs: The namedtuple defined at the top of this file.
|
| 305 |
+
"""
|
| 306 |
+
assert isinstance(block_string, str)
|
| 307 |
+
|
| 308 |
+
ops = block_string.split('_')
|
| 309 |
+
options = {}
|
| 310 |
+
for op in ops:
|
| 311 |
+
splits = re.split(r'(\d.*)', op)
|
| 312 |
+
if len(splits) >= 2:
|
| 313 |
+
key, value = splits[:2]
|
| 314 |
+
options[key] = value
|
| 315 |
+
|
| 316 |
+
# Check stride
|
| 317 |
+
assert (('s' in options and len(options['s']) == 1) or
|
| 318 |
+
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
| 319 |
+
|
| 320 |
+
return BlockArgs(
|
| 321 |
+
num_repeat=int(options['r']),
|
| 322 |
+
kernel_size=int(options['k']),
|
| 323 |
+
stride=[int(options['s'][0])],
|
| 324 |
+
expand_ratio=int(options['e']),
|
| 325 |
+
input_filters=int(options['i']),
|
| 326 |
+
output_filters=int(options['o']),
|
| 327 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
| 328 |
+
id_skip=('noskip' not in block_string))
|
| 329 |
+
|
| 330 |
+
@staticmethod
|
| 331 |
+
def _encode_block_string(block):
|
| 332 |
+
"""Encode a block to a string.
|
| 333 |
+
Args:
|
| 334 |
+
block (namedtuple): A BlockArgs type argument.
|
| 335 |
+
Returns:
|
| 336 |
+
block_string: A String form of BlockArgs.
|
| 337 |
+
"""
|
| 338 |
+
args = [
|
| 339 |
+
'r%d' % block.num_repeat,
|
| 340 |
+
'k%d' % block.kernel_size,
|
| 341 |
+
's%d%d' % (block.strides[0], block.strides[1]),
|
| 342 |
+
'e%s' % block.expand_ratio,
|
| 343 |
+
'i%d' % block.input_filters,
|
| 344 |
+
'o%d' % block.output_filters
|
| 345 |
+
]
|
| 346 |
+
if 0 < block.se_ratio <= 1:
|
| 347 |
+
args.append('se%s' % block.se_ratio)
|
| 348 |
+
if block.id_skip is False:
|
| 349 |
+
args.append('noskip')
|
| 350 |
+
return '_'.join(args)
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
def decode(string_list):
|
| 354 |
+
"""Decode a list of string notations to specify blocks inside the network.
|
| 355 |
+
Args:
|
| 356 |
+
string_list (list[str]): A list of strings, each string is a notation of block.
|
| 357 |
+
Returns:
|
| 358 |
+
blocks_args: A list of BlockArgs namedtuples of block args.
|
| 359 |
+
"""
|
| 360 |
+
assert isinstance(string_list, list)
|
| 361 |
+
blocks_args = []
|
| 362 |
+
for block_string in string_list:
|
| 363 |
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
| 364 |
+
return blocks_args
|
| 365 |
+
|
| 366 |
+
@staticmethod
|
| 367 |
+
def encode(blocks_args):
|
| 368 |
+
"""Encode a list of BlockArgs to a list of strings.
|
| 369 |
+
Args:
|
| 370 |
+
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
|
| 371 |
+
Returns:
|
| 372 |
+
block_strings: A list of strings, each string is a notation of block.
|
| 373 |
+
"""
|
| 374 |
+
block_strings = []
|
| 375 |
+
for block in blocks_args:
|
| 376 |
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
| 377 |
+
return block_strings
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class SwishImplementation(torch.autograd.Function):
|
| 381 |
+
@staticmethod
|
| 382 |
+
def forward(ctx, i):
|
| 383 |
+
result = i * torch.sigmoid(i)
|
| 384 |
+
ctx.save_for_backward(i)
|
| 385 |
+
return result
|
| 386 |
+
|
| 387 |
+
@staticmethod
|
| 388 |
+
def backward(ctx, grad_output):
|
| 389 |
+
i = ctx.saved_tensors[0]
|
| 390 |
+
sigmoid_i = torch.sigmoid(i)
|
| 391 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def get_width_and_height_from_size(x):
|
| 395 |
+
"""Obtain height and width from x.
|
| 396 |
+
Args:
|
| 397 |
+
x (int, tuple or list): Data size.
|
| 398 |
+
Returns:
|
| 399 |
+
size: A tuple or list (H,W).
|
| 400 |
+
"""
|
| 401 |
+
if isinstance(x, int):
|
| 402 |
+
return x, x
|
| 403 |
+
if isinstance(x, list) or isinstance(x, tuple):
|
| 404 |
+
return x
|
| 405 |
+
else:
|
| 406 |
+
raise TypeError()
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def calculate_output_image_size(input_image_size, stride):
|
| 410 |
+
"""Calculates the output image size when using Conv2dSamePadding with a stride.
|
| 411 |
+
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
| 412 |
+
Args:
|
| 413 |
+
input_image_size (int, tuple or list): Size of input image.
|
| 414 |
+
stride (int, tuple or list): Conv2d operation's stride.
|
| 415 |
+
Returns:
|
| 416 |
+
output_image_size: A list [H,W].
|
| 417 |
+
"""
|
| 418 |
+
if input_image_size is None:
|
| 419 |
+
return None
|
| 420 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
| 421 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
| 422 |
+
image_height = int(math.ceil(image_height / stride))
|
| 423 |
+
image_width = int(math.ceil(image_width / stride))
|
| 424 |
+
return [image_height, image_width]
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class MemoryEfficientSwish(nn.Module):
|
| 428 |
+
def forward(self, x):
|
| 429 |
+
return SwishImplementation.apply(x)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
url_map_advprop = {
|
| 433 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
|
| 434 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
|
| 435 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
|
| 436 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
|
| 437 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
|
| 438 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
|
| 439 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
|
| 440 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
|
| 441 |
+
'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
url_map = {
|
| 446 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
|
| 447 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
|
| 448 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
|
| 449 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
|
| 450 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
|
| 451 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
|
| 452 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
|
| 453 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
|
| 458 |
+
"""Loads pretrained weights from weights path or download using url.
|
| 459 |
+
Args:
|
| 460 |
+
model (Module): The whole model of efficientnet.
|
| 461 |
+
model_name (str): Model name of efficientnet.
|
| 462 |
+
weights_path (None or str):
|
| 463 |
+
str: path to pretrained weights file on the local disk.
|
| 464 |
+
None: use pretrained weights downloaded from the Internet.
|
| 465 |
+
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
|
| 466 |
+
advprop (bool): Whether to load pretrained weights
|
| 467 |
+
trained with advprop (valid when weights_path is None).
|
| 468 |
+
"""
|
| 469 |
+
if isinstance(weights_path, str):
|
| 470 |
+
state_dict = torch.load(weights_path)
|
| 471 |
+
else:
|
| 472 |
+
# AutoAugment or Advprop (different preprocessing)
|
| 473 |
+
url_map_ = url_map_advprop if advprop else url_map
|
| 474 |
+
state_dict = model_zoo.load_url(url_map_[model_name])
|
| 475 |
+
|
| 476 |
+
if load_fc:
|
| 477 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
| 478 |
+
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
| 479 |
+
else:
|
| 480 |
+
state_dict.pop('_fc.weight')
|
| 481 |
+
state_dict.pop('_fc.bias')
|
| 482 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
| 483 |
+
|
| 484 |
+
# if len(ret.missing_keys):
|
| 485 |
+
# assert set(ret.missing_keys) == set(
|
| 486 |
+
# ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
| 487 |
+
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
|
| 488 |
+
|
| 489 |
+
if verbose:
|
| 490 |
+
print('Loaded pretrained weights for {}'.format(model_name))
|
models/networks/mrsa_resnet.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
from __future__ import absolute_import
|
| 3 |
+
from __future__ import division
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn.modules.activation import ReLU
|
| 12 |
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
| 13 |
+
from torch.nn.modules.pooling import MaxPool2d
|
| 14 |
+
import torch.utils.model_zoo as model_zoo
|
| 15 |
+
|
| 16 |
+
from ..builder import MODELS, build_model
|
| 17 |
+
from .common import (
|
| 18 |
+
BN_MOMENTUM,
|
| 19 |
+
conv_block,
|
| 20 |
+
point_wise_block,
|
| 21 |
+
InceptionBlock,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
model_urls = {
|
| 26 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 27 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 28 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 29 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 30 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 35 |
+
"""3x3 convolution with padding"""
|
| 36 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 37 |
+
padding=1, bias=False)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class BasicBlock(nn.Module):
|
| 41 |
+
expansion = 1
|
| 42 |
+
|
| 43 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 44 |
+
super(BasicBlock, self).__init__()
|
| 45 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 46 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 47 |
+
self.relu = nn.ReLU(inplace=True)
|
| 48 |
+
self.conv2 = conv3x3(planes, planes)
|
| 49 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 50 |
+
self.downsample = downsample
|
| 51 |
+
self.stride = stride
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
residual = x
|
| 55 |
+
|
| 56 |
+
out = self.conv1(x)
|
| 57 |
+
out = self.bn1(out)
|
| 58 |
+
out = self.relu(out)
|
| 59 |
+
|
| 60 |
+
out = self.conv2(out)
|
| 61 |
+
out = self.bn2(out)
|
| 62 |
+
|
| 63 |
+
if self.downsample is not None:
|
| 64 |
+
residual = self.downsample(x)
|
| 65 |
+
|
| 66 |
+
out += residual
|
| 67 |
+
out = self.relu(out)
|
| 68 |
+
|
| 69 |
+
return out
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def __repr__():
|
| 73 |
+
return 'BasicBlock'
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Bottleneck(nn.Module):
|
| 77 |
+
expansion = 4
|
| 78 |
+
|
| 79 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 80 |
+
super(Bottleneck, self).__init__()
|
| 81 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 82 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 83 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 84 |
+
padding=1, bias=False)
|
| 85 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 86 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
| 87 |
+
bias=False)
|
| 88 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
| 89 |
+
momentum=BN_MOMENTUM)
|
| 90 |
+
self.relu = nn.ReLU(inplace=True)
|
| 91 |
+
self.downsample = downsample
|
| 92 |
+
self.stride = stride
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
residual = x
|
| 96 |
+
|
| 97 |
+
out = self.conv1(x)
|
| 98 |
+
out = self.bn1(out)
|
| 99 |
+
out = self.relu(out)
|
| 100 |
+
|
| 101 |
+
out = self.conv2(out)
|
| 102 |
+
out = self.bn2(out)
|
| 103 |
+
out = self.relu(out)
|
| 104 |
+
|
| 105 |
+
out = self.conv3(out)
|
| 106 |
+
out = self.bn3(out)
|
| 107 |
+
|
| 108 |
+
if self.downsample is not None:
|
| 109 |
+
residual = self.downsample(x)
|
| 110 |
+
|
| 111 |
+
out += residual
|
| 112 |
+
out = self.relu(out)
|
| 113 |
+
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def __repr__():
|
| 118 |
+
return 'Bottleneck'
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@MODELS.register_module()
|
| 122 |
+
class PoseResNet(nn.Module):
|
| 123 |
+
def __init__(self,
|
| 124 |
+
block,
|
| 125 |
+
layers,
|
| 126 |
+
heads,
|
| 127 |
+
head_conv,
|
| 128 |
+
dropout_prob,
|
| 129 |
+
fpn=False,
|
| 130 |
+
cls_based_hm=True,
|
| 131 |
+
use_c2=False,
|
| 132 |
+
**kwargs):
|
| 133 |
+
self.inplanes = 64
|
| 134 |
+
self.deconv_with_bias = False
|
| 135 |
+
self.heads = heads
|
| 136 |
+
self.fpn = fpn
|
| 137 |
+
self.cls_based_hm= cls_based_hm
|
| 138 |
+
self.use_c2 = use_c2
|
| 139 |
+
|
| 140 |
+
#Convert Cls name into Cls Object
|
| 141 |
+
if isinstance(block, str):
|
| 142 |
+
for bl in [BasicBlock, Bottleneck]:
|
| 143 |
+
if block == bl.__repr__():
|
| 144 |
+
block = bl
|
| 145 |
+
|
| 146 |
+
for k, v in kwargs.items():
|
| 147 |
+
if v is None:
|
| 148 |
+
raise ValueError(f'The {k} argument receive a None value, Please check!')
|
| 149 |
+
self.__setattr__(k, v)
|
| 150 |
+
|
| 151 |
+
super(PoseResNet, self).__init__()
|
| 152 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
| 153 |
+
bias=False)
|
| 154 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
| 155 |
+
self.relu = nn.ReLU(inplace=True)
|
| 156 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 157 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 158 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 159 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 160 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 161 |
+
|
| 162 |
+
# Custom dropout layer
|
| 163 |
+
self.dropout_layer = nn.Dropout(dropout_prob)
|
| 164 |
+
|
| 165 |
+
if self.fpn:
|
| 166 |
+
# Adding sidmoid layer
|
| 167 |
+
self.sigmoid_layer = nn.Sigmoid()
|
| 168 |
+
|
| 169 |
+
# Adding pointwise block
|
| 170 |
+
self.pw_block_1 = self._point_wise_block(2048, 1024)
|
| 171 |
+
|
| 172 |
+
# used for deconv layers
|
| 173 |
+
deconv_filters = [256, 128, 256] if self.fpn else [256, 256, 256]
|
| 174 |
+
self.deconv_layers = self._make_deconv_layer(
|
| 175 |
+
3,
|
| 176 |
+
deconv_filters,
|
| 177 |
+
[4, 4, 4],
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Adding inception block
|
| 181 |
+
if self.fpn:
|
| 182 |
+
for idx, deconv_layer in enumerate(self.deconv_layers):
|
| 183 |
+
self.__setattr__(f'deconv_layer_{idx}', nn.Sequential(deconv_layer))
|
| 184 |
+
self.pw_block_2 = self._point_wise_block(512, 512)
|
| 185 |
+
if self.use_c2:
|
| 186 |
+
self.pw_block_3 = self._point_wise_block(512, 256)
|
| 187 |
+
self.pw_block_c3 = self._point_wise_block(1024, 256)
|
| 188 |
+
self.pw_block_c2 = self._point_wise_block(512, 128)
|
| 189 |
+
self.inception_block = InceptionBlock(256, 256, stride=1, pool_size=3)
|
| 190 |
+
|
| 191 |
+
for head in sorted(self.heads):
|
| 192 |
+
num_output = self.heads[head]
|
| 193 |
+
if head_conv > 0:
|
| 194 |
+
if head != 'cls':
|
| 195 |
+
fc = nn.Sequential(
|
| 196 |
+
nn.Conv2d(256, head_conv,
|
| 197 |
+
kernel_size=3, padding=1, bias=True),
|
| 198 |
+
nn.BatchNorm2d(head_conv),
|
| 199 |
+
nn.ReLU(inplace=True),
|
| 200 |
+
nn.Conv2d(head_conv, num_output,
|
| 201 |
+
kernel_size=1, stride=1, padding=0)
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
if self.cls_based_hm:
|
| 205 |
+
fc = nn.Sequential(
|
| 206 |
+
nn.AdaptiveMaxPool2d(head_conv//4),
|
| 207 |
+
nn.Flatten(),
|
| 208 |
+
nn.Linear(num_output*((head_conv//4)**2), head_conv, bias=True),
|
| 209 |
+
nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
|
| 210 |
+
nn.ReLU(inplace=True),
|
| 211 |
+
nn.Linear(head_conv, 1, bias=True),
|
| 212 |
+
nn.Sigmoid()
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
fc = nn.Sequential(
|
| 216 |
+
nn.Conv2d(256, head_conv, kernel_size=3,
|
| 217 |
+
padding=1, bias=True),
|
| 218 |
+
nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
|
| 219 |
+
nn.ReLU(inplace=True),
|
| 220 |
+
# nn.Conv2d(head_conv, num_output, kernel_size=1,
|
| 221 |
+
# stride=1, padding=0, bias=True),
|
| 222 |
+
# nn.BatchNorm2d(num_output),
|
| 223 |
+
# nn.ReLU(inplace=True),
|
| 224 |
+
# nn.AdaptiveMaxPool2d(head_conv//4),
|
| 225 |
+
nn.AdaptiveAvgPool2d(1),
|
| 226 |
+
nn.Flatten(),
|
| 227 |
+
# nn.Linear((head_conv//4)**2, head_conv, bias=True),
|
| 228 |
+
# nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
|
| 229 |
+
# nn.ReLU(inplace=True),
|
| 230 |
+
nn.Linear(head_conv, 1, bias=True),
|
| 231 |
+
# nn.Sigmoid()
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
fc = nn.Conv2d(
|
| 235 |
+
in_channels=256,
|
| 236 |
+
out_channels=num_output,
|
| 237 |
+
kernel_size=1,
|
| 238 |
+
stride=1,
|
| 239 |
+
padding=0
|
| 240 |
+
)
|
| 241 |
+
self.__setattr__(head, fc)
|
| 242 |
+
|
| 243 |
+
def _point_wise_block(self, inplanes, outplanes):
|
| 244 |
+
self.inplanes = outplanes
|
| 245 |
+
module = point_wise_block(inplanes, outplanes)
|
| 246 |
+
return module
|
| 247 |
+
|
| 248 |
+
def _conv_block(self, inplanes, outplanes, kernel_size, stride=1):
|
| 249 |
+
self.inplanes = outplanes
|
| 250 |
+
module = conv_block(inplanes, outplanes, kernel_size=kernel_size, stride=stride)
|
| 251 |
+
return module
|
| 252 |
+
|
| 253 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 254 |
+
downsample = None
|
| 255 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 256 |
+
downsample = nn.Sequential(
|
| 257 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 258 |
+
kernel_size=1, stride=stride, bias=False),
|
| 259 |
+
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
layers = []
|
| 263 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 264 |
+
self.inplanes = planes * block.expansion
|
| 265 |
+
for i in range(1, blocks):
|
| 266 |
+
layers.append(block(self.inplanes, planes))
|
| 267 |
+
|
| 268 |
+
return nn.Sequential(*layers)
|
| 269 |
+
|
| 270 |
+
def _get_deconv_cfg(self, deconv_kernel, index):
|
| 271 |
+
if deconv_kernel == 4:
|
| 272 |
+
padding = 1
|
| 273 |
+
output_padding = 0
|
| 274 |
+
elif deconv_kernel == 3:
|
| 275 |
+
padding = 1
|
| 276 |
+
output_padding = 1
|
| 277 |
+
elif deconv_kernel == 2:
|
| 278 |
+
padding = 0
|
| 279 |
+
output_padding = 0
|
| 280 |
+
|
| 281 |
+
return deconv_kernel, padding, output_padding
|
| 282 |
+
|
| 283 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
| 284 |
+
assert num_layers == len(num_filters), \
|
| 285 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
| 286 |
+
assert num_layers == len(num_kernels), \
|
| 287 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
| 288 |
+
|
| 289 |
+
layers = []
|
| 290 |
+
for i in range(num_layers):
|
| 291 |
+
kernel, padding, output_padding = \
|
| 292 |
+
self._get_deconv_cfg(num_kernels[i], i)
|
| 293 |
+
|
| 294 |
+
planes = num_filters[i]
|
| 295 |
+
layers.append(nn.Sequential(
|
| 296 |
+
nn.ConvTranspose2d(
|
| 297 |
+
in_channels=self.inplanes,
|
| 298 |
+
out_channels=planes,
|
| 299 |
+
kernel_size=kernel,
|
| 300 |
+
stride=2,
|
| 301 |
+
padding=padding,
|
| 302 |
+
output_padding=output_padding,
|
| 303 |
+
bias=self.deconv_with_bias),
|
| 304 |
+
nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
|
| 305 |
+
)
|
| 306 |
+
if (not self.fpn):
|
| 307 |
+
layers.append(nn.ReLU(inplace=True))
|
| 308 |
+
|
| 309 |
+
self.inplanes = planes if not self.fpn else planes * 2
|
| 310 |
+
|
| 311 |
+
if self.fpn:
|
| 312 |
+
return layers
|
| 313 |
+
else:
|
| 314 |
+
return nn.Sequential(*layers)
|
| 315 |
+
|
| 316 |
+
def forward(self, x):
|
| 317 |
+
x = self.conv1(x)
|
| 318 |
+
x = self.bn1(x)
|
| 319 |
+
x = self.relu(x)
|
| 320 |
+
x = self.maxpool(x)
|
| 321 |
+
|
| 322 |
+
x1 = self.layer1(x) #256 x 64 x 64
|
| 323 |
+
x2 = self.layer2(x1) #512 x 32 x 32
|
| 324 |
+
x3 = self.layer3(x2) #1024 x 16 x 16
|
| 325 |
+
x4 = self.layer4(x3) #2048 x 8 x 8
|
| 326 |
+
|
| 327 |
+
# Custom dropout layer
|
| 328 |
+
x = self.dropout_layer(x4) #B x 8 x 8 x 2048
|
| 329 |
+
x3 = self.dropout_layer(x3)
|
| 330 |
+
x2 = self.dropout_layer(x2)
|
| 331 |
+
x1 = self.dropout_layer(x1)
|
| 332 |
+
|
| 333 |
+
# Custom FPN
|
| 334 |
+
if self.fpn:
|
| 335 |
+
assert isinstance(self.deconv_layers, list), "To custom FPN, decompose deconv layers as a list!"
|
| 336 |
+
x = self.pw_block_1(x) # B x 1024 x 8 x 8
|
| 337 |
+
x = self.deconv_layer_0(x) # B x 256 x 16 x 16
|
| 338 |
+
# x = self.relu(x) # B x 256 x 16 x 16
|
| 339 |
+
|
| 340 |
+
x_weighted = self.sigmoid_layer(x) # B x 256 x 16 x 16
|
| 341 |
+
x_inverse = torch.sub(1, x_weighted, alpha=1) # B x 256 x 16 x 16
|
| 342 |
+
x3 = self.pw_block_c3(x3) #B x 256 x 16 x 16
|
| 343 |
+
x3_ = torch.multiply(x3, x_inverse) #B x 256 x 16 x 16
|
| 344 |
+
x = torch.cat((x, x3_), dim=1) #B x 512 x 16 x 16
|
| 345 |
+
|
| 346 |
+
x = self.pw_block_2(x) #B x 512 x 16 x 16
|
| 347 |
+
x = self.deconv_layer_1(x) #B x 128 x 32 x 32
|
| 348 |
+
# x = self.relu(x) #B x 128 x 32 x 32
|
| 349 |
+
|
| 350 |
+
x_weighted = self.sigmoid_layer(x) #B x 128 x 32 x 32
|
| 351 |
+
x_inverse = torch.sub(1, x_weighted, alpha=1) #B x 128 x 32 x 32
|
| 352 |
+
x2 = self.pw_block_c2(x2)
|
| 353 |
+
x2_ = torch.multiply(x2, x_inverse) #B x 128 x 32 x 32
|
| 354 |
+
x = torch.cat((x, x2_), dim=1) #B x 256 x 32 x 32
|
| 355 |
+
|
| 356 |
+
x = self.inception_block(x) #B x 256 x 64 x 64
|
| 357 |
+
x = self.deconv_layer_2(x) #B x 256 x 64 x 64
|
| 358 |
+
|
| 359 |
+
if self.use_c2:
|
| 360 |
+
x_weighted = self.sigmoid_layer(x)
|
| 361 |
+
x_inverse = torch.sub(1, x_weighted, alpha=1)
|
| 362 |
+
x1_ = torch.multiply(x1, x_inverse)
|
| 363 |
+
x = torch.cat((x, x1_), dim=1)
|
| 364 |
+
x = self.pw_block_3(x)
|
| 365 |
+
else:
|
| 366 |
+
x = self.relu(x) #B x 256 x 64 x 64
|
| 367 |
+
else:
|
| 368 |
+
assert isinstance(self.deconv_layers, nn.Module), "Deconv Layer must be nn Module to compute!"
|
| 369 |
+
x = self.deconv_layers(x)
|
| 370 |
+
|
| 371 |
+
ret = {}
|
| 372 |
+
x1_hm = None
|
| 373 |
+
for head in self.heads:
|
| 374 |
+
if self.cls_based_hm and head == 'cls' and x1_hm is not None:
|
| 375 |
+
x = x1_hm
|
| 376 |
+
elif head == 'hm':
|
| 377 |
+
x1_hm = x
|
| 378 |
+
|
| 379 |
+
ret[head] = self.__getattr__(head)(x)
|
| 380 |
+
|
| 381 |
+
return [ret]
|
| 382 |
+
|
| 383 |
+
def init_weights(self, pretrained=True, **kwargs):
|
| 384 |
+
num_layers = kwargs.get('num_layers')
|
| 385 |
+
if pretrained:
|
| 386 |
+
if self.fpn:
|
| 387 |
+
for bl in [self.pw_block_1, self.pw_block_2]:
|
| 388 |
+
for _, l in bl.named_parameters():
|
| 389 |
+
if isinstance(l, nn.Conv2d):
|
| 390 |
+
nn.init.normal_(l.weight, std=0.001)
|
| 391 |
+
nn.init.constant_(l.bias, 0)
|
| 392 |
+
|
| 393 |
+
for _, l in self.inception_block.named_parameters():
|
| 394 |
+
if isinstance(l, nn.Conv2d):
|
| 395 |
+
nn.init.normal_(l.weight, std=0.001)
|
| 396 |
+
nn.init.constant_(l.bias, 0)
|
| 397 |
+
|
| 398 |
+
# print('=> init resnet deconv weights from normal distribution')
|
| 399 |
+
if isinstance(self.deconv_layers, nn.Module):
|
| 400 |
+
for _, m in self.deconv_layers.named_modules():
|
| 401 |
+
if isinstance(m, nn.ConvTranspose2d):
|
| 402 |
+
# print('=> init {}.weight as normal(0, 0.001)'.format(name))
|
| 403 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 404 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 405 |
+
if self.deconv_with_bias:
|
| 406 |
+
nn.init.constant_(m.bias, 0)
|
| 407 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 408 |
+
# print('=> init {}.weight as 1'.format(name))
|
| 409 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 410 |
+
nn.init.constant_(m.weight, 1)
|
| 411 |
+
nn.init.constant_(m.bias, 0)
|
| 412 |
+
else:
|
| 413 |
+
for layer in [self.deconv_layer_0, self.deconv_layer_1, self.deconv_layer_2]:
|
| 414 |
+
for _, m in layer.named_modules():
|
| 415 |
+
if isinstance(m, nn.ConvTranspose2d):
|
| 416 |
+
# print('=> init {}.weight as normal(0, 0.001)'.format(name))
|
| 417 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 418 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 419 |
+
if self.deconv_with_bias:
|
| 420 |
+
nn.init.constant_(m.bias, 0)
|
| 421 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 422 |
+
# print('=> init {}.weight as 1'.format(name))
|
| 423 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 424 |
+
nn.init.constant_(m.weight, 1)
|
| 425 |
+
nn.init.constant_(m.bias, 0)
|
| 426 |
+
|
| 427 |
+
# print('=> init final conv weights from normal distribution')
|
| 428 |
+
for head in self.heads:
|
| 429 |
+
final_layer = self.__getattr__(head)
|
| 430 |
+
for i, m in enumerate(final_layer.modules()):
|
| 431 |
+
if isinstance(m, nn.Conv2d):
|
| 432 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 433 |
+
# print('=> init {}.weight as normal(0, 0.001)'.format(name))
|
| 434 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 435 |
+
if m.weight.shape[0] == self.heads[head]:
|
| 436 |
+
if 'hm' in head:
|
| 437 |
+
nn.init.constant_(m.bias, -2.19)
|
| 438 |
+
else:
|
| 439 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 440 |
+
nn.init.constant_(m.bias, 0)
|
| 441 |
+
# if isinstance(m, nn.Linear):
|
| 442 |
+
# if m.weight.shape[0] == self.heads[head]:
|
| 443 |
+
# prior = 1/71
|
| 444 |
+
# nn.init.constant_(m.bias, -math.log((1-prior)/prior))
|
| 445 |
+
# else:
|
| 446 |
+
# nn.init.normal_(m.weight, std=0.001)
|
| 447 |
+
# nn.init.constant_(m.bias, 0)
|
| 448 |
+
|
| 449 |
+
#pretrained_state_dict = torch.load(pretrained)
|
| 450 |
+
url = model_urls['resnet{}'.format(num_layers)]
|
| 451 |
+
pretrained_state_dict = model_zoo.load_url(url)
|
| 452 |
+
print('=> loading pretrained model {}'.format(url))
|
| 453 |
+
self.load_state_dict(pretrained_state_dict, strict=False)
|
| 454 |
+
else:
|
| 455 |
+
print('=> imagenet pretrained model dose not exist')
|
| 456 |
+
print('=> please download it first')
|
| 457 |
+
raise ValueError('imagenet pretrained model does not exist')
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
|
| 461 |
+
34: (BasicBlock, [3, 4, 6, 3]),
|
| 462 |
+
50: (Bottleneck, [3, 4, 6, 3]),
|
| 463 |
+
101: (Bottleneck, [3, 4, 23, 3]),
|
| 464 |
+
152: (Bottleneck, [3, 8, 36, 3])}
|
models/networks/pose_efficientNet.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
if not os.getcwd() in sys.path:
|
| 6 |
+
sys.path.append(os.getcwd())
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from torch.utils import model_zoo
|
| 12 |
+
|
| 13 |
+
from ..builder import MODELS, build_model
|
| 14 |
+
from .efficientNet import (
|
| 15 |
+
round_filters,
|
| 16 |
+
round_repeats,
|
| 17 |
+
drop_connect,
|
| 18 |
+
get_same_padding_conv2d,
|
| 19 |
+
get_model_params,
|
| 20 |
+
efficientnet_params,
|
| 21 |
+
load_pretrained_weights,
|
| 22 |
+
Swish,
|
| 23 |
+
MemoryEfficientSwish,
|
| 24 |
+
calculate_output_image_size,
|
| 25 |
+
url_map_advprop,
|
| 26 |
+
url_map
|
| 27 |
+
)
|
| 28 |
+
from .common import (
|
| 29 |
+
InceptionBlock,
|
| 30 |
+
conv_block,
|
| 31 |
+
BN_MOMENTUM,
|
| 32 |
+
SELayer
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
VALID_MODELS = (
|
| 37 |
+
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
|
| 38 |
+
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
|
| 39 |
+
'efficientnet-b8',
|
| 40 |
+
|
| 41 |
+
# Support the construction of 'efficientnet-l2' without pretrained weights
|
| 42 |
+
'efficientnet-l2'
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MBConvBlock(nn.Module):
|
| 47 |
+
"""Mobile Inverted Residual Bottleneck Block.
|
| 48 |
+
Args:
|
| 49 |
+
block_args (namedtuple): BlockArgs, defined in utils.py.
|
| 50 |
+
global_params (namedtuple): GlobalParam, defined in utils.py.
|
| 51 |
+
image_size (tuple or list): [image_height, image_width].
|
| 52 |
+
References:
|
| 53 |
+
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
| 54 |
+
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
| 55 |
+
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, block_args, global_params, image_size=None):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self._block_args = block_args
|
| 61 |
+
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
|
| 62 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
| 63 |
+
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
| 64 |
+
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
|
| 65 |
+
|
| 66 |
+
# Expansion phase (Inverted Bottleneck)
|
| 67 |
+
inp = self._block_args.input_filters # number of input channels
|
| 68 |
+
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
| 69 |
+
if self._block_args.expand_ratio != 1:
|
| 70 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 71 |
+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
| 72 |
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 73 |
+
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
| 74 |
+
|
| 75 |
+
# Depthwise convolution phase
|
| 76 |
+
k = self._block_args.kernel_size
|
| 77 |
+
s = self._block_args.stride
|
| 78 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 79 |
+
self._depthwise_conv = Conv2d(
|
| 80 |
+
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
| 81 |
+
kernel_size=k, stride=s, bias=False)
|
| 82 |
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 83 |
+
image_size = calculate_output_image_size(image_size, s)
|
| 84 |
+
|
| 85 |
+
# Squeeze and Excitation layer, if desired
|
| 86 |
+
if self.has_se:
|
| 87 |
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
| 88 |
+
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
| 89 |
+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
| 90 |
+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
| 91 |
+
|
| 92 |
+
# Pointwise convolution phase
|
| 93 |
+
final_oup = self._block_args.output_filters
|
| 94 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 95 |
+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
| 96 |
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 97 |
+
self._swish = MemoryEfficientSwish()
|
| 98 |
+
|
| 99 |
+
def forward(self, inputs, drop_connect_rate=None):
|
| 100 |
+
"""MBConvBlock's forward function.
|
| 101 |
+
Args:
|
| 102 |
+
inputs (tensor): Input tensor.
|
| 103 |
+
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
| 104 |
+
Returns:
|
| 105 |
+
Output of this block after processing.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
# Expansion and Depthwise Convolution
|
| 109 |
+
x = inputs
|
| 110 |
+
if self._block_args.expand_ratio != 1:
|
| 111 |
+
x = self._expand_conv(inputs)
|
| 112 |
+
x = self._bn0(x)
|
| 113 |
+
x = self._swish(x)
|
| 114 |
+
|
| 115 |
+
x = self._depthwise_conv(x)
|
| 116 |
+
x = self._bn1(x)
|
| 117 |
+
x = self._swish(x)
|
| 118 |
+
|
| 119 |
+
# Squeeze and Excitation
|
| 120 |
+
if self.has_se:
|
| 121 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
| 122 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
| 123 |
+
x_squeezed = self._swish(x_squeezed)
|
| 124 |
+
x_squeezed = self._se_expand(x_squeezed)
|
| 125 |
+
x = torch.sigmoid(x_squeezed) * x
|
| 126 |
+
|
| 127 |
+
# Pointwise Convolution
|
| 128 |
+
x = self._project_conv(x)
|
| 129 |
+
x = self._bn2(x)
|
| 130 |
+
|
| 131 |
+
# Skip connection and drop connect
|
| 132 |
+
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
| 133 |
+
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
| 134 |
+
# The combination of skip connection and drop connect brings about stochastic depth.
|
| 135 |
+
if drop_connect_rate:
|
| 136 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
| 137 |
+
x = x + inputs # skip connection
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
def set_swish(self, memory_efficient=True):
|
| 141 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
| 142 |
+
Args:
|
| 143 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
| 144 |
+
"""
|
| 145 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@MODELS.register_module()
|
| 149 |
+
class EfficientNet(nn.Module):
|
| 150 |
+
"""EfficientNet model.
|
| 151 |
+
Most easily loaded with the .from_name or .from_pretrained methods.
|
| 152 |
+
Args:
|
| 153 |
+
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
|
| 154 |
+
global_params (namedtuple): A set of GlobalParams shared between blocks.
|
| 155 |
+
References:
|
| 156 |
+
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
|
| 157 |
+
Example:
|
| 158 |
+
>>> import torch
|
| 159 |
+
>>> from efficientnet.model import EfficientNet
|
| 160 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
| 161 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 162 |
+
>>> model.eval()
|
| 163 |
+
>>> outputs = model(inputs)
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 167 |
+
super().__init__()
|
| 168 |
+
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
| 169 |
+
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
| 170 |
+
self._global_params = global_params
|
| 171 |
+
self._blocks_args = blocks_args
|
| 172 |
+
|
| 173 |
+
# Batch norm parameters
|
| 174 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 175 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 176 |
+
|
| 177 |
+
# Get stem static or dynamic convolution depending on image size
|
| 178 |
+
image_size = global_params.image_size
|
| 179 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 180 |
+
|
| 181 |
+
# Stem
|
| 182 |
+
in_channels = 3 # rgb
|
| 183 |
+
out_channels = round_filters(32, self._global_params) # number of output channels
|
| 184 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 185 |
+
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 186 |
+
image_size = calculate_output_image_size(image_size, 2)
|
| 187 |
+
|
| 188 |
+
# Build blocks
|
| 189 |
+
self._blocks = nn.ModuleList([])
|
| 190 |
+
for block_args in self._blocks_args:
|
| 191 |
+
|
| 192 |
+
# Update block input and output filters based on depth multiplier.
|
| 193 |
+
block_args = block_args._replace(
|
| 194 |
+
input_filters=round_filters(block_args.input_filters, self._global_params),
|
| 195 |
+
output_filters=round_filters(block_args.output_filters, self._global_params),
|
| 196 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# The first block needs to take care of stride and filter size increase.
|
| 200 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
| 201 |
+
image_size = calculate_output_image_size(image_size, block_args.stride)
|
| 202 |
+
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
| 203 |
+
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
| 204 |
+
for _ in range(block_args.num_repeat - 1):
|
| 205 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
| 206 |
+
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
| 207 |
+
|
| 208 |
+
# Head
|
| 209 |
+
in_channels = block_args.output_filters # output of final block
|
| 210 |
+
out_channels = round_filters(1280, self._global_params)
|
| 211 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 212 |
+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 213 |
+
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 214 |
+
|
| 215 |
+
# Final linear layer
|
| 216 |
+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
| 217 |
+
if self._global_params.include_top:
|
| 218 |
+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
| 219 |
+
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
| 220 |
+
|
| 221 |
+
# Heatmap Decoder Construction
|
| 222 |
+
if self._global_params.include_hm_decoder:
|
| 223 |
+
print("Constructing the heatmap Decoder!")
|
| 224 |
+
self.efpn = self._global_params.efpn
|
| 225 |
+
self.tfpn = self._global_params.tfpn
|
| 226 |
+
|
| 227 |
+
assert not (self.efpn and self.tfpn), "Only one of E-FPN or FPN is intergrated!"
|
| 228 |
+
|
| 229 |
+
self.se_layer = self._global_params.se_layer
|
| 230 |
+
# self.hm_decoder_filters = [1792, 448, 160, 56] if self.fpn else [1792, 256, 256, 128]
|
| 231 |
+
self.hm_decoder_filters = [1792, 448, 160, 56]
|
| 232 |
+
num_kernels = [4, 4, 4, 4] if (self.efpn or self.tfpn) else [4, 4, 4]
|
| 233 |
+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
| 234 |
+
self._sigmoid = nn.Sigmoid()
|
| 235 |
+
self._relu = nn.ReLU(inplace=True)
|
| 236 |
+
self._relu1 = nn.ReLU(inplace=False)
|
| 237 |
+
self.deconv_with_bias = False
|
| 238 |
+
if self._global_params.use_c3:
|
| 239 |
+
self.inception_block = InceptionBlock(112, 112, stride=1, pool_size=3)
|
| 240 |
+
else:
|
| 241 |
+
self.inception_block = InceptionBlock(56, 56, stride=1, pool_size=3)
|
| 242 |
+
self.heads = self._global_params.heads
|
| 243 |
+
n_deconv = len(self.hm_decoder_filters)
|
| 244 |
+
self.fpn_layers = [self._global_params.use_c51, self._global_params.use_c4, self._global_params.use_c3]
|
| 245 |
+
|
| 246 |
+
if self.efpn or self.tfpn:
|
| 247 |
+
for idx in range(n_deconv):
|
| 248 |
+
in_decod_filters = self.hm_decoder_filters[idx]
|
| 249 |
+
|
| 250 |
+
if idx == 0:
|
| 251 |
+
out_decod_filters = self.hm_decoder_filters[idx+1]
|
| 252 |
+
deconv = nn.Sequential(
|
| 253 |
+
conv_block(in_decod_filters, out_decod_filters, (3,3), stride=1, padding=1),
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
in_decod_filters = in_decod_filters*2 if self.fpn_layers[idx-1] else in_decod_filters
|
| 257 |
+
kernel, padding, output_padding = self._get_deconv_cfg(num_kernels[idx])
|
| 258 |
+
|
| 259 |
+
if idx+1 < n_deconv:
|
| 260 |
+
out_decod_filters = self.hm_decoder_filters[idx+1]
|
| 261 |
+
deconv = nn.Sequential(
|
| 262 |
+
conv_block(in_decod_filters, out_decod_filters, (3,3), stride=1, padding=1),
|
| 263 |
+
nn.ConvTranspose2d(
|
| 264 |
+
in_channels=out_decod_filters,
|
| 265 |
+
out_channels=out_decod_filters,
|
| 266 |
+
kernel_size=kernel,
|
| 267 |
+
stride=2,
|
| 268 |
+
padding=padding,
|
| 269 |
+
output_padding=output_padding,
|
| 270 |
+
bias=self.deconv_with_bias),
|
| 271 |
+
nn.BatchNorm2d(out_decod_filters, momentum=BN_MOMENTUM),
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
out_decod_filters = in_decod_filters
|
| 275 |
+
deconv = nn.Sequential(
|
| 276 |
+
self.inception_block,
|
| 277 |
+
nn.ConvTranspose2d(
|
| 278 |
+
in_channels=out_decod_filters,
|
| 279 |
+
out_channels=out_decod_filters,
|
| 280 |
+
kernel_size=kernel,
|
| 281 |
+
stride=2,
|
| 282 |
+
padding=padding,
|
| 283 |
+
output_padding=output_padding,
|
| 284 |
+
bias=self.deconv_with_bias),
|
| 285 |
+
nn.BatchNorm2d(out_decod_filters, momentum=BN_MOMENTUM),
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# In case of using C2, this conv to apply to C2 features to get the same filters of the last deconv
|
| 289 |
+
if self._global_params.use_c2:
|
| 290 |
+
self.conv_c2 = conv_block(32, out_decod_filters, (3,3), stride=1, padding=1)
|
| 291 |
+
if self.se_layer:
|
| 292 |
+
se = SELayer(channel=out_decod_filters*2)
|
| 293 |
+
self.__setattr__(f'se_layer_{idx+1}', se)
|
| 294 |
+
|
| 295 |
+
self.__setattr__(f'deconv_{idx+1}', deconv)
|
| 296 |
+
else:
|
| 297 |
+
self.deconv_layers = self._make_deconv_layer(
|
| 298 |
+
len(num_kernels),
|
| 299 |
+
self.hm_decoder_filters,
|
| 300 |
+
num_kernels,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
for head, num_output in self.heads.items():
|
| 304 |
+
head_conv = int(self._global_params.head_conv)
|
| 305 |
+
num_output = int(num_output)
|
| 306 |
+
if self._global_params.use_c2:
|
| 307 |
+
assert self._global_params.efpn or self._global_params.tfpn, "FPN Design must be set active!"
|
| 308 |
+
assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
|
| 309 |
+
in_head_filters = self.hm_decoder_filters[-1]*4
|
| 310 |
+
elif self._global_params.use_c3:
|
| 311 |
+
in_head_filters = self.hm_decoder_filters[-1]*2
|
| 312 |
+
else:
|
| 313 |
+
in_head_filters = self.hm_decoder_filters[-1]
|
| 314 |
+
|
| 315 |
+
if head_conv > 0:
|
| 316 |
+
if head != 'cls':
|
| 317 |
+
fc = nn.Sequential(
|
| 318 |
+
nn.Conv2d(in_head_filters, head_conv,
|
| 319 |
+
kernel_size=3, padding=1, bias=True),
|
| 320 |
+
nn.BatchNorm2d(head_conv),
|
| 321 |
+
nn.ReLU(inplace=True),
|
| 322 |
+
nn.Conv2d(head_conv, num_output,
|
| 323 |
+
kernel_size=1, stride=1, padding=0)
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
fc = nn.Sequential(
|
| 327 |
+
nn.Conv2d(in_head_filters, head_conv, kernel_size=3,
|
| 328 |
+
padding=1, bias=True),
|
| 329 |
+
nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
|
| 330 |
+
nn.ReLU(inplace=True),
|
| 331 |
+
# nn.Conv2d(head_conv, num_output, kernel_size=1,
|
| 332 |
+
# stride=1, padding=0, bias=True),
|
| 333 |
+
# nn.BatchNorm2d(num_output),
|
| 334 |
+
# nn.ReLU(inplace=True),
|
| 335 |
+
# nn.AdaptiveMaxPool2d(head_conv//4),
|
| 336 |
+
nn.AdaptiveAvgPool2d(1),
|
| 337 |
+
nn.Flatten(),
|
| 338 |
+
# nn.Linear((head_conv//4)**2, head_conv, bias=True),
|
| 339 |
+
# nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
|
| 340 |
+
# nn.ReLU(inplace=True),
|
| 341 |
+
nn.Linear(head_conv, num_output, bias=True),
|
| 342 |
+
# nn.Sigmoid(),
|
| 343 |
+
# nn.Softmax(dim=-1)
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
fc = nn.Conv2d(
|
| 347 |
+
in_channels=in_head_filters,
|
| 348 |
+
out_channels=num_output,
|
| 349 |
+
kernel_size=1,
|
| 350 |
+
stride=1,
|
| 351 |
+
padding=0
|
| 352 |
+
)
|
| 353 |
+
self.__setattr__(head, fc)
|
| 354 |
+
|
| 355 |
+
# set activation to memory efficient swish by default
|
| 356 |
+
self._swish = MemoryEfficientSwish()
|
| 357 |
+
|
| 358 |
+
def _get_deconv_cfg(self, deconv_kernel):
|
| 359 |
+
if deconv_kernel == 4:
|
| 360 |
+
padding = 1
|
| 361 |
+
output_padding = 0
|
| 362 |
+
elif deconv_kernel == 3:
|
| 363 |
+
padding = 1
|
| 364 |
+
output_padding = 1
|
| 365 |
+
elif deconv_kernel == 2:
|
| 366 |
+
padding = 0
|
| 367 |
+
output_padding = 0
|
| 368 |
+
|
| 369 |
+
return deconv_kernel, padding, output_padding
|
| 370 |
+
|
| 371 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
| 372 |
+
assert num_layers == (len(num_filters) - 1), \
|
| 373 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
| 374 |
+
assert num_layers == len(num_kernels), \
|
| 375 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
| 376 |
+
|
| 377 |
+
layers = []
|
| 378 |
+
for i in range(num_layers):
|
| 379 |
+
kernel, padding, output_padding = \
|
| 380 |
+
self._get_deconv_cfg(num_kernels[i])
|
| 381 |
+
|
| 382 |
+
in_planes = num_filters[i]
|
| 383 |
+
out_planes = num_filters[i+1]
|
| 384 |
+
|
| 385 |
+
layers.append(nn.Sequential(
|
| 386 |
+
nn.ConvTranspose2d(
|
| 387 |
+
in_channels=in_planes,
|
| 388 |
+
out_channels=out_planes,
|
| 389 |
+
kernel_size=kernel,
|
| 390 |
+
stride=2,
|
| 391 |
+
padding=padding,
|
| 392 |
+
output_padding=output_padding,
|
| 393 |
+
bias=self.deconv_with_bias),
|
| 394 |
+
nn.BatchNorm2d(out_planes, momentum=BN_MOMENTUM),
|
| 395 |
+
nn.ReLU(inplace=True))
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
return nn.Sequential(*layers)
|
| 399 |
+
|
| 400 |
+
def set_swish(self, memory_efficient=True):
|
| 401 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
| 402 |
+
Args:
|
| 403 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
| 404 |
+
"""
|
| 405 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 406 |
+
for block in self._blocks:
|
| 407 |
+
block.set_swish(memory_efficient)
|
| 408 |
+
|
| 409 |
+
def extract_endpoints(self, inputs):
|
| 410 |
+
"""Use convolution layer to extract features
|
| 411 |
+
from reduction levels i in [1, 2, 3, 4, 5].
|
| 412 |
+
Args:
|
| 413 |
+
inputs (tensor): Input tensor.
|
| 414 |
+
Returns:
|
| 415 |
+
Dictionary of last intermediate features
|
| 416 |
+
with reduction levels i in [1, 2, 3, 4, 5].
|
| 417 |
+
Example:
|
| 418 |
+
>>> import torch
|
| 419 |
+
>>> from efficientnet.model import EfficientNet
|
| 420 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
| 421 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 422 |
+
>>> endpoints = model.extract_endpoints(inputs)
|
| 423 |
+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
|
| 424 |
+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
|
| 425 |
+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
|
| 426 |
+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
|
| 427 |
+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
|
| 428 |
+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
|
| 429 |
+
"""
|
| 430 |
+
endpoints = dict()
|
| 431 |
+
|
| 432 |
+
# Stem
|
| 433 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
| 434 |
+
prev_x = x
|
| 435 |
+
|
| 436 |
+
# Blocks
|
| 437 |
+
for idx, block in enumerate(self._blocks):
|
| 438 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 439 |
+
if drop_connect_rate:
|
| 440 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
| 441 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 442 |
+
# print('Prev', prev_x.size())
|
| 443 |
+
# print('X', x.size())
|
| 444 |
+
if prev_x.size(2) > x.size(2):
|
| 445 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
|
| 446 |
+
elif idx == len(self._blocks) - 1:
|
| 447 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
| 448 |
+
prev_x = x
|
| 449 |
+
|
| 450 |
+
# Head
|
| 451 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 452 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
| 453 |
+
|
| 454 |
+
return endpoints
|
| 455 |
+
|
| 456 |
+
def extract_features(self, inputs):
|
| 457 |
+
"""use convolution layer to extract feature .
|
| 458 |
+
Args:
|
| 459 |
+
inputs (tensor): Input tensor.
|
| 460 |
+
Returns:
|
| 461 |
+
Output of the final convolution
|
| 462 |
+
layer in the efficientnet model.
|
| 463 |
+
"""
|
| 464 |
+
# Stem
|
| 465 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
| 466 |
+
|
| 467 |
+
# Blocks
|
| 468 |
+
for idx, block in enumerate(self._blocks):
|
| 469 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 470 |
+
if drop_connect_rate:
|
| 471 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
| 472 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 473 |
+
|
| 474 |
+
# Head
|
| 475 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 476 |
+
|
| 477 |
+
return x
|
| 478 |
+
|
| 479 |
+
def forward(self, inputs):
|
| 480 |
+
"""EfficientNet's forward function.
|
| 481 |
+
Calls extract_features to extract features, applies final linear layer, and returns logits.
|
| 482 |
+
Args:
|
| 483 |
+
inputs (tensor): Input tensor.
|
| 484 |
+
Returns:
|
| 485 |
+
Output of this model after processing.
|
| 486 |
+
"""
|
| 487 |
+
# Convolution layers
|
| 488 |
+
# x = self.extract_features(inputs)
|
| 489 |
+
endpoints = self.extract_endpoints(inputs)
|
| 490 |
+
x1 = endpoints['reduction_6']
|
| 491 |
+
x2 = endpoints['reduction_5']
|
| 492 |
+
x3 = endpoints['reduction_4']
|
| 493 |
+
x4 = endpoints['reduction_3']
|
| 494 |
+
x5 = endpoints['reduction_2']
|
| 495 |
+
x = x1
|
| 496 |
+
|
| 497 |
+
if self._global_params.include_top:
|
| 498 |
+
# Pooling and final linear layer
|
| 499 |
+
x = self._avg_pooling(x)
|
| 500 |
+
|
| 501 |
+
x = x.flatten(start_dim=1)
|
| 502 |
+
x = self._dropout(x)
|
| 503 |
+
x = self._fc(x)
|
| 504 |
+
return x
|
| 505 |
+
|
| 506 |
+
if self._global_params.include_hm_decoder:
|
| 507 |
+
x1 = self._dropout(x1)
|
| 508 |
+
x2 = self._dropout(x2)
|
| 509 |
+
x3 = self._dropout(x3)
|
| 510 |
+
x4 = self._dropout(x4)
|
| 511 |
+
|
| 512 |
+
if self.efpn:
|
| 513 |
+
assert self._global_params.use_c51, "C51 must be utilized for FPN intergration"
|
| 514 |
+
|
| 515 |
+
x = self.__getattr__('deconv_1')(x1)
|
| 516 |
+
|
| 517 |
+
if self._global_params.use_c51:
|
| 518 |
+
x_weighted = self._sigmoid(x)
|
| 519 |
+
x_inv = torch.sub(1, x_weighted, alpha=1)
|
| 520 |
+
x2_ = torch.multiply(x_inv, x2)
|
| 521 |
+
x = torch.cat([x, x2_], dim=1)
|
| 522 |
+
|
| 523 |
+
if self.se_layer:
|
| 524 |
+
x = self.__getattr__('se_layer_1')(x)
|
| 525 |
+
else:
|
| 526 |
+
x = self._relu(x)
|
| 527 |
+
|
| 528 |
+
x = self.__getattr__('deconv_2')(x)
|
| 529 |
+
|
| 530 |
+
if self._global_params.use_c4:
|
| 531 |
+
x_weighted = self._sigmoid(x)
|
| 532 |
+
x_inv = torch.sub(1, x_weighted, alpha=1)
|
| 533 |
+
x3_ = torch.multiply(x_inv, x3)
|
| 534 |
+
x = torch.cat([x, x3_], dim=1)
|
| 535 |
+
|
| 536 |
+
if self.se_layer:
|
| 537 |
+
x = self.__getattr__('se_layer_2')(x)
|
| 538 |
+
else:
|
| 539 |
+
x = self._relu(x)
|
| 540 |
+
|
| 541 |
+
x = self.__getattr__('deconv_3')(x)
|
| 542 |
+
|
| 543 |
+
if self._global_params.use_c3:
|
| 544 |
+
assert self._global_params.use_c4, "C4 must be utilized for FPN intergration of C3"
|
| 545 |
+
|
| 546 |
+
x_weighted = self._sigmoid(x)
|
| 547 |
+
x_inv = torch.sub(1, x_weighted, alpha=1)
|
| 548 |
+
x4_ = torch.multiply(x_inv, x4)
|
| 549 |
+
x = torch.cat([x, x4_], dim=1)
|
| 550 |
+
|
| 551 |
+
if self.se_layer:
|
| 552 |
+
x = self.__getattr__('se_layer_3')(x)
|
| 553 |
+
else:
|
| 554 |
+
x = self._relu(x)
|
| 555 |
+
|
| 556 |
+
x = self.__getattr__('deconv_4')(x)
|
| 557 |
+
|
| 558 |
+
if not self._global_params.use_c2:
|
| 559 |
+
x = self._relu(x)
|
| 560 |
+
else:
|
| 561 |
+
assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
|
| 562 |
+
|
| 563 |
+
x5 = self._dropout(x5)
|
| 564 |
+
x5_ = self.conv_c2(x5)
|
| 565 |
+
x_weighted = self._sigmoid(x)
|
| 566 |
+
x_inv = torch.sub(1, x_weighted, alpha=1)
|
| 567 |
+
x5_ = torch.multiply(x_inv, x5_)
|
| 568 |
+
x = torch.cat([x, x5_], dim=1)
|
| 569 |
+
|
| 570 |
+
if self.se_layer:
|
| 571 |
+
x = self.__getattr__('se_layer_4')(x)
|
| 572 |
+
elif self.tfpn:
|
| 573 |
+
assert self._global_params.use_c51, "C51 must be utilized for FPN intergration"
|
| 574 |
+
x = self.__getattr__('deconv_1')(x1)
|
| 575 |
+
x = self._relu1(x)
|
| 576 |
+
x = torch.cat([x, x2], dim=1)
|
| 577 |
+
|
| 578 |
+
x = self.__getattr__('deconv_2')(x)
|
| 579 |
+
if not self._global_params.use_c4:
|
| 580 |
+
x = self._relu1(x)
|
| 581 |
+
else:
|
| 582 |
+
x = torch.cat([x, x3], dim=1)
|
| 583 |
+
|
| 584 |
+
x = self.__getattr__('deconv_3')(x)
|
| 585 |
+
if not self._global_params.use_c3:
|
| 586 |
+
x = self._relu1(x)
|
| 587 |
+
else:
|
| 588 |
+
assert self._global_params.use_c4, "C4 must be utilized for FPN intergration of C3"
|
| 589 |
+
x = torch.cat([x, x4], dim=1)
|
| 590 |
+
|
| 591 |
+
x = self.__getattr__('deconv_4')(x)
|
| 592 |
+
if not self._global_params.use_c2:
|
| 593 |
+
x = self._relu(x)
|
| 594 |
+
else:
|
| 595 |
+
assert self._global_params.use_c3, "C3 must be utilized for FPN intergration of C2"
|
| 596 |
+
x5 = self._dropout(x5)
|
| 597 |
+
x5 = self.conv_c2(x5)
|
| 598 |
+
x = self._relu1(x)
|
| 599 |
+
x = torch.cat([x, x5], dim=1)
|
| 600 |
+
else:
|
| 601 |
+
x = self.deconv_layers(x1)
|
| 602 |
+
|
| 603 |
+
ret = {}
|
| 604 |
+
for head in self.heads:
|
| 605 |
+
ret[head] = self.__getattr__(head)(x)
|
| 606 |
+
|
| 607 |
+
return [ret]
|
| 608 |
+
|
| 609 |
+
@classmethod
|
| 610 |
+
def from_name(cls, model_name, in_channels=3, **override_params):
|
| 611 |
+
"""Create an efficientnet model according to name.
|
| 612 |
+
Args:
|
| 613 |
+
model_name (str): Name for efficientnet.
|
| 614 |
+
in_channels (int): Input data's channel number.
|
| 615 |
+
override_params (other key word params):
|
| 616 |
+
Params to override model's global_params.
|
| 617 |
+
Optional key:
|
| 618 |
+
'width_coefficient', 'depth_coefficient',
|
| 619 |
+
'image_size', 'dropout_rate',
|
| 620 |
+
'num_classes', 'batch_norm_momentum',
|
| 621 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
| 622 |
+
'depth_divisor', 'min_depth'
|
| 623 |
+
Returns:
|
| 624 |
+
An efficientnet model.
|
| 625 |
+
"""
|
| 626 |
+
cls._check_model_name_is_valid(model_name)
|
| 627 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
| 628 |
+
model = cls(blocks_args, global_params)
|
| 629 |
+
model._change_in_channels(in_channels)
|
| 630 |
+
return model
|
| 631 |
+
|
| 632 |
+
@classmethod
|
| 633 |
+
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
|
| 634 |
+
in_channels=3, num_classes=1000, **override_params):
|
| 635 |
+
"""Create an efficientnet model according to name.
|
| 636 |
+
Args:
|
| 637 |
+
model_name (str): Name for efficientnet.
|
| 638 |
+
weights_path (None or str):
|
| 639 |
+
str: path to pretrained weights file on the local disk.
|
| 640 |
+
None: use pretrained weights downloaded from the Internet.
|
| 641 |
+
advprop (bool):
|
| 642 |
+
Whether to load pretrained weights
|
| 643 |
+
trained with advprop (valid when weights_path is None).
|
| 644 |
+
in_channels (int): Input data's channel number.
|
| 645 |
+
num_classes (int):
|
| 646 |
+
Number of categories for classification.
|
| 647 |
+
It controls the output size for final linear layer.
|
| 648 |
+
override_params (other key word params):
|
| 649 |
+
Params to override model's global_params.
|
| 650 |
+
Optional key:
|
| 651 |
+
'width_coefficient', 'depth_coefficient',
|
| 652 |
+
'image_size', 'dropout_rate',
|
| 653 |
+
'batch_norm_momentum',
|
| 654 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
| 655 |
+
'depth_divisor', 'min_depth'
|
| 656 |
+
Returns:
|
| 657 |
+
A pretrained efficientnet model.
|
| 658 |
+
"""
|
| 659 |
+
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
|
| 660 |
+
load_pretrained_weights(model, model_name, weights_path=weights_path,
|
| 661 |
+
load_fc=((num_classes == 1000) and (model._global_params.include_top)), advprop=advprop)
|
| 662 |
+
model._change_in_channels(in_channels)
|
| 663 |
+
return model
|
| 664 |
+
|
| 665 |
+
@classmethod
|
| 666 |
+
def get_image_size(cls, model_name):
|
| 667 |
+
"""Get the input image size for a given efficientnet model.
|
| 668 |
+
Args:
|
| 669 |
+
model_name (str): Name for efficientnet.
|
| 670 |
+
Returns:
|
| 671 |
+
Input image size (resolution).
|
| 672 |
+
"""
|
| 673 |
+
cls._check_model_name_is_valid(model_name)
|
| 674 |
+
_, _, res, _ = efficientnet_params(model_name)
|
| 675 |
+
return res
|
| 676 |
+
|
| 677 |
+
@classmethod
|
| 678 |
+
def _check_model_name_is_valid(cls, model_name):
|
| 679 |
+
"""Validates model name.
|
| 680 |
+
Args:
|
| 681 |
+
model_name (str): Name for efficientnet.
|
| 682 |
+
Returns:
|
| 683 |
+
bool: Is a valid name or not.
|
| 684 |
+
"""
|
| 685 |
+
if model_name not in VALID_MODELS:
|
| 686 |
+
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
|
| 687 |
+
|
| 688 |
+
def _change_in_channels(self, in_channels):
|
| 689 |
+
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
| 690 |
+
Args:
|
| 691 |
+
in_channels (int): Input data's channel number.
|
| 692 |
+
"""
|
| 693 |
+
if in_channels != 3:
|
| 694 |
+
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
| 695 |
+
out_channels = round_filters(32, self._global_params)
|
| 696 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
@MODELS.register_module()
|
| 700 |
+
class PoseEfficientNet(EfficientNet):
|
| 701 |
+
def __init__(self, model_name, in_channels=3, **override_params):
|
| 702 |
+
self.model_name = model_name
|
| 703 |
+
self.in_channels = in_channels
|
| 704 |
+
|
| 705 |
+
# Initialize Parent Class
|
| 706 |
+
super()._check_model_name_is_valid(model_name)
|
| 707 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
| 708 |
+
super().__init__(blocks_args, global_params)
|
| 709 |
+
|
| 710 |
+
@classmethod
|
| 711 |
+
def from_name(cls, model_name, in_channels, **override_params):
|
| 712 |
+
return NotImplemented
|
| 713 |
+
|
| 714 |
+
@classmethod
|
| 715 |
+
def from_pretrained(cls, model_name, weights_path, advprop, in_channels, num_classes, **override_params):
|
| 716 |
+
return NotImplemented
|
| 717 |
+
|
| 718 |
+
def _change_in_channels(self, in_channels):
|
| 719 |
+
return NotImplemented
|
| 720 |
+
|
| 721 |
+
def init_weights(self, pretrained=False, advprop=False, verbose=True):
|
| 722 |
+
if pretrained:
|
| 723 |
+
url_map_ = url_map_advprop if advprop else url_map
|
| 724 |
+
state_dict = model_zoo.load_url(url_map_[self.model_name])
|
| 725 |
+
self.load_state_dict(state_dict, strict=False)
|
| 726 |
+
|
| 727 |
+
# Initialize weights for Deconvolution Layer
|
| 728 |
+
if self._global_params.include_hm_decoder:
|
| 729 |
+
if self.efpn or self.tfpn:
|
| 730 |
+
deconv_layers = [self.deconv_1, self.deconv_2, self.deconv_3, self.deconv_4]
|
| 731 |
+
else:
|
| 732 |
+
deconv_layers = self.deconv_layers
|
| 733 |
+
|
| 734 |
+
for layer in deconv_layers:
|
| 735 |
+
for _, m in layer.named_modules():
|
| 736 |
+
if isinstance(m, nn.ConvTranspose2d):
|
| 737 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 738 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 739 |
+
if self.deconv_with_bias:
|
| 740 |
+
nn.init.constant_(m.bias, 0)
|
| 741 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 742 |
+
nn.init.constant_(m.weight, 1)
|
| 743 |
+
nn.init.constant_(m.bias, 0)
|
| 744 |
+
|
| 745 |
+
# Init head parameters
|
| 746 |
+
for head in self.heads:
|
| 747 |
+
final_layer = self.__getattr__(head)
|
| 748 |
+
for i, m in enumerate(final_layer.modules()):
|
| 749 |
+
if isinstance(m, nn.Conv2d):
|
| 750 |
+
if m.weight.shape[0] == self.heads[head]:
|
| 751 |
+
if 'hm' in head:
|
| 752 |
+
nn.init.constant_(m.bias, -2.19)
|
| 753 |
+
else:
|
| 754 |
+
# nn.init.normal_(m.weight, std=0.001)
|
| 755 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 756 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 757 |
+
nn.init.constant_(m.bias, 0)
|
| 758 |
+
|
| 759 |
+
self._change_in_channels(in_channels=self.in_channels)
|
| 760 |
+
if verbose:
|
| 761 |
+
print('Loaded pretrained weights for {}'.format(self.model_name))
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
if __name__ == '__main__':
|
| 765 |
+
cfg = dict(type='PoseEfficientNet',
|
| 766 |
+
model_name='efficientnet-b4',
|
| 767 |
+
include_top=False,
|
| 768 |
+
include_hm_decoder=True,
|
| 769 |
+
head_conv=64,
|
| 770 |
+
heads={'hm':1, 'cls':1, 'cstency':256},
|
| 771 |
+
use_c2=True)
|
| 772 |
+
model = build_model(cfg, MODELS)
|
| 773 |
+
model.init_weights(pretrained=True)
|
| 774 |
+
model.eval()
|
| 775 |
+
inputs = torch.rand((1, 3, 384, 384))
|
| 776 |
+
|
| 777 |
+
for i, (n, p) in enumerate(model.named_parameters()):
|
| 778 |
+
print(i, n)
|
| 779 |
+
|
| 780 |
+
# To show the whole pose EFN model outputs shape
|
| 781 |
+
x = model(inputs)[0]
|
| 782 |
+
for head in x.keys():
|
| 783 |
+
print(f'{head} shape is --- {x[head].shape}')
|
| 784 |
+
|
| 785 |
+
# To show the endpoints features shape
|
| 786 |
+
# endpoints = model.extract_endpoints(inputs)
|
| 787 |
+
# for k in endpoints.keys():
|
| 788 |
+
# print(endpoints[k].shape)
|
models/networks/pose_hrnet.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
from __future__ import absolute_import
|
| 3 |
+
from __future__ import division
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from ..builder import MODELS
|
| 13 |
+
|
| 14 |
+
from .common import conv3x3, BN_MOMENTUM
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BasicBlock(nn.Module):
|
| 18 |
+
expansion = 1
|
| 19 |
+
|
| 20 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 21 |
+
super(BasicBlock, self).__init__()
|
| 22 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 23 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 24 |
+
self.relu = nn.ReLU(inplace=True)
|
| 25 |
+
self.conv2 = conv3x3(planes, planes)
|
| 26 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 27 |
+
self.downsample = downsample
|
| 28 |
+
self.stride = stride
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
residual = x
|
| 32 |
+
|
| 33 |
+
out = self.conv1(x)
|
| 34 |
+
out = self.bn1(out)
|
| 35 |
+
out = self.relu(out)
|
| 36 |
+
|
| 37 |
+
out = self.conv2(out)
|
| 38 |
+
out = self.bn2(out)
|
| 39 |
+
|
| 40 |
+
if self.downsample is not None:
|
| 41 |
+
residual = self.downsample(x)
|
| 42 |
+
|
| 43 |
+
out += residual
|
| 44 |
+
out = self.relu(out)
|
| 45 |
+
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Bottleneck(nn.Module):
|
| 50 |
+
expansion = 4
|
| 51 |
+
|
| 52 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 53 |
+
super(Bottleneck, self).__init__()
|
| 54 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 55 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 56 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 57 |
+
padding=1, bias=False)
|
| 58 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 59 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
| 60 |
+
bias=False)
|
| 61 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
| 62 |
+
momentum=BN_MOMENTUM)
|
| 63 |
+
self.relu = nn.ReLU(inplace=True)
|
| 64 |
+
self.downsample = downsample
|
| 65 |
+
self.stride = stride
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
residual = x
|
| 69 |
+
|
| 70 |
+
out = self.conv1(x)
|
| 71 |
+
out = self.bn1(out)
|
| 72 |
+
out = self.relu(out)
|
| 73 |
+
|
| 74 |
+
out = self.conv2(out)
|
| 75 |
+
out = self.bn2(out)
|
| 76 |
+
out = self.relu(out)
|
| 77 |
+
|
| 78 |
+
out = self.conv3(out)
|
| 79 |
+
out = self.bn3(out)
|
| 80 |
+
|
| 81 |
+
if self.downsample is not None:
|
| 82 |
+
residual = self.downsample(x)
|
| 83 |
+
|
| 84 |
+
out += residual
|
| 85 |
+
out = self.relu(out)
|
| 86 |
+
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class HighResolutionModule(nn.Module):
|
| 91 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
| 92 |
+
num_channels, fuse_method, multi_scale_output=True):
|
| 93 |
+
super(HighResolutionModule, self).__init__()
|
| 94 |
+
self._check_branches(
|
| 95 |
+
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
| 96 |
+
|
| 97 |
+
self.num_inchannels = num_inchannels
|
| 98 |
+
self.fuse_method = fuse_method
|
| 99 |
+
self.num_branches = num_branches
|
| 100 |
+
|
| 101 |
+
self.multi_scale_output = multi_scale_output
|
| 102 |
+
|
| 103 |
+
self.branches = self._make_branches(
|
| 104 |
+
num_branches, blocks, num_blocks, num_channels)
|
| 105 |
+
self.fuse_layers = self._make_fuse_layers()
|
| 106 |
+
self.relu = nn.ReLU(True)
|
| 107 |
+
|
| 108 |
+
def _check_branches(self, num_branches, blocks, num_blocks,
|
| 109 |
+
num_inchannels, num_channels):
|
| 110 |
+
if num_branches != len(num_blocks):
|
| 111 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
| 112 |
+
num_branches, len(num_blocks))
|
| 113 |
+
# logger.error(error_msg)
|
| 114 |
+
raise ValueError(error_msg)
|
| 115 |
+
|
| 116 |
+
if num_branches != len(num_channels):
|
| 117 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
| 118 |
+
num_branches, len(num_channels))
|
| 119 |
+
# logger.error(error_msg)
|
| 120 |
+
raise ValueError(error_msg)
|
| 121 |
+
|
| 122 |
+
if num_branches != len(num_inchannels):
|
| 123 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
| 124 |
+
num_branches, len(num_inchannels))
|
| 125 |
+
# logger.error(error_msg)
|
| 126 |
+
raise ValueError(error_msg)
|
| 127 |
+
|
| 128 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
| 129 |
+
stride=1):
|
| 130 |
+
downsample = None
|
| 131 |
+
if stride != 1 or \
|
| 132 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
| 133 |
+
downsample = nn.Sequential(
|
| 134 |
+
nn.Conv2d(
|
| 135 |
+
self.num_inchannels[branch_index],
|
| 136 |
+
num_channels[branch_index] * block.expansion,
|
| 137 |
+
kernel_size=1, stride=stride, bias=False
|
| 138 |
+
),
|
| 139 |
+
nn.BatchNorm2d(
|
| 140 |
+
num_channels[branch_index] * block.expansion,
|
| 141 |
+
momentum=BN_MOMENTUM
|
| 142 |
+
),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
layers = []
|
| 146 |
+
layers.append(
|
| 147 |
+
block(
|
| 148 |
+
self.num_inchannels[branch_index],
|
| 149 |
+
num_channels[branch_index],
|
| 150 |
+
stride,
|
| 151 |
+
downsample
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
self.num_inchannels[branch_index] = \
|
| 155 |
+
num_channels[branch_index] * block.expansion
|
| 156 |
+
for i in range(1, num_blocks[branch_index]):
|
| 157 |
+
layers.append(
|
| 158 |
+
block(
|
| 159 |
+
self.num_inchannels[branch_index],
|
| 160 |
+
num_channels[branch_index]
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return nn.Sequential(*layers)
|
| 165 |
+
|
| 166 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 167 |
+
branches = []
|
| 168 |
+
|
| 169 |
+
for i in range(num_branches):
|
| 170 |
+
branches.append(
|
| 171 |
+
self._make_one_branch(i, block, num_blocks, num_channels)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return nn.ModuleList(branches)
|
| 175 |
+
|
| 176 |
+
def _make_fuse_layers(self):
|
| 177 |
+
if self.num_branches == 1:
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
num_branches = self.num_branches
|
| 181 |
+
num_inchannels = self.num_inchannels
|
| 182 |
+
fuse_layers = []
|
| 183 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
| 184 |
+
fuse_layer = []
|
| 185 |
+
for j in range(num_branches):
|
| 186 |
+
if j > i:
|
| 187 |
+
fuse_layer.append(
|
| 188 |
+
nn.Sequential(
|
| 189 |
+
nn.Conv2d(
|
| 190 |
+
num_inchannels[j],
|
| 191 |
+
num_inchannels[i],
|
| 192 |
+
1, 1, 0, bias=False
|
| 193 |
+
),
|
| 194 |
+
nn.BatchNorm2d(num_inchannels[i]),
|
| 195 |
+
nn.Upsample(scale_factor=2**(j-i), mode='nearest')
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
elif j == i:
|
| 199 |
+
fuse_layer.append(None)
|
| 200 |
+
else:
|
| 201 |
+
conv3x3s = []
|
| 202 |
+
for k in range(i-j):
|
| 203 |
+
if k == i - j - 1:
|
| 204 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
| 205 |
+
conv3x3s.append(
|
| 206 |
+
nn.Sequential(
|
| 207 |
+
nn.Conv2d(
|
| 208 |
+
num_inchannels[j],
|
| 209 |
+
num_outchannels_conv3x3,
|
| 210 |
+
3, 2, 1, bias=False
|
| 211 |
+
),
|
| 212 |
+
nn.BatchNorm2d(num_outchannels_conv3x3)
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
| 217 |
+
conv3x3s.append(
|
| 218 |
+
nn.Sequential(
|
| 219 |
+
nn.Conv2d(
|
| 220 |
+
num_inchannels[j],
|
| 221 |
+
num_outchannels_conv3x3,
|
| 222 |
+
3, 2, 1, bias=False
|
| 223 |
+
),
|
| 224 |
+
nn.BatchNorm2d(num_outchannels_conv3x3),
|
| 225 |
+
nn.ReLU(True)
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 229 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 230 |
+
|
| 231 |
+
return nn.ModuleList(fuse_layers)
|
| 232 |
+
|
| 233 |
+
def get_num_inchannels(self):
|
| 234 |
+
return self.num_inchannels
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
if self.num_branches == 1:
|
| 238 |
+
return [self.branches[0](x[0])]
|
| 239 |
+
|
| 240 |
+
for i in range(self.num_branches):
|
| 241 |
+
x[i] = self.branches[i](x[i])
|
| 242 |
+
|
| 243 |
+
x_fuse = []
|
| 244 |
+
|
| 245 |
+
for i in range(len(self.fuse_layers)):
|
| 246 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
| 247 |
+
for j in range(1, self.num_branches):
|
| 248 |
+
if i == j:
|
| 249 |
+
y = y + x[j]
|
| 250 |
+
else:
|
| 251 |
+
y = y + self.fuse_layers[i][j](x[j])
|
| 252 |
+
x_fuse.append(self.relu(y))
|
| 253 |
+
|
| 254 |
+
return x_fuse
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
blocks_dict = {
|
| 258 |
+
'BASIC': BasicBlock,
|
| 259 |
+
'BOTTLENECK': Bottleneck
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@MODELS.register_module()
|
| 264 |
+
class PoseHighResolutionNet(nn.Module):
|
| 265 |
+
def __init__(self,
|
| 266 |
+
cfg,
|
| 267 |
+
**kwargs):
|
| 268 |
+
self.inplanes = 64
|
| 269 |
+
extra = cfg.MODEL.EXTRA
|
| 270 |
+
self.cls_based_hm = cfg.MODEL.cls_based_hm
|
| 271 |
+
self.heads = cfg.MODEL.heads
|
| 272 |
+
super(PoseHighResolutionNet, self).__init__()
|
| 273 |
+
|
| 274 |
+
# stem net
|
| 275 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
|
| 276 |
+
bias=False)
|
| 277 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
| 278 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
|
| 279 |
+
bias=False)
|
| 280 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
| 281 |
+
self.relu = nn.ReLU(inplace=True)
|
| 282 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 4)
|
| 283 |
+
|
| 284 |
+
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
|
| 285 |
+
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
| 286 |
+
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
| 287 |
+
num_channels = [
|
| 288 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
| 289 |
+
]
|
| 290 |
+
self.transition1 = self._make_transition_layer([256], num_channels)
|
| 291 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
| 292 |
+
self.stage2_cfg, num_channels)
|
| 293 |
+
|
| 294 |
+
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
|
| 295 |
+
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
| 296 |
+
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
| 297 |
+
num_channels = [
|
| 298 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
| 299 |
+
]
|
| 300 |
+
self.transition2 = self._make_transition_layer(
|
| 301 |
+
pre_stage_channels, num_channels)
|
| 302 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
| 303 |
+
self.stage3_cfg, num_channels)
|
| 304 |
+
|
| 305 |
+
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
|
| 306 |
+
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
| 307 |
+
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
| 308 |
+
num_channels = [
|
| 309 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
| 310 |
+
]
|
| 311 |
+
self.transition3 = self._make_transition_layer(
|
| 312 |
+
pre_stage_channels, num_channels)
|
| 313 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
| 314 |
+
self.stage4_cfg, num_channels, multi_scale_output=False)
|
| 315 |
+
|
| 316 |
+
self.final_layer = nn.Conv2d(
|
| 317 |
+
in_channels=pre_stage_channels[0],
|
| 318 |
+
out_channels=cfg.MODEL.NUM_JOINTS,
|
| 319 |
+
kernel_size=extra.FINAL_CONV_KERNEL,
|
| 320 |
+
stride=1,
|
| 321 |
+
padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.final_layer_cls = nn.Sequential(
|
| 325 |
+
nn.BatchNorm2d(cfg.MODEL.NUM_JOINTS, momentum=BN_MOMENTUM),
|
| 326 |
+
nn.AdaptiveMaxPool2d(cfg.MODEL.HEATMAP_SIZE[0]//4),
|
| 327 |
+
nn.Flatten(),
|
| 328 |
+
nn.Linear((cfg.MODEL.HEATMAP_SIZE[0]//4)**2, cfg.MODEL.NUM_JOINTS, bias=True),
|
| 329 |
+
nn.Sigmoid()
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS']
|
| 333 |
+
|
| 334 |
+
def _make_transition_layer(
|
| 335 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
| 336 |
+
num_branches_cur = len(num_channels_cur_layer)
|
| 337 |
+
num_branches_pre = len(num_channels_pre_layer)
|
| 338 |
+
|
| 339 |
+
transition_layers = []
|
| 340 |
+
for i in range(num_branches_cur):
|
| 341 |
+
if i < num_branches_pre:
|
| 342 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
| 343 |
+
transition_layers.append(
|
| 344 |
+
nn.Sequential(
|
| 345 |
+
nn.Conv2d(
|
| 346 |
+
num_channels_pre_layer[i],
|
| 347 |
+
num_channels_cur_layer[i],
|
| 348 |
+
3, 1, 1, bias=False
|
| 349 |
+
),
|
| 350 |
+
nn.BatchNorm2d(num_channels_cur_layer[i]),
|
| 351 |
+
nn.ReLU(inplace=True)
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
transition_layers.append(None)
|
| 356 |
+
else:
|
| 357 |
+
conv3x3s = []
|
| 358 |
+
for j in range(i+1-num_branches_pre):
|
| 359 |
+
inchannels = num_channels_pre_layer[-1]
|
| 360 |
+
outchannels = num_channels_cur_layer[i] \
|
| 361 |
+
if j == i-num_branches_pre else inchannels
|
| 362 |
+
conv3x3s.append(
|
| 363 |
+
nn.Sequential(
|
| 364 |
+
nn.Conv2d(
|
| 365 |
+
inchannels, outchannels, 3, 2, 1, bias=False
|
| 366 |
+
),
|
| 367 |
+
nn.BatchNorm2d(outchannels),
|
| 368 |
+
nn.ReLU(inplace=True)
|
| 369 |
+
)
|
| 370 |
+
)
|
| 371 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
| 372 |
+
|
| 373 |
+
return nn.ModuleList(transition_layers)
|
| 374 |
+
|
| 375 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 376 |
+
downsample = None
|
| 377 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 378 |
+
downsample = nn.Sequential(
|
| 379 |
+
nn.Conv2d(
|
| 380 |
+
self.inplanes, planes * block.expansion,
|
| 381 |
+
kernel_size=1, stride=stride, bias=False
|
| 382 |
+
),
|
| 383 |
+
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
layers = []
|
| 387 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 388 |
+
self.inplanes = planes * block.expansion
|
| 389 |
+
for i in range(1, blocks):
|
| 390 |
+
layers.append(block(self.inplanes, planes))
|
| 391 |
+
|
| 392 |
+
return nn.Sequential(*layers)
|
| 393 |
+
|
| 394 |
+
def _make_stage(self, layer_config, num_inchannels,
|
| 395 |
+
multi_scale_output=True):
|
| 396 |
+
num_modules = layer_config['NUM_MODULES']
|
| 397 |
+
num_branches = layer_config['NUM_BRANCHES']
|
| 398 |
+
num_blocks = layer_config['NUM_BLOCKS']
|
| 399 |
+
num_channels = layer_config['NUM_CHANNELS']
|
| 400 |
+
block = blocks_dict[layer_config['BLOCK']]
|
| 401 |
+
fuse_method = layer_config['FUSE_METHOD']
|
| 402 |
+
|
| 403 |
+
modules = []
|
| 404 |
+
for i in range(num_modules):
|
| 405 |
+
# multi_scale_output is only used last module
|
| 406 |
+
if not multi_scale_output and i == num_modules - 1:
|
| 407 |
+
reset_multi_scale_output = False
|
| 408 |
+
else:
|
| 409 |
+
reset_multi_scale_output = True
|
| 410 |
+
|
| 411 |
+
modules.append(
|
| 412 |
+
HighResolutionModule(
|
| 413 |
+
num_branches,
|
| 414 |
+
block,
|
| 415 |
+
num_blocks,
|
| 416 |
+
num_inchannels,
|
| 417 |
+
num_channels,
|
| 418 |
+
fuse_method,
|
| 419 |
+
reset_multi_scale_output
|
| 420 |
+
)
|
| 421 |
+
)
|
| 422 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
| 423 |
+
|
| 424 |
+
return nn.Sequential(*modules), num_inchannels
|
| 425 |
+
|
| 426 |
+
def forward(self, x):
|
| 427 |
+
x = self.conv1(x)
|
| 428 |
+
x = self.bn1(x)
|
| 429 |
+
x = self.relu(x)
|
| 430 |
+
x = self.conv2(x)
|
| 431 |
+
x = self.bn2(x)
|
| 432 |
+
x = self.relu(x)
|
| 433 |
+
x = self.layer1(x)
|
| 434 |
+
|
| 435 |
+
x_list = []
|
| 436 |
+
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
| 437 |
+
if self.transition1[i] is not None:
|
| 438 |
+
x_list.append(self.transition1[i](x))
|
| 439 |
+
else:
|
| 440 |
+
x_list.append(x)
|
| 441 |
+
y_list = self.stage2(x_list)
|
| 442 |
+
|
| 443 |
+
x_list = []
|
| 444 |
+
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
| 445 |
+
if self.transition2[i] is not None:
|
| 446 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
| 447 |
+
else:
|
| 448 |
+
x_list.append(y_list[i])
|
| 449 |
+
y_list = self.stage3(x_list)
|
| 450 |
+
|
| 451 |
+
x_list = []
|
| 452 |
+
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
| 453 |
+
if self.transition3[i] is not None:
|
| 454 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
| 455 |
+
else:
|
| 456 |
+
x_list.append(y_list[i])
|
| 457 |
+
y_list = self.stage4(x_list)
|
| 458 |
+
|
| 459 |
+
x = self.final_layer(y_list[0])
|
| 460 |
+
|
| 461 |
+
ret = {}
|
| 462 |
+
for head in self.heads.keys():
|
| 463 |
+
if head == 'hm':
|
| 464 |
+
ret[head] = x
|
| 465 |
+
else:
|
| 466 |
+
x1 = self.final_layer_cls(x)
|
| 467 |
+
ret[head] = x1
|
| 468 |
+
return [ret]
|
| 469 |
+
|
| 470 |
+
def init_weights(self, pretrained='', **kwargs):
|
| 471 |
+
for m in self.modules():
|
| 472 |
+
if isinstance(m, nn.Conv2d):
|
| 473 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 474 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 475 |
+
for name, _ in m.named_parameters():
|
| 476 |
+
if name in ['bias']:
|
| 477 |
+
nn.init.constant_(m.bias, 0)
|
| 478 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 479 |
+
nn.init.constant_(m.weight, 1)
|
| 480 |
+
nn.init.constant_(m.bias, 0)
|
| 481 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 482 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 483 |
+
for name, _ in m.named_parameters():
|
| 484 |
+
if name in ['bias']:
|
| 485 |
+
nn.init.constant_(m.bias, 0)
|
| 486 |
+
|
| 487 |
+
if os.path.isfile(pretrained):
|
| 488 |
+
pretrained_state_dict = torch.load(pretrained)
|
| 489 |
+
|
| 490 |
+
need_init_state_dict = {}
|
| 491 |
+
for name, m in pretrained_state_dict.items():
|
| 492 |
+
if name.split('.')[0] in self.pretrained_layers \
|
| 493 |
+
or self.pretrained_layers[0] == '*':
|
| 494 |
+
need_init_state_dict[name] = m
|
| 495 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
| 496 |
+
elif pretrained:
|
| 497 |
+
raise ValueError('{} is not exist!'.format(pretrained))
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def get_pose_net(cfg, is_train, **kwargs):
|
| 501 |
+
model = PoseHighResolutionNet(cfg, **kwargs)
|
| 502 |
+
|
| 503 |
+
if is_train and cfg.MODEL.INIT_WEIGHTS:
|
| 504 |
+
model.init_weights(cfg.MODEL.PRETRAINED)
|
| 505 |
+
|
| 506 |
+
return model
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
if __name__ == "__main__":
|
| 510 |
+
from configs.get_config import load_config
|
| 511 |
+
from builder import build_model
|
| 512 |
+
cfg = load_config("configs/hrnet_sbi.yaml")
|
| 513 |
+
|
| 514 |
+
hrnet = build_model(cfg.MODEL, MODELS, default_args=dict(cfg=cfg))
|
| 515 |
+
print(hrnet)
|
models/networks/xception.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Creates an Xception Model as defined in:
|
| 3 |
+
|
| 4 |
+
Francois Chollet
|
| 5 |
+
Xception: Deep Learning with Depthwise Separable Convolutions
|
| 6 |
+
https://arxiv.org/pdf/1610.02357.pdf
|
| 7 |
+
|
| 8 |
+
This weights ported from the Keras implementation. Achieves the following performance on the validation set:
|
| 9 |
+
|
| 10 |
+
Loss:0.9173 Prec@1:78.892 Prec@5:94.292
|
| 11 |
+
|
| 12 |
+
REMEMBER to set your image size to 3x299x299 for both test and validation
|
| 13 |
+
|
| 14 |
+
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
| 15 |
+
std=[0.5, 0.5, 0.5])
|
| 16 |
+
|
| 17 |
+
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
| 18 |
+
"""
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torch.utils.model_zoo as model_zoo
|
| 24 |
+
from torch.nn import init
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from ..builder import MODELS
|
| 28 |
+
from .common import conv_block, BN_MOMENTUM
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
model_urls = {
|
| 32 |
+
'xception':'https://www.dropbox.com/s/1hplpzet9d7dv29/xception-c0a72b38.pth.tar?dl=1'
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SeparableConv2d(nn.Module):
|
| 37 |
+
def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
|
| 38 |
+
super(SeparableConv2d,self).__init__()
|
| 39 |
+
|
| 40 |
+
self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
|
| 41 |
+
self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
|
| 42 |
+
|
| 43 |
+
def forward(self,x):
|
| 44 |
+
x = self.conv1(x)
|
| 45 |
+
x = self.pointwise(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Block(nn.Module):
|
| 50 |
+
def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
|
| 51 |
+
super(Block, self).__init__()
|
| 52 |
+
|
| 53 |
+
if out_filters != in_filters or strides!=1:
|
| 54 |
+
self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
|
| 55 |
+
self.skipbn = nn.BatchNorm2d(out_filters)
|
| 56 |
+
else:
|
| 57 |
+
self.skip=None
|
| 58 |
+
|
| 59 |
+
self.relu = nn.ReLU(inplace=True)
|
| 60 |
+
rep=[]
|
| 61 |
+
|
| 62 |
+
filters=in_filters
|
| 63 |
+
if grow_first:
|
| 64 |
+
rep.append(self.relu)
|
| 65 |
+
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
|
| 66 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
| 67 |
+
filters = out_filters
|
| 68 |
+
|
| 69 |
+
for i in range(reps-1):
|
| 70 |
+
rep.append(self.relu)
|
| 71 |
+
rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
|
| 72 |
+
rep.append(nn.BatchNorm2d(filters))
|
| 73 |
+
|
| 74 |
+
if not grow_first:
|
| 75 |
+
rep.append(self.relu)
|
| 76 |
+
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
|
| 77 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
| 78 |
+
|
| 79 |
+
if not start_with_relu:
|
| 80 |
+
rep = rep[1:]
|
| 81 |
+
else:
|
| 82 |
+
rep[0] = nn.ReLU(inplace=False)
|
| 83 |
+
|
| 84 |
+
if strides != 1:
|
| 85 |
+
rep.append(nn.MaxPool2d(3,strides,1))
|
| 86 |
+
self.rep = nn.Sequential(*rep)
|
| 87 |
+
|
| 88 |
+
def forward(self,inp):
|
| 89 |
+
x = self.rep(inp)
|
| 90 |
+
|
| 91 |
+
if self.skip is not None:
|
| 92 |
+
skip = self.skip(inp)
|
| 93 |
+
skip = self.skipbn(skip)
|
| 94 |
+
else:
|
| 95 |
+
skip = inp
|
| 96 |
+
|
| 97 |
+
x+=skip
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@MODELS.register_module()
|
| 102 |
+
class Xception(nn.Module):
|
| 103 |
+
"""
|
| 104 |
+
Xception optimized for the ImageNet dataset, as specified in
|
| 105 |
+
https://arxiv.org/pdf/1610.02357.pdf
|
| 106 |
+
"""
|
| 107 |
+
def __init__(self,
|
| 108 |
+
heads,
|
| 109 |
+
head_conv=64,
|
| 110 |
+
cls_based_hm=True,
|
| 111 |
+
dropout_prob=0.5,
|
| 112 |
+
**kwargs):
|
| 113 |
+
""" Constructor
|
| 114 |
+
Args:
|
| 115 |
+
num_classes: number of classes
|
| 116 |
+
"""
|
| 117 |
+
self.heads = heads
|
| 118 |
+
self.head_conv = head_conv
|
| 119 |
+
self.cls_based_hm = cls_based_hm
|
| 120 |
+
self.dropout_prob = dropout_prob
|
| 121 |
+
super(Xception, self).__init__()
|
| 122 |
+
|
| 123 |
+
self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
|
| 124 |
+
self.bn1 = nn.BatchNorm2d(32)
|
| 125 |
+
self.relu = nn.ReLU(inplace=True)
|
| 126 |
+
|
| 127 |
+
self.conv2 = nn.Conv2d(32,64,3,bias=False)
|
| 128 |
+
self.bn2 = nn.BatchNorm2d(64)
|
| 129 |
+
#do relu here
|
| 130 |
+
|
| 131 |
+
self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
|
| 132 |
+
self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
|
| 133 |
+
self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)
|
| 134 |
+
|
| 135 |
+
self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 136 |
+
self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 137 |
+
self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 138 |
+
self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 139 |
+
|
| 140 |
+
self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 141 |
+
self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 142 |
+
self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 143 |
+
self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 144 |
+
|
| 145 |
+
self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
|
| 146 |
+
|
| 147 |
+
self.conv3 = SeparableConv2d(1024,1536,3,1,1)
|
| 148 |
+
self.bn3 = nn.BatchNorm2d(1536)
|
| 149 |
+
|
| 150 |
+
#do relu here
|
| 151 |
+
self.conv4 = SeparableConv2d(1536,2048,3,1,1)
|
| 152 |
+
self.bn4 = nn.BatchNorm2d(2048)
|
| 153 |
+
|
| 154 |
+
self.dropout = nn.Dropout2d(p=self.dropout_prob)
|
| 155 |
+
|
| 156 |
+
self.conv_block_1 = conv_block(2048, 256, (3,3), padding=1)
|
| 157 |
+
self.deconv_1 = nn.Sequential(
|
| 158 |
+
nn.ConvTranspose2d(
|
| 159 |
+
in_channels=256,
|
| 160 |
+
out_channels=256,
|
| 161 |
+
kernel_size=(4,4),
|
| 162 |
+
stride=2,
|
| 163 |
+
padding=1,
|
| 164 |
+
output_padding=0,
|
| 165 |
+
bias=False),
|
| 166 |
+
nn.BatchNorm2d(256, momentum=BN_MOMENTUM),
|
| 167 |
+
nn.ReLU(inplace=True)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.conv_block_2 = conv_block(256, 256, (3,3), padding=1)
|
| 171 |
+
self.deconv_2 = nn.Sequential(
|
| 172 |
+
nn.ConvTranspose2d(
|
| 173 |
+
in_channels=256,
|
| 174 |
+
out_channels=128,
|
| 175 |
+
kernel_size=(4,4),
|
| 176 |
+
stride=2,
|
| 177 |
+
padding=1,
|
| 178 |
+
output_padding=0,
|
| 179 |
+
bias=False),
|
| 180 |
+
nn.BatchNorm2d(128, momentum=BN_MOMENTUM),
|
| 181 |
+
nn.ReLU(inplace=True)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.conv_block_3 = conv_block(128, 128, (3,3), padding=1)
|
| 185 |
+
self.deconv_3 = nn.Sequential(
|
| 186 |
+
nn.ConvTranspose2d(
|
| 187 |
+
in_channels=128,
|
| 188 |
+
out_channels=64,
|
| 189 |
+
kernel_size=(4,4),
|
| 190 |
+
stride=2,
|
| 191 |
+
padding=1,
|
| 192 |
+
output_padding=0,
|
| 193 |
+
bias=False),
|
| 194 |
+
nn.BatchNorm2d(64, momentum=BN_MOMENTUM),
|
| 195 |
+
nn.ReLU(inplace=True)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
for head in sorted(self.heads):
|
| 199 |
+
num_output = self.heads[head]
|
| 200 |
+
if self.head_conv > 0:
|
| 201 |
+
if head != 'cls':
|
| 202 |
+
fc = nn.Sequential(
|
| 203 |
+
nn.Conv2d(64, self.head_conv,
|
| 204 |
+
kernel_size=3, padding=1, bias=False),
|
| 205 |
+
nn.BatchNorm2d(self.head_conv),
|
| 206 |
+
nn.ReLU(inplace=True),
|
| 207 |
+
nn.Conv2d(self.head_conv, num_output,
|
| 208 |
+
kernel_size=1, stride=1, padding=0)
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
if self.cls_based_hm:
|
| 212 |
+
fc = nn.Sequential(
|
| 213 |
+
nn.AdaptiveAvgPool2d(head_conv//4),
|
| 214 |
+
nn.Flatten(),
|
| 215 |
+
nn.Linear((head_conv//4)**2, head_conv, bias=False),
|
| 216 |
+
nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
|
| 217 |
+
nn.ReLU(inplace=True),
|
| 218 |
+
nn.Linear(head_conv, num_output, bias=True),
|
| 219 |
+
nn.Sigmoid()
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
fc = nn.Sequential(
|
| 223 |
+
nn.Conv2d(64, head_conv, kernel_size=3,
|
| 224 |
+
padding=1, bias=False),
|
| 225 |
+
nn.BatchNorm2d(head_conv, momentum=BN_MOMENTUM),
|
| 226 |
+
nn.ReLU(inplace=True),
|
| 227 |
+
nn.Conv2d(head_conv, num_output, kernel_size=1,
|
| 228 |
+
stride=1, padding=0, bias=False),
|
| 229 |
+
nn.BatchNorm2d(num_output),
|
| 230 |
+
# nn.ReLU(inplace=True),
|
| 231 |
+
nn.AdaptiveAvgPool2d(head_conv//4),
|
| 232 |
+
nn.Flatten(),
|
| 233 |
+
nn.Linear((head_conv//4)**2, head_conv, bias=False),
|
| 234 |
+
nn.BatchNorm1d(head_conv, momentum=BN_MOMENTUM),
|
| 235 |
+
nn.ReLU(inplace=True),
|
| 236 |
+
nn.Linear(head_conv, num_output, bias=True),
|
| 237 |
+
nn.Sigmoid()
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
fc = nn.Conv2d(
|
| 241 |
+
in_channels=64,
|
| 242 |
+
out_channels=num_output,
|
| 243 |
+
kernel_size=1,
|
| 244 |
+
stride=1,
|
| 245 |
+
padding=0
|
| 246 |
+
)
|
| 247 |
+
self.__setattr__(head, fc)
|
| 248 |
+
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
x = self.conv1(x)
|
| 251 |
+
x = self.bn1(x)
|
| 252 |
+
x = self.relu(x)
|
| 253 |
+
|
| 254 |
+
x = self.conv2(x)
|
| 255 |
+
x = self.bn2(x)
|
| 256 |
+
x = self.relu(x)
|
| 257 |
+
|
| 258 |
+
x = self.block1(x)
|
| 259 |
+
x = self.block2(x)
|
| 260 |
+
x = self.block3(x)
|
| 261 |
+
x = self.block4(x)
|
| 262 |
+
x = self.block5(x)
|
| 263 |
+
x = self.block6(x)
|
| 264 |
+
x = self.block7(x)
|
| 265 |
+
x = self.block8(x)
|
| 266 |
+
x = self.block9(x)
|
| 267 |
+
x = self.block10(x)
|
| 268 |
+
x = self.block11(x)
|
| 269 |
+
x = self.block12(x)
|
| 270 |
+
|
| 271 |
+
x = self.conv3(x)
|
| 272 |
+
x = self.bn3(x)
|
| 273 |
+
x = self.relu(x)
|
| 274 |
+
|
| 275 |
+
x = self.conv4(x)
|
| 276 |
+
x = self.bn4(x)
|
| 277 |
+
x = self.relu(x)
|
| 278 |
+
|
| 279 |
+
x = self.dropout(x)
|
| 280 |
+
|
| 281 |
+
x = self.conv_block_1(x)
|
| 282 |
+
x = self.deconv_1(x)
|
| 283 |
+
|
| 284 |
+
x = self.conv_block_2(x)
|
| 285 |
+
x = self.deconv_2(x)
|
| 286 |
+
|
| 287 |
+
x = self.conv_block_3(x)
|
| 288 |
+
x = self.deconv_3(x)
|
| 289 |
+
|
| 290 |
+
ret = {}
|
| 291 |
+
x1_hm = None
|
| 292 |
+
for head in self.heads:
|
| 293 |
+
if not self.cls_based_hm or head != 'cls':
|
| 294 |
+
ret[head] = self.__getattr__(head)(x)
|
| 295 |
+
if head == 'hm':
|
| 296 |
+
x1_hm = ret[head]
|
| 297 |
+
else:
|
| 298 |
+
assert 'hm' in ret.keys(), "Other heads need features from heatmap, please check it!"
|
| 299 |
+
ret[head] = self.__getattr__(head)(x1_hm)
|
| 300 |
+
return [ret]
|
| 301 |
+
|
| 302 |
+
def init_weights(self, pretrained=False):
|
| 303 |
+
if not pretrained:
|
| 304 |
+
for m in self.modules():
|
| 305 |
+
if isinstance(m, nn.Conv2d):
|
| 306 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 307 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 308 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 309 |
+
m.weight.data.fill_(1)
|
| 310 |
+
m.bias.data.zero_()
|
| 311 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 312 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 313 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 314 |
+
if self.deconv_with_bias:
|
| 315 |
+
nn.init.constant_(m.bias, 0)
|
| 316 |
+
else:
|
| 317 |
+
self.load_state_dict(model_zoo.load_url(model_urls['xception']), strict=False)
|
| 318 |
+
|
| 319 |
+
# Init head parameters
|
| 320 |
+
for head in self.heads:
|
| 321 |
+
final_layer = self.__getattr__(head)
|
| 322 |
+
for i, m in enumerate(final_layer.modules()):
|
| 323 |
+
prior = 1/71
|
| 324 |
+
# if isinstance(m, nn.Conv2d):
|
| 325 |
+
# if m.weight.shape[0] == self.heads[head]:
|
| 326 |
+
# if 'hm' in head:
|
| 327 |
+
# # nn.init.constant_(m.bias, -2.19)
|
| 328 |
+
# nn.init.constant_(m.bias, -math.log((1-prior)/prior))
|
| 329 |
+
# else:
|
| 330 |
+
# nn.init.normal_(m.weight, std=0.001)
|
| 331 |
+
# # nn.init.constant_(m.bias, 0)
|
| 332 |
+
if isinstance(m, nn.Linear):
|
| 333 |
+
if m.weight.shape[0] == self.heads[head]:
|
| 334 |
+
nn.init.constant_(m.bias, -math.log((1-prior)/prior))
|
| 335 |
+
# else:
|
| 336 |
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 337 |
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 338 |
+
# # nn.init.constant_(m.bias, 0)
|
models/utils.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#-*- coding: utf-8 -*-
|
| 2 |
+
from __future__ import absolute_import
|
| 3 |
+
from __future__ import division
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
layers_position = {
|
| 13 |
+
'PoseResNet_50': 158,
|
| 14 |
+
'PoseResNet_101': 311,
|
| 15 |
+
'PoseEfficientNet_B4': 415,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def preset_model(cfg, model, optimizer=None):
|
| 20 |
+
#Loading models from config, make sure the pretrained path correct to the model name
|
| 21 |
+
start_epoch = 0
|
| 22 |
+
if 'pretrained' in cfg.TRAIN and os.path.isfile(cfg.TRAIN.pretrained):
|
| 23 |
+
model, optimizer, start_epoch = load_model(model,
|
| 24 |
+
cfg.TRAIN.pretrained,
|
| 25 |
+
optimizer=optimizer,
|
| 26 |
+
resume=cfg.TRAIN.resume,
|
| 27 |
+
lr=cfg.TRAIN.lr,
|
| 28 |
+
lr_step=cfg.TRAIN.lr_scheduler.milestones,
|
| 29 |
+
gamma=cfg.TRAIN.lr_scheduler.gamma)
|
| 30 |
+
else:
|
| 31 |
+
model.init_weights(**cfg.MODEL.INIT_WEIGHTS)
|
| 32 |
+
print('Loading model successfully -- {}'.format(cfg.MODEL.type))
|
| 33 |
+
|
| 34 |
+
#Freeze backbone if begin_epoch < warm up
|
| 35 |
+
if cfg.TRAIN.freeze_backbone and start_epoch < cfg.TRAIN.warm_up:
|
| 36 |
+
freeze_backbone(cfg.MODEL, model)
|
| 37 |
+
|
| 38 |
+
print('Number of parameters', sum(p.numel() for p in model.parameters()))
|
| 39 |
+
print('Number of trainable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))
|
| 40 |
+
return model, optimizer, start_epoch
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_pretrained(model, weight_path):
|
| 44 |
+
'''
|
| 45 |
+
This function only care about state dict of model
|
| 46 |
+
For other modules such as optimizer, resume learning, please refer @load_model
|
| 47 |
+
'''
|
| 48 |
+
state_dict = torch.load(weight_path)['state_dict']
|
| 49 |
+
model.load_state_dict(state_dict, strict=True)
|
| 50 |
+
return model
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def freeze_backbone(cfg, model):
|
| 54 |
+
'''
|
| 55 |
+
This func to freeze some specific layers to warm up the models
|
| 56 |
+
'''
|
| 57 |
+
if hasattr(model, 'backbone'):
|
| 58 |
+
backbone = model.backbone
|
| 59 |
+
for param in backbone.parameters():
|
| 60 |
+
param.requires_grad = False
|
| 61 |
+
else:
|
| 62 |
+
for i, (n, p) in enumerate(model.named_parameters()):
|
| 63 |
+
if (i <= layers_position[f'{cfg.type}_{cfg.num_layers}']):
|
| 64 |
+
p.requires_grad = False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def unfreeze_backbone(model):
|
| 68 |
+
'''
|
| 69 |
+
This func to unfreeze all model layers
|
| 70 |
+
'''
|
| 71 |
+
for param in model.parameters():
|
| 72 |
+
if not param.requires_grad:
|
| 73 |
+
param.requires_grad = True
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_model(model, model_path, optimizer=None, resume=False,
|
| 77 |
+
lr=None, lr_step=None, gamma=None):
|
| 78 |
+
start_epoch = 0
|
| 79 |
+
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
|
| 80 |
+
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
|
| 81 |
+
state_dict_ = checkpoint['state_dict']
|
| 82 |
+
state_dict = {}
|
| 83 |
+
|
| 84 |
+
# convert data_parallal to model
|
| 85 |
+
for k in state_dict_:
|
| 86 |
+
if k.startswith('module') and not k.startswith('module_list'):
|
| 87 |
+
state_dict[k[7:]] = state_dict_[k]
|
| 88 |
+
else:
|
| 89 |
+
state_dict[k] = state_dict_[k]
|
| 90 |
+
model_state_dict = model.state_dict()
|
| 91 |
+
|
| 92 |
+
# check loaded parameters and created model parameters
|
| 93 |
+
msg = 'If you see this, your model does not fully load the ' + \
|
| 94 |
+
'pre-trained weight. Please make sure ' + \
|
| 95 |
+
'you have correctly specified --arch xxx ' + \
|
| 96 |
+
'or set the correct --num_classes for your own dataset.'
|
| 97 |
+
for k in state_dict:
|
| 98 |
+
if k in model_state_dict:
|
| 99 |
+
if state_dict[k].shape != model_state_dict[k].shape:
|
| 100 |
+
print('Skip loading parameter {}, required shape{}, '\
|
| 101 |
+
'loaded shape{}. {}'.format(
|
| 102 |
+
k, model_state_dict[k].shape, state_dict[k].shape, msg))
|
| 103 |
+
state_dict[k] = model_state_dict[k]
|
| 104 |
+
else:
|
| 105 |
+
print('Drop parameter {}.'.format(k) + msg)
|
| 106 |
+
for k in model_state_dict:
|
| 107 |
+
if not (k in state_dict):
|
| 108 |
+
print('No param {}.'.format(k) + msg)
|
| 109 |
+
state_dict[k] = model_state_dict[k]
|
| 110 |
+
model.load_state_dict(state_dict, strict=False)
|
| 111 |
+
|
| 112 |
+
# resume optimizer parameters
|
| 113 |
+
if optimizer is not None and resume:
|
| 114 |
+
if 'optimizer' in checkpoint:
|
| 115 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 116 |
+
start_epoch = checkpoint['epoch'] + 1
|
| 117 |
+
start_lr = lr
|
| 118 |
+
for step in lr_step:
|
| 119 |
+
if start_epoch >= step:
|
| 120 |
+
start_lr *= gamma
|
| 121 |
+
for param_group in optimizer.param_groups:
|
| 122 |
+
param_group['lr'] = start_lr
|
| 123 |
+
print('Resumed optimizer with start lr', start_lr)
|
| 124 |
+
else:
|
| 125 |
+
print('No optimizer parameters in checkpoint.')
|
| 126 |
+
return model, optimizer, start_epoch
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def save_model(path, epoch, model, optimizer=None):
|
| 130 |
+
if isinstance(model, torch.nn.DataParallel):
|
| 131 |
+
state_dict = model.module.state_dict()
|
| 132 |
+
else:
|
| 133 |
+
state_dict = model.state_dict()
|
| 134 |
+
data = {'epoch': epoch,
|
| 135 |
+
'state_dict': state_dict}
|
| 136 |
+
if not (optimizer is None):
|
| 137 |
+
data['optimizer'] = optimizer.state_dict()
|
| 138 |
+
torch.save(data, path)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
opencv-python>=4.8.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
Pillow>=10.0.0
|
| 6 |
+
gradio>=3.50.0
|
| 7 |
+
detectron2>=0.6.0; platform_system!="Darwin" # Detectron2 not available for macOS
|
| 8 |
+
fvcore>=0.1.5.post20221221; platform_system!="Darwin" # Required for detectron2
|
| 9 |
+
iopath>=0.1.9; platform_system!="Darwin" # Required for detectron2
|
| 10 |
+
pycocotools>=2.0.6; platform_system!="Darwin" # Required for detectron2
|