MogensR commited on
Commit
c137c1a
·
1 Parent(s): 2769fce

lunch done

Browse files
Files changed (1) hide show
  1. models/matanyone_loader.py +106 -60
models/matanyone_loader.py CHANGED
@@ -1,72 +1,118 @@
1
- # models/matanyone_loader.py
2
- import os, logging, torch, gc
3
- import numpy as np
4
- from typing import Optional, Tuple
 
 
 
5
 
6
- log = logging.getLogger("matany_loader")
 
 
 
 
7
 
8
- def _import_inference_core():
9
- try:
10
- # Check the actual import path from pq-yang/MatAnyone repo
11
- from matanyone.inference_core import InferenceCore
12
- return InferenceCore
13
- except Exception as e:
14
- log.error("MatAnyone import failed (vendoring/repo path?): %s", e)
15
- return None
16
 
17
- def _to_chw01(img):
18
- # img: HWC uint8 or float01 -> CHW float01
19
- if img.dtype != np.float32:
20
- img = img.astype("float32")/255.0
21
- return np.transpose(img, (2,0,1))
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def _to_1hw01(mask):
24
- # mask: HxW [0,1]
25
- m = mask.astype("float32")
26
- return m[None, ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
28
  class MatAnyoneSession:
29
- def __init__(self, device: torch.device, precision: str = "fp16"):
30
- self.device = device
31
- self.precision = precision
32
  self.core = None
 
33
 
34
- def load(self, ckpt_path: Optional[str] = None, repo_id: Optional[str] = None, filename: Optional[str] = None):
35
- InferenceCore = _import_inference_core()
36
- if InferenceCore is None:
37
- raise RuntimeError("MatAnyone not importable")
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- if ckpt_path is None and repo_id and filename:
40
- from huggingface_hub import hf_hub_download
41
- ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.environ.get("HF_HOME"))
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # init model
44
- self.core = InferenceCore(ckpt_path, device=str(self.device))
45
- return self
 
 
46
 
47
- @torch.inference_mode()
48
- def step(self, image_rgb, seed_mask: Optional[np.ndarray]=None):
49
- """
50
- image_rgb: HxWx3 uint8/float01
51
- seed_mask: HxW float01 for first frame, else None
52
- returns alpha HxW float01
53
- """
54
- assert self.core is not None, "MatAnyone not loaded"
55
- img = _to_chw01(image_rgb) # CHW
56
- if seed_mask is not None:
57
- mask = _to_1hw01(seed_mask) # 1HW
58
- alpha = self.core.step(img, mask)
59
- else:
60
- alpha = self.core.step(img, None)
61
- # ensure HxW
62
- if isinstance(alpha, np.ndarray):
63
- return alpha.astype("float32")
64
- if torch.is_tensor(alpha):
65
- return alpha.detach().float().cpu().numpy()
66
- raise RuntimeError("MatAnyone returned unknown alpha type")
67
 
