File size: 3,090 Bytes
7171d2e
 
d49d935
0b9a94a
87fce81
7171d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627d8bd
7171d2e
627d8bd
7171d2e
 
 
 
 
 
 
 
 
 
 
 
81cb914
 
7171d2e
 
c70f7dd
d49d935
 
 
c70f7dd
73c03f9
0b9a94a
 
 
73c03f9
c70f7dd
d49d935
 
 
f43d01b
0b9a94a
6ab1531
0b9a94a
6ab1531
 
c70f7dd
 
 
0b9a94a
d49d935
 
0b9a94a
d49d935
 
 
0b9a94a
c70f7dd
0b9a94a
 
c70f7dd
6ab1531
4c603f2
6ab1531
c70f7dd
 
 
 
 
 
f9dfab2
c70f7dd
 
d49d935
 
c70f7dd
8523363
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from pathlib import Path

import gradio as gr
from fastai.vision.all import *

#######################
#   Data & Learner    #
#######################


class ImageImageDataLoaders(DataLoaders):
    """Create DataLoaders for image→image tasks."""

    @classmethod
    @delegates(DataLoaders.from_dblock)
    def from_label_func(
        cls,
        path: Path,
        filenames,
        label_func,
        valid_pct: float = 0.2,
        seed: int | None = None,
        item_transforms=None,
        batch_transforms=None,
        **kwargs,
    ):
        dblock = DataBlock(
            blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
            get_y=label_func,
            splitter=RandomSplitter(valid_pct, seed=seed),
            item_tfms=item_transforms,
            batch_tfms=batch_transforms,
        )
        return cls.from_dblock(dblock, filenames, path=path, **kwargs)


def get_y_fn(x: Path) -> Path:
    """Return same image as label for architecture initialization."""
    return x


def create_data(data_path: Path):
    """Create minimal data loader for model architecture initialization."""
    fnames = get_files(data_path, extensions=".jpg")
    return ImageImageDataLoaders.from_label_func(
        data_path,
        seed=42,
        bs=1,
        num_workers=0,
        valid_pct=0.0,
        filenames=fnames,
        label_func=get_y_fn,
    )


# Initialize learner with architecture
data = create_data(Path("examples"))
learner = unet_learner(
    data,
    resnet34,
    n_out=3,
    loss_func=MSELossFlat(),
    path=".",
    model_dir="models",
)
learner.load("model")

#####################
#  Inference Logic  #
#####################


def predict_depth(input_img: PILImage) -> PILImageBW:
    depth, *_ = learner.predict(input_img)
    return PILImageBW.create(depth).convert("L")


#####################
#    Gradio UI      #
#####################

title = "📷 SavtaDepth WebApp"

description_md = """
    <p style="text-align:center;font-size:1.05rem;max-width:760px;margin:auto;">
        Upload an RGB image on the left and get a grayscale depth map on the right.
    </p>
    """

footer_html = """
    <p style='text-align:center;font-size:0.9rem;'>
        <a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>Project on DAGsHub</a> •
        <a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a>
    </p>
    """

examples = [["examples/00008.jpg"], ["examples/00045.jpg"]]

input_component = gr.Image(width=640, height=480, label="Input RGB")
output_component = gr.Image(label="Predicted Depth", image_mode="L")

with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"<center><h1>{title}</h1></center>")
    gr.HTML(description_md)

    gr.Interface(
        fn=predict_depth,
        inputs=input_component,
        outputs=output_component,
        examples=examples,
        cache_examples=False,
    )

    gr.HTML(footer_html)

if __name__ == "__main__":
    demo.queue().launch()