Spaces:
Sleeping
Sleeping
| """ | |
| Model loading for ZSInvert. | |
| Loads the generator LLM (Qwen2.5-0.5B-Instruct) and selectable | |
| embedding encoders (GTE-base, GTR-T5-base, Contriever). | |
| Part of E04: ZSInvert. | |
| """ | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| GENERATOR_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" | |
| ENCODERS = { | |
| "gte": "thenlper/gte-base", | |
| "gtr": "sentence-transformers/gtr-t5-base", | |
| "contriever": "facebook/contriever", | |
| "mini": "sentence-transformers/all-MiniLM-L6-v2", | |
| } | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _llm: AutoModelForCausalLM | None = None | |
| _llm_tokenizer: AutoTokenizer | None = None | |
| _encoders: dict[str, SentenceTransformer] = {} | |
| def load_llm() -> tuple[AutoModelForCausalLM, AutoTokenizer]: | |
| """Load generator LLM. Singleton.""" | |
| global _llm, _llm_tokenizer | |
| if _llm is None: | |
| _llm_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL) | |
| _llm = AutoModelForCausalLM.from_pretrained( | |
| GENERATOR_MODEL, | |
| dtype=torch.bfloat16, | |
| ).eval().to(_device) | |
| return _llm, _llm_tokenizer | |
| def load_encoder(name: str = "gte") -> SentenceTransformer: | |
| """Load embedding encoder by name. Cached per name.""" | |
| if name not in ENCODERS: | |
| raise ValueError(f"Unknown encoder '{name}'. Choose from: {list(ENCODERS.keys())}") | |
| if name not in _encoders: | |
| model_id = ENCODERS[name] | |
| _encoders[name] = SentenceTransformer(model_id, device=_device) | |
| return _encoders[name] | |
| def encode_text(text: str, encoder: SentenceTransformer) -> torch.Tensor: | |
| """Encode text to normalized embedding vector. Returns shape (1, hidden_dim).""" | |
| emb = encoder.encode( | |
| text, | |
| convert_to_tensor=True, | |
| normalize_embeddings=True, | |
| ) | |
| return emb.unsqueeze(0) | |
| def get_chat_format(tokenizer: AutoTokenizer) -> tuple[list[int], list[int]]: | |
| """Extract chat prefix/suffix token IDs from the Qwen2.5 chat template. | |
| The prefix is everything the template adds before the user content. | |
| The suffix is everything after the user content through the generation prompt. | |
| For Qwen2.5 the structure is: | |
| <|im_start|>system\\n...system prompt...<|im_end|>\\n | |
| <|im_start|>user\\n{CONTENT}<|im_end|>\\n | |
| <|im_start|>assistant\\n | |
| We split so that: prefix + prompt_tokens + suffix = full template. | |
| """ | |
| # Template with empty content (no gen prompt) — find where content is inserted | |
| empty = list(tokenizer.apply_chat_template( | |
| [{"role": "user", "content": ""}], | |
| add_generation_prompt=False, | |
| )) | |
| # Template with a known marker to locate the split point | |
| marker = list(tokenizer.apply_chat_template( | |
| [{"role": "user", "content": "hello"}], | |
| add_generation_prompt=True, | |
| )) | |
| marker_tokens = list(tokenizer.encode("hello", add_special_tokens=False)) | |
| # Find where the marker content appears in the full template | |
| marker_len = len(marker_tokens) | |
| for i in range(len(marker)): | |
| if marker[i : i + marker_len] == marker_tokens: | |
| prefix = marker[:i] | |
| suffix = marker[i + marker_len :] | |
| return prefix, suffix | |
| # Fallback: use the empty template structure | |
| # Empty template has <|im_end|>\n right after user\n — drop those 2 tokens | |
| prefix = empty[:-2] | |
| full_gen = list(tokenizer.apply_chat_template( | |
| [{"role": "user", "content": ""}], | |
| add_generation_prompt=True, | |
| )) | |
| suffix = full_gen[len(prefix):] | |
| return prefix, suffix | |