| | import os |
| | from typing import List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| | from transformers import AutoTokenizer, Qwen3ForCausalLM |
| |
|
| | from mmgp import offload |
| | from shared.utils import files_locator as fl |
| |
|
| | OUTPUT_LAYERS = [9, 18, 27] |
| | MAX_LENGTH = 512 |
| |
|
| |
|
| | class Qwen3Embedder(nn.Module): |
| | def __init__( |
| | self, |
| | model_spec: str | None = None, |
| | tokenizer_path: str | None = None, |
| | torch_dtype: str = "bfloat16", |
| | ): |
| | super().__init__() |
| | file_path = model_spec |
| | default_config = os.path.join(os.path.dirname(file_path), "config.json") |
| | self.model = offload.fast_load_transformers_model( |
| | file_path, |
| | writable_tensors=False, |
| | modelClass=Qwen3ForCausalLM, |
| | defaultConfigPath=default_config, |
| | ) |
| |
|
| | tokenizer_root = tokenizer_path or os.path.dirname(file_path) |
| | if tokenizer_root and not os.path.isabs(tokenizer_root): |
| | tokenizer_root = fl.locate_folder(tokenizer_root) |
| | tokenizer_subdir = os.path.join(tokenizer_root, "tokenizer") |
| | if os.path.isdir(tokenizer_subdir): |
| | tokenizer_root = tokenizer_subdir |
| | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_root, trust_remote_code=True) |
| | self.max_length = MAX_LENGTH |
| |
|
| | @torch.no_grad() |
| | def forward(self, txt: List[str]): |
| | all_input_ids = [] |
| | all_attention_masks = [] |
| |
|
| | for prompt in txt: |
| | messages = [{"role": "user", "content": prompt}] |
| | text = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | enable_thinking=False, |
| | ) |
| |
|
| | model_inputs = self.tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding="max_length", |
| | truncation=True, |
| | max_length=self.max_length, |
| | ) |
| | all_input_ids.append(model_inputs["input_ids"]) |
| | all_attention_masks.append(model_inputs["attention_mask"]) |
| |
|
| | input_ids = torch.cat(all_input_ids, dim=0).to(self.model.device) |
| | attention_mask = torch.cat(all_attention_masks, dim=0).to(self.model.device) |
| |
|
| | output = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | output_hidden_states=True, |
| | use_cache=False, |
| | ) |
| |
|
| | out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1) |
| | return rearrange(out, "b c l d -> b l (c d)") |
| |
|