File size: 23,762 Bytes
3203cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# pl_models.py
"""
This module defines PyTorch Lightning modules for the Tahoeformer project.
It includes a base model class (`LitBaseModel`) and the main experimental model
(`LitEnformerSMILES`) which combines an Enformer-based DNA sequence model with
drug information (SMILES string processed into Morgan Fingerprints) and dose information 
to predict gene expression.

Key components:
- masked_mse: A utility loss function for Mean Squared Error that handles NaN targets.
- LitBaseModel: A base LightningModule providing common training, validation, test steps,
  optimizer configuration, and basic metric logging hooks.
- LitEnformerSMILES: The primary model for predicting drug-induced gene expression changes,
  using Enformer for DNA and Morgan fingerprints for drugs.
- MetricLogger: A PyTorch Lightning Callback for detailed logging of predictions.
"""

import pandas as pd
import os
import torch
import torch.nn as nn
import lightning.pytorch as pl
from enformer_pytorch.finetune import HeadAdapterWrapper
from enformer_pytorch import Enformer
from torchmetrics.regression import PearsonCorrCoef, R2Score
from warnings import warn
import wandb
import numpy as np # Added for MetricLogger consistency

# --- Utility Functions ---
def masked_mse(y_hat, y):
    """
    Computes Mean Squared Error (MSE) while ignoring NaN values in the target tensor.

    Args:
        y_hat (torch.Tensor): The predicted values.
        y (torch.Tensor): The target values, which may contain NaNs.

    Returns:
        torch.Tensor: A scalar tensor representing the masked MSE. Returns 0.0 if all targets are NaN.
    """
    mask = torch.isnan(y)
    if mask.all(): # Handle case where all targets in batch are NaN
        return torch.tensor(0.0, device=y_hat.device, requires_grad=True)
    mse = torch.mean((y[~mask] - y_hat[~mask])**2)
    return mse

