AlbeRota commited on
Commit
2c3f571
·
1 Parent(s): 571eb53

ZeroGPU loading fixed

Browse files
Files changed (1) hide show
  1. app.py +50 -27
app.py CHANGED
@@ -108,62 +108,81 @@ _cached_device = None
108
 
109
 
110
  def _get_model(device: str):
111
- """Return the pretrained model, loading it once and reusing."""
112
  global _cached_ura_model, _cached_device
113
  assets = _get_assets()
114
- if _cached_ura_model is not None and _cached_device == device:
115
- return _cached_ura_model
116
  from unreflectanything import model
117
 
118
- _cached_ura_model = model(
119
- pretrained=True,
120
- weights_path=assets.weights_path,
121
- # weights_path="/home/arota/UnReflectAnything/weights/full_model_weights.pt",
122
- config_path=assets.config_path,
123
- device=device,
124
- verbose=False,
125
- skip_path_resolution=True,
126
- )
127
- _cached_device = device
 
 
 
 
 
 
 
 
 
128
  return _cached_ura_model
129
 
130
-
131
  def build_ui():
132
  _get_assets()
133
- device = "cuda" if torch.cuda.is_available() else "cpu"
134
- # Start loading the model in the background so it is ready (or nearly ready) by first use.
135
- print(f"Initializing model on {device}...")
136
- _get_model(device)
137
 
 
138
  @spaces.GPU if spaces else lambda x: x
139
  def run_inference(image: np.ndarray | None) -> np.ndarray | None:
140
- """Run reflection removal using the cached model. Returns RGB numpy [H,W,3] in 0–255 or None."""
141
  if image is None:
142
  return None
 
143
  from torchvision.transforms import functional as TF
 
144
 
 
 
145
  ura_model = _get_model(device)
 
146
  target_side = ura_model.image_size
147
- # image: [H, W, 3] uint8 0–255
148
  h, w = image.shape[:2]
149
- tensor = TF.to_tensor(image).unsqueeze(0) # [1, 3, H, W], [0, 1]
 
 
150
  tensor = TF.resize(tensor, [target_side, target_side], antialias=True)
151
- tensor = tensor.to(ura_model.device, dtype=torch.float32)
152
- mask = tensor.mean(1, keepdim=True) > 0.9 # [1, 1, S, S]
153
- import time
 
154
 
155
  with torch.no_grad():
156
  start_time = time.time()
 
157
  diffuse = ura_model(images=tensor, inpaint_mask_override=mask)
158
  end_time = time.time()
159
- diffuse = diffuse.cpu()
160
  inference_time_ms = (end_time - start_time) * 1000
161
- gr.Success(f"Inference time: {inference_time_ms:.1f} ms")
 
 
 
162
  diffuse = TF.resize(diffuse, [h, w], antialias=True)
163
  out = diffuse[0].numpy().transpose(1, 2, 0)
164
  out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
165
  return out
166
 
 
 
167
  def run_inference_slider(
168
  image: np.ndarray | None,
169
  ) -> tuple[np.ndarray | None, np.ndarray | None] | None:
@@ -277,5 +296,9 @@ def _launch_with_allowed_paths(*args, **kwargs):
277
  demo.launch = _launch_with_allowed_paths
278
 
279
 
 
280
  if __name__ == "__main__":
281
- demo.launch()
 
 
 
 
108
 
109
 
110
  def _get_model(device: str):
111
+ """Return the pretrained model, loading it once and moving to the requested device."""
112
  global _cached_ura_model, _cached_device
113
  assets = _get_assets()
114
+
 
115
  from unreflectanything import model
116
 
117
+ # If the model isn't loaded yet, initialize it
118
+ if _cached_ura_model is None:
119
+ print(f"Loading model initially on {device}...")
120
+ _cached_ura_model = model(
121
+ pretrained=True,
122
+ weights_path=assets.weights_path,
123
+ config_path=assets.config_path,
124
+ device=device,
125
+ verbose=False,
126
+ skip_path_resolution=True,
127
+ )
128
+ _cached_device = device
129
+
130
+ # If the model is loaded but on the wrong device, move it
131
+ if _cached_device != device:
132
+ print(f"Moving model from {_cached_device} to {device}...")
133
+ _cached_ura_model.to(device)
134
+ _cached_device = device
135
+
136
  return _cached_ura_model
137
 
 
138
  def build_ui():
139
  _get_assets()
140
+ # PREVENT: _get_model("cuda") here. It will crash ZeroGPU during startup.
141
+ print("UI building... Model will initialize on first inference.")
 
 
142
 
143
+ # Note: Use the decorator directly on the function that does the heavy lifting
144
  @spaces.GPU if spaces else lambda x: x
145
  def run_inference(image: np.ndarray | None) -> np.ndarray | None:
146
+ """Run reflection removal using the cached model on GPU."""
147
  if image is None:
148
  return None
149
+
150
  from torchvision.transforms import functional as TF
151
+ import time
152
 
153
+ # Now it is safe to request 'cuda' because we are inside the @spaces.GPU wrapper
154
+ device = "cuda" if (torch.cuda.is_available() and spaces) else "cpu"
155
  ura_model = _get_model(device)
156
+
157
  target_side = ura_model.image_size
 
158
  h, w = image.shape[:2]
159
+
160
+ # Pre-processing
161
+ tensor = TF.to_tensor(image).unsqueeze(0) # [1, 3, H, W]
162
  tensor = TF.resize(tensor, [target_side, target_side], antialias=True)
163
+ tensor = tensor.to(device, dtype=torch.float32)
164
+
165
+ # Create mask based on highlights
166
+ mask = tensor.mean(1, keepdim=True) > 0.9
167
 
168
  with torch.no_grad():
169
  start_time = time.time()
170
+ # The model is already on 'device' thanks to _get_model
171
  diffuse = ura_model(images=tensor, inpaint_mask_override=mask)
172
  end_time = time.time()
173
+
174
  inference_time_ms = (end_time - start_time) * 1000
175
+ gr.Info(f"Inference complete in {inference_time_ms:.1f} ms") # Use gr.Info for better UX
176
+
177
+ # Post-processing
178
+ diffuse = diffuse.cpu()
179
  diffuse = TF.resize(diffuse, [h, w], antialias=True)
180
  out = diffuse[0].numpy().transpose(1, 2, 0)
181
  out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
182
  return out
183
 
184
+ # ... keep your run_inference_slider and UI layout code the same ...
185
+
186
  def run_inference_slider(
187
  image: np.ndarray | None,
188
  ) -> tuple[np.ndarray | None, np.ndarray | None] | None:
 
296
  demo.launch = _launch_with_allowed_paths
297
 
298
 
299
+ # Replace your existing launch logic at the very bottom of the file with this:
300
  if __name__ == "__main__":
301
+ demo.launch(ssr_mode=False, server_name="0.0.0.0", server_port=7860)
302
+ else:
303
+ # This handles cases where Hugging Face imports the file
304
+ demo.launch(ssr_mode=False, server_name="0.0.0.0", server_port=7860)