usingcolor commited on
Commit
0f440c7
·
1 Parent(s): f33c68f

chore: initialize git repository and project structure

Browse files
Files changed (1) hide show
  1. app.py +39 -19
app.py CHANGED
@@ -53,9 +53,6 @@ from mambaeye.positional_encoding import sinusoidal_position_encoding_2d
53
  from mamba_ssm.utils.generation import InferenceParams
54
 
55
  # Global Configuration
56
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
- print(f"Using device: {DEVICE}")
58
-
59
  TARGET_CANVAS_SIZE = 512
60
  PATCH_SIZE = 16
61
  CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]
@@ -79,22 +76,41 @@ _GLOBAL_MODEL = None
79
 
80
  def get_model():
81
  global _GLOBAL_MODEL
 
82
  if _GLOBAL_MODEL is None:
83
  print(f"Downloading {MODEL_FILENAME} from {MODEL_REPO}...")
84
  try:
85
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
86
  model = MambaEye(**MODEL_CONFIG)
87
 
88
- # On zero_gpu, downloading weights might happen on CPU first
89
- map_loc = torch.device('cpu')
90
- model.load_state_dict(torch.load(checkpoint_path, map_location=map_loc))
91
  model.eval()
92
  _GLOBAL_MODEL = model
93
  print("Model loaded successfully.")
94
  except Exception as e:
95
  print(f"Failed to load model: {e}")
96
  raise
97
- return _GLOBAL_MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def _compute_move_embedding(patch_location: torch.Tensor, cur_location: torch.Tensor = None) -> torch.Tensor:
100
  if cur_location is None:
@@ -184,8 +200,7 @@ def run_auto_scan(image, scan_pattern, sequence_length):
184
  if image is None:
185
  return None, {"Upload Image": 1.0}, None, "Upload Image"
186
 
187
- model = get_model()
188
- model.to(DEVICE)
189
 
190
  state = init_state_for_image(image)
191
 
@@ -203,18 +218,17 @@ def run_auto_scan(image, scan_pattern, sequence_length):
203
  )
204
 
205
  inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1)
206
- state['inference_params'] = inference_params
207
 
208
  patches_list = []
209
  moves_list = []
210
  cur_location = None
211
 
212
  for px, py in positions:
213
- loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=DEVICE)
214
  move_emb = _compute_move_embedding(loc_tensor, cur_location)
215
  cur_location = loc_tensor
216
 
217
- patch = extract_patch(state['canvas_tensor'], px, py).to(DEVICE)
218
  patches_list.append(patch)
219
  moves_list.append(move_emb.squeeze(0))
220
 
@@ -229,8 +243,10 @@ def run_auto_scan(image, scan_pattern, sequence_length):
229
  state['cur_location'] = cur_location.cpu()
230
  state['drawn_positions'] = positions
231
  state['sequence_length'] = sequence_length
232
- # On ZeroGPU spaces safely store Tensors back to CPU State
 
233
  state['canvas_tensor'] = state['canvas_tensor'].cpu()
 
234
 
