NoelShin
commited on
Commit
·
82d5d16
1
Parent(s):
b8f8de8
Add application file
Browse files- .DS_Store +0 -0
- .idea/vcs.xml +6 -0
- .idea/workspace.xml +57 -0
- README.md +5 -5
- app.py +96 -0
- description.html +19 -0
- images/2007_002260.jpg +0 -0
- images/2008_002536.jpg +0 -0
- images/2008_003499.jpg +0 -0
- images/2008_007814.jpg +0 -0
- images/2009_004801.jpg +0 -0
- images/2010_001079.jpg +0 -0
- images/2010_005063.jpg +0 -0
- networks/__init__.py +2 -0
- networks/_deeplab.py +190 -0
- networks/backbone/__init__.py +3 -0
- networks/backbone/hrnetv2.py +330 -0
- networks/backbone/mobilenetv2.py +188 -0
- networks/backbone/resnet.py +335 -0
- networks/modeling.py +181 -0
- networks/utils.py +90 -0
- requirements.txt +5 -0
- utils.py +59 -0
- voc_val_n500_cp2_ex.yaml +50 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="AutoImportSettings">
|
| 4 |
+
<option name="autoReloadType" value="SELECTIVE" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ChangeListManager">
|
| 7 |
+
<list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
|
| 8 |
+
<change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
|
| 9 |
+
</list>
|
| 10 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 11 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 12 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 13 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 14 |
+
</component>
|
| 15 |
+
<component name="Git.Settings">
|
| 16 |
+
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
| 17 |
+
</component>
|
| 18 |
+
<component name="ProjectId" id="2FJJUIvRiY0OO5Dz2zvs0pNxkhb" />
|
| 19 |
+
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
|
| 20 |
+
<component name="ProjectViewState">
|
| 21 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 22 |
+
<option name="showLibraryContents" value="true" />
|
| 23 |
+
<option name="showMembers" value="true" />
|
| 24 |
+
</component>
|
| 25 |
+
<component name="PropertiesComponent">
|
| 26 |
+
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
|
| 27 |
+
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
|
| 28 |
+
<property name="WebServerToolWindowFactoryState" value="false" />
|
| 29 |
+
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
|
| 30 |
+
<property name="node.js.detected.package.eslint" value="true" />
|
| 31 |
+
<property name="node.js.detected.package.tslint" value="true" />
|
| 32 |
+
<property name="node.js.selected.package.eslint" value="(autodetect)" />
|
| 33 |
+
<property name="node.js.selected.package.tslint" value="(autodetect)" />
|
| 34 |
+
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
|
| 35 |
+
</component>
|
| 36 |
+
<component name="RecentsManager">
|
| 37 |
+
<key name="CopyFile.RECENT_KEYS">
|
| 38 |
+
<recent name="$PROJECT_DIR$" />
|
| 39 |
+
</key>
|
| 40 |
+
</component>
|
| 41 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
| 42 |
+
<component name="TaskManager">
|
| 43 |
+
<task active="true" id="Default" summary="Default task">
|
| 44 |
+
<changelist id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="" />
|
| 45 |
+
<created>1664204268713</created>
|
| 46 |
+
<option name="number" value="Default" />
|
| 47 |
+
<option name="presentableId" value="Default" />
|
| 48 |
+
<updated>1664204268713</updated>
|
| 49 |
+
<workItem from="1664204270261" duration="37000" />
|
| 50 |
+
<workItem from="1664204316867" duration="4389000" />
|
| 51 |
+
</task>
|
| 52 |
+
<servers />
|
| 53 |
+
</component>
|
| 54 |
+
<component name="TypeScriptGeneratedFilesManager">
|
| 55 |
+
<option name="version" value="3" />
|
| 56 |
+
</component>
|
| 57 |
+
</project>
|
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 1 |
---
|
| 2 |
+
title: namedmask
|
| 3 |
+
emoji: 😷
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 2.9.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
app.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser, Namespace
|
| 2 |
+
from typing import Dict, List, Tuple
|
| 3 |
+
import codecs
|
| 4 |
+
import yaml
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision.transforms.functional import to_tensor, normalize, resize
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from utils import get_network, colourise_mask
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 16 |
+
|
| 17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
# state_dict: dict = torch.hub.load_state_dict_from_url(
|
| 19 |
+
# "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
|
| 20 |
+
# map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
# )["model"]
|
| 22 |
+
|
| 23 |
+
parser = ArgumentParser("NamedMask demo")
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--config",
|
| 26 |
+
type=str,
|
| 27 |
+
default="voc_val_n500_cp2_ex.yaml"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
args: Namespace = parser.parse_args()
|
| 31 |
+
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
|
| 32 |
+
base_args.pop("dataset_name")
|
| 33 |
+
args: dict = vars(args)
|
| 34 |
+
args.update(base_args)
|
| 35 |
+
args: Namespace = Namespace(**args)
|
| 36 |
+
|
| 37 |
+
model = get_network().to(device)
|
| 38 |
+
# model.load_state_dict(state_dict)
|
| 39 |
+
model.eval()
|
| 40 |
+
|
| 41 |
+
size: int = 384
|
| 42 |
+
max_size: int = 512
|
| 43 |
+
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
|
| 44 |
+
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def main(image: Image):
|
| 49 |
+
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
|
| 50 |
+
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
|
| 51 |
+
|
| 52 |
+
# logits: b (=1) x n_categories x H x W, torch.float32
|
| 53 |
+
logits: torch.Tensor = model(image[None].to(device))
|
| 54 |
+
|
| 55 |
+
# pred: H x W
|
| 56 |
+
pred: torch.Tensor = logits.squeeze(dim=0).argmax(dim=0).cpu().numpy()
|
| 57 |
+
coloured_pred: np.ndarray = colourise_mask(mask=pred.cpu().numpy())
|
| 58 |
+
super_imposed_img = cv2.addWeighted(coloured_pred, 0.5, np.array(pil_image), 0.5, 0)
|
| 59 |
+
|
| 60 |
+
# resize prediction to original resolution
|
| 61 |
+
# note: upsampling by 4 and cutting the padded region allows for a better result
|
| 62 |
+
# H, W = image.shape[-2:]
|
| 63 |
+
#
|
| 64 |
+
# # iterate over batch dimension
|
| 65 |
+
# pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
|
| 66 |
+
#
|
| 67 |
+
# pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
|
| 68 |
+
#
|
| 69 |
+
# attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
|
| 70 |
+
# super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
|
| 71 |
+
return super_imposed_img
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
demo = gr.Interface(
|
| 75 |
+
fn=main,
|
| 76 |
+
inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
|
| 77 |
+
outputs=gr.outputs.Image(type="numpy", label="prediction"), # "image",
|
| 78 |
+
examples=[f"images/{fname}.jpg" for fname in [
|
| 79 |
+
"2007_002260",
|
| 80 |
+
"2008_002536",
|
| 81 |
+
"2008_003499",
|
| 82 |
+
"2008_007814",
|
| 83 |
+
"2009_004801",
|
| 84 |
+
"2010_001079",
|
| 85 |
+
"2010_005063"
|
| 86 |
+
]],
|
| 87 |
+
examples_per_page=10,
|
| 88 |
+
description=codecs.open("description.html", 'r', "utf-8").read(),
|
| 89 |
+
title="NamedMask: Distilling Segmenters from Complementary Foundation Models",
|
| 90 |
+
allow_flagging="never",
|
| 91 |
+
analytics_enabled=False
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
demo.launch(
|
| 95 |
+
# share=True
|
| 96 |
+
)
|
description.html
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<title>Title</title>
|
| 6 |
+
</head>
|
| 7 |
+
<body>
|
| 8 |
+
This is a demo of <a href="https://arxiv.org/pdf/2209.11228.pdf">NamedMask: Distilling Segmenters from Complementary Foundation Models</a>.</br>
|
| 9 |
+
The goal of this work is to segment and name regions of images without access to pixel-level labels during training.
|
| 10 |
+
To tackle this task, we construct segmenters by distilling the complementary strengths of two foundation models.
|
| 11 |
+
The first, CLIP (Radford et al. 2021), exhibits the ability to assign names to image content but lacks an accessible representation of object structure.
|
| 12 |
+
The second, DINO (Caron et al. 2021), captures the spatial extent of objects but has no knowledge of object names.
|
| 13 |
+
Our method, termed NamedMask, begins by using CLIP to construct category-specific archives of images.
|
| 14 |
+
These images are pseudo-labelled with a category-agnostic salient object detector bootstrapped from DINO, then refined by category-specific segmenters using the CLIP archive labels.
|
| 15 |
+
Thanks to the high quality of the refined masks, we show that a standard segmentation architecture trained on these archives with appropriate data augmentation achieves impressive semantic segmentation abilities for both single-object and multi-object images.
|
| 16 |
+
As a result, our proposed NamedMask performs favourably against a range of prior work on five benchmarks including the VOC2012, COCO and large-scale ImageNet-S datasets.
|
| 17 |
+
Code is publicly available at <a href="https://github.com/NoelShin/namedmask">our repo</a>.
|
| 18 |
+
</body>
|
| 19 |
+
</html>
|
images/2007_002260.jpg
ADDED
|
images/2008_002536.jpg
ADDED
|
images/2008_003499.jpg
ADDED
|
images/2008_007814.jpg
ADDED
|
images/2009_004801.jpg
ADDED
|
images/2010_001079.jpg
ADDED
|
images/2010_005063.jpg
ADDED
|
networks/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from networks.modeling import *
|
| 2 |
+
from networks._deeplab import convert_to_separable_conv, set_bn_momentum
|
networks/_deeplab.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from networks.deeplab.utils import _SimpleSegmentationModel
|
| 6 |
+
|
| 7 |
+
__all__ = ["DeepLabV3"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DeepLabV3(_SimpleSegmentationModel):
|
| 11 |
+
"""
|
| 12 |
+
Implements DeepLabV3 model from
|
| 13 |
+
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
|
| 14 |
+
<https://arxiv.org/abs/1706.05587>`_.
|
| 15 |
+
Arguments:
|
| 16 |
+
backbone (nn.Module): the network used to compute the features for the model.
|
| 17 |
+
The backbone should return an OrderedDict[Tensor], with the key being
|
| 18 |
+
"out" for the last feature map used, and "aux" if an auxiliary classifier
|
| 19 |
+
is used.
|
| 20 |
+
classifier (nn.Module): module that takes the "out" element returned from
|
| 21 |
+
the backbone and returns a dense prediction.
|
| 22 |
+
aux_classifier (nn.Module, optional): auxiliary classifier used during training
|
| 23 |
+
"""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DeepLabHeadV3Plus(nn.Module):
|
| 28 |
+
def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
|
| 29 |
+
super(DeepLabHeadV3Plus, self).__init__()
|
| 30 |
+
self.project = nn.Sequential(
|
| 31 |
+
nn.Conv2d(low_level_channels, 48, 1, bias=False),
|
| 32 |
+
nn.BatchNorm2d(48),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.aspp = ASPP(in_channels, aspp_dilate)
|
| 37 |
+
|
| 38 |
+
self.classifier = nn.Sequential(
|
| 39 |
+
nn.Conv2d(304, 256, 3, padding=1, bias=False),
|
| 40 |
+
nn.BatchNorm2d(256),
|
| 41 |
+
nn.ReLU(inplace=True),
|
| 42 |
+
nn.Conv2d(256, num_classes, 1)
|
| 43 |
+
)
|
| 44 |
+
self._init_weight()
|
| 45 |
+
|
| 46 |
+
def forward(self, feature):
|
| 47 |
+
low_level_feature = self.project(feature['low_level'])
|
| 48 |
+
output_feature = self.aspp(feature['out'])
|
| 49 |
+
output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear',
|
| 50 |
+
align_corners=False)
|
| 51 |
+
return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
|
| 52 |
+
|
| 53 |
+
def _init_weight(self):
|
| 54 |
+
for m in self.modules():
|
| 55 |
+
if isinstance(m, nn.Conv2d):
|
| 56 |
+
nn.init.kaiming_normal_(m.weight)
|
| 57 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 58 |
+
nn.init.constant_(m.weight, 1)
|
| 59 |
+
nn.init.constant_(m.bias, 0)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class DeepLabHead(nn.Module):
|
| 63 |
+
def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
|
| 64 |
+
super(DeepLabHead, self).__init__()
|
| 65 |
+
|
| 66 |
+
self.classifier = nn.Sequential(
|
| 67 |
+
ASPP(in_channels, aspp_dilate),
|
| 68 |
+
nn.Conv2d(256, 256, 3, padding=1, bias=False),
|
| 69 |
+
nn.BatchNorm2d(256),
|
| 70 |
+
nn.ReLU(inplace=True),
|
| 71 |
+
nn.Conv2d(256, num_classes, 1)
|
| 72 |
+
)
|
| 73 |
+
self._init_weight()
|
| 74 |
+
|
| 75 |
+
def forward(self, feature):
|
| 76 |
+
return self.classifier(feature['out'])
|
| 77 |
+
|
| 78 |
+
def _init_weight(self):
|
| 79 |
+
for m in self.modules():
|
| 80 |
+
if isinstance(m, nn.Conv2d):
|
| 81 |
+
nn.init.kaiming_normal_(m.weight)
|
| 82 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 83 |
+
nn.init.constant_(m.weight, 1)
|
| 84 |
+
nn.init.constant_(m.bias, 0)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class AtrousSeparableConvolution(nn.Module):
|
| 88 |
+
""" Atrous Separable Convolution
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
| 92 |
+
stride=1, padding=0, dilation=1, bias=True):
|
| 93 |
+
super(AtrousSeparableConvolution, self).__init__()
|
| 94 |
+
self.body = nn.Sequential(
|
| 95 |
+
# Separable Conv
|
| 96 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
| 97 |
+
dilation=dilation, bias=bias, groups=in_channels),
|
| 98 |
+
# PointWise Conv
|
| 99 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self._init_weight()
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
return self.body(x)
|
| 106 |
+
|
| 107 |
+
def _init_weight(self):
|
| 108 |
+
for m in self.modules():
|
| 109 |
+
if isinstance(m, nn.Conv2d):
|
| 110 |
+
nn.init.kaiming_normal_(m.weight)
|
| 111 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 112 |
+
nn.init.constant_(m.weight, 1)
|
| 113 |
+
nn.init.constant_(m.bias, 0)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ASPPConv(nn.Sequential):
|
| 117 |
+
def __init__(self, in_channels, out_channels, dilation):
|
| 118 |
+
modules = [
|
| 119 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
|
| 120 |
+
nn.BatchNorm2d(out_channels),
|
| 121 |
+
nn.ReLU(inplace=True)
|
| 122 |
+
]
|
| 123 |
+
super(ASPPConv, self).__init__(*modules)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ASPPPooling(nn.Sequential):
|
| 127 |
+
def __init__(self, in_channels, out_channels):
|
| 128 |
+
super(ASPPPooling, self).__init__(
|
| 129 |
+
nn.AdaptiveAvgPool2d(1),
|
| 130 |
+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
| 131 |
+
nn.BatchNorm2d(out_channels),
|
| 132 |
+
nn.ReLU(inplace=True))
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
size = x.shape[-2:]
|
| 136 |
+
x = super(ASPPPooling, self).forward(x)
|
| 137 |
+
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ASPP(nn.Module):
|
| 141 |
+
def __init__(self, in_channels, atrous_rates):
|
| 142 |
+
super(ASPP, self).__init__()
|
| 143 |
+
out_channels = 256
|
| 144 |
+
modules = []
|
| 145 |
+
modules.append(nn.Sequential(
|
| 146 |
+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
| 147 |
+
nn.BatchNorm2d(out_channels),
|
| 148 |
+
nn.ReLU(inplace=True)))
|
| 149 |
+
|
| 150 |
+
rate1, rate2, rate3 = tuple(atrous_rates)
|
| 151 |
+
modules.append(ASPPConv(in_channels, out_channels, rate1))
|
| 152 |
+
modules.append(ASPPConv(in_channels, out_channels, rate2))
|
| 153 |
+
modules.append(ASPPConv(in_channels, out_channels, rate3))
|
| 154 |
+
modules.append(ASPPPooling(in_channels, out_channels))
|
| 155 |
+
|
| 156 |
+
self.convs = nn.ModuleList(modules)
|
| 157 |
+
|
| 158 |
+
self.project = nn.Sequential(
|
| 159 |
+
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
|
| 160 |
+
nn.BatchNorm2d(out_channels),
|
| 161 |
+
nn.ReLU(inplace=True),
|
| 162 |
+
nn.Dropout(0.1), )
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
res = []
|
| 166 |
+
for conv in self.convs:
|
| 167 |
+
res.append(conv(x))
|
| 168 |
+
res = torch.cat(res, dim=1)
|
| 169 |
+
return self.project(res)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def convert_to_separable_conv(module):
|
| 173 |
+
new_module = module
|
| 174 |
+
if isinstance(module, nn.Conv2d) and module.kernel_size[0] > 1:
|
| 175 |
+
new_module = AtrousSeparableConvolution(module.in_channels,
|
| 176 |
+
module.out_channels,
|
| 177 |
+
module.kernel_size,
|
| 178 |
+
module.stride,
|
| 179 |
+
module.padding,
|
| 180 |
+
module.dilation,
|
| 181 |
+
module.bias)
|
| 182 |
+
for name, child in module.named_children():
|
| 183 |
+
new_module.add_module(name, convert_to_separable_conv(child))
|
| 184 |
+
return new_module
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def set_bn_momentum(model, momentum=0.1):
|
| 188 |
+
for m in model.modules():
|
| 189 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 190 |
+
m.momentum = momentum
|
networks/backbone/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from networks.deeplab.backbone import resnet
|
| 2 |
+
from networks.deeplab.backbone import mobilenetv2
|
| 3 |
+
from networks.deeplab.backbone import hrnetv2
|
networks/backbone/hrnetv2.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
__all__ = ['HRNet', 'hrnetv2_48', 'hrnetv2_32']
|
| 7 |
+
|
| 8 |
+
# Checkpoint path of pre-trained backbone (edit to your path). Download backbone pretrained model hrnetv2-32 @
|
| 9 |
+
# https://drive.google.com/file/d/1NxCK7Zgn5PmeS7W1jYLt5J9E0RRZ2oyF/view?usp=sharing .Personally, I added the backbone
|
| 10 |
+
# weights to the folder /checkpoints
|
| 11 |
+
try:
|
| 12 |
+
CKPT_PATH = './checkpoints/hrnetv2_32_model_best_epoch96.pth'
|
| 13 |
+
print(f"Backbone HRNet Pretrained weights at: {CKPT_PATH}, only usable for HRNetv2-32")
|
| 14 |
+
except:
|
| 15 |
+
print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
|
| 16 |
+
|
| 17 |
+
# HRNetv2-48 not available yet, but you can train the whole model from scratch.
|
| 18 |
+
|
| 19 |
+
class Bottleneck(nn.Module):
|
| 20 |
+
expansion = 4
|
| 21 |
+
|
| 22 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 23 |
+
super(Bottleneck, self).__init__()
|
| 24 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 25 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 26 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 27 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 28 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 29 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 30 |
+
self.relu = nn.ReLU(inplace=True)
|
| 31 |
+
self.downsample = downsample
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
identity = x
|
| 35 |
+
|
| 36 |
+
out = self.conv1(x)
|
| 37 |
+
out = self.bn1(out)
|
| 38 |
+
out = self.relu(out)
|
| 39 |
+
out = self.conv2(out)
|
| 40 |
+
out = self.bn2(out)
|
| 41 |
+
out = self.relu(out)
|
| 42 |
+
out = self.conv3(out)
|
| 43 |
+
out = self.bn3(out)
|
| 44 |
+
|
| 45 |
+
if self.downsample is not None:
|
| 46 |
+
identity = self.downsample(x)
|
| 47 |
+
|
| 48 |
+
out += identity
|
| 49 |
+
out = self.relu(out)
|
| 50 |
+
|
| 51 |
+
return out
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BasicBlock(nn.Module):
|
| 55 |
+
expansion = 1
|
| 56 |
+
|
| 57 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 58 |
+
super(BasicBlock, self).__init__()
|
| 59 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 60 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 61 |
+
self.relu = nn.ReLU(inplace=True)
|
| 62 |
+
self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 63 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 64 |
+
self.downsample = downsample
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
identity = x
|
| 68 |
+
|
| 69 |
+
out = self.conv1(x)
|
| 70 |
+
out = self.bn1(out)
|
| 71 |
+
out = self.relu(out)
|
| 72 |
+
out = self.conv2(out)
|
| 73 |
+
out = self.bn2(out)
|
| 74 |
+
|
| 75 |
+
if self.downsample is not None:
|
| 76 |
+
identity = self.downsample(x)
|
| 77 |
+
|
| 78 |
+
out += identity
|
| 79 |
+
out = self.relu(out)
|
| 80 |
+
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class StageModule(nn.Module):
|
| 85 |
+
def __init__(self, stage, output_branches, c):
|
| 86 |
+
super(StageModule, self).__init__()
|
| 87 |
+
|
| 88 |
+
self.number_of_branches = stage # number of branches is equivalent to the stage configuration.
|
| 89 |
+
self.output_branches = output_branches
|
| 90 |
+
|
| 91 |
+
self.branches = nn.ModuleList()
|
| 92 |
+
|
| 93 |
+
# Note: Resolution + Number of channels maintains the same throughout respective branch.
|
| 94 |
+
for i in range(self.number_of_branches): # Stage scales with the number of branches. Ex: Stage 2 -> 2 branch
|
| 95 |
+
channels = c * (2 ** i) # Scale channels by 2x for branch with lower resolution,
|
| 96 |
+
|
| 97 |
+
# Paper does x4 basic block for each forward sequence in each branch (x4 basic block considered as a block)
|
| 98 |
+
branch = nn.Sequential(*[BasicBlock(channels, channels) for _ in range(4)])
|
| 99 |
+
|
| 100 |
+
self.branches.append(branch) # list containing all forward sequence of individual branches.
|
| 101 |
+
|
| 102 |
+
# For each branch requires repeated fusion with all other branches after passing through x4 basic blocks.
|
| 103 |
+
self.fuse_layers = nn.ModuleList()
|
| 104 |
+
|
| 105 |
+
for branch_output_number in range(self.output_branches):
|
| 106 |
+
|
| 107 |
+
self.fuse_layers.append(nn.ModuleList())
|
| 108 |
+
|
| 109 |
+
for branch_number in range(self.number_of_branches):
|
| 110 |
+
if branch_number == branch_output_number:
|
| 111 |
+
self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
|
| 112 |
+
elif branch_number > branch_output_number:
|
| 113 |
+
self.fuse_layers[-1].append(nn.Sequential(
|
| 114 |
+
nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=1, stride=1,
|
| 115 |
+
bias=False),
|
| 116 |
+
nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
|
| 117 |
+
track_running_stats=True),
|
| 118 |
+
nn.Upsample(scale_factor=(2.0 ** (branch_number - branch_output_number)), mode='nearest'),
|
| 119 |
+
))
|
| 120 |
+
elif branch_number < branch_output_number:
|
| 121 |
+
downsampling_fusion = []
|
| 122 |
+
for _ in range(branch_output_number - branch_number - 1):
|
| 123 |
+
downsampling_fusion.append(nn.Sequential(
|
| 124 |
+
nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_number), kernel_size=3, stride=2,
|
| 125 |
+
padding=1,
|
| 126 |
+
bias=False),
|
| 127 |
+
nn.BatchNorm2d(c * (2 ** branch_number), eps=1e-05, momentum=0.1, affine=True,
|
| 128 |
+
track_running_stats=True),
|
| 129 |
+
nn.ReLU(inplace=True),
|
| 130 |
+
))
|
| 131 |
+
downsampling_fusion.append(nn.Sequential(
|
| 132 |
+
nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=3,
|
| 133 |
+
stride=2, padding=1,
|
| 134 |
+
bias=False),
|
| 135 |
+
nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
|
| 136 |
+
track_running_stats=True),
|
| 137 |
+
))
|
| 138 |
+
self.fuse_layers[-1].append(nn.Sequential(*downsampling_fusion))
|
| 139 |
+
|
| 140 |
+
self.relu = nn.ReLU(inplace=True)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
|
| 144 |
+
# input to each stage is a list of inputs for each branch
|
| 145 |
+
x = [branch(branch_input) for branch, branch_input in zip(self.branches, x)]
|
| 146 |
+
|
| 147 |
+
x_fused = []
|
| 148 |
+
for branch_output_index in range(
|
| 149 |
+
self.output_branches): # Amount of output branches == total length of fusion layers
|
| 150 |
+
for input_index in range(self.number_of_branches): # The inputs of other branches to be fused.
|
| 151 |
+
if input_index == 0:
|
| 152 |
+
x_fused.append(self.fuse_layers[branch_output_index][input_index](x[input_index]))
|
| 153 |
+
else:
|
| 154 |
+
x_fused[branch_output_index] = x_fused[branch_output_index] + self.fuse_layers[branch_output_index][
|
| 155 |
+
input_index](x[input_index])
|
| 156 |
+
|
| 157 |
+
# After fusing all streams together, you will need to pass the fused layers
|
| 158 |
+
for i in range(self.output_branches):
|
| 159 |
+
x_fused[i] = self.relu(x_fused[i])
|
| 160 |
+
|
| 161 |
+
return x_fused # returning a list of fused outputs
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class HRNet(nn.Module):
|
| 165 |
+
def __init__(self, c=48, num_blocks=[1, 4, 3], num_classes=1000):
|
| 166 |
+
super(HRNet, self).__init__()
|
| 167 |
+
|
| 168 |
+
# Stem:
|
| 169 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 170 |
+
self.bn1 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
|
| 171 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 172 |
+
self.bn2 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
|
| 173 |
+
self.relu = nn.ReLU(inplace=True)
|
| 174 |
+
|
| 175 |
+
# Stage 1:
|
| 176 |
+
downsample = nn.Sequential(
|
| 177 |
+
nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
|
| 178 |
+
nn.BatchNorm2d(256, eps=1e-05, affine=True, track_running_stats=True),
|
| 179 |
+
)
|
| 180 |
+
# Note that bottleneck module will expand the output channels according to the output channels*block.expansion
|
| 181 |
+
bn_expansion = Bottleneck.expansion # The channel expansion is set in the bottleneck class.
|
| 182 |
+
self.layer1 = nn.Sequential(
|
| 183 |
+
Bottleneck(64, 64, downsample=downsample), # Input is 64 for first module connection
|
| 184 |
+
Bottleneck(bn_expansion * 64, 64),
|
| 185 |
+
Bottleneck(bn_expansion * 64, 64),
|
| 186 |
+
Bottleneck(bn_expansion * 64, 64),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Transition 1 - Creation of the first two branches (one full and one half resolution)
|
| 190 |
+
# Need to transition into high resolution stream and mid resolution stream
|
| 191 |
+
self.transition1 = nn.ModuleList([
|
| 192 |
+
nn.Sequential(
|
| 193 |
+
nn.Conv2d(256, c, kernel_size=3, stride=1, padding=1, bias=False),
|
| 194 |
+
nn.BatchNorm2d(c, eps=1e-05, affine=True, track_running_stats=True),
|
| 195 |
+
nn.ReLU(inplace=True),
|
| 196 |
+
),
|
| 197 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 198 |
+
nn.Conv2d(256, c * 2, kernel_size=3, stride=2, padding=1, bias=False),
|
| 199 |
+
nn.BatchNorm2d(c * 2, eps=1e-05, affine=True, track_running_stats=True),
|
| 200 |
+
nn.ReLU(inplace=True),
|
| 201 |
+
)),
|
| 202 |
+
])
|
| 203 |
+
|
| 204 |
+
# Stage 2:
|
| 205 |
+
number_blocks_stage2 = num_blocks[0]
|
| 206 |
+
self.stage2 = nn.Sequential(
|
| 207 |
+
*[StageModule(stage=2, output_branches=2, c=c) for _ in range(number_blocks_stage2)])
|
| 208 |
+
|
| 209 |
+
# Transition 2 - Creation of the third branch (1/4 resolution)
|
| 210 |
+
self.transition2 = self._make_transition_layers(c, transition_number=2)
|
| 211 |
+
|
| 212 |
+
# Stage 3:
|
| 213 |
+
number_blocks_stage3 = num_blocks[1] # number blocks you want to create before fusion
|
| 214 |
+
self.stage3 = nn.Sequential(
|
| 215 |
+
*[StageModule(stage=3, output_branches=3, c=c) for _ in range(number_blocks_stage3)])
|
| 216 |
+
|
| 217 |
+
# Transition - Creation of the fourth branch (1/8 resolution)
|
| 218 |
+
self.transition3 = self._make_transition_layers(c, transition_number=3)
|
| 219 |
+
|
| 220 |
+
# Stage 4:
|
| 221 |
+
number_blocks_stage4 = num_blocks[2] # number blocks you want to create before fusion
|
| 222 |
+
self.stage4 = nn.Sequential(
|
| 223 |
+
*[StageModule(stage=4, output_branches=4, c=c) for _ in range(number_blocks_stage4)])
|
| 224 |
+
|
| 225 |
+
# Classifier (extra module if want to use for classification):
|
| 226 |
+
# pool, reduce dimensionality, flatten, connect to linear layer for classification:
|
| 227 |
+
out_channels = sum([c * 2 ** i for i in range(len(num_blocks)+1)]) # total output channels of HRNetV2
|
| 228 |
+
pool_feature_map = 8
|
| 229 |
+
self.bn_classifier = nn.Sequential(
|
| 230 |
+
nn.Conv2d(out_channels, out_channels // 4, kernel_size=1, bias=False),
|
| 231 |
+
nn.BatchNorm2d(out_channels // 4, eps=1e-05, affine=True, track_running_stats=True),
|
| 232 |
+
nn.ReLU(inplace=True),
|
| 233 |
+
nn.AdaptiveAvgPool2d(pool_feature_map),
|
| 234 |
+
nn.Flatten(),
|
| 235 |
+
nn.Linear(pool_feature_map * pool_feature_map * (out_channels // 4), num_classes),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def _make_transition_layers(c, transition_number):
|
| 240 |
+
return nn.Sequential(
|
| 241 |
+
nn.Conv2d(c * (2 ** (transition_number - 1)), c * (2 ** transition_number), kernel_size=3, stride=2,
|
| 242 |
+
padding=1, bias=False),
|
| 243 |
+
nn.BatchNorm2d(c * (2 ** transition_number), eps=1e-05, affine=True,
|
| 244 |
+
track_running_stats=True),
|
| 245 |
+
nn.ReLU(inplace=True),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def forward(self, x):
|
| 249 |
+
# Stem:
|
| 250 |
+
x = self.conv1(x)
|
| 251 |
+
x = self.bn1(x)
|
| 252 |
+
x = self.relu(x)
|
| 253 |
+
x = self.conv2(x)
|
| 254 |
+
x = self.bn2(x)
|
| 255 |
+
x = self.relu(x)
|
| 256 |
+
|
| 257 |
+
# Stage 1
|
| 258 |
+
x = self.layer1(x)
|
| 259 |
+
x = [trans(x) for trans in self.transition1] # split to 2 branches, form a list.
|
| 260 |
+
|
| 261 |
+
# Stage 2
|
| 262 |
+
x = self.stage2(x)
|
| 263 |
+
x.append(self.transition2(x[-1]))
|
| 264 |
+
|
| 265 |
+
# Stage 3
|
| 266 |
+
x = self.stage3(x)
|
| 267 |
+
x.append(self.transition3(x[-1]))
|
| 268 |
+
|
| 269 |
+
# Stage 4
|
| 270 |
+
x = self.stage4(x)
|
| 271 |
+
|
| 272 |
+
# HRNetV2 Example: (follow paper, upsample via bilinear interpolation and to highest resolution size)
|
| 273 |
+
output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
|
| 274 |
+
x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
| 275 |
+
x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
| 276 |
+
x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
| 277 |
+
|
| 278 |
+
# Upsampling all the other resolution streams and then concatenate all (rather than adding/fusing like HRNetV1)
|
| 279 |
+
x = torch.cat([x[0], x1, x2, x3], dim=1)
|
| 280 |
+
x = self.bn_classifier(x)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _hrnet(arch, channels, num_blocks, pretrained, progress, **kwargs):
|
| 285 |
+
model = HRNet(channels, num_blocks, **kwargs)
|
| 286 |
+
if pretrained:
|
| 287 |
+
print("Loading pretrained backbone HRNetV2 model .....")
|
| 288 |
+
checkpoint = torch.load(CKPT_PATH)
|
| 289 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 290 |
+
return model
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def hrnetv2_48(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
|
| 294 |
+
w_channels = 48
|
| 295 |
+
return _hrnet('hrnetv2_48', w_channels, number_blocks, pretrained, progress,
|
| 296 |
+
**kwargs)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def hrnetv2_32(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
|
| 300 |
+
w_channels = 32
|
| 301 |
+
return _hrnet('hrnetv2_32', w_channels, number_blocks, pretrained, progress,
|
| 302 |
+
**kwargs)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == '__main__':
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
CKPT_PATH = os.path.join(os.path.abspath("."), '../../checkpoints/hrnetv2_32_model_best_epoch96.pth')
|
| 309 |
+
print("--- Running file as MAIN ---")
|
| 310 |
+
print(f"Backbone HRNET Pretrained weights as __main__ at: {CKPT_PATH}")
|
| 311 |
+
except:
|
| 312 |
+
print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
|
| 313 |
+
|
| 314 |
+
# Models
|
| 315 |
+
model = hrnetv2_32(pretrained=True)
|
| 316 |
+
#model = hrnetv2_48(pretrained=False)
|
| 317 |
+
|
| 318 |
+
if torch.cuda.is_available():
|
| 319 |
+
torch.backends.cudnn.deterministic = True
|
| 320 |
+
device = torch.device('cuda')
|
| 321 |
+
else:
|
| 322 |
+
device = torch.device('cpu')
|
| 323 |
+
model.to(device)
|
| 324 |
+
in_ = torch.ones(1, 3, 768, 768).to(device)
|
| 325 |
+
y = model(in_)
|
| 326 |
+
print(y.shape)
|
| 327 |
+
|
| 328 |
+
# Calculate total number of parameters:
|
| 329 |
+
# pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 330 |
+
# print(pytorch_total_params)
|
networks/backbone/mobilenetv2.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
try: # for torchvision<0.4
|
| 3 |
+
from torchvision.models.utils import load_state_dict_from_url
|
| 4 |
+
except: # for torchvision>=0.4
|
| 5 |
+
from torch.hub import load_state_dict_from_url
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
model_urls = {
|
| 12 |
+
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _make_divisible(v, divisor, min_value=None):
|
| 17 |
+
"""
|
| 18 |
+
This function is taken from the original tf repo.
|
| 19 |
+
It ensures that all layers have a channel number that is divisible by 8
|
| 20 |
+
It can be seen here:
|
| 21 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
| 22 |
+
:param v:
|
| 23 |
+
:param divisor:
|
| 24 |
+
:param min_value:
|
| 25 |
+
:return:
|
| 26 |
+
"""
|
| 27 |
+
if min_value is None:
|
| 28 |
+
min_value = divisor
|
| 29 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 30 |
+
# Make sure that round down does not go down by more than 10%.
|
| 31 |
+
if new_v < 0.9 * v:
|
| 32 |
+
new_v += divisor
|
| 33 |
+
return new_v
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ConvBNReLU(nn.Sequential):
|
| 37 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
|
| 38 |
+
#padding = (kernel_size - 1) // 2
|
| 39 |
+
super(ConvBNReLU, self).__init__(
|
| 40 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
|
| 41 |
+
nn.BatchNorm2d(out_planes),
|
| 42 |
+
nn.ReLU6(inplace=True)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def fixed_padding(kernel_size, dilation):
|
| 46 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
| 47 |
+
pad_total = kernel_size_effective - 1
|
| 48 |
+
pad_beg = pad_total // 2
|
| 49 |
+
pad_end = pad_total - pad_beg
|
| 50 |
+
return (pad_beg, pad_end, pad_beg, pad_end)
|
| 51 |
+
|
| 52 |
+
class InvertedResidual(nn.Module):
|
| 53 |
+
def __init__(self, inp, oup, stride, dilation, expand_ratio):
|
| 54 |
+
super(InvertedResidual, self).__init__()
|
| 55 |
+
self.stride = stride
|
| 56 |
+
assert stride in [1, 2]
|
| 57 |
+
|
| 58 |
+
hidden_dim = int(round(inp * expand_ratio))
|
| 59 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
| 60 |
+
|
| 61 |
+
layers = []
|
| 62 |
+
if expand_ratio != 1:
|
| 63 |
+
# pw
|
| 64 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
| 65 |
+
|
| 66 |
+
layers.extend([
|
| 67 |
+
# dw
|
| 68 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
|
| 69 |
+
# pw-linear
|
| 70 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 71 |
+
nn.BatchNorm2d(oup),
|
| 72 |
+
])
|
| 73 |
+
self.conv = nn.Sequential(*layers)
|
| 74 |
+
|
| 75 |
+
self.input_padding = fixed_padding( 3, dilation )
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
x_pad = F.pad(x, self.input_padding)
|
| 79 |
+
if self.use_res_connect:
|
| 80 |
+
return x + self.conv(x_pad)
|
| 81 |
+
else:
|
| 82 |
+
return self.conv(x_pad)
|
| 83 |
+
|
| 84 |
+
class MobileNetV2(nn.Module):
|
| 85 |
+
def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
|
| 86 |
+
"""
|
| 87 |
+
MobileNet V2 main class
|
| 88 |
+
Args:
|
| 89 |
+
num_classes (int): Number of classes
|
| 90 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
| 91 |
+
inverted_residual_setting: Network structure
|
| 92 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
| 93 |
+
Set to 1 to turn off rounding
|
| 94 |
+
"""
|
| 95 |
+
super(MobileNetV2, self).__init__()
|
| 96 |
+
block = InvertedResidual
|
| 97 |
+
input_channel = 32
|
| 98 |
+
last_channel = 1280
|
| 99 |
+
self.output_stride = output_stride
|
| 100 |
+
current_stride = 1
|
| 101 |
+
if inverted_residual_setting is None:
|
| 102 |
+
inverted_residual_setting = [
|
| 103 |
+
# t, c, n, s
|
| 104 |
+
[1, 16, 1, 1],
|
| 105 |
+
[6, 24, 2, 2],
|
| 106 |
+
[6, 32, 3, 2],
|
| 107 |
+
[6, 64, 4, 2],
|
| 108 |
+
[6, 96, 3, 1],
|
| 109 |
+
[6, 160, 3, 2],
|
| 110 |
+
[6, 320, 1, 1],
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
| 114 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
| 115 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
| 116 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
| 117 |
+
|
| 118 |
+
# building first layer
|
| 119 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
| 120 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
| 121 |
+
features = [ConvBNReLU(3, input_channel, stride=2)]
|
| 122 |
+
current_stride *= 2
|
| 123 |
+
dilation=1
|
| 124 |
+
previous_dilation = 1
|
| 125 |
+
|
| 126 |
+
# building inverted residual blocks
|
| 127 |
+
for t, c, n, s in inverted_residual_setting:
|
| 128 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
| 129 |
+
previous_dilation = dilation
|
| 130 |
+
if current_stride == output_stride:
|
| 131 |
+
stride = 1
|
| 132 |
+
dilation *= s
|
| 133 |
+
else:
|
| 134 |
+
stride = s
|
| 135 |
+
current_stride *= s
|
| 136 |
+
output_channel = int(c * width_mult)
|
| 137 |
+
|
| 138 |
+
for i in range(n):
|
| 139 |
+
if i==0:
|
| 140 |
+
features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
|
| 141 |
+
else:
|
| 142 |
+
features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
|
| 143 |
+
input_channel = output_channel
|
| 144 |
+
# building last several layers
|
| 145 |
+
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
| 146 |
+
# make it nn.Sequential
|
| 147 |
+
self.features = nn.Sequential(*features)
|
| 148 |
+
|
| 149 |
+
# building classifier
|
| 150 |
+
self.classifier = nn.Sequential(
|
| 151 |
+
nn.Dropout(0.2),
|
| 152 |
+
nn.Linear(self.last_channel, num_classes),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# weight initialization
|
| 156 |
+
for m in self.modules():
|
| 157 |
+
if isinstance(m, nn.Conv2d):
|
| 158 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
| 159 |
+
if m.bias is not None:
|
| 160 |
+
nn.init.zeros_(m.bias)
|
| 161 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 162 |
+
nn.init.ones_(m.weight)
|
| 163 |
+
nn.init.zeros_(m.bias)
|
| 164 |
+
elif isinstance(m, nn.Linear):
|
| 165 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 166 |
+
nn.init.zeros_(m.bias)
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
x = self.features(x)
|
| 170 |
+
x = x.mean([2, 3])
|
| 171 |
+
x = self.classifier(x)
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
|
| 176 |
+
"""
|
| 177 |
+
Constructs a MobileNetV2 architecture from
|
| 178 |
+
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
| 179 |
+
Args:
|
| 180 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 181 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 182 |
+
"""
|
| 183 |
+
model = MobileNetV2(**kwargs)
|
| 184 |
+
if pretrained:
|
| 185 |
+
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
| 186 |
+
progress=progress)
|
| 187 |
+
model.load_state_dict(state_dict)
|
| 188 |
+
return model
|
networks/backbone/resnet.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
try: # for torchvision<0.4
|
| 4 |
+
from torchvision.models.utils import load_state_dict_from_url
|
| 5 |
+
except: # for torchvision>=0.4
|
| 6 |
+
from torch.hub import load_state_dict_from_url
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
| 10 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
| 11 |
+
'wide_resnet50_2', 'wide_resnet101_2']
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
model_urls = {
|
| 15 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 16 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 17 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 18 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 19 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 20 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
| 21 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
| 22 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
| 23 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 28 |
+
"""3x3 convolution with padding"""
|
| 29 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 30 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 34 |
+
"""1x1 convolution"""
|
| 35 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BasicBlock(nn.Module):
|
| 39 |
+
expansion = 1
|
| 40 |
+
|
| 41 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 42 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 43 |
+
super(BasicBlock, self).__init__()
|
| 44 |
+
if norm_layer is None:
|
| 45 |
+
norm_layer = nn.BatchNorm2d
|
| 46 |
+
if groups != 1 or base_width != 64:
|
| 47 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 48 |
+
if dilation > 1:
|
| 49 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 50 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 51 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 52 |
+
self.bn1 = norm_layer(planes)
|
| 53 |
+
self.relu = nn.ReLU(inplace=True)
|
| 54 |
+
self.conv2 = conv3x3(planes, planes)
|
| 55 |
+
self.bn2 = norm_layer(planes)
|
| 56 |
+
self.downsample = downsample
|
| 57 |
+
self.stride = stride
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
identity = x
|
| 61 |
+
|
| 62 |
+
out = self.conv1(x)
|
| 63 |
+
out = self.bn1(out)
|
| 64 |
+
out = self.relu(out)
|
| 65 |
+
|
| 66 |
+
out = self.conv2(out)
|
| 67 |
+
out = self.bn2(out)
|
| 68 |
+
|
| 69 |
+
if self.downsample is not None:
|
| 70 |
+
identity = self.downsample(x)
|
| 71 |
+
|
| 72 |
+
out += identity
|
| 73 |
+
out = self.relu(out)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Bottleneck(nn.Module):
|
| 79 |
+
expansion = 4
|
| 80 |
+
|
| 81 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 82 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 83 |
+
super(Bottleneck, self).__init__()
|
| 84 |
+
if norm_layer is None:
|
| 85 |
+
norm_layer = nn.BatchNorm2d
|
| 86 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 87 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 88 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 89 |
+
self.bn1 = norm_layer(width)
|
| 90 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 91 |
+
self.bn2 = norm_layer(width)
|
| 92 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 93 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 94 |
+
self.relu = nn.ReLU(inplace=True)
|
| 95 |
+
self.downsample = downsample
|
| 96 |
+
self.stride = stride
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
identity = x
|
| 100 |
+
|
| 101 |
+
out = self.conv1(x)
|
| 102 |
+
out = self.bn1(out)
|
| 103 |
+
out = self.relu(out)
|
| 104 |
+
|
| 105 |
+
out = self.conv2(out)
|
| 106 |
+
out = self.bn2(out)
|
| 107 |
+
out = self.relu(out)
|
| 108 |
+
|
| 109 |
+
out = self.conv3(out)
|
| 110 |
+
out = self.bn3(out)
|
| 111 |
+
|
| 112 |
+
if self.downsample is not None:
|
| 113 |
+
identity = self.downsample(x)
|
| 114 |
+
|
| 115 |
+
out += identity
|
| 116 |
+
out = self.relu(out)
|
| 117 |
+
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class ResNet(nn.Module):
|
| 122 |
+
|
| 123 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
| 124 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 125 |
+
norm_layer=None):
|
| 126 |
+
super(ResNet, self).__init__()
|
| 127 |
+
if norm_layer is None:
|
| 128 |
+
norm_layer = nn.BatchNorm2d
|
| 129 |
+
self._norm_layer = norm_layer
|
| 130 |
+
|
| 131 |
+
self.inplanes = 64
|
| 132 |
+
self.dilation = 1
|
| 133 |
+
if replace_stride_with_dilation is None:
|
| 134 |
+
# each element in the tuple indicates if we should replace
|
| 135 |
+
# the 2x2 stride with a dilated convolution instead
|
| 136 |
+
replace_stride_with_dilation = [False, False, False]
|
| 137 |
+
if len(replace_stride_with_dilation) != 3:
|
| 138 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 139 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 140 |
+
self.groups = groups
|
| 141 |
+
self.base_width = width_per_group
|
| 142 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
| 143 |
+
bias=False)
|
| 144 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 145 |
+
self.relu = nn.ReLU(inplace=True)
|
| 146 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 147 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 148 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 149 |
+
dilate=replace_stride_with_dilation[0])
|
| 150 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 151 |
+
dilate=replace_stride_with_dilation[1])
|
| 152 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 153 |
+
dilate=replace_stride_with_dilation[2])
|
| 154 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 155 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 156 |
+
|
| 157 |
+
for m in self.modules():
|
| 158 |
+
if isinstance(m, nn.Conv2d):
|
| 159 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 160 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 161 |
+
nn.init.constant_(m.weight, 1)
|
| 162 |
+
nn.init.constant_(m.bias, 0)
|
| 163 |
+
|
| 164 |
+
# Zero-initialize the last BN in each residual branch,
|
| 165 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 166 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 167 |
+
if zero_init_residual:
|
| 168 |
+
for m in self.modules():
|
| 169 |
+
if isinstance(m, Bottleneck):
|
| 170 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 171 |
+
elif isinstance(m, BasicBlock):
|
| 172 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 173 |
+
|
| 174 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 175 |
+
norm_layer = self._norm_layer
|
| 176 |
+
downsample = None
|
| 177 |
+
previous_dilation = self.dilation
|
| 178 |
+
if dilate:
|
| 179 |
+
self.dilation *= stride
|
| 180 |
+
stride = 1
|
| 181 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 182 |
+
downsample = nn.Sequential(
|
| 183 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 184 |
+
norm_layer(planes * block.expansion),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
layers = []
|
| 188 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 189 |
+
self.base_width, previous_dilation, norm_layer))
|
| 190 |
+
self.inplanes = planes * block.expansion
|
| 191 |
+
for _ in range(1, blocks):
|
| 192 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 193 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 194 |
+
norm_layer=norm_layer))
|
| 195 |
+
|
| 196 |
+
return nn.Sequential(*layers)
|
| 197 |
+
|
| 198 |
+
def forward(self, x):
|
| 199 |
+
x = self.conv1(x)
|
| 200 |
+
x = self.bn1(x)
|
| 201 |
+
x = self.relu(x)
|
| 202 |
+
x = self.maxpool(x)
|
| 203 |
+
|
| 204 |
+
x = self.layer1(x)
|
| 205 |
+
x = self.layer2(x)
|
| 206 |
+
x = self.layer3(x)
|
| 207 |
+
x = self.layer4(x)
|
| 208 |
+
|
| 209 |
+
x = self.avgpool(x)
|
| 210 |
+
x = torch.flatten(x, 1)
|
| 211 |
+
x = self.fc(x)
|
| 212 |
+
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
| 217 |
+
model = ResNet(block, layers, **kwargs)
|
| 218 |
+
if pretrained:
|
| 219 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
| 220 |
+
progress=progress)
|
| 221 |
+
model.load_state_dict(state_dict)
|
| 222 |
+
return model
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
| 226 |
+
r"""ResNet-18 model from
|
| 227 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 228 |
+
Args:
|
| 229 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 230 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 231 |
+
"""
|
| 232 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
| 233 |
+
**kwargs)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
| 237 |
+
r"""ResNet-34 model from
|
| 238 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 239 |
+
Args:
|
| 240 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 241 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 242 |
+
"""
|
| 243 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
| 244 |
+
**kwargs)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
| 248 |
+
r"""ResNet-50 model from
|
| 249 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 250 |
+
Args:
|
| 251 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 252 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 253 |
+
"""
|
| 254 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
| 255 |
+
**kwargs)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
| 259 |
+
r"""ResNet-101 model from
|
| 260 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 261 |
+
Args:
|
| 262 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 263 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 264 |
+
"""
|
| 265 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
| 266 |
+
**kwargs)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
| 270 |
+
r"""ResNet-152 model from
|
| 271 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 272 |
+
Args:
|
| 273 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 274 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 275 |
+
"""
|
| 276 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
| 277 |
+
**kwargs)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
| 281 |
+
r"""ResNeXt-50 32x4d model from
|
| 282 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
| 283 |
+
Args:
|
| 284 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 285 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 286 |
+
"""
|
| 287 |
+
kwargs['groups'] = 32
|
| 288 |
+
kwargs['width_per_group'] = 4
|
| 289 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
| 290 |
+
pretrained, progress, **kwargs)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
| 294 |
+
r"""ResNeXt-101 32x8d model from
|
| 295 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
| 296 |
+
Args:
|
| 297 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 298 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 299 |
+
"""
|
| 300 |
+
kwargs['groups'] = 32
|
| 301 |
+
kwargs['width_per_group'] = 8
|
| 302 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
| 303 |
+
pretrained, progress, **kwargs)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
| 307 |
+
r"""Wide ResNet-50-2 model from
|
| 308 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
| 309 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
| 310 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
| 311 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
| 312 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
| 313 |
+
Args:
|
| 314 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 315 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 316 |
+
"""
|
| 317 |
+
kwargs['width_per_group'] = 64 * 2
|
| 318 |
+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
| 319 |
+
pretrained, progress, **kwargs)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
| 323 |
+
r"""Wide ResNet-101-2 model from
|
| 324 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
| 325 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
| 326 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
| 327 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
| 328 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
| 329 |
+
Args:
|
| 330 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 331 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 332 |
+
"""
|
| 333 |
+
kwargs['width_per_group'] = 64 * 2
|
| 334 |
+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
| 335 |
+
pretrained, progress, **kwargs)
|
networks/modeling.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from networks.deeplab.utils import IntermediateLayerGetter
|
| 2 |
+
from networks.deeplab._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
|
| 3 |
+
from networks.deeplab.backbone import resnet, mobilenetv2, hrnetv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _segm_hrnet(name, backbone_name, num_classes, pretrained_backbone):
|
| 7 |
+
backbone = hrnetv2.__dict__[backbone_name](pretrained_backbone)
|
| 8 |
+
# HRNetV2 config:
|
| 9 |
+
# the final output channels is dependent on highest resolution channel config (c).
|
| 10 |
+
# output of backbone will be the inplanes to assp:
|
| 11 |
+
hrnet_channels = int(backbone_name.split('_')[-1])
|
| 12 |
+
inplanes = sum([hrnet_channels * 2 ** i for i in range(4)])
|
| 13 |
+
low_level_planes = 256 # all hrnet version channel output from bottleneck is the same
|
| 14 |
+
aspp_dilate = [12, 24, 36] # If follow paper trend, can put [24, 48, 72].
|
| 15 |
+
|
| 16 |
+
if name == 'deeplabv3plus':
|
| 17 |
+
return_layers = {'stage4': 'out', 'layer1': 'low_level'}
|
| 18 |
+
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
|
| 19 |
+
elif name == 'deeplabv3':
|
| 20 |
+
return_layers = {'stage4': 'out'}
|
| 21 |
+
classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
|
| 22 |
+
|
| 23 |
+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers, hrnet_flag=True)
|
| 24 |
+
model = DeepLabV3(backbone, classifier)
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
|
| 29 |
+
if output_stride == 8:
|
| 30 |
+
replace_stride_with_dilation = [False, True, True]
|
| 31 |
+
aspp_dilate = [12, 24, 36]
|
| 32 |
+
else:
|
| 33 |
+
replace_stride_with_dilation = [False, False, True]
|
| 34 |
+
aspp_dilate = [6, 12, 18]
|
| 35 |
+
|
| 36 |
+
backbone = resnet.__dict__[backbone_name](
|
| 37 |
+
pretrained=pretrained_backbone,
|
| 38 |
+
replace_stride_with_dilation=replace_stride_with_dilation)
|
| 39 |
+
|
| 40 |
+
inplanes = 2048
|
| 41 |
+
low_level_planes = 256
|
| 42 |
+
|
| 43 |
+
if name == 'deeplabv3plus':
|
| 44 |
+
return_layers = {'layer4': 'out', 'layer1': 'low_level'}
|
| 45 |
+
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
|
| 46 |
+
elif name == 'deeplabv3':
|
| 47 |
+
return_layers = {'layer4': 'out'}
|
| 48 |
+
classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
|
| 49 |
+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 50 |
+
|
| 51 |
+
model = DeepLabV3(backbone, classifier)
|
| 52 |
+
return model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
|
| 56 |
+
if output_stride == 8:
|
| 57 |
+
aspp_dilate = [12, 24, 36]
|
| 58 |
+
else:
|
| 59 |
+
aspp_dilate = [6, 12, 18]
|
| 60 |
+
|
| 61 |
+
backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride)
|
| 62 |
+
|
| 63 |
+
# rename layers
|
| 64 |
+
backbone.low_level_features = backbone.features[0:4]
|
| 65 |
+
backbone.high_level_features = backbone.features[4:-1]
|
| 66 |
+
backbone.features = None
|
| 67 |
+
backbone.classifier = None
|
| 68 |
+
|
| 69 |
+
inplanes = 320
|
| 70 |
+
low_level_planes = 24
|
| 71 |
+
|
| 72 |
+
if name == 'deeplabv3plus':
|
| 73 |
+
return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'}
|
| 74 |
+
classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
|
| 75 |
+
elif name == 'deeplabv3':
|
| 76 |
+
return_layers = {'high_level_features': 'out'}
|
| 77 |
+
classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
|
| 78 |
+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 79 |
+
|
| 80 |
+
model = DeepLabV3(backbone, classifier)
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
|
| 85 |
+
if backbone == 'mobilenetv2':
|
| 86 |
+
model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride,
|
| 87 |
+
pretrained_backbone=pretrained_backbone)
|
| 88 |
+
elif backbone.startswith('resnet'):
|
| 89 |
+
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride,
|
| 90 |
+
pretrained_backbone=pretrained_backbone)
|
| 91 |
+
elif backbone.startswith('hrnetv2'):
|
| 92 |
+
model = _segm_hrnet(arch_type, backbone, num_classes, pretrained_backbone=pretrained_backbone)
|
| 93 |
+
else:
|
| 94 |
+
raise NotImplementedError
|
| 95 |
+
return model
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Deeplab v3
|
| 99 |
+
def deeplabv3_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False): # no pretrained backbone yet
|
| 100 |
+
return _load_model('deeplabv3', 'hrnetv2_48', output_stride, num_classes, pretrained_backbone=pretrained_backbone)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def deeplabv3_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True):
|
| 104 |
+
return _load_model('deeplabv3', 'hrnetv2_32', output_stride, num_classes, pretrained_backbone=pretrained_backbone)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
|
| 108 |
+
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
|
| 109 |
+
Args:
|
| 110 |
+
num_classes (int): number of classes.
|
| 111 |
+
output_stride (int): output stride for deeplab.
|
| 112 |
+
pretrained_backbone (bool): If True, use the pretrained backbone.
|
| 113 |
+
"""
|
| 114 |
+
return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride,
|
| 115 |
+
pretrained_backbone=pretrained_backbone)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
|
| 119 |
+
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
|
| 120 |
+
Args:
|
| 121 |
+
num_classes (int): number of classes.
|
| 122 |
+
output_stride (int): output stride for deeplab.
|
| 123 |
+
pretrained_backbone (bool): If True, use the pretrained backbone.
|
| 124 |
+
"""
|
| 125 |
+
return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride,
|
| 126 |
+
pretrained_backbone=pretrained_backbone)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):
|
| 130 |
+
"""Constructs a DeepLabV3 model with a MobileNetv2 backbone.
|
| 131 |
+
Args:
|
| 132 |
+
num_classes (int): number of classes.
|
| 133 |
+
output_stride (int): output stride for deeplab.
|
| 134 |
+
pretrained_backbone (bool): If True, use the pretrained backbone.
|
| 135 |
+
"""
|
| 136 |
+
return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride,
|
| 137 |
+
pretrained_backbone=pretrained_backbone)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# Deeplab v3+
|
| 141 |
+
def deeplabv3plus_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False): # no pretrained backbone yet
|
| 142 |
+
return _load_model('deeplabv3plus', 'hrnetv2_48', num_classes, output_stride,
|
| 143 |
+
pretrained_backbone=pretrained_backbone)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def deeplabv3plus_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True):
|
| 147 |
+
return _load_model('deeplabv3plus', 'hrnetv2_32', num_classes, output_stride,
|
| 148 |
+
pretrained_backbone=pretrained_backbone)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
|
| 152 |
+
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
|
| 153 |
+
Args:
|
| 154 |
+
num_classes (int): number of classes.
|
| 155 |
+
output_stride (int): output stride for deeplab.
|
| 156 |
+
pretrained_backbone (bool): If True, use the pretrained backbone.
|
| 157 |
+
"""
|
| 158 |
+
return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride,
|
| 159 |
+
pretrained_backbone=pretrained_backbone)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
|
| 163 |
+
"""Constructs a DeepLabV3+ model with a ResNet-101 backbone.
|
| 164 |
+
Args:
|
| 165 |
+
num_classes (int): number of classes.
|
| 166 |
+
output_stride (int): output stride for deeplab.
|
| 167 |
+
pretrained_backbone (bool): If True, use the pretrained backbone.
|
| 168 |
+
"""
|
| 169 |
+
return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride,
|
| 170 |
+
pretrained_backbone=pretrained_backbone)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True):
|
| 174 |
+
"""Constructs a DeepLabV3+ model with a MobileNetv2 backbone.
|
| 175 |
+
Args:
|
| 176 |
+
num_classes (int): number of classes.
|
| 177 |
+
output_stride (int): output stride for deeplab.
|
| 178 |
+
pretrained_backbone (bool): If True, use the pretrained backbone.
|
| 179 |
+
"""
|
| 180 |
+
return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride,
|
| 181 |
+
pretrained_backbone=pretrained_backbone)
|
networks/utils.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _SimpleSegmentationModel(nn.Module):
|
| 9 |
+
def __init__(self, backbone, classifier):
|
| 10 |
+
super(_SimpleSegmentationModel, self).__init__()
|
| 11 |
+
self.backbone = backbone
|
| 12 |
+
self.classifier = classifier
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
input_shape = x.shape[-2:]
|
| 16 |
+
features = self.backbone(x)
|
| 17 |
+
x = self.classifier(features)
|
| 18 |
+
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class IntermediateLayerGetter(nn.ModuleDict):
|
| 23 |
+
"""
|
| 24 |
+
Module wrapper that returns intermediate layers from a model
|
| 25 |
+
It has a strong assumption that the modules have been registered
|
| 26 |
+
into the model in the same order as they are used.
|
| 27 |
+
This means that one should **not** reuse the same nn.Module
|
| 28 |
+
twice in the forward if you want this to work.
|
| 29 |
+
Additionally, it is only able to query submodules that are directly
|
| 30 |
+
assigned to the model. So if `model` is passed, `model.feature1` can
|
| 31 |
+
be returned, but not `model.feature1.layer2`.
|
| 32 |
+
Arguments:
|
| 33 |
+
model (nn.Module): model on which we will extract the features
|
| 34 |
+
return_layers (Dict[name, new_name]): a dict containing the names
|
| 35 |
+
of the modules for which the activations will be returned as
|
| 36 |
+
the key of the dict, and the value of the dict is the name
|
| 37 |
+
of the returned activation (which the user can specify).
|
| 38 |
+
Examples::
|
| 39 |
+
>>> m = torchvision.models.resnet18(pretrained=True)
|
| 40 |
+
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
| 41 |
+
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
|
| 42 |
+
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
| 43 |
+
>>> out = new_m(torch.rand(1, 3, 224, 224))
|
| 44 |
+
>>> print([(k, v.shape) for k, v in out.items()])
|
| 45 |
+
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
| 46 |
+
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, model, return_layers, hrnet_flag=False):
|
| 50 |
+
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
| 51 |
+
raise ValueError("return_layers are not present in model")
|
| 52 |
+
|
| 53 |
+
self.hrnet_flag = hrnet_flag
|
| 54 |
+
|
| 55 |
+
orig_return_layers = return_layers
|
| 56 |
+
return_layers = {k: v for k, v in return_layers.items()}
|
| 57 |
+
layers = OrderedDict()
|
| 58 |
+
for name, module in model.named_children():
|
| 59 |
+
layers[name] = module
|
| 60 |
+
if name in return_layers:
|
| 61 |
+
del return_layers[name]
|
| 62 |
+
if not return_layers:
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
super(IntermediateLayerGetter, self).__init__(layers)
|
| 66 |
+
self.return_layers = orig_return_layers
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
out = OrderedDict()
|
| 70 |
+
for name, module in self.named_children():
|
| 71 |
+
if self.hrnet_flag and name.startswith('transition'): # if using hrnet, you need to take care of transition
|
| 72 |
+
if name == 'transition1': # in transition1, you need to split the module to two streams first
|
| 73 |
+
x = [trans(x) for trans in module]
|
| 74 |
+
else: # all other transition is just an extra one stream split
|
| 75 |
+
x.append(module(x[-1]))
|
| 76 |
+
else: # other models (ex:resnet,mobilenet) are convolutions in series.
|
| 77 |
+
x = module(x)
|
| 78 |
+
|
| 79 |
+
if name in self.return_layers:
|
| 80 |
+
out_name = self.return_layers[name]
|
| 81 |
+
if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together
|
| 82 |
+
output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
|
| 83 |
+
x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
| 84 |
+
x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
| 85 |
+
x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
|
| 86 |
+
x = torch.cat([x[0], x1, x2, x3], dim=1)
|
| 87 |
+
out[out_name] = x
|
| 88 |
+
else:
|
| 89 |
+
out[out_name] = x
|
| 90 |
+
return out
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opencv-contrib-python==4.5.5.62
|
| 2 |
+
torch==1.11.0
|
| 3 |
+
torchvision==0.12.0
|
| 4 |
+
timm==0.4.12
|
| 5 |
+
scipy==1.6.2
|
utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Tuple, Union
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from networks import deeplabv3plus_resnet50
|
| 5 |
+
from networks import convert_to_separable_conv, set_bn_momentum
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_network() -> torch.nn.Module:
|
| 9 |
+
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
|
| 10 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 11 |
+
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt"
|
| 12 |
+
)
|
| 13 |
+
network.backbone.load_state_dict(state_dict, strict=True)
|
| 14 |
+
convert_to_separable_conv(network.classifier)
|
| 15 |
+
set_bn_momentum(network.backbone, momentum=0.01)
|
| 16 |
+
return network
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def colourise_mask(
|
| 20 |
+
mask: np.ndarray,
|
| 21 |
+
):
|
| 22 |
+
assert len(mask.shape) == 2, ValueError(mask.shape)
|
| 23 |
+
h, w = mask.shape
|
| 24 |
+
grid = np.zeros((h, w, 3), dtype=np.uint8)
|
| 25 |
+
|
| 26 |
+
unique_labels = set(mask.flatten())
|
| 27 |
+
|
| 28 |
+
voc2012_palette = {
|
| 29 |
+
0: [0, 0, 0],
|
| 30 |
+
1: [128, 0, 0],
|
| 31 |
+
2: [0, 128, 0],
|
| 32 |
+
3: [128, 128, 0],
|
| 33 |
+
4: [0, 0, 128],
|
| 34 |
+
5: [128, 0, 128],
|
| 35 |
+
6: [0, 128, 128],
|
| 36 |
+
7: [128, 128, 128],
|
| 37 |
+
8: [64, 0, 0],
|
| 38 |
+
9: [192, 0, 0],
|
| 39 |
+
10: [64, 128, 0],
|
| 40 |
+
11: [192, 128, 0],
|
| 41 |
+
12: [64, 0, 128],
|
| 42 |
+
13: [192, 0, 128],
|
| 43 |
+
14: [64, 128, 128],
|
| 44 |
+
15: [192, 128, 128],
|
| 45 |
+
16: [0, 64, 0],
|
| 46 |
+
17: [128, 64, 0],
|
| 47 |
+
18: [0, 192, 0],
|
| 48 |
+
19: [128, 192, 0],
|
| 49 |
+
20: [0, 64, 128],
|
| 50 |
+
255: [255, 255, 255]
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
for l in unique_labels:
|
| 54 |
+
grid[mask == l] = np.array(voc2012_palette[l])
|
| 55 |
+
try:
|
| 56 |
+
grid[mask == l] = np.array(voc2012_palette[l])
|
| 57 |
+
except IndexError:
|
| 58 |
+
raise IndexError(f"No colour is found for a label id: {l}")
|
| 59 |
+
return grid
|
voc_val_n500_cp2_ex.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# base directories
|
| 2 |
+
category_to_p_images_fp: "/home/cs-shin1/datasets/ImageNet2012/voc2012_category_to_p_images_n500.json"
|
| 3 |
+
dir_ckpt: "/home/cs-shin1/namedmask/ckpt"
|
| 4 |
+
dir_train_dataset: "/home/cs-shin1/datasets/ImageNet2012"
|
| 5 |
+
dir_val_dataset: "/home/cs-shin1/datasets/VOCdevkit/VOC2012"
|
| 6 |
+
|
| 7 |
+
# augmentations
|
| 8 |
+
max_n_masks: 2
|
| 9 |
+
scale_range: [ 0.1, 1.0 ]
|
| 10 |
+
|
| 11 |
+
use_expert_pseudo_masks: true
|
| 12 |
+
category_agnostic: false
|
| 13 |
+
|
| 14 |
+
n_categories: 21
|
| 15 |
+
categories: [
|
| 16 |
+
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "dining table",
|
| 17 |
+
"dog", "horse", "motorbike", "person", "potted plant", "sheep", "sofa", "train", "tv/monitor"
|
| 18 |
+
]
|
| 19 |
+
n_images: 500
|
| 20 |
+
|
| 21 |
+
# dataset
|
| 22 |
+
dataset_name: "voc2012"
|
| 23 |
+
split: "val"
|
| 24 |
+
train_image_size: 384
|
| 25 |
+
|
| 26 |
+
# dataloader:
|
| 27 |
+
train_dataloader_kwargs:
|
| 28 |
+
batch_size: 16
|
| 29 |
+
num_workers: 16
|
| 30 |
+
pin_memory: true
|
| 31 |
+
shuffle: true
|
| 32 |
+
|
| 33 |
+
val_dataloader_kwargs:
|
| 34 |
+
batch_size: 1
|
| 35 |
+
num_workers: 4
|
| 36 |
+
pin_memory: true
|
| 37 |
+
|
| 38 |
+
# Segmenter configuration
|
| 39 |
+
# ["deeplabv3plus_resnet101", "deeplabv3plus_resnet50", "deeplabv3plus_mobilenet"]
|
| 40 |
+
segmenter_name: "deeplabv3plus_resnet50"
|
| 41 |
+
|
| 42 |
+
# optimiser
|
| 43 |
+
lr: 0.0005
|
| 44 |
+
momentum: 0.9
|
| 45 |
+
weight_decay: 0.0002
|
| 46 |
+
betas: [0.9, 0.999]
|
| 47 |
+
n_iters: 20000
|
| 48 |
+
|
| 49 |
+
iter_eval: 1000
|
| 50 |
+
iter_log: 100
|