# --- Base Lightning Module ---
class LitBaseModel(pl.LightningModule):
    """
    A base PyTorch Lightning module providing common boilerplate for training and evaluation.

    This class implements a generic training/validation/test step, loss calculation using
    `masked_mse`, optimizer configuration (AdamW), and hooks for accumulating outputs
    for detailed metric logging via the `MetricLogger` callback.

    Derived classes are expected to implement the `forward` method.

    Hyperparameters:
        learning_rate (float): The learning rate for the optimizer.
        loss_alpha (float): A coefficient for the primary loss term (MSE). Useful if 
                            additional loss terms were to be added.
        weight_decay (float, optional): Weight decay for the AdamW optimizer. If None,
                                        AdamW's internal default is used.
        eval_gene_sets (dict, optional): A dictionary where keys are set names (e.g., 'oncogenes')
                                         and values are lists of gene IDs. Used by `MetricLogger`
                                         to compute metrics for specific gene subsets.
    """
    def __init__(self, learning_rate=5e-6, loss_alpha=1.0, weight_decay=None,
                 eval_gene_sets=None): # eval_gene_sets: dict {'train': [genes], 'valid': [genes], 'test': [genes]}
        """
        Initializes the LitBaseModel.

        Args:
            learning_rate (float, optional): Learning rate. Defaults to 5e-6.
            loss_alpha (float, optional): Alpha for MSE loss. Defaults to 1.0.
            weight_decay (float, optional): Weight decay for AdamW. If None, uses optimizer default.
                                          Defaults to None.
            eval_gene_sets (dict, optional): Dictionary of gene sets for targeted evaluation.
                                           Keys are names, values are lists of gene IDs.
                                           Defaults to None.
        """
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.loss_alpha = loss_alpha # alpha for mse vs. other terms (if any)
        self.weight_decay = weight_decay
        self.eval_gene_sets = eval_gene_sets if eval_gene_sets else {}

        # Results accumulated per epoch for MetricLogger
        self.epoch_outputs = []

    def loss_fn(self, y_hat, y):
        """
        Calculates the loss for the model.

        Currently uses `masked_mse` scaled by `self.loss_alpha`.

        Args:
            y_hat (torch.Tensor): Predicted values from the model.
            y (torch.Tensor): Ground truth target values.

        Returns:
            torch.Tensor: The computed loss value.
        """
        mse_term = masked_mse(y_hat, y)
        # Potentially: add other loss terms here, weighted by (1-loss_alpha) if desired
        return self.loss_alpha * mse_term

    def _common_step(self, batch, batch_idx, step_type):
        """
        A common step for training, validation, and testing.

        This method unpacks the batch, performs a forward pass, calculates the loss,
        logs the loss, and accumulates outputs for epoch-level metric calculation
        (for validation and test steps).

        Args:
            batch: The batch of data from the DataLoader. Expected to be a tuple containing
                   DNA sequence, Morgan fingerprints, dose, target expression, 
                   and metadata (gene_id, drug_id, cell_line).
            batch_idx (int): The index of the current batch.
            step_type (str): A string indicating the type of step ('train', 'val', or 'test').

        Returns:
            torch.Tensor: The loss for the current batch.
        """
        # Batch structure will change after dataset modification:
        # (dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line)
        dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line = batch
        
        y_hat = self(dna_seq, morgan_fingerprints, dose) # Call forward method of derived class

        loss = self.loss_fn(y_hat, target_expression)
        self.log(f'{step_type}_loss', loss, batch_size=target_expression.shape[0], on_step=(step_type=='train' and False), on_epoch=True, prog_bar=(step_type!='train'))
        
        if step_type != 'train':
            # Prepare data for MetricLogger
            batch_size = target_expression.shape[0]
            for i in range(batch_size):
                item_data = {
                    'pred': y_hat[i].detach(),
                    'target': target_expression[i].detach(),
                    'gene_id': gene_id[i], 
                    'drug_id': drug_id[i], 
                    'cell_line': cell_line[i], 
                    'rank': self.trainer.global_rank 
                }
                self.epoch_outputs.append(item_data)
        return loss

    def training_step(self, batch, batch_idx):
        """PyTorch Lightning training step. Calls `_common_step`."""
        return self._common_step(batch, batch_idx, 'train')

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """PyTorch Lightning validation step. Calls `_common_step`."""
        return self._common_step(batch, batch_idx, 'val')

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """PyTorch Lightning test step. Calls `_common_step`."""
        return self._common_step(batch, batch_idx, 'test')
    
    def on_validation_epoch_start(self):
        """Clears accumulated outputs at the start of each validation epoch."""
        self.epoch_outputs = []
    
    def on_test_epoch_start(self):
        """Clears accumulated outputs at the start of each test epoch."""
        self.epoch_outputs = []

    def configure_optimizers(self):
        """
        Configures the optimizer for the model.

        Uses AdamW with the specified learning rate and weight decay.

        Returns:
            torch.optim.Optimizer: The configured AdamW optimizer.
        """
        if self.weight_decay is None:
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        else:
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        return optimizer

