|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
|
|
|
from .yaml_util import MyLoader |
|
|
from dataclasses import dataclass |
|
|
from transformers import BertModel, BertConfig, PretrainedConfig |
|
|
from typing import Optional, Union |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FoundationOutput: |
|
|
loss: torch.Tensor = None |
|
|
logits: torch.Tensor = None |
|
|
num_output: torch.Tensor = None |
|
|
est_err_output: torch.Tensor = None |
|
|
hidden_states: torch.Tensor = None |
|
|
masked_loss: torch.Tensor = None |
|
|
num_loss: torch.Tensor = None |
|
|
est_err_loss: torch.Tensor = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FoundationBertConfig: |
|
|
vocab_size: int |
|
|
hidden_size: int |
|
|
num_hidden_layers: int |
|
|
num_attention_heads: int |
|
|
intermediate_size: int |
|
|
hidden_dropout_prob: float |
|
|
attention_probs_dropout_prob: float |
|
|
pad_token_id: int |
|
|
classifier_dropout: float |
|
|
max_position_embeddings: int |
|
|
contrastive_temperature: float |
|
|
loss_weights: dict |
|
|
use_xval_loss: bool = True |
|
|
use_mlm_loss: bool = True |
|
|
use_regression_loss: bool = False |
|
|
use_contrastive_loss: bool = False |
|
|
transform_numeric: bool = False |
|
|
|
|
|
def to_dict(self): |
|
|
return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()} |
|
|
|
|
|
class FoundationBert(BertModel): |
|
|
def __init__(self, |
|
|
config: FoundationBertConfig = None, |
|
|
use_mlm_loss: bool = False, |
|
|
use_regression_loss: bool = True, |
|
|
use_contrastive_loss: bool = False, |
|
|
use_xval_loss: bool = False, |
|
|
transform_numeric: bool = False, |
|
|
*args, |
|
|
**kwargs): |
|
|
self.gconfig = config |
|
|
|
|
|
bert_conf = BertConfig( |
|
|
vocab_size=config.vocab_size, |
|
|
hidden_size=config.hidden_size, |
|
|
num_hidden_layers=config.num_hidden_layers, |
|
|
num_attention_heads=config.num_attention_heads, |
|
|
intermediate_size=config.intermediate_size, |
|
|
hidden_dropout_prob=config.hidden_dropout_prob, |
|
|
attention_probs_dropout_prob=config.attention_probs_dropout_prob, |
|
|
pad_token_id=config.pad_token_id, |
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
_attn_implementation='sdpa' |
|
|
) |
|
|
self.gconfig.transform_numeric = transform_numeric |
|
|
super().__init__(bert_conf,) |
|
|
try: |
|
|
if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss: |
|
|
raise ValueError("At least one loss must be enabled") |
|
|
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) |
|
|
except: |
|
|
self.gconfig.use_mlm_loss = use_mlm_loss |
|
|
self.gconfig.use_regression_loss = use_regression_loss |
|
|
self.gconfig.use_contrastive_loss = use_contrastive_loss |
|
|
self.gconfig.use_xval_loss = use_xval_loss |
|
|
self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) |
|
|
|
|
|
self.dataset_path = kwargs.get('dataset_path', None) |
|
|
|
|
|
self.modalities = kwargs['modalities'] |
|
|
self.mask_token = kwargs['mask_token'] |
|
|
|
|
|
self.scalar_keys = [ |
|
|
'redshift', |
|
|
'halo_mass', |
|
|
'stellar_mass', |
|
|
] |
|
|
self.vector_keys = [ |
|
|
'SED', |
|
|
'SFH', |
|
|
'mag_{band}_spherex', |
|
|
'mag_{band}_lsst', |
|
|
] |
|
|
self.modalscalars = [m if m in self.vector_keys else 'scalars' for m in self.modalities] |
|
|
self.modalscalars = list(dict.fromkeys(self.modalscalars)) |
|
|
|
|
|
|
|
|
|
|
|
self.embedding = torch.nn.ModuleDict() |
|
|
self.num_head = torch.nn.ModuleDict() |
|
|
|
|
|
for modality in self.modalscalars: |
|
|
self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) |
|
|
self.num_head[modality] = torch.nn.Sequential( |
|
|
torch.nn.Linear(config.hidden_size, config.hidden_size), |
|
|
torch.nn.LayerNorm(config.hidden_size), |
|
|
torch.nn.GELU(), |
|
|
torch.nn.Linear(config.hidden_size, config.hidden_size // 2), |
|
|
torch.nn.GELU(), |
|
|
torch.nn.Linear(config.hidden_size // 2, 1) |
|
|
) |
|
|
|
|
|
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.xval_loss = torch.nn.MSELoss(reduction='none') |
|
|
|
|
|
self.distributed_loss = False |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(self, |
|
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
|
*model_args, |
|
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
|
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
|
ignore_mismatched_sizes: bool = False, |
|
|
force_download: bool = False, |
|
|
local_files_only: bool = False, |
|
|
token: Optional[Union[str, bool]] = None, |
|
|
revision: str = "main", |
|
|
use_safetensors: bool = None, |
|
|
**kwargs, |
|
|
): |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
try: |
|
|
model_config = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="train_config.yaml", |
|
|
revision=kwargs.get("revision", "main") |
|
|
) |
|
|
except Exception as e: |
|
|
model_config = os.path.join(pretrained_model_name_or_path, "train_config.yaml") |
|
|
|
|
|
|
|
|
with open(model_config, 'r') as f: |
|
|
config = yaml.load(f, Loader=MyLoader) |
|
|
|
|
|
kwargs['modalities'] = config['modalities'] |
|
|
kwargs['dataset_path'] = config['dataset_path'] |
|
|
kwargs['mask_token'] = config['mask_token'] |
|
|
|
|
|
return super().from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
**config['model_config'], |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def pool_output(self, |
|
|
embeddings: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
use_last: bool = False |
|
|
) -> torch.Tensor: |
|
|
"""Average pool the hidden states using the attention mask. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
embeddings : torch.Tensor |
|
|
The hidden states to pool (B, SeqLen, HiddenDim). |
|
|
attention_mask : torch.Tensor |
|
|
The attention mask for the hidden states (B, SeqLen). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
The pooled embeddings (B, HiddenDim). |
|
|
""" |
|
|
|
|
|
sl_mod = 1 if use_last else 2 |
|
|
seq_lengths = attention_mask.sum(axis=1) |
|
|
|
|
|
new_attention = attention_mask.clone() |
|
|
new_attention[:, 0] = attention_mask[:,0] * 0 |
|
|
new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod] |
|
|
|
|
|
|
|
|
pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device) |
|
|
|
|
|
|
|
|
sum_embeds = torch.sum(embeddings * pool_mask, 1) |
|
|
|
|
|
|
|
|
seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) |
|
|
|
|
|
return sum_embeds / seq_lengths |
|
|
|
|
|
|
|
|
def last_token_pool( |
|
|
self, |
|
|
embeddings: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Pool the last hidden states using the attention mask. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
embeddings : torch.Tensor |
|
|
The last hidden states to pool (B, SeqLen, HiddenDim). |
|
|
attention_mask : torch.Tensor |
|
|
The attention mask for the hidden states (B, SeqLen). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
The pooled embeddings (B, HiddenDim). |
|
|
""" |
|
|
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] |
|
|
if left_padding: |
|
|
return embeddings[:, -1] |
|
|
else: |
|
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
|
batch_size = embeddings.shape[0] |
|
|
return embeddings[ |
|
|
torch.arange(batch_size, device=embeddings.device), |
|
|
sequence_lengths, |
|
|
] |
|
|
|
|
|
def forward(self, inputs, return_input_label_mapping=False): |
|
|
""" |
|
|
Forward pass that computes predictions for each modality. |
|
|
|
|
|
Args: |
|
|
input_label_mapping (dict): A dictionary containing inputs and labels for different modalities. |
|
|
|
|
|
Returns: |
|
|
outputs (dict): A dictionary containing the logits and error logits for each modality. |
|
|
""" |
|
|
|
|
|
|
|
|
input_label_mapping = {} |
|
|
combined = [] |
|
|
for src_modality in self.modalscalars: |
|
|
|
|
|
input_label_mapping[src_modality] = { |
|
|
'input': inputs[f"input_{src_modality}"], |
|
|
'labels': inputs[f"labels_{src_modality}"] |
|
|
} |
|
|
|
|
|
input_data = input_label_mapping[src_modality]['input'] |
|
|
label = input_label_mapping[src_modality]['labels'] |
|
|
input_data = torch.where(label, self.mask_token, input_data) |
|
|
|
|
|
x = self.embedding[src_modality](input_data.unsqueeze(-1)) |
|
|
x = torch.nn.functional.silu(x) |
|
|
combined.append(x) |
|
|
|
|
|
combined = torch.cat(combined, dim=1) |
|
|
|
|
|
self.position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) |
|
|
combined += self.position_embeddings(self.position_ids) |
|
|
combined = self.embed_dropout(combined) |
|
|
|
|
|
x = self.encoder(combined, output_hidden_states=True).last_hidden_state |
|
|
|
|
|
start = 0 |
|
|
outputs = {} |
|
|
|
|
|
for tgt_modality in self.modalscalars: |
|
|
length = input_label_mapping[tgt_modality]['input'].shape[1] |
|
|
x_t = x[:, start:start+length, :] |
|
|
outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) |
|
|
|
|
|
start += length |
|
|
|
|
|
if getattr(self, 'save_umap_for', None): |
|
|
pooled = x_t.mean(dim=1) |
|
|
self.save_pooled_embedding(pooled) |
|
|
|
|
|
return (outputs, input_label_mapping) if return_input_label_mapping else outputs |
|
|
|
|
|
def save_pooled_embedding(self, features): |
|
|
""" |
|
|
Save the last hidden state to a file. |
|
|
""" |
|
|
import h5py |
|
|
fname = Path(self.save_umap_for) |
|
|
fname.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
features = features.detach().cpu().numpy() |
|
|
|
|
|
if fname.exists(): |
|
|
with h5py.File(fname, 'r+') as f: |
|
|
old_size = f['features'].shape[0] |
|
|
new_size = old_size + features.shape[0] |
|
|
|
|
|
f['features'].resize((new_size, features.shape[-1])) |
|
|
f['features'][old_size:] = features |
|
|
|
|
|
else: |
|
|
with h5py.File(fname, 'w') as f: |
|
|
f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True) |
|
|
|