kingabzpro commited on
Commit
c70f7dd
·
verified ·
1 Parent(s): 334a409

Update app/app_savta.py

Browse files
Files changed (1) hide show
  1. app/app_savta.py +126 -80
app/app_savta.py CHANGED
@@ -1,106 +1,152 @@
1
- import torch
2
  import os
 
 
 
3
  from fastai.vision.all import *
4
  import gradio as gr
5
 
6
- ############### HF ###########################
7
-
8
- HF_TOKEN = os.getenv('HF_TOKEN')
9
 
 
 
10
  hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "savtadepth-flags-V2")
11
 
12
- ############## DVC ################################
 
 
13
 
14
  PROD_MODEL_PATH = "src/models"
15
  TRAIN_PATH = "src/data/processed/train/bathroom"
16
  TEST_PATH = "src/data/processed/test/bathroom"
17
 
18
- if os.path.isdir(".dvc"):
19
  print("Running DVC")
20
- # os.system("dvc config cache.type copy")
21
- # os.system("dvc config core.no_scm true")
22
- if os.system(f"dvc pull {PROD_MODEL_PATH} {TRAIN_PATH } {TEST_PATH }") != 0:
23
- exit("dvc pull failed")
24
- os.system("rm -r .dvc")
25
- # .apt/usr/lib/dvc
26
 
27
- ############## Inference ##############################
 
 
28
 
29
  class ImageImageDataLoaders(DataLoaders):
