import torch import os import yaml from pathlib import Path # from ..utils.masked_data_modeling_loss import MaskedDataLossWithSoftmax # from ..utils.contrastive_loss import ContrastiveLoss # from ..utils.yaml_util import MyLoader from dataclasses import dataclass from transformers import ModernBertModel, ModernBertConfig, PretrainedConfig from transformers.utils import cached_file from typing import Optional, Union # import yaml class MyLoader(yaml.SafeLoader): # returns def construct_mapping(self, *args, **kwargs): super().add_constructor(None, construct_undefined) # when loading we want to skip keys that require construction, mapping = super().construct_mapping(*args, **kwargs) return mapping import typing class Tagged(typing.NamedTuple): tag: str value: object def construct_undefined(self, node): if isinstance(node, yaml.nodes.ScalarNode): value = self.construct_scalar(node) elif isinstance(node, yaml.nodes.SequenceNode): value = self.construct_sequence(node) elif isinstance(node, yaml.nodes.MappingNode): value = self.construct_mapping(node) else: assert False, f"unexpected node: {node!r}" return Tagged(node.tag, value) @dataclass class FoundationOutput: loss: torch.Tensor = None logits: torch.Tensor = None num_output: torch.Tensor = None est_err_output: torch.Tensor = None hidden_states: torch.Tensor = None masked_loss: torch.Tensor = None num_loss: torch.Tensor = None est_err_loss: torch.Tensor = None @dataclass class FoundationBertConfig: vocab_size: int hidden_size: int num_hidden_layers: int num_attention_heads: int intermediate_size: int hidden_dropout_prob: float attention_probs_dropout_prob: float pad_token_id: int classifier_dropout: float max_position_embeddings: int contrastive_temperature: float loss_weights: dict use_xval_loss: bool = True use_mlm_loss: bool = True use_regression_loss: bool = False use_contrastive_loss: bool = False transform_numeric: bool = False use_sdpa_attention: bool = True def to_dict(self): return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()} class FoundationBert(ModernBertModel): def __init__(self, config: FoundationBertConfig = None, use_mlm_loss: bool = False, use_regression_loss: bool = True, use_contrastive_loss: bool = False, use_xval_loss: bool = False, transform_numeric: bool = False, *args, **kwargs): self.gconfig = config # print(f"⚠️ FoundationBert.__init__: {self.gconfig=}") bert_conf = ModernBertConfig( vocab_size=config.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, pad_token_id=config.pad_token_id, max_position_embeddings=config.max_position_embeddings, _attn_implementation='sdpa' ) self.gconfig.transform_numeric = transform_numeric super().__init__(bert_conf,) try: if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss: raise ValueError("At least one loss must be enabled") self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) except: self.gconfig.use_mlm_loss = use_mlm_loss self.gconfig.use_regression_loss = use_regression_loss self.gconfig.use_contrastive_loss = use_contrastive_loss self.gconfig.use_xval_loss = use_xval_loss self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) self.dataset_path = kwargs.get('dataset_path', None) self.vector_shape = kwargs['vector_shape'] self.scalar_shape = kwargs['scalar_shape'] self.mask_token = kwargs['mask_token'] # self.scalar_keys = [ # 'redshift', # 'halo_mass', # 'stellar_mass', # ] # self.vector_keys = [ # 'SED', # 'SFH', # 'mag_{band}_spherex', # 'mag_{band}_lsst', # ] # convert modality names to 'scalars' or keep as is if in vector shape self.modalscalars = [m if m in self.vector_shape else 'scalars' for m in self.modalities] # remove duplicates while preserving order self.modalscalars = list(dict.fromkeys(self.modalscalars)) print(f"✅ FoundationBert.__init__ is called with {kwargs=}, {self.modalscalars=}, {self.dataset_path=} ✅") self.embedding = torch.nn.ModuleDict() # modality specific embedding layers self.num_head = torch.nn.ModuleDict() # modality specific regression heads # create modality specific layers for modality in self.modalscalars: self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) # input.shape -> ouput.shape: (B, L, 1) -> (B, L, H) self.num_head[modality] = torch.nn.Sequential( torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.LayerNorm(config.hidden_size), torch.nn.GELU(), torch.nn.Linear(config.hidden_size, config.hidden_size // 2), torch.nn.GELU(), torch.nn.Linear(config.hidden_size // 2, 1) ) # self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size) self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob) # self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently # self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently # self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently self.distributed_loss = False @property def modalities(self): return self.vector_shape | self.scalar_shape @classmethod def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): """ Modification to correctly handle loading extraneous parameters for GBert """ path = Path(pretrained_model_name_or_path) if 'checkpoint' in str(pretrained_model_name_or_path): model_config = path.parent / 'train_config.yaml' elif path.is_dir(): model_config = path / 'train_config.yaml' else: model_config = cached_file( pretrained_model_name_or_path, 'train_config.yaml', cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, ) with open(model_config, 'r') as f: config = yaml.load(f, Loader=MyLoader) kwargs['modalities'] = config['modalities'] kwargs['dataset_path'] = config['dataset_path'] kwargs['mask_token'] = config['mask_token'] if 'vector_shape' not in kwargs and 'vector_shape' in config: kwargs['vector_shape'] = config['vector_shape'] if 'scalar_shape' not in kwargs and 'scalar_shape' in config: kwargs['scalar_shape'] = config['scalar_shape'] print(f"✅ Foundationbert.from_pretrained is called with {model_config=} and {kwargs=} ✅") return super().from_pretrained( pretrained_model_name_or_path, **config['model_config'], **kwargs ) def pool_output(self, embeddings: torch.Tensor, attention_mask: torch.Tensor, use_last: bool = False ) -> torch.Tensor: """Average pool the hidden states using the attention mask. Parameters ---------- embeddings : torch.Tensor The hidden states to pool (B, SeqLen, HiddenDim). attention_mask : torch.Tensor The attention mask for the hidden states (B, SeqLen). Returns ------- torch.Tensor The pooled embeddings (B, HiddenDim). """ # Get the sequence lengths sl_mod = 1 if use_last else 2 seq_lengths = attention_mask.sum(axis=1) # Set the attention mask to 0 for start and end tokens new_attention = attention_mask.clone() new_attention[:, 0] = attention_mask[:,0] * 0 new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod] # Create a mask for the pooling operation (B, SeqLen, HiddenDim) pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device) # Sum the embeddings over the sequence length (use the mask to avoid # pad, start, and stop tokens) sum_embeds = torch.sum(embeddings * pool_mask, 1) # Avoid division by zero for zero length sequences by clamping # sum_mask = torch.clamp(pool_mask.sum(1), min=1e-9) seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) # Shape (B, 1) to broadcast # Compute mean pooled embeddings for each sequence return sum_embeds / seq_lengths def last_token_pool( self, embeddings: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """Pool the last hidden states using the attention mask. Parameters ---------- embeddings : torch.Tensor The last hidden states to pool (B, SeqLen, HiddenDim). attention_mask : torch.Tensor The attention mask for the hidden states (B, SeqLen). Returns ------- torch.Tensor The pooled embeddings (B, HiddenDim). """ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] if left_padding: return embeddings[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = embeddings.shape[0] return embeddings[ torch.arange(batch_size, device=embeddings.device), sequence_lengths, ] def forward(self, inputs, return_input_label_mapping=False): """ Forward pass that computes predictions for each modality. Args: input_label_mapping (dict): A dictionary containing inputs and labels for different modalities. Returns: outputs (dict): A dictionary containing the logits and error logits for each modality. """ # Initialize the dictionary for the dynamic input-label mapping input_label_mapping = {} combined = [] for src_modality in self.modalscalars: # Add the modality's input and label data to the input_label_mapping input_label_mapping[src_modality] = { 'input': inputs[f"input_{src_modality}"], # Input data 'labels': inputs[f"labels_{src_modality}"] # Corresponding labels } input_data = input_label_mapping[src_modality]['input'] # get input data label = input_label_mapping[src_modality]['labels'] # get label data (for masking) input_data = torch.where(label, self.mask_token, input_data) # apply masking x = self.embedding[src_modality](input_data.unsqueeze(-1)) # shape: (B, L, H) x = torch.nn.functional.silu(x) combined.append(x) # combine all modalities combined = torch.cat(combined, dim=1) # Concatenate along the sequence length dimension position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) # shape: (1, L) # combined += self.position_embeddings(position_ids) # add position embedding combined = self.embed_dropout(combined) # x = self.encoder(combined, output_hidden_states=True).last_hidden_state # encode the combined input hidden_states = combined for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states, position_ids = position_ids)[0] x = self.final_norm(hidden_states) start = 0 outputs = {} # Iterate over each target modality to compute logits for tgt_modality in self.modalscalars: length = input_label_mapping[tgt_modality]['input'].shape[1] # get sequence length of the modality x_t = x[:, start:start+length, :] # slice the encoded output for each modality outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) # modality specific regression head start += length # update start index for next modality if getattr(self, 'save_umap_for', None): pooled = x_t.mean(dim=1) # Mean pooling over the sequence length dimension self.save_pooled_embedding(pooled) # saved for UMAP visualization return (outputs, input_label_mapping) if return_input_label_mapping else outputs def save_pooled_embedding(self, features): """ Save the last hidden state to a file. """ import h5py fname = Path(self.save_umap_for) fname.parent.mkdir(parents=True, exist_ok=True) features = features.detach().cpu().numpy() if fname.exists(): with h5py.File(fname, 'r+') as f: old_size = f['features'].shape[0] # get current size new_size = old_size + features.shape[0] # calculate new size f['features'].resize((new_size, features.shape[-1])) # resize dataset f['features'][old_size:] = features # append new features else: with h5py.File(fname, 'w') as f: f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True) def get_retrieval_embedding( self, inputs, pooling: str = "mean", normalize: bool = True, ) -> torch.Tensor: """ Build a single embedding per sample for kNN-style retrieval. Parameters ---------- inputs : dict Batch dict with `input_` and `labels_` entries. pooling : str `mean` (default) or `last`. normalize : bool L2-normalize output embeddings for cosine/inner-product search. """ combined = [] for src_modality in self.modalscalars: input_data = inputs[f"input_{src_modality}"] label = inputs[f"labels_{src_modality}"] input_data = torch.where(label, self.mask_token, input_data) x = self.embedding[src_modality](input_data.unsqueeze(-1)) x = torch.nn.functional.silu(x) combined.append(x) combined = torch.cat(combined, dim=1) position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) combined = self.embed_dropout(combined) hidden_states = combined for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states, position_ids=position_ids)[0] hidden_states = self.final_norm(hidden_states) if pooling == "last": embedding = hidden_states[:, -1, :] else: embedding = hidden_states.mean(dim=1) if normalize: embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1) return embedding