ECG / dataset.py
IFMedTechdemo's picture
Upload folder using huggingface_hub
e3b4744 verified
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import os
class MIMICECGDataset(Dataset):
"""
PyTorch Dataset for MIMIC-IV-ECG.
"""
def __init__(self, df, data_dir, transform=False, label_func=None):
"""
Args:
df (pd.DataFrame): Dataframe with subject_id, study_id, and report columns.
data_dir (str): Root directory of the dataset (containing the 'files' folder).
transform (callable, optional): Optional transform to be applied on a sample.
label_func (callable, optional): Custom function to extract labels from a row.
"""
self.df = df
self.data_dir = data_dir
self.transform = transform
self.label_func = label_func
# MIMIC-ECG Constants
self.n_leads = 12
self.n_samples = 5000
self.fs = 500
self.gain = 200.0 # Standard gain in ADU/mV
# Define the target classes we want to detect
# These keys will be searched in the report columns
self.class_mapping = {
'Normally filtered': 0, # Not a diagnosis, but often present
'Sinus rhythm': 0,
'Atrial fibrillation': 1,
'Sinus tachycardia': 2,
'Sinus bradycardia': 3,
'Ventricular tachycardia': 4,
# Add more as needed
}
self.num_classes = 5 # For now
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
row = self.df.iloc[idx]
subj_id = str(row['subject_id'])
study_id = str(row['study_id'])
# Construct path: files/p{XXX}/p{subject_id}/s{study_id}/{study_id}.dat
subdir = f"p{subj_id[:4]}"
# Ensure we handle the folder structure correctly.
# Based on exploration: data_dir/files/p100/p10000032/s40689238/40689238.dat
file_path = os.path.join(self.data_dir, 'files', subdir, f"p{subj_id}", f"s{study_id}", f"{study_id}.dat")
# 1. Load Signal
signal = self.load_signal_numpy(file_path)
# 2. Get Labels
if self.label_func:
# Need to pass text to label_func, or row?
# get_refined_labels expects text. Let's extract text here or let func handle row.
# Best to let func handle text so it's pure.
cols = [c for c in self.df.columns if 'report_' in c]
full_text = ' '.join([str(row[c]) for c in cols])
labels = self.label_func(full_text)
else:
labels = self.get_labels(row)
# 3. Return sample
# Signal shape: (12, 5000)
sample = {
'signal': signal,
'labels': labels,
'study_id': study_id
}
return sample
def load_signal_numpy(self, path):
"""
Reads the binary .dat file using numpy.
Returns a torch tensor of shape (12, 5000).
"""
# Return zeros if file is missing (to avoid crashing training loop on missing files)
if not os.path.exists(path):
return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
try:
# Read binary file as 16-bit integers
raw_data = np.fromfile(path, dtype=np.int16)
# Check size
expected_size = self.n_leads * self.n_samples
if raw_data.size != expected_size:
# Handle truncated or wrong-sized files by padding or cutting
if raw_data.size < expected_size:
padded = np.zeros(expected_size, dtype=np.int16)
padded[:raw_data.size] = raw_data
raw_data = padded
else:
raw_data = raw_data[:expected_size]
# Reshape to (Samples, Leads) then Transpose to (Leads, Samples)
# stored as (samples, leads) interleaved? Usually yes in WFDB format 16
# Actually, standard WFDB '16' format is often interleaved.
# Let's assume interleaved (s1L1, s1L2... s1L12, s2L1...)
signal = raw_data.reshape((self.n_samples, self.n_leads)).T
# Normalize to mV
signal = signal.astype(np.float32) / self.gain
return torch.from_numpy(signal)
except Exception as e:
# print(f"Error loading {path}: {e}")
return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)
def get_labels(self, row):
"""
Extracts labels from report columns.
Returns a multi-hot tensor of shape (num_classes).
"""
# Combine all report text
cols = [c for c in self.df.columns if 'report_' in c]
full_text = ' '.join([str(row[c]) for c in cols]).lower()
# Create label vector
label_vec = torch.zeros(self.num_classes, dtype=torch.float32)
# Simple string matching
# 0: Sinus Rhythm (Normal-ish)
if 'sinus rhythm' in full_text:
label_vec[0] = 1.0
# 1: Atrial Fibrillation
if 'atrial fibrillation' in full_text:
label_vec[1] = 1.0
# 2: Tachycardia
if 'sinus tachycardia' in full_text:
label_vec[2] = 1.0
# 3: Bradycardia
if 'sinus bradycardia' in full_text:
label_vec[3] = 1.0
# 4: VTach
if 'ventricular tachycardia' in full_text:
label_vec[4] = 1.0
return label_vec