rest_eeg / app.py
uyen1109's picture
Update app.py
85829c7 verified
# app.py — REST EEG Seizure Demo (GraphConv, supports W2/U, stats + top-k + span)
import os, io, json
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import h5py, matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from torch_geometric.nn import GraphConv
# ---------------------------
# 0) Config & device
# ---------------------------
MODEL_REPO = os.getenv("MODEL_REPO", "uyen1109/rest_eeg_seizure_analysis")
SPACE_DIR = "space_infer"
device = torch.device("cpu")
torch.set_num_threads(max(1, os.cpu_count() // 2))
# ==========================================================
# 1) Model (GraphConv + tên tham số khớp checkpoint, optional W2/U)
# Bắt buộc: W1, gc1, gc2, fc
# Tuỳ chọn: W2 (train có thêm) & U (transition trên S)
# ==========================================================
class RESTNet(nn.Module):
def __init__(
self,
in_dim: int,
state_q: int = 64,
w2_in: int | None = None, # None => không dùng W2; nếu có: in_dim (x_t) hoặc state_q (S)
u_in: int | None = None, # None => không dùng U; nếu có: state_q
):
super().__init__()
self.in_dim = in_dim
self.state_q = state_q
# Tên trùng checkpoint
self.W1 = nn.Linear(in_dim, state_q)
self.gc1 = GraphConv(state_q, state_q, aggr="mean")
self.gc2 = GraphConv(state_q, state_q, aggr="mean")
self.fc = nn.Linear(state_q, 1)
# Optional: W2 & U (nếu tồn tại trong checkpoint)
self._use_W2 = False
self._use_U = False
if w2_in is not None:
if w2_in == in_dim:
# W2 hoạt động trên x_t
self.W2 = nn.Linear(in_dim, state_q)
self._use_W2 = True
elif w2_in == state_q:
# W2 hoạt động trên S
self.W2 = nn.Linear(state_q, state_q)
self._use_W2 = True
if u_in is not None and u_in == state_q:
# không bias để tên khớp 'U.weight' trong ckpt
self.U = nn.Linear(state_q, state_q, bias=False)
self._use_U = True
@torch.no_grad()
def forward(self, x_ntf, edge_index, edge_weight=None):
# x_ntf: [N, T, F]
N, T, Fdim = x_ntf.shape
S = torch.zeros(N, self.state_q, device=x_ntf.device)
frame_logits = []
for t in range(T):
upd = self.W1(x_ntf[:, t, :]) # luôn có
if self._use_W2:
if self.W2.in_features == Fdim:
upd = upd + self.W2(x_ntf[:, t, :]) # W2 trên x_t
else:
upd = upd + self.W2(S) # W2 trên S
if self._use_U:
S = self.U(S) + upd
else:
S = S + upd
# message passing (GraphConv như khi train)
S = F.relu(self.gc1(S, edge_index, edge_weight))
S = F.relu(self.gc2(S, edge_index, edge_weight))
frame_logits.append(self.fc(S).mean(dim=0, keepdim=True).squeeze(-1))
frame_logits = torch.stack(frame_logits, dim=0) # [T]
frame_probs = torch.sigmoid(frame_logits)
return frame_probs.mean().item(), frame_probs
# ==========================================================
# 2) Load config + weights (strict load, auto-enable W2/U)
# ==========================================================
def hub_download(fname: str) -> str:
return hf_hub_download(MODEL_REPO, f"{SPACE_DIR}/{fname}", repo_type="model")
CFG_PATH = hub_download("rest_config.json")
STATE_PATH = hub_download("rest_state.pt")
with open(CFG_PATH, "r") as f:
CFG = json.load(f)
sd = torch.load(STATE_PATH, map_location="cpu")
# Suy ra F (in_dim) và Q (state_q) từ checkpoint
if "W1.weight" in sd:
Q_MODEL, F_MODEL = sd["W1.weight"].shape # [Q, F]
else:
F_MODEL = int(CFG.get("in_feat", 128))
Q_MODEL = int(CFG.get("state_q", 32))
# Phát hiện W2/U có trong checkpoint và lấy in_features để dựng lớp tương ứng
w2_in = sd["W2.weight"].shape[1] if "W2.weight" in sd else None
u_in = sd["U.weight"].shape[1] if "U.weight" in sd else None
MODEL = RESTNet(in_dim=F_MODEL, state_q=Q_MODEL, w2_in=w2_in, u_in=u_in).to(device).eval()
# Nạp checkpoint nghiêm ngặt; nếu fail vì key phụ không dùng, fallback non-strict để Space không crash
try:
missing, unexpected = MODEL.load_state_dict(sd, strict=True)
print(f"Loaded strict: F={F_MODEL}, Q={Q_MODEL}, W2_in={w2_in}, U_in={u_in}")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
except Exception as e:
print("Strict load failed:", e)
missing, unexpected = MODEL.load_state_dict(sd, strict=False)
print("Fallback non-strict load. Missing:", missing, "Unexpected:", unexpected)
# ==========================================================
# 3) Data utilities + viz (stats/top-k/spans)
# ==========================================================
def _normalize_x_shape(x: np.ndarray) -> np.ndarray:
"""
Đưa x về [N, T, F] theo heuristic:
- N ~ [8..128] (số kênh EEG)
- T >= 2
- F >= 1
"""
assert x.ndim == 3, "Input x phải có 3 chiều."
perms = [(0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0)]
for p in perms:
y = np.transpose(x, p)
N, T, Fdim = y.shape
if 8 <= N <= 128 and T >= 2 and Fdim >= 1:
return y
return x # fallback nếu không đoán được
def _adapt_features_to_model(x_ntf: np.ndarray, target_F: int) -> np.ndarray:
"""
Pad hoặc cắt trục F để khớp target_F (model.in_dim).
"""
N, T, Fdim = x_ntf.shape
if Fdim == target_F:
return x_ntf
if Fdim > target_F:
return x_ntf[..., :target_F]
pad = target_F - Fdim
return np.pad(x_ntf, ((0,0),(0,0),(0,pad)), mode="constant")
def load_npz(npz_file):
if isinstance(npz_file, str):
npz = np.load(npz_file, allow_pickle=True)
else:
buf = io.BytesIO(npz_file.read())
npz = np.load(buf, allow_pickle=True)
x = np.asarray(npz["x"])
x = _normalize_x_shape(x)
edge_index = np.asarray(npz["edge_index"])
edge_weight = np.asarray(npz["edge_weight"]) if "edge_weight" in npz else None
return x, edge_index, edge_weight
def load_h5(h5_file, clip_idx=0):
if isinstance(h5_file, str):
f = h5py.File(h5_file, "r")
else:
buf = io.BytesIO(h5_file.read())
f = h5py.File(buf, "r")
keys = list(f.keys())
for k in ["x","clips","X"]:
if k in keys:
X = f[k]
break
else:
raise gr.Error("Không tìm thấy dataset 'x'/'clips'/'X' trong H5.")
# Lấy đúng clip
if X.ndim == 3:
x = X[:] # [N,T,F?]
else:
x = X[clip_idx] # [N,T,F?]
x = _normalize_x_shape(x)
for k in ["edge_index","edge_idx","edges"]:
if k in keys:
edge_index = f[k][:]
break
else:
raise gr.Error("Không có 'edge_index' trong H5.")
edge_weight = None
for k in ["edge_weight","edge_w","weights"]:
if k in keys:
edge_weight = f[k][:]
break
f.close()
return x, edge_index, edge_weight
def _cluster_spans_from_top(top_idx: np.ndarray, T: int, span_half: int, merge_gap: int):
"""
Từ danh sách frame top-k (đã sort tăng), tạo các span [l,r] với padding 'span_half',
và merge nếu khoảng cách giữa các span liền kề <= merge_gap.
"""
if top_idx is None or len(top_idx) == 0:
return []
span_half = max(0, int(span_half))
merge_gap = max(0, int(merge_gap))
# Tạo các interval cơ bản
spans = []
for t in top_idx:
l = max(0, int(t) - span_half)
r = min(T - 1, int(t) + span_half)
spans.append([l, r])
# Merge các interval nếu gần nhau
spans.sort(key=lambda x: x[0])
merged = []
cur_l, cur_r = spans[0]
for l, r in spans[1:]:
if l <= cur_r + merge_gap:
cur_r = max(cur_r, r)
else:
merged.append([cur_l, cur_r])
cur_l, cur_r = l, r
merged.append([cur_l, cur_r])
return merged
def plot_frame_probs(frame_probs: np.ndarray, top_idx: np.ndarray | None = None,
spans: list[list[int]] | None = None) -> np.ndarray:
# Trả về numpy array (HxWx3) để hợp với gr.Image(type="numpy")
fig = plt.figure(figsize=(7, 2.8))
ax = fig.add_subplot(111)
t = np.arange(len(frame_probs))
ax.plot(t, frame_probs, lw=1.5)
# Tô các span (nếu có) trước để đường nằm phía trên
if spans:
for l, r in spans:
ax.axvspan(l, r, alpha=0.15)
# Đánh dấu top-k
if top_idx is not None and len(top_idx) > 0:
ax.scatter(top_idx, frame_probs[top_idx], s=24, marker="o")
ax.set_title("Frame-wise seizure probability")
ax.set_xlabel("Frame index (t)")
ax.set_ylabel("p(seizure)")
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3)
fig.tight_layout()
# Lấy ảnh từ canvas (ổn định trên Matplotlib mới)
fig.canvas.draw()
w, h = fig.canvas.get_width_height()
buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
img = buf.reshape(h, w, 4)[..., :3].copy() # drop alpha → RGB
plt.close(fig)
return img
def summarize_probs(frame_probs: np.ndarray, top_k: int = 10,
span_half: int = 12, merge_gap: int = 5) -> tuple[str, np.ndarray, list[list[int]]]:
"""
Trả về (markdown thống kê, top_idx theo p(t) giảm dần, danh sách spans).
An toàn khi số frame rất ít (T=0/1) và khi top_k > T.
"""
# Ép về 1D & float
p = np.asarray(frame_probs, dtype=float).reshape(-1)
M = int(p.size)
if M == 0:
md = "**Frame stats** \n- (no frames)\n"
return md, np.array([], dtype=int), []
# Clamp k theo số frame
try:
k = int(top_k)
except Exception:
k = 10
k = max(1, min(k, M))
kth = max(0, min(k - 1, M - 1))
# Lấy top-k an toàn
if M == 1:
top_idx = np.array([0], dtype=int)
else:
# argpartition rồi sort giảm dần theo p
top_idx = np.argpartition(-p, kth)[:k]
top_idx = top_idx[np.argsort(-p[top_idx], kind="mergesort")]
# Tạo spans từ top-k (đã sort tăng để gộp)
spans = _cluster_spans_from_top(top_idx=np.sort(top_idx), T=M,
span_half=int(max(0, span_half)),
merge_gap=int(max(0, merge_gap)))
# Thống kê
mean = float(np.mean(p)); std = float(np.std(p))
pmin = float(np.min(p)); imin = int(np.argmin(p))
pmax = float(np.max(p)); imax = int(np.argmax(p))
rows = "\n".join([f"| {i} | {p[i]:.4f} |" for i in top_idx])
span_rows = "\n".join([f"- [{l}, {r}] (len={r-l+1})" for l, r in spans]) if spans else "- (none)"
md = (
f"**Frame stats** \n"
f"- frames: **{M}** \n"
f"- mean: **{mean:.4f}** · std: **{std:.4f}** \n"
f"- max: **{pmax:.4f}** tại **t={imax}** · min: **{pmin:.4f}** tại **t={imin}** \n"
f"- top-{k} frames (of {M}): \n\n"
f"| frame t | p(t) |\n|---:|---:|\n{rows if rows else '| - | - |'}\n\n"
f"**Merged spans** từ top-k (pad=±{int(span_half)}, merge_gap≤{int(merge_gap)}):\n{span_rows}\n"
)
return md, top_idx, spans
# ==========================================================
# 4) Demo files (từ repo)
# ==========================================================
DEMO_FILES = []
for i in range(3):
try:
DEMO_FILES.append(hf_hub_download(MODEL_REPO, f"{SPACE_DIR}/demo_clip{i}.npz", repo_type="model"))
except Exception:
pass
print("Model in_dim:", MODEL.in_dim, "state_q:", MODEL.state_q)
# ==========================================================
# 5) Inference handlers
# ==========================================================
def infer_demo(demo_id, top_k, span_half, merge_gap):
if not DEMO_FILES:
raise gr.Error("Không có demo_clip*.npz trong repo. Hãy export ở bước A.")
path = DEMO_FILES[int(demo_id)]
x, ei, ew = load_npz(path)
x = _adapt_features_to_model(x, MODEL.in_dim)
x = torch.tensor(x, dtype=torch.float32, device=device)
ei = torch.tensor(ei, dtype=torch.long, device=device)
ew = torch.tensor(ew, dtype=torch.float32, device=device) if ew is not None else None
clip_p, frame_p = MODEL(x, ei, ew)
p_np = frame_p.cpu().numpy()
stats_md, top_idx, spans = summarize_probs(p_np, top_k=int(top_k),
span_half=int(span_half), merge_gap=int(merge_gap))
img = plot_frame_probs(p_np, top_idx=top_idx, spans=spans)
return f"{clip_p:.4f}", stats_md, img
def infer_custom(file, file_type, clip_idx, top_k, span_half, merge_gap):
if file is None:
raise gr.Error("Hãy upload 1 file H5/NPZ hoặc chọn demo.")
if file_type == "npz":
x, ei, ew = load_npz(file)
else:
x, ei, ew = load_h5(file, clip_idx=clip_idx)
x = _adapt_features_to_model(x, MODEL.in_dim)
x = torch.tensor(x, dtype=torch.float32, device=device)
ei = torch.tensor(ei, dtype=torch.long, device=device)
ew = torch.tensor(ew, dtype=torch.float32, device=device) if ew is not None else None
clip_p, frame_p = MODEL(x, ei, ew)
p_np = frame_p.cpu().numpy()
stats_md, top_idx, spans = summarize_probs(p_np, top_k=int(top_k),
span_half=int(span_half), merge_gap=int(merge_gap))
img = plot_frame_probs(p_np, top_idx=top_idx, spans=spans)
return f"{clip_p:.4f}", stats_md, img
# ==========================================================
# 6) Gradio UI
# ==========================================================
with gr.Blocks(title="REST EEG Seizure Demo") as demo:
gr.Markdown("# REST EEG Seizure – CHB-MIT\nDemo chạy trên CPU (Space). Dữ liệu lớn H5 không tự tải để tiết kiệm tài nguyên.")
with gr.Tab("Demo"):
dsel = gr.Dropdown(choices=[str(i) for i in range(len(DEMO_FILES))],
label="Chọn demo clip", value="0" if DEMO_FILES else None)
dtopk = gr.Slider(1, 50, value=10, step=1, label="Top-k frames to highlight")
dspan = gr.Slider(0, 200, value=12, step=1, label="Span half-width (±frames)")
dgap = gr.Slider(0, 50, value=5, step=1, label="Merge spans if gap ≤")
dbtn = gr.Button("Run demo")
dout = gr.Textbox(label="Clip probability")
dstats = gr.Markdown(label="Frame stats")
dfig = gr.Image(label="Frame-wise probability", type="numpy")
dbtn.click(fn=infer_demo, inputs=[dsel, dtopk, dspan, dgap], outputs=[dout, dstats, dfig])
with gr.Tab("Upload"):
ftype = gr.Radio(choices=["npz","h5"], value="npz", label="Loại file")
fup = gr.File(label="Upload .npz (x, edge_index, edge_weight) hoặc .h5")
cidx = gr.Slider(0, 50, value=0, step=1, label="clip_idx (nếu H5)")
utopk = gr.Slider(1, 50, value=10, step=1, label="Top-k frames to highlight")
uspan = gr.Slider(0, 200, value=12, step=1, label="Span half-width (±frames)")
ugap = gr.Slider(0, 50, value=5, step=1, label="Merge spans if gap ≤")
ubtn = gr.Button("Run inference")
uout = gr.Textbox(label="Clip probability")
ustats = gr.Markdown(label="Frame stats")
ufig = gr.Image(label="Frame-wise probability", type="numpy")
ubtn.click(fn=infer_custom, inputs=[fup, ftype, cidx, utopk, uspan, ugap], outputs=[uout, ustats, ufig])
demo.launch()