ryandt's picture
Removed streaming
1fc3b76
"""
Model loading for ZSInvert.
Loads the generator LLM (Qwen2.5-0.5B-Instruct) and selectable
embedding encoders (GTE-base, GTR-T5-base, Contriever).
Part of E04: ZSInvert.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
GENERATOR_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
ENCODERS = {
"gte": "thenlper/gte-base",
"gtr": "sentence-transformers/gtr-t5-base",
"contriever": "facebook/contriever",
"mini": "sentence-transformers/all-MiniLM-L6-v2",
}
_device = "cuda" if torch.cuda.is_available() else "cpu"
_llm: AutoModelForCausalLM | None = None
_llm_tokenizer: AutoTokenizer | None = None
_encoders: dict[str, SentenceTransformer] = {}
def load_llm() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Load generator LLM. Singleton."""
global _llm, _llm_tokenizer
if _llm is None:
_llm_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
_llm = AutoModelForCausalLM.from_pretrained(
GENERATOR_MODEL,
dtype=torch.bfloat16,
).eval().to(_device)
return _llm, _llm_tokenizer
def load_encoder(name: str = "gte") -> SentenceTransformer:
"""Load embedding encoder by name. Cached per name."""
if name not in ENCODERS:
raise ValueError(f"Unknown encoder '{name}'. Choose from: {list(ENCODERS.keys())}")
if name not in _encoders:
model_id = ENCODERS[name]
_encoders[name] = SentenceTransformer(model_id, device=_device)
return _encoders[name]
def encode_text(text: str, encoder: SentenceTransformer) -> torch.Tensor:
"""Encode text to normalized embedding vector. Returns shape (1, hidden_dim)."""
emb = encoder.encode(
text,
convert_to_tensor=True,
normalize_embeddings=True,
)
return emb.unsqueeze(0)
def get_chat_format(tokenizer: AutoTokenizer) -> tuple[list[int], list[int]]:
"""Extract chat prefix/suffix token IDs from the Qwen2.5 chat template.
The prefix is everything the template adds before the user content.
The suffix is everything after the user content through the generation prompt.
For Qwen2.5 the structure is:
<|im_start|>system\\n...system prompt...<|im_end|>\\n
<|im_start|>user\\n{CONTENT}<|im_end|>\\n
<|im_start|>assistant\\n
We split so that: prefix + prompt_tokens + suffix = full template.
"""
# Template with empty content (no gen prompt) — find where content is inserted
empty = list(tokenizer.apply_chat_template(
[{"role": "user", "content": ""}],
add_generation_prompt=False,
))
# Template with a known marker to locate the split point
marker = list(tokenizer.apply_chat_template(
[{"role": "user", "content": "hello"}],
add_generation_prompt=True,
))
marker_tokens = list(tokenizer.encode("hello", add_special_tokens=False))
# Find where the marker content appears in the full template
marker_len = len(marker_tokens)
for i in range(len(marker)):
if marker[i : i + marker_len] == marker_tokens:
prefix = marker[:i]
suffix = marker[i + marker_len :]
return prefix, suffix
# Fallback: use the empty template structure
# Empty template has <|im_end|>\n right after user\n — drop those 2 tokens
prefix = empty[:-2]
full_gen = list(tokenizer.apply_chat_template(
[{"role": "user", "content": ""}],
add_generation_prompt=True,
))
suffix = full_gen[len(prefix):]
return prefix, suffix