mosaic / foundation_bert.py
Xsmos's picture
Fix import error and add source_files to config
989ce21 verified
import sys
import os
from pathlib import Path
import torch
import yaml
# from masked_data_modeling_loss import MaskedDataLossWithSoftmax
# from ..utils.contrastive_loss import ContrastiveLoss
from .yaml_util import MyLoader
from dataclasses import dataclass
from transformers import BertModel, BertConfig, PretrainedConfig
from typing import Optional, Union
@dataclass
class FoundationOutput:
loss: torch.Tensor = None
logits: torch.Tensor = None
num_output: torch.Tensor = None
est_err_output: torch.Tensor = None
hidden_states: torch.Tensor = None
masked_loss: torch.Tensor = None
num_loss: torch.Tensor = None
est_err_loss: torch.Tensor = None
@dataclass
class FoundationBertConfig:
vocab_size: int
hidden_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
hidden_dropout_prob: float
attention_probs_dropout_prob: float
pad_token_id: int
classifier_dropout: float
max_position_embeddings: int
contrastive_temperature: float
loss_weights: dict
use_xval_loss: bool = True
use_mlm_loss: bool = True
use_regression_loss: bool = False
use_contrastive_loss: bool = False
transform_numeric: bool = False
def to_dict(self):
return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()}
class FoundationBert(BertModel):
def __init__(self,
config: FoundationBertConfig = None,
use_mlm_loss: bool = False,
use_regression_loss: bool = True,
use_contrastive_loss: bool = False,
use_xval_loss: bool = False,
transform_numeric: bool = False,
*args,
**kwargs):
self.gconfig = config
# print(f"⚠️ FoundationBert.__init__: {self.gconfig=}")
bert_conf = BertConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
pad_token_id=config.pad_token_id,
max_position_embeddings=config.max_position_embeddings,
_attn_implementation='sdpa'
)
self.gconfig.transform_numeric = transform_numeric
super().__init__(bert_conf,)
try:
if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss:
raise ValueError("At least one loss must be enabled")
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
except:
self.gconfig.use_mlm_loss = use_mlm_loss
self.gconfig.use_regression_loss = use_regression_loss
self.gconfig.use_contrastive_loss = use_contrastive_loss
self.gconfig.use_xval_loss = use_xval_loss
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
self.dataset_path = kwargs.get('dataset_path', None)
self.modalities = kwargs['modalities']
self.mask_token = kwargs['mask_token']
self.scalar_keys = [
'redshift',
'halo_mass',
'stellar_mass',
]
self.vector_keys = [
'SED',
'SFH',
'mag_{band}_spherex',
'mag_{band}_lsst',
]
self.modalscalars = [m if m in self.vector_keys else 'scalars' for m in self.modalities]
self.modalscalars = list(dict.fromkeys(self.modalscalars))
# print(f"✅ FoundationBert.__init__ is called with {kwargs=}, {self.modalscalars=}, {self.dataset_path=} ✅")
self.embedding = torch.nn.ModuleDict() # modality specific embedding layers
self.num_head = torch.nn.ModuleDict() # modality specific regression heads
# create modality specific layers
for modality in self.modalscalars:
self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) # input.shape -> ouput.shape: (B, L, 1) -> (B, L, H)
self.num_head[modality] = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.LayerNorm(config.hidden_size),
torch.nn.GELU(),
torch.nn.Linear(config.hidden_size, config.hidden_size // 2),
torch.nn.GELU(),
torch.nn.Linear(config.hidden_size // 2, 1)
)
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
#self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
self.distributed_loss = False
@classmethod
def from_pretrained(self,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
from huggingface_hub import hf_hub_download
try:
model_config = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="train_config.yaml",
revision=kwargs.get("revision", "main")
)
except Exception as e:
model_config = os.path.join(pretrained_model_name_or_path, "train_config.yaml")
# print(f"✅ Successfully located config at: {model_config}")
with open(model_config, 'r') as f:
config = yaml.load(f, Loader=MyLoader)
kwargs['modalities'] = config['modalities']
kwargs['dataset_path'] = config['dataset_path']
kwargs['mask_token'] = config['mask_token']
# print(f"✅ Foundationbert.from_pretrained is called with {model_config=} and {kwargs=} ✅")
return super().from_pretrained(
pretrained_model_name_or_path,
**config['model_config'],
**kwargs
)
def pool_output(self,
embeddings: torch.Tensor,
attention_mask: torch.Tensor,
use_last: bool = False
) -> torch.Tensor:
"""Average pool the hidden states using the attention mask.
Parameters
----------
embeddings : torch.Tensor
The hidden states to pool (B, SeqLen, HiddenDim).
attention_mask : torch.Tensor
The attention mask for the hidden states (B, SeqLen).
Returns
-------
torch.Tensor
The pooled embeddings (B, HiddenDim).
"""
# Get the sequence lengths
sl_mod = 1 if use_last else 2
seq_lengths = attention_mask.sum(axis=1)
# Set the attention mask to 0 for start and end tokens
new_attention = attention_mask.clone()
new_attention[:, 0] = attention_mask[:,0] * 0
new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod]
# Create a mask for the pooling operation (B, SeqLen, HiddenDim)
pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device)
# Sum the embeddings over the sequence length (use the mask to avoid
# pad, start, and stop tokens)
sum_embeds = torch.sum(embeddings * pool_mask, 1)
# Avoid division by zero for zero length sequences by clamping
# sum_mask = torch.clamp(pool_mask.sum(1), min=1e-9)
seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) # Shape (B, 1) to broadcast
# Compute mean pooled embeddings for each sequence
return sum_embeds / seq_lengths
def last_token_pool(
self,
embeddings: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Pool the last hidden states using the attention mask.
Parameters
----------
embeddings : torch.Tensor
The last hidden states to pool (B, SeqLen, HiddenDim).
attention_mask : torch.Tensor
The attention mask for the hidden states (B, SeqLen).
Returns
-------
torch.Tensor
The pooled embeddings (B, HiddenDim).
"""
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return embeddings[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = embeddings.shape[0]
return embeddings[
torch.arange(batch_size, device=embeddings.device),
sequence_lengths,
]
def forward(self, inputs, return_input_label_mapping=False):
"""
Forward pass that computes predictions for each modality.
Args:
input_label_mapping (dict): A dictionary containing inputs and labels for different modalities.
Returns:
outputs (dict): A dictionary containing the logits and error logits for each modality.
"""
# Initialize the dictionary for the dynamic input-label mapping
input_label_mapping = {}
combined = []
for src_modality in self.modalscalars:
# Add the modality's input and label data to the input_label_mapping
input_label_mapping[src_modality] = {
'input': inputs[f"input_{src_modality}"], # Input data
'labels': inputs[f"labels_{src_modality}"] # Corresponding labels
}
input_data = input_label_mapping[src_modality]['input'] # get input data
label = input_label_mapping[src_modality]['labels'] # get label data (for masking)
input_data = torch.where(label, self.mask_token, input_data) # apply masking
x = self.embedding[src_modality](input_data.unsqueeze(-1)) # shape: (B, L, H)
x = torch.nn.functional.silu(x)
combined.append(x) # combine all modalities
combined = torch.cat(combined, dim=1) # Concatenate along the sequence length dimension
self.position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) # shape: (1, L)
combined += self.position_embeddings(self.position_ids) # add position embedding
combined = self.embed_dropout(combined)
x = self.encoder(combined, output_hidden_states=True).last_hidden_state # encode the combined input
start = 0
outputs = {}
# Iterate over each target modality to compute logits
for tgt_modality in self.modalscalars:
length = input_label_mapping[tgt_modality]['input'].shape[1] # get sequence length of the modality
x_t = x[:, start:start+length, :] # slice the encoded output for each modality
outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) # modality specific regression head
start += length # update start index for next modality
if getattr(self, 'save_umap_for', None):
pooled = x_t.mean(dim=1) # Mean pooling over the sequence length dimension
self.save_pooled_embedding(pooled) # saved for UMAP visualization
return (outputs, input_label_mapping) if return_input_label_mapping else outputs
def save_pooled_embedding(self, features):
"""
Save the last hidden state to a file.
"""
import h5py
fname = Path(self.save_umap_for)
fname.parent.mkdir(parents=True, exist_ok=True)
features = features.detach().cpu().numpy()
if fname.exists():
with h5py.File(fname, 'r+') as f:
old_size = f['features'].shape[0] # get current size
new_size = old_size + features.shape[0] # calculate new size
f['features'].resize((new_size, features.shape[-1])) # resize dataset
f['features'][old_size:] = features # append new features
else:
with h5py.File(fname, 'w') as f:
f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True)