Spaces:
Running
Running
File size: 5,906 Bytes
e3b4744 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import os
class MIMICECGDataset(Dataset):
"""
PyTorch Dataset for MIMIC-IV-ECG.
"""
def __init__(self, df, data_dir, transform=False, label_func=None):
"""
Args:
df (pd.DataFrame): Dataframe with subject_id, study_id, and report columns.
data_dir (str): Root directory of the dataset (containing the 'files' folder).
transform (callable, optional): Optional transform to be applied on a sample.
label_func (callable, optional): Custom function to extract labels from a row.
"""
self.df = df
self.data_dir = data_dir
self.transform = transform
self.label_func = label_func
# MIMIC-ECG Constants
self.n_leads = 12
self.n_samples = 5000
self.fs = 500
self.gain = 200.0 # Standard gain in ADU/mV
# Define the target classes we want to detect
# These keys will be searched in the report columns
self.class_mapping = {
'Normally filtered': 0, # Not a diagnosis, but often present
'Sinus rhythm': 0,
'Atrial fibrillation': 1,
'Sinus tachycardia': 2,
'Sinus bradycardia': 3,
'Ventricular tachycardia': 4,
# Add more as needed
}
self.num_classes = 5 # For now
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
row = self.df.iloc[idx]
subj_id = str(row['subject_id'])
study_id = str(row['study_id'])
# Construct path: files/p{XXX}/p{subject_id}/s{study_id}/{study_id}.dat
subdir = f"p{subj_id[:4]}"
# Ensure we handle the folder structure correctly.
# Based on exploration: data_dir/files/p100/p10000032/s40689238/40689238.dat
file_path = os.path.join(self.data_dir, 'files', subdir, f"p{subj_id}", f"s{study_id}", f"{study_id}.dat")
# 1. Load Signal
signal = self.load_signal_numpy(file_path)
# 2. Get Labels
if self.label_func:
# Need to pass text to label_func, or row?
# get_refined_labels expects text. Let's extract text here or let func handle row.
# Best to let func handle text so it's pure.
cols = [c for c in self.df.columns if 'report_' in c]
full_text = ' '.join([str(row[c]) for c in cols])
labels = self.label_func(full_text)
else:
labels = self.get_labels(row)
# 3. Return sample
# Signal shape: (12, 5000)
sample = {
'signal': signal,
'labels': labels,
'study_id': study_id
}
return sample
def load_signal_numpy(self, path):
"""
Reads the binary .dat file using numpy.
Returns a torch tensor of shape (12, 5000).
"""
# Return zeros if file is missing (to avoid crashing training loop on missing files)
if not os.path.exists(path):
return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
try:
# Read binary file as 16-bit integers
raw_data = np.fromfile(path, dtype=np.int16)
# Check size
expected_size = self.n_leads * self.n_samples
if raw_data.size != expected_size:
# Handle truncated or wrong-sized files by padding or cutting
if raw_data.size < expected_size:
padded = np.zeros(expected_size, dtype=np.int16)
padded[:raw_data.size] = raw_data
raw_data = padded
else:
raw_data = raw_data[:expected_size]
# Reshape to (Samples, Leads) then Transpose to (Leads, Samples)
# stored as (samples, leads) interleaved? Usually yes in WFDB format 16
# Actually, standard WFDB '16' format is often interleaved.
# Let's assume interleaved (s1L1, s1L2... s1L12, s2L1...)
signal = raw_data.reshape((self.n_samples, self.n_leads)).T
# Normalize to mV
signal = signal.astype(np.float32) / self.gain
return torch.from_numpy(signal)
except Exception as e:
# print(f"Error loading {path}: {e}")
return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
def get_labels(self, row):
"""
Extracts labels from report columns.
Returns a multi-hot tensor of shape (num_classes).
"""
# Combine all report text
cols = [c for c in self.df.columns if 'report_' in c]
full_text = ' '.join([str(row[c]) for c in cols]).lower()
# Create label vector
label_vec = torch.zeros(self.num_classes, dtype=torch.float32)
# Simple string matching
# 0: Sinus Rhythm (Normal-ish)
if 'sinus rhythm' in full_text:
label_vec[0] = 1.0
# 1: Atrial Fibrillation
if 'atrial fibrillation' in full_text:
label_vec[1] = 1.0
# 2: Tachycardia
if 'sinus tachycardia' in full_text:
label_vec[2] = 1.0
# 3: Bradycardia
if 'sinus bradycardia' in full_text:
label_vec[3] = 1.0
# 4: VTach
if 'ventricular tachycardia' in full_text:
label_vec[4] = 1.0
return label_vec
|