| | from typing import Callable, Optional, Tuple |
| |
|
| | import copy |
| | import json |
| | import math |
| | import multiprocessing |
| | import os |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import transformers |
| |
|
| | class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig): |
| | """We create a dummy configuration class that will just set properties |
| | based on whatever kwargs we pass in. |
| | |
| | When this class is initialized (see experiments.py) we pass in the |
| | union of all data, model, and training args, all of which should |
| | get saved to the config json. |
| | """ |
| |
|
| | def __init__(self, **kwargs): |
| | for key, value in kwargs.items(): |
| | try: |
| | json.dumps(value) |
| | setattr(self, key, value) |
| | except TypeError: |
| | |
| | continue |
| | super().__init__() |
| |
|
| | def load_embedder_and_tokenizer(name: str) -> Tuple[ |
| | transformers.PreTrainedModel, |
| | transformers.PreTrainedTokenizer |
| | ]: |
| | assert name is not None, "name must be provided to load_embedder_and_tokenizer" |
| | if name.startswith("nomic") or (name == "bert-base-uncased"): |
| | model = transformers.AutoModelForMaskedLM.from_pretrained(name, trust_remote_code=True).bert |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| | elif name in ["gtr-base", "gtr_base"]: |
| | model = transformers.AutoModel.from_pretrained( |
| | "sentence-transformers/gtr-t5-base" |
| | ).encoder |
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | "sentence-transformers/gtr-t5-base" |
| | ) |
| | elif name == "pile-t5-base-encoder": |
| | model = transformers.AutoModel.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ).encoder |
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | elif name == "pile-t5-base-decoder": |
| | model = transformers.AutoModel.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ).decoder |
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): |
| | model = transformers.AutoModelForCausalLM.from_pretrained( |
| | name, |
| | |
| | attn_implementation="flash_attention_2", |
| | low_cpu_mem_usage=True, |
| | |
| | ) |
| | model.padding_side = "right" |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.add_eos_token = True |
| | else: |
| | model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| |
|
| | |
| | |
| | |
| | return model, tokenizer |
| | def get_world_size() -> int: |
| | try: |
| | return torch.distributed.get_world_size() |
| | except (RuntimeError, ValueError): |
| | return 1 |
| |
|
| |
|
| | def get_rank() -> int: |
| | try: |
| | return torch.distributed.get_rank() |
| | except (RuntimeError, ValueError): |
| | return 0 |
| | |
| | def gather(t: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | world_size = get_world_size() |
| | if world_size == 1: |
| | return t |
| |
|
| | if t.ndim == 0: |
| | t = t.unsqueeze(0) |
| |
|
| | gathered = [torch.empty_like(t) for _ in range(world_size)] |
| | torch.distributed.all_gather(gathered, t) |
| | gathered[get_rank()] = t |
| | return torch.cat(gathered, dim=0) |
| |
|
| |
|
| | def gather_sum(t: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | world_size = get_world_size() |
| | if world_size == 1: |
| | return t |
| |
|
| | if t.ndim == 0: |
| | t = t.unsqueeze(0) |
| |
|
| | gathered = [torch.empty_like(t) for _ in range(world_size)] |
| | torch.distributed.all_gather(gathered, t) |
| | gathered = torch.stack(gathered, dim=0) |
| | return gathered.sum(dim=0) |
| |
|
| |
|
| | def get_num_proc() -> int: |
| | world_size: int = get_world_size() |
| | try: |
| | |
| | |
| | return len(os.sched_getaffinity(0)) // world_size |
| | except AttributeError: |
| | return multiprocessing.cpu_count() // world_size |
| |
|
| |
|
| | def torch_main_worker_finish_first(func: Callable): |
| | def wrapper(*args, **kwargs): |
| | |
| | try: |
| | local_rank = torch.distributed.get_rank() |
| | ddp_enabled = True |
| | except (RuntimeError, ValueError): |
| | local_rank = -1 |
| | ddp_enabled = False |
| | is_main_worker = local_rank <= 0 |
| | |
| | if is_main_worker: |
| | result = func(*args, **kwargs) |
| | |
| | if ddp_enabled: |
| | torch.distributed.barrier() |
| | |
| | if not is_main_worker: |
| | result = func(*args, **kwargs) |
| | |
| | if ddp_enabled: |
| | torch.distributed.barrier() |
| | return result |
| |
|
| | return wrapper |
| |
|
| |
|
| | def print0(*args, **kwargs) -> None: |
| | if get_rank() == 0: |
| | print(*args, **kwargs) |
| |
|
| |
|
| | def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None: |
| | if hasattr(model, "module"): |
| | model = model.module |
| | |
| | world_size = get_world_size() |
| |
|
| | if world_size > 8: |
| | print0(f"[verify_ddp_weights_equal] Skipping with world_size={world_size} ⚠️") |
| | return |
| |
|
| | for name, param in model.named_parameters(): |
| | if param is None: continue |
| | if param.grad is None: |
| | print0(f"[verify_ddp_weights_equal] Skipping param [{name}] with no grad") |
| | continue |
| | gathered_param = gather(param).reshape((world_size, -1)) |
| | absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs() |
| | rank_params_eq = (absolute_diffs < atol).all() |
| | assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}" |
| | |
| | gathered_param_grad = gather(param.grad).reshape((world_size, -1)) |
| | absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs() |
| | rank_grad_params_eq = (absolute_grad_diffs < atol).all() |
| | assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}" |
| | |
| | |
| | |
| | print0("[verify_ddp_weights_equal] Verified DDP parameter correctness ✅") |
| | |
| |
|
| |
|
| | def mean_pool_3d( |
| | hidden_states: torch.Tensor, attention_mask: torch.Tensor |
| | ) -> torch.Tensor: |
| | B, T, S, D = hidden_states.shape |
| | unmasked_outputs = hidden_states * attention_mask[..., None] |
| | pooled_outputs = unmasked_outputs.sum(dim=2) / (attention_mask.sum(dim=2)[..., None] + 1e-9) |
| |
|
| | |
| | sequence_means = ( |
| | hidden_states.reshape((B, S * T, D)) |
| | .mean(dim=1, keepdim=True) |
| | .expand(-1, T, -1) |
| | ) |
| | pooled_outputs = pooled_outputs.where( |
| | (attention_mask.sum(dim=2)[..., None] > 0), |
| | sequence_means |
| | ) |
| | assert pooled_outputs.shape == (B, T, D) |
| |
|
| | return pooled_outputs |
| |
|
| | def mean_pool( |
| | hidden_states: torch.Tensor, attention_mask: torch.Tensor |
| | ) -> torch.Tensor: |
| | B, _S, D = hidden_states.shape |
| | unmasked_outputs = hidden_states * attention_mask[..., None] |
| | pooled_outputs = unmasked_outputs.sum(dim=1) / (attention_mask.sum(dim=1)[:, None] + 1e-20) |
| | |
| | assert pooled_outputs.shape == (B, D) |
| | return pooled_outputs |
| |
|
| |
|
| | def mean_pool_weighted( |
| | hidden_states: torch.Tensor, attention_mask: torch.Tensor |
| | ) -> torch.Tensor: |
| | B, _S, D = hidden_states.shape |
| | attention_mask *= attention_mask.cumsum(dim=1) |
| | s = torch.sum(hidden_states * attention_mask.unsqueeze(-1).float(), dim=1) |
| | d = attention_mask.sum(dim=1, keepdim=True).float() |
| | return s / d |
| |
|
| |
|
| | def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor: |
| | assert min_row < max_row, f"can't slice from row {min_row} to {max_row}" |
| | t = t.coalesce() |
| | row_idxs = t.indices()[0] |
| | index_mask = (min_row <= row_idxs) & (row_idxs < max_row) |
| |
|
| | num_rows = (max_row - min_row) |
| | num_cols = t.shape[1] |
| |
|
| | idxs = t.indices()[:, index_mask] |
| | vals = t.values()[index_mask] |
| | return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce() |
| |
|
| |
|
| | def slice_tensor_rows(t: torch.Tensor, min_row: int, max_row: int) -> torch.Tensor: |
| | if t.is_sparse: |
| | return slice_sparse_tensor_rows(t=t, min_row=min_row, max_row=max_row) |
| | else: |
| | return t[min_row:max_row] |
| |
|
| |
|
| | @torch.no_grad |
| | def maxsim( |
| | X: torch.Tensor, y: torch.Tensor, |
| | maximize: bool, chunk_size: int = 8_000, |
| | debug_mem_usage: bool = False) -> torch.Tensor: |
| | device = X.device |
| | n_samples = X.shape[0] |
| |
|
| | max_sim_v = torch.zeros(n_samples, device=device, dtype=X.dtype) |
| | max_sim_i = torch.zeros(n_samples, device=device, dtype=torch.int64) |
| |
|
| | |
| | |
| | rank = get_rank() |
| | world_size = get_world_size() |
| |
|
| | worker_worklist_size = int(math.ceil(n_samples / world_size)) |
| | splits_start_idx = worker_worklist_size * rank |
| | splits_end_idx = worker_worklist_size * (rank + 1) |
| |
|
| | for i in range(splits_start_idx, splits_end_idx, chunk_size): |
| | start, end = i, min(i + chunk_size, n_samples) |
| | sub_x = slice_tensor_rows(X, start, end) |
| | if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}") |
| | if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape) |
| | sub_sim = sub_x @ y |
| | sub_sim = sub_sim |
| | if maximize: |
| | sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1) |
| | else: |
| | sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().min(dim=-1) |
| | del sub_sim |
| | del sub_x |
| | torch.cuda.empty_cache() |
| | max_sim_v[start: end] = sub_max_sim_v |
| | max_sim_i[start: end] = sub_max_sim_i |
| | |
| | |
| | max_sim_v = gather_sum(max_sim_v) |
| | max_sim_i = gather_sum(max_sim_i) |
| | k = y.shape[1] |
| |
|
| | assert max_sim_v.shape == (n_samples,) |
| | assert max_sim_i.shape == (n_samples,) |
| | assert max_sim_i.min() >= 0 |
| | assert max_sim_i.max() <= k |
| |
|
| | return max_sim_v, max_sim_i |
| |
|
| |
|
| | def forward_batched( |
| | model: torch.nn.Module, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | batch_size: int, |
| | dataset_input_ids: Optional[torch.Tensor] = None, |
| | dataset_attention_mask: Optional[torch.Tensor] = None, |
| | **second_stage_model_kwargs, |
| | ) -> torch.Tensor: |
| | if hasattr(model, "module"): |
| | model = model.module |
| | |
| | if hasattr(model, "first_stage_model"): |
| | |
| | if len(dataset_input_ids.shape) == 2: |
| | dataset_input_ids = dataset_input_ids[None] |
| | dataset_attention_mask = dataset_attention_mask[None] |
| |
|
| | dataset_embeddings = [] |
| | for j in range(len(dataset_input_ids)): |
| | i = 0 |
| | dataset_embeddings_batch = [] |
| | while i < dataset_input_ids.shape[1]: |
| | dataset_embeddings_batch.append( |
| | model.first_stage_model( |
| | input_ids=dataset_input_ids[j][i:i+batch_size], |
| | attention_mask=dataset_attention_mask[j][i:i+batch_size], |
| | ) |
| | ) |
| | i += batch_size |
| | dataset_embeddings.append( |
| | torch.cat(dataset_embeddings_batch, dim=0) |
| | ) |
| | |
| | |
| | dataset_embeddings = torch.stack(dataset_embeddings, dim=0).mean(dim=0) |
| |
|
| | j = 0 |
| | outputs = [] |
| | while j < len(input_ids): |
| | outputs.append( |
| | model.second_stage_model( |
| | input_ids=input_ids[j:j+batch_size], |
| | attention_mask=attention_mask[j:j+batch_size], |
| | dataset_embeddings=dataset_embeddings, |
| | **second_stage_model_kwargs, |
| | ) |
| | ) |
| | j += batch_size |
| | return torch.cat(outputs, dim=0) |
| |
|
| | else: |
| | i = 0 |
| | outputs = [] |
| | while i < len(input_ids): |
| | outputs.append( |
| | model( |
| | input_ids=input_ids[i:i+batch_size], |
| | attention_mask=attention_mask[i:i+batch_size], |
| | **second_stage_model_kwargs, |
| | ) |
| | ) |
| | i += batch_size |
| | return torch.cat(outputs, dim=0) |
| |
|
| |
|
| | def last_token_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| | |
| | b, n, d = hidden_state.size() |
| | |
| | |
| | |
| | reversed_mask = torch.flip(attention_mask, dims=(1,)) |
| | argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False) |
| | gather_indices = attention_mask.size(1) - argmax_reverse - 1 |
| | |
| | gather_indices = torch.clamp(gather_indices, min=0) |
| | |
| | gather_indices = gather_indices.unsqueeze(-1).repeat(1, d) |
| | gather_indices = gather_indices.unsqueeze(1) |
| | assert gather_indices.shape == (b, 1, d) |
| | |
| | |
| | |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float() |
| | return torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1) |
| |
|
| | def print0(*args, **kwargs) -> None: |
| | if get_rank() == 0: |
| | print(*args, **kwargs) |
| |
|
| |
|
| | def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None: |
| | if hasattr(model, 'transformer'): |
| | if hasattr(model.transformer, 'h'): |
| | |
| | model.transformer.h = model.transformer.h[:n_layers] |
| | else: |
| | model.transformer.layer = model.transformer.layer[:n_layers] |
| | elif hasattr(model, 'encoder'): |
| | if hasattr(model.encoder, 'layers'): |
| | model.encoder.layers = model.encoder.layers[:n_layers] |
| | else: |
| | model.encoder.layer = model.encoder.layer[:n_layers] |
| | else: |
| | raise RuntimeError(f"unknown how to limit layers of model {type(model)}") |
| | |
| |
|
| |
|
| | def disable_dropout(model: torch.nn.Module): |
| | dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)] |
| | for m in dropout_modules: |
| | m.p = 0.0 |
| | print0( |
| | f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}" |
| | ) |
| |
|
| |
|
| | def disable_causality(model: torch.nn.Module): |
| | disabled_modules = 0 |
| | for m in model.modules(): |
| | if hasattr(m, "is_causal"): |
| | m.is_causal = False |
| | disabled_modules += 1 |
| | print0( |
| | f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}" |
| | ) |
| |
|
| | class ContextualModelMixin(nn.Module): |
| | @property |
| | def num_corpus_tokens(self) -> int: |
| | return self.transductive_corpus_size * self.transductive_tokens_per_document |
| |
|
| | def contextual_init(self): |
| | self.n_soft_prompt = 8 |
| | self.prompt_projection = torch.nn.Sequential( |
| | torch.nn.Linear(self.hidden_size, self.hidden_size), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt) |
| | ) |
| | self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1) |
| | self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
| | self.randomize_dataset_sequence_order = True |
| | self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0) |
| | if self.sequence_dropout_prob > 0.0: |
| | self.sequence_dropout_null_embedding = torch.nn.Parameter( |
| | torch.randn(self.hidden_size) * 0.01, |
| | requires_grad = True |
| | ) |
| | self.output_projection = torch.nn.Sequential( |
| | torch.nn.Linear(self.hidden_size, self.hidden_size), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(self.hidden_size, self.hidden_size) |
| | ) |
| |
|
| | def _prepare_dataset_embeddings( |
| | self, |
| | input_ids: torch.Tensor, dataset_embeddings: torch.Tensor, |
| | null_dataset_embedding: bool = False, |
| | ) -> torch.Tensor: |
| | if not isinstance(dataset_embeddings, torch.Tensor): |
| | dataset_embeddings = torch.tensor(dataset_embeddings) |
| |
|
| | if len(dataset_embeddings.shape) == 2: |
| | |
| | dataset_embeddings = dataset_embeddings[None, :, :] |
| | dataset_embeddings = dataset_embeddings.to(input_ids.device) |
| | |
| | batch_size = input_ids.shape[0] |
| | if (self.transductive_tokens_per_document > 1): |
| | if self.training: |
| | |
| | |
| | |
| | assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document |
| | R = torch.randint( |
| | low=0, |
| | high=len(dataset_embeddings), |
| | size=(batch_size, self.config.transductive_corpus_size), |
| | device=dataset_embeddings.device |
| | ) |
| | |
| | dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size)) |
| | else: |
| | dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size)) |
| | |
| |
|
| | if dataset_embeddings.shape[1] > self.num_corpus_tokens: |
| | |
| | |
| | dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :] |
| | |
| | _, corpus_size, _hidden_size = dataset_embeddings.shape |
| | if _ == 1: |
| | |
| | dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1)) |
| |
|
| | if self.training and self.sequence_dropout_prob > 0.0: |
| | sequence_dropout_mask = ( |
| | torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob |
| | ) |
| | null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
| | dataset_embeddings = torch.where( |
| | sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings |
| | ) |
| | elif null_dataset_embedding: |
| | null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
| | dataset_embeddings = null_embeddings |
| | |
| | |
| | |
| | |
| | |
| | soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype) |
| | soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size)) |
| | soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) |
| | soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1) |
| |
|
| | |
| |
|
| | if self.training and self.randomize_dataset_sequence_order: |
| | randomized_order = torch.stack( |
| | [ |
| | torch.cat( |
| | ( |
| | torch.randperm(corpus_size, device=soft_prompt.device), |
| | torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size |
| | ), dim=0) |
| | for _ in range(batch_size)]) |
| | randomized_order = randomized_order.to(soft_prompt.device) |
| | soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt)) |
| | |
| | return soft_prompt |
| |
|
| | class BiEncoder(transformers.PreTrainedModel): |
| | embedder: transformers.PreTrainedModel |
| | def __init__( |
| | self, |
| | config, |
| | ): |
| | super().__init__(config=config) |
| | embedder, _ = load_embedder_and_tokenizer( |
| | config.embedder, |
| | ) |
| |
|
| | if config.limit_layers: |
| | print0(f"Limiting layers to {config.limit_layers}") |
| | limit_layers(embedder, config.limit_layers) |
| | |
| | self.embedder = embedder |
| | |
| | |
| | |
| | self.hidden_size = self.embedder.config.hidden_size |
| | |
| | self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
| | self.mlp = torch.nn.Sequential( |
| | torch.nn.Linear(self.hidden_size, self.hidden_size), |
| | torch.nn.GELU(), |
| | torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size), |
| | ) |
| | self.temp = config.logit_scale |
| |
|
| | if config.disable_dropout: |
| | disable_dropout(self) |
| | self.pooling_strategy = vars(config).get("pooling_strategy", "mean") |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | dataset_input_ids: Optional[torch.Tensor] = None, |
| | dataset_attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids = None, |
| | output_hidden_states: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim) |
| | document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim) |
| | where the corpus_size >= batch_size and is structured like this: |
| | [d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2] |
| | for a corpus with three documents and two hard negatives per document |
| | """ |
| | |
| | |
| | del token_type_ids |
| |
|
| | |
| | |
| | |
| | |
| | |
| | outputs = ( |
| | self.embedder( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | ).last_hidden_state |
| | ) |
| |
|
| | if self.transductive_tokens_per_document > 1: |
| | document_embeddings = None |
| | batch_size, seq_length, output_dim = outputs.shape |
| |
|
| | if seq_length % self.transductive_tokens_per_document != 0: |
| | |
| | n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document) |
| | outputs = torch.cat( |
| | (outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)), |
| | dim=1 |
| | ) |
| | attention_mask = torch.cat( |
| | (attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)), |
| | dim=1 |
| | ) |
| | seq_length += n_extra_embeds |
| | print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask") |
| | |
| | |
| |
|
| | outputs = outputs.reshape( |
| | (batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim) |
| | ) |
| |
|
| | attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1)) |
| | document_embeddings = mean_pool_3d(outputs, attention_mask) |
| | |
| | document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim)) |
| | else: |
| | if self.pooling_strategy == "mean": |
| | document_embeddings = mean_pool(outputs, attention_mask) |
| | else: |
| | document_embeddings = document_embeddings.max(dim=1) |
| | output = self.mlp(document_embeddings) |
| |
|
| | if output_hidden_states: |
| | return { |
| | "hidden_states": outputs, |
| | "pooled": output, |
| | } |
| | else: |
| | return output |
| |
|
| |
|
| | class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin): |
| | def __init__( |
| | self, |
| | config, |
| | dataset_backbone: transformers.PreTrainedModel, |
| | first_stage_hidden_size: int, |
| | ): |
| | super().__init__(config=config) |
| | self.backbone = dataset_backbone |
| | self.backbone_hidden_size = self.backbone.config.hidden_size |
| | self.hidden_size = first_stage_hidden_size |
| | self.contextual_init() |
| | disable_causality(self.backbone) |
| | |
| | self.input_ln = torch.nn.LayerNorm( |
| | self.backbone_hidden_size, |
| | eps=1e-5 |
| | ) |
| | |
| | |
| | self.output_projection = torch.nn.Sequential( |
| | torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size) |
| | ) |
| | self._shift_rotary_embedding() |
| | |
| | @property |
| | def num_corpus_tokens(self) -> int: |
| | return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
| |
|
| | @property |
| | def corpus_token_ratio(self) -> float: |
| | |
| | |
| | return self.backbone_hidden_size / self.hidden_size |
| | |
| | def corpus_token_pad_size(self, n_tokens: int) -> int: |
| | return self.hidden_size % self.backbone_hidden_size |
| | |
| | def _shift_rotary_embedding(self) -> None: |
| | disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
| | |
| | print("Warning: Positional embedding disabling not implemented for LLAMA.") |
| | |
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | dataset_embeddings: torch.Tensor, |
| | output_hidden_states: bool = False, |
| | null_dataset_embedding: bool = False, |
| | ) -> torch.Tensor: |
| | soft_prompt = self._prepare_dataset_embeddings( |
| | input_ids=input_ids, |
| | dataset_embeddings=dataset_embeddings, |
| | null_dataset_embedding=null_dataset_embedding, |
| | ) |
| | |
| | |
| | |
| | num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item() |
| | soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements)) |
| | num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size) |
| | padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device) |
| | soft_prompt = torch.cat((soft_prompt, padding), dim=1) |
| | soft_prompt = soft_prompt.reshape( |
| | (soft_prompt.shape[0], -1, self.backbone_hidden_size) |
| | ) |
| | soft_prompt = self.input_ln(soft_prompt) |
| | |
| |
|
| | backbone_attention_mask = torch.ones( |
| | soft_prompt.shape[0:2], |
| | dtype=torch.long, |
| | device=soft_prompt.device, |
| | ) |
| | token_embeddings = self.backbone.get_input_embeddings() |
| | inputs_embeds = token_embeddings(input_ids) |
| | |
| | inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
| | |
| | input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
| | |
| |
|
| | output = self.backbone( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=input_attention_mask, |
| | output_hidden_states=True, |
| | ) |
| | |
| | last_hidden_state = output.hidden_states[-1] |
| | n_soft_prompt_tokens = soft_prompt.shape[1] |
| |
|
| | output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :] |
| | output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:] |
| |
|
| | |
| | if vars(self.config).get("pooling_strategy") == "last_token": |
| | output_pooled = last_token_pool(output_vectors, output_attention_mask) |
| | elif vars(self.config).get("pooling_strategy") == "mean": |
| | output_pooled = mean_pool(output_vectors, output_attention_mask) |
| | else: |
| | output_pooled = mean_pool_weighted(output_vectors, output_attention_mask) |
| |
|
| | |
| | |
| | output = self.output_projection(output_pooled) |
| |
|
| | if output_hidden_states: |
| | return { |
| | "hidden_states": output_vectors, |
| | "pooled": output, |
| | } |
| | else: |
| | return output |
| |
|
| |
|
| | class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
| | def __init__( |
| | self, |
| | config, |
| | dataset_backbone: transformers.PreTrainedModel, |
| | ): |
| | super().__init__(config=config) |
| | self.backbone = dataset_backbone |
| | self.hidden_size = self.backbone.config.hidden_size |
| | self.hidden_size = dataset_backbone.config.hidden_size |
| | |
| | |
| | |
| | |
| | self.contextual_init() |
| | self._shift_rotary_embedding() |
| | |
| | @property |
| | def num_corpus_tokens(self) -> int: |
| | return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
| | |
| | def _shift_rotary_embedding(self) -> None: |
| | disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
| | if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding: |
| | |
| | |
| | self.backbone.config.rotary_start_pos = 0.0 |
| | rotary_disabled = 0 |
| |
|
| | rotary_start_pos = self.num_corpus_tokens |
| | for module in self.backbone.modules(): |
| | if hasattr(module, "rotary_emb_dim"): |
| | module.rotary_start_pos = rotary_start_pos |
| | rotary_disabled += 1 |
| | print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}") |
| | |
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | dataset_embeddings: torch.Tensor, |
| | output_hidden_states: bool = False, |
| | null_dataset_embedding: bool = False, |
| | ) -> torch.Tensor: |
| | |
| | soft_prompt = self._prepare_dataset_embeddings( |
| | input_ids=input_ids, |
| | dataset_embeddings=dataset_embeddings, |
| | null_dataset_embedding=null_dataset_embedding, |
| | ) |
| | |
| | backbone_attention_mask = torch.ones( |
| | soft_prompt.shape[0:2], |
| | dtype=torch.long, |
| | device=soft_prompt.device, |
| | ) |
| | inputs_embeds = self.backbone.embeddings(input_ids) |
| | |
| | inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
| | |
| | attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
| | |
| | output = self.backbone( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | ) |
| | |
| | output_vectors = output.last_hidden_state |
| |
|
| | |
| | n_soft_prompt_tokens = soft_prompt.shape[1] |
| | |
| |
|
| | output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :] |
| | output_attention_mask = attention_mask[:, n_soft_prompt_tokens:] |
| |
|
| | |
| | output_pooled = mean_pool(output_vectors, output_attention_mask) |
| |
|
| | |
| | |
| | |
| | |
| | output = self.output_projection(output_pooled) |
| |
|
| | |
| |
|
| | if output_hidden_states: |
| | return { |
| | "hidden_states": output_vectors, |
| | "pooled": output, |
| | } |
| | else: |
| | return output |
| |
|
| |
|
| | class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
| | def __init__( |
| | self, |
| | config, |
| | embedder: transformers.PreTrainedModel, |
| | ): |
| | super().__init__(config=config) |
| | self.embedder = embedder |
| | self.hidden_size = self.embedder.config.hidden_size |
| | self.contextual_init() |
| | |
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | dataset_input_ids: torch.Tensor, |
| | dataset_attention_mask: torch.Tensor, |
| | output_hidden_states: bool = False, |
| | ) -> torch.Tensor: |
| | R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device) |
| | |
| | dataset_input_ids = dataset_input_ids[R] |
| | input_ids = torch.cat((dataset_input_ids, input_ids), dim=1) |
| |
|
| | dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device) |
| | input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1) |
| | output_attention_mask = torch.cat( |
| | (torch.zeros_like(dataset_input_ids), attention_mask), dim=1 |
| | ) |
| |
|
| | output = self.embedder( |
| | input_ids=input_ids, |
| | attention_mask=input_attention_mask, |
| | ) |
| | |
| | output_vectors = output.last_hidden_state |
| | output_pooled = mean_pool(output_vectors, output_attention_mask) |
| | output = self.output_projection(output_pooled) |
| |
|
| | if output_hidden_states: |
| | S_d = dataset_attention_mask.shape[1] |
| | output_vectors = output_vectors[:, S_d:, :] |
| | return { |
| | "hidden_states": output_vectors, |
| | "pooled": output, |
| | } |
| | else: |
| | return output |
| |
|
| |
|
| | class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel): |
| | config_class = ContextualModelConfig |
| | embedder: transformers.PreTrainedModel |
| | dataset_backbone: transformers.PreTrainedModel |
| | def __init__( |
| | self, |
| | config, |
| | ): |
| | super().__init__(config=config) |
| | dataset_backbone, _ = load_embedder_and_tokenizer( |
| | vars(config).get("dataset_backbone") or config.embedder |
| | ) |
| |
|
| | if config.limit_layers: |
| | print0(f"Limiting layers to {config.limit_layers}") |
| | limit_layers(dataset_backbone, config.limit_layers) |
| | |
| | biencoder_config = copy.deepcopy(config) |
| | biencoder_config.embedding_output_dim = None |
| | biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None) |
| | self.first_stage_model = BiEncoder( |
| | config=biencoder_config, |
| | ) |
| |
|
| | if vars(config).get("autoregressive_backbone", False): |
| | self.second_stage_model = DatasetConditionedAutoregressive( |
| | config=config, |
| | dataset_backbone=dataset_backbone, |
| | first_stage_hidden_size=self.first_stage_model.hidden_size, |
| | ) |
| | else: |
| | self.second_stage_model = DatasetConditionedBiencoder( |
| | config=config, |
| | dataset_backbone=dataset_backbone |
| | ) |
| | |
| | self.temp = config.logit_scale |
| | if config.disable_dropout: |
| | disable_dropout(self) |
| | |
| | transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False) |
| | if transductive_tie_token_embeddings: |
| | self.second_stage_model.backbone.embeddings.word_embeddings.weight = ( |
| | self.first_stage_model.embedder.embeddings.word_embeddings.weight |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | dataset_input_ids: Optional[torch.Tensor], |
| | dataset_attention_mask: Optional[torch.Tensor], |
| | output_hidden_states: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | input_ids (long torch.Tensor) – ids of input tokens |
| | attention_mask (bool torch.Tensor) |
| | """ |
| | dataset_embeddings = self.first_stage_model( |
| | input_ids=dataset_input_ids, |
| | attention_mask=dataset_attention_mask |
| | ) |
| | return self.second_stage_model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | dataset_embeddings=dataset_embeddings, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| |
|
| |
|
| | def get_model_class(name: str): |
| | if name in 'transductive': |
| | return ContextualDocumentEmbeddingTransformer |
| | elif name == 'biencoder': |
| | return BiEncoder |
| | elif name == "dataset_prefix_biencoder": |
| | return DatasetPrefixBiencoder |
| | else: |
| | raise ValueError(f'unknown model cls {name}') |
| |
|