Feature Extraction
sentence-transformers
Safetensors
Transformers
mteb
modernbert
custom_code
Eval Results (legacy)
Instructions to use jxm/cde-small-v2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use jxm/cde-small-v2 with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("jxm/cde-small-v2", trust_remote_code=True) sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Transformers
How to use jxm/cde-small-v2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="jxm/cde-small-v2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("jxm/cde-small-v2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import Callable, Optional, Tuple | |
| import copy | |
| import json | |
| import math | |
| import multiprocessing | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig): | |
| """We create a dummy configuration class that will just set properties | |
| based on whatever kwargs we pass in. | |
| When this class is initialized (see experiments.py) we pass in the | |
| union of all data, model, and training args, all of which should | |
| get saved to the config json. | |
| """ | |
| def __init__(self, **kwargs): | |
| for key, value in kwargs.items(): | |
| try: | |
| json.dumps(value) | |
| setattr(self, key, value) | |
| except TypeError: | |
| # value was not JSON-serializable, skip | |
| continue | |
| super().__init__() | |
| def load_embedder_and_tokenizer(name: str) -> Tuple[ | |
| transformers.PreTrainedModel, | |
| transformers.PreTrainedTokenizer | |
| ]: | |
| assert name is not None, "name must be provided to load_embedder_and_tokenizer" | |
| if name.startswith("nomic") or (name == "bert-base-uncased"): | |
| model = transformers.AutoModelForMaskedLM.from_pretrained(name, trust_remote_code=True).bert | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(name) | |
| elif name in ["gtr-base", "gtr_base"]: | |
| model = transformers.AutoModel.from_pretrained( | |
| "sentence-transformers/gtr-t5-base" | |
| ).encoder | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| "sentence-transformers/gtr-t5-base" | |
| ) | |
| elif name == "pile-t5-base-encoder": | |
| model = transformers.AutoModel.from_pretrained( | |
| "EleutherAI/pile-t5-base" | |
| ).encoder | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| "EleutherAI/pile-t5-base" | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| elif name == "pile-t5-base-decoder": | |
| model = transformers.AutoModel.from_pretrained( | |
| "EleutherAI/pile-t5-base" | |
| ).decoder | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| "EleutherAI/pile-t5-base" | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): | |
| model = transformers.AutoModelForCausalLM.from_pretrained( | |
| name, | |
| # torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| low_cpu_mem_usage=True, | |
| # device_map="auto", | |
| ) | |
| model.padding_side = "right" | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.add_eos_token = True | |
| else: | |
| model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(name) | |
| # if use_bettertransformer: | |
| # from optimum.bettertransformer import BetterTransformer | |
| # model = BetterTransformer.transform(model) | |
| return model, tokenizer | |
| def get_world_size() -> int: | |
| try: | |
| return torch.distributed.get_world_size() | |
| except (RuntimeError, ValueError): | |
| return 1 | |
| def get_rank() -> int: | |
| try: | |
| return torch.distributed.get_rank() | |
| except (RuntimeError, ValueError): | |
| return 0 | |
| def gather(t: torch.Tensor) -> torch.Tensor: | |
| # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM | |
| # https://github.com/pytorch/pytorch/issues/58005 | |
| # only should use torch.distributed.nn.all_gather if we implement a `local_loss` | |
| # like: https://github.com/mlfoundations/open_clip/issues/616 | |
| world_size = get_world_size() | |
| if world_size == 1: | |
| return t | |
| if t.ndim == 0: | |
| t = t.unsqueeze(0) | |
| gathered = [torch.empty_like(t) for _ in range(world_size)] | |
| torch.distributed.all_gather(gathered, t) | |
| gathered[get_rank()] = t | |
| return torch.cat(gathered, dim=0) | |
| def gather_sum(t: torch.Tensor) -> torch.Tensor: | |
| # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM | |
| # https://github.com/pytorch/pytorch/issues/58005 | |
| # only should use torch.distributed.nn.all_gather if we implement a `local_loss` | |
| # like: https://github.com/mlfoundations/open_clip/issues/616 | |
| world_size = get_world_size() | |
| if world_size == 1: | |
| return t | |
| if t.ndim == 0: | |
| t = t.unsqueeze(0) | |
| gathered = [torch.empty_like(t) for _ in range(world_size)] | |
| torch.distributed.all_gather(gathered, t) | |
| gathered = torch.stack(gathered, dim=0) | |
| return gathered.sum(dim=0) # Sum across workers | |
| def get_num_proc() -> int: | |
| world_size: int = get_world_size() | |
| try: | |
| # os.sched_getaffinity respects schedulers, unlike cpu_count(), but it's only available | |
| # on some Unix platforms, so we support both! | |
| return len(os.sched_getaffinity(0)) // world_size # type: ignore[attr-defined] | |
| except AttributeError: | |
| return multiprocessing.cpu_count() // world_size | |
| def torch_main_worker_finish_first(func: Callable): | |
| def wrapper(*args, **kwargs): | |
| # Get local rank (need to support non-DDP). | |
| try: | |
| local_rank = torch.distributed.get_rank() | |
| ddp_enabled = True | |
| except (RuntimeError, ValueError): | |
| local_rank = -1 | |
| ddp_enabled = False | |
| is_main_worker = local_rank <= 0 | |
| # Run on main worker first. | |
| if is_main_worker: | |
| result = func(*args, **kwargs) | |
| # Then everyone waits. | |
| if ddp_enabled: | |
| torch.distributed.barrier() | |
| # Run on other workers now. | |
| if not is_main_worker: | |
| result = func(*args, **kwargs) | |
| # Now everyone waits again. | |
| if ddp_enabled: | |
| torch.distributed.barrier() | |
| return result | |
| return wrapper | |
| def print0(*args, **kwargs) -> None: | |
| if get_rank() == 0: | |
| print(*args, **kwargs) | |
| def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None: | |
| if hasattr(model, "module"): | |
| model = model.module | |
| world_size = get_world_size() | |
| if world_size > 8: | |
| print0(f"[verify_ddp_weights_equal] Skipping with world_size={world_size} ⚠️") | |
| return | |
| for name, param in model.named_parameters(): | |
| if param is None: continue | |
| if param.grad is None: | |
| print0(f"[verify_ddp_weights_equal] Skipping param [{name}] with no grad") | |
| continue | |
| gathered_param = gather(param).reshape((world_size, -1)) | |
| absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs() | |
| rank_params_eq = (absolute_diffs < atol).all() | |
| assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}" | |
| ################################################################################################################### | |
| gathered_param_grad = gather(param.grad).reshape((world_size, -1)) | |
| absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs() | |
| rank_grad_params_eq = (absolute_grad_diffs < atol).all() | |
| assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}" | |
| ################################################################################################################### | |
| print0("[verify_ddp_weights_equal] Verified DDP parameter correctness ✅") | |
| def mean_pool_3d( | |
| hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| B, T, S, D = hidden_states.shape | |
| unmasked_outputs = hidden_states * attention_mask[..., None] | |
| pooled_outputs = unmasked_outputs.sum(dim=2) / (attention_mask.sum(dim=2)[..., None] + 1e-9) | |
| # fix for gradient flow: fill empty rows with the mean of the rest of the sequence | |
| sequence_means = ( | |
| hidden_states.reshape((B, S * T, D)) | |
| .mean(dim=1, keepdim=True) | |
| .expand(-1, T, -1) | |
| ) | |
| pooled_outputs = pooled_outputs.where( | |
| (attention_mask.sum(dim=2)[..., None] > 0), | |
| sequence_means | |
| ) | |
| assert pooled_outputs.shape == (B, T, D) | |
| return pooled_outputs | |
| def mean_pool( | |
| hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| B, _S, D = hidden_states.shape | |
| unmasked_outputs = hidden_states * attention_mask[..., None] | |
| pooled_outputs = unmasked_outputs.sum(dim=1) / (attention_mask.sum(dim=1)[:, None] + 1e-20) | |
| assert pooled_outputs.shape == (B, D) | |
| return pooled_outputs | |
| def mean_pool_weighted( | |
| hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| B, _S, D = hidden_states.shape | |
| attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0] | |
| s = torch.sum(hidden_states * attention_mask.unsqueeze(-1).float(), dim=1) | |
| d = attention_mask.sum(dim=1, keepdim=True).float() | |
| return s / d | |
| def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor: | |
| assert min_row < max_row, f"can't slice from row {min_row} to {max_row}" | |
| t = t.coalesce() | |
| row_idxs = t.indices()[0] | |
| index_mask = (min_row <= row_idxs) & (row_idxs < max_row) | |
| num_rows = (max_row - min_row) | |
| num_cols = t.shape[1] | |
| idxs = t.indices()[:, index_mask] | |
| vals = t.values()[index_mask] | |
| return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce() | |
| def slice_tensor_rows(t: torch.Tensor, min_row: int, max_row: int) -> torch.Tensor: | |
| if t.is_sparse: | |
| return slice_sparse_tensor_rows(t=t, min_row=min_row, max_row=max_row) | |
| else: | |
| return t[min_row:max_row] | |
| def maxsim( | |
| X: torch.Tensor, y: torch.Tensor, | |
| maximize: bool, chunk_size: int = 8_000, | |
| debug_mem_usage: bool = False) -> torch.Tensor: | |
| device = X.device | |
| n_samples = X.shape[0] | |
| max_sim_v = torch.zeros(n_samples, device=device, dtype=X.dtype) | |
| max_sim_i = torch.zeros(n_samples, device=device, dtype=torch.int64) | |
| # TODO: Implement faster max (without going to dense tensors). | |
| # TODO: Use multiple GPUs. | |
| rank = get_rank() | |
| world_size = get_world_size() | |
| worker_worklist_size = int(math.ceil(n_samples / world_size)) | |
| splits_start_idx = worker_worklist_size * rank | |
| splits_end_idx = worker_worklist_size * (rank + 1) | |
| for i in range(splits_start_idx, splits_end_idx, chunk_size): | |
| start, end = i, min(i + chunk_size, n_samples) | |
| sub_x = slice_tensor_rows(X, start, end) | |
| if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}") | |
| if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape) | |
| sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem! | |
| sub_sim = sub_sim | |
| if maximize: | |
| sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1) | |
| else: | |
| sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().min(dim=-1) | |
| del sub_sim | |
| del sub_x | |
| torch.cuda.empty_cache() # needs to happen after maxsim for some reason. | |
| max_sim_v[start: end] = sub_max_sim_v | |
| max_sim_i[start: end] = sub_max_sim_i | |
| # gather | |
| max_sim_v = gather_sum(max_sim_v) | |
| max_sim_i = gather_sum(max_sim_i) | |
| k = y.shape[1] | |
| assert max_sim_v.shape == (n_samples,) | |
| assert max_sim_i.shape == (n_samples,) | |
| assert max_sim_i.min() >= 0 | |
| assert max_sim_i.max() <= k | |
| return max_sim_v, max_sim_i | |
| def forward_batched( | |
| model: torch.nn.Module, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| batch_size: int, | |
| dataset_input_ids: Optional[torch.Tensor] = None, | |
| dataset_attention_mask: Optional[torch.Tensor] = None, | |
| **second_stage_model_kwargs, | |
| ) -> torch.Tensor: | |
| if hasattr(model, "module"): | |
| model = model.module | |
| if hasattr(model, "first_stage_model"): | |
| # Support pooling over 3D dataset_input_ids inputs. | |
| if len(dataset_input_ids.shape) == 2: | |
| dataset_input_ids = dataset_input_ids[None] | |
| dataset_attention_mask = dataset_attention_mask[None] | |
| dataset_embeddings = [] | |
| for j in range(len(dataset_input_ids)): | |
| i = 0 | |
| dataset_embeddings_batch = [] | |
| while i < dataset_input_ids.shape[1]: | |
| dataset_embeddings_batch.append( | |
| model.first_stage_model( | |
| input_ids=dataset_input_ids[j][i:i+batch_size], | |
| attention_mask=dataset_attention_mask[j][i:i+batch_size], | |
| ) | |
| ) | |
| i += batch_size | |
| dataset_embeddings.append( | |
| torch.cat(dataset_embeddings_batch, dim=0) | |
| ) | |
| # Automatically pool over 3D dataset_input_ids. | |
| dataset_embeddings = torch.stack(dataset_embeddings, dim=0).mean(dim=0) | |
| j = 0 | |
| outputs = [] | |
| while j < len(input_ids): | |
| outputs.append( | |
| model.second_stage_model( | |
| input_ids=input_ids[j:j+batch_size], | |
| attention_mask=attention_mask[j:j+batch_size], | |
| dataset_embeddings=dataset_embeddings, | |
| **second_stage_model_kwargs, | |
| ) | |
| ) | |
| j += batch_size | |
| return torch.cat(outputs, dim=0) | |
| else: | |
| i = 0 | |
| outputs = [] | |
| while i < len(input_ids): | |
| outputs.append( | |
| model( | |
| input_ids=input_ids[i:i+batch_size], | |
| attention_mask=attention_mask[i:i+batch_size], | |
| **second_stage_model_kwargs, | |
| ) | |
| ) | |
| i += batch_size | |
| return torch.cat(outputs, dim=0) | |
| def last_token_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| # https://github.com/ContextualAI/gritlm/blob/main/gritlm/gritlm.py#L190 | |
| b, n, d = hidden_state.size() | |
| # Get the last `1` in the attention mask of each item | |
| # Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1` | |
| # except when 1) There's all 1's 2) There's 0's before the 1's | |
| reversed_mask = torch.flip(attention_mask, dims=(1,)) | |
| argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False) | |
| gather_indices = attention_mask.size(1) - argmax_reverse - 1 | |
| # If there are empty sequences, where the index would become -1 it will crash so set them to 0 | |
| gather_indices = torch.clamp(gather_indices, min=0) | |
| # Turn indices from shape [b] -> [b, 1, d] | |
| gather_indices = gather_indices.unsqueeze(-1).repeat(1, d) | |
| gather_indices = gather_indices.unsqueeze(1) | |
| assert gather_indices.shape == (b, 1, d) | |
| # Gather along the seq len: [b, n, d] -> [b, d] | |
| # Actually no need for the attention mask as we gather the last token where attn_mask=1 but | |
| # as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float() | |
| return torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1) | |
| def print0(*args, **kwargs) -> None: | |
| if get_rank() == 0: | |
| print(*args, **kwargs) | |
| def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None: | |
| if hasattr(model, 'transformer'): | |
| if hasattr(model.transformer, 'h'): | |
| # gpt2 | |
| model.transformer.h = model.transformer.h[:n_layers] | |
| else: | |
| model.transformer.layer = model.transformer.layer[:n_layers] | |
| elif hasattr(model, 'encoder'): | |
| if hasattr(model.encoder, 'layers'): | |
| model.encoder.layers = model.encoder.layers[:n_layers] | |
| else: | |
| model.encoder.layer = model.encoder.layer[:n_layers] | |
| else: | |
| raise RuntimeError(f"unknown how to limit layers of model {type(model)}") | |
| def disable_dropout(model: torch.nn.Module): | |
| dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)] | |
| for m in dropout_modules: | |
| m.p = 0.0 | |
| print0( | |
| f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}" | |
| ) | |
| def disable_causality(model: torch.nn.Module): | |
| disabled_modules = 0 | |
| for m in model.modules(): | |
| if hasattr(m, "is_causal"): | |
| m.is_causal = False | |
| disabled_modules += 1 | |
| print0( | |
| f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}" | |
| ) | |
| class ContextualModelMixin(nn.Module): | |
| def num_corpus_tokens(self) -> int: | |
| return self.transductive_corpus_size * self.transductive_tokens_per_document | |
| def contextual_init(self): | |
| self.n_soft_prompt = 8 | |
| self.prompt_projection = torch.nn.Sequential( | |
| torch.nn.Linear(self.hidden_size, self.hidden_size), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt) | |
| ) | |
| self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1) | |
| self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) | |
| self.randomize_dataset_sequence_order = True | |
| self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0) | |
| if self.sequence_dropout_prob > 0.0: | |
| self.sequence_dropout_null_embedding = torch.nn.Parameter( | |
| torch.randn(self.hidden_size) * 0.01, | |
| requires_grad = True | |
| ) | |
| self.output_projection = torch.nn.Sequential( | |
| torch.nn.Linear(self.hidden_size, self.hidden_size), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(self.hidden_size, self.hidden_size) | |
| ) | |
| def _prepare_dataset_embeddings( | |
| self, | |
| input_ids: torch.Tensor, dataset_embeddings: torch.Tensor, | |
| null_dataset_embedding: bool = False, | |
| ) -> torch.Tensor: | |
| if not isinstance(dataset_embeddings, torch.Tensor): | |
| dataset_embeddings = torch.tensor(dataset_embeddings) | |
| if len(dataset_embeddings.shape) == 2: | |
| # Auto-expand for a batch. | |
| dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d) | |
| dataset_embeddings = dataset_embeddings.to(input_ids.device) | |
| batch_size = input_ids.shape[0] | |
| if (self.transductive_tokens_per_document > 1): | |
| if self.training: | |
| # Choose N random documents to fill our context window with. | |
| # This logic is a little confusing but allows us to sample a | |
| # different batch *per-document* | |
| assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document | |
| R = torch.randint( | |
| low=0, | |
| high=len(dataset_embeddings), | |
| size=(batch_size, self.config.transductive_corpus_size), | |
| device=dataset_embeddings.device | |
| ) | |
| # TODO make this deterministic somehow for evaluation? | |
| dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size)) | |
| else: | |
| dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size)) | |
| # print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape) | |
| if dataset_embeddings.shape[1] > self.num_corpus_tokens: | |
| # If too many dataset embeddings are passed in, just take the first N until | |
| # we have the proper number. | |
| dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :] | |
| _, corpus_size, _hidden_size = dataset_embeddings.shape | |
| if _ == 1: | |
| # Auto-expand for a batch. | |
| dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1)) | |
| if self.training and self.sequence_dropout_prob > 0.0: | |
| sequence_dropout_mask = ( | |
| torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob | |
| ) | |
| null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) | |
| dataset_embeddings = torch.where( | |
| sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings | |
| ) | |
| elif null_dataset_embedding: | |
| null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) | |
| dataset_embeddings = null_embeddings | |
| # print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}") | |
| # backbone_max_seq_length = self.backbone.config.max_trained_positions | |
| # assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model" | |
| soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype) | |
| soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size)) | |
| soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1)) | |
| soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1) | |
| # print(f"[ContextualModelMixin] soft_prompt.shape = {soft_prompt.shape}") | |
| if self.training and self.randomize_dataset_sequence_order: | |
| randomized_order = torch.stack( | |
| [ | |
| torch.cat( | |
| ( | |
| torch.randperm(corpus_size, device=soft_prompt.device), | |
| torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size | |
| ), dim=0) | |
| for _ in range(batch_size)]) | |
| randomized_order = randomized_order.to(soft_prompt.device) | |
| soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt)) | |
| return soft_prompt | |
| class BiEncoder(transformers.PreTrainedModel): | |
| embedder: transformers.PreTrainedModel | |
| def __init__( | |
| self, | |
| config, #: transformers.PreTrainedConfig, | |
| ): | |
| super().__init__(config=config) | |
| embedder, _ = load_embedder_and_tokenizer( | |
| config.embedder, | |
| ) | |
| if config.limit_layers: | |
| print0(f"Limiting layers to {config.limit_layers}") | |
| limit_layers(embedder, config.limit_layers) | |
| self.embedder = embedder | |
| # if ("t5" in embedder.config.model_type): | |
| # print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`") | |
| # self.embedder = torch.compile(self.embedder) | |
| self.hidden_size = self.embedder.config.hidden_size | |
| # Allow pooling to multiple tokens per document | |
| self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) | |
| self.mlp = torch.nn.Sequential( | |
| torch.nn.Linear(self.hidden_size, self.hidden_size), | |
| torch.nn.GELU(), | |
| torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size), | |
| ) | |
| self.temp = config.logit_scale | |
| if config.disable_dropout: | |
| disable_dropout(self) | |
| self.pooling_strategy = vars(config).get("pooling_strategy", "mean") | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| dataset_input_ids: Optional[torch.Tensor] = None, | |
| dataset_attention_mask: Optional[torch.Tensor] = None, | |
| token_type_ids = None, | |
| output_hidden_states: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim) | |
| document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim) | |
| where the corpus_size >= batch_size and is structured like this: | |
| [d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2] | |
| for a corpus with three documents and two hard negatives per document | |
| """ | |
| # del dataset_input_ids | |
| # del dataset_attention_mask | |
| del token_type_ids | |
| # from cde.lib.dist import get_rank | |
| # tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") | |
| # if get_rank() == 0: | |
| # breakpoint() | |
| # torch.distributed.barrier() | |
| outputs = ( | |
| self.embedder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ).last_hidden_state | |
| ) | |
| if self.transductive_tokens_per_document > 1: | |
| document_embeddings = None | |
| batch_size, seq_length, output_dim = outputs.shape | |
| if seq_length % self.transductive_tokens_per_document != 0: | |
| # Pad to nearest multiple | |
| n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document) | |
| outputs = torch.cat( | |
| (outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)), | |
| dim=1 | |
| ) | |
| attention_mask = torch.cat( | |
| (attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)), | |
| dim=1 | |
| ) | |
| seq_length += n_extra_embeds | |
| print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask") | |
| # print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape) | |
| outputs = outputs.reshape( | |
| (batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim) | |
| ) | |
| attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1)) | |
| document_embeddings = mean_pool_3d(outputs, attention_mask) | |
| document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim)) | |
| else: | |
| if self.pooling_strategy == "mean": | |
| document_embeddings = mean_pool(outputs, attention_mask) | |
| else: | |
| document_embeddings = document_embeddings.max(dim=1) | |
| output = self.mlp(document_embeddings) | |
| if output_hidden_states: | |
| return { | |
| "hidden_states": outputs, | |
| "pooled": output, | |
| } | |
| else: | |
| return output | |
| class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin): | |
| def __init__( | |
| self, | |
| config, | |
| dataset_backbone: transformers.PreTrainedModel, | |
| first_stage_hidden_size: int, | |
| ): | |
| super().__init__(config=config) | |
| self.backbone = dataset_backbone | |
| self.backbone_hidden_size = self.backbone.config.hidden_size | |
| self.hidden_size = first_stage_hidden_size # Input token size | |
| self.contextual_init() | |
| disable_causality(self.backbone) | |
| self.input_ln = torch.nn.LayerNorm( | |
| self.backbone_hidden_size, | |
| eps=1e-5 | |
| ) | |
| # Override contextual init | |
| self.output_projection = torch.nn.Sequential( | |
| torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size) | |
| ) | |
| self._shift_rotary_embedding() | |
| def num_corpus_tokens(self) -> int: | |
| return self.config.transductive_corpus_size * self.transductive_tokens_per_document | |
| def corpus_token_ratio(self) -> float: | |
| # How many tokens from the first stage make one token in the second | |
| # stage? | |
| return self.backbone_hidden_size / self.hidden_size | |
| def corpus_token_pad_size(self, n_tokens: int) -> int: | |
| return self.hidden_size % self.backbone_hidden_size | |
| def _shift_rotary_embedding(self) -> None: | |
| disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) | |
| # TODO: Can we do this for LLAMA? | |
| print("Warning: Positional embedding disabling not implemented for LLAMA.") | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| dataset_embeddings: torch.Tensor, | |
| output_hidden_states: bool = False, | |
| null_dataset_embedding: bool = False, | |
| ) -> torch.Tensor: | |
| soft_prompt = self._prepare_dataset_embeddings( | |
| input_ids=input_ids, | |
| dataset_embeddings=dataset_embeddings, | |
| null_dataset_embedding=null_dataset_embedding, | |
| ) | |
| # Reshape for this model. | |
| # print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape) | |
| num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item() | |
| soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements)) | |
| num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size) | |
| padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device) | |
| soft_prompt = torch.cat((soft_prompt, padding), dim=1) | |
| soft_prompt = soft_prompt.reshape( | |
| (soft_prompt.shape[0], -1, self.backbone_hidden_size) | |
| ) | |
| soft_prompt = self.input_ln(soft_prompt) | |
| # print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape) | |
| backbone_attention_mask = torch.ones( | |
| soft_prompt.shape[0:2], | |
| dtype=torch.long, | |
| device=soft_prompt.device, | |
| ) | |
| token_embeddings = self.backbone.get_input_embeddings() | |
| inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d) | |
| # print("[2] inputs_embeds.shape =", inputs_embeds.shape) | |
| inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d) | |
| # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape) | |
| input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) | |
| # print("[3.b] attention_mask.shape =", attention_mask.shape) | |
| output = self.backbone( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=input_attention_mask, | |
| output_hidden_states=True, | |
| ) # (1, 4 + b + s, d) | |
| # trim soft prompt | |
| last_hidden_state = output.hidden_states[-1] | |
| n_soft_prompt_tokens = soft_prompt.shape[1] | |
| output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :] | |
| output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:] | |
| # Take last token position | |
| if vars(self.config).get("pooling_strategy") == "last_token": | |
| output_pooled = last_token_pool(output_vectors, output_attention_mask) | |
| elif vars(self.config).get("pooling_strategy") == "mean": | |
| output_pooled = mean_pool(output_vectors, output_attention_mask) | |
| else: | |
| output_pooled = mean_pool_weighted(output_vectors, output_attention_mask) | |
| # average with original vectors | |
| # TODO: Argparse for pooling strategy. | |
| output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) | |
| if output_hidden_states: | |
| return { | |
| "hidden_states": output_vectors, | |
| "pooled": output, | |
| } | |
| else: | |
| return output | |
| class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin): | |
| def __init__( | |
| self, | |
| config, | |
| dataset_backbone: transformers.PreTrainedModel, | |
| ): | |
| super().__init__(config=config) | |
| self.backbone = dataset_backbone | |
| self.hidden_size = self.backbone.config.hidden_size | |
| self.hidden_size = dataset_backbone.config.hidden_size | |
| # self.input_ln = torch.nn.LayerNorm( | |
| # self.hidden_size, | |
| # eps=self.backbone.config.layer_norm_epsilon | |
| # ) | |
| self.contextual_init() | |
| self._shift_rotary_embedding() | |
| def num_corpus_tokens(self) -> int: | |
| return self.config.transductive_corpus_size * self.transductive_tokens_per_document | |
| def _shift_rotary_embedding(self) -> None: | |
| disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) | |
| if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding: | |
| # We only want to apply positional embeddings to the | |
| # *text* portion of the backbone network. | |
| self.backbone.config.rotary_start_pos = 0.0 | |
| rotary_disabled = 0 | |
| rotary_start_pos = self.num_corpus_tokens | |
| for module in self.backbone.modules(): | |
| if hasattr(module, "rotary_emb_dim"): | |
| module.rotary_start_pos = rotary_start_pos | |
| rotary_disabled += 1 | |
| print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}") | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| dataset_embeddings: torch.Tensor, | |
| output_hidden_states: bool = False, | |
| null_dataset_embedding: bool = False, | |
| ) -> torch.Tensor: | |
| # print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape) | |
| soft_prompt = self._prepare_dataset_embeddings( | |
| input_ids=input_ids, | |
| dataset_embeddings=dataset_embeddings, | |
| null_dataset_embedding=null_dataset_embedding, | |
| ) | |
| # print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}") | |
| backbone_attention_mask = torch.ones( | |
| soft_prompt.shape[0:2], | |
| dtype=torch.long, | |
| device=soft_prompt.device, | |
| ) | |
| inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d) | |
| # print("[2] inputs_embeds.shape =", inputs_embeds.shape) | |
| inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d) | |
| # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape) | |
| attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) | |
| # print("[3.b] attention_mask.shape =", attention_mask.shape) | |
| output = self.backbone( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| ) # (1, 4 + b + s, d) | |
| # trim soft prompt | |
| output_vectors = output.last_hidden_state | |
| # use only these tokens | |
| n_soft_prompt_tokens = soft_prompt.shape[1] | |
| # print("n_soft_prompt_tokens =", n_soft_prompt_tokens) | |
| output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :] | |
| output_attention_mask = attention_mask[:, n_soft_prompt_tokens:] | |
| # print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape) | |
| output_pooled = mean_pool(output_vectors, output_attention_mask) | |
| # average with original vectors | |
| # TODO: Argparse for pooling strategy. | |
| # output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d) | |
| # print("output_pooled.shape =", output_pooled.shape) | |
| output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) | |
| # print("returning output.shape =", output.shape) | |
| if output_hidden_states: | |
| return { | |
| "hidden_states": output_vectors, | |
| "pooled": output, | |
| } | |
| else: | |
| return output | |
| class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin): | |
| def __init__( | |
| self, | |
| config, #: transformers.PreTrainedConfig, | |
| embedder: transformers.PreTrainedModel, | |
| ): | |
| super().__init__(config=config) | |
| self.embedder = embedder | |
| self.hidden_size = self.embedder.config.hidden_size | |
| self.contextual_init() | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| dataset_input_ids: torch.Tensor, | |
| dataset_attention_mask: torch.Tensor, | |
| output_hidden_states: bool = False, | |
| ) -> torch.Tensor: | |
| R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device) | |
| dataset_input_ids = dataset_input_ids[R] | |
| input_ids = torch.cat((dataset_input_ids, input_ids), dim=1) | |
| dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device) | |
| input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1) | |
| output_attention_mask = torch.cat( | |
| (torch.zeros_like(dataset_input_ids), attention_mask), dim=1 | |
| ) | |
| output = self.embedder( | |
| input_ids=input_ids, | |
| attention_mask=input_attention_mask, | |
| ) | |
| output_vectors = output.last_hidden_state | |
| output_pooled = mean_pool(output_vectors, output_attention_mask) | |
| output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) | |
| if output_hidden_states: | |
| S_d = dataset_attention_mask.shape[1] | |
| output_vectors = output_vectors[:, S_d:, :] | |
| return { | |
| "hidden_states": output_vectors, | |
| "pooled": output, | |
| } | |
| else: | |
| return output | |
| class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel): | |
| config_class = ContextualModelConfig | |
| embedder: transformers.PreTrainedModel | |
| dataset_backbone: transformers.PreTrainedModel | |
| def __init__( | |
| self, | |
| config, | |
| ): | |
| super().__init__(config=config) | |
| dataset_backbone, _ = load_embedder_and_tokenizer( | |
| vars(config).get("dataset_backbone") or config.embedder | |
| ) | |
| if config.limit_layers: | |
| print0(f"Limiting layers to {config.limit_layers}") | |
| limit_layers(dataset_backbone, config.limit_layers) | |
| biencoder_config = copy.deepcopy(config) | |
| biencoder_config.embedding_output_dim = None | |
| biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None) | |
| self.first_stage_model = BiEncoder( | |
| config=biencoder_config, | |
| ) | |
| if vars(config).get("autoregressive_backbone", False): | |
| self.second_stage_model = DatasetConditionedAutoregressive( | |
| config=config, | |
| dataset_backbone=dataset_backbone, | |
| first_stage_hidden_size=self.first_stage_model.hidden_size, | |
| ) | |
| else: | |
| self.second_stage_model = DatasetConditionedBiencoder( | |
| config=config, | |
| dataset_backbone=dataset_backbone | |
| ) | |
| self.temp = config.logit_scale | |
| if config.disable_dropout: | |
| disable_dropout(self) | |
| transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False) | |
| if transductive_tie_token_embeddings: | |
| self.second_stage_model.backbone.embeddings.word_embeddings.weight = ( | |
| self.first_stage_model.embedder.embeddings.word_embeddings.weight | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| dataset_input_ids: Optional[torch.Tensor], | |
| dataset_attention_mask: Optional[torch.Tensor], | |
| output_hidden_states: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| input_ids (long torch.Tensor) – ids of input tokens | |
| attention_mask (bool torch.Tensor) | |
| """ | |
| dataset_embeddings = self.first_stage_model( | |
| input_ids=dataset_input_ids, | |
| attention_mask=dataset_attention_mask | |
| ) | |
| return self.second_stage_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| dataset_embeddings=dataset_embeddings, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| def get_model_class(name: str): | |
| if name in 'transductive': | |
| return ContextualDocumentEmbeddingTransformer | |
| elif name == 'biencoder': | |
| return BiEncoder | |
| elif name == "dataset_prefix_biencoder": | |
| return DatasetPrefixBiencoder | |
| else: | |
| raise ValueError(f'unknown model cls {name}') | |