| """Singleton model loader — downloads reflow-1-big from HF Hub once.""" |
|
|
| import os |
| import sys |
| import torch |
| import tiktoken |
|
|
| |
| try: |
| from models.reflow import GPT, GPTConfig |
| except ImportError: |
| from reflow import GPT, GPTConfig |
|
|
| _model = None |
| _enc = None |
| _device = None |
| _W_v2s = None |
| _signal_basis = None |
|
|
| |
| _LOCAL_CKPT_CANDIDATES = [ |
| "model/ckpt.pt", |
| ] |
|
|
|
|
| def _find_local_ckpt(): |
| """Search for a local checkpoint file relative to the project root.""" |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| for rel in _LOCAL_CKPT_CANDIDATES: |
| path = os.path.join(base_dir, rel) |
| if os.path.isfile(path): |
| return path |
| return None |
|
|
|
|
| def get_model(): |
| """Return (model, enc, device). First call downloads & loads the checkpoint.""" |
| global _model, _enc, _device, _W_v2s, _signal_basis |
|
|
| if _model is not None: |
| return _model, _enc, _device |
|
|
| _device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"[model_loader] device = {_device}") |
|
|
| |
| ckpt_path = _find_local_ckpt() |
| if ckpt_path: |
| print(f"[model_loader] Found local checkpoint: {ckpt_path}") |
| else: |
| |
| print("[model_loader] No local checkpoint found, downloading from cloud ...") |
| try: |
| from huggingface_hub import hf_hub_download |
| print("[model_loader] Trying Hugging Face Hub ...") |
| ckpt_path = hf_hub_download( |
| repo_id="reuAC/reFlow", |
| filename="out/reflow-1-big/ckpt.pt", |
| ) |
| print(f"[model_loader] Downloaded from Hugging Face: {ckpt_path}") |
| except Exception as e: |
| print(f"[model_loader] Hugging Face download failed: {e}") |
| print("[model_loader] Trying ModelScope ...") |
| from modelscope.hub.file_download import model_file_download |
| ckpt_path = model_file_download( |
| model_id="recuAC/reFlow", |
| file_path="out/reflow-1-big/ckpt.pt", |
| ) |
| print(f"[model_loader] Downloaded from ModelScope: {ckpt_path}") |
|
|
| checkpoint = torch.load(ckpt_path, map_location=_device, weights_only=False) |
|
|
| |
| model_args = checkpoint["model_args"] |
| config = GPTConfig(**model_args) |
| _model = GPT(config) |
|
|
| |
| state_dict = checkpoint["model"] |
| for k in list(state_dict.keys()): |
| if k.startswith("_orig_mod."): |
| state_dict[k[len("_orig_mod."):]] = state_dict.pop(k) |
|
|
| _model.load_state_dict(state_dict) |
| _model.eval().to(_device) |
|
|
| |
| _W_v2s = _model.transformer.wte.vocab_to_signals.weight.data |
| _signal_basis = _model.transformer.wte.signal_basis.data |
|
|
| _enc = tiktoken.get_encoding("gpt2") |
| print("[model_loader] Model ready.") |
| return _model, _enc, _device |
|
|
|
|
| def get_cached_tensors(): |
| """Return (W_v2s, signal_basis) — call get_model() first.""" |
| if _W_v2s is None: |
| get_model() |
| return _W_v2s, _signal_basis |
|
|