30
- """Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
 
31
  @classmethod
32
  @delegates(DataLoaders.from_dblock)
33
- def from_label_func(cls, path, filenames, label_func, valid_pct=0.2, seed=None, item_transforms=None,
34
- batch_transforms=None, **kwargs):
35
- """Create from list of `fnames` in `path`s with `label_func`."""
36
- datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
37
- get_y=label_func,
38
- splitter=RandomSplitter(valid_pct, seed=seed),
39
- item_tfms=item_transforms,
40
- batch_tfms=batch_transforms)
41
- res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
42
- return res
43
-
44
-
45
- def get_y_fn(x):
46
- y = str(x.absolute()).replace('.jpg', '_depth.png')
47
- y = Path(y)
48
-
49
- return y
50
-
51
-
52
- def create_data(data_path):
53
- fnames = get_files(data_path/'train', extensions='.jpg')
54
- data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0, filenames=fnames, label_func=get_y_fn)
55
- return data
56
-
57
- data = create_data(Path('src/data/processed'))
58
- learner = unet_learner(data,resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/')
59
- learner.load('model')
60
-
61
- def gen(input_img):
62
- return PILImageBW.create((learner.predict(input_img))[0]).convert('L')
63
-
64
- ################### Gradio Web APP ################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  title = "SavtaDepth WebApp"
67
 
68
- description = """
69
- <p>
70
- <center>
71
- Savta Depth is a collaborative Open Source Data Science project for monocular depth estimation - Turn 2d photos into 3d photos. To test the model and code please check out the link bellow.
72
- <img src="https://huggingface.co/spaces/kingabzpro/savtadepth/resolve/main/examples/cover.png" alt="logo" width="250"/>
73
- </center>
74
- </p>
75
- """
76
- article = "<p style='text-align: center'><a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>SavtaDepth Project from OperationSavta</a></p><p style='text-align: center'><a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a></p></center><center><img src='https://visitor-badge.glitch.me/badge?page_id=kingabzpro/savtadepth' alt='visitor badge'></center></p>"
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  examples = [
79
  ["examples/00008.jpg"],
80
  ["examples/00045.jpg"],
81
  ]
82
- favicon = "examples/favicon.ico"
83
- thumbnail = "examples/SavtaDepth.png"
84
-
85
-
86
- def main():
87
- iface = gr.Interface(
88
- gen,
89
- gr.inputs.Image(shape=(640,480),type='numpy'),
90
- "image",
91
- title = title,
92
- flagging_options=["incorrect", "worst","ambiguous"],
93
- allow_flagging = "manual",
94
- flagging_callback=hf_writer,
95
- description = description,
96
- article = article,
97
- examples = examples,
98
- theme ="peach",
99
- allow_screenshot=True
100
- )
101
-
102
- iface.launch(enable_queue=True)
103
- # enable_queue=True,auth=("admin", "pass1234")
104
-
105
- if __name__ == '__main__':
106
- main()
 
 
1
  import os
2
+ from pathlib import Path
3
+
4
+ import torch
5
  from fastai.vision.all import *
6
  import gradio as gr
7
 
8
+ ######################
9
+ # Hugging Face flags #
10
+ ######################
11
 
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+ # `HuggingFaceDatasetSaver` is still available in Gradio ≥ 5
14
  hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "savtadepth-flags-V2")
15
 
16
+ ############
17
+ # DVC #
18
+ ############
19
 
20
  PROD_MODEL_PATH = "src/models"
21
  TRAIN_PATH = "src/data/processed/train/bathroom"
22
  TEST_PATH = "src/data/processed/test/bathroom"
23
 
24
+ if Path(".dvc").is_dir():
25
  print("Running DVC")
26
+ if os.system(f"dvc pull {PROD_MODEL_PATH} {TRAIN_PATH} {TEST_PATH}") != 0:
27
+ raise SystemExit("dvc pull failed")
28
+ # remove DVC metadata to avoid accidental reuse in the Space
29
+ os.system("rm -rf .dvc")
 
 
30
 
31
+ #######################
32
+ # Data & Learner #
33
+ #######################
34
 
35
  class ImageImageDataLoaders(DataLoaders):
36
+ """Wrapper to create DataLoaders for image→image tasks."""
37
+
38
  @classmethod
39
  @delegates(DataLoaders.from_dblock)
40
+ def from_label_func(
41
+ cls,
42
+ path: Path,
43
+ filenames,
44
+ label_func,
45
+ valid_pct: float = 0.2,
46
+ seed: int | None = None,
47
+ item_transforms=None,
48
+ batch_transforms=None,
49
+ **kwargs,
50
+ ):
51
+ dblock = DataBlock(
52
+ blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
53
+ get_y=label_func,
54
+ splitter=RandomSplitter(valid_pct, seed=seed),
55
+ item_tfms=item_transforms,
56
+ batch_tfms=batch_transforms,
57
+ )
58
+ return cls.from_dblock(dblock, filenames, path=path, **kwargs)
59
+
60
+
61
+ def get_y_fn(x: Path) -> Path:
62
+ """Map an RGB image path to its depth‑map counterpart."""
63
+ return Path(str(x).replace(".jpg", "_depth.png"))
64
+
65
+
66
+ def create_data(data_path: Path):
67
+ fnames = get_files(data_path / "train", extensions=".jpg")
68
+ return ImageImageDataLoaders.from_label_func(
69
+ data_path / "train",
70
+ seed=42,
71
+ bs=4,
72
+ num_workers=0,
73
+ filenames=fnames,
74
+ label_func=get_y_fn,
75
+ )
76
+
77
+
78
+ data = create_data(Path("src/data/processed"))
79
+ learner = unet_learner(
80
+ data,
81
+ resnet34,
82
+ metrics=rmse,
83
+ wd=1e-2,
84
+ n_out=3,
85
+ loss_func=MSELossFlat(),
86
+ path="src/",
87
+ )
88
+ learner.load("model")
89
+
90
+ #####################
91
+ # Inference Logic #
92
+ #####################
93
+
94
+ def predict_depth(input_img: PILImage) -> PILImageBW:
95
+ """Generate a single‑channel depth prediction from an RGB image."""
96
+ depth, *_ = learner.predict(input_img)
97
+ return PILImageBW.create(depth).convert("L")
98
+
99
+ #####################
100
+ # Gradio UI #
101
+ #####################
102
 
103
  title = "SavtaDepth WebApp"
104
 
105
+ description = (
106
+ """
107
+ <p style="text-align:center;">
108
+ Savta Depth is a collaborative OpenSource project for monocular depth estimation turn 2‑D photos into 3‑D. 🏞️<br>
109
+ Try the model below or explore the resources.
110
+ <br><img src="https://huggingface.co/spaces/kingabzpro/savtadepth/resolve/main/examples/cover.png" width="250"/>
111
+ </p>
112
+ """
113
+ )
114
+
115
+ article = (
116
+ """
117
+ <p style='text-align:center'>
118
+ <a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>Project on DAGsHub</a> •
119
+ <a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a>
120
+ <br/>
121
+ <img src='https://visitor-badge.glitch.me/badge?page_id=kingabzpro/savtadepth' alt='visitor badge'/>
122
+ </p>
123
+ """
124
+ )
125
 
126
  examples = [
127
  ["examples/00008.jpg"],
128
  ["examples/00045.jpg"],
129
  ]
130
+
131
+ # Modern Gradio components (v5+)
132
+ input_component = gr.Image(shape=(640, 480), image_mode="RGB", label="Input RGB")
133
+ output_component = gr.Image(image_mode="L", label="Predicted Depth")
134
+
135
+ with gr.Blocks(title=title) as demo:
136
+ gr.Markdown(description)
137
+ gr.Markdown(article)
138
+
139
+ gr.Interface(
140
+ fn=predict_depth,
141
+ inputs=input_component,
142
+ outputs=output_component,
143
+ allow_flagging="manual",
144
+ flagging_options=["incorrect", "worst", "ambiguous"],
145
+ flagging_callback=hf_writer,
146
+ examples=examples,
147
+ cache_examples=False, # use live inference for examples
148
+ theme=gr.themes.Soft(),
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ demo.queue(concurrency_count=3).launch()