Spaces:
Sleeping
Sleeping
File size: 3,090 Bytes
7171d2e d49d935 0b9a94a 87fce81 7171d2e 627d8bd 7171d2e 627d8bd 7171d2e 81cb914 7171d2e c70f7dd d49d935 c70f7dd 73c03f9 0b9a94a 73c03f9 c70f7dd d49d935 f43d01b 0b9a94a 6ab1531 0b9a94a 6ab1531 c70f7dd 0b9a94a d49d935 0b9a94a d49d935 0b9a94a c70f7dd 0b9a94a c70f7dd 6ab1531 4c603f2 6ab1531 c70f7dd f9dfab2 c70f7dd d49d935 c70f7dd 8523363 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from pathlib import Path
import gradio as gr
from fastai.vision.all import *
#######################
# Data & Learner #
#######################
class ImageImageDataLoaders(DataLoaders):
"""Create DataLoaders for image→image tasks."""
@classmethod
@delegates(DataLoaders.from_dblock)
def from_label_func(
cls,
path: Path,
filenames,
label_func,
valid_pct: float = 0.2,
seed: int | None = None,
item_transforms=None,
batch_transforms=None,
**kwargs,
):
dblock = DataBlock(
blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
get_y=label_func,
splitter=RandomSplitter(valid_pct, seed=seed),
item_tfms=item_transforms,
batch_tfms=batch_transforms,
)
return cls.from_dblock(dblock, filenames, path=path, **kwargs)
def get_y_fn(x: Path) -> Path:
"""Return same image as label for architecture initialization."""
return x
def create_data(data_path: Path):
"""Create minimal data loader for model architecture initialization."""
fnames = get_files(data_path, extensions=".jpg")
return ImageImageDataLoaders.from_label_func(
data_path,
seed=42,
bs=1,
num_workers=0,
valid_pct=0.0,
filenames=fnames,
label_func=get_y_fn,
)
# Initialize learner with architecture
data = create_data(Path("examples"))
learner = unet_learner(
data,
resnet34,
n_out=3,
loss_func=MSELossFlat(),
path=".",
model_dir="models",
)
learner.load("model")
#####################
# Inference Logic #
#####################
def predict_depth(input_img: PILImage) -> PILImageBW:
depth, *_ = learner.predict(input_img)
return PILImageBW.create(depth).convert("L")
#####################
# Gradio UI #
#####################
title = "📷 SavtaDepth WebApp"
description_md = """
<p style="text-align:center;font-size:1.05rem;max-width:760px;margin:auto;">
Upload an RGB image on the left and get a grayscale depth map on the right.
</p>
"""
footer_html = """
<p style='text-align:center;font-size:0.9rem;'>
<a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>Project on DAGsHub</a> •
<a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a>
</p>
"""
examples = [["examples/00008.jpg"], ["examples/00045.jpg"]]
input_component = gr.Image(width=640, height=480, label="Input RGB")
output_component = gr.Image(label="Predicted Depth", image_mode="L")
with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo:
gr.Markdown(f"<center><h1>{title}</h1></center>")
gr.HTML(description_md)
gr.Interface(
fn=predict_depth,
inputs=input_component,
outputs=output_component,
examples=examples,
cache_examples=False,
)
gr.HTML(footer_html)
if __name__ == "__main__":
demo.queue().launch()
|