|
|
|
|
|
""" |
|
|
This module defines PyTorch Lightning modules for the Tahoeformer project. |
|
|
It includes a base model class (`LitBaseModel`) and the main experimental model |
|
|
(`LitEnformerSMILES`) which combines an Enformer-based DNA sequence model with |
|
|
drug information (SMILES string processed into Morgan Fingerprints) and dose information |
|
|
to predict gene expression. |
|
|
|
|
|
Key components: |
|
|
- masked_mse: A utility loss function for Mean Squared Error that handles NaN targets. |
|
|
- LitBaseModel: A base LightningModule providing common training, validation, test steps, |
|
|
optimizer configuration, and basic metric logging hooks. |
|
|
- LitEnformerSMILES: The primary model for predicting drug-induced gene expression changes, |
|
|
using Enformer for DNA and Morgan fingerprints for drugs. |
|
|
- MetricLogger: A PyTorch Lightning Callback for detailed logging of predictions. |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import lightning.pytorch as pl |
|
|
from enformer_pytorch.finetune import HeadAdapterWrapper |
|
|
from enformer_pytorch import Enformer |
|
|
from torchmetrics.regression import PearsonCorrCoef, R2Score |
|
|
from warnings import warn |
|
|
import wandb |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def masked_mse(y_hat, y): |
|
|
""" |
|
|
Computes Mean Squared Error (MSE) while ignoring NaN values in the target tensor. |
|
|
|
|
|
Args: |
|
|
y_hat (torch.Tensor): The predicted values. |
|
|
y (torch.Tensor): The target values, which may contain NaNs. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A scalar tensor representing the masked MSE. Returns 0.0 if all targets are NaN. |
|
|
""" |
|
|
mask = torch.isnan(y) |
|
|
if mask.all(): |
|
|
return torch.tensor(0.0, device=y_hat.device, requires_grad=True) |
|
|
mse = torch.mean((y[~mask] - y_hat[~mask])**2) |
|
|
return mse |
|
|
|
|
|
|
|
|
class LitBaseModel(pl.LightningModule): |
|
|
""" |
|
|
A base PyTorch Lightning module providing common boilerplate for training and evaluation. |
|
|
|
|
|
This class implements a generic training/validation/test step, loss calculation using |
|
|
`masked_mse`, optimizer configuration (AdamW), and hooks for accumulating outputs |
|
|
for detailed metric logging via the `MetricLogger` callback. |
|
|
|
|
|
Derived classes are expected to implement the `forward` method. |
|
|
|
|
|
Hyperparameters: |
|
|
learning_rate (float): The learning rate for the optimizer. |
|
|
loss_alpha (float): A coefficient for the primary loss term (MSE). Useful if |
|
|
additional loss terms were to be added. |
|
|
weight_decay (float, optional): Weight decay for the AdamW optimizer. If None, |
|
|
AdamW's internal default is used. |
|
|
eval_gene_sets (dict, optional): A dictionary where keys are set names (e.g., 'oncogenes') |
|
|
and values are lists of gene IDs. Used by `MetricLogger` |
|
|
to compute metrics for specific gene subsets. |
|
|
""" |
|
|
def __init__(self, learning_rate=5e-6, loss_alpha=1.0, weight_decay=None, |
|
|
eval_gene_sets=None): |
|
|
""" |
|
|
Initializes the LitBaseModel. |
|
|
|
|
|
Args: |
|
|
learning_rate (float, optional): Learning rate. Defaults to 5e-6. |
|
|
loss_alpha (float, optional): Alpha for MSE loss. Defaults to 1.0. |
|
|
weight_decay (float, optional): Weight decay for AdamW. If None, uses optimizer default. |
|
|
Defaults to None. |
|
|
eval_gene_sets (dict, optional): Dictionary of gene sets for targeted evaluation. |
|
|
Keys are names, values are lists of gene IDs. |
|
|
Defaults to None. |
|
|
""" |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
self.learning_rate = learning_rate |
|
|
self.loss_alpha = loss_alpha |
|
|
self.weight_decay = weight_decay |
|
|
self.eval_gene_sets = eval_gene_sets if eval_gene_sets else {} |
|
|
|
|
|
|
|
|
self.epoch_outputs = [] |
|
|
|
|
|
def loss_fn(self, y_hat, y): |
|
|
""" |
|
|
Calculates the loss for the model. |
|
|
|
|
|
Currently uses `masked_mse` scaled by `self.loss_alpha`. |
|
|
|
|
|
Args: |
|
|
y_hat (torch.Tensor): Predicted values from the model. |
|
|
y (torch.Tensor): Ground truth target values. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The computed loss value. |
|
|
""" |
|
|
mse_term = masked_mse(y_hat, y) |
|
|
|
|
|
return self.loss_alpha * mse_term |
|
|
|
|
|
def _common_step(self, batch, batch_idx, step_type): |
|
|
""" |
|
|
A common step for training, validation, and testing. |
|
|
|
|
|
This method unpacks the batch, performs a forward pass, calculates the loss, |
|
|
logs the loss, and accumulates outputs for epoch-level metric calculation |
|
|
(for validation and test steps). |
|
|
|
|
|
Args: |
|
|
batch: The batch of data from the DataLoader. Expected to be a tuple containing |
|
|
DNA sequence, Morgan fingerprints, dose, target expression, |
|
|
and metadata (gene_id, drug_id, cell_line). |
|
|
batch_idx (int): The index of the current batch. |
|
|
step_type (str): A string indicating the type of step ('train', 'val', or 'test'). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The loss for the current batch. |
|
|
""" |
|
|
|
|
|
|
|
|
dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line = batch |
|
|
|
|
|
y_hat = self(dna_seq, morgan_fingerprints, dose) |
|
|
|
|
|
loss = self.loss_fn(y_hat, target_expression) |
|
|
self.log(f'{step_type}_loss', loss, batch_size=target_expression.shape[0], on_step=(step_type=='train' and False), on_epoch=True, prog_bar=(step_type!='train')) |
|
|
|
|
|
if step_type != 'train': |
|
|
|
|
|
batch_size = target_expression.shape[0] |
|
|
for i in range(batch_size): |
|
|
item_data = { |
|
|
'pred': y_hat[i].detach(), |
|
|
'target': target_expression[i].detach(), |
|
|
'gene_id': gene_id[i], |
|
|
'drug_id': drug_id[i], |
|
|
'cell_line': cell_line[i], |
|
|
'rank': self.trainer.global_rank |
|
|
} |
|
|
self.epoch_outputs.append(item_data) |
|
|
return loss |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
"""PyTorch Lightning training step. Calls `_common_step`.""" |
|
|
return self._common_step(batch, batch_idx, 'train') |
|
|
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
"""PyTorch Lightning validation step. Calls `_common_step`.""" |
|
|
return self._common_step(batch, batch_idx, 'val') |
|
|
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
"""PyTorch Lightning test step. Calls `_common_step`.""" |
|
|
return self._common_step(batch, batch_idx, 'test') |
|
|
|
|
|
def on_validation_epoch_start(self): |
|
|
"""Clears accumulated outputs at the start of each validation epoch.""" |
|
|
self.epoch_outputs = [] |
|
|
|
|
|
def on_test_epoch_start(self): |
|
|
"""Clears accumulated outputs at the start of each test epoch.""" |
|
|
self.epoch_outputs = [] |
|
|
|
|
|
def configure_optimizers(self): |
|
|
""" |
|
|
Configures the optimizer for the model. |
|
|
|
|
|
Uses AdamW with the specified learning rate and weight decay. |
|
|
|
|
|
Returns: |
|
|
torch.optim.Optimizer: The configured AdamW optimizer. |
|
|
""" |
|
|
if self.weight_decay is None: |
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) |
|
|
else: |
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) |
|
|
return optimizer |
|
|
|
|
|
|
|
|
class LitEnformerSMILES(LitBaseModel): |
|
|
""" |
|
|
A PyTorch Lightning module that combines genomic sequence information (via Enformer) |
|
|
with drug chemical structure (represented by Morgan fingerprints) and drug dose |
|
|
to predict gene expression changes. |
|
|
|
|
|
The model architecture consists of three main branches: |
|
|
1. DNA Module: Uses a pre-trained Enformer model (with an adapted head) to extract |
|
|
features from a one-hot encoded DNA sequence centered around a gene's TSS. |
|
|
2. Drug Module: Uses pre-computed Morgan fingerprints as the drug representation. |
|
|
3. Dose Module: Directly uses the numerical dose value. |
|
|
|
|
|
Features from these three branches are concatenated and passed through a multi-layer |
|
|
fusion head (MLP with ReLU, BatchNorm, Dropout) to produce the final prediction |
|
|
of gene expression. |
|
|
|
|
|
Inherits common training and evaluation logic from `LitBaseModel`. |
|
|
""" |
|
|
def __init__(self, |
|
|
enformer_model_name: str = 'EleutherAI/enformer-official-rough', |
|
|
enformer_target_length: int = -1, |
|
|
num_output_tracks_enformer_head: int = 1, |
|
|
morgan_fingerprint_dim: int = 2048, |
|
|
dose_input_dim: int = 1, |
|
|
fusion_hidden_dim: int = 256, |
|
|
final_output_tracks: int = 1, |
|
|
learning_rate=5e-6, |
|
|
loss_alpha=1.0, |
|
|
weight_decay=None, |
|
|
eval_gene_sets=None): |
|
|
""" |
|
|
Initializes the LitEnformerSMILES (or LitEnformerMorgan) model. |
|
|
|
|
|
Args: |
|
|
enformer_model_name (str, optional): Name or path of the pre-trained Enformer model. |
|
|
enformer_target_length (int, optional): Target length for Enformer's internal pooling. |
|
|
num_output_tracks_enformer_head (int, optional): Output features from Enformer head. |
|
|
morgan_fingerprint_dim (int, optional): Dimensionality of the Morgan fingerprint vector |
|
|
(e.g., 2048 for ECFP4). Defaults to 2048. |
|
|
dose_input_dim (int, optional): Dimensionality of the drug dose input. Defaults to 1. |
|
|
fusion_hidden_dim (int, optional): Hidden dimension for the fusion MLP. Defaults to 256. |
|
|
final_output_tracks (int, optional): Number of final output values. Defaults to 1. |
|
|
learning_rate (float, optional): Learning rate. Defaults to 5e-6. |
|
|
loss_alpha (float, optional): Weight for MSE loss. Defaults to 1.0. |
|
|
weight_decay (float, optional): Weight decay. Defaults to None. |
|
|
eval_gene_sets (dict, optional): Gene sets for targeted evaluation. Defaults to None. |
|
|
""" |
|
|
super().__init__(learning_rate, loss_alpha, weight_decay, eval_gene_sets) |
|
|
self.save_hyperparameters( |
|
|
"enformer_model_name", "enformer_target_length", |
|
|
"num_output_tracks_enformer_head", "morgan_fingerprint_dim", |
|
|
"dose_input_dim", "fusion_hidden_dim", "final_output_tracks", |
|
|
"learning_rate", "loss_alpha", "weight_decay" |
|
|
) |
|
|
|
|
|
|
|
|
enformer_pretrained = Enformer.from_pretrained( |
|
|
self.hparams.enformer_model_name, |
|
|
target_length=self.hparams.enformer_target_length |
|
|
) |
|
|
self.dna_module = HeadAdapterWrapper( |
|
|
enformer=enformer_pretrained, |
|
|
num_tracks=self.hparams.num_output_tracks_enformer_head, |
|
|
post_transformer_embed=False, |
|
|
output_activation=nn.Identity() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fusion_input_dim = self.hparams.num_output_tracks_enformer_head + self.hparams.morgan_fingerprint_dim + self.hparams.dose_input_dim |
|
|
self.fusion_head = nn.Sequential( |
|
|
nn.Linear(fusion_input_dim, self.hparams.fusion_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.BatchNorm1d(self.hparams.fusion_hidden_dim), |
|
|
nn.Dropout(0.25), |
|
|
nn.Linear(self.hparams.fusion_hidden_dim, self.hparams.fusion_hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.BatchNorm1d(self.hparams.fusion_hidden_dim // 2), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(self.hparams.fusion_hidden_dim // 2, self.hparams.final_output_tracks) |
|
|
) |
|
|
|
|
|
def forward(self, dna_seq, morgan_fingerprints, dose): |
|
|
""" |
|
|
Defines the forward pass of the LitEnformerSMILES model using Morgan Fingerprints. |
|
|
|
|
|
Args: |
|
|
dna_seq (torch.Tensor): Batch of one-hot encoded DNA sequences. |
|
|
Shape: (batch_size, sequence_length, 4). |
|
|
morgan_fingerprints (torch.Tensor): Batch of pre-computed Morgan fingerprint vectors. |
|
|
Shape: (batch_size, morgan_fingerprint_dim). |
|
|
dose (torch.Tensor): Batch of drug dose values. |
|
|
Shape: (batch_size, dose_input_dim). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The model's prediction. Shape: (batch_size, final_output_tracks). |
|
|
""" |
|
|
|
|
|
dna_out_intermediate = self.dna_module(dna_seq, freeze_enformer=False) |
|
|
center_seq_idx = dna_out_intermediate.shape[1] // 2 |
|
|
dna_features = dna_out_intermediate[:, center_seq_idx, :] |
|
|
|
|
|
|
|
|
|
|
|
smiles_features = morgan_fingerprints |
|
|
|
|
|
|
|
|
if dose.ndim == 1: |
|
|
dose = dose.unsqueeze(-1) |
|
|
|
|
|
|
|
|
combined_features = torch.cat([dna_features, smiles_features, dose], dim=1) |
|
|
prediction = self.fusion_head(combined_features) |
|
|
return prediction |
|
|
|
|
|
|
|
|
class MetricLogger(pl.Callback): |
|
|
""" |
|
|
A PyTorch Lightning Callback for comprehensive metric calculation and logging. |
|
|
|
|
|
This callback accumulates predictions and targets during validation and test epochs. |
|
|
At the end of these epochs, it: |
|
|
1. Processes the accumulated outputs into a pandas DataFrame. |
|
|
2. Saves the raw predictions and targets for the epoch to a CSV file. |
|
|
3. Logs a sample of these raw predictions as a W&B Table if WandbLogger is used. |
|
|
4. Calculates overall performance metrics (MSE, Pearson, R2) for the epoch. |
|
|
5. If `eval_gene_sets` are provided in the LightningModule, calculates metrics for these specific gene subsets. |
|
|
6. Calculates metrics per cell line if 'cell_line' information is available in the outputs. |
|
|
7. Logs all calculated metrics to the LightningModule's logger. |
|
|
|
|
|
Attributes: |
|
|
save_dir_prefix (str): Prefix for the directory where metric CSVs will be saved. |
|
|
current_epoch_data (list): List to accumulate dictionaries of pred/target/metadata per item. |
|
|
""" |
|
|
def __init__(self, save_dir_prefix="results"): |
|
|
""" |
|
|
Initializes the MetricLogger callback. |
|
|
|
|
|
Args: |
|
|
save_dir_prefix (str, optional): Directory prefix for saving metrics files. |
|
|
Defaults to "results". |
|
|
""" |
|
|
super().__init__() |
|
|
self.save_dir_prefix = save_dir_prefix |
|
|
self.current_epoch_data = [] |
|
|
|
|
|
def _process_epoch_outputs(self, pl_module, stage): |
|
|
""" |
|
|
Processes the raw outputs collected during an epoch into a pandas DataFrame. |
|
|
|
|
|
Converts tensor data for 'pred' and 'target' columns to NumPy/Python native types. |
|
|
|
|
|
Args: |
|
|
pl_module (pl.LightningModule): The LightningModule instance. |
|
|
stage (str): The current stage (e.g., "validation", "test"). |
|
|
|
|
|
Returns: |
|
|
pd.DataFrame: A DataFrame containing the processed epoch outputs. |
|
|
Returns an empty DataFrame if no outputs were collected. |
|
|
""" |
|
|
if not hasattr(pl_module, 'epoch_outputs') or not pl_module.epoch_outputs: |
|
|
warn(f"No outputs collected (pl_module.epoch_outputs is missing or empty) during {stage} epoch for MetricLogger.") |
|
|
return pd.DataFrame() |
|
|
|
|
|
df = pd.DataFrame(pl_module.epoch_outputs) |
|
|
|
|
|
for col in ['pred', 'target']: |
|
|
if col in df.columns and not df[col].empty: |
|
|
if isinstance(df[col].iloc[0], torch.Tensor): |
|
|
df[col] = df[col].apply(lambda x: x.cpu().float().numpy().item() if x.numel() == 1 else x.cpu().float().numpy()) |
|
|
return df |
|
|
|
|
|
def on_validation_epoch_end(self, trainer, pl_module): |
|
|
"""Hook called at the end of the validation epoch.""" |
|
|
if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs: |
|
|
self.current_epoch_data = self._process_epoch_outputs(pl_module, "validation") |
|
|
if not self.current_epoch_data.empty: |
|
|
self._log_and_save_metrics(trainer, pl_module, "validation") |
|
|
else: |
|
|
warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_validation_epoch_end.") |
|
|
|
|
|
def on_test_epoch_end(self, trainer, pl_module): |
|
|
"""Hook called at the end of the test epoch.""" |
|
|
if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs: |
|
|
self.current_epoch_data = self._process_epoch_outputs(pl_module, "test") |
|
|
if not self.current_epoch_data.empty: |
|
|
self._log_and_save_metrics(trainer, pl_module, "test") |
|
|
else: |
|
|
warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_test_epoch_end.") |
|
|
|
|
|
|
|
|
def _log_and_save_metrics(self, trainer, pl_module, stage): |
|
|
""" |
|
|
Calculates, logs, and saves metrics for the current stage and epoch. |
|
|
|
|
|
Args: |
|
|
trainer (pl.Trainer): The PyTorch Lightning Trainer instance. |
|
|
pl_module (pl.LightningModule): The LightningModule instance. |
|
|
stage (str): The current stage (e.g., "validation", "test"). |
|
|
""" |
|
|
epoch = trainer.current_epoch if trainer.current_epoch is not None else -1 |
|
|
save_dir = getattr(pl_module.hparams, 'save_dir', |
|
|
os.path.join(self.save_dir_prefix, f"run_{trainer.logger.version if trainer.logger else 'local'}")) |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
raw_preds_path = os.path.join(save_dir, f"{stage}_predictions_epoch_{epoch}.csv") |
|
|
self.current_epoch_data.to_csv(raw_preds_path, index=False) |
|
|
|
|
|
if trainer.logger and hasattr(trainer.logger, 'experiment') and isinstance(trainer.logger.experiment, wandb.sdk.wandb_run.Run): |
|
|
try: |
|
|
trainer.logger.experiment.log({f"{stage}_raw_predictions_epoch_{epoch}": wandb.Table(dataframe=self.current_epoch_data.head(1000))}) |
|
|
except Exception as e: |
|
|
warn(f"MetricLogger: Failed to log raw predictions table to W&B: {e}") |
|
|
|
|
|
overall_metrics = self._calculate_metrics_for_group(self.current_epoch_data, pl_module.device) |
|
|
if overall_metrics: |
|
|
pl_module.log_dict({f"{stage}_{k}_epoch": v for k, v in overall_metrics.items()}, sync_dist=True) |
|
|
|
|
|
if hasattr(pl_module, 'eval_gene_sets') and pl_module.eval_gene_sets and isinstance(pl_module.eval_gene_sets, dict) and 'gene_id' in self.current_epoch_data.columns: |
|
|
for split_name, gene_list in pl_module.eval_gene_sets.items(): |
|
|
if not gene_list: continue |
|
|
split_df = self.current_epoch_data[self.current_epoch_data['gene_id'].isin(gene_list)] |
|
|
if not split_df.empty: |
|
|
split_metrics = self._calculate_metrics_for_group(split_df, pl_module.device) |
|
|
if split_metrics: |
|
|
pl_module.log_dict({f"{stage}_{split_name}_genes_{k}_epoch": v for k, v in split_metrics.items()}, sync_dist=True) |
|
|
|
|
|
if 'cell_line' in self.current_epoch_data.columns: |
|
|
for cell_line, group_df in self.current_epoch_data.groupby('cell_line'): |
|
|
cl_metrics = self._calculate_metrics_for_group(group_df, pl_module.device) |
|
|
if cl_metrics: |
|
|
pl_module.log_dict({f"{stage}_{cell_line}_cell_line_{k}_epoch": v for k,v in cl_metrics.items()}, sync_dist=True) |
|
|
|
|
|
|
|
|
def _calculate_metrics_for_group(self, df_group, device): |
|
|
""" |
|
|
Calculates regression metrics (MSE, Pearson, R2) for a given group of predictions. |
|
|
|
|
|
Args: |
|
|
df_group (pd.DataFrame): DataFrame containing 'pred' and 'target' columns for the group. |
|
|
device (torch.device): The device to perform calculations on. |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary of calculated metrics (mse, pearson, r2). Returns empty if data is insufficient. |
|
|
""" |
|
|
if df_group.empty or 'pred' not in df_group.columns or 'target' not in df_group.columns: |
|
|
return {} |
|
|
|
|
|
preds_np = np.array(df_group['pred'].tolist(), dtype=np.float32) |
|
|
targets_np = np.array(df_group['target'].tolist(), dtype=np.float32) |
|
|
|
|
|
preds = torch.tensor(preds_np).to(device) |
|
|
targets = torch.tensor(targets_np).to(device) |
|
|
|
|
|
if preds.ndim == 1: |
|
|
preds = preds.squeeze() |
|
|
targets = targets.squeeze() |
|
|
|
|
|
if preds.numel() == 0 or targets.numel() == 0 or preds.shape != targets.shape : |
|
|
warn(f"Skipping metrics calculation for a group due to mismatched or empty preds/targets. Pred shape: {preds.shape}, Target shape: {targets.shape}") |
|
|
return {} |
|
|
|
|
|
mse_val_tensor = masked_mse(preds.unsqueeze(-1) if preds.ndim==1 else preds, |
|
|
targets.unsqueeze(-1) if targets.ndim==1 else targets) |
|
|
calculated_metrics = {'mse': mse_val_tensor.item()} |
|
|
|
|
|
if preds.numel() < 2: |
|
|
warn(f"Skipping Pearson/R2 for a group with < 2 samples. Found {preds.numel()} samples. Only MSE will be reported.") |
|
|
return calculated_metrics |
|
|
|
|
|
preds_for_corr = preds.squeeze() |
|
|
targets_for_corr = targets.squeeze() |
|
|
|
|
|
if preds_for_corr.shape != targets_for_corr.shape or preds_for_corr.ndim > 1 and preds_for_corr.shape[1] >1: |
|
|
warn(f"Skipping Pearson/R2 due to incompatible shapes after squeeze for correlation. Pred: {preds_for_corr.shape}, Target: {targets_for_corr.shape}") |
|
|
return calculated_metrics |
|
|
|
|
|
try: |
|
|
pearson_fn = PearsonCorrCoef().to(device) |
|
|
pearson_val = pearson_fn(preds_for_corr, targets_for_corr) |
|
|
calculated_metrics['pearson'] = pearson_val.item() |
|
|
except Exception as e: |
|
|
warn(f"Could not compute Pearson Correlation: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}") |
|
|
|
|
|
try: |
|
|
r2_fn = R2Score().to(device) |
|
|
r2_val = r2_fn(preds_for_corr, targets_for_corr) |
|
|
calculated_metrics['r2'] = r2_val.item() |
|
|
except Exception as e: |
|
|
warn(f"Could not compute R2 Score: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}") |
|
|
|
|
|
return calculated_metrics |