mosaic-light / foundation_bert.py
Xsmos's picture
Fix remote train_config loading
86ec786 verified
import torch
import os
import yaml
from pathlib import Path
# from ..utils.masked_data_modeling_loss import MaskedDataLossWithSoftmax
# from ..utils.contrastive_loss import ContrastiveLoss
# from ..utils.yaml_util import MyLoader
from dataclasses import dataclass
from transformers import ModernBertModel, ModernBertConfig, PretrainedConfig
from transformers.utils import cached_file
from typing import Optional, Union
# import yaml
class MyLoader(yaml.SafeLoader):
# returns
def construct_mapping(self, *args, **kwargs):
super().add_constructor(None, construct_undefined)
# when loading we want to skip keys that require construction,
mapping = super().construct_mapping(*args, **kwargs)
return mapping
import typing
class Tagged(typing.NamedTuple):
tag: str
value: object
def construct_undefined(self, node):
if isinstance(node, yaml.nodes.ScalarNode):
value = self.construct_scalar(node)
elif isinstance(node, yaml.nodes.SequenceNode):
value = self.construct_sequence(node)
elif isinstance(node, yaml.nodes.MappingNode):
value = self.construct_mapping(node)
else:
assert False, f"unexpected node: {node!r}"
return Tagged(node.tag, value)
@dataclass
class FoundationOutput:
loss: torch.Tensor = None
logits: torch.Tensor = None
num_output: torch.Tensor = None
est_err_output: torch.Tensor = None
hidden_states: torch.Tensor = None
masked_loss: torch.Tensor = None
num_loss: torch.Tensor = None
est_err_loss: torch.Tensor = None
@dataclass
class FoundationBertConfig:
vocab_size: int
hidden_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
hidden_dropout_prob: float
attention_probs_dropout_prob: float
pad_token_id: int
classifier_dropout: float
max_position_embeddings: int
contrastive_temperature: float
loss_weights: dict
use_xval_loss: bool = True
use_mlm_loss: bool = True
use_regression_loss: bool = False
use_contrastive_loss: bool = False
transform_numeric: bool = False
use_sdpa_attention: bool = True
def to_dict(self):
return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()}
class FoundationBert(ModernBertModel):
def __init__(self,
config: FoundationBertConfig = None,
use_mlm_loss: bool = False,
use_regression_loss: bool = True,
use_contrastive_loss: bool = False,
use_xval_loss: bool = False,
transform_numeric: bool = False,
*args,
**kwargs):
self.gconfig = config
# print(f"⚠️ FoundationBert.__init__: {self.gconfig=}")
bert_conf = ModernBertConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
pad_token_id=config.pad_token_id,
max_position_embeddings=config.max_position_embeddings,
_attn_implementation='sdpa'
)
self.gconfig.transform_numeric = transform_numeric
super().__init__(bert_conf,)
try:
if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss:
raise ValueError("At least one loss must be enabled")
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
except:
self.gconfig.use_mlm_loss = use_mlm_loss
self.gconfig.use_regression_loss = use_regression_loss
self.gconfig.use_contrastive_loss = use_contrastive_loss
self.gconfig.use_xval_loss = use_xval_loss
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss)
self.dataset_path = kwargs.get('dataset_path', None)
self.vector_shape = kwargs['vector_shape']
self.scalar_shape = kwargs['scalar_shape']
self.mask_token = kwargs['mask_token']
# self.scalar_keys = [
# 'redshift',
# 'halo_mass',
# 'stellar_mass',
# ]
# self.vector_keys = [
# 'SED',
# 'SFH',
# 'mag_{band}_spherex',
# 'mag_{band}_lsst',
# ]
# convert modality names to 'scalars' or keep as is if in vector shape
self.modalscalars = [m if m in self.vector_shape else 'scalars' for m in self.modalities]
# remove duplicates while preserving order
self.modalscalars = list(dict.fromkeys(self.modalscalars))
print(f"✅ FoundationBert.__init__ is called with {kwargs=}, {self.modalscalars=}, {self.dataset_path=} ✅")
self.embedding = torch.nn.ModuleDict() # modality specific embedding layers
self.num_head = torch.nn.ModuleDict() # modality specific regression heads
# create modality specific layers
for modality in self.modalscalars:
self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) # input.shape -> ouput.shape: (B, L, 1) -> (B, L, H)
self.num_head[modality] = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.LayerNorm(config.hidden_size),
torch.nn.GELU(),
torch.nn.Linear(config.hidden_size, config.hidden_size // 2),
torch.nn.GELU(),
torch.nn.Linear(config.hidden_size // 2, 1)
)
# self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob)
# self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
# self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
# self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
self.distributed_loss = False
@property
def modalities(self):
return self.vector_shape | self.scalar_shape
@classmethod
def from_pretrained(self,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
"""
Modification to correctly handle loading extraneous parameters for GBert
"""
path = Path(pretrained_model_name_or_path)
if 'checkpoint' in str(pretrained_model_name_or_path):
model_config = path.parent / 'train_config.yaml'
elif path.is_dir():
model_config = path / 'train_config.yaml'
else:
model_config = cached_file(
pretrained_model_name_or_path,
'train_config.yaml',
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
)
with open(model_config, 'r') as f:
config = yaml.load(f, Loader=MyLoader)
kwargs['modalities'] = config['modalities']
kwargs['dataset_path'] = config['dataset_path']
kwargs['mask_token'] = config['mask_token']
if 'vector_shape' not in kwargs and 'vector_shape' in config:
kwargs['vector_shape'] = config['vector_shape']
if 'scalar_shape' not in kwargs and 'scalar_shape' in config:
kwargs['scalar_shape'] = config['scalar_shape']
print(f"✅ Foundationbert.from_pretrained is called with {model_config=} and {kwargs=} ✅")
return super().from_pretrained(
pretrained_model_name_or_path,
**config['model_config'],
**kwargs
)
def pool_output(self,
embeddings: torch.Tensor,
attention_mask: torch.Tensor,
use_last: bool = False
) -> torch.Tensor:
"""Average pool the hidden states using the attention mask.
Parameters
----------
embeddings : torch.Tensor
The hidden states to pool (B, SeqLen, HiddenDim).
attention_mask : torch.Tensor
The attention mask for the hidden states (B, SeqLen).
Returns
-------
torch.Tensor
The pooled embeddings (B, HiddenDim).
"""
# Get the sequence lengths
sl_mod = 1 if use_last else 2
seq_lengths = attention_mask.sum(axis=1)
# Set the attention mask to 0 for start and end tokens
new_attention = attention_mask.clone()
new_attention[:, 0] = attention_mask[:,0] * 0
new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod]
# Create a mask for the pooling operation (B, SeqLen, HiddenDim)
pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device)
# Sum the embeddings over the sequence length (use the mask to avoid
# pad, start, and stop tokens)
sum_embeds = torch.sum(embeddings * pool_mask, 1)
# Avoid division by zero for zero length sequences by clamping
# sum_mask = torch.clamp(pool_mask.sum(1), min=1e-9)
seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) # Shape (B, 1) to broadcast
# Compute mean pooled embeddings for each sequence
return sum_embeds / seq_lengths
def last_token_pool(
self,
embeddings: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Pool the last hidden states using the attention mask.
Parameters
----------
embeddings : torch.Tensor
The last hidden states to pool (B, SeqLen, HiddenDim).
attention_mask : torch.Tensor
The attention mask for the hidden states (B, SeqLen).
Returns
-------
torch.Tensor
The pooled embeddings (B, HiddenDim).
"""
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return embeddings[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = embeddings.shape[0]
return embeddings[
torch.arange(batch_size, device=embeddings.device),
sequence_lengths,
]
def forward(self, inputs, return_input_label_mapping=False):
"""
Forward pass that computes predictions for each modality.
Args:
input_label_mapping (dict): A dictionary containing inputs and labels for different modalities.
Returns:
outputs (dict): A dictionary containing the logits and error logits for each modality.
"""
# Initialize the dictionary for the dynamic input-label mapping
input_label_mapping = {}
combined = []
for src_modality in self.modalscalars:
# Add the modality's input and label data to the input_label_mapping
input_label_mapping[src_modality] = {
'input': inputs[f"input_{src_modality}"], # Input data
'labels': inputs[f"labels_{src_modality}"] # Corresponding labels
}
input_data = input_label_mapping[src_modality]['input'] # get input data
label = input_label_mapping[src_modality]['labels'] # get label data (for masking)
input_data = torch.where(label, self.mask_token, input_data) # apply masking
x = self.embedding[src_modality](input_data.unsqueeze(-1)) # shape: (B, L, H)
x = torch.nn.functional.silu(x)
combined.append(x) # combine all modalities
combined = torch.cat(combined, dim=1) # Concatenate along the sequence length dimension
position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) # shape: (1, L)
# combined += self.position_embeddings(position_ids) # add position embedding
combined = self.embed_dropout(combined)
# x = self.encoder(combined, output_hidden_states=True).last_hidden_state # encode the combined input
hidden_states = combined
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, position_ids = position_ids)[0]
x = self.final_norm(hidden_states)
start = 0
outputs = {}
# Iterate over each target modality to compute logits
for tgt_modality in self.modalscalars:
length = input_label_mapping[tgt_modality]['input'].shape[1] # get sequence length of the modality
x_t = x[:, start:start+length, :] # slice the encoded output for each modality
outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) # modality specific regression head
start += length # update start index for next modality
if getattr(self, 'save_umap_for', None):
pooled = x_t.mean(dim=1) # Mean pooling over the sequence length dimension
self.save_pooled_embedding(pooled) # saved for UMAP visualization
return (outputs, input_label_mapping) if return_input_label_mapping else outputs
def save_pooled_embedding(self, features):
"""
Save the last hidden state to a file.
"""
import h5py
fname = Path(self.save_umap_for)
fname.parent.mkdir(parents=True, exist_ok=True)
features = features.detach().cpu().numpy()
if fname.exists():
with h5py.File(fname, 'r+') as f:
old_size = f['features'].shape[0] # get current size
new_size = old_size + features.shape[0] # calculate new size
f['features'].resize((new_size, features.shape[-1])) # resize dataset
f['features'][old_size:] = features # append new features
else:
with h5py.File(fname, 'w') as f:
f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True)
def get_retrieval_embedding(
self,
inputs,
pooling: str = "mean",
normalize: bool = True,
) -> torch.Tensor:
"""
Build a single embedding per sample for kNN-style retrieval.
Parameters
----------
inputs : dict
Batch dict with `input_<modality>` and `labels_<modality>` entries.
pooling : str
`mean` (default) or `last`.
normalize : bool
L2-normalize output embeddings for cosine/inner-product search.
"""
combined = []
for src_modality in self.modalscalars:
input_data = inputs[f"input_{src_modality}"]
label = inputs[f"labels_{src_modality}"]
input_data = torch.where(label, self.mask_token, input_data)
x = self.embedding[src_modality](input_data.unsqueeze(-1))
x = torch.nn.functional.silu(x)
combined.append(x)
combined = torch.cat(combined, dim=1)
position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device)
combined = self.embed_dropout(combined)
hidden_states = combined
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, position_ids=position_ids)[0]
hidden_states = self.final_norm(hidden_states)
if pooling == "last":
embedding = hidden_states[:, -1, :]
else:
embedding = hidden_states.mean(dim=1)
if normalize:
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
return embedding