import torch from torch.utils.data import Dataset import pandas as pd import numpy as np import os class MIMICECGDataset(Dataset): """ PyTorch Dataset for MIMIC-IV-ECG. """ def __init__(self, df, data_dir, transform=False, label_func=None): """ Args: df (pd.DataFrame): Dataframe with subject_id, study_id, and report columns. data_dir (str): Root directory of the dataset (containing the 'files' folder). transform (callable, optional): Optional transform to be applied on a sample. label_func (callable, optional): Custom function to extract labels from a row. """ self.df = df self.data_dir = data_dir self.transform = transform self.label_func = label_func # MIMIC-ECG Constants self.n_leads = 12 self.n_samples = 5000 self.fs = 500 self.gain = 200.0 # Standard gain in ADU/mV # Define the target classes we want to detect # These keys will be searched in the report columns self.class_mapping = { 'Normally filtered': 0, # Not a diagnosis, but often present 'Sinus rhythm': 0, 'Atrial fibrillation': 1, 'Sinus tachycardia': 2, 'Sinus bradycardia': 3, 'Ventricular tachycardia': 4, # Add more as needed } self.num_classes = 5 # For now def __len__(self): return len(self.df) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() row = self.df.iloc[idx] subj_id = str(row['subject_id']) study_id = str(row['study_id']) # Construct path: files/p{XXX}/p{subject_id}/s{study_id}/{study_id}.dat subdir = f"p{subj_id[:4]}" # Ensure we handle the folder structure correctly. # Based on exploration: data_dir/files/p100/p10000032/s40689238/40689238.dat file_path = os.path.join(self.data_dir, 'files', subdir, f"p{subj_id}", f"s{study_id}", f"{study_id}.dat") # 1. Load Signal signal = self.load_signal_numpy(file_path) # 2. Get Labels if self.label_func: # Need to pass text to label_func, or row? # get_refined_labels expects text. Let's extract text here or let func handle row. # Best to let func handle text so it's pure. cols = [c for c in self.df.columns if 'report_' in c] full_text = ' '.join([str(row[c]) for c in cols]) labels = self.label_func(full_text) else: labels = self.get_labels(row) # 3. Return sample # Signal shape: (12, 5000) sample = { 'signal': signal, 'labels': labels, 'study_id': study_id } return sample def load_signal_numpy(self, path): """ Reads the binary .dat file using numpy. Returns a torch tensor of shape (12, 5000). """ # Return zeros if file is missing (to avoid crashing training loop on missing files) if not os.path.exists(path): return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32) try: # Read binary file as 16-bit integers raw_data = np.fromfile(path, dtype=np.int16) # Check size expected_size = self.n_leads * self.n_samples if raw_data.size != expected_size: # Handle truncated or wrong-sized files by padding or cutting if raw_data.size < expected_size: padded = np.zeros(expected_size, dtype=np.int16) padded[:raw_data.size] = raw_data raw_data = padded else: raw_data = raw_data[:expected_size] # Reshape to (Samples, Leads) then Transpose to (Leads, Samples) # stored as (samples, leads) interleaved? Usually yes in WFDB format 16 # Actually, standard WFDB '16' format is often interleaved. # Let's assume interleaved (s1L1, s1L2... s1L12, s2L1...) signal = raw_data.reshape((self.n_samples, self.n_leads)).T # Normalize to mV signal = signal.astype(np.float32) / self.gain return torch.from_numpy(signal) except Exception as e: # print(f"Error loading {path}: {e}") return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32) def get_labels(self, row): """ Extracts labels from report columns. Returns a multi-hot tensor of shape (num_classes). """ # Combine all report text cols = [c for c in self.df.columns if 'report_' in c] full_text = ' '.join([str(row[c]) for c in cols]).lower() # Create label vector label_vec = torch.zeros(self.num_classes, dtype=torch.float32) # Simple string matching # 0: Sinus Rhythm (Normal-ish) if 'sinus rhythm' in full_text: label_vec[0] = 1.0 # 1: Atrial Fibrillation if 'atrial fibrillation' in full_text: label_vec[1] = 1.0 # 2: Tachycardia if 'sinus tachycardia' in full_text: label_vec[2] = 1.0 # 3: Bradycardia if 'sinus bradycardia' in full_text: label_vec[3] = 1.0 # 4: VTach if 'ventricular tachycardia' in full_text: label_vec[4] = 1.0 return label_vec