Add code: rest_prep_1020.py
Browse files- code/rest_prep_1020.py +251 -0
code/rest_prep_1020.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, re, math, json, argparse, warnings
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Tuple, Dict, Optional
|
| 4 |
+
import numpy as np, h5py, pyedflib
|
| 5 |
+
|
| 6 |
+
# ---------- 10-20 electrode 2D coords (x,y) đã chuẩn hoá tương đối ----------
|
| 7 |
+
E1020 = {
|
| 8 |
+
"FP1":(-0.5, 1.0), "FP2":( 0.5, 1.0),
|
| 9 |
+
"F7": (-1.0, 0.6), "F3":(-0.3, 0.6), "FZ":(0.0,0.8), "F4":( 0.3, 0.6), "F8":( 1.0, 0.6),
|
| 10 |
+
"T3":(-1.2, 0.2), "T7":(-1.2, 0.2), "C3":(-0.4, 0.2), "CZ":(0.0,0.3), "C4":( 0.4, 0.2), "T4":( 1.2, 0.2), "T8":( 1.2, 0.2),
|
| 11 |
+
"T5":(-1.1, -0.2), "P7":(-1.1, -0.2), "P3":(-0.3, -0.2), "PZ":(0.0,-0.1), "P4":( 0.3, -0.2), "T6":( 1.1, -0.2), "P8":( 1.1, -0.2),
|
| 12 |
+
"O1":(-0.5, -0.8), "O2":( 0.5, -0.8)
|
| 13 |
+
}
|
| 14 |
+
ALIASES = {"T3":"T7","T4":"T8","T5":"P7","T6":"P8"} # đồng nhất tên
|
| 15 |
+
|
| 16 |
+
_BLOCKLIST = re.compile(r'(?:ECG|VNS|-$)', re.IGNORECASE)
|
| 17 |
+
|
| 18 |
+
def _norm_e(name:str)->Optional[str]:
|
| 19 |
+
name=name.upper().strip()
|
| 20 |
+
name=ALIASES.get(name,name)
|
| 21 |
+
return name if name in E1020 else None
|
| 22 |
+
|
| 23 |
+
def pair_midpoint(ch_label:str)->Optional[Tuple[float,float]]:
|
| 24 |
+
"""
|
| 25 |
+
CHB-MIT thường dùng kênh lưỡng cực 'A-B'. Lấy toạ độ là trung điểm của hai điện cực A,B.
|
| 26 |
+
"""
|
| 27 |
+
if '-' not in ch_label: return None
|
| 28 |
+
a,b = [x.strip().upper() for x in ch_label.split('-',1)]
|
| 29 |
+
a=_norm_e(a); b=_norm_e(b)
|
| 30 |
+
if (a is None) or (b is None): return None
|
| 31 |
+
ax,ay=E1020[a]; bx,by=E1020[b]
|
| 32 |
+
return ((ax+bx)/2.0, (ay+by)/2.0)
|
| 33 |
+
|
| 34 |
+
def build_1020_graph(channels:List[str], k:int=8, sigma:float=0.5, radius:float=1.6):
|
| 35 |
+
"""
|
| 36 |
+
Đồ thị khoảng cách tĩnh theo 10–20:
|
| 37 |
+
w_ij = exp(-||vi-vj||^2 / sigma^2), nếu ||vi-vj|| <= radius, ngược lại 0.
|
| 38 |
+
Sau đó với mỗi nút, giữ lại k láng giềng w lớn nhất (undirected).
|
| 39 |
+
"""
|
| 40 |
+
coords=[]
|
| 41 |
+
keep=[]
|
| 42 |
+
for ch in channels:
|
| 43 |
+
if _BLOCKLIST.search(ch): continue
|
| 44 |
+
m=pair_midpoint(ch)
|
| 45 |
+
if m is not None:
|
| 46 |
+
coords.append(m); keep.append(ch)
|
| 47 |
+
C=len(keep)
|
| 48 |
+
if C<2:
|
| 49 |
+
raise RuntimeError("Không đủ kênh ánh xạ được sang 10-20")
|
| 50 |
+
coords=np.array(coords, dtype=np.float32)
|
| 51 |
+
D=np.sqrt(((coords[None,:,:]-coords[:,None,:])**2).sum(axis=-1)) # (C,C)
|
| 52 |
+
W=np.exp(-(D**2)/(sigma**2))
|
| 53 |
+
W[D>radius]=0.0
|
| 54 |
+
np.fill_diagonal(W, 0.0)
|
| 55 |
+
|
| 56 |
+
edges=set()
|
| 57 |
+
for i in range(C):
|
| 58 |
+
idx=np.argsort(-W[i])[:max(1,min(k,C-1))]
|
| 59 |
+
for j in idx:
|
| 60 |
+
a,b=(i,j) if i<j else (j,i)
|
| 61 |
+
if W[a,b]>0: edges.add((a,b))
|
| 62 |
+
edges=sorted(list(edges))
|
| 63 |
+
ei=np.array(edges, dtype=np.int64).T
|
| 64 |
+
ei=np.hstack([ei, ei[::-1,:]]) if ei.size else ei
|
| 65 |
+
ew=np.array([W[i,j] for (i,j) in edges], dtype=np.float32)
|
| 66 |
+
ew=np.concatenate([ew, ew]) if ew.size else ew
|
| 67 |
+
return keep, ei, ew
|
| 68 |
+
|
| 69 |
+
# ---------- EDF utils ----------
|
| 70 |
+
def list_edf_files(patient_dir: Path):
|
| 71 |
+
return sorted([p for p in patient_dir.glob("*.edf") if p.is_file()])
|
| 72 |
+
|
| 73 |
+
def edf_channel_labels(edf_path: Path):
|
| 74 |
+
f=pyedflib.EdfReader(str(edf_path))
|
| 75 |
+
labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)]
|
| 76 |
+
f._close(); del f
|
| 77 |
+
return [ch for ch in labels if not _BLOCKLIST.search(ch)]
|
| 78 |
+
|
| 79 |
+
def intersection_channels(edf_paths):
|
| 80 |
+
common=None
|
| 81 |
+
for p in edf_paths:
|
| 82 |
+
chans=set(edf_channel_labels(p))
|
| 83 |
+
if not chans: continue
|
| 84 |
+
common = chans if common is None else (common & chans)
|
| 85 |
+
return sorted(list(common)) if common else []
|
| 86 |
+
|
| 87 |
+
def read_edf_signals(edf_path: Path, keep_channels):
|
| 88 |
+
f=pyedflib.EdfReader(str(edf_path))
|
| 89 |
+
labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)]
|
| 90 |
+
fs_all=[int(round(f.getSampleFrequency(i))) for i in range(f.signals_in_file)]
|
| 91 |
+
fs=int(round(np.median(fs_all))) if fs_all else 256
|
| 92 |
+
idxs=[]; out_labels=[]
|
| 93 |
+
for ch in keep_channels:
|
| 94 |
+
try:
|
| 95 |
+
i=labels.index(ch)
|
| 96 |
+
if _BLOCKLIST.search(ch): continue
|
| 97 |
+
idxs.append(i); out_labels.append(ch)
|
| 98 |
+
except ValueError:
|
| 99 |
+
f._close(); del f
|
| 100 |
+
raise RuntimeError(f"Channel {ch} not found in {edf_path.name}")
|
| 101 |
+
sigs=np.vstack([f.readSignal(i) for i in idxs]).astype(np.float32) # (C,N)
|
| 102 |
+
f._close(); del f
|
| 103 |
+
return sigs, fs, out_labels
|
| 104 |
+
|
| 105 |
+
# ---------- nhãn co giật ----------
|
| 106 |
+
def parse_summary(summary_path: Path) -> Dict[str, List[Tuple[float,float]]]:
|
| 107 |
+
mapping={}
|
| 108 |
+
if not (summary_path and summary_path.exists()):
|
| 109 |
+
return mapping
|
| 110 |
+
curr=None; buf=[]
|
| 111 |
+
with summary_path.open("r", errors="ignore") as f:
|
| 112 |
+
for line in f:
|
| 113 |
+
line=line.strip()
|
| 114 |
+
mfile=re.search(r'File Name:\s*(\S+\.edf)', line, re.IGNORECASE)
|
| 115 |
+
if mfile:
|
| 116 |
+
if curr and buf:
|
| 117 |
+
pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)]
|
| 118 |
+
mapping.setdefault(curr, []).extend(pairs)
|
| 119 |
+
curr=mfile.group(1); buf=[]; continue
|
| 120 |
+
if re.search(r'Seizure (Start|End) Time', line, re.IGNORECASE):
|
| 121 |
+
nums=[float(x) for x in re.findall(r'[\d.]+', line)]
|
| 122 |
+
if nums: buf.extend(nums)
|
| 123 |
+
if curr and buf:
|
| 124 |
+
pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)]
|
| 125 |
+
mapping.setdefault(curr, []).extend(pairs)
|
| 126 |
+
return mapping
|
| 127 |
+
|
| 128 |
+
def parse_seizures_file(seiz_file: Path) -> List[Tuple[float,float]]:
|
| 129 |
+
intervals=[]
|
| 130 |
+
if not seiz_file.exists(): return intervals
|
| 131 |
+
with seiz_file.open("r", errors="ignore") as f:
|
| 132 |
+
for line in f:
|
| 133 |
+
nums=[float(x) for x in re.findall(r'[-+]?\d*\.?\d+', line)]
|
| 134 |
+
if len(nums)>=2: intervals.append((nums[0], nums[1]))
|
| 135 |
+
return intervals
|
| 136 |
+
|
| 137 |
+
def slice_starts(N, fs, clip_sec, hop_sec):
|
| 138 |
+
T=int(fs*clip_sec); hop=int(fs*hop_sec)
|
| 139 |
+
if N<T: return np.zeros((0,), dtype=np.int64)
|
| 140 |
+
return np.arange(0, N-T+1, hop, dtype=np.int64)
|
| 141 |
+
|
| 142 |
+
def zscore_perclip(clip):
|
| 143 |
+
mu=clip.mean(axis=1, keepdims=True); sd=clip.std(axis=1, keepdims=True)+1e-8
|
| 144 |
+
return (clip-mu)/sd
|
| 145 |
+
|
| 146 |
+
def label_for_window(start, T, fs, intervals, min_overlap_sec):
|
| 147 |
+
a,b=start, start+T
|
| 148 |
+
thr=int(round(min_overlap_sec*fs)); overlap=0
|
| 149 |
+
for (u,v) in intervals:
|
| 150 |
+
u_s=int(round(u*fs)); v_s=int(round(v*fs))
|
| 151 |
+
overlap += max(0, min(b, v_s) - max(a, u_s))
|
| 152 |
+
if overlap>=thr: return 1
|
| 153 |
+
return 0
|
| 154 |
+
|
| 155 |
+
class H5Appender:
|
| 156 |
+
def __init__(self, out_path: Path, C: int, T: int, fs: int, channels, edge_index, edge_weight, gzip_level=4):
|
| 157 |
+
self.f=h5py.File(str(out_path),"w")
|
| 158 |
+
self.ds_clips=self.f.create_dataset("clips", shape=(0,C,T,1), maxshape=(None,C,T,1),
|
| 159 |
+
dtype="float32", chunks=(16,C,T,1), compression="gzip", compression_opts=gzip_level)
|
| 160 |
+
self.ds_labels=self.f.create_dataset("labels", shape=(0,), maxshape=(None,), dtype="i8",
|
| 161 |
+
chunks=True, compression="gzip", compression_opts=gzip_level)
|
| 162 |
+
self.ds_fileids=self.f.create_dataset("file_ids", shape=(0,), maxshape=(None,), dtype="i4",
|
| 163 |
+
chunks=True, compression="gzip", compression_opts=gzip_level)
|
| 164 |
+
self.f.attrs["fs"]=fs; self.f.attrs["T"]=T; self.f.attrs["patient"]=out_path.stem
|
| 165 |
+
self.f.create_dataset("channels", data=np.array([c.encode() for c in channels]))
|
| 166 |
+
self.f.create_dataset("edge_index", data=edge_index.astype(np.int64))
|
| 167 |
+
self.f.create_dataset("edge_weight", data=edge_weight.astype(np.float32))
|
| 168 |
+
self.n=0
|
| 169 |
+
def append(self, clips_CT, labels, file_id:int):
|
| 170 |
+
if clips_CT.size==0: return
|
| 171 |
+
M,C,T=clips_CT.shape
|
| 172 |
+
clips=clips_CT[...,None].astype(np.float32)
|
| 173 |
+
self.ds_clips.resize(self.n+M, axis=0)
|
| 174 |
+
self.ds_labels.resize(self.n+M, axis=0)
|
| 175 |
+
self.ds_fileids.resize(self.n+M, axis=0)
|
| 176 |
+
self.ds_clips[self.n:self.n+M]=clips
|
| 177 |
+
self.ds_labels[self.n:self.n+M]=labels.astype(np.int64)
|
| 178 |
+
self.ds_fileids[self.n:self.n+M]=np.full((M,), file_id, dtype=np.int32)
|
| 179 |
+
self.n+=M
|
| 180 |
+
def close(self): self.f.close()
|
| 181 |
+
|
| 182 |
+
def process_patient_1020(root: Path, patient: str, out_path: Path,
|
| 183 |
+
clip_sec: float=4.0, hop_sec: float=2.0,
|
| 184 |
+
fs_target: int=256, min_overlap_sec: float=0.25,
|
| 185 |
+
graph_k: int=8, sigma: float=0.5, radius: float=1.6):
|
| 186 |
+
pat_dir=root/patient
|
| 187 |
+
assert pat_dir.exists(), f"Not found: {pat_dir}"
|
| 188 |
+
edfs=list_edf_files(pat_dir); assert edfs, f"No EDF in {pat_dir}"
|
| 189 |
+
|
| 190 |
+
# kênh giao nhau trong bệnh nhân
|
| 191 |
+
keep_channels=intersection_channels(edfs)
|
| 192 |
+
# chỉ giữ kênh map được sang 10-20 (bỏ kênh lạ)
|
| 193 |
+
keep_channels=[ch for ch in keep_channels if pair_midpoint(ch) is not None]
|
| 194 |
+
if len(keep_channels)<8:
|
| 195 |
+
warnings.warn(f"[{patient}] chỉ còn {len(keep_channels)} kênh sau khi map 10-20")
|
| 196 |
+
|
| 197 |
+
# đồ thị 10-20 tĩnh
|
| 198 |
+
keep_channels, ei, ew = build_1020_graph(keep_channels, k=graph_k, sigma=sigma, radius=radius)
|
| 199 |
+
|
| 200 |
+
summ = next(iter(list(pat_dir.glob("*summary*.txt"))), None)
|
| 201 |
+
summ_map=parse_summary(summ) if summ else {}
|
| 202 |
+
seiz_map={}
|
| 203 |
+
for p in edfs:
|
| 204 |
+
iv=parse_seizures_file(p.with_suffix(p.suffix+".seizures"))
|
| 205 |
+
if iv: seiz_map[p.name]=iv
|
| 206 |
+
def intervals_for(name): return seiz_map.get(name, summ_map.get(name, []))
|
| 207 |
+
|
| 208 |
+
app=None; file_id=0
|
| 209 |
+
total_pos=0; total=0
|
| 210 |
+
|
| 211 |
+
for edf in edfs:
|
| 212 |
+
sigs, fs, chans=read_edf_signals(edf, keep_channels)
|
| 213 |
+
# resample
|
| 214 |
+
if fs!=fs_target:
|
| 215 |
+
ratio=fs_target/fs
|
| 216 |
+
N_new=int(round(sigs.shape[1]*ratio))
|
| 217 |
+
t_old=np.linspace(0, sigs.shape[1]-1, sigs.shape[1], dtype=np.float32)
|
| 218 |
+
t_new=np.linspace(0, sigs.shape[1]-1, N_new, dtype=np.float32)
|
| 219 |
+
sigs=np.stack([np.interp(t_new, t_old, ch) for ch in sigs], axis=0).astype(np.float32)
|
| 220 |
+
|
| 221 |
+
C,N=sigs.shape; T=int(fs_target*clip_sec)
|
| 222 |
+
starts=slice_starts(N, fs_target, clip_sec, hop_sec)
|
| 223 |
+
if app is None:
|
| 224 |
+
app=H5Appender(out_path, C=C, T=T, fs=fs_target, channels=chans, edge_index=ei, edge_weight=ew)
|
| 225 |
+
ivals=intervals_for(edf.name)
|
| 226 |
+
labels=np.array([label_for_window(int(s), T, fs_target, ivals, min_overlap_sec) for s in starts], dtype=np.int64)
|
| 227 |
+
M=len(starts); clips=np.empty((M,C,T), dtype=np.float32)
|
| 228 |
+
for i,s in enumerate(starts): clips[i]=zscore_perclip(sigs[:, s:s+T])
|
| 229 |
+
app.append(clips, labels, file_id=file_id)
|
| 230 |
+
total += M; total_pos += int(labels.sum()); file_id+=1
|
| 231 |
+
|
| 232 |
+
app.close()
|
| 233 |
+
print(f"[Done] {patient} -> {out_path} | clips={total} (pos={total_pos}, neg={total-total_pos}), C={len(keep_channels)}, T={int(fs_target*clip_sec)}")
|
| 234 |
+
|
| 235 |
+
if __name__=="__main__":
|
| 236 |
+
import argparse
|
| 237 |
+
ap=argparse.ArgumentParser()
|
| 238 |
+
ap.add_argument("--root", required=True)
|
| 239 |
+
ap.add_argument("--patient", required=True)
|
| 240 |
+
ap.add_argument("--out", required=True)
|
| 241 |
+
ap.add_argument("--clip-sec", type=float, default=4.0)
|
| 242 |
+
ap.add_argument("--hop-sec", type=float, default=2.0)
|
| 243 |
+
ap.add_argument("--fs", type=int, default=256)
|
| 244 |
+
ap.add_argument("--min-overlap", type=float, default=0.25)
|
| 245 |
+
ap.add_argument("--k", type=int, default=8)
|
| 246 |
+
ap.add_argument("--sigma", type=float, default=0.5)
|
| 247 |
+
ap.add_argument("--radius", type=float, default=1.6)
|
| 248 |
+
args=ap.parse_args()
|
| 249 |
+
process_patient_1020(Path(args.root), args.patient, Path(args.out),
|
| 250 |
+
clip_sec=args.clip_sec, hop_sec=args.hop_sec, fs_target=args.fs,
|
| 251 |
+
min_overlap_sec=args.min_overlap, graph_k=args.k, sigma=args.sigma, radius=args.radius)
|