68
- def reset(self):
69
- if self.core and hasattr(self.core, "reset"):
70
- self.core.reset()
71
- torch.cuda.empty_cache()
72
- gc.collect()
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MatAnyone Loader (compact)
4
+ - Uses top-level wrapper: `from matanyone import InferenceCore`
5
+ - Constructor takes a model/repo id string (e.g. "PeiqingYang/MatAnyone")
6
+ - Normalizes inputs: image -> CHW float32 [0,1], mask -> 1HW float32 [0,1]
7
+ """
8
 
9
+ from __future__ import annotations
10
+ import os, logging, time
11
+ from typing import Iterable, Optional
12
+ import numpy as np
13
+ import torch
14
 
15
+ logger = logging.getLogger("backgroundfx_pro")
 
 
 
 
 
 
 
16
 
17
+ # ---------- tiny helpers ----------
18
+ def _to_chw_float01(x: np.ndarray | torch.Tensor) -> torch.Tensor:
19
+ if isinstance(x, np.ndarray):
20
+ t = torch.from_numpy(x)
21
+ else:
22
+ t = x
23
+ if t.ndim == 3 and t.shape[-1] in (1, 3, 4): # HWC
24
+ t = t.permute(2, 0, 1) # -> CHW
25
+ elif t.ndim == 2: # HW -> 1HW
26
+ t = t.unsqueeze(0)
27
+ elif t.ndim != 3:
28
+ raise ValueError(f"image: bad shape {tuple(t.shape)}")
29
+ t = t.contiguous().to(torch.float32)
30
+ with torch.no_grad():
31
+ if t.numel() and (torch.nanmax(t) > 1.0 or torch.nanmin(t) < 0.0):
32
+ t = t / 255.0
33
+ t.clamp_(0.0, 1.0)
34
+ return t
35
 
36
+ def _to_1hw_float01(m: np.ndarray | torch.Tensor) -> torch.Tensor:
37
+ if isinstance(m, np.ndarray):
38
+ t = torch.from_numpy(m)
39
+ else:
40
+ t = m
41
+ if t.ndim == 2: # HW
42
+ t = t.unsqueeze(0) # -> 1HW
43
+ elif t.ndim == 3:
44
+ if t.shape[0] in (1, 3): # CHW
45
+ t = t[:1, ...]
46
+ elif t.shape[-1] in (1, 3): # HWC
47
+ t = t[..., 0]
48
+ t = t.unsqueeze(0)
49
+ else:
50
+ raise ValueError(f"mask: bad shape {tuple(t.shape)}")
51
+ else:
52
+ raise ValueError(f"mask: bad shape {tuple(t.shape)}")
53
+ t = t.contiguous().to(torch.float32)
54
+ with torch.no_grad():
55
+ if t.numel() and (torch.nanmax(t) > 1.0 or torch.nanmin(t) < 0.0):
56
+ t = t / 255.0
57
+ t.clamp_(0.0, 1.0)
58
+ return t
59
 
60
+ # ---------- session ----------
61
  class MatAnyoneSession:
62
+ def __init__(self, device: Optional[str] = None, repo_id: Optional[str] = None) -> None:
63
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
64
+ self.repo_id = repo_id or os.getenv("MATANY_REPO_ID", "PeiqingYang/MatAnyone")
65
  self.core = None
66
+ self.loaded = False
67
 
68
+ def load(self) -> bool:
69
+ t0 = time.time()
70
+ try:
71
+ # ✅ top-level wrapper (accepts model/repo id string)
72
+ from matanyone import InferenceCore
73
+ logger.info("[MatA] init: repo_id=%s device=%s", self.repo_id, self.device)
74
+ self.core = InferenceCore(self.repo_id)
75
+ self.loaded = True
76
+ logger.info("[MatA] init OK (%.2fs)", time.time() - t0)
77
+ return True
78
+ except TypeError as e:
79
+ logger.error("MatAnyone constructor mismatch: %s (fork expects network=...)", e)
80
+ except Exception as e:
81
+ logger.error("MatAnyone init error: %s", e)
82
+ self.loaded = False
83
+ return False
84
 
85
+ def step(self, image: np.ndarray | torch.Tensor, seed_mask: np.ndarray | torch.Tensor) -> np.ndarray:
86
+ if not self.loaded or self.core is None:
87
+ raise RuntimeError("MatAnyone not loaded")
88
+ img = _to_chw_float01(image).to(self.device, non_blocking=True)
89
+ msk = _to_1hw_float01(seed_mask).to(self.device, non_blocking=True)
90
+ out = self.core.step(img, msk)
91
+ alpha = out[0] if isinstance(out, (tuple, list)) else out
92
+ if not isinstance(alpha, torch.Tensor):
93
+ alpha = torch.as_tensor(alpha)
94
+ if alpha.ndim == 3 and alpha.shape[0] == 1:
95
+ alpha = alpha[0]
96
+ if alpha.ndim != 2:
97
+ raise ValueError(f"alpha: bad shape {tuple(alpha.shape)}")
98
+ return alpha.detach().to("cpu", torch.float32).clamp_(0.0, 1.0).contiguous().numpy()
99
 
100
+ def process_video(self, frames: Iterable[np.ndarray | torch.Tensor], seed_mask_hw, every: int = 50):
101
+ for i, f in enumerate(frames, 1):
102
+ yield self.step(f, seed_mask_hw)
103
+ if every and (i % every == 0):
104
+ logger.info("[MatA] processed %d frames", i)
105
 
106
+ def close(self) -> None:
107
+ self.core = None
108
+ self.loaded = False
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # ---------- factory ----------
113
+ def get_matanyone_session(enable: bool = True) -> Optional[MatAnyoneSession]:
114
+ if not enable:
115
+ logger.info("[MatA] disabled.")
116
+ return None
117
+ s = MatAnyoneSession()
118
+ return s if s.load() else None