rest_eeg_seizure_analysis / code /rest_prep_1020.py
uyen1109's picture
Add code: rest_prep_1020.py
23f06d2 verified
import os, re, math, json, argparse, warnings
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import numpy as np, h5py, pyedflib
# ---------- 10-20 electrode 2D coords (x,y) đã chuẩn hoá tương đối ----------
E1020 = {
"FP1":(-0.5, 1.0), "FP2":( 0.5, 1.0),
"F7": (-1.0, 0.6), "F3":(-0.3, 0.6), "FZ":(0.0,0.8), "F4":( 0.3, 0.6), "F8":( 1.0, 0.6),
"T3":(-1.2, 0.2), "T7":(-1.2, 0.2), "C3":(-0.4, 0.2), "CZ":(0.0,0.3), "C4":( 0.4, 0.2), "T4":( 1.2, 0.2), "T8":( 1.2, 0.2),
"T5":(-1.1, -0.2), "P7":(-1.1, -0.2), "P3":(-0.3, -0.2), "PZ":(0.0,-0.1), "P4":( 0.3, -0.2), "T6":( 1.1, -0.2), "P8":( 1.1, -0.2),
"O1":(-0.5, -0.8), "O2":( 0.5, -0.8)
}
ALIASES = {"T3":"T7","T4":"T8","T5":"P7","T6":"P8"} # đồng nhất tên
_BLOCKLIST = re.compile(r'(?:ECG|VNS|-$)', re.IGNORECASE)
def _norm_e(name:str)->Optional[str]:
name=name.upper().strip()
name=ALIASES.get(name,name)
return name if name in E1020 else None
def pair_midpoint(ch_label:str)->Optional[Tuple[float,float]]:
"""
CHB-MIT thường dùng kênh lưỡng cực 'A-B'. Lấy toạ độ là trung điểm của hai điện cực A,B.
"""
if '-' not in ch_label: return None
a,b = [x.strip().upper() for x in ch_label.split('-',1)]
a=_norm_e(a); b=_norm_e(b)
if (a is None) or (b is None): return None
ax,ay=E1020[a]; bx,by=E1020[b]
return ((ax+bx)/2.0, (ay+by)/2.0)
def build_1020_graph(channels:List[str], k:int=8, sigma:float=0.5, radius:float=1.6):
"""
Đồ thị khoảng cách tĩnh theo 10–20:
w_ij = exp(-||vi-vj||^2 / sigma^2), nếu ||vi-vj|| <= radius, ngược lại 0.
Sau đó với mỗi nút, giữ lại k láng giềng w lớn nhất (undirected).
"""
coords=[]
keep=[]
for ch in channels:
if _BLOCKLIST.search(ch): continue
m=pair_midpoint(ch)
if m is not None:
coords.append(m); keep.append(ch)
C=len(keep)
if C<2:
raise RuntimeError("Không đủ kênh ánh xạ được sang 10-20")
coords=np.array(coords, dtype=np.float32)
D=np.sqrt(((coords[None,:,:]-coords[:,None,:])**2).sum(axis=-1)) # (C,C)
W=np.exp(-(D**2)/(sigma**2))
W[D>radius]=0.0
np.fill_diagonal(W, 0.0)
edges=set()
for i in range(C):
idx=np.argsort(-W[i])[:max(1,min(k,C-1))]
for j in idx:
a,b=(i,j) if i<j else (j,i)
if W[a,b]>0: edges.add((a,b))
edges=sorted(list(edges))
ei=np.array(edges, dtype=np.int64).T
ei=np.hstack([ei, ei[::-1,:]]) if ei.size else ei
ew=np.array([W[i,j] for (i,j) in edges], dtype=np.float32)
ew=np.concatenate([ew, ew]) if ew.size else ew
return keep, ei, ew
# ---------- EDF utils ----------
def list_edf_files(patient_dir: Path):
return sorted([p for p in patient_dir.glob("*.edf") if p.is_file()])
def edf_channel_labels(edf_path: Path):
f=pyedflib.EdfReader(str(edf_path))
labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)]
f._close(); del f
return [ch for ch in labels if not _BLOCKLIST.search(ch)]
def intersection_channels(edf_paths):
common=None
for p in edf_paths:
chans=set(edf_channel_labels(p))
if not chans: continue
common = chans if common is None else (common & chans)
return sorted(list(common)) if common else []
def read_edf_signals(edf_path: Path, keep_channels):
f=pyedflib.EdfReader(str(edf_path))
labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)]
fs_all=[int(round(f.getSampleFrequency(i))) for i in range(f.signals_in_file)]
fs=int(round(np.median(fs_all))) if fs_all else 256
idxs=[]; out_labels=[]
for ch in keep_channels:
try:
i=labels.index(ch)
if _BLOCKLIST.search(ch): continue
idxs.append(i); out_labels.append(ch)
except ValueError:
f._close(); del f
raise RuntimeError(f"Channel {ch} not found in {edf_path.name}")
sigs=np.vstack([f.readSignal(i) for i in idxs]).astype(np.float32) # (C,N)
f._close(); del f
return sigs, fs, out_labels
# ---------- nhãn co giật ----------
def parse_summary(summary_path: Path) -> Dict[str, List[Tuple[float,float]]]:
mapping={}
if not (summary_path and summary_path.exists()):
return mapping
curr=None; buf=[]
with summary_path.open("r", errors="ignore") as f:
for line in f:
line=line.strip()
mfile=re.search(r'File Name:\s*(\S+\.edf)', line, re.IGNORECASE)
if mfile:
if curr and buf:
pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)]
mapping.setdefault(curr, []).extend(pairs)
curr=mfile.group(1); buf=[]; continue
if re.search(r'Seizure (Start|End) Time', line, re.IGNORECASE):
nums=[float(x) for x in re.findall(r'[\d.]+', line)]
if nums: buf.extend(nums)
if curr and buf:
pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)]
mapping.setdefault(curr, []).extend(pairs)
return mapping
def parse_seizures_file(seiz_file: Path) -> List[Tuple[float,float]]:
intervals=[]
if not seiz_file.exists(): return intervals
with seiz_file.open("r", errors="ignore") as f:
for line in f:
nums=[float(x) for x in re.findall(r'[-+]?\d*\.?\d+', line)]
if len(nums)>=2: intervals.append((nums[0], nums[1]))
return intervals
def slice_starts(N, fs, clip_sec, hop_sec):
T=int(fs*clip_sec); hop=int(fs*hop_sec)
if N<T: return np.zeros((0,), dtype=np.int64)
return np.arange(0, N-T+1, hop, dtype=np.int64)
def zscore_perclip(clip):
mu=clip.mean(axis=1, keepdims=True); sd=clip.std(axis=1, keepdims=True)+1e-8
return (clip-mu)/sd
def label_for_window(start, T, fs, intervals, min_overlap_sec):
a,b=start, start+T
thr=int(round(min_overlap_sec*fs)); overlap=0
for (u,v) in intervals:
u_s=int(round(u*fs)); v_s=int(round(v*fs))
overlap += max(0, min(b, v_s) - max(a, u_s))
if overlap>=thr: return 1
return 0
class H5Appender:
def __init__(self, out_path: Path, C: int, T: int, fs: int, channels, edge_index, edge_weight, gzip_level=4):
self.f=h5py.File(str(out_path),"w")
self.ds_clips=self.f.create_dataset("clips", shape=(0,C,T,1), maxshape=(None,C,T,1),
dtype="float32", chunks=(16,C,T,1), compression="gzip", compression_opts=gzip_level)
self.ds_labels=self.f.create_dataset("labels", shape=(0,), maxshape=(None,), dtype="i8",
chunks=True, compression="gzip", compression_opts=gzip_level)
self.ds_fileids=self.f.create_dataset("file_ids", shape=(0,), maxshape=(None,), dtype="i4",
chunks=True, compression="gzip", compression_opts=gzip_level)
self.f.attrs["fs"]=fs; self.f.attrs["T"]=T; self.f.attrs["patient"]=out_path.stem
self.f.create_dataset("channels", data=np.array([c.encode() for c in channels]))
self.f.create_dataset("edge_index", data=edge_index.astype(np.int64))
self.f.create_dataset("edge_weight", data=edge_weight.astype(np.float32))
self.n=0
def append(self, clips_CT, labels, file_id:int):
if clips_CT.size==0: return
M,C,T=clips_CT.shape
clips=clips_CT[...,None].astype(np.float32)
self.ds_clips.resize(self.n+M, axis=0)
self.ds_labels.resize(self.n+M, axis=0)
self.ds_fileids.resize(self.n+M, axis=0)
self.ds_clips[self.n:self.n+M]=clips
self.ds_labels[self.n:self.n+M]=labels.astype(np.int64)
self.ds_fileids[self.n:self.n+M]=np.full((M,), file_id, dtype=np.int32)
self.n+=M
def close(self): self.f.close()
def process_patient_1020(root: Path, patient: str, out_path: Path,
clip_sec: float=4.0, hop_sec: float=2.0,
fs_target: int=256, min_overlap_sec: float=0.25,
graph_k: int=8, sigma: float=0.5, radius: float=1.6):
pat_dir=root/patient
assert pat_dir.exists(), f"Not found: {pat_dir}"
edfs=list_edf_files(pat_dir); assert edfs, f"No EDF in {pat_dir}"
# kênh giao nhau trong bệnh nhân
keep_channels=intersection_channels(edfs)
# chỉ giữ kênh map được sang 10-20 (bỏ kênh lạ)
keep_channels=[ch for ch in keep_channels if pair_midpoint(ch) is not None]
if len(keep_channels)<8:
warnings.warn(f"[{patient}] chỉ còn {len(keep_channels)} kênh sau khi map 10-20")
# đồ thị 10-20 tĩnh
keep_channels, ei, ew = build_1020_graph(keep_channels, k=graph_k, sigma=sigma, radius=radius)
summ = next(iter(list(pat_dir.glob("*summary*.txt"))), None)
summ_map=parse_summary(summ) if summ else {}
seiz_map={}
for p in edfs:
iv=parse_seizures_file(p.with_suffix(p.suffix+".seizures"))
if iv: seiz_map[p.name]=iv
def intervals_for(name): return seiz_map.get(name, summ_map.get(name, []))
app=None; file_id=0
total_pos=0; total=0
for edf in edfs:
sigs, fs, chans=read_edf_signals(edf, keep_channels)
# resample
if fs!=fs_target:
ratio=fs_target/fs
N_new=int(round(sigs.shape[1]*ratio))
t_old=np.linspace(0, sigs.shape[1]-1, sigs.shape[1], dtype=np.float32)
t_new=np.linspace(0, sigs.shape[1]-1, N_new, dtype=np.float32)
sigs=np.stack([np.interp(t_new, t_old, ch) for ch in sigs], axis=0).astype(np.float32)
C,N=sigs.shape; T=int(fs_target*clip_sec)
starts=slice_starts(N, fs_target, clip_sec, hop_sec)
if app is None:
app=H5Appender(out_path, C=C, T=T, fs=fs_target, channels=chans, edge_index=ei, edge_weight=ew)
ivals=intervals_for(edf.name)
labels=np.array([label_for_window(int(s), T, fs_target, ivals, min_overlap_sec) for s in starts], dtype=np.int64)
M=len(starts); clips=np.empty((M,C,T), dtype=np.float32)
for i,s in enumerate(starts): clips[i]=zscore_perclip(sigs[:, s:s+T])
app.append(clips, labels, file_id=file_id)
total += M; total_pos += int(labels.sum()); file_id+=1
app.close()
print(f"[Done] {patient} -> {out_path} | clips={total} (pos={total_pos}, neg={total-total_pos}), C={len(keep_channels)}, T={int(fs_target*clip_sec)}")
if __name__=="__main__":
import argparse
ap=argparse.ArgumentParser()
ap.add_argument("--root", required=True)
ap.add_argument("--patient", required=True)
ap.add_argument("--out", required=True)
ap.add_argument("--clip-sec", type=float, default=4.0)
ap.add_argument("--hop-sec", type=float, default=2.0)
ap.add_argument("--fs", type=int, default=256)
ap.add_argument("--min-overlap", type=float, default=0.25)
ap.add_argument("--k", type=int, default=8)
ap.add_argument("--sigma", type=float, default=0.5)
ap.add_argument("--radius", type=float, default=1.6)
args=ap.parse_args()
process_patient_1020(Path(args.root), args.patient, Path(args.out),
clip_sec=args.clip_sec, hop_sec=args.hop_sec, fs_target=args.fs,
min_overlap_sec=args.min_overlap, graph_k=args.k, sigma=args.sigma, radius=args.radius)