import sys import os from pathlib import Path import torch import yaml # from masked_data_modeling_loss import MaskedDataLossWithSoftmax # from ..utils.contrastive_loss import ContrastiveLoss from .yaml_util import MyLoader from dataclasses import dataclass from transformers import BertModel, BertConfig, PretrainedConfig from typing import Optional, Union @dataclass class FoundationOutput: loss: torch.Tensor = None logits: torch.Tensor = None num_output: torch.Tensor = None est_err_output: torch.Tensor = None hidden_states: torch.Tensor = None masked_loss: torch.Tensor = None num_loss: torch.Tensor = None est_err_loss: torch.Tensor = None @dataclass class FoundationBertConfig: vocab_size: int hidden_size: int num_hidden_layers: int num_attention_heads: int intermediate_size: int hidden_dropout_prob: float attention_probs_dropout_prob: float pad_token_id: int classifier_dropout: float max_position_embeddings: int contrastive_temperature: float loss_weights: dict use_xval_loss: bool = True use_mlm_loss: bool = True use_regression_loss: bool = False use_contrastive_loss: bool = False transform_numeric: bool = False def to_dict(self): return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()} class FoundationBert(BertModel): def __init__(self, config: FoundationBertConfig = None, use_mlm_loss: bool = False, use_regression_loss: bool = True, use_contrastive_loss: bool = False, use_xval_loss: bool = False, transform_numeric: bool = False, *args, **kwargs): self.gconfig = config # print(f"⚠️ FoundationBert.__init__: {self.gconfig=}") bert_conf = BertConfig( vocab_size=config.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, pad_token_id=config.pad_token_id, max_position_embeddings=config.max_position_embeddings, _attn_implementation='sdpa' ) self.gconfig.transform_numeric = transform_numeric super().__init__(bert_conf,) try: if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss: raise ValueError("At least one loss must be enabled") self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) except: self.gconfig.use_mlm_loss = use_mlm_loss self.gconfig.use_regression_loss = use_regression_loss self.gconfig.use_contrastive_loss = use_contrastive_loss self.gconfig.use_xval_loss = use_xval_loss self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) self.dataset_path = kwargs.get('dataset_path', None) self.modalities = kwargs['modalities'] self.mask_token = kwargs['mask_token'] self.scalar_keys = [ 'redshift', 'halo_mass', 'stellar_mass', ] self.vector_keys = [ 'SED', 'SFH', 'mag_{band}_spherex', 'mag_{band}_lsst', ] self.modalscalars = [m if m in self.vector_keys else 'scalars' for m in self.modalities] self.modalscalars = list(dict.fromkeys(self.modalscalars)) # print(f"✅ FoundationBert.__init__ is called with {kwargs=}, {self.modalscalars=}, {self.dataset_path=} ✅") self.embedding = torch.nn.ModuleDict() # modality specific embedding layers self.num_head = torch.nn.ModuleDict() # modality specific regression heads # create modality specific layers for modality in self.modalscalars: self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) # input.shape -> ouput.shape: (B, L, 1) -> (B, L, H) self.num_head[modality] = torch.nn.Sequential( torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.LayerNorm(config.hidden_size), torch.nn.GELU(), torch.nn.Linear(config.hidden_size, config.hidden_size // 2), torch.nn.GELU(), torch.nn.Linear(config.hidden_size // 2, 1) ) self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size) self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently #self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently self.distributed_loss = False @classmethod def from_pretrained(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): from huggingface_hub import hf_hub_download try: model_config = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="train_config.yaml", revision=kwargs.get("revision", "main") ) except Exception as e: model_config = os.path.join(pretrained_model_name_or_path, "train_config.yaml") # print(f"✅ Successfully located config at: {model_config}") with open(model_config, 'r') as f: config = yaml.load(f, Loader=MyLoader) kwargs['modalities'] = config['modalities'] kwargs['dataset_path'] = config['dataset_path'] kwargs['mask_token'] = config['mask_token'] # print(f"✅ Foundationbert.from_pretrained is called with {model_config=} and {kwargs=} ✅") return super().from_pretrained( pretrained_model_name_or_path, **config['model_config'], **kwargs ) def pool_output(self, embeddings: torch.Tensor, attention_mask: torch.Tensor, use_last: bool = False ) -> torch.Tensor: """Average pool the hidden states using the attention mask. Parameters ---------- embeddings : torch.Tensor The hidden states to pool (B, SeqLen, HiddenDim). attention_mask : torch.Tensor The attention mask for the hidden states (B, SeqLen). Returns ------- torch.Tensor The pooled embeddings (B, HiddenDim). """ # Get the sequence lengths sl_mod = 1 if use_last else 2 seq_lengths = attention_mask.sum(axis=1) # Set the attention mask to 0 for start and end tokens new_attention = attention_mask.clone() new_attention[:, 0] = attention_mask[:,0] * 0 new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod] # Create a mask for the pooling operation (B, SeqLen, HiddenDim) pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device) # Sum the embeddings over the sequence length (use the mask to avoid # pad, start, and stop tokens) sum_embeds = torch.sum(embeddings * pool_mask, 1) # Avoid division by zero for zero length sequences by clamping # sum_mask = torch.clamp(pool_mask.sum(1), min=1e-9) seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) # Shape (B, 1) to broadcast # Compute mean pooled embeddings for each sequence return sum_embeds / seq_lengths def last_token_pool( self, embeddings: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """Pool the last hidden states using the attention mask. Parameters ---------- embeddings : torch.Tensor The last hidden states to pool (B, SeqLen, HiddenDim). attention_mask : torch.Tensor The attention mask for the hidden states (B, SeqLen). Returns ------- torch.Tensor The pooled embeddings (B, HiddenDim). """ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] if left_padding: return embeddings[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = embeddings.shape[0] return embeddings[ torch.arange(batch_size, device=embeddings.device), sequence_lengths, ] def forward(self, inputs, return_input_label_mapping=False): """ Forward pass that computes predictions for each modality. Args: input_label_mapping (dict): A dictionary containing inputs and labels for different modalities. Returns: outputs (dict): A dictionary containing the logits and error logits for each modality. """ # Initialize the dictionary for the dynamic input-label mapping input_label_mapping = {} combined = [] for src_modality in self.modalscalars: # Add the modality's input and label data to the input_label_mapping input_label_mapping[src_modality] = { 'input': inputs[f"input_{src_modality}"], # Input data 'labels': inputs[f"labels_{src_modality}"] # Corresponding labels } input_data = input_label_mapping[src_modality]['input'] # get input data label = input_label_mapping[src_modality]['labels'] # get label data (for masking) input_data = torch.where(label, self.mask_token, input_data) # apply masking x = self.embedding[src_modality](input_data.unsqueeze(-1)) # shape: (B, L, H) x = torch.nn.functional.silu(x) combined.append(x) # combine all modalities combined = torch.cat(combined, dim=1) # Concatenate along the sequence length dimension self.position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) # shape: (1, L) combined += self.position_embeddings(self.position_ids) # add position embedding combined = self.embed_dropout(combined) x = self.encoder(combined, output_hidden_states=True).last_hidden_state # encode the combined input start = 0 outputs = {} # Iterate over each target modality to compute logits for tgt_modality in self.modalscalars: length = input_label_mapping[tgt_modality]['input'].shape[1] # get sequence length of the modality x_t = x[:, start:start+length, :] # slice the encoded output for each modality outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) # modality specific regression head start += length # update start index for next modality if getattr(self, 'save_umap_for', None): pooled = x_t.mean(dim=1) # Mean pooling over the sequence length dimension self.save_pooled_embedding(pooled) # saved for UMAP visualization return (outputs, input_label_mapping) if return_input_label_mapping else outputs def save_pooled_embedding(self, features): """ Save the last hidden state to a file. """ import h5py fname = Path(self.save_umap_for) fname.parent.mkdir(parents=True, exist_ok=True) features = features.detach().cpu().numpy() if fname.exists(): with h5py.File(fname, 'r+') as f: old_size = f['features'].shape[0] # get current size new_size = old_size + features.shape[0] # calculate new size f['features'].resize((new_size, features.shape[-1])) # resize dataset f['features'][old_size:] = features # append new features else: with h5py.File(fname, 'w') as f: f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True)