File size: 2,106 Bytes
22df1ea
 
 
 
 
 
 
 
 
6f3fe10
22df1ea
 
 
 
 
 
 
 
 
6f3fe10
 
0142f2c
22df1ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3d0932
22df1ea
 
 
6f3fe10
22df1ea
 
d3d0932
6f3fe10
22df1ea
 
 
d3d0932
6f3fe10
22df1ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Shared lazy singleton for SAM 2.1 Tiny (model + processor).

Both card detection (prompt-based) and hand segmentation use the same
HuggingFace weights, so loading them once per process halves cold-start
cost and keeps only one copy of the encoder in memory.
"""

from __future__ import annotations

import logging
import os
import time
from typing import Tuple

# Bump the default HF Hub HEAD/download timeout (10s) before transformers
# reads the env var. On flaky networks the 10s HEAD check fires a retry storm
# even when the weights are already cached locally.
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "60")

logger = logging.getLogger(__name__)

SAM2_MODEL_ID = "facebook/sam2.1-hiera-small"

# SAM resizes internally to 1024 — feeding >1024 wastes CPU on image encoding.
INFERENCE_MAX_SIDE = 1024

_model = None
_processor = None


def get_sam2() -> Tuple[object, object]:
    """Return (model, processor) singletons, loading on first call.

    Tries the local HF cache first (``local_files_only=True``). This avoids
    the HEAD-request retry storm that happens when huggingface.co is slow or
    unreachable but the weights are already on disk. On a true cache miss we
    fall through to a normal online load.
    """
    global _model, _processor
    if _model is None or _processor is None:
        from transformers import Sam2Model, Sam2Processor
        t0 = time.time()
        logger.info("loading SAM 2.1 (%s)", SAM2_MODEL_ID)
        try:
            _processor = Sam2Processor.from_pretrained(SAM2_MODEL_ID, local_files_only=True)
            _model = Sam2Model.from_pretrained(SAM2_MODEL_ID, local_files_only=True).to("cpu").eval()
            logger.info("SAM 2.1 loaded (offline cache) in %.1fs", time.time() - t0)
        except (OSError, ValueError):
            # Cache miss — fall back to online download.
            _processor = Sam2Processor.from_pretrained(SAM2_MODEL_ID)
            _model = Sam2Model.from_pretrained(SAM2_MODEL_ID).to("cpu").eval()
            logger.info("SAM 2.1 loaded (online) in %.1fs", time.time() - t0)
    return _model, _processor