Abid Ali Awan commited on
Commit
87fce81
·
1 Parent(s): aedffb4

Enhance input handling and suppress warnings in app_savta.py

Browse files

- Added a warning filter to suppress deprecation warnings for cleaner output.
- Improved input image handling by removing unnecessary mode specification in Image.fromarray calls.
- Added a check to ensure the loaded model is a valid fastai learner, with error handling for incompatible formats.

These changes improve the user experience and maintain the robustness of the application.

Files changed (1) hide show
  1. app/app_savta.py +16 -7
app/app_savta.py CHANGED
@@ -1,8 +1,12 @@
1
  import os, sys, tempfile, subprocess
 
2
  from pathlib import Path
3
  import torch
4
  import gradio as gr
5
 
 
 
 
6
  # Try to import fastai components
7
  try:
8
  from fastai.vision.all import *
@@ -64,7 +68,7 @@ if not MODEL_PATH.exists():
64
  if len(input_img.shape) == 3 and input_img.shape[2] == 3:
65
  input_img = Image.fromarray(input_img.astype('uint8'))
66
  elif len(input_img.shape) == 2:
67
- input_img = Image.fromarray(input_img.astype('uint8'), mode='L')
68
  img_gray = input_img.convert('L')
69
 
70
  # Simple edge detection for depth
@@ -83,7 +87,7 @@ if not MODEL_PATH.exists():
83
 
84
  # Convert back to PIL Image
85
  depth_array = (depth_factor * 255).astype(np.uint8)
86
- return Image.fromarray(depth_array, mode='L')
87
 
88
  learner = SimpleDepthEstimator()
89
  else:
@@ -92,7 +96,12 @@ else:
92
  # Simple approach for inference only (without training data)
93
  if FASTAI_AVAILABLE:
94
  learn = load_learner(MODEL_PATH)
95
- learner = learn
 
 
 
 
 
96
  else:
97
  raise ImportError("FastAI not available")
98
  except Exception as e:
@@ -114,7 +123,7 @@ else:
114
  if len(input_img.shape) == 3 and input_img.shape[2] == 3:
115
  input_img = Image.fromarray(input_img.astype('uint8'))
116
  elif len(input_img.shape) == 2:
117
- input_img = Image.fromarray(input_img.astype('uint8'), mode='L')
118
  img_gray = input_img.convert('L')
119
 
120
  # Simple edge detection for depth
@@ -133,7 +142,7 @@ else:
133
 
134
  # Convert back to PIL Image
135
  depth_array = (depth_factor * 255).astype(np.uint8)
136
- return Image.fromarray(depth_array, mode='L')
137
 
138
  learner = SimpleDepthEstimator()
139
 
@@ -176,7 +185,7 @@ def predict_depth(input_img):
176
  input_img = Image.fromarray(input_img.astype('uint8'))
177
  # Grayscale numpy array
178
  elif len(input_img.shape) == 2:
179
- input_img = Image.fromarray(input_img.astype('uint8'), mode='L')
180
 
181
  # Use our simple depth estimation
182
  return learner.predict(input_img)
@@ -218,7 +227,7 @@ with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo:
218
  fn=predict_depth,
219
  inputs=input_component,
220
  outputs=output_component,
221
- allow_flagging=allow_flagging,
222
  flagging_options=["incorrect", "worst", "ambiguous"],
223
  flagging_callback=hf_writer,
224
  examples=examples,
 
1
  import os, sys, tempfile, subprocess
2
+ import warnings
3
  from pathlib import Path
4
  import torch
5
  import gradio as gr
6
 
7
+ # Suppress deprecation warnings for cleaner output
8
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
9
+
10
  # Try to import fastai components
11
  try:
12
  from fastai.vision.all import *
 
68
  if len(input_img.shape) == 3 and input_img.shape[2] == 3:
69
  input_img = Image.fromarray(input_img.astype('uint8'))
70
  elif len(input_img.shape) == 2:
71
+ input_img = Image.fromarray(input_img.astype('uint8'))
72
  img_gray = input_img.convert('L')
73
 
74
  # Simple edge detection for depth
 
87
 
88
  # Convert back to PIL Image
89
  depth_array = (depth_factor * 255).astype(np.uint8)
90
+ return Image.fromarray(depth_array)
91
 
92
  learner = SimpleDepthEstimator()
93
  else:
 
96
  # Simple approach for inference only (without training data)
97
  if FASTAI_AVAILABLE:
98
  learn = load_learner(MODEL_PATH)
99
+ # Check if it's actually a learner object or just a dict
100
+ if hasattr(learn, 'dls') and hasattr(learn, 'predict'):
101
+ learner = learn
102
+ else:
103
+ print("⚠️ Loaded model is not a valid fastai learner")
104
+ raise ValueError("Model file format incompatible")
105
  else:
106
  raise ImportError("FastAI not available")
107
  except Exception as e:
 
123
  if len(input_img.shape) == 3 and input_img.shape[2] == 3:
124
  input_img = Image.fromarray(input_img.astype('uint8'))
125
  elif len(input_img.shape) == 2:
126
+ input_img = Image.fromarray(input_img.astype('uint8'))
127
  img_gray = input_img.convert('L')
128
 
129
  # Simple edge detection for depth
 
142
 
143
  # Convert back to PIL Image
144
  depth_array = (depth_factor * 255).astype(np.uint8)
145
+ return Image.fromarray(depth_array)
146
 
147
  learner = SimpleDepthEstimator()
148
 
 
185
  input_img = Image.fromarray(input_img.astype('uint8'))
186
  # Grayscale numpy array
187
  elif len(input_img.shape) == 2:
188
+ input_img = Image.fromarray(input_img.astype('uint8'))
189
 
190
  # Use our simple depth estimation
191
  return learner.predict(input_img)
 
227
  fn=predict_depth,
228
  inputs=input_component,
229
  outputs=output_component,
230
+ flagging_mode=allow_flagging,
231
  flagging_options=["incorrect", "worst", "ambiguous"],
232
  flagging_callback=hf_writer,
233
  examples=examples,