Spaces:
Sleeping
Sleeping
Commit ·
0f440c7
1
Parent(s): f33c68f
chore: initialize git repository and project structure
Browse files
app.py
CHANGED
|
@@ -53,9 +53,6 @@ from mambaeye.positional_encoding import sinusoidal_position_encoding_2d
|
|
| 53 |
from mamba_ssm.utils.generation import InferenceParams
|
| 54 |
|
| 55 |
# Global Configuration
|
| 56 |
-
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
-
print(f"Using device: {DEVICE}")
|
| 58 |
-
|
| 59 |
TARGET_CANVAS_SIZE = 512
|
| 60 |
PATCH_SIZE = 16
|
| 61 |
CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]
|
|
@@ -79,22 +76,41 @@ _GLOBAL_MODEL = None
|
|
| 79 |
|
| 80 |
def get_model():
|
| 81 |
global _GLOBAL_MODEL
|
|
|
|
| 82 |
if _GLOBAL_MODEL is None:
|
| 83 |
print(f"Downloading {MODEL_FILENAME} from {MODEL_REPO}...")
|
| 84 |
try:
|
| 85 |
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
|
| 86 |
model = MambaEye(**MODEL_CONFIG)
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
model.
|
| 91 |
model.eval()
|
| 92 |
_GLOBAL_MODEL = model
|
| 93 |
print("Model loaded successfully.")
|
| 94 |
except Exception as e:
|
| 95 |
print(f"Failed to load model: {e}")
|
| 96 |
raise
|
| 97 |
-
return _GLOBAL_MODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
def _compute_move_embedding(patch_location: torch.Tensor, cur_location: torch.Tensor = None) -> torch.Tensor:
|
| 100 |
if cur_location is None:
|
|
@@ -184,8 +200,7 @@ def run_auto_scan(image, scan_pattern, sequence_length):
|
|
| 184 |
if image is None:
|
| 185 |
return None, {"Upload Image": 1.0}, None, "Upload Image"
|
| 186 |
|
| 187 |
-
model = get_model()
|
| 188 |
-
model.to(DEVICE)
|
| 189 |
|
| 190 |
state = init_state_for_image(image)
|
| 191 |
|
|
@@ -203,18 +218,17 @@ def run_auto_scan(image, scan_pattern, sequence_length):
|
|
| 203 |
)
|
| 204 |
|
| 205 |
inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1)
|
| 206 |
-
state['inference_params'] = inference_params
|
| 207 |
|
| 208 |
patches_list = []
|
| 209 |
moves_list = []
|
| 210 |
cur_location = None
|
| 211 |
|
| 212 |
for px, py in positions:
|
| 213 |
-
loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=
|
| 214 |
move_emb = _compute_move_embedding(loc_tensor, cur_location)
|
| 215 |
cur_location = loc_tensor
|
| 216 |
|
| 217 |
-
patch = extract_patch(state['canvas_tensor'], px, py).to(
|
| 218 |
patches_list.append(patch)
|
| 219 |
moves_list.append(move_emb.squeeze(0))
|
| 220 |
|
|
@@ -229,8 +243,10 @@ def run_auto_scan(image, scan_pattern, sequence_length):
|
|
| 229 |
state['cur_location'] = cur_location.cpu()
|
| 230 |
state['drawn_positions'] = positions
|
| 231 |
state['sequence_length'] = sequence_length
|
| 232 |
-
|
|
|
|
| 233 |
state['canvas_tensor'] = state['canvas_tensor'].cpu()
|
|
|
|
| 234 |
|
| 235 |
img_display, _ = draw_patches_on_image(
|
| 236 |
state['original_image'], state['drawn_positions'],
|
|
@@ -244,14 +260,15 @@ def on_click(evt: gr.SelectData, original_image, state):
|
|
| 244 |
if original_image is None:
|
| 245 |
return None, {"Upload Image": 1.0}, state, "Upload Image"
|
| 246 |
|
| 247 |
-
model = get_model()
|
| 248 |
-
model.to(DEVICE)
|
| 249 |
|
| 250 |
if state is None or state.get('inference_params') is None:
|
| 251 |
-
# Initialize state to begin a new purely user-guided sequence
|
| 252 |
state = init_state_for_image(original_image)
|
| 253 |
state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
|
| 254 |
|
|
|
|
|
|
|
|
|
|
| 255 |
x_orig, y_orig = evt.index
|
| 256 |
orig_h, orig_w = state['original_image'].shape[:2]
|
| 257 |
ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
|
|
@@ -262,11 +279,11 @@ def on_click(evt: gr.SelectData, original_image, state):
|
|
| 262 |
px = (canvas_x // PATCH_SIZE) * PATCH_SIZE
|
| 263 |
py = (canvas_y // PATCH_SIZE) * PATCH_SIZE
|
| 264 |
|
| 265 |
-
cur_loc = state['cur_location'].to(
|
| 266 |
-
loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=
|
| 267 |
move_emb = _compute_move_embedding(loc_tensor, cur_loc)
|
| 268 |
|
| 269 |
-
patch = extract_patch(state['canvas_tensor'], px, py).to(
|
| 270 |
|
| 271 |
img_seq = patch.unsqueeze(0).unsqueeze(0) # (1, 1, 768)
|
| 272 |
move_seq = move_emb.unsqueeze(0) # (1, 1, 512)
|
|
@@ -280,6 +297,9 @@ def on_click(evt: gr.SelectData, original_image, state):
|
|
| 280 |
state['drawn_positions'].append((px, py))
|
| 281 |
state['sequence_length'] += 1
|
| 282 |
|
|
|
|
|
|
|
|
|
|
| 283 |
img_display, _ = draw_patches_on_image(
|
| 284 |
state['original_image'], state['drawn_positions'],
|
| 285 |
state['x_offset'], state['y_offset'], state['h'], state['w']
|
|
|
|
| 53 |
from mamba_ssm.utils.generation import InferenceParams
|
| 54 |
|
| 55 |
# Global Configuration
|
|
|
|
|
|
|
|
|
|
| 56 |
TARGET_CANVAS_SIZE = 512
|
| 57 |
PATCH_SIZE = 16
|
| 58 |
CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]
|
|
|
|
| 76 |
|
| 77 |
def get_model():
|
| 78 |
global _GLOBAL_MODEL
|
| 79 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
if _GLOBAL_MODEL is None:
|
| 81 |
print(f"Downloading {MODEL_FILENAME} from {MODEL_REPO}...")
|
| 82 |
try:
|
| 83 |
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
|
| 84 |
model = MambaEye(**MODEL_CONFIG)
|
| 85 |
|
| 86 |
+
# Since this runs inside ZeroGPU worker, load directly to device
|
| 87 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| 88 |
+
model.to(device)
|
| 89 |
model.eval()
|
| 90 |
_GLOBAL_MODEL = model
|
| 91 |
print("Model loaded successfully.")
|
| 92 |
except Exception as e:
|
| 93 |
print(f"Failed to load model: {e}")
|
| 94 |
raise
|
| 95 |
+
return _GLOBAL_MODEL, device
|
| 96 |
+
|
| 97 |
+
def transfer_inference_params(params, device):
|
| 98 |
+
"""Recursively moves the KV cache state of MambaEye InferenceParams to CPU or CUDA."""
|
| 99 |
+
if params is None or getattr(params, "key_value_memory_dict", None) is None:
|
| 100 |
+
return params
|
| 101 |
+
|
| 102 |
+
for k, v in params.key_value_memory_dict.items():
|
| 103 |
+
if isinstance(v, torch.Tensor):
|
| 104 |
+
params.key_value_memory_dict[k] = v.to(device)
|
| 105 |
+
elif isinstance(v, tuple):
|
| 106 |
+
params.key_value_memory_dict[k] = tuple(x.to(device) if isinstance(x, torch.Tensor) else x for x in v)
|
| 107 |
+
elif isinstance(v, list):
|
| 108 |
+
params.key_value_memory_dict[k] = [x.to(device) if isinstance(x, torch.Tensor) else x for x in v]
|
| 109 |
+
elif isinstance(v, dict): # E.g., layers map
|
| 110 |
+
for k2, v2 in v.items():
|
| 111 |
+
if hasattr(v2, "to"):
|
| 112 |
+
params.key_value_memory_dict[k][k2] = v2.to(device)
|
| 113 |
+
return params
|
| 114 |
|
| 115 |
def _compute_move_embedding(patch_location: torch.Tensor, cur_location: torch.Tensor = None) -> torch.Tensor:
|
| 116 |
if cur_location is None:
|
|
|
|
| 200 |
if image is None:
|
| 201 |
return None, {"Upload Image": 1.0}, None, "Upload Image"
|
| 202 |
|
| 203 |
+
model, device = get_model()
|
|
|
|
| 204 |
|
| 205 |
state = init_state_for_image(image)
|
| 206 |
|
|
|
|
| 218 |
)
|
| 219 |
|
| 220 |
inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1)
|
|
|
|
| 221 |
|
| 222 |
patches_list = []
|
| 223 |
moves_list = []
|
| 224 |
cur_location = None
|
| 225 |
|
| 226 |
for px, py in positions:
|
| 227 |
+
loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device)
|
| 228 |
move_emb = _compute_move_embedding(loc_tensor, cur_location)
|
| 229 |
cur_location = loc_tensor
|
| 230 |
|
| 231 |
+
patch = extract_patch(state['canvas_tensor'], px, py).to(device)
|
| 232 |
patches_list.append(patch)
|
| 233 |
moves_list.append(move_emb.squeeze(0))
|
| 234 |
|
|
|
|
| 243 |
state['cur_location'] = cur_location.cpu()
|
| 244 |
state['drawn_positions'] = positions
|
| 245 |
state['sequence_length'] = sequence_length
|
| 246 |
+
|
| 247 |
+
# On ZeroGPU spaces securely move Tensors back to CPU State
|
| 248 |
state['canvas_tensor'] = state['canvas_tensor'].cpu()
|
| 249 |
+
state['inference_params'] = transfer_inference_params(inference_params, torch.device('cpu'))
|
| 250 |
|
| 251 |
img_display, _ = draw_patches_on_image(
|
| 252 |
state['original_image'], state['drawn_positions'],
|
|
|
|
| 260 |
if original_image is None:
|
| 261 |
return None, {"Upload Image": 1.0}, state, "Upload Image"
|
| 262 |
|
| 263 |
+
model, device = get_model()
|
|
|
|
| 264 |
|
| 265 |
if state is None or state.get('inference_params') is None:
|
|
|
|
| 266 |
state = init_state_for_image(original_image)
|
| 267 |
state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
|
| 268 |
|
| 269 |
+
# Move InferenceParams back to the functional device correctly!
|
| 270 |
+
state['inference_params'] = transfer_inference_params(state['inference_params'], device)
|
| 271 |
+
|
| 272 |
x_orig, y_orig = evt.index
|
| 273 |
orig_h, orig_w = state['original_image'].shape[:2]
|
| 274 |
ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
|
|
|
|
| 279 |
px = (canvas_x // PATCH_SIZE) * PATCH_SIZE
|
| 280 |
py = (canvas_y // PATCH_SIZE) * PATCH_SIZE
|
| 281 |
|
| 282 |
+
cur_loc = state['cur_location'].to(device) if state['cur_location'] is not None else None
|
| 283 |
+
loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device)
|
| 284 |
move_emb = _compute_move_embedding(loc_tensor, cur_loc)
|
| 285 |
|
| 286 |
+
patch = extract_patch(state['canvas_tensor'], px, py).to(device)
|
| 287 |
|
| 288 |
img_seq = patch.unsqueeze(0).unsqueeze(0) # (1, 1, 768)
|
| 289 |
move_seq = move_emb.unsqueeze(0) # (1, 1, 512)
|
|
|
|
| 297 |
state['drawn_positions'].append((px, py))
|
| 298 |
state['sequence_length'] += 1
|
| 299 |
|
| 300 |
+
# Strip back to CPU for Gradio Session Memory
|
| 301 |
+
state['inference_params'] = transfer_inference_params(state['inference_params'], torch.device('cpu'))
|
| 302 |
+
|
| 303 |
img_display, _ = draw_patches_on_image(
|
| 304 |
state['original_image'], state['drawn_positions'],
|
| 305 |
state['x_offset'], state['y_offset'], state['h'], state['w']
|