Abid Ali Awan commited on
Commit
d49d935
·
1 Parent(s): 87b61df

Enhance app_savta.py with improved model loading and fallback depth estimation

Browse files

- Added support for Hugging Face flagging with a dataset saver.
- Implemented a fallback depth estimation method using simple edge detection when the model is not found.
- Updated inference logic and Gradio UI to include flagging options and a footer with project links.
- Streamlined model loading process with error handling for better user experience.

Files changed (1) hide show
  1. app/app_savta.py +122 -10
app/app_savta.py CHANGED
@@ -1,34 +1,141 @@
 
1
  from pathlib import Path
2
-
3
- import gradio as gr
4
  from fastai.vision.all import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Model setup
 
 
 
 
7
  MODEL_PATH = Path(__file__).parent.parent / "models" / "model.pth"
8
 
 
9
  if not MODEL_PATH.exists():
10
- raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- learner = load_learner(MODEL_PATH)
13
 
 
 
 
14
 
15
- # Inference function
16
  def predict_depth(input_img: PILImage) -> PILImageBW:
17
  depth, *_ = learner.predict(input_img)
18
  return PILImageBW.create(depth).convert("L")
19
 
 
 
 
20
 
21
- # Gradio UI
22
  title = "📷 SavtaDepth WebApp"
23
 
24
- description_md = """
 
25
  <p style="text-align:center;font-size:1.05rem;max-width:760px;margin:auto;">
26
  Upload an RGB image on the left and get a grayscale depth map on the right.
27
  </p>
28
  """
 
29
 
30
- examples_dir = Path(__file__).parent.parent / "examples"
31
- examples = [[str(examples_dir / "00008.jpg")], [str(examples_dir / "00045.jpg")]]
 
 
 
 
 
 
 
 
32
 
33
  input_component = gr.Image(width=640, height=480, label="Input RGB")
34
  output_component = gr.Image(label="Predicted Depth", image_mode="L")
@@ -41,9 +148,14 @@ with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo:
41
  fn=predict_depth,
42
  inputs=input_component,
43
  outputs=output_component,
 
 
 
44
  examples=examples,
45
  cache_examples=False,
46
  )
47
 
 
 
48
  if __name__ == "__main__":
49
  demo.queue().launch()
 
1
+ import os, sys, tempfile, subprocess
2
  from pathlib import Path
3
+ import torch
 
4
  from fastai.vision.all import *
5
+ import gradio as gr
6
+
7
+ #######################
8
+ # Hugging Face flags #
9
+ #######################
10
+
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ try:
14
+ from gradio.flagging import HuggingFaceDatasetSaver # type: ignore
15
+ hf_writer: gr.FlaggingCallback | None = HuggingFaceDatasetSaver(
16
+ repo_id="savtadepth-flags-V2", token=HF_TOKEN
17
+ )
18
+ allow_flagging: str | bool = "manual"
19
+ except (ImportError, AttributeError):
20
+ hf_writer = None
21
+ allow_flagging = "never"
22
 
23
+ ############
24
+ # Model setup without DVC
25
+ ############
26
+
27
+ # Use local model path
28
  MODEL_PATH = Path(__file__).parent.parent / "models" / "model.pth"
29
 
30
+ # Check if model exists and use fastai approach from working version
31
  if not MODEL_PATH.exists():
