Spaces:
Runtime error
Runtime error
Commit Β·
6324b4f
0
Parent(s):
Duplicate from ybelkada/FocusOnDepth
Browse filesCo-authored-by: Younes Belkada <ybelkada@users.noreply.huggingface.co>
- .gitattributes +28 -0
- README.md +10 -0
- app.py +75 -0
- example_image.jpg +0 -0
- focusondepth/__init__.py +0 -0
- focusondepth/fusion.py +41 -0
- focusondepth/head.py +50 -0
- focusondepth/model_config.py +45 -0
- focusondepth/model_definition.py +68 -0
- focusondepth/reassemble.py +115 -0
- requirements.txt +7 -0
.gitattributes
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: FocusOnDepth
|
| 3 |
+
emoji: π¨
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
duplicated_from: ybelkada/FocusOnDepth
|
| 10 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
|
| 10 |
+
from transformers import AutoConfig, AutoModel
|
| 11 |
+
from transformers import AutoModel
|
| 12 |
+
|
| 13 |
+
from focusondepth.model_config import FocusOnDepthConfig
|
| 14 |
+
from focusondepth.model_definition import FocusOnDepth
|
| 15 |
+
|
| 16 |
+
AutoConfig.register("focusondepth", FocusOnDepthConfig)
|
| 17 |
+
AutoModel.register(FocusOnDepthConfig, FocusOnDepth)
|
| 18 |
+
|
| 19 |
+
transform = transforms.Compose([
|
| 20 |
+
transforms.Resize((384, 384)),
|
| 21 |
+
transforms.ToTensor(),
|
| 22 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 23 |
+
])
|
| 24 |
+
model = AutoModel.from_pretrained('ybelkada/focusondepth', trust_remote_code=True)
|
| 25 |
+
|
| 26 |
+
@torch.no_grad()
|
| 27 |
+
def inference(input_image):
|
| 28 |
+
global model, transform
|
| 29 |
+
|
| 30 |
+
model.eval()
|
| 31 |
+
input_image = Image.fromarray(input_image)
|
| 32 |
+
original_size = input_image.size
|
| 33 |
+
tensor_image = transform(input_image)
|
| 34 |
+
|
| 35 |
+
depth, segmentation = model(tensor_image.unsqueeze(0))
|
| 36 |
+
depth = 1-depth
|
| 37 |
+
|
| 38 |
+
depth = transforms.ToPILImage()(depth[0, :])
|
| 39 |
+
segmentation = transforms.ToPILImage()(segmentation.argmax(dim=1).float())
|
| 40 |
+
|
| 41 |
+
return [depth.resize(original_size, resample=Image.BICUBIC), segmentation.resize(original_size, resample=Image.NEAREST)]
|
| 42 |
+
|
| 43 |
+
description = """
|
| 44 |
+
<center>
|
| 45 |
+
Can a single model predict both segmentation and depth estimation? At least, if the segmentation is constrained for a single class, the answer is yes! <br>
|
| 46 |
+
In this project, we use a DPT model to predict the depth and the segmentation mask of the class human, of an image. This model could be potentially used for an autofocus application where you would need the segmentation mask of the humans on the picture, as well as the depth estimation of the scene<br>
|
| 47 |
+
Credits also to <div style='text-align: center;'><a href='https://github.com/antocad' target='_blank'>@antocad</a> !
|
| 48 |
+
</center>
|
| 49 |
+
"""
|
| 50 |
+
title="""
|
| 51 |
+
FocusOnDepth - A single DPT encoder for Dense Prediction Tasks
|
| 52 |
+
"""
|
| 53 |
+
css = """
|
| 54 |
+
"""
|
| 55 |
+
article = """
|
| 56 |
+
<center>
|
| 57 |
+
Example image taken from <a href="https://www.flickr.com/photos/17423713@N03/29129350066">here</a>. The image is free to share and use. <br>
|
| 58 |
+
</center>
|
| 59 |
+
<div style='text-align: center;'><a href='https://github.com/isl-org/DPT' target='_blank'>Original Paper</a> | <a href='https://github.com/antocad/FocusOnDepth' target='_blank'>Extended Version</a></div>
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
iface = gr.Interface(
|
| 63 |
+
fn=inference,
|
| 64 |
+
inputs=gr.inputs.Image(label="Input Image"),
|
| 65 |
+
outputs = [
|
| 66 |
+
gr.outputs.Image(label="Depth Map:"),
|
| 67 |
+
gr.outputs.Image(label="Segmentation Map:"),
|
| 68 |
+
],
|
| 69 |
+
examples=['example_image.jpg'],
|
| 70 |
+
description=description,
|
| 71 |
+
title=title,
|
| 72 |
+
css=css,
|
| 73 |
+
article=article
|
| 74 |
+
)
|
| 75 |
+
iface.launch(enable_queue=True, cache_examples=True)
|
example_image.jpg
ADDED
|
focusondepth/__init__.py
ADDED
|
File without changes
|
focusondepth/fusion.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class ResidualConvUnit(nn.Module):
|
| 6 |
+
def __init__(self, features):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.conv1 = nn.Conv2d(
|
| 10 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True)
|
| 11 |
+
self.conv2 = nn.Conv2d(
|
| 12 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True)
|
| 13 |
+
self.relu = nn.ReLU(inplace=True)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
"""Forward pass.
|
| 17 |
+
Args:
|
| 18 |
+
x (tensor): input
|
| 19 |
+
Returns:
|
| 20 |
+
tensor: output
|
| 21 |
+
"""
|
| 22 |
+
out = self.relu(x)
|
| 23 |
+
out = self.conv1(out)
|
| 24 |
+
out = self.relu(out)
|
| 25 |
+
out = self.conv2(out)
|
| 26 |
+
return out + x
|
| 27 |
+
|
| 28 |
+
class Fusion(nn.Module):
|
| 29 |
+
def __init__(self, resample_dim):
|
| 30 |
+
super(Fusion, self).__init__()
|
| 31 |
+
self.res_conv1 = ResidualConvUnit(resample_dim)
|
| 32 |
+
self.res_conv2 = ResidualConvUnit(resample_dim)
|
| 33 |
+
|
| 34 |
+
def forward(self, x, previous_stage=None):
|
| 35 |
+
if previous_stage == None:
|
| 36 |
+
previous_stage = torch.zeros_like(x)
|
| 37 |
+
output_stage1 = self.res_conv1(x)
|
| 38 |
+
output_stage1 += previous_stage
|
| 39 |
+
output_stage2 = self.res_conv2(output_stage1)
|
| 40 |
+
output_stage2 = nn.functional.interpolate(output_stage2, scale_factor=2, mode="bilinear", align_corners=True)
|
| 41 |
+
return output_stage2
|
focusondepth/head.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class Interpolate(nn.Module):
|
| 6 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
| 7 |
+
super(Interpolate, self).__init__()
|
| 8 |
+
self.interp = nn.functional.interpolate
|
| 9 |
+
self.scale_factor = scale_factor
|
| 10 |
+
self.mode = mode
|
| 11 |
+
self.align_corners = align_corners
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
x = self.interp(
|
| 15 |
+
x,
|
| 16 |
+
scale_factor=self.scale_factor,
|
| 17 |
+
mode=self.mode,
|
| 18 |
+
align_corners=self.align_corners)
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
class HeadDepth(nn.Module):
|
| 22 |
+
def __init__(self, features):
|
| 23 |
+
super(HeadDepth, self).__init__()
|
| 24 |
+
self.head = nn.Sequential(
|
| 25 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
| 26 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
| 27 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 30 |
+
# nn.ReLU()
|
| 31 |
+
nn.Sigmoid()
|
| 32 |
+
)
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x = self.head(x)
|
| 35 |
+
# x = (x - x.min())/(x.max()-x.min() + 1e-15)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
class HeadSeg(nn.Module):
|
| 39 |
+
def __init__(self, features, nclasses=2):
|
| 40 |
+
super(HeadSeg, self).__init__()
|
| 41 |
+
self.head = nn.Sequential(
|
| 42 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
| 43 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
| 44 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
| 45 |
+
nn.ReLU(),
|
| 46 |
+
nn.Conv2d(32, nclasses, kernel_size=1, stride=1, padding=0)
|
| 47 |
+
)
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.head(x)
|
| 50 |
+
return x
|
focusondepth/model_config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FocusOnDepthConfig(PretrainedConfig):
|
| 6 |
+
model_type = "focusondepth"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
image_size = (3, 384, 384),
|
| 11 |
+
patch_size = 16,
|
| 12 |
+
emb_dim = 768,
|
| 13 |
+
resample_dim = 256,
|
| 14 |
+
read = 'projection',
|
| 15 |
+
num_layers_encoder = 24,
|
| 16 |
+
hooks = [2, 5, 8, 11],
|
| 17 |
+
reassemble_s = [4, 8, 16, 32],
|
| 18 |
+
transformer_dropout= 0,
|
| 19 |
+
nclasses = 2,
|
| 20 |
+
type_ = "full",
|
| 21 |
+
model_timm = "vit_base_patch16_384",
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
if type_ not in ["full", "depth", "segmentation"]:
|
| 25 |
+
raise ValueError(f"`type_` must be 'full' or depth' or 'segmentation, got {type_}.")
|
| 26 |
+
if read not in ["ignore", "add", "projection"]:
|
| 27 |
+
raise ValueError(f"`read` must be '', 'ignore' or 'add' or 'projection, got {read}.")
|
| 28 |
+
if image_size[0] != 3 and image_size[1] != 384 and image_size[2] != 384:
|
| 29 |
+
raise ValueError(f"`image_size` must be 3, 384, 384, got {image_size}.")
|
| 30 |
+
if patch_size != 16:
|
| 31 |
+
raise ValueError(f"`patch_size` must be 16, got {patch_size}.")
|
| 32 |
+
|
| 33 |
+
self.image_size = image_size
|
| 34 |
+
self.patch_size = patch_size
|
| 35 |
+
self.emb_dim = emb_dim
|
| 36 |
+
self.resample_dim = resample_dim
|
| 37 |
+
self.read = read
|
| 38 |
+
self.num_layers_encoder = num_layers_encoder
|
| 39 |
+
self.hooks = hooks
|
| 40 |
+
self.reassemble_s = reassemble_s
|
| 41 |
+
self.transformer_dropout = transformer_dropout
|
| 42 |
+
self.nclasses = nclasses
|
| 43 |
+
self.type_ = type_
|
| 44 |
+
self.model_timm = model_timm
|
| 45 |
+
super().__init__(**kwargs)
|
focusondepth/model_definition.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel
|
| 2 |
+
import timm
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .model_config import FocusOnDepthConfig
|
| 7 |
+
from .reassemble import Reassemble
|
| 8 |
+
from .fusion import Fusion
|
| 9 |
+
from .head import HeadDepth, HeadSeg
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FocusOnDepth(PreTrainedModel):
|
| 13 |
+
config_class = FocusOnDepthConfig
|
| 14 |
+
|
| 15 |
+
def __init__(self, config):
|
| 16 |
+
super().__init__(config)
|
| 17 |
+
self.transformer_encoders = timm.create_model(config.model_timm, pretrained=True)
|
| 18 |
+
self.type_ = config.type_
|
| 19 |
+
|
| 20 |
+
#Register hooks
|
| 21 |
+
self.activation = {}
|
| 22 |
+
self.hooks = config.hooks
|
| 23 |
+
self._get_layers_from_hooks(self.hooks)
|
| 24 |
+
|
| 25 |
+
#Reassembles Fusion
|
| 26 |
+
self.reassembles = []
|
| 27 |
+
self.fusions = []
|
| 28 |
+
for s in config.reassemble_s:
|
| 29 |
+
self.reassembles.append(Reassemble(config.image_size, config.read, config.patch_size, s, config.emb_dim, config.resample_dim))
|
| 30 |
+
self.fusions.append(Fusion(config.resample_dim))
|
| 31 |
+
self.reassembles = nn.ModuleList(self.reassembles)
|
| 32 |
+
self.fusions = nn.ModuleList(self.fusions)
|
| 33 |
+
|
| 34 |
+
#Head
|
| 35 |
+
if self.type_ == "full":
|
| 36 |
+
self.head_depth = HeadDepth(config.resample_dim)
|
| 37 |
+
self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
|
| 38 |
+
elif self.type_ == "depth":
|
| 39 |
+
self.head_depth = HeadDepth(config.resample_dim)
|
| 40 |
+
self.head_segmentation = None
|
| 41 |
+
else:
|
| 42 |
+
self.head_depth = None
|
| 43 |
+
self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
|
| 44 |
+
|
| 45 |
+
def forward(self, img):
|
| 46 |
+
_ = self.transformer_encoders(img)
|
| 47 |
+
previous_stage = None
|
| 48 |
+
for i in np.arange(len(self.fusions)-1, -1, -1):
|
| 49 |
+
hook_to_take = 't'+str(self.hooks[i])
|
| 50 |
+
activation_result = self.activation[hook_to_take]
|
| 51 |
+
reassemble_result = self.reassembles[i](activation_result)
|
| 52 |
+
fusion_result = self.fusions[i](reassemble_result, previous_stage)
|
| 53 |
+
previous_stage = fusion_result
|
| 54 |
+
out_depth = None
|
| 55 |
+
out_segmentation = None
|
| 56 |
+
if self.head_depth != None:
|
| 57 |
+
out_depth = self.head_depth(previous_stage)
|
| 58 |
+
if self.head_segmentation != None:
|
| 59 |
+
out_segmentation = self.head_segmentation(previous_stage)
|
| 60 |
+
return out_depth, out_segmentation
|
| 61 |
+
|
| 62 |
+
def _get_layers_from_hooks(self, hooks):
|
| 63 |
+
def get_activation(name):
|
| 64 |
+
def hook(model, input, output):
|
| 65 |
+
self.activation[name] = output
|
| 66 |
+
return hook
|
| 67 |
+
for h in hooks:
|
| 68 |
+
self.transformer_encoders.blocks[h].register_forward_hook(get_activation('t'+str(h)))
|
focusondepth/reassemble.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
from einops.layers.torch import Rearrange
|
| 6 |
+
|
| 7 |
+
class Read_ignore(nn.Module):
|
| 8 |
+
def __init__(self, start_index=1):
|
| 9 |
+
super(Read_ignore, self).__init__()
|
| 10 |
+
self.start_index = start_index
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return x[:, self.start_index:]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Read_add(nn.Module):
|
| 17 |
+
def __init__(self, start_index=1):
|
| 18 |
+
super(Read_add, self).__init__()
|
| 19 |
+
self.start_index = start_index
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
if self.start_index == 2:
|
| 23 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
| 24 |
+
else:
|
| 25 |
+
readout = x[:, 0]
|
| 26 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Read_projection(nn.Module):
|
| 30 |
+
def __init__(self, in_features, start_index=1):
|
| 31 |
+
super(Read_projection, self).__init__()
|
| 32 |
+
self.start_index = start_index
|
| 33 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
| 37 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
| 38 |
+
return self.project(features)
|
| 39 |
+
|
| 40 |
+
class MyConvTranspose2d(nn.Module):
|
| 41 |
+
def __init__(self, conv, output_size):
|
| 42 |
+
super(MyConvTranspose2d, self).__init__()
|
| 43 |
+
self.output_size = output_size
|
| 44 |
+
self.conv = conv
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = self.conv(x, output_size=self.output_size)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
class Resample(nn.Module):
|
| 51 |
+
def __init__(self, p, s, h, emb_dim, resample_dim):
|
| 52 |
+
super(Resample, self).__init__()
|
| 53 |
+
assert (s in [4, 8, 16, 32]), "s must be in [0.5, 4, 8, 16, 32]"
|
| 54 |
+
self.conv1 = nn.Conv2d(emb_dim, resample_dim, kernel_size=1, stride=1, padding=0)
|
| 55 |
+
if s == 4:
|
| 56 |
+
self.conv2 = nn.ConvTranspose2d(resample_dim,
|
| 57 |
+
resample_dim,
|
| 58 |
+
kernel_size=4,
|
| 59 |
+
stride=4,
|
| 60 |
+
padding=0,
|
| 61 |
+
bias=True,
|
| 62 |
+
dilation=1,
|
| 63 |
+
groups=1)
|
| 64 |
+
elif s == 8:
|
| 65 |
+
self.conv2 = nn.ConvTranspose2d(resample_dim,
|
| 66 |
+
resample_dim,
|
| 67 |
+
kernel_size=2,
|
| 68 |
+
stride=2,
|
| 69 |
+
padding=0,
|
| 70 |
+
bias=True,
|
| 71 |
+
dilation=1,
|
| 72 |
+
groups=1)
|
| 73 |
+
elif s == 16:
|
| 74 |
+
self.conv2 = nn.Identity()
|
| 75 |
+
else:
|
| 76 |
+
self.conv2 = nn.Conv2d(resample_dim, resample_dim, kernel_size=2,stride=2, padding=0, bias=True)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = self.conv1(x)
|
| 80 |
+
x = self.conv2(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class Reassemble(nn.Module):
|
| 84 |
+
def __init__(self, image_size, read, p, s, emb_dim, resample_dim):
|
| 85 |
+
"""
|
| 86 |
+
p = patch size
|
| 87 |
+
s = coefficient resample
|
| 88 |
+
emb_dim <=> D (in the paper)
|
| 89 |
+
resample_dim <=> ^D (in the paper)
|
| 90 |
+
read : {"ignore", "add", "projection"}
|
| 91 |
+
"""
|
| 92 |
+
super(Reassemble, self).__init__()
|
| 93 |
+
channels, image_height, image_width = image_size
|
| 94 |
+
|
| 95 |
+
#Read
|
| 96 |
+
self.read = Read_ignore()
|
| 97 |
+
if read == 'add':
|
| 98 |
+
self.read = Read_add()
|
| 99 |
+
elif read == 'projection':
|
| 100 |
+
self.read = Read_projection(emb_dim)
|
| 101 |
+
|
| 102 |
+
#Concat after read
|
| 103 |
+
self.concat = Rearrange('b (h w) c -> b c h w',
|
| 104 |
+
c=emb_dim,
|
| 105 |
+
h=(image_height // p),
|
| 106 |
+
w=(image_width // p))
|
| 107 |
+
|
| 108 |
+
#Projection + Resample
|
| 109 |
+
self.resample = Resample(p, s, image_height, emb_dim, resample_dim)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
x = self.read(x)
|
| 113 |
+
x = self.concat(x)
|
| 114 |
+
x = self.resample(x)
|
| 115 |
+
return x
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
timm
|
| 4 |
+
numpy
|
| 5 |
+
pillow
|
| 6 |
+
transformers
|
| 7 |
+
einops
|