""" Model loading for ZSInvert. Loads the generator LLM (Qwen2.5-0.5B-Instruct) and selectable embedding encoders (GTE-base, GTR-T5-base, Contriever). Part of E04: ZSInvert. """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from sentence_transformers import SentenceTransformer GENERATOR_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" ENCODERS = { "gte": "thenlper/gte-base", "gtr": "sentence-transformers/gtr-t5-base", "contriever": "facebook/contriever", "mini": "sentence-transformers/all-MiniLM-L6-v2", } _device = "cuda" if torch.cuda.is_available() else "cpu" _llm: AutoModelForCausalLM | None = None _llm_tokenizer: AutoTokenizer | None = None _encoders: dict[str, SentenceTransformer] = {} def load_llm() -> tuple[AutoModelForCausalLM, AutoTokenizer]: """Load generator LLM. Singleton.""" global _llm, _llm_tokenizer if _llm is None: _llm_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL) _llm = AutoModelForCausalLM.from_pretrained( GENERATOR_MODEL, dtype=torch.bfloat16, ).eval().to(_device) return _llm, _llm_tokenizer def load_encoder(name: str = "gte") -> SentenceTransformer: """Load embedding encoder by name. Cached per name.""" if name not in ENCODERS: raise ValueError(f"Unknown encoder '{name}'. Choose from: {list(ENCODERS.keys())}") if name not in _encoders: model_id = ENCODERS[name] _encoders[name] = SentenceTransformer(model_id, device=_device) return _encoders[name] def encode_text(text: str, encoder: SentenceTransformer) -> torch.Tensor: """Encode text to normalized embedding vector. Returns shape (1, hidden_dim).""" emb = encoder.encode( text, convert_to_tensor=True, normalize_embeddings=True, ) return emb.unsqueeze(0) def get_chat_format(tokenizer: AutoTokenizer) -> tuple[list[int], list[int]]: """Extract chat prefix/suffix token IDs from the Qwen2.5 chat template. The prefix is everything the template adds before the user content. The suffix is everything after the user content through the generation prompt. For Qwen2.5 the structure is: <|im_start|>system\\n...system prompt...<|im_end|>\\n <|im_start|>user\\n{CONTENT}<|im_end|>\\n <|im_start|>assistant\\n We split so that: prefix + prompt_tokens + suffix = full template. """ # Template with empty content (no gen prompt) — find where content is inserted empty = list(tokenizer.apply_chat_template( [{"role": "user", "content": ""}], add_generation_prompt=False, )) # Template with a known marker to locate the split point marker = list(tokenizer.apply_chat_template( [{"role": "user", "content": "hello"}], add_generation_prompt=True, )) marker_tokens = list(tokenizer.encode("hello", add_special_tokens=False)) # Find where the marker content appears in the full template marker_len = len(marker_tokens) for i in range(len(marker)): if marker[i : i + marker_len] == marker_tokens: prefix = marker[:i] suffix = marker[i + marker_len :] return prefix, suffix # Fallback: use the empty template structure # Empty template has <|im_end|>\n right after user\n — drop those 2 tokens prefix = empty[:-2] full_gen = list(tokenizer.apply_chat_template( [{"role": "user", "content": ""}], add_generation_prompt=True, )) suffix = full_gen[len(prefix):] return prefix, suffix