uyen1109 commited on
Commit
23f06d2
·
verified ·
1 Parent(s): 319d909

Add code: rest_prep_1020.py

Browse files
Files changed (1) hide show
  1. code/rest_prep_1020.py +251 -0
code/rest_prep_1020.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, math, json, argparse, warnings
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Dict, Optional
4
+ import numpy as np, h5py, pyedflib
5
+
6
+ # ---------- 10-20 electrode 2D coords (x,y) đã chuẩn hoá tương đối ----------
7
+ E1020 = {
8
+ "FP1":(-0.5, 1.0), "FP2":( 0.5, 1.0),
9
+ "F7": (-1.0, 0.6), "F3":(-0.3, 0.6), "FZ":(0.0,0.8), "F4":( 0.3, 0.6), "F8":( 1.0, 0.6),
10
+ "T3":(-1.2, 0.2), "T7":(-1.2, 0.2), "C3":(-0.4, 0.2), "CZ":(0.0,0.3), "C4":( 0.4, 0.2), "T4":( 1.2, 0.2), "T8":( 1.2, 0.2),
11
+ "T5":(-1.1, -0.2), "P7":(-1.1, -0.2), "P3":(-0.3, -0.2), "PZ":(0.0,-0.1), "P4":( 0.3, -0.2), "T6":( 1.1, -0.2), "P8":( 1.1, -0.2),
12
+ "O1":(-0.5, -0.8), "O2":( 0.5, -0.8)
13
+ }
14
+ ALIASES = {"T3":"T7","T4":"T8","T5":"P7","T6":"P8"} # đồng nhất tên
15
+
16
+ _BLOCKLIST = re.compile(r'(?:ECG|VNS|-$)', re.IGNORECASE)
17
+
18
+ def _norm_e(name:str)->Optional[str]:
19
+ name=name.upper().strip()
20
+ name=ALIASES.get(name,name)
21
+ return name if name in E1020 else None
22
+
23
+ def pair_midpoint(ch_label:str)->Optional[Tuple[float,float]]:
24
+ """
25
+ CHB-MIT thường dùng kênh lưỡng cực 'A-B'. Lấy toạ độ là trung điểm của hai điện cực A,B.
26
+ """
27
+ if '-' not in ch_label: return None
28
+ a,b = [x.strip().upper() for x in ch_label.split('-',1)]
29
+ a=_norm_e(a); b=_norm_e(b)
30
+ if (a is None) or (b is None): return None
31
+ ax,ay=E1020[a]; bx,by=E1020[b]
32
+ return ((ax+bx)/2.0, (ay+by)/2.0)
33
+
34
+ def build_1020_graph(channels:List[str], k:int=8, sigma:float=0.5, radius:float=1.6):
35
+ """
36
+ Đồ thị khoảng cách tĩnh theo 10–20:
37
+ w_ij = exp(-||vi-vj||^2 / sigma^2), nếu ||vi-vj|| <= radius, ngược lại 0.
38
+ Sau đó với mỗi nút, giữ lại k láng giềng w lớn nhất (undirected).
39
+ """
40
+ coords=[]
41
+ keep=[]
42
+ for ch in channels:
43
+ if _BLOCKLIST.search(ch): continue
44
+ m=pair_midpoint(ch)
45
+ if m is not None:
46
+ coords.append(m); keep.append(ch)
47
+ C=len(keep)
48
+ if C<2:
49
+ raise RuntimeError("Không đủ kênh ánh xạ được sang 10-20")
50
+ coords=np.array(coords, dtype=np.float32)
51
+ D=np.sqrt(((coords[None,:,:]-coords[:,None,:])**2).sum(axis=-1)) # (C,C)
52
+ W=np.exp(-(D**2)/(sigma**2))
53
+ W[D>radius]=0.0
54
+ np.fill_diagonal(W, 0.0)
55
+
56
+ edges=set()
57
+ for i in range(C):
58
+ idx=np.argsort(-W[i])[:max(1,min(k,C-1))]
59
+ for j in idx:
60
+ a,b=(i,j) if i<j else (j,i)
61
+ if W[a,b]>0: edges.add((a,b))
62
+ edges=sorted(list(edges))
63
+ ei=np.array(edges, dtype=np.int64).T
64
+ ei=np.hstack([ei, ei[::-1,:]]) if ei.size else ei
65
+ ew=np.array([W[i,j] for (i,j) in edges], dtype=np.float32)
66
+ ew=np.concatenate([ew, ew]) if ew.size else ew
67
+ return keep, ei, ew
68
+
69
+ # ---------- EDF utils ----------
70
+ def list_edf_files(patient_dir: Path):
71
+ return sorted([p for p in patient_dir.glob("*.edf") if p.is_file()])
72
+
73
+ def edf_channel_labels(edf_path: Path):
74
+ f=pyedflib.EdfReader(str(edf_path))
75
+ labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)]
76
+ f._close(); del f
77
+ return [ch for ch in labels if not _BLOCKLIST.search(ch)]
78
+
79
+ def intersection_channels(edf_paths):
80
+ common=None
81
+ for p in edf_paths:
82
+ chans=set(edf_channel_labels(p))
83
+ if not chans: continue
84
+ common = chans if common is None else (common & chans)
85
+ return sorted(list(common)) if common else []
86
+
87
+ def read_edf_signals(edf_path: Path, keep_channels):
88
+ f=pyedflib.EdfReader(str(edf_path))
89
+ labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)]
90
+ fs_all=[int(round(f.getSampleFrequency(i))) for i in range(f.signals_in_file)]
91
+ fs=int(round(np.median(fs_all))) if fs_all else 256
92
+ idxs=[]; out_labels=[]
93
+ for ch in keep_channels:
94
+ try:
95
+ i=labels.index(ch)
96
+ if _BLOCKLIST.search(ch): continue
97
+ idxs.append(i); out_labels.append(ch)
98
+ except ValueError:
99
+ f._close(); del f
100
+ raise RuntimeError(f"Channel {ch} not found in {edf_path.name}")
101
+ sigs=np.vstack([f.readSignal(i) for i in idxs]).astype(np.float32) # (C,N)
102
+ f._close(); del f
103
+ return sigs, fs, out_labels
104
+
105
+ # ---------- nhãn co giật ----------
106
+ def parse_summary(summary_path: Path) -> Dict[str, List[Tuple[float,float]]]:
107
+ mapping={}
108
+ if not (summary_path and summary_path.exists()):
109
+ return mapping
110
+ curr=None; buf=[]
111
+ with summary_path.open("r", errors="ignore") as f:
112
+ for line in f:
113
+ line=line.strip()
114
+ mfile=re.search(r'File Name:\s*(\S+\.edf)', line, re.IGNORECASE)
115
+ if mfile:
116
+ if curr and buf:
117
+ pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)]
118
+ mapping.setdefault(curr, []).extend(pairs)
119
+ curr=mfile.group(1); buf=[]; continue
120
+ if re.search(r'Seizure (Start|End) Time', line, re.IGNORECASE):
121
+ nums=[float(x) for x in re.findall(r'[\d.]+', line)]
122
+ if nums: buf.extend(nums)
123
+ if curr and buf:
124
+ pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)]
125
+ mapping.setdefault(curr, []).extend(pairs)
126
+ return mapping
127
+
128
+ def parse_seizures_file(seiz_file: Path) -> List[Tuple[float,float]]:
129
+ intervals=[]
130
+ if not seiz_file.exists(): return intervals
131
+ with seiz_file.open("r", errors="ignore") as f:
132
+ for line in f:
133
+ nums=[float(x) for x in re.findall(r'[-+]?\d*\.?\d+', line)]
134
+ if len(nums)>=2: intervals.append((nums[0], nums[1]))
135
+ return intervals
136
+
137
+ def slice_starts(N, fs, clip_sec, hop_sec):
138
+ T=int(fs*clip_sec); hop=int(fs*hop_sec)
139
+ if N<T: return np.zeros((0,), dtype=np.int64)
140
+ return np.arange(0, N-T+1, hop, dtype=np.int64)
141
+
142
+ def zscore_perclip(clip):
143
+ mu=clip.mean(axis=1, keepdims=True); sd=clip.std(axis=1, keepdims=True)+1e-8
144
+ return (clip-mu)/sd
145
+
146
+ def label_for_window(start, T, fs, intervals, min_overlap_sec):
147
+ a,b=start, start+T
148
+ thr=int(round(min_overlap_sec*fs)); overlap=0
149
+ for (u,v) in intervals:
150
+ u_s=int(round(u*fs)); v_s=int(round(v*fs))
151
+ overlap += max(0, min(b, v_s) - max(a, u_s))
152
+ if overlap>=thr: return 1
153
+ return 0
154
+
155
+ class H5Appender:
156
+ def __init__(self, out_path: Path, C: int, T: int, fs: int, channels, edge_index, edge_weight, gzip_level=4):
157
+ self.f=h5py.File(str(out_path),"w")
158
+ self.ds_clips=self.f.create_dataset("clips", shape=(0,C,T,1), maxshape=(None,C,T,1),
159
+ dtype="float32", chunks=(16,C,T,1), compression="gzip", compression_opts=gzip_level)
160
+ self.ds_labels=self.f.create_dataset("labels", shape=(0,), maxshape=(None,), dtype="i8",
161
+ chunks=True, compression="gzip", compression_opts=gzip_level)
162
+ self.ds_fileids=self.f.create_dataset("file_ids", shape=(0,), maxshape=(None,), dtype="i4",
163
+ chunks=True, compression="gzip", compression_opts=gzip_level)
164
+ self.f.attrs["fs"]=fs; self.f.attrs["T"]=T; self.f.attrs["patient"]=out_path.stem
165
+ self.f.create_dataset("channels", data=np.array([c.encode() for c in channels]))
166
+ self.f.create_dataset("edge_index", data=edge_index.astype(np.int64))
167
+ self.f.create_dataset("edge_weight", data=edge_weight.astype(np.float32))
168
+ self.n=0
169
+ def append(self, clips_CT, labels, file_id:int):
170
+ if clips_CT.size==0: return
171
+ M,C,T=clips_CT.shape
172
+ clips=clips_CT[...,None].astype(np.float32)
173
+ self.ds_clips.resize(self.n+M, axis=0)
174
+ self.ds_labels.resize(self.n+M, axis=0)
175
+ self.ds_fileids.resize(self.n+M, axis=0)
176
+ self.ds_clips[self.n:self.n+M]=clips
177
+ self.ds_labels[self.n:self.n+M]=labels.astype(np.int64)
178
+ self.ds_fileids[self.n:self.n+M]=np.full((M,), file_id, dtype=np.int32)
179
+ self.n+=M
180
+ def close(self): self.f.close()
181
+
182
+ def process_patient_1020(root: Path, patient: str, out_path: Path,
183
+ clip_sec: float=4.0, hop_sec: float=2.0,
184
+ fs_target: int=256, min_overlap_sec: float=0.25,
185
+ graph_k: int=8, sigma: float=0.5, radius: float=1.6):
186
+ pat_dir=root/patient
187
+ assert pat_dir.exists(), f"Not found: {pat_dir}"
188
+ edfs=list_edf_files(pat_dir); assert edfs, f"No EDF in {pat_dir}"
189
+
190
+ # kênh giao nhau trong bệnh nhân
191
+ keep_channels=intersection_channels(edfs)
192
+ # chỉ giữ kênh map được sang 10-20 (bỏ kênh lạ)
193
+ keep_channels=[ch for ch in keep_channels if pair_midpoint(ch) is not None]
194
+ if len(keep_channels)<8:
195
+ warnings.warn(f"[{patient}] chỉ còn {len(keep_channels)} kênh sau khi map 10-20")
196
+
197
+ # đồ thị 10-20 tĩnh
198
+ keep_channels, ei, ew = build_1020_graph(keep_channels, k=graph_k, sigma=sigma, radius=radius)
199
+
200
+ summ = next(iter(list(pat_dir.glob("*summary*.txt"))), None)
201
+ summ_map=parse_summary(summ) if summ else {}
202
+ seiz_map={}
203
+ for p in edfs:
204
+ iv=parse_seizures_file(p.with_suffix(p.suffix+".seizures"))
205
+ if iv: seiz_map[p.name]=iv
206
+ def intervals_for(name): return seiz_map.get(name, summ_map.get(name, []))
207
+
208
+ app=None; file_id=0
209
+ total_pos=0; total=0
210
+
211
+ for edf in edfs:
212
+ sigs, fs, chans=read_edf_signals(edf, keep_channels)
213
+ # resample
214
+ if fs!=fs_target:
215
+ ratio=fs_target/fs
216
+ N_new=int(round(sigs.shape[1]*ratio))
217
+ t_old=np.linspace(0, sigs.shape[1]-1, sigs.shape[1], dtype=np.float32)
218
+ t_new=np.linspace(0, sigs.shape[1]-1, N_new, dtype=np.float32)
219
+ sigs=np.stack([np.interp(t_new, t_old, ch) for ch in sigs], axis=0).astype(np.float32)
220
+
221
+ C,N=sigs.shape; T=int(fs_target*clip_sec)
222
+ starts=slice_starts(N, fs_target, clip_sec, hop_sec)
223
+ if app is None:
224
+ app=H5Appender(out_path, C=C, T=T, fs=fs_target, channels=chans, edge_index=ei, edge_weight=ew)
225
+ ivals=intervals_for(edf.name)
226
+ labels=np.array([label_for_window(int(s), T, fs_target, ivals, min_overlap_sec) for s in starts], dtype=np.int64)
227
+ M=len(starts); clips=np.empty((M,C,T), dtype=np.float32)
228
+ for i,s in enumerate(starts): clips[i]=zscore_perclip(sigs[:, s:s+T])
229
+ app.append(clips, labels, file_id=file_id)
230
+ total += M; total_pos += int(labels.sum()); file_id+=1
231
+
232
+ app.close()
233
+ print(f"[Done] {patient} -> {out_path} | clips={total} (pos={total_pos}, neg={total-total_pos}), C={len(keep_channels)}, T={int(fs_target*clip_sec)}")
234
+
235
+ if __name__=="__main__":
236
+ import argparse
237
+ ap=argparse.ArgumentParser()
238
+ ap.add_argument("--root", required=True)
239
+ ap.add_argument("--patient", required=True)
240
+ ap.add_argument("--out", required=True)
241
+ ap.add_argument("--clip-sec", type=float, default=4.0)
242
+ ap.add_argument("--hop-sec", type=float, default=2.0)
243
+ ap.add_argument("--fs", type=int, default=256)
244
+ ap.add_argument("--min-overlap", type=float, default=0.25)
245
+ ap.add_argument("--k", type=int, default=8)
246
+ ap.add_argument("--sigma", type=float, default=0.5)
247
+ ap.add_argument("--radius", type=float, default=1.6)
248
+ args=ap.parse_args()
249
+ process_patient_1020(Path(args.root), args.patient, Path(args.out),
250
+ clip_sec=args.clip_sec, hop_sec=args.hop_sec, fs_target=args.fs,
251
+ min_overlap_sec=args.min_overlap, graph_k=args.k, sigma=args.sigma, radius=args.radius)