|
|
|
|
|
|
|
|
import os, io, json |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import h5py, matplotlib.pyplot as plt |
|
|
from huggingface_hub import hf_hub_download |
|
|
from torch_geometric.nn import GraphConv |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO = os.getenv("MODEL_REPO", "uyen1109/rest_eeg_seizure_analysis") |
|
|
SPACE_DIR = "space_infer" |
|
|
device = torch.device("cpu") |
|
|
torch.set_num_threads(max(1, os.cpu_count() // 2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RESTNet(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_dim: int, |
|
|
state_q: int = 64, |
|
|
w2_in: int | None = None, |
|
|
u_in: int | None = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_dim = in_dim |
|
|
self.state_q = state_q |
|
|
|
|
|
|
|
|
self.W1 = nn.Linear(in_dim, state_q) |
|
|
self.gc1 = GraphConv(state_q, state_q, aggr="mean") |
|
|
self.gc2 = GraphConv(state_q, state_q, aggr="mean") |
|
|
self.fc = nn.Linear(state_q, 1) |
|
|
|
|
|
|
|
|
self._use_W2 = False |
|
|
self._use_U = False |
|
|
|
|
|
if w2_in is not None: |
|
|
if w2_in == in_dim: |
|
|
|
|
|
self.W2 = nn.Linear(in_dim, state_q) |
|
|
self._use_W2 = True |
|
|
elif w2_in == state_q: |
|
|
|
|
|
self.W2 = nn.Linear(state_q, state_q) |
|
|
self._use_W2 = True |
|
|
|
|
|
if u_in is not None and u_in == state_q: |
|
|
|
|
|
self.U = nn.Linear(state_q, state_q, bias=False) |
|
|
self._use_U = True |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, x_ntf, edge_index, edge_weight=None): |
|
|
|
|
|
N, T, Fdim = x_ntf.shape |
|
|
S = torch.zeros(N, self.state_q, device=x_ntf.device) |
|
|
frame_logits = [] |
|
|
for t in range(T): |
|
|
upd = self.W1(x_ntf[:, t, :]) |
|
|
if self._use_W2: |
|
|
if self.W2.in_features == Fdim: |
|
|
upd = upd + self.W2(x_ntf[:, t, :]) |
|
|
else: |
|
|
upd = upd + self.W2(S) |
|
|
|
|
|
if self._use_U: |
|
|
S = self.U(S) + upd |
|
|
else: |
|
|
S = S + upd |
|
|
|
|
|
|
|
|
S = F.relu(self.gc1(S, edge_index, edge_weight)) |
|
|
S = F.relu(self.gc2(S, edge_index, edge_weight)) |
|
|
|
|
|
frame_logits.append(self.fc(S).mean(dim=0, keepdim=True).squeeze(-1)) |
|
|
|
|
|
frame_logits = torch.stack(frame_logits, dim=0) |
|
|
frame_probs = torch.sigmoid(frame_logits) |
|
|
return frame_probs.mean().item(), frame_probs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hub_download(fname: str) -> str: |
|
|
return hf_hub_download(MODEL_REPO, f"{SPACE_DIR}/{fname}", repo_type="model") |
|
|
|
|
|
CFG_PATH = hub_download("rest_config.json") |
|
|
STATE_PATH = hub_download("rest_state.pt") |
|
|
|
|
|
with open(CFG_PATH, "r") as f: |
|
|
CFG = json.load(f) |
|
|
|
|
|
sd = torch.load(STATE_PATH, map_location="cpu") |
|
|
|
|
|
|
|
|
if "W1.weight" in sd: |
|
|
Q_MODEL, F_MODEL = sd["W1.weight"].shape |
|
|
else: |
|
|
F_MODEL = int(CFG.get("in_feat", 128)) |
|
|
Q_MODEL = int(CFG.get("state_q", 32)) |
|
|
|
|
|
|
|
|
w2_in = sd["W2.weight"].shape[1] if "W2.weight" in sd else None |
|
|
u_in = sd["U.weight"].shape[1] if "U.weight" in sd else None |
|
|
|
|
|
MODEL = RESTNet(in_dim=F_MODEL, state_q=Q_MODEL, w2_in=w2_in, u_in=u_in).to(device).eval() |
|
|
|
|
|
|
|
|
try: |
|
|
missing, unexpected = MODEL.load_state_dict(sd, strict=True) |
|
|
print(f"Loaded strict: F={F_MODEL}, Q={Q_MODEL}, W2_in={w2_in}, U_in={u_in}") |
|
|
print("Missing keys:", missing) |
|
|
print("Unexpected keys:", unexpected) |
|
|
except Exception as e: |
|
|
print("Strict load failed:", e) |
|
|
missing, unexpected = MODEL.load_state_dict(sd, strict=False) |
|
|
print("Fallback non-strict load. Missing:", missing, "Unexpected:", unexpected) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_x_shape(x: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Đưa x về [N, T, F] theo heuristic: |
|
|
- N ~ [8..128] (số kênh EEG) |
|
|
- T >= 2 |
|
|
- F >= 1 |
|
|
""" |
|
|
assert x.ndim == 3, "Input x phải có 3 chiều." |
|
|
perms = [(0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0)] |
|
|
for p in perms: |
|
|
y = np.transpose(x, p) |
|
|
N, T, Fdim = y.shape |
|
|
if 8 <= N <= 128 and T >= 2 and Fdim >= 1: |
|
|
return y |
|
|
return x |
|
|
|
|
|
def _adapt_features_to_model(x_ntf: np.ndarray, target_F: int) -> np.ndarray: |
|
|
""" |
|
|
Pad hoặc cắt trục F để khớp target_F (model.in_dim). |
|
|
""" |
|
|
N, T, Fdim = x_ntf.shape |
|
|
if Fdim == target_F: |
|
|
return x_ntf |
|
|
if Fdim > target_F: |
|
|
return x_ntf[..., :target_F] |
|
|
pad = target_F - Fdim |
|
|
return np.pad(x_ntf, ((0,0),(0,0),(0,pad)), mode="constant") |
|
|
|
|
|
def load_npz(npz_file): |
|
|
if isinstance(npz_file, str): |
|
|
npz = np.load(npz_file, allow_pickle=True) |
|
|
else: |
|
|
buf = io.BytesIO(npz_file.read()) |
|
|
npz = np.load(buf, allow_pickle=True) |
|
|
x = np.asarray(npz["x"]) |
|
|
x = _normalize_x_shape(x) |
|
|
edge_index = np.asarray(npz["edge_index"]) |
|
|
edge_weight = np.asarray(npz["edge_weight"]) if "edge_weight" in npz else None |
|
|
return x, edge_index, edge_weight |
|
|
|
|
|
def load_h5(h5_file, clip_idx=0): |
|
|
if isinstance(h5_file, str): |
|
|
f = h5py.File(h5_file, "r") |
|
|
else: |
|
|
buf = io.BytesIO(h5_file.read()) |
|
|
f = h5py.File(buf, "r") |
|
|
keys = list(f.keys()) |
|
|
for k in ["x","clips","X"]: |
|
|
if k in keys: |
|
|
X = f[k] |
|
|
break |
|
|
else: |
|
|
raise gr.Error("Không tìm thấy dataset 'x'/'clips'/'X' trong H5.") |
|
|
|
|
|
if X.ndim == 3: |
|
|
x = X[:] |
|
|
else: |
|
|
x = X[clip_idx] |
|
|
x = _normalize_x_shape(x) |
|
|
for k in ["edge_index","edge_idx","edges"]: |
|
|
if k in keys: |
|
|
edge_index = f[k][:] |
|
|
break |
|
|
else: |
|
|
raise gr.Error("Không có 'edge_index' trong H5.") |
|
|
edge_weight = None |
|
|
for k in ["edge_weight","edge_w","weights"]: |
|
|
if k in keys: |
|
|
edge_weight = f[k][:] |
|
|
break |
|
|
f.close() |
|
|
return x, edge_index, edge_weight |
|
|
|
|
|
def _cluster_spans_from_top(top_idx: np.ndarray, T: int, span_half: int, merge_gap: int): |
|
|
""" |
|
|
Từ danh sách frame top-k (đã sort tăng), tạo các span [l,r] với padding 'span_half', |
|
|
và merge nếu khoảng cách giữa các span liền kề <= merge_gap. |
|
|
""" |
|
|
if top_idx is None or len(top_idx) == 0: |
|
|
return [] |
|
|
span_half = max(0, int(span_half)) |
|
|
merge_gap = max(0, int(merge_gap)) |
|
|
|
|
|
|
|
|
spans = [] |
|
|
for t in top_idx: |
|
|
l = max(0, int(t) - span_half) |
|
|
r = min(T - 1, int(t) + span_half) |
|
|
spans.append([l, r]) |
|
|
|
|
|
|
|
|
spans.sort(key=lambda x: x[0]) |
|
|
merged = [] |
|
|
cur_l, cur_r = spans[0] |
|
|
for l, r in spans[1:]: |
|
|
if l <= cur_r + merge_gap: |
|
|
cur_r = max(cur_r, r) |
|
|
else: |
|
|
merged.append([cur_l, cur_r]) |
|
|
cur_l, cur_r = l, r |
|
|
merged.append([cur_l, cur_r]) |
|
|
|
|
|
return merged |
|
|
|
|
|
def plot_frame_probs(frame_probs: np.ndarray, top_idx: np.ndarray | None = None, |
|
|
spans: list[list[int]] | None = None) -> np.ndarray: |
|
|
|
|
|
fig = plt.figure(figsize=(7, 2.8)) |
|
|
ax = fig.add_subplot(111) |
|
|
t = np.arange(len(frame_probs)) |
|
|
ax.plot(t, frame_probs, lw=1.5) |
|
|
|
|
|
|
|
|
if spans: |
|
|
for l, r in spans: |
|
|
ax.axvspan(l, r, alpha=0.15) |
|
|
|
|
|
|
|
|
if top_idx is not None and len(top_idx) > 0: |
|
|
ax.scatter(top_idx, frame_probs[top_idx], s=24, marker="o") |
|
|
|
|
|
ax.set_title("Frame-wise seizure probability") |
|
|
ax.set_xlabel("Frame index (t)") |
|
|
ax.set_ylabel("p(seizure)") |
|
|
ax.set_ylim(0, 1) |
|
|
ax.grid(True, alpha=0.3) |
|
|
fig.tight_layout() |
|
|
|
|
|
|
|
|
fig.canvas.draw() |
|
|
w, h = fig.canvas.get_width_height() |
|
|
buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) |
|
|
img = buf.reshape(h, w, 4)[..., :3].copy() |
|
|
plt.close(fig) |
|
|
return img |
|
|
|
|
|
def summarize_probs(frame_probs: np.ndarray, top_k: int = 10, |
|
|
span_half: int = 12, merge_gap: int = 5) -> tuple[str, np.ndarray, list[list[int]]]: |
|
|
""" |
|
|
Trả về (markdown thống kê, top_idx theo p(t) giảm dần, danh sách spans). |
|
|
An toàn khi số frame rất ít (T=0/1) và khi top_k > T. |
|
|
""" |
|
|
|
|
|
p = np.asarray(frame_probs, dtype=float).reshape(-1) |
|
|
M = int(p.size) |
|
|
|
|
|
if M == 0: |
|
|
md = "**Frame stats** \n- (no frames)\n" |
|
|
return md, np.array([], dtype=int), [] |
|
|
|
|
|
|
|
|
try: |
|
|
k = int(top_k) |
|
|
except Exception: |
|
|
k = 10 |
|
|
k = max(1, min(k, M)) |
|
|
kth = max(0, min(k - 1, M - 1)) |
|
|
|
|
|
|
|
|
if M == 1: |
|
|
top_idx = np.array([0], dtype=int) |
|
|
else: |
|
|
|
|
|
top_idx = np.argpartition(-p, kth)[:k] |
|
|
top_idx = top_idx[np.argsort(-p[top_idx], kind="mergesort")] |
|
|
|
|
|
|
|
|
spans = _cluster_spans_from_top(top_idx=np.sort(top_idx), T=M, |
|
|
span_half=int(max(0, span_half)), |
|
|
merge_gap=int(max(0, merge_gap))) |
|
|
|
|
|
|
|
|
mean = float(np.mean(p)); std = float(np.std(p)) |
|
|
pmin = float(np.min(p)); imin = int(np.argmin(p)) |
|
|
pmax = float(np.max(p)); imax = int(np.argmax(p)) |
|
|
|
|
|
rows = "\n".join([f"| {i} | {p[i]:.4f} |" for i in top_idx]) |
|
|
span_rows = "\n".join([f"- [{l}, {r}] (len={r-l+1})" for l, r in spans]) if spans else "- (none)" |
|
|
md = ( |
|
|
f"**Frame stats** \n" |
|
|
f"- frames: **{M}** \n" |
|
|
f"- mean: **{mean:.4f}** · std: **{std:.4f}** \n" |
|
|
f"- max: **{pmax:.4f}** tại **t={imax}** · min: **{pmin:.4f}** tại **t={imin}** \n" |
|
|
f"- top-{k} frames (of {M}): \n\n" |
|
|
f"| frame t | p(t) |\n|---:|---:|\n{rows if rows else '| - | - |'}\n\n" |
|
|
f"**Merged spans** từ top-k (pad=±{int(span_half)}, merge_gap≤{int(merge_gap)}):\n{span_rows}\n" |
|
|
) |
|
|
return md, top_idx, spans |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEMO_FILES = [] |
|
|
for i in range(3): |
|
|
try: |
|
|
DEMO_FILES.append(hf_hub_download(MODEL_REPO, f"{SPACE_DIR}/demo_clip{i}.npz", repo_type="model")) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
print("Model in_dim:", MODEL.in_dim, "state_q:", MODEL.state_q) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_demo(demo_id, top_k, span_half, merge_gap): |
|
|
if not DEMO_FILES: |
|
|
raise gr.Error("Không có demo_clip*.npz trong repo. Hãy export ở bước A.") |
|
|
path = DEMO_FILES[int(demo_id)] |
|
|
x, ei, ew = load_npz(path) |
|
|
x = _adapt_features_to_model(x, MODEL.in_dim) |
|
|
|
|
|
x = torch.tensor(x, dtype=torch.float32, device=device) |
|
|
ei = torch.tensor(ei, dtype=torch.long, device=device) |
|
|
ew = torch.tensor(ew, dtype=torch.float32, device=device) if ew is not None else None |
|
|
|
|
|
clip_p, frame_p = MODEL(x, ei, ew) |
|
|
p_np = frame_p.cpu().numpy() |
|
|
stats_md, top_idx, spans = summarize_probs(p_np, top_k=int(top_k), |
|
|
span_half=int(span_half), merge_gap=int(merge_gap)) |
|
|
img = plot_frame_probs(p_np, top_idx=top_idx, spans=spans) |
|
|
return f"{clip_p:.4f}", stats_md, img |
|
|
|
|
|
def infer_custom(file, file_type, clip_idx, top_k, span_half, merge_gap): |
|
|
if file is None: |
|
|
raise gr.Error("Hãy upload 1 file H5/NPZ hoặc chọn demo.") |
|
|
if file_type == "npz": |
|
|
x, ei, ew = load_npz(file) |
|
|
else: |
|
|
x, ei, ew = load_h5(file, clip_idx=clip_idx) |
|
|
|
|
|
x = _adapt_features_to_model(x, MODEL.in_dim) |
|
|
|
|
|
x = torch.tensor(x, dtype=torch.float32, device=device) |
|
|
ei = torch.tensor(ei, dtype=torch.long, device=device) |
|
|
ew = torch.tensor(ew, dtype=torch.float32, device=device) if ew is not None else None |
|
|
|
|
|
clip_p, frame_p = MODEL(x, ei, ew) |
|
|
p_np = frame_p.cpu().numpy() |
|
|
stats_md, top_idx, spans = summarize_probs(p_np, top_k=int(top_k), |
|
|
span_half=int(span_half), merge_gap=int(merge_gap)) |
|
|
img = plot_frame_probs(p_np, top_idx=top_idx, spans=spans) |
|
|
return f"{clip_p:.4f}", stats_md, img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="REST EEG Seizure Demo") as demo: |
|
|
gr.Markdown("# REST EEG Seizure – CHB-MIT\nDemo chạy trên CPU (Space). Dữ liệu lớn H5 không tự tải để tiết kiệm tài nguyên.") |
|
|
|
|
|
with gr.Tab("Demo"): |
|
|
dsel = gr.Dropdown(choices=[str(i) for i in range(len(DEMO_FILES))], |
|
|
label="Chọn demo clip", value="0" if DEMO_FILES else None) |
|
|
dtopk = gr.Slider(1, 50, value=10, step=1, label="Top-k frames to highlight") |
|
|
dspan = gr.Slider(0, 200, value=12, step=1, label="Span half-width (±frames)") |
|
|
dgap = gr.Slider(0, 50, value=5, step=1, label="Merge spans if gap ≤") |
|
|
dbtn = gr.Button("Run demo") |
|
|
dout = gr.Textbox(label="Clip probability") |
|
|
dstats = gr.Markdown(label="Frame stats") |
|
|
dfig = gr.Image(label="Frame-wise probability", type="numpy") |
|
|
dbtn.click(fn=infer_demo, inputs=[dsel, dtopk, dspan, dgap], outputs=[dout, dstats, dfig]) |
|
|
|
|
|
with gr.Tab("Upload"): |
|
|
ftype = gr.Radio(choices=["npz","h5"], value="npz", label="Loại file") |
|
|
fup = gr.File(label="Upload .npz (x, edge_index, edge_weight) hoặc .h5") |
|
|
cidx = gr.Slider(0, 50, value=0, step=1, label="clip_idx (nếu H5)") |
|
|
utopk = gr.Slider(1, 50, value=10, step=1, label="Top-k frames to highlight") |
|
|
uspan = gr.Slider(0, 200, value=12, step=1, label="Span half-width (±frames)") |
|
|
ugap = gr.Slider(0, 50, value=5, step=1, label="Merge spans if gap ≤") |
|
|
ubtn = gr.Button("Run inference") |
|
|
uout = gr.Textbox(label="Clip probability") |
|
|
ustats = gr.Markdown(label="Frame stats") |
|
|
ufig = gr.Image(label="Frame-wise probability", type="numpy") |
|
|
ubtn.click(fn=infer_custom, inputs=[fup, ftype, cidx, utopk, uspan, ugap], outputs=[uout, ustats, ufig]) |
|
|
|
|
|
demo.launch() |
|
|
|