diffretriever-llada-8b-single / modeling_diffretriever.py
wshuai190's picture
Add self-contained DiffRetriever (trust_remote_code: code + config + adapter/)
a7c784c verified
Raw
History Blame Contribute Delete
4.65 kB
"""HuggingFace `trust_remote_code` wrapper for DiffRetriever.
Lets a released checkpoint load with a single call:
from transformers import AutoModel
model = AutoModel.from_pretrained(
"ielabgroup/diffretriever-dream-7b-single", trust_remote_code=True)
model.eval()
ids, mask = model.tokenize(["a query"], is_query=True)
out = model.encode(ids, mask, is_query=True) # {'repr_hidden', ...}
It wraps `TrainableDiffusionRetriever` (shipped in the same repo) and exposes
`.tokenize()` / `.encode()` / `.backbone`. The base diffusion backbone
(Dream / LLaDA / ...) is pulled from its own Hub repo at load time; this repo
carries only the LoRA adapter + tokenizer + retriever_config.json + this code.
Shipped inside each model repo, so keep the import surface minimal.
"""
from __future__ import annotations
import os
import shutil
import tempfile
import torch # noqa: F401 (used by the wrapped retriever; keep import explicit)
from transformers import PreTrainedModel
from .configuration_diffretriever import DiffRetrieverConfig
from .diffretriever_trainable import TrainableDiffusionRetriever
class DiffRetrieverModel(PreTrainedModel):
config_class = DiffRetrieverConfig
def __init__(self, config: DiffRetrieverConfig, retriever=None):
super().__init__(config)
# Registered as a submodule so .to()/.eval()/.parameters() recurse into it.
self.retriever = retriever
# ── Loading ──────────────────────────────────────────────────────────────
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# We do NOT call super().from_pretrained: the checkpoint is a LoRA
# adapter that must be attached on top of a base backbone pulled from a
# different Hub repo, which TrainableDiffusionRetriever.load() handles.
kwargs.pop("trust_remote_code", None)
config = kwargs.pop("config", None)
path = str(pretrained_model_name_or_path)
if not os.path.isdir(path):
from huggingface_hub import snapshot_download
dl = {k: kwargs[k] for k in
("cache_dir", "revision", "token", "local_files_only", "proxies")
if k in kwargs and kwargs[k] is not None}
path = snapshot_download(path, **dl)
# Build a temp view of the snapshot for TrainableDiffusionRetriever.load:
# - DROP config.json: load() detects a released adapter repo only when
# adapter_config.json is present AND config.json is absent; this repo
# ships a config.json for the auto_map.
# - FLATTEN the adapter/ subdir into the root: the LoRA adapter is
# stored under adapter/ (not the repo root) on purpose, so that
# transformers' PEFT auto-loader does NOT hijack AutoModel and load
# the base model directly instead of this wrapper. load() still needs
# adapter_config.json at the top level, so we link it up here.
tmp = tempfile.mkdtemp(prefix="diffretriever_")
def _link(src, dst):
if os.path.isfile(src) and not os.path.exists(dst):
os.symlink(os.path.abspath(src), dst)
try:
for fn in os.listdir(path):
if fn == "config.json":
continue
_link(os.path.join(path, fn), os.path.join(tmp, fn))
adapter_dir = os.path.join(path, "adapter")
if os.path.isdir(adapter_dir):
for fn in os.listdir(adapter_dir):
_link(os.path.join(adapter_dir, fn), os.path.join(tmp, fn))
retriever = TrainableDiffusionRetriever.load(tmp)
finally:
shutil.rmtree(tmp, ignore_errors=True)
if config is None:
config = DiffRetrieverConfig(
base_model=getattr(retriever, "model_name", None),
backbone_type=getattr(retriever, "model_type", None),
)
return cls(config, retriever=retriever)
# ── Retrieval API (delegates to the wrapped retriever) ─────────────────────
@property
def backbone(self):
return self.retriever.backbone
def tokenize(self, *args, **kwargs):
return self.retriever.tokenize(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.retriever.encode(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.retriever(*args, **kwargs)