# --- Enformer + Morgan Fingerprints Model ---
class LitEnformerSMILES(LitBaseModel): # Consider renaming to LitEnformerMorgan for clarity
    """
    A PyTorch Lightning module that combines genomic sequence information (via Enformer)
    with drug chemical structure (represented by Morgan fingerprints) and drug dose
    to predict gene expression changes.

    The model architecture consists of three main branches:
    1. DNA Module: Uses a pre-trained Enformer model (with an adapted head) to extract
       features from a one-hot encoded DNA sequence centered around a gene's TSS.
    2. Drug Module: Uses pre-computed Morgan fingerprints as the drug representation.
    3. Dose Module: Directly uses the numerical dose value.

    Features from these three branches are concatenated and passed through a multi-layer
    fusion head (MLP with ReLU, BatchNorm, Dropout) to produce the final prediction
    of gene expression.

    Inherits common training and evaluation logic from `LitBaseModel`.
    """
    def __init__(self,
                 enformer_model_name: str = 'EleutherAI/enformer-official-rough',
                 enformer_target_length: int = -1,
                 num_output_tracks_enformer_head: int = 1, 
                 morgan_fingerprint_dim: int = 2048, # dim of the Morgan fingerprint vector
                 dose_input_dim: int = 1, 
                 fusion_hidden_dim: int = 256, 
                 final_output_tracks: int = 1, 
                 learning_rate=5e-6, 
                 loss_alpha=1.0, 
                 weight_decay=None,
                 eval_gene_sets=None):
        """
        Initializes the LitEnformerSMILES (or LitEnformerMorgan) model.

        Args:
            enformer_model_name (str, optional): Name or path of the pre-trained Enformer model.
            enformer_target_length (int, optional): Target length for Enformer's internal pooling.
            num_output_tracks_enformer_head (int, optional): Output features from Enformer head.
            morgan_fingerprint_dim (int, optional): Dimensionality of the Morgan fingerprint vector
                                                   (e.g., 2048 for ECFP4). Defaults to 2048.
            dose_input_dim (int, optional): Dimensionality of the drug dose input. Defaults to 1.
            fusion_hidden_dim (int, optional): Hidden dimension for the fusion MLP. Defaults to 256.
            final_output_tracks (int, optional): Number of final output values. Defaults to 1.
            learning_rate (float, optional): Learning rate. Defaults to 5e-6.
            loss_alpha (float, optional): Weight for MSE loss. Defaults to 1.0.
            weight_decay (float, optional): Weight decay. Defaults to None.
            eval_gene_sets (dict, optional): Gene sets for targeted evaluation. Defaults to None.
        """
        super().__init__(learning_rate, loss_alpha, weight_decay, eval_gene_sets)
        self.save_hyperparameters(
            "enformer_model_name", "enformer_target_length", 
            "num_output_tracks_enformer_head", "morgan_fingerprint_dim",
            "dose_input_dim", "fusion_hidden_dim", "final_output_tracks",
            "learning_rate", "loss_alpha", "weight_decay"
        )

        # 1. DNA Module (Enformer with HeadAdapter)
        enformer_pretrained = Enformer.from_pretrained(
            self.hparams.enformer_model_name,
            target_length=self.hparams.enformer_target_length 
        )
        self.dna_module = HeadAdapterWrapper(
            enformer=enformer_pretrained,
            num_tracks=self.hparams.num_output_tracks_enformer_head,
            post_transformer_embed=False, 
            output_activation=nn.Identity()
        )

        # 2. Drug Module (Morgan Fingerprints are provided as input directly)
        # No layers needed here as fingerprints are pre-computed.
        # The self.hparams.morgan_fingerprint_dim defines the expected input dimension.

        # 3. Fusion Head 
        # Input dimension uses morgan_fingerprint_dim
        fusion_input_dim = self.hparams.num_output_tracks_enformer_head + self.hparams.morgan_fingerprint_dim + self.hparams.dose_input_dim
        self.fusion_head = nn.Sequential(
            nn.Linear(fusion_input_dim, self.hparams.fusion_hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(self.hparams.fusion_hidden_dim),
            nn.Dropout(0.25),
            nn.Linear(self.hparams.fusion_hidden_dim, self.hparams.fusion_hidden_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(self.hparams.fusion_hidden_dim // 2),
            nn.Dropout(0.1),
            nn.Linear(self.hparams.fusion_hidden_dim // 2, self.hparams.final_output_tracks)
        )

    def forward(self, dna_seq, morgan_fingerprints, dose):
        """
        Defines the forward pass of the LitEnformerSMILES model using Morgan Fingerprints.

        Args:
            dna_seq (torch.Tensor): Batch of one-hot encoded DNA sequences.
                                    Shape: (batch_size, sequence_length, 4).
            morgan_fingerprints (torch.Tensor): Batch of pre-computed Morgan fingerprint vectors.
                                                Shape: (batch_size, morgan_fingerprint_dim).
            dose (torch.Tensor): Batch of drug dose values.
                                 Shape: (batch_size, dose_input_dim).

        Returns:
            torch.Tensor: The model's prediction. Shape: (batch_size, final_output_tracks).
        """
        # --- DNA Processing ---
        dna_out_intermediate = self.dna_module(dna_seq, freeze_enformer=False)
        center_seq_idx = dna_out_intermediate.shape[1] // 2 
        dna_features = dna_out_intermediate[:, center_seq_idx, :] 
        
        # --- Drug Processing (Morgan Fingerprints) ---
        # Morgan fingerprints are directly used as features.
        smiles_features = morgan_fingerprints # Shape: (batch_size, morgan_fingerprint_dim)

        # --- Dose Processing ---
        if dose.ndim == 1:
            dose = dose.unsqueeze(-1)
        
        # --- Feature Combination & Final Prediction ---
        combined_features = torch.cat([dna_features, smiles_features, dose], dim=1)
        prediction = self.fusion_head(combined_features)
        return prediction

# --- Metrics Logging Callback ---
class MetricLogger(pl.Callback):
    """
    A PyTorch Lightning Callback for comprehensive metric calculation and logging.

    This callback accumulates predictions and targets during validation and test epochs.
    At the end of these epochs, it:
    1. Processes the accumulated outputs into a pandas DataFrame.
    2. Saves the raw predictions and targets for the epoch to a CSV file.
    3. Logs a sample of these raw predictions as a W&B Table if WandbLogger is used.
    4. Calculates overall performance metrics (MSE, Pearson, R2) for the epoch.
    5. If `eval_gene_sets` are provided in the LightningModule, calculates metrics for these specific gene subsets.
    6. Calculates metrics per cell line if 'cell_line' information is available in the outputs.
    7. Logs all calculated metrics to the LightningModule's logger.

    Attributes:
        save_dir_prefix (str): Prefix for the directory where metric CSVs will be saved.
        current_epoch_data (list): List to accumulate dictionaries of pred/target/metadata per item.
    """
    def __init__(self, save_dir_prefix="results"):
        """
        Initializes the MetricLogger callback.

        Args:
            save_dir_prefix (str, optional): Directory prefix for saving metrics files.
                                           Defaults to "results".
        """
        super().__init__()
        self.save_dir_prefix = save_dir_prefix
        self.current_epoch_data = [] 

    def _process_epoch_outputs(self, pl_module, stage):
        """
        Processes the raw outputs collected during an epoch into a pandas DataFrame.

        Converts tensor data for 'pred' and 'target' columns to NumPy/Python native types.

        Args:
            pl_module (pl.LightningModule): The LightningModule instance.
            stage (str): The current stage (e.g., "validation", "test").

        Returns:
            pd.DataFrame: A DataFrame containing the processed epoch outputs.
                          Returns an empty DataFrame if no outputs were collected.
        """
        if not hasattr(pl_module, 'epoch_outputs') or not pl_module.epoch_outputs:
            warn(f"No outputs collected (pl_module.epoch_outputs is missing or empty) during {stage} epoch for MetricLogger.")
            return pd.DataFrame()

        df = pd.DataFrame(pl_module.epoch_outputs)
        
        for col in ['pred', 'target']:
            if col in df.columns and not df[col].empty:
                if isinstance(df[col].iloc[0], torch.Tensor):
                    df[col] = df[col].apply(lambda x: x.cpu().float().numpy().item() if x.numel() == 1 else x.cpu().float().numpy())
        return df

    def on_validation_epoch_end(self, trainer, pl_module):
        """Hook called at the end of the validation epoch."""
        if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
            self.current_epoch_data = self._process_epoch_outputs(pl_module, "validation")
            if not self.current_epoch_data.empty:
                self._log_and_save_metrics(trainer, pl_module, "validation")
        else:
            warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_validation_epoch_end.")

    def on_test_epoch_end(self, trainer, pl_module):
        """Hook called at the end of the test epoch."""
        if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
            self.current_epoch_data = self._process_epoch_outputs(pl_module, "test")
            if not self.current_epoch_data.empty:
                self._log_and_save_metrics(trainer, pl_module, "test")
        else:
            warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_test_epoch_end.")


    def _log_and_save_metrics(self, trainer, pl_module, stage):
        """
        Calculates, logs, and saves metrics for the current stage and epoch.

        Args:
            trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
            pl_module (pl.LightningModule): The LightningModule instance.
            stage (str): The current stage (e.g., "validation", "test").
        """
        epoch = trainer.current_epoch if trainer.current_epoch is not None else -1 
        save_dir = getattr(pl_module.hparams, 'save_dir', 
                           os.path.join(self.save_dir_prefix, f"run_{trainer.logger.version if trainer.logger else 'local'}"))
        os.makedirs(save_dir, exist_ok=True)

        raw_preds_path = os.path.join(save_dir, f"{stage}_predictions_epoch_{epoch}.csv")
        self.current_epoch_data.to_csv(raw_preds_path, index=False)
        
        if trainer.logger and hasattr(trainer.logger, 'experiment') and isinstance(trainer.logger.experiment, wandb.sdk.wandb_run.Run):
            try:
                trainer.logger.experiment.log({f"{stage}_raw_predictions_epoch_{epoch}": wandb.Table(dataframe=self.current_epoch_data.head(1000))})
            except Exception as e:
                warn(f"MetricLogger: Failed to log raw predictions table to W&B: {e}")

        overall_metrics = self._calculate_metrics_for_group(self.current_epoch_data, pl_module.device)
        if overall_metrics:
            pl_module.log_dict({f"{stage}_{k}_epoch": v for k, v in overall_metrics.items()}, sync_dist=True)

        if hasattr(pl_module, 'eval_gene_sets') and pl_module.eval_gene_sets and isinstance(pl_module.eval_gene_sets, dict) and 'gene_id' in self.current_epoch_data.columns:
            for split_name, gene_list in pl_module.eval_gene_sets.items():
                if not gene_list: continue
                split_df = self.current_epoch_data[self.current_epoch_data['gene_id'].isin(gene_list)]
                if not split_df.empty:
                    split_metrics = self._calculate_metrics_for_group(split_df, pl_module.device)
                    if split_metrics:
                        pl_module.log_dict({f"{stage}_{split_name}_genes_{k}_epoch": v for k, v in split_metrics.items()}, sync_dist=True)
        
        if 'cell_line' in self.current_epoch_data.columns:
            for cell_line, group_df in self.current_epoch_data.groupby('cell_line'):
                cl_metrics = self._calculate_metrics_for_group(group_df, pl_module.device)
                if cl_metrics:
                     pl_module.log_dict({f"{stage}_{cell_line}_cell_line_{k}_epoch": v for k,v in cl_metrics.items()}, sync_dist=True)


    def _calculate_metrics_for_group(self, df_group, device):
        """
        Calculates regression metrics (MSE, Pearson, R2) for a given group of predictions.

        Args:
            df_group (pd.DataFrame): DataFrame containing 'pred' and 'target' columns for the group.
            device (torch.device): The device to perform calculations on.

        Returns:
            dict: A dictionary of calculated metrics (mse, pearson, r2). Returns empty if data is insufficient.
        """
        if df_group.empty or 'pred' not in df_group.columns or 'target' not in df_group.columns:
            return {}

        preds_np = np.array(df_group['pred'].tolist(), dtype=np.float32)
        targets_np = np.array(df_group['target'].tolist(), dtype=np.float32)

        preds = torch.tensor(preds_np).to(device)
        targets = torch.tensor(targets_np).to(device)

        if preds.ndim == 1: 
            preds = preds.squeeze()
            targets = targets.squeeze()
        
        if preds.numel() == 0 or targets.numel() == 0 or preds.shape != targets.shape :
            warn(f"Skipping metrics calculation for a group due to mismatched or empty preds/targets. Pred shape: {preds.shape}, Target shape: {targets.shape}")
            return {}
        
        mse_val_tensor = masked_mse(preds.unsqueeze(-1) if preds.ndim==1 else preds, 
                                    targets.unsqueeze(-1) if targets.ndim==1 else targets)
        calculated_metrics = {'mse': mse_val_tensor.item()}

        if preds.numel() < 2: 
             warn(f"Skipping Pearson/R2 for a group with < 2 samples. Found {preds.numel()} samples. Only MSE will be reported.")
             return calculated_metrics 

        preds_for_corr = preds.squeeze()
        targets_for_corr = targets.squeeze()

        if preds_for_corr.shape != targets_for_corr.shape or preds_for_corr.ndim > 1 and preds_for_corr.shape[1] >1:
            warn(f"Skipping Pearson/R2 due to incompatible shapes after squeeze for correlation. Pred: {preds_for_corr.shape}, Target: {targets_for_corr.shape}")
            return calculated_metrics

        try:
            pearson_fn = PearsonCorrCoef().to(device) 
            pearson_val = pearson_fn(preds_for_corr, targets_for_corr)
            calculated_metrics['pearson'] = pearson_val.item()
        except Exception as e:
            warn(f"Could not compute Pearson Correlation: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")

        try:
            r2_fn = R2Score().to(device) 
            r2_val = r2_fn(preds_for_corr, targets_for_corr)
            calculated_metrics['r2'] = r2_val.item()
        except Exception as e:
            warn(f"Could not compute R2 Score: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")

        return calculated_metrics