"""Singleton model loader — downloads reflow-1-big from HF Hub once.""" import os import sys import torch import tiktoken # Try to import from models/ directory first, then from root try: from models.reflow import GPT, GPTConfig except ImportError: from reflow import GPT, GPTConfig _model = None _enc = None _device = None _W_v2s = None _signal_basis = None # Local paths to check before downloading (relative to this file's directory) _LOCAL_CKPT_CANDIDATES = [ "model/ckpt.pt", ] def _find_local_ckpt(): """Search for a local checkpoint file relative to the project root.""" base_dir = os.path.dirname(os.path.abspath(__file__)) for rel in _LOCAL_CKPT_CANDIDATES: path = os.path.join(base_dir, rel) if os.path.isfile(path): return path return None def get_model(): """Return (model, enc, device). First call downloads & loads the checkpoint.""" global _model, _enc, _device, _W_v2s, _signal_basis if _model is not None: return _model, _enc, _device _device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[model_loader] device = {_device}") # 1. Try local checkpoint first ckpt_path = _find_local_ckpt() if ckpt_path: print(f"[model_loader] Found local checkpoint: {ckpt_path}") else: # 2. Fall back to download from HF Hub or ModelScope print("[model_loader] No local checkpoint found, downloading from cloud ...") try: from huggingface_hub import hf_hub_download print("[model_loader] Trying Hugging Face Hub ...") ckpt_path = hf_hub_download( repo_id="reuAC/reFlow", filename="out/reflow-1-big/ckpt.pt", ) print(f"[model_loader] Downloaded from Hugging Face: {ckpt_path}") except Exception as e: print(f"[model_loader] Hugging Face download failed: {e}") print("[model_loader] Trying ModelScope ...") from modelscope.hub.file_download import model_file_download ckpt_path = model_file_download( model_id="recuAC/reFlow", file_path="out/reflow-1-big/ckpt.pt", ) print(f"[model_loader] Downloaded from ModelScope: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location=_device, weights_only=False) # Build model from saved config model_args = checkpoint["model_args"] config = GPTConfig(**model_args) _model = GPT(config) # Strip _orig_mod. prefix (torch.compile artifact) state_dict = checkpoint["model"] for k in list(state_dict.keys()): if k.startswith("_orig_mod."): state_dict[k[len("_orig_mod."):]] = state_dict.pop(k) _model.load_state_dict(state_dict) _model.eval().to(_device) # Cache frequently-used tensors _W_v2s = _model.transformer.wte.vocab_to_signals.weight.data _signal_basis = _model.transformer.wte.signal_basis.data _enc = tiktoken.get_encoding("gpt2") print("[model_loader] Model ready.") return _model, _enc, _device def get_cached_tensors(): """Return (W_v2s, signal_basis) — call get_model() first.""" if _W_v2s is None: get_model() return _W_v2s, _signal_basis