jizhengjia ybelkada commited on
Commit
6324b4f
Β·
0 Parent(s):

Duplicate from ybelkada/FocusOnDepth

Browse files

Co-authored-by: Younes Belkada <ybelkada@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.wasm filter=lfs diff=lfs merge=lfs -text
25
+ *.xz filter=lfs diff=lfs merge=lfs -text
26
+ *.zip filter=lfs diff=lfs merge=lfs -text
27
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
28
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FocusOnDepth
3
+ emoji: 🐨
4
+ colorFrom: pink
5
+ colorTo: pink
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: ybelkada/FocusOnDepth
10
+ ---
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ from torchvision import transforms
9
+
10
+ from transformers import AutoConfig, AutoModel
11
+ from transformers import AutoModel
12
+
13
+ from focusondepth.model_config import FocusOnDepthConfig
14
+ from focusondepth.model_definition import FocusOnDepth
15
+
16
+ AutoConfig.register("focusondepth", FocusOnDepthConfig)
17
+ AutoModel.register(FocusOnDepthConfig, FocusOnDepth)
18
+
19
+ transform = transforms.Compose([
20
+ transforms.Resize((384, 384)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
23
+ ])
24
+ model = AutoModel.from_pretrained('ybelkada/focusondepth', trust_remote_code=True)
25
+
26
+ @torch.no_grad()
27
+ def inference(input_image):
28
+ global model, transform
29
+
30
+ model.eval()
31
+ input_image = Image.fromarray(input_image)
32
+ original_size = input_image.size
33
+ tensor_image = transform(input_image)
34
+
35
+ depth, segmentation = model(tensor_image.unsqueeze(0))
36
+ depth = 1-depth
37
+
38
+ depth = transforms.ToPILImage()(depth[0, :])
39
+ segmentation = transforms.ToPILImage()(segmentation.argmax(dim=1).float())
40
+
41
+ return [depth.resize(original_size, resample=Image.BICUBIC), segmentation.resize(original_size, resample=Image.NEAREST)]
42
+
43
+ description = """
44
+ <center>
45
+ Can a single model predict both segmentation and depth estimation? At least, if the segmentation is constrained for a single class, the answer is yes! <br>
46
+ In this project, we use a DPT model to predict the depth and the segmentation mask of the class human, of an image. This model could be potentially used for an autofocus application where you would need the segmentation mask of the humans on the picture, as well as the depth estimation of the scene<br>
47
+ Credits also to <div style='text-align: center;'><a href='https://github.com/antocad' target='_blank'>@antocad</a> !
48
+ </center>
49
+ """
50
+ title="""
51
+ FocusOnDepth - A single DPT encoder for Dense Prediction Tasks
52
+ """
53
+ css = """
54
+ """
55
+ article = """
56
+ <center>
57
+ Example image taken from <a href="https://www.flickr.com/photos/17423713@N03/29129350066">here</a>. The image is free to share and use. <br>
58
+ </center>
59
+ <div style='text-align: center;'><a href='https://github.com/isl-org/DPT' target='_blank'>Original Paper</a> | <a href='https://github.com/antocad/FocusOnDepth' target='_blank'>Extended Version</a></div>
60
+ """
61
+
62
+ iface = gr.Interface(
63
+ fn=inference,
64
+ inputs=gr.inputs.Image(label="Input Image"),
65
+ outputs = [
66
+ gr.outputs.Image(label="Depth Map:"),
67
+ gr.outputs.Image(label="Segmentation Map:"),
68
+ ],
69
+ examples=['example_image.jpg'],
70
+ description=description,
71
+ title=title,
72
+ css=css,
73
+ article=article
74
+ )
75
+ iface.launch(enable_queue=True, cache_examples=True)
example_image.jpg ADDED
focusondepth/__init__.py ADDED
File without changes
focusondepth/fusion.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ResidualConvUnit(nn.Module):
6
+ def __init__(self, features):
7
+ super().__init__()
8
+
9
+ self.conv1 = nn.Conv2d(
10
+ features, features, kernel_size=3, stride=1, padding=1, bias=True)
11
+ self.conv2 = nn.Conv2d(
12
+ features, features, kernel_size=3, stride=1, padding=1, bias=True)
13
+ self.relu = nn.ReLU(inplace=True)
14
+
15
+ def forward(self, x):
16
+ """Forward pass.
17
+ Args:
18
+ x (tensor): input
19
+ Returns:
20
+ tensor: output
21
+ """
22
+ out = self.relu(x)
23
+ out = self.conv1(out)
24
+ out = self.relu(out)
25
+ out = self.conv2(out)
26
+ return out + x
27
+
28
+ class Fusion(nn.Module):
29
+ def __init__(self, resample_dim):
30
+ super(Fusion, self).__init__()
31
+ self.res_conv1 = ResidualConvUnit(resample_dim)
32
+ self.res_conv2 = ResidualConvUnit(resample_dim)
33
+
34
+ def forward(self, x, previous_stage=None):
35
+ if previous_stage == None:
36
+ previous_stage = torch.zeros_like(x)
37
+ output_stage1 = self.res_conv1(x)
38
+ output_stage1 += previous_stage
39
+ output_stage2 = self.res_conv2(output_stage1)
40
+ output_stage2 = nn.functional.interpolate(output_stage2, scale_factor=2, mode="bilinear", align_corners=True)
41
+ return output_stage2
focusondepth/head.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class Interpolate(nn.Module):
6
+ def __init__(self, scale_factor, mode, align_corners=False):
7
+ super(Interpolate, self).__init__()
8
+ self.interp = nn.functional.interpolate
9
+ self.scale_factor = scale_factor
10
+ self.mode = mode
11
+ self.align_corners = align_corners
12
+
13
+ def forward(self, x):
14
+ x = self.interp(
15
+ x,
16
+ scale_factor=self.scale_factor,
17
+ mode=self.mode,
18
+ align_corners=self.align_corners)
19
+ return x
20
+
21
+ class HeadDepth(nn.Module):
22
+ def __init__(self, features):
23
+ super(HeadDepth, self).__init__()
24
+ self.head = nn.Sequential(
25
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
26
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
27
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
28
+ nn.ReLU(),
29
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
30
+ # nn.ReLU()
31
+ nn.Sigmoid()
32
+ )
33
+ def forward(self, x):
34
+ x = self.head(x)
35
+ # x = (x - x.min())/(x.max()-x.min() + 1e-15)
36
+ return x
37
+
38
+ class HeadSeg(nn.Module):
39
+ def __init__(self, features, nclasses=2):
40
+ super(HeadSeg, self).__init__()
41
+ self.head = nn.Sequential(
42
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
43
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
44
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
45
+ nn.ReLU(),
46
+ nn.Conv2d(32, nclasses, kernel_size=1, stride=1, padding=0)
47
+ )
48
+ def forward(self, x):
49
+ x = self.head(x)
50
+ return x
focusondepth/model_config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class FocusOnDepthConfig(PretrainedConfig):
6
+ model_type = "focusondepth"
7
+
8
+ def __init__(
9
+ self,
10
+ image_size = (3, 384, 384),
11
+ patch_size = 16,
12
+ emb_dim = 768,
13
+ resample_dim = 256,
14
+ read = 'projection',
15
+ num_layers_encoder = 24,
16
+ hooks = [2, 5, 8, 11],
17
+ reassemble_s = [4, 8, 16, 32],
18
+ transformer_dropout= 0,
19
+ nclasses = 2,
20
+ type_ = "full",
21
+ model_timm = "vit_base_patch16_384",
22
+ **kwargs,
23
+ ):
24
+ if type_ not in ["full", "depth", "segmentation"]:
25
+ raise ValueError(f"`type_` must be 'full' or depth' or 'segmentation, got {type_}.")
26
+ if read not in ["ignore", "add", "projection"]:
27
+ raise ValueError(f"`read` must be '', 'ignore' or 'add' or 'projection, got {read}.")
28
+ if image_size[0] != 3 and image_size[1] != 384 and image_size[2] != 384:
29
+ raise ValueError(f"`image_size` must be 3, 384, 384, got {image_size}.")
30
+ if patch_size != 16:
31
+ raise ValueError(f"`patch_size` must be 16, got {patch_size}.")
32
+
33
+ self.image_size = image_size
34
+ self.patch_size = patch_size
35
+ self.emb_dim = emb_dim
36
+ self.resample_dim = resample_dim
37
+ self.read = read
38
+ self.num_layers_encoder = num_layers_encoder
39
+ self.hooks = hooks
40
+ self.reassemble_s = reassemble_s
41
+ self.transformer_dropout = transformer_dropout
42
+ self.nclasses = nclasses
43
+ self.type_ = type_
44
+ self.model_timm = model_timm
45
+ super().__init__(**kwargs)
focusondepth/model_definition.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ import timm
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from .model_config import FocusOnDepthConfig
7
+ from .reassemble import Reassemble
8
+ from .fusion import Fusion
9
+ from .head import HeadDepth, HeadSeg
10
+
11
+
12
+ class FocusOnDepth(PreTrainedModel):
13
+ config_class = FocusOnDepthConfig
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.transformer_encoders = timm.create_model(config.model_timm, pretrained=True)
18
+ self.type_ = config.type_
19
+
20
+ #Register hooks
21
+ self.activation = {}
22
+ self.hooks = config.hooks
23
+ self._get_layers_from_hooks(self.hooks)
24
+
25
+ #Reassembles Fusion
26
+ self.reassembles = []
27
+ self.fusions = []
28
+ for s in config.reassemble_s:
29
+ self.reassembles.append(Reassemble(config.image_size, config.read, config.patch_size, s, config.emb_dim, config.resample_dim))
30
+ self.fusions.append(Fusion(config.resample_dim))
31
+ self.reassembles = nn.ModuleList(self.reassembles)
32
+ self.fusions = nn.ModuleList(self.fusions)
33
+
34
+ #Head
35
+ if self.type_ == "full":
36
+ self.head_depth = HeadDepth(config.resample_dim)
37
+ self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
38
+ elif self.type_ == "depth":
39
+ self.head_depth = HeadDepth(config.resample_dim)
40
+ self.head_segmentation = None
41
+ else:
42
+ self.head_depth = None
43
+ self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
44
+
45
+ def forward(self, img):
46
+ _ = self.transformer_encoders(img)
47
+ previous_stage = None
48
+ for i in np.arange(len(self.fusions)-1, -1, -1):
49
+ hook_to_take = 't'+str(self.hooks[i])
50
+ activation_result = self.activation[hook_to_take]
51
+ reassemble_result = self.reassembles[i](activation_result)
52
+ fusion_result = self.fusions[i](reassemble_result, previous_stage)
53
+ previous_stage = fusion_result
54
+ out_depth = None
55
+ out_segmentation = None
56
+ if self.head_depth != None:
57
+ out_depth = self.head_depth(previous_stage)
58
+ if self.head_segmentation != None:
59
+ out_segmentation = self.head_segmentation(previous_stage)
60
+ return out_depth, out_segmentation
61
+
62
+ def _get_layers_from_hooks(self, hooks):
63
+ def get_activation(name):
64
+ def hook(model, input, output):
65
+ self.activation[name] = output
66
+ return hook
67
+ for h in hooks:
68
+ self.transformer_encoders.blocks[h].register_forward_hook(get_activation('t'+str(h)))
focusondepth/reassemble.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import Rearrange
6
+
7
+ class Read_ignore(nn.Module):
8
+ def __init__(self, start_index=1):
9
+ super(Read_ignore, self).__init__()
10
+ self.start_index = start_index
11
+
12
+ def forward(self, x):
13
+ return x[:, self.start_index:]
14
+
15
+
16
+ class Read_add(nn.Module):
17
+ def __init__(self, start_index=1):
18
+ super(Read_add, self).__init__()
19
+ self.start_index = start_index
20
+
21
+ def forward(self, x):
22
+ if self.start_index == 2:
23
+ readout = (x[:, 0] + x[:, 1]) / 2
24
+ else:
25
+ readout = x[:, 0]
26
+ return x[:, self.start_index :] + readout.unsqueeze(1)
27
+
28
+
29
+ class Read_projection(nn.Module):
30
+ def __init__(self, in_features, start_index=1):
31
+ super(Read_projection, self).__init__()
32
+ self.start_index = start_index
33
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
34
+
35
+ def forward(self, x):
36
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
37
+ features = torch.cat((x[:, self.start_index :], readout), -1)
38
+ return self.project(features)
39
+
40
+ class MyConvTranspose2d(nn.Module):
41
+ def __init__(self, conv, output_size):
42
+ super(MyConvTranspose2d, self).__init__()
43
+ self.output_size = output_size
44
+ self.conv = conv
45
+
46
+ def forward(self, x):
47
+ x = self.conv(x, output_size=self.output_size)
48
+ return x
49
+
50
+ class Resample(nn.Module):
51
+ def __init__(self, p, s, h, emb_dim, resample_dim):
52
+ super(Resample, self).__init__()
53
+ assert (s in [4, 8, 16, 32]), "s must be in [0.5, 4, 8, 16, 32]"
54
+ self.conv1 = nn.Conv2d(emb_dim, resample_dim, kernel_size=1, stride=1, padding=0)
55
+ if s == 4:
56
+ self.conv2 = nn.ConvTranspose2d(resample_dim,
57
+ resample_dim,
58
+ kernel_size=4,
59
+ stride=4,
60
+ padding=0,
61
+ bias=True,
62
+ dilation=1,
63
+ groups=1)
64
+ elif s == 8:
65
+ self.conv2 = nn.ConvTranspose2d(resample_dim,
66
+ resample_dim,
67
+ kernel_size=2,
68
+ stride=2,
69
+ padding=0,
70
+ bias=True,
71
+ dilation=1,
72
+ groups=1)
73
+ elif s == 16:
74
+ self.conv2 = nn.Identity()
75
+ else:
76
+ self.conv2 = nn.Conv2d(resample_dim, resample_dim, kernel_size=2,stride=2, padding=0, bias=True)
77
+
78
+ def forward(self, x):
79
+ x = self.conv1(x)
80
+ x = self.conv2(x)
81
+ return x
82
+
83
+ class Reassemble(nn.Module):
84
+ def __init__(self, image_size, read, p, s, emb_dim, resample_dim):
85
+ """
86
+ p = patch size
87
+ s = coefficient resample
88
+ emb_dim <=> D (in the paper)
89
+ resample_dim <=> ^D (in the paper)
90
+ read : {"ignore", "add", "projection"}
91
+ """
92
+ super(Reassemble, self).__init__()
93
+ channels, image_height, image_width = image_size
94
+
95
+ #Read
96
+ self.read = Read_ignore()
97
+ if read == 'add':
98
+ self.read = Read_add()
99
+ elif read == 'projection':
100
+ self.read = Read_projection(emb_dim)
101
+
102
+ #Concat after read
103
+ self.concat = Rearrange('b (h w) c -> b c h w',
104
+ c=emb_dim,
105
+ h=(image_height // p),
106
+ w=(image_width // p))
107
+
108
+ #Projection + Resample
109
+ self.resample = Resample(p, s, image_height, emb_dim, resample_dim)
110
+
111
+ def forward(self, x):
112
+ x = self.read(x)
113
+ x = self.concat(x)
114
+ x = self.resample(x)
115
+ return x
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ numpy
5
+ pillow
6
+ transformers
7
+ einops