Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| class MIMICECGDataset(Dataset): | |
| """ | |
| PyTorch Dataset for MIMIC-IV-ECG. | |
| """ | |
| def __init__(self, df, data_dir, transform=False, label_func=None): | |
| """ | |
| Args: | |
| df (pd.DataFrame): Dataframe with subject_id, study_id, and report columns. | |
| data_dir (str): Root directory of the dataset (containing the 'files' folder). | |
| transform (callable, optional): Optional transform to be applied on a sample. | |
| label_func (callable, optional): Custom function to extract labels from a row. | |
| """ | |
| self.df = df | |
| self.data_dir = data_dir | |
| self.transform = transform | |
| self.label_func = label_func | |
| # MIMIC-ECG Constants | |
| self.n_leads = 12 | |
| self.n_samples = 5000 | |
| self.fs = 500 | |
| self.gain = 200.0 # Standard gain in ADU/mV | |
| # Define the target classes we want to detect | |
| # These keys will be searched in the report columns | |
| self.class_mapping = { | |
| 'Normally filtered': 0, # Not a diagnosis, but often present | |
| 'Sinus rhythm': 0, | |
| 'Atrial fibrillation': 1, | |
| 'Sinus tachycardia': 2, | |
| 'Sinus bradycardia': 3, | |
| 'Ventricular tachycardia': 4, | |
| # Add more as needed | |
| } | |
| self.num_classes = 5 # For now | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| if torch.is_tensor(idx): | |
| idx = idx.tolist() | |
| row = self.df.iloc[idx] | |
| subj_id = str(row['subject_id']) | |
| study_id = str(row['study_id']) | |
| # Construct path: files/p{XXX}/p{subject_id}/s{study_id}/{study_id}.dat | |
| subdir = f"p{subj_id[:4]}" | |
| # Ensure we handle the folder structure correctly. | |
| # Based on exploration: data_dir/files/p100/p10000032/s40689238/40689238.dat | |
| file_path = os.path.join(self.data_dir, 'files', subdir, f"p{subj_id}", f"s{study_id}", f"{study_id}.dat") | |
| # 1. Load Signal | |
| signal = self.load_signal_numpy(file_path) | |
| # 2. Get Labels | |
| if self.label_func: | |
| # Need to pass text to label_func, or row? | |
| # get_refined_labels expects text. Let's extract text here or let func handle row. | |
| # Best to let func handle text so it's pure. | |
| cols = [c for c in self.df.columns if 'report_' in c] | |
| full_text = ' '.join([str(row[c]) for c in cols]) | |
| labels = self.label_func(full_text) | |
| else: | |
| labels = self.get_labels(row) | |
| # 3. Return sample | |
| # Signal shape: (12, 5000) | |
| sample = { | |
| 'signal': signal, | |
| 'labels': labels, | |
| 'study_id': study_id | |
| } | |
| return sample | |
| def load_signal_numpy(self, path): | |
| """ | |
| Reads the binary .dat file using numpy. | |
| Returns a torch tensor of shape (12, 5000). | |
| """ | |
| # Return zeros if file is missing (to avoid crashing training loop on missing files) | |
| if not os.path.exists(path): | |
| return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32) | |
| try: | |
| # Read binary file as 16-bit integers | |
| raw_data = np.fromfile(path, dtype=np.int16) | |
| # Check size | |
| expected_size = self.n_leads * self.n_samples | |
| if raw_data.size != expected_size: | |
| # Handle truncated or wrong-sized files by padding or cutting | |
| if raw_data.size < expected_size: | |
| padded = np.zeros(expected_size, dtype=np.int16) | |
| padded[:raw_data.size] = raw_data | |
| raw_data = padded | |
| else: | |
| raw_data = raw_data[:expected_size] | |
| # Reshape to (Samples, Leads) then Transpose to (Leads, Samples) | |
| # stored as (samples, leads) interleaved? Usually yes in WFDB format 16 | |
| # Actually, standard WFDB '16' format is often interleaved. | |
| # Let's assume interleaved (s1L1, s1L2... s1L12, s2L1...) | |
| signal = raw_data.reshape((self.n_samples, self.n_leads)).T | |
| # Normalize to mV | |
| signal = signal.astype(np.float32) / self.gain | |
| return torch.from_numpy(signal) | |
| except Exception as e: | |
| # print(f"Error loading {path}: {e}") | |
| return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32) | |
| def get_labels(self, row): | |
| """ | |
| Extracts labels from report columns. | |
| Returns a multi-hot tensor of shape (num_classes). | |
| """ | |
| # Combine all report text | |
| cols = [c for c in self.df.columns if 'report_' in c] | |
| full_text = ' '.join([str(row[c]) for c in cols]).lower() | |
| # Create label vector | |
| label_vec = torch.zeros(self.num_classes, dtype=torch.float32) | |
| # Simple string matching | |
| # 0: Sinus Rhythm (Normal-ish) | |
| if 'sinus rhythm' in full_text: | |
| label_vec[0] = 1.0 | |
| # 1: Atrial Fibrillation | |
| if 'atrial fibrillation' in full_text: | |
| label_vec[1] = 1.0 | |
| # 2: Tachycardia | |
| if 'sinus tachycardia' in full_text: | |
| label_vec[2] = 1.0 | |
| # 3: Bradycardia | |
| if 'sinus bradycardia' in full_text: | |
| label_vec[3] = 1.0 | |
| # 4: VTach | |
| if 'ventricular tachycardia' in full_text: | |
| label_vec[4] = 1.0 | |
| return label_vec | |