File size: 3,270 Bytes
bf44358
 
 
9e2ce74
bf44358
 
9e2ce74
 
 
 
 
 
bf44358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3726763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf44358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Singleton model loader — downloads reflow-1-big from HF Hub once."""

import os
import sys
import torch
import tiktoken

# Try to import from models/ directory first, then from root
try:
    from models.reflow import GPT, GPTConfig
except ImportError:
    from reflow import GPT, GPTConfig

_model = None
_enc = None
_device = None
_W_v2s = None
_signal_basis = None

# Local paths to check before downloading (relative to this file's directory)
_LOCAL_CKPT_CANDIDATES = [
    "model/ckpt.pt",
]


def _find_local_ckpt():
    """Search for a local checkpoint file relative to the project root."""
    base_dir = os.path.dirname(os.path.abspath(__file__))
    for rel in _LOCAL_CKPT_CANDIDATES:
        path = os.path.join(base_dir, rel)
        if os.path.isfile(path):
            return path
    return None


def get_model():
    """Return (model, enc, device).  First call downloads & loads the checkpoint."""
    global _model, _enc, _device, _W_v2s, _signal_basis

    if _model is not None:
        return _model, _enc, _device

    _device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[model_loader] device = {_device}")

    # 1. Try local checkpoint first
    ckpt_path = _find_local_ckpt()
    if ckpt_path:
        print(f"[model_loader] Found local checkpoint: {ckpt_path}")
    else:
        # 2. Fall back to download from HF Hub or ModelScope
        print("[model_loader] No local checkpoint found, downloading from cloud ...")
        try:
            from huggingface_hub import hf_hub_download
            print("[model_loader] Trying Hugging Face Hub ...")
            ckpt_path = hf_hub_download(
                repo_id="reuAC/reFlow",
                filename="out/reflow-1-big/ckpt.pt",
            )
            print(f"[model_loader] Downloaded from Hugging Face: {ckpt_path}")
        except Exception as e:
            print(f"[model_loader] Hugging Face download failed: {e}")
            print("[model_loader] Trying ModelScope ...")
            from modelscope.hub.file_download import model_file_download
            ckpt_path = model_file_download(
                model_id="recuAC/reFlow",
                file_path="out/reflow-1-big/ckpt.pt",
            )
            print(f"[model_loader] Downloaded from ModelScope: {ckpt_path}")

    checkpoint = torch.load(ckpt_path, map_location=_device, weights_only=False)

    # Build model from saved config
    model_args = checkpoint["model_args"]
    config = GPTConfig(**model_args)
    _model = GPT(config)

    # Strip _orig_mod. prefix (torch.compile artifact)
    state_dict = checkpoint["model"]
    for k in list(state_dict.keys()):
        if k.startswith("_orig_mod."):
            state_dict[k[len("_orig_mod."):]] = state_dict.pop(k)

    _model.load_state_dict(state_dict)
    _model.eval().to(_device)

    # Cache frequently-used tensors
    _W_v2s = _model.transformer.wte.vocab_to_signals.weight.data
    _signal_basis = _model.transformer.wte.signal_basis.data

    _enc = tiktoken.get_encoding("gpt2")
    print("[model_loader] Model ready.")
    return _model, _enc, _device


def get_cached_tensors():
    """Return (W_v2s, signal_basis) — call get_model() first."""
    if _W_v2s is None:
        get_model()
    return _W_v2s, _signal_basis