File size: 3,581 Bytes
5689bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc3b76
5689bad
 
1fc3b76
5689bad
1fc3b76
5689bad
 
1fc3b76
f06d2ef
5689bad
 
 
 
 
 
 
 
 
 
 
 
1fc3b76
5689bad
 
1fc3b76
5689bad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Model loading for ZSInvert.

Loads the generator LLM (Qwen2.5-0.5B-Instruct) and selectable
embedding encoders (GTE-base, GTR-T5-base, Contriever).

Part of E04: ZSInvert.
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer

GENERATOR_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"

ENCODERS = {
    "gte": "thenlper/gte-base",
    "gtr": "sentence-transformers/gtr-t5-base",
    "contriever": "facebook/contriever",
    "mini": "sentence-transformers/all-MiniLM-L6-v2",
}

_device = "cuda" if torch.cuda.is_available() else "cpu"

_llm: AutoModelForCausalLM | None = None
_llm_tokenizer: AutoTokenizer | None = None
_encoders: dict[str, SentenceTransformer] = {}


def load_llm() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    """Load generator LLM. Singleton."""
    global _llm, _llm_tokenizer
    if _llm is None:
        _llm_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
        _llm = AutoModelForCausalLM.from_pretrained(
            GENERATOR_MODEL,
            dtype=torch.bfloat16,
        ).eval().to(_device)
    return _llm, _llm_tokenizer


def load_encoder(name: str = "gte") -> SentenceTransformer:
    """Load embedding encoder by name. Cached per name."""
    if name not in ENCODERS:
        raise ValueError(f"Unknown encoder '{name}'. Choose from: {list(ENCODERS.keys())}")
    if name not in _encoders:
        model_id = ENCODERS[name]
        _encoders[name] = SentenceTransformer(model_id, device=_device)
    return _encoders[name]


def encode_text(text: str, encoder: SentenceTransformer) -> torch.Tensor:
    """Encode text to normalized embedding vector. Returns shape (1, hidden_dim)."""
    emb = encoder.encode(
        text,
        convert_to_tensor=True,
        normalize_embeddings=True,
    )
    return emb.unsqueeze(0)


def get_chat_format(tokenizer: AutoTokenizer) -> tuple[list[int], list[int]]:
    """Extract chat prefix/suffix token IDs from the Qwen2.5 chat template.

    The prefix is everything the template adds before the user content.
    The suffix is everything after the user content through the generation prompt.

    For Qwen2.5 the structure is:
        <|im_start|>system\\n...system prompt...<|im_end|>\\n
        <|im_start|>user\\n{CONTENT}<|im_end|>\\n
        <|im_start|>assistant\\n

    We split so that: prefix + prompt_tokens + suffix = full template.
    """
    # Template with empty content (no gen prompt) — find where content is inserted
    empty = list(tokenizer.apply_chat_template(
        [{"role": "user", "content": ""}],
        add_generation_prompt=False,
    ))
    # Template with a known marker to locate the split point
    marker = list(tokenizer.apply_chat_template(
        [{"role": "user", "content": "hello"}],
        add_generation_prompt=True,
    ))
    marker_tokens = list(tokenizer.encode("hello", add_special_tokens=False))

    # Find where the marker content appears in the full template
    marker_len = len(marker_tokens)
    for i in range(len(marker)):
        if marker[i : i + marker_len] == marker_tokens:
            prefix = marker[:i]
            suffix = marker[i + marker_len :]
            return prefix, suffix

    # Fallback: use the empty template structure
    # Empty template has <|im_end|>\n right after user\n — drop those 2 tokens
    prefix = empty[:-2]
    full_gen = list(tokenizer.apply_chat_template(
        [{"role": "user", "content": ""}],
        add_generation_prompt=True,
    ))
    suffix = full_gen[len(prefix):]
    return prefix, suffix