MogensR commited on
Commit
6ca9173
·
1 Parent(s): e6ecd1a

soon fixed

Browse files
Files changed (1) hide show
  1. models/matany_compat_patch.py +18 -10
models/matany_compat_patch.py CHANGED
@@ -49,17 +49,25 @@ def apply_matany_t1_squeeze_guard() -> bool:
49
  # Store original method
50
  orig_method = getattr(MatAnyone, method_name)
51
 
52
- def method_compat(self, img, *args, **kwargs):
53
- try:
54
- if isinstance(img, torch.Tensor) and img.dim() == 5 and img.shape[1] == 1:
55
- log.info(f"[MatAnyCompat] Squeezing 5D {img.shape} to 4D {img.squeeze(1).shape} in {method_name}")
56
- img = img.squeeze(1) # [B,1,C,H,W] [B,C,H,W]
57
- except Exception as e:
58
- log.warning(f"[MatAnyCompat] Failed to process input shape in {method_name}: %s", e)
59
- return orig_method(self, img, *args, **kwargs)
 
 
 
 
 
 
 
 
60
 
61
- setattr(MatAnyone, method_name, method_compat)
62
- setattr(MatAnyone, f"_{method_name}_patched", True)
63
  log.info(f"[MatAnyCompat] Applied T=1 squeeze guard in MatAnyone.{method_name}")
64
  patched = True
65
 
 
49
  # Store original method
50
  orig_method = getattr(MatAnyone, method_name)
51
 
52
+ def create_method_compat(orig_fn, name):
53
+ def method_compat(self, img, *args, **kwargs):
54
+ try:
55
+ if isinstance(img, torch.Tensor):
56
+ # Handle [B, 1, C, H, W] -> [B, C, H, W]
57
+ if img.dim() == 5 and img.shape[1] == 1:
58
+ log.info(f"[MatAnyCompat] Squeezing 5D {img.shape} to 4D {img.squeeze(1).shape} in {name}")
59
+ img = img.squeeze(1)
60
+ # Handle [1, 1, C, H, W] -> [C, H, W] (remove both batch and time dims)
61
+ elif img.dim() == 5 and img.shape[0] == 1 and img.shape[1] == 1:
62
+ log.info(f"[MatAnyCompat] Squeezing 5D {img.shape} to 3D {img.squeeze(0).squeeze(0).shape} in {name}")
63
+ img = img.squeeze(0).squeeze(0) # Remove batch and time dimensions
64
+ except Exception as e:
65
+ log.warning(f"[MatAnyCompat] Failed to process input shape in {name}: %s", e)
66
+ return orig_fn(self, img, *args, **kwargs)
67
+ return method_compat
68
 
69
+ setattr(MatAnyone, method_name, create_method_compat(orig_method, method_name))
70
+ setattr(MatAnyone, f"_{method_name}_patched", True) # FIXED: Changed MatAny to MatAnyone
71
  log.info(f"[MatAnyCompat] Applied T=1 squeeze guard in MatAnyone.{method_name}")
72
  patched = True
73