32
+ print("Model not found at", MODEL_PATH)
33
+ print("Using fallback depth estimation...")
34
+ # Fallback to simple image processing
35
+ class SimpleDepthEstimator:
36
+ def predict(self, input_img):
37
+ from PIL import Image
38
+ import numpy as np
39
+
40
+ # Convert to grayscale if needed
41
+ if input_img.mode != 'L':
42
+ img_gray = input_img.convert('L')
43
+ else:
44
+ img_gray = input_img
45
+
46
+ # Simple edge detection for depth
47
+ img_array = np.array(img_gray, dtype=np.float32)
48
+ grad_x = np.abs(np.diff(img_array, axis=1, prepend=img_array[:, :1]))
49
+ grad_y = np.abs(np.diff(img_array, axis=0, prepend=img_array[:1, :]))
50
+ edge_magnitude = np.sqrt(grad_x**2 + grad_y**2)
51
+
52
+ # Create depth based on edges and brightness
53
+ if edge_magnitude.max() > 0:
54
+ edge_magnitude = (edge_magnitude - edge_magnitude.min()) / (edge_magnitude.max() - edge_magnitude.min()) * 255
55
+
56
+ normalized_brightness = (img_array - img_array.min()) / (img_array.max() - img_array.min() + 1e-8)
57
+ depth_factor = 0.6 * (edge_magnitude / 255.0) + 0.4 * (1 - normalized_brightness)
58
+ depth_factor = np.clip(depth_factor, 0, 1)
59
+
60
+ # Convert back to PIL Image
61
+ depth_array = (depth_factor * 255).astype(np.uint8)
62
+ return Image.fromarray(depth_array, mode='L')
63
+
64
+ learner = SimpleDepthEstimator()
65
+ else:
66
+ try:
67
+ # Use the working approach from the previous version
68
+ # Simple approach for inference only (without training data)
69
+ learn = load_learner(MODEL_PATH)
70
+ learner = learn
71
+ except Exception as e:
72
+ print(f"❌ Failed to load model: {e}")
73
+ print("Using fallback depth estimation...")
74
+
75
+ class SimpleDepthEstimator:
76
+ def predict(self, input_img):
77
+ from PIL import Image
78
+ import numpy as np
79
+
80
+ # Convert to grayscale if needed
81
+ if input_img.mode != 'L':
82
+ img_gray = input_img.convert('L')
83
+ else:
84
+ img_gray = input_img
85
+
86
+ # Simple edge detection for depth
87
+ img_array = np.array(img_gray, dtype=np.float32)
88
+ grad_x = np.abs(np.diff(img_array, axis=1, prepend=img_array[:, :1]))
89
+ grad_y = np.abs(np.diff(img_array, axis=0, prepend=img_array[:1, :]))
90
+ edge_magnitude = np.sqrt(grad_x**2 + grad_y**2)
91
+
92
+ # Create depth based on edges and brightness
93
+ if edge_magnitude.max() > 0:
94
+ edge_magnitude = (edge_magnitude - edge_magnitude.min()) / (edge_magnitude.max() - edge_magnitude.min()) * 255
95
+
96
+ normalized_brightness = (img_array - img_array.min()) / (img_array.max() - img_array.min() + 1e-8)
97
+ depth_factor = 0.6 * (edge_magnitude / 255.0) + 0.4 * (1 - normalized_brightness)
98
+ depth_factor = np.clip(depth_factor, 0, 1)
99
+
100
+ # Convert back to PIL Image
101
+ depth_array = (depth_factor * 255).astype(np.uint8)
102
+ return Image.fromarray(depth_array, mode='L')
103
+
104
+ learner = SimpleDepthEstimator()
105
 
 
106
 
107
+ #####################
108
+ # Inference Logic #
109
+ #####################
110
 
 
111
  def predict_depth(input_img: PILImage) -> PILImageBW:
112
  depth, *_ = learner.predict(input_img)
113
  return PILImageBW.create(depth).convert("L")
114
 
115
+ #####################
116
+ # Gradio UI #
117
+ #####################
118
 
 
119
  title = "📷 SavtaDepth WebApp"
120
 
121
+ description_md = (
122
+ """
123
  <p style="text-align:center;font-size:1.05rem;max-width:760px;margin:auto;">
124
  Upload an RGB image on the left and get a grayscale depth map on the right.
125
  </p>
126
  """
127
+ )
128
 
129
+ footer_html = (
130
+ """
131
+ <p style='text-align:center;font-size:0.9rem;'>
132
+ <a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>Project on DAGsHub</a> •
133
+ <a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a>
134
+ </p>
135
+ """
136
+ )
137
+
138
+ examples = [["examples/00008.jpg"], ["examples/00045.jpg"]]
139
 
140
  input_component = gr.Image(width=640, height=480, label="Input RGB")
141
  output_component = gr.Image(label="Predicted Depth", image_mode="L")
 
148
  fn=predict_depth,
149
  inputs=input_component,
150
  outputs=output_component,
151
+ allow_flagging=allow_flagging,
152
+ flagging_options=["incorrect", "worst", "ambiguous"],
153
+ flagging_callback=hf_writer,
154
  examples=examples,
155
  cache_examples=False,
156
  )
157
 
158
+ gr.HTML(footer_html)
159
+
160
  if __name__ == "__main__":
161
  demo.queue().launch()