235
  img_display, _ = draw_patches_on_image(
236
  state['original_image'], state['drawn_positions'],
@@ -244,14 +260,15 @@ def on_click(evt: gr.SelectData, original_image, state):
244
  if original_image is None:
245
  return None, {"Upload Image": 1.0}, state, "Upload Image"
246
 
247
- model = get_model()
248
- model.to(DEVICE)
249
 
250
  if state is None or state.get('inference_params') is None:
251
- # Initialize state to begin a new purely user-guided sequence
252
  state = init_state_for_image(original_image)
253
  state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
254
 
 
 
 
255
  x_orig, y_orig = evt.index
256
  orig_h, orig_w = state['original_image'].shape[:2]
257
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
@@ -262,11 +279,11 @@ def on_click(evt: gr.SelectData, original_image, state):
262
  px = (canvas_x // PATCH_SIZE) * PATCH_SIZE
263
  py = (canvas_y // PATCH_SIZE) * PATCH_SIZE
264
 
265
- cur_loc = state['cur_location'].to(DEVICE) if state['cur_location'] is not None else None
266
- loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=DEVICE)
267
  move_emb = _compute_move_embedding(loc_tensor, cur_loc)
268
 
269
- patch = extract_patch(state['canvas_tensor'], px, py).to(DEVICE)
270
 
271
  img_seq = patch.unsqueeze(0).unsqueeze(0) # (1, 1, 768)
272
  move_seq = move_emb.unsqueeze(0) # (1, 1, 512)
@@ -280,6 +297,9 @@ def on_click(evt: gr.SelectData, original_image, state):
280
  state['drawn_positions'].append((px, py))
281
  state['sequence_length'] += 1
282
 
 
 
 
283
  img_display, _ = draw_patches_on_image(
284
  state['original_image'], state['drawn_positions'],
285
  state['x_offset'], state['y_offset'], state['h'], state['w']
 
53
  from mamba_ssm.utils.generation import InferenceParams
54
 
55
  # Global Configuration
 
 
 
56
  TARGET_CANVAS_SIZE = 512
57
  PATCH_SIZE = 16
58
  CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]
 
76
 
77
  def get_model():
78
  global _GLOBAL_MODEL
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
  if _GLOBAL_MODEL is None:
81
  print(f"Downloading {MODEL_FILENAME} from {MODEL_REPO}...")
82
  try:
83
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
84
  model = MambaEye(**MODEL_CONFIG)
85
 
86
+ # Since this runs inside ZeroGPU worker, load directly to device
87
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
88
+ model.to(device)
89
  model.eval()
90
  _GLOBAL_MODEL = model
91
  print("Model loaded successfully.")
92
  except Exception as e:
93
  print(f"Failed to load model: {e}")
94
  raise
95
+ return _GLOBAL_MODEL, device
96
+
97
+ def transfer_inference_params(params, device):
98
+ """Recursively moves the KV cache state of MambaEye InferenceParams to CPU or CUDA."""
99
+ if params is None or getattr(params, "key_value_memory_dict", None) is None:
100
+ return params
101
+
102
+ for k, v in params.key_value_memory_dict.items():
103
+ if isinstance(v, torch.Tensor):
104
+ params.key_value_memory_dict[k] = v.to(device)
105
+ elif isinstance(v, tuple):
106
+ params.key_value_memory_dict[k] = tuple(x.to(device) if isinstance(x, torch.Tensor) else x for x in v)
107
+ elif isinstance(v, list):
108
+ params.key_value_memory_dict[k] = [x.to(device) if isinstance(x, torch.Tensor) else x for x in v]
109
+ elif isinstance(v, dict): # E.g., layers map
110
+ for k2, v2 in v.items():
111
+ if hasattr(v2, "to"):
112
+ params.key_value_memory_dict[k][k2] = v2.to(device)
113
+ return params
114
 
115
  def _compute_move_embedding(patch_location: torch.Tensor, cur_location: torch.Tensor = None) -> torch.Tensor:
116
  if cur_location is None:
 
200
  if image is None:
201
  return None, {"Upload Image": 1.0}, None, "Upload Image"
202
 
203
+ model, device = get_model()
 
204
 
205
  state = init_state_for_image(image)
206
 
 
218
  )
219
 
220
  inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1)
 
221
 
222
  patches_list = []
223
  moves_list = []
224
  cur_location = None
225
 
226
  for px, py in positions:
227
+ loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device)
228
  move_emb = _compute_move_embedding(loc_tensor, cur_location)
229
  cur_location = loc_tensor
230
 
231
+ patch = extract_patch(state['canvas_tensor'], px, py).to(device)
232
  patches_list.append(patch)
233
  moves_list.append(move_emb.squeeze(0))
234
 
 
243
  state['cur_location'] = cur_location.cpu()
244
  state['drawn_positions'] = positions
245
  state['sequence_length'] = sequence_length
246
+
247
+ # On ZeroGPU spaces securely move Tensors back to CPU State
248
  state['canvas_tensor'] = state['canvas_tensor'].cpu()
249
+ state['inference_params'] = transfer_inference_params(inference_params, torch.device('cpu'))
250
 
251
  img_display, _ = draw_patches_on_image(
252
  state['original_image'], state['drawn_positions'],
 
260
  if original_image is None:
261
  return None, {"Upload Image": 1.0}, state, "Upload Image"
262
 
263
+ model, device = get_model()
 
264
 
265
  if state is None or state.get('inference_params') is None:
 
266
  state = init_state_for_image(original_image)
267
  state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
268
 
269
+ # Move InferenceParams back to the functional device correctly!
270
+ state['inference_params'] = transfer_inference_params(state['inference_params'], device)
271
+
272
  x_orig, y_orig = evt.index
273
  orig_h, orig_w = state['original_image'].shape[:2]
274
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
 
279
  px = (canvas_x // PATCH_SIZE) * PATCH_SIZE
280
  py = (canvas_y // PATCH_SIZE) * PATCH_SIZE
281
 
282
+ cur_loc = state['cur_location'].to(device) if state['cur_location'] is not None else None
283
+ loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device)
284
  move_emb = _compute_move_embedding(loc_tensor, cur_loc)
285
 
286
+ patch = extract_patch(state['canvas_tensor'], px, py).to(device)
287
 
288
  img_seq = patch.unsqueeze(0).unsqueeze(0) # (1, 1, 768)
289
  move_seq = move_emb.unsqueeze(0) # (1, 1, 512)
 
297
  state['drawn_positions'].append((px, py))
298
  state['sequence_length'] += 1
299
 
300
+ # Strip back to CPU for Gradio Session Memory
301
+ state['inference_params'] = transfer_inference_params(state['inference_params'], torch.device('cpu'))
302
+
303
  img_display, _ = draw_patches_on_image(
304
  state['original_image'], state['drawn_positions'],
305
  state['x_offset'], state['y_offset'], state['h'], state['w']