"""HuggingFace `trust_remote_code` wrapper for DiffRetriever. Lets a released checkpoint load with a single call: from transformers import AutoModel model = AutoModel.from_pretrained( "ielabgroup/diffretriever-dream-7b-single", trust_remote_code=True) model.eval() ids, mask = model.tokenize(["a query"], is_query=True) out = model.encode(ids, mask, is_query=True) # {'repr_hidden', ...} It wraps `TrainableDiffusionRetriever` (shipped in the same repo) and exposes `.tokenize()` / `.encode()` / `.backbone`. The base diffusion backbone (Dream / LLaDA / ...) is pulled from its own Hub repo at load time; this repo carries only the LoRA adapter + tokenizer + retriever_config.json + this code. Shipped inside each model repo, so keep the import surface minimal. """ from __future__ import annotations import os import shutil import tempfile import torch # noqa: F401 (used by the wrapped retriever; keep import explicit) from transformers import PreTrainedModel from .configuration_diffretriever import DiffRetrieverConfig from .diffretriever_trainable import TrainableDiffusionRetriever class DiffRetrieverModel(PreTrainedModel): config_class = DiffRetrieverConfig def __init__(self, config: DiffRetrieverConfig, retriever=None): super().__init__(config) # Registered as a submodule so .to()/.eval()/.parameters() recurse into it. self.retriever = retriever # ── Loading ────────────────────────────────────────────────────────────── @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # We do NOT call super().from_pretrained: the checkpoint is a LoRA # adapter that must be attached on top of a base backbone pulled from a # different Hub repo, which TrainableDiffusionRetriever.load() handles. kwargs.pop("trust_remote_code", None) config = kwargs.pop("config", None) path = str(pretrained_model_name_or_path) if not os.path.isdir(path): from huggingface_hub import snapshot_download dl = {k: kwargs[k] for k in ("cache_dir", "revision", "token", "local_files_only", "proxies") if k in kwargs and kwargs[k] is not None} path = snapshot_download(path, **dl) # Build a temp view of the snapshot for TrainableDiffusionRetriever.load: # - DROP config.json: load() detects a released adapter repo only when # adapter_config.json is present AND config.json is absent; this repo # ships a config.json for the auto_map. # - FLATTEN the adapter/ subdir into the root: the LoRA adapter is # stored under adapter/ (not the repo root) on purpose, so that # transformers' PEFT auto-loader does NOT hijack AutoModel and load # the base model directly instead of this wrapper. load() still needs # adapter_config.json at the top level, so we link it up here. tmp = tempfile.mkdtemp(prefix="diffretriever_") def _link(src, dst): if os.path.isfile(src) and not os.path.exists(dst): os.symlink(os.path.abspath(src), dst) try: for fn in os.listdir(path): if fn == "config.json": continue _link(os.path.join(path, fn), os.path.join(tmp, fn)) adapter_dir = os.path.join(path, "adapter") if os.path.isdir(adapter_dir): for fn in os.listdir(adapter_dir): _link(os.path.join(adapter_dir, fn), os.path.join(tmp, fn)) retriever = TrainableDiffusionRetriever.load(tmp) finally: shutil.rmtree(tmp, ignore_errors=True) if config is None: config = DiffRetrieverConfig( base_model=getattr(retriever, "model_name", None), backbone_type=getattr(retriever, "model_type", None), ) return cls(config, retriever=retriever) # ── Retrieval API (delegates to the wrapped retriever) ───────────────────── @property def backbone(self): return self.retriever.backbone def tokenize(self, *args, **kwargs): return self.retriever.tokenize(*args, **kwargs) def encode(self, *args, **kwargs): return self.retriever.encode(*args, **kwargs) def forward(self, *args, **kwargs): return self.retriever(*args, **kwargs)