File size: 5,906 Bytes
e3b4744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import os

class MIMICECGDataset(Dataset):
    """

    PyTorch Dataset for MIMIC-IV-ECG.

    """
    def __init__(self, df, data_dir, transform=False, label_func=None):
        """

        Args:

            df (pd.DataFrame): Dataframe with subject_id, study_id, and report columns.

            data_dir (str): Root directory of the dataset (containing the 'files' folder).

            transform (callable, optional): Optional transform to be applied on a sample.

            label_func (callable, optional): Custom function to extract labels from a row.

        """
        self.df = df
        self.data_dir = data_dir
        self.transform = transform
        self.label_func = label_func
        
        # MIMIC-ECG Constants
        self.n_leads = 12
        self.n_samples = 5000
        self.fs = 500
        self.gain = 200.0 # Standard gain in ADU/mV
        
        # Define the target classes we want to detect
        # These keys will be searched in the report columns
        self.class_mapping = {
            'Normally filtered': 0, # Not a diagnosis, but often present
            'Sinus rhythm': 0,
            'Atrial fibrillation': 1,
            'Sinus tachycardia': 2,
            'Sinus bradycardia': 3,
            'Ventricular tachycardia': 4,
            # Add more as needed
        }
        self.num_classes = 5 # For now
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        row = self.df.iloc[idx]
        subj_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        
        # Construct path: files/p{XXX}/p{subject_id}/s{study_id}/{study_id}.dat
        subdir = f"p{subj_id[:4]}"
        # Ensure we handle the folder structure correctly. 
        # Based on exploration: data_dir/files/p100/p10000032/s40689238/40689238.dat
        file_path = os.path.join(self.data_dir, 'files', subdir, f"p{subj_id}", f"s{study_id}", f"{study_id}.dat")
        
        # 1. Load Signal
        signal = self.load_signal_numpy(file_path)
        
        # 2. Get Labels
        if self.label_func:
            # Need to pass text to label_func, or row?
            # get_refined_labels expects text. Let's extract text here or let func handle row.
            # Best to let func handle text so it's pure.
            cols = [c for c in self.df.columns if 'report_' in c]
            full_text = ' '.join([str(row[c]) for c in cols])
            labels = self.label_func(full_text)
        else:
            labels = self.get_labels(row)
        
        # 3. Return sample
        # Signal shape: (12, 5000)
        sample = {
            'signal': signal, 
            'labels': labels, 
            'study_id': study_id
        }
        
        return sample

    def load_signal_numpy(self, path):
        """

        Reads the binary .dat file using numpy.

        Returns a torch tensor of shape (12, 5000).

        """
        # Return zeros if file is missing (to avoid crashing training loop on missing files)
        if not os.path.exists(path):
            return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)

        try:
            # Read binary file as 16-bit integers
            raw_data = np.fromfile(path, dtype=np.int16)
            
            # Check size
            expected_size = self.n_leads * self.n_samples
            
            if raw_data.size != expected_size:
                # Handle truncated or wrong-sized files by padding or cutting
                if raw_data.size < expected_size:
                    padded = np.zeros(expected_size, dtype=np.int16)
                    padded[:raw_data.size] = raw_data
                    raw_data = padded
                else:
                    raw_data = raw_data[:expected_size]
            
            # Reshape to (Samples, Leads) then Transpose to (Leads, Samples)
            # stored as (samples, leads) interleaved? Usually yes in WFDB format 16
            # Actually, standard WFDB '16' format is often interleaved.
            # Let's assume interleaved (s1L1, s1L2... s1L12, s2L1...)
            signal = raw_data.reshape((self.n_samples, self.n_leads)).T 
            
            # Normalize to mV
            signal = signal.astype(np.float32) / self.gain
            
            return torch.from_numpy(signal)
            
        except Exception as e:
            # print(f"Error loading {path}: {e}")
            return torch.zeros((self.n_leads, self.n_samples), dtype=torch.float32)

    def get_labels(self, row):
        """

        Extracts labels from report columns.

        Returns a multi-hot tensor of shape (num_classes).

        """
        # Combine all report text
        cols = [c for c in self.df.columns if 'report_' in c]
        full_text = ' '.join([str(row[c]) for c in cols]).lower()
        
        # Create label vector
        label_vec = torch.zeros(self.num_classes, dtype=torch.float32)
        
        # Simple string matching
        # 0: Sinus Rhythm (Normal-ish)
        if 'sinus rhythm' in full_text:
            label_vec[0] = 1.0
            
        # 1: Atrial Fibrillation
        if 'atrial fibrillation' in full_text:
            label_vec[1] = 1.0
            
        # 2: Tachycardia
        if 'sinus tachycardia' in full_text:
            label_vec[2] = 1.0

        # 3: Bradycardia
        if 'sinus bradycardia' in full_text:
            label_vec[3] = 1.0
            
        # 4: VTach
        if 'ventricular tachycardia' in full_text:
            label_vec[4] = 1.0
            
        return label_vec