|
|
import os, re, math, json, argparse, warnings |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Dict, Optional |
|
|
import numpy as np, h5py, pyedflib |
|
|
|
|
|
|
|
|
E1020 = { |
|
|
"FP1":(-0.5, 1.0), "FP2":( 0.5, 1.0), |
|
|
"F7": (-1.0, 0.6), "F3":(-0.3, 0.6), "FZ":(0.0,0.8), "F4":( 0.3, 0.6), "F8":( 1.0, 0.6), |
|
|
"T3":(-1.2, 0.2), "T7":(-1.2, 0.2), "C3":(-0.4, 0.2), "CZ":(0.0,0.3), "C4":( 0.4, 0.2), "T4":( 1.2, 0.2), "T8":( 1.2, 0.2), |
|
|
"T5":(-1.1, -0.2), "P7":(-1.1, -0.2), "P3":(-0.3, -0.2), "PZ":(0.0,-0.1), "P4":( 0.3, -0.2), "T6":( 1.1, -0.2), "P8":( 1.1, -0.2), |
|
|
"O1":(-0.5, -0.8), "O2":( 0.5, -0.8) |
|
|
} |
|
|
ALIASES = {"T3":"T7","T4":"T8","T5":"P7","T6":"P8"} |
|
|
|
|
|
_BLOCKLIST = re.compile(r'(?:ECG|VNS|-$)', re.IGNORECASE) |
|
|
|
|
|
def _norm_e(name:str)->Optional[str]: |
|
|
name=name.upper().strip() |
|
|
name=ALIASES.get(name,name) |
|
|
return name if name in E1020 else None |
|
|
|
|
|
def pair_midpoint(ch_label:str)->Optional[Tuple[float,float]]: |
|
|
""" |
|
|
CHB-MIT thường dùng kênh lưỡng cực 'A-B'. Lấy toạ độ là trung điểm của hai điện cực A,B. |
|
|
""" |
|
|
if '-' not in ch_label: return None |
|
|
a,b = [x.strip().upper() for x in ch_label.split('-',1)] |
|
|
a=_norm_e(a); b=_norm_e(b) |
|
|
if (a is None) or (b is None): return None |
|
|
ax,ay=E1020[a]; bx,by=E1020[b] |
|
|
return ((ax+bx)/2.0, (ay+by)/2.0) |
|
|
|
|
|
def build_1020_graph(channels:List[str], k:int=8, sigma:float=0.5, radius:float=1.6): |
|
|
""" |
|
|
Đồ thị khoảng cách tĩnh theo 10–20: |
|
|
w_ij = exp(-||vi-vj||^2 / sigma^2), nếu ||vi-vj|| <= radius, ngược lại 0. |
|
|
Sau đó với mỗi nút, giữ lại k láng giềng w lớn nhất (undirected). |
|
|
""" |
|
|
coords=[] |
|
|
keep=[] |
|
|
for ch in channels: |
|
|
if _BLOCKLIST.search(ch): continue |
|
|
m=pair_midpoint(ch) |
|
|
if m is not None: |
|
|
coords.append(m); keep.append(ch) |
|
|
C=len(keep) |
|
|
if C<2: |
|
|
raise RuntimeError("Không đủ kênh ánh xạ được sang 10-20") |
|
|
coords=np.array(coords, dtype=np.float32) |
|
|
D=np.sqrt(((coords[None,:,:]-coords[:,None,:])**2).sum(axis=-1)) |
|
|
W=np.exp(-(D**2)/(sigma**2)) |
|
|
W[D>radius]=0.0 |
|
|
np.fill_diagonal(W, 0.0) |
|
|
|
|
|
edges=set() |
|
|
for i in range(C): |
|
|
idx=np.argsort(-W[i])[:max(1,min(k,C-1))] |
|
|
for j in idx: |
|
|
a,b=(i,j) if i<j else (j,i) |
|
|
if W[a,b]>0: edges.add((a,b)) |
|
|
edges=sorted(list(edges)) |
|
|
ei=np.array(edges, dtype=np.int64).T |
|
|
ei=np.hstack([ei, ei[::-1,:]]) if ei.size else ei |
|
|
ew=np.array([W[i,j] for (i,j) in edges], dtype=np.float32) |
|
|
ew=np.concatenate([ew, ew]) if ew.size else ew |
|
|
return keep, ei, ew |
|
|
|
|
|
|
|
|
def list_edf_files(patient_dir: Path): |
|
|
return sorted([p for p in patient_dir.glob("*.edf") if p.is_file()]) |
|
|
|
|
|
def edf_channel_labels(edf_path: Path): |
|
|
f=pyedflib.EdfReader(str(edf_path)) |
|
|
labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)] |
|
|
f._close(); del f |
|
|
return [ch for ch in labels if not _BLOCKLIST.search(ch)] |
|
|
|
|
|
def intersection_channels(edf_paths): |
|
|
common=None |
|
|
for p in edf_paths: |
|
|
chans=set(edf_channel_labels(p)) |
|
|
if not chans: continue |
|
|
common = chans if common is None else (common & chans) |
|
|
return sorted(list(common)) if common else [] |
|
|
|
|
|
def read_edf_signals(edf_path: Path, keep_channels): |
|
|
f=pyedflib.EdfReader(str(edf_path)) |
|
|
labels=[f.getLabel(i).strip() for i in range(f.signals_in_file)] |
|
|
fs_all=[int(round(f.getSampleFrequency(i))) for i in range(f.signals_in_file)] |
|
|
fs=int(round(np.median(fs_all))) if fs_all else 256 |
|
|
idxs=[]; out_labels=[] |
|
|
for ch in keep_channels: |
|
|
try: |
|
|
i=labels.index(ch) |
|
|
if _BLOCKLIST.search(ch): continue |
|
|
idxs.append(i); out_labels.append(ch) |
|
|
except ValueError: |
|
|
f._close(); del f |
|
|
raise RuntimeError(f"Channel {ch} not found in {edf_path.name}") |
|
|
sigs=np.vstack([f.readSignal(i) for i in idxs]).astype(np.float32) |
|
|
f._close(); del f |
|
|
return sigs, fs, out_labels |
|
|
|
|
|
|
|
|
def parse_summary(summary_path: Path) -> Dict[str, List[Tuple[float,float]]]: |
|
|
mapping={} |
|
|
if not (summary_path and summary_path.exists()): |
|
|
return mapping |
|
|
curr=None; buf=[] |
|
|
with summary_path.open("r", errors="ignore") as f: |
|
|
for line in f: |
|
|
line=line.strip() |
|
|
mfile=re.search(r'File Name:\s*(\S+\.edf)', line, re.IGNORECASE) |
|
|
if mfile: |
|
|
if curr and buf: |
|
|
pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)] |
|
|
mapping.setdefault(curr, []).extend(pairs) |
|
|
curr=mfile.group(1); buf=[]; continue |
|
|
if re.search(r'Seizure (Start|End) Time', line, re.IGNORECASE): |
|
|
nums=[float(x) for x in re.findall(r'[\d.]+', line)] |
|
|
if nums: buf.extend(nums) |
|
|
if curr and buf: |
|
|
pairs=[(buf[i], buf[i+1]) for i in range(0,len(buf)-1,2)] |
|
|
mapping.setdefault(curr, []).extend(pairs) |
|
|
return mapping |
|
|
|
|
|
def parse_seizures_file(seiz_file: Path) -> List[Tuple[float,float]]: |
|
|
intervals=[] |
|
|
if not seiz_file.exists(): return intervals |
|
|
with seiz_file.open("r", errors="ignore") as f: |
|
|
for line in f: |
|
|
nums=[float(x) for x in re.findall(r'[-+]?\d*\.?\d+', line)] |
|
|
if len(nums)>=2: intervals.append((nums[0], nums[1])) |
|
|
return intervals |
|
|
|
|
|
def slice_starts(N, fs, clip_sec, hop_sec): |
|
|
T=int(fs*clip_sec); hop=int(fs*hop_sec) |
|
|
if N<T: return np.zeros((0,), dtype=np.int64) |
|
|
return np.arange(0, N-T+1, hop, dtype=np.int64) |
|
|
|
|
|
def zscore_perclip(clip): |
|
|
mu=clip.mean(axis=1, keepdims=True); sd=clip.std(axis=1, keepdims=True)+1e-8 |
|
|
return (clip-mu)/sd |
|
|
|
|
|
def label_for_window(start, T, fs, intervals, min_overlap_sec): |
|
|
a,b=start, start+T |
|
|
thr=int(round(min_overlap_sec*fs)); overlap=0 |
|
|
for (u,v) in intervals: |
|
|
u_s=int(round(u*fs)); v_s=int(round(v*fs)) |
|
|
overlap += max(0, min(b, v_s) - max(a, u_s)) |
|
|
if overlap>=thr: return 1 |
|
|
return 0 |
|
|
|
|
|
class H5Appender: |
|
|
def __init__(self, out_path: Path, C: int, T: int, fs: int, channels, edge_index, edge_weight, gzip_level=4): |
|
|
self.f=h5py.File(str(out_path),"w") |
|
|
self.ds_clips=self.f.create_dataset("clips", shape=(0,C,T,1), maxshape=(None,C,T,1), |
|
|
dtype="float32", chunks=(16,C,T,1), compression="gzip", compression_opts=gzip_level) |
|
|
self.ds_labels=self.f.create_dataset("labels", shape=(0,), maxshape=(None,), dtype="i8", |
|
|
chunks=True, compression="gzip", compression_opts=gzip_level) |
|
|
self.ds_fileids=self.f.create_dataset("file_ids", shape=(0,), maxshape=(None,), dtype="i4", |
|
|
chunks=True, compression="gzip", compression_opts=gzip_level) |
|
|
self.f.attrs["fs"]=fs; self.f.attrs["T"]=T; self.f.attrs["patient"]=out_path.stem |
|
|
self.f.create_dataset("channels", data=np.array([c.encode() for c in channels])) |
|
|
self.f.create_dataset("edge_index", data=edge_index.astype(np.int64)) |
|
|
self.f.create_dataset("edge_weight", data=edge_weight.astype(np.float32)) |
|
|
self.n=0 |
|
|
def append(self, clips_CT, labels, file_id:int): |
|
|
if clips_CT.size==0: return |
|
|
M,C,T=clips_CT.shape |
|
|
clips=clips_CT[...,None].astype(np.float32) |
|
|
self.ds_clips.resize(self.n+M, axis=0) |
|
|
self.ds_labels.resize(self.n+M, axis=0) |
|
|
self.ds_fileids.resize(self.n+M, axis=0) |
|
|
self.ds_clips[self.n:self.n+M]=clips |
|
|
self.ds_labels[self.n:self.n+M]=labels.astype(np.int64) |
|
|
self.ds_fileids[self.n:self.n+M]=np.full((M,), file_id, dtype=np.int32) |
|
|
self.n+=M |
|
|
def close(self): self.f.close() |
|
|
|
|
|
def process_patient_1020(root: Path, patient: str, out_path: Path, |
|
|
clip_sec: float=4.0, hop_sec: float=2.0, |
|
|
fs_target: int=256, min_overlap_sec: float=0.25, |
|
|
graph_k: int=8, sigma: float=0.5, radius: float=1.6): |
|
|
pat_dir=root/patient |
|
|
assert pat_dir.exists(), f"Not found: {pat_dir}" |
|
|
edfs=list_edf_files(pat_dir); assert edfs, f"No EDF in {pat_dir}" |
|
|
|
|
|
|
|
|
keep_channels=intersection_channels(edfs) |
|
|
|
|
|
keep_channels=[ch for ch in keep_channels if pair_midpoint(ch) is not None] |
|
|
if len(keep_channels)<8: |
|
|
warnings.warn(f"[{patient}] chỉ còn {len(keep_channels)} kênh sau khi map 10-20") |
|
|
|
|
|
|
|
|
keep_channels, ei, ew = build_1020_graph(keep_channels, k=graph_k, sigma=sigma, radius=radius) |
|
|
|
|
|
summ = next(iter(list(pat_dir.glob("*summary*.txt"))), None) |
|
|
summ_map=parse_summary(summ) if summ else {} |
|
|
seiz_map={} |
|
|
for p in edfs: |
|
|
iv=parse_seizures_file(p.with_suffix(p.suffix+".seizures")) |
|
|
if iv: seiz_map[p.name]=iv |
|
|
def intervals_for(name): return seiz_map.get(name, summ_map.get(name, [])) |
|
|
|
|
|
app=None; file_id=0 |
|
|
total_pos=0; total=0 |
|
|
|
|
|
for edf in edfs: |
|
|
sigs, fs, chans=read_edf_signals(edf, keep_channels) |
|
|
|
|
|
if fs!=fs_target: |
|
|
ratio=fs_target/fs |
|
|
N_new=int(round(sigs.shape[1]*ratio)) |
|
|
t_old=np.linspace(0, sigs.shape[1]-1, sigs.shape[1], dtype=np.float32) |
|
|
t_new=np.linspace(0, sigs.shape[1]-1, N_new, dtype=np.float32) |
|
|
sigs=np.stack([np.interp(t_new, t_old, ch) for ch in sigs], axis=0).astype(np.float32) |
|
|
|
|
|
C,N=sigs.shape; T=int(fs_target*clip_sec) |
|
|
starts=slice_starts(N, fs_target, clip_sec, hop_sec) |
|
|
if app is None: |
|
|
app=H5Appender(out_path, C=C, T=T, fs=fs_target, channels=chans, edge_index=ei, edge_weight=ew) |
|
|
ivals=intervals_for(edf.name) |
|
|
labels=np.array([label_for_window(int(s), T, fs_target, ivals, min_overlap_sec) for s in starts], dtype=np.int64) |
|
|
M=len(starts); clips=np.empty((M,C,T), dtype=np.float32) |
|
|
for i,s in enumerate(starts): clips[i]=zscore_perclip(sigs[:, s:s+T]) |
|
|
app.append(clips, labels, file_id=file_id) |
|
|
total += M; total_pos += int(labels.sum()); file_id+=1 |
|
|
|
|
|
app.close() |
|
|
print(f"[Done] {patient} -> {out_path} | clips={total} (pos={total_pos}, neg={total-total_pos}), C={len(keep_channels)}, T={int(fs_target*clip_sec)}") |
|
|
|
|
|
if __name__=="__main__": |
|
|
import argparse |
|
|
ap=argparse.ArgumentParser() |
|
|
ap.add_argument("--root", required=True) |
|
|
ap.add_argument("--patient", required=True) |
|
|
ap.add_argument("--out", required=True) |
|
|
ap.add_argument("--clip-sec", type=float, default=4.0) |
|
|
ap.add_argument("--hop-sec", type=float, default=2.0) |
|
|
ap.add_argument("--fs", type=int, default=256) |
|
|
ap.add_argument("--min-overlap", type=float, default=0.25) |
|
|
ap.add_argument("--k", type=int, default=8) |
|
|
ap.add_argument("--sigma", type=float, default=0.5) |
|
|
ap.add_argument("--radius", type=float, default=1.6) |
|
|
args=ap.parse_args() |
|
|
process_patient_1020(Path(args.root), args.patient, Path(args.out), |
|
|
clip_sec=args.clip_sec, hop_sec=args.hop_sec, fs_target=args.fs, |
|
|
min_overlap_sec=args.min_overlap, graph_k=args.k, sigma=args.sigma, radius=args.radius) |
|
|
|