MogensR commited on
Commit
42d7c0b
·
1 Parent(s): 84cb6bd

Phase 1: Add SAM2/MatAnyone optimization infrastructure

Browse files
Dockerfile CHANGED
@@ -1,11 +1,57 @@
1
- # syntax=docker/dockerfile:1
2
- FROM python:3.11-slim
3
 
4
- WORKDIR /code
 
 
 
 
5
 
6
- COPY requirements.txt .
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  RUN pip install --no-cache-dir -r requirements.txt
8
 
 
 
 
 
 
 
 
 
 
 
9
  COPY . .
10
 
11
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Production Dockerfile for BackgroundFX Pro with SAM2 + MatAnyone
2
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
3
 
4
+ # System dependencies
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ git ffmpeg libglib2.0-0 libgl1 libglib2.0-0 libsm6 libxrender1 libxext6 \
7
+ python3.10 python3.10-venv python3-pip \
8
+ && rm -rf /var/lib/apt/lists/*
9
 
10
+ # Upgrade pip
11
+ RUN python3 -m pip install --upgrade pip
12
+
13
+ # Environment variables for caching and performance
14
+ ENV HF_HOME=/home/user/.cache/huggingface \
15
+ TORCH_HOME=/home/user/.cache/torch \
16
+ TRANSFORMERS_CACHE=/home/user/.cache/transformers \
17
+ MPLCONFIGDIR=/home/user/.config/matplotlib
18
+
19
+ # CUDA and memory optimizations for T4
20
+ ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,expandable_segments:True \
21
+ CUDA_LAUNCH_BLOCKING=0 \
22
+ OMP_NUM_THREADS=2 \
23
+ MKL_NUM_THREADS=2 \
24
+ HF_HUB_ENABLE_HF_TRANSFER=1 \
25
+ TOKENIZERS_PARALLELISM=false
26
+
27
+ # Create working directory
28
+ WORKDIR /home/user/app
29
+
30
+ # Copy and install Python dependencies
31
+ COPY requirements.txt ./requirements.txt
32
  RUN pip install --no-cache-dir -r requirements.txt
33
 
34
+ # Vendor SAM2 and MatAnyone at build time (more reliable than runtime git)
35
+ # SAM2
36
+ RUN git clone --depth=1 https://github.com/facebookresearch/segment-anything-2 /home/user/app/third_party/sam2
37
+ ENV PYTHONPATH=/home/user/app/third_party/sam2:${PYTHONPATH}
38
+
39
+ # MatAnyone (official repo)
40
+ RUN git clone --depth=1 https://github.com/pq-yang/MatAnyone /home/user/app/third_party/matanyone
41
+ ENV PYTHONPATH=/home/user/app/third_party/matanyone:${PYTHONPATH}
42
+
43
+ # Copy application code
44
  COPY . .
45
 
46
+ # Create cache directories
47
+ RUN mkdir -p /home/user/.cache/huggingface /home/user/.cache/torch /home/user/.cache/transformers
48
+
49
+ # Expose Gradio port
50
+ EXPOSE 7860
51
+
52
+ # Environment for Gradio
53
+ ENV GRADIO_SERVER_NAME=0.0.0.0 \
54
+ GRADIO_SERVER_PORT=7860
55
+
56
+ # Run the application
57
+ CMD ["python3", "app.py"]
models/__init__.py ADDED
File without changes
models/matanyone_loader.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/matanyone_loader.py
2
+ import os, logging, torch, gc
3
+ import numpy as np
4
+ from typing import Optional, Tuple
5
+
6
+ log = logging.getLogger("matany_loader")
7
+
8
+ def _import_inference_core():
9
+ try:
10
+ # Check the actual import path from pq-yang/MatAnyone repo
11
+ from matanyone.inference_core import InferenceCore
12
+ return InferenceCore
13
+ except Exception as e:
14
+ log.error("MatAnyone import failed (vendoring/repo path?): %s", e)
15
+ return None
16
+
17
+ def _to_chw01(img):
18
+ # img: HWC uint8 or float01 -> CHW float01
19
+ if img.dtype != np.float32:
20
+ img = img.astype("float32")/255.0
21
+ return np.transpose(img, (2,0,1))
22
+
23
+ def _to_1hw01(mask):
24
+ # mask: HxW [0,1]
25
+ m = mask.astype("float32")
26
+ return m[None, ...]
27
+
28
+ class MatAnyoneSession:
29
+ def __init__(self, device: torch.device, precision: str = "fp16"):
30
+ self.device = device
31
+ self.precision = precision
32
+ self.core = None
33
+
34
+ def load(self, ckpt_path: Optional[str] = None, repo_id: Optional[str] = None, filename: Optional[str] = None):
35
+ InferenceCore = _import_inference_core()
36
+ if InferenceCore is None:
37
+ raise RuntimeError("MatAnyone not importable")
38
+
39
+ if ckpt_path is None and repo_id and filename:
40
+ from huggingface_hub import hf_hub_download
41
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.environ.get("HF_HOME"))
42
+
43
+ # init model
44
+ self.core = InferenceCore(ckpt_path, device=str(self.device))
45
+ return self
46
+
47
+ @torch.inference_mode()
48
+ def step(self, image_rgb, seed_mask: Optional[np.ndarray]=None):
49
+ """
50
+ image_rgb: HxWx3 uint8/float01
51
+ seed_mask: HxW float01 for first frame, else None
52
+ returns alpha HxW float01
53
+ """
54
+ assert self.core is not None, "MatAnyone not loaded"
55
+ img = _to_chw01(image_rgb) # CHW
56
+ if seed_mask is not None:
57
+ mask = _to_1hw01(seed_mask) # 1HW
58
+ alpha = self.core.step(img, mask)
59
+ else:
60
+ alpha = self.core.step(img, None)
61
+ # ensure HxW
62
+ if isinstance(alpha, np.ndarray):
63
+ return alpha.astype("float32")
64
+ if torch.is_tensor(alpha):
65
+ return alpha.detach().float().cpu().numpy()
66
+ raise RuntimeError("MatAnyone returned unknown alpha type")
67
+
68
+ def reset(self):
69
+ if self.core and hasattr(self.core, "reset"):
70
+ self.core.reset()
71
+ torch.cuda.empty_cache()
72
+ gc.collect()
models/sam2_loader.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/sam2_loader.py
2
+ import os, logging, torch
3
+ from huggingface_hub import hf_hub_download
4
+ from pathlib import Path
5
+ import numpy as np
6
+
7
+ log = logging.getLogger("sam2_loader")
8
+
9
+ DEFAULT_MODEL_ID = os.environ.get("SAM2_MODEL_ID", "facebook/sam2")
10
+ DEFAULT_VARIANT = os.environ.get("SAM2_VARIANT", "sam2_hiera_large")
11
+
12
+ # Map variant -> filenames (SAM2 releases follow this pattern)
13
+ VARIANT_FILES = {
14
+ "sam2_hiera_small": ("sam2_hiera_small.pt", "configs/sam2/sam2_hiera_s.yaml"),
15
+ "sam2_hiera_base": ("sam2_hiera_base.pt", "configs/sam2/sam2_hiera_b.yaml"),
16
+ "sam2_hiera_large": ("sam2_hiera_large.pt", "configs/sam2/sam2_hiera_l.yaml"),
17
+ }
18
+
19
+ def _download_checkpoint(model_id: str, ckpt_name: str) -> str:
20
+ return hf_hub_download(repo_id=model_id, filename=ckpt_name, local_dir=os.environ.get("HF_HOME"))
21
+
22
+ def _find_sam2_build():
23
+ try:
24
+ from sam2.build_sam import build_sam2
25
+ return build_sam2
26
+ except Exception as e:
27
+ log.error("SAM2 not importable (check Dockerfile vendoring): %s", e)
28
+ return None
29
+
30
+ class SAM2Predictor:
31
+ def __init__(self, device: torch.device):
32
+ self.device = device
33
+ self.model = None
34
+ self.predictor = None
35
+
36
+ def load(self, variant: str = DEFAULT_VARIANT, model_id: str = DEFAULT_MODEL_ID):
37
+ build_sam2 = _find_sam2_build()
38
+ if build_sam2 is None:
39
+ raise RuntimeError("SAM2 build function not available")
40
+
41
+ ckpt_name, cfg_path = VARIANT_FILES.get(variant, VARIANT_FILES["sam2_hiera_large"])
42
+ ckpt = _download_checkpoint(model_id, ckpt_name)
43
+
44
+ # Compose config via hydra-free path (using explicit path args)
45
+ model = build_sam2(config_file=cfg_path, ckpt_path=ckpt, device=str(self.device))
46
+ model.eval()
47
+ self.model = model
48
+
49
+ try:
50
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
51
+ self.predictor = SAM2VideoPredictor(self.model)
52
+ except Exception:
53
+ # Fallback to image predictor if video predictor missing
54
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
55
+ self.predictor = SAM2ImagePredictor(self.model)
56
+
57
+ return self
58
+
59
+ @torch.inference_mode()
60
+ def first_frame_mask(self, image_rgb01):
61
+ """
62
+ Returns an initial binary-ish mask for the foreground subject from first frame.
63
+ You can refine prompts here (points/boxes) if you add UI hooks later.
64
+ """
65
+ if hasattr(self.predictor, "set_image"):
66
+ self.predictor.set_image((image_rgb01*255).astype("uint8"))
67
+ # simple auto-box prompt (tight box)
68
+ h, w = image_rgb01.shape[:2]
69
+ box = np.array([1, 1, w-2, h-2])
70
+ masks, _, _ = self.predictor.predict(box=box, multimask_output=False)
71
+ mask = masks[0] # HxW bool/float
72
+ else:
73
+ # video predictor path: run_single_frame if available
74
+ mask = (image_rgb01[...,0] > -1) # dummy, should not happen
75
+ return mask.astype("float32")
requirements.txt CHANGED
@@ -1,8 +1,43 @@
 
 
1
  torch==2.2.2
2
  torchvision==0.17.2
 
 
 
 
 
 
 
 
 
3
  opencv-python-headless==4.10.0.84
 
 
 
 
 
4
  numpy==1.26.4
5
- pillow==10.4.0
 
6
  gradio==5.42.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  requests==2.31.0
8
- huggingface-hub>=0.33.5
 
1
+ # ===== Core runtime =====
2
+ # Option A: Keep your current Torch stack (safe for existing builds)
3
  torch==2.2.2
4
  torchvision==0.17.2
5
+ torchaudio==2.2.2
6
+
7
+ # Option B: Faster CUDA 12.1 wheels for T4 (uncomment to use instead)
8
+ # torch==2.3.1+cu121
9
+ # torchvision==0.18.1+cu121
10
+ # torchaudio==2.3.1+cu121
11
+ # --extra-index-url https://download.pytorch.org/whl/cu121
12
+
13
+ # ===== Video / image IO =====
14
  opencv-python-headless==4.10.0.84
15
+ imageio==2.35.1
16
+ imageio-ffmpeg==0.5.1
17
+ moviepy==1.0.3
18
+ decord==0.6.0
19
+ Pillow==10.4.0
20
  numpy==1.26.4
21
+
22
+ # ===== Gradio UI =====
23
  gradio==5.42.0
24
+
25
+ # ===== SAM2 Dependencies =====
26
+ hydra-core==1.3.2
27
+ omegaconf==2.3.0
28
+ einops==0.8.0
29
+ timm==1.0.9
30
+ pyyaml==6.0.2
31
+ matplotlib==3.9.2
32
+
33
+ # ===== MatAnyone Dependencies =====
34
+ kornia==0.7.3
35
+ scikit-image==0.24.0
36
+ tqdm==4.66.5
37
+
38
+ # ===== Helpers / caching =====
39
+ huggingface_hub==0.24.6
40
+ ffmpeg-python==0.2.0
41
+ psutil==6.0.0
42
  requests==2.31.0
43
+ scikit-learn==1.5.1
utils/__init__.py ADDED
File without changes
utils/accelerator.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/accelerator.py
2
+ import os, torch, logging, psutil, gc
3
+
4
+ log = logging.getLogger("accelerator")
5
+
6
+ def pick_device():
7
+ if torch.cuda.is_available():
8
+ dev = torch.device("cuda")
9
+ name = torch.cuda.get_device_name(0)
10
+ log.info(f"Using GPU: {name}")
11
+ return dev
12
+ log.warning("CUDA not available; falling back to CPU.")
13
+ return torch.device("cpu")
14
+
15
+ def torch_global_tuning():
16
+ # better matmul perf without crazy memory
17
+ try:
18
+ torch.set_float32_matmul_precision("high")
19
+ except Exception:
20
+ pass
21
+
22
+ def memory_checkpoint(tag=""):
23
+ try:
24
+ if torch.cuda.is_available():
25
+ mem = torch.cuda.memory_allocated() / (1024**2)
26
+ log.info(f"[CUDA mem] {tag}: {mem:.1f} MB")
27
+ except Exception:
28
+ pass
29
+
30
+ def cleanup():
31
+ if torch.cuda.is_available():
32
+ torch.cuda.synchronize()
33
+ torch.cuda.empty_cache()
34
+ gc.collect()