About finetuning and dataset loading
#1
by
alveare
- opened
Hi Author,
I’m trying to run your model and ran into two blocking issues. I’d appreciate your guidance:
Import error in utils.py
When loading the model, utils.py has
from pretrain.pre_utils import compute_power
but I can’t find a pretrain/pre_utils.py.MAYO preprocessing format
During MAYO preprocessing, the pipeline produces x_22500.npy and y_22500.npy.
However, train1 seems to load data differently. It needs
_data = np.load(os.path.join(_dir, f'{pat_file}/data.npy'))
and
_y = np.load(os.path.join(_dir, f'{pat_file}/label.npy')) if need_y else None,
but I cannot find {pat_file} folders in raw MAYO dataset.
Could you let me know how I can resolve this issue?
pretrain.pre_utils is indeed missing. And I wrote a version of compute_powers myself:
import torch
def compute_power(
x: torch.Tensor, fs: int = 256,
bands: Optional[List[Tuple[int, int]]] = None
) -> torch.Tensor:
"""
Compute power spectral density (PSD) for each segment in the input tensor.
Standard bands (as specified in the paper of Brant):
θ (4-8Hz), alpha (8-13Hz), β (13-30Hz), gamma1 (30-50Hz),
gamma2 (50-70Hz), gamma3 (70-90Hz), gamma4 (90-110Hz), gamma5 (110-128Hz).
Parameters
----------
x : torch.Tensor
Input tensor of shape (..., seg_len),
where the last dimension is the temporal dimension
to be converted to PSD.
fs : int, optional
Sampling frequency in Hz, by default 256.
bands: Optional[List[Tuple[int, int]]], optional
List of frequency bands as (low, high) tuples. \n
If None, defaults to standard bands.
Returns
-------
torch.Tensor
Power tensor of shape (..., band_num), where 8 is the number of bands.
"""
if bands is None:
bands = [(4, 8), (8, 13), (13, 30), (30, 50), (50, 70), (70, 90), (90, 110), (110, 128)]
band_num = len(bands)
seg_len = x.shape[-1]
x_flat = x.view(-1, seg_len)
fft = torch.fft.fft(x_flat, dim=-1)
power_spectrum = torch.abs(fft) ** 2 / seg_len
# Extract positive frequencies
power_pos = power_spectrum[:, :seg_len // 2 + 1]
freqs_pos = torch.fft.fftfreq(seg_len, 1 / fs, device=x.device)[:seg_len // 2 + 1]
# Compute power for each band
band_powers = []
for low, high in bands:
mask = (freqs_pos >= low) & (freqs_pos < high)
band_power = power_pos[:, mask].sum(dim=-1)
band_powers.append(band_power)
power_tensor = torch.stack(band_powers, dim=-1)
original_shape = x.shape[:-1]
power_tensor = power_tensor.view(original_shape + (band_num,))
return power_tensor