File size: 1,667 Bytes
77e37fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import annotations

import os
from pathlib import Path
from typing import Callable

from huggingface_hub import snapshot_download

TARGET_OMNI_MODEL = os.getenv("PB3D_TARGET_MODEL", "tencent/Hunyuan3D-Omni")
DEFAULT_CACHE_ROOT = Path(os.getenv("PB3D_MODEL_CACHE", "./models"))


def get_target_model_dir(model_id: str = TARGET_OMNI_MODEL) -> Path:
    safe = model_id.replace("/", "--")
    return DEFAULT_CACHE_ROOT / safe


def ensure_target_model_cached(
    model_id: str = TARGET_OMNI_MODEL,
    progress: Callable[[str], None] | None = None,
) -> dict:
    """
    Best-effort local cache of the upstream target model repo.

    This fetches the model repo into the Space filesystem so later integration work
    can call the upstream inference stack directly. It does not assume internal file
    names beyond the official repo id.
    """
    target_dir = get_target_model_dir(model_id)
    target_dir.mkdir(parents=True, exist_ok=True)

    if progress:
        progress(f"Checking local cache for {model_id}…")

    try:
        local_path = snapshot_download(
            repo_id=model_id,
            local_dir=str(target_dir),
            local_dir_use_symlinks=False,
            resume_download=True,
        )
        return {
            "ok": True,
            "model_id": model_id,
            "local_path": local_path,
            "message": f"Cached {model_id} in {local_path}",
        }
    except Exception as exc:  # pragma: no cover
        return {
            "ok": False,
            "model_id": model_id,
            "local_path": str(target_dir),
            "message": f"Could not cache {model_id}: {exc}",
        }