File size: 4,645 Bytes
a7c784c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""HuggingFace `trust_remote_code` wrapper for DiffRetriever.

Lets a released checkpoint load with a single call:

    from transformers import AutoModel
    model = AutoModel.from_pretrained(
        "ielabgroup/diffretriever-dream-7b-single", trust_remote_code=True)
    model.eval()
    ids, mask = model.tokenize(["a query"], is_query=True)
    out = model.encode(ids, mask, is_query=True)   # {'repr_hidden', ...}

It wraps `TrainableDiffusionRetriever` (shipped in the same repo) and exposes
`.tokenize()` / `.encode()` / `.backbone`. The base diffusion backbone
(Dream / LLaDA / ...) is pulled from its own Hub repo at load time; this repo
carries only the LoRA adapter + tokenizer + retriever_config.json + this code.

Shipped inside each model repo, so keep the import surface minimal.
"""
from __future__ import annotations

import os
import shutil
import tempfile

import torch  # noqa: F401  (used by the wrapped retriever; keep import explicit)
from transformers import PreTrainedModel

from .configuration_diffretriever import DiffRetrieverConfig
from .diffretriever_trainable import TrainableDiffusionRetriever


class DiffRetrieverModel(PreTrainedModel):
    config_class = DiffRetrieverConfig

    def __init__(self, config: DiffRetrieverConfig, retriever=None):
        super().__init__(config)
        # Registered as a submodule so .to()/.eval()/.parameters() recurse into it.
        self.retriever = retriever

    # ── Loading ──────────────────────────────────────────────────────────────
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        # We do NOT call super().from_pretrained: the checkpoint is a LoRA
        # adapter that must be attached on top of a base backbone pulled from a
        # different Hub repo, which TrainableDiffusionRetriever.load() handles.
        kwargs.pop("trust_remote_code", None)
        config = kwargs.pop("config", None)

        path = str(pretrained_model_name_or_path)
        if not os.path.isdir(path):
            from huggingface_hub import snapshot_download
            dl = {k: kwargs[k] for k in
                  ("cache_dir", "revision", "token", "local_files_only", "proxies")
                  if k in kwargs and kwargs[k] is not None}
            path = snapshot_download(path, **dl)

        # Build a temp view of the snapshot for TrainableDiffusionRetriever.load:
        #   - DROP config.json: load() detects a released adapter repo only when
        #     adapter_config.json is present AND config.json is absent; this repo
        #     ships a config.json for the auto_map.
        #   - FLATTEN the adapter/ subdir into the root: the LoRA adapter is
        #     stored under adapter/ (not the repo root) on purpose, so that
        #     transformers' PEFT auto-loader does NOT hijack AutoModel and load
        #     the base model directly instead of this wrapper. load() still needs
        #     adapter_config.json at the top level, so we link it up here.
        tmp = tempfile.mkdtemp(prefix="diffretriever_")

        def _link(src, dst):
            if os.path.isfile(src) and not os.path.exists(dst):
                os.symlink(os.path.abspath(src), dst)

        try:
            for fn in os.listdir(path):
                if fn == "config.json":
                    continue
                _link(os.path.join(path, fn), os.path.join(tmp, fn))
            adapter_dir = os.path.join(path, "adapter")
            if os.path.isdir(adapter_dir):
                for fn in os.listdir(adapter_dir):
                    _link(os.path.join(adapter_dir, fn), os.path.join(tmp, fn))
            retriever = TrainableDiffusionRetriever.load(tmp)
        finally:
            shutil.rmtree(tmp, ignore_errors=True)

        if config is None:
            config = DiffRetrieverConfig(
                base_model=getattr(retriever, "model_name", None),
                backbone_type=getattr(retriever, "model_type", None),
            )
        return cls(config, retriever=retriever)

    # ── Retrieval API (delegates to the wrapped retriever) ─────────────────────
    @property
    def backbone(self):
        return self.retriever.backbone

    def tokenize(self, *args, **kwargs):
        return self.retriever.tokenize(*args, **kwargs)

    def encode(self, *args, **kwargs):
        return self.retriever.encode(*args, **kwargs)

    def forward(self, *args, **kwargs):
        return self.retriever(*args, **kwargs)