About finetuning and dataset loading

#1
by alveare - opened

Hi Author,

I’m trying to run your model and ran into two blocking issues. I’d appreciate your guidance:

  1. Import error in utils.py
    When loading the model, utils.py has
    from pretrain.pre_utils import compute_power
    but I can’t find a pretrain/pre_utils.py.

  2. MAYO preprocessing format
    During MAYO preprocessing, the pipeline produces x_22500.npy and y_22500.npy.
    However, train1 seems to load data differently. It needs
    _data = np.load(os.path.join(_dir, f'{pat_file}/data.npy'))
    and
    _y = np.load(os.path.join(_dir, f'{pat_file}/label.npy')) if need_y else None,
    but I cannot find {pat_file} folders in raw MAYO dataset.
    Could you let me know how I can resolve this issue?

pretrain.pre_utils is indeed missing. And I wrote a version of compute_powers myself:

import torch

def compute_power(
        x: torch.Tensor, fs: int = 256,
        bands: Optional[List[Tuple[int, int]]] = None
    ) -> torch.Tensor:
    """
    Compute power spectral density (PSD) for each segment in the input tensor.

    Standard bands (as specified in the paper of Brant):
    θ (4-8Hz), alpha (8-13Hz), β (13-30Hz), gamma1 (30-50Hz),
    gamma2 (50-70Hz), gamma3 (70-90Hz), gamma4 (90-110Hz), gamma5 (110-128Hz).

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (..., seg_len),
        where the last dimension is the temporal dimension
        to be converted to PSD.
    fs : int, optional
        Sampling frequency in Hz, by default 256.
    bands: Optional[List[Tuple[int, int]]], optional
        List of frequency bands as (low, high) tuples. \n
        If None, defaults to standard bands.

    Returns
    -------
    torch.Tensor
        Power tensor of shape (..., band_num), where 8 is the number of bands.
    """
    if bands is None:
        bands = [(4, 8), (8, 13), (13, 30), (30, 50), (50, 70), (70, 90), (90, 110), (110, 128)]

    band_num = len(bands)
    seg_len = x.shape[-1]

    x_flat = x.view(-1, seg_len)

    fft = torch.fft.fft(x_flat, dim=-1)
    power_spectrum = torch.abs(fft) ** 2 / seg_len
    
    # Extract positive frequencies
    power_pos = power_spectrum[:, :seg_len // 2 + 1]
    freqs_pos = torch.fft.fftfreq(seg_len, 1 / fs, device=x.device)[:seg_len // 2 + 1]
    
    # Compute power for each band
    band_powers = []
    for low, high in bands:
        mask = (freqs_pos >= low) & (freqs_pos < high)
        band_power = power_pos[:, mask].sum(dim=-1)
        band_powers.append(band_power)

    power_tensor = torch.stack(band_powers, dim=-1)
    original_shape = x.shape[:-1]
    power_tensor = power_tensor.view(original_shape + (band_num,))
    
    return power_tensor

Sign up or log in to comment