import glob import os from dataclasses import dataclass from typing import Any, Optional import gradio as gr import matplotlib.pyplot as plt import numpy as np import ruptures as rpt import torch from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from TaikoChartEstimator.data.tokenizer import EventTokenizer from TaikoChartEstimator.model.model import TaikoChartEstimator @dataclass class ParsedCourse: name: str level: Optional[int] segments: list[dict] difficulty_hint: Optional[str] @dataclass class ParsedTJA: meta: dict[str, Any] courses: dict[str, ParsedCourse] NOTE_DIGIT_TO_TYPE = { "1": "Don", "2": "Ka", "3": "DonBig", "4": "KaBig", "5": "Roll", "6": "RollBig", "7": "Balloon", "8": "EndOf", "9": "BalloonAlt", } def _strip_comment(line: str) -> str: if "//" in line: line = line.split("//", 1)[0] return line.strip() def parse_tja(text: str) -> ParsedTJA: """Parse a (single-song) TJA into dataset-like `segments` per course. Supported (best-effort): COURSE/LEVEL, BPM, OFFSET, #START/#END, #BPMCHANGE, #MEASURE, #SCROLL, #DELAY, #GOGOSTART/#GOGOEND. Branching commands are ignored. """ if not text or not text.strip(): raise ValueError("Empty TJA input") text = text.replace("\ufeff", "") lines = [_strip_comment(l) for l in text.replace("\r\n", "\n").split("\n")] lines = [l for l in lines if l] meta: dict[str, Any] = {} courses: dict[str, dict[str, Any]] = {} current_course: Optional[dict[str, Any]] = None in_chart = False bpm = 120.0 offset = 0.0 measure_num = 4 measure_den = 4 scroll = 1.0 gogo = False current_time = 0.0 measure_start_time = 0.0 measure_digits: list[str] = [] def beats_per_measure() -> float: # TJA: #MEASURE a/b means measure length = 4 * a / b quarter-note beats return 4.0 * float(measure_num) / float(measure_den) def measure_duration_sec(local_bpm: float) -> float: return beats_per_measure() * 60.0 / max(local_bpm, 1e-6) def flush_measure_if_any() -> None: nonlocal current_time, measure_start_time, measure_digits if current_course is None: return digits = "".join(measure_digits).strip() if not digits: return dur = measure_duration_sec(bpm) step = dur / max(len(digits), 1) notes: list[dict] = [] for i, ch in enumerate(digits): if ch == "0": continue note_type = NOTE_DIGIT_TO_TYPE.get(ch) if not note_type: continue t = measure_start_time + i * step notes.append( { "note_type": note_type, "timestamp": float(t), "bpm": float(bpm), "scroll": float(scroll), "gogo": bool(gogo), } ) current_course["segments"].append( { "timestamp": float(measure_start_time), "measure_num": int(measure_num), "measure_den": int(measure_den), "notes": notes, } ) # Advance time by exactly one measure current_time = measure_start_time + dur measure_start_time = current_time measure_digits = [] def finalize_long_note_durations() -> None: if current_course is None: return # Flatten notes flat: list[dict] = [] for seg in current_course["segments"]: for n in seg.get("notes", []): flat.append(n) flat.sort(key=lambda n: n.get("timestamp", 0.0)) open_idx: list[int] = [] for i, n in enumerate(flat): nt = n.get("note_type") if nt in {"Roll", "RollBig", "Balloon", "BalloonAlt"}: open_idx.append(i) elif nt == "EndOf" and open_idx: start_i = open_idx.pop() start = flat[start_i] start_bpm = float(start.get("bpm", 120.0)) dt = float(n.get("timestamp", 0.0)) - float(start.get("timestamp", 0.0)) dur_beats = max(0.0, dt * start_bpm / 60.0) start["delay"] = float(dur_beats) def ensure_course(name: str) -> dict[str, Any]: nonlocal courses if name not in courses: courses[name] = { "name": name, "level": None, "segments": [], "difficulty_hint": None, } return courses[name] for raw in lines: line = raw.strip() if not in_chart and ":" in line and not line.startswith("#"): k, v = [p.strip() for p in line.split(":", 1)] ku = k.upper() meta[ku] = v if ku == "BPM": try: bpm = float(v) except ValueError: pass elif ku == "OFFSET": try: offset = float(v) except ValueError: pass elif ku == "COURSE": current_course = ensure_course(v) # Reset per-course chart state in_chart = False elif ku == "LEVEL" and current_course is not None: try: current_course["level"] = int(float(v)) except ValueError: current_course["level"] = None continue if line.startswith("#START"): if current_course is None: current_course = ensure_course("(default)") # Reset chart state at start in_chart = True bpm = float(meta.get("BPM", bpm) or bpm) try: offset = float(meta.get("OFFSET", offset) or offset) except ValueError: offset = offset measure_num, measure_den = 4, 4 scroll = 1.0 gogo = False current_time = 0.0 measure_start_time = 0.0 measure_digits = [] # Apply offset as a global shift (best-effort) current_time += float(offset) measure_start_time = current_time continue if not in_chart: continue if line.startswith("#END"): flush_measure_if_any() finalize_long_note_durations() in_chart = False continue if line.startswith("#"): cmd = line[1:].strip() cmd_u = cmd.upper() if cmd_u.startswith("BPMCHANGE"): flush_measure_if_any() try: bpm = float(cmd.split(maxsplit=1)[1]) except Exception: pass elif cmd_u.startswith("MEASURE"): flush_measure_if_any() try: frac = cmd.split(maxsplit=1)[1].strip() a, b = frac.split("/", 1) measure_num = int(a) measure_den = int(b) except Exception: pass elif cmd_u.startswith("SCROLL"): flush_measure_if_any() try: scroll = float(cmd.split(maxsplit=1)[1]) except Exception: pass elif cmd_u.startswith("DELAY"): flush_measure_if_any() try: current_time += float(cmd.split(maxsplit=1)[1]) except Exception: pass measure_start_time = current_time elif cmd_u.startswith("GOGOSTART"): flush_measure_if_any() gogo = True elif cmd_u.startswith("GOGOEND"): flush_measure_if_any() gogo = False else: # Ignore other commands (branching etc.) pass continue # Note data: may contain multiple commas for ch in line: if ch.isdigit(): measure_digits.append(ch) elif ch == ",": flush_measure_if_any() # Build ParsedTJA parsed_courses: dict[str, ParsedCourse] = {} difficulty_map = { "0": "easy", "easy": "easy", "1": "normal", "normal": "normal", "2": "hard", "hard": "hard", "3": "oni", "oni": "oni", "4": "oni", "ura": "oni", "edit": "oni", } for name, c in courses.items(): name_l = name.strip().lower() hint = difficulty_map.get(name_l) parsed_courses[name] = ParsedCourse( name=name, level=c.get("level"), segments=c.get("segments", []), difficulty_hint=hint, ) return ParsedTJA(meta=meta, courses=parsed_courses) def _discover_checkpoints() -> list[str]: # Prefer local trained outputs paths = [] for p in glob.glob("outputs/*/pretrained/*"): if os.path.isdir(p) and os.path.exists(os.path.join(p, "config.json")): paths.append(p) # Also accept HF / user-provided paths via manual input if not paths: return [ "JacobLinCool/TaikoChartEstimator-20251228", "JacobLinCool/TaikoChartEstimator-20251229", ] return sorted(paths) _MODEL_CACHE: dict[str, TaikoChartEstimator] = {} def _resolve_device(device: str) -> str: device = (device or "cpu").lower() if device == "cuda" and torch.cuda.is_available(): return "cuda" if ( device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ): return "mps" return "cpu" def _load_model(checkpoint_path: str, device: str) -> TaikoChartEstimator: device = _resolve_device(device) key = f"{checkpoint_path}::{device}" if key in _MODEL_CACHE: return _MODEL_CACHE[key] model = TaikoChartEstimator.from_pretrained(checkpoint_path) model.eval() model.to(torch.device(device)) _MODEL_CACHE[key] = model return model def _build_instances_from_segments( segments: list[dict], max_tokens_per_instance: int, window_measures: list[int], hop_measures: int, max_instances_per_chart: int, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, list[tuple[float, float]], list[int] ]: tokenizer = EventTokenizer() tokens = tokenizer.tokenize_chart(segments) all_instances: list[torch.Tensor] = [] all_masks: list[torch.Tensor] = [] all_times: list[tuple[float, float]] = [] all_token_counts: list[int] = [] for window_size in window_measures: windows = tokenizer.create_windows( tokens, window_measures=window_size, hop_measures=hop_measures ) for window_tokens in windows: if not window_tokens: continue tensor, mask = tokenizer.tokens_to_tensor( window_tokens, max_length=max_tokens_per_instance ) all_token_counts.append(int(mask.sum().item())) tensor, mask = tokenizer.pad_sequence(tensor, mask, max_tokens_per_instance) all_instances.append(tensor) all_masks.append(mask) all_times.append( (float(window_tokens[0].timestamp), float(window_tokens[-1].timestamp)) ) if not all_instances: raise ValueError("No note events parsed (empty chart or unsupported format)") if len(all_instances) > max_instances_per_chart: idx = np.linspace( 0, len(all_instances) - 1, max_instances_per_chart, dtype=int ).tolist() all_instances = [all_instances[i] for i in idx] all_masks = [all_masks[i] for i in idx] all_times = [all_times[i] for i in idx] all_token_counts = [all_token_counts[i] for i in idx] instances = torch.stack(all_instances).unsqueeze(0) # [1, N, L, 6] masks = torch.stack(all_masks).unsqueeze(0) # [1, N, L] counts = torch.tensor([len(all_instances)], dtype=torch.long) # [1] return instances, masks, counts, all_times, all_token_counts def _plot_attention( times: list[tuple[float, float]], avg_attention: np.ndarray, topk_mask: Optional[np.ndarray], title: str, ): # Sort by time to avoid misleading zig-zag lines when windows are generated in mixed order. t0 = np.array([a for a, _ in times], dtype=np.float64) t1 = np.array([b for _, b in times], dtype=np.float64) mids = (t0 + t1) / 2.0 order = np.argsort(mids) mids_s = mids[order] attn_s = avg_attention[order] topk_s = topk_mask[order] if topk_mask is not None else None fig, ax = plt.subplots(figsize=(10, 3.2)) ax.scatter(mids_s, attn_s, s=14, alpha=0.8, label="Instance") ax.plot(mids_s, attn_s, linewidth=1.5, alpha=0.6) if topk_s is not None: sel = topk_s.astype(bool) ax.scatter( mids_s[sel], attn_s[sel], s=40, marker="o", edgecolors="black", linewidths=0.4, label="Top-k", ) ax.set_xlabel("Time (s)") ax.set_ylabel("Avg attention (weight)") ax.set_title(title) ax.grid(True, alpha=0.25) ax.legend(loc="best") fig.tight_layout() return fig def _plot_branch_heatmap(branch_attn: np.ndarray, title: str): # branch_attn: [n_branches, n_instances] fig, ax = plt.subplots(figsize=(10, 3.2)) im = ax.imshow(branch_attn, aspect="auto", interpolation="nearest") ax.set_title(title) ax.set_xlabel("Instance (time-sorted)") ax.set_ylabel("Branch") cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04) cbar.set_label("Attention weight") fig.tight_layout() return fig def _plot_density_and_attention( times: list[tuple[float, float]], token_counts: list[int], avg_attention: np.ndarray, topk_mask: Optional[np.ndarray], title: str, ): t0 = np.array([a for a, _ in times], dtype=np.float64) t1 = np.array([b for _, b in times], dtype=np.float64) mids = (t0 + t1) / 2.0 durations = np.maximum(t1 - t0, 1e-6) token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64) density = token_counts_np / durations order = np.argsort(mids) mids_s = mids[order] dens_s = density[order] attn_s = avg_attention[order] topk_s = topk_mask[order] if topk_mask is not None else None fig, ax1 = plt.subplots(figsize=(10, 3.2)) ax1.plot(mids_s, dens_s, linewidth=1.8, color="tab:blue", label="Token density") ax1.set_xlabel("Time (s)") ax1.set_ylabel("Tokens / sec", color="tab:blue") ax1.tick_params(axis="y", labelcolor="tab:blue") ax1.grid(True, alpha=0.25) ax2 = ax1.twinx() ax2.scatter( mids_s, attn_s, s=14, color="tab:orange", alpha=0.75, label="Avg attention" ) if topk_s is not None: sel = topk_s.astype(bool) ax2.scatter( mids_s[sel], attn_s[sel], s=40, marker="o", edgecolors="black", linewidths=0.4, color="tab:orange", label="Top-k attention", ) ax2.set_ylabel("Avg attention", color="tab:orange") ax2.tick_params(axis="y", labelcolor="tab:orange") ax1.set_title(title) # Merge legends h1, l1 = ax1.get_legend_handles_labels() h2, l2 = ax2.get_legend_handles_labels() ax1.legend(h1 + h2, l1 + l2, loc="best") fig.tight_layout() return fig def _plot_local_difficulty( times: list[tuple[float, float]], local_stars: np.ndarray, token_counts: list[int], title: str, ): """Plot estimated local difficulty (star rating) over time.""" t0 = np.array([a for a, _ in times], dtype=np.float64) t1 = np.array([b for _, b in times], dtype=np.float64) mids = (t0 + t1) / 2.0 durations = np.maximum(t1 - t0, 1e-6) token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64) density = token_counts_np / durations order = np.argsort(mids) mids_s = mids[order] stars_s = local_stars[order] dens_s = density[order] # EMA Smoothing # Alpha = 2 / (span + 1), for span=4 (approx 8-16s depending on window) -> alpha=0.4 alpha = 0.3 if len(stars_s) > 0: stars_smooth = np.zeros_like(stars_s) stars_smooth[0] = stars_s[0] for i in range(1, len(stars_s)): stars_smooth[i] = alpha * stars_s[i] + (1 - alpha) * stars_smooth[i - 1] else: stars_smooth = stars_s fig, ax1 = plt.subplots(figsize=(10, 3.5)) # Plot difficulty curve color = "tab:red" ax1.set_xlabel("Time (s)") ax1.set_ylabel("Estimated Local Stars", color=color) # Plot raw faint ax1.plot(mids_s, stars_s, color=color, linewidth=1, alpha=0.3, label="Raw") # Plot smoothed main ax1.plot(mids_s, stars_smooth, color=color, linewidth=2.5, label="Smoothed (EMA)") ax1.tick_params(axis="y", labelcolor=color) ax1.grid(True, alpha=0.25) # Fill area under smoothed curve ax1.fill_between(mids_s, stars_smooth, alpha=0.1, color=color) # Plot density on secondary axis for context ax2 = ax1.twinx() color2 = "tab:blue" ax2.set_ylabel("Density (notes/s)", color=color2) ax2.plot( mids_s, dens_s, color=color2, linewidth=1, linestyle="--", alpha=0.5, label="Note Density", ) ax2.tick_params(axis="y", labelcolor=color2) ax1.set_title(title) # Legends lines1, labels1 = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right") fig.tight_layout() return fig def _smooth_embeddings(embeddings: np.ndarray, window_size: int = 3) -> np.ndarray: """Apply temporal smoothing (moving average) to embeddings.""" if len(embeddings) < window_size: return embeddings # Kernel for simple moving average kernel = np.ones(window_size) / window_size # Apply to each dimension independenty # We can use scipy.ndimage.convolve1d or simplified numpy for dependency-free smoothed = np.zeros_like(embeddings) for dim in range(embeddings.shape[1]): # Padding: 'edge' mode equivalent x = embeddings[:, dim] pad_width = window_size // 2 padded = np.pad(x, pad_width, mode="edge") # Convolve s = np.convolve(padded, kernel, mode="valid") # Handle shape mismatch due to even/odd window if len(s) > len(x): s = s[: len(x)] elif len(s) < len(x): # Should not happen with padded='edge' widely enough but just in case s = np.pad(s, (0, len(x) - len(s)), mode="edge") smoothed[:, dim] = s return smoothed def _smooth_labels(labels: np.ndarray, window_size: int = 3) -> np.ndarray: """Apply mode filter to labels to enforce temporal continuity.""" if len(labels) < window_size: return labels n = len(labels) smoothed = labels.copy() pad = window_size // 2 # Simple sliding window mode for i in range(n): start = max(0, i - pad) end = min(n, i + pad + 1) window = labels[start:end] # Find mode counts = np.bincount(window) smoothed[i] = np.argmax(counts) return smoothed def _perform_clustering( embeddings: np.ndarray, min_k: int = 3, max_k: int = 8, smoothing_window: int = 3, label_smoothing_window: int = 3, random_state: int = 42, ) -> tuple[np.ndarray, int, dict]: """ Perform K-Means clustering with automatic K selection using Silhouette Score. Applying temporal smoothing to stabilize clusters. Args: embeddings: [N, D] data points min_k: Minimum number of clusters max_k: Maximum number of clusters Returns: labels: [N] cluster labels best_k: Selected number of clusters stats: Info about clustering quality """ # Simply if N is too small N = embeddings.shape[0] if N < min_k: return np.zeros(N, dtype=int), 1, {"score": 0.0} # 1. Temporal Smoothing if smoothing_window > 1: # print(f"Smoothing embeddings with window={smoothing_window}") work_embeddings = _smooth_embeddings(embeddings, window_size=smoothing_window) else: work_embeddings = embeddings best_score = -1.0 best_k = min_k best_model = None print(f"Clustering {N} instances...") effective_max_k = min(max_k, N - 1) if effective_max_k < min_k: effective_max_k = min_k for k in range(min_k, effective_max_k + 1): kmeans = KMeans(n_clusters=k, random_state=random_state, n_init=10) labels = kmeans.fit_predict(work_embeddings) try: score = silhouette_score(work_embeddings, labels) # print(f"K={k}, Silhouette={score:.4f}") if score > best_score: best_score = score best_k = k best_model = kmeans except Exception: pass if best_model is None: # Fallback kmeans = KMeans(n_clusters=min_k, random_state=random_state, n_init=10) kmeans.fit(work_embeddings) best_model = kmeans best_k = min_k labels = best_model.labels_ # 2. Label Smoothing (Post-processing) if label_smoothing_window > 1: labels = _smooth_labels(labels, window_size=label_smoothing_window) return labels, best_k, {"silhouette": best_score} def _analyze_clusters( cluster_labels: np.ndarray, local_stars: np.ndarray, note_density: np.ndarray, avg_attention: Optional[np.ndarray] = None, ) -> list[dict]: """ Analyze properties of each cluster to create a profile. Returns list of dicts: [{id, count, avg_stars, avg_density, avg_attn, desc}] """ unique_labels = np.unique(cluster_labels) profiles = [] for label in unique_labels: mask = cluster_labels == label count = mask.sum() avg_s = local_stars[mask].mean() if len(local_stars) > 0 else 0 avg_d = note_density[mask].mean() if len(note_density) > 0 else 0 avg_a = avg_attention[mask].mean() if avg_attention is not None else 0 profiles.append( { "Cluster ID": int(label), "Count": int(count), "Avg Stars": float(f"{avg_s:.2f}"), "Avg Density": float(f"{avg_d:.2f}"), "Avg Attention": float(f"{avg_a:.4f}"), } ) # Sort by Avg Stars to make it intuitive (Cluster 0 = Easiest or Hardest?) # Let's keep ID but maybe we can add a rank? # Sorting purely by ID is safer for consistency with plot colors. profiles.sort(key=lambda x: x["Cluster ID"]) return profiles def _plot_clusters( times: list[tuple[float, float]], cluster_labels: np.ndarray, local_stars: np.ndarray, title: str, ): """Plot timeline colored by cluster ID.""" t0 = np.array([a for a, _ in times], dtype=np.float64) t1 = np.array([b for _, b in times], dtype=np.float64) mids = (t0 + t1) / 2.0 # Sort order = np.argsort(mids) mids_s = mids[order] stars_s = local_stars[order] labels_s = cluster_labels[order] unique_labels = np.unique(labels_s) n_clusters = len(unique_labels) # Use a distinct colormap cmap = plt.get_cmap("tab10" if n_clusters <= 10 else "tab20") fig, ax = plt.subplots(figsize=(10, 3.5)) # We want to plot segments. Since they are time-sorted, we can just scatter or valid-bar plot. # A step plot or bar plot might be good. # Let's use a scatter plot for simplicity but heavy markers. for i, label in enumerate(unique_labels): mask = labels_s == label ax.scatter( mids_s[mask], stars_s[mask], color=cmap(i), label=f"Cluster {label}", s=20, alpha=0.8, ) # Also plot a faint line to show connectivity ax.plot(mids_s, stars_s, color="gray", alpha=0.2, linewidth=1) ax.set_xlabel("Time (s)") ax.set_ylabel("Local Stars") ax.set_title(title) ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) ax.grid(True, alpha=0.25) fig.tight_layout() return fig def _detect_segments( local_stars: np.ndarray, times: list[tuple[float, float]], min_segment_size: int = 3, penalty_scale: float = 1.0, ) -> list[dict]: """ Detect segments using Change Point Detection. IMPORTANT: Windows may not be in temporal order (e.g., mixed window sizes). We sort by midpoint time first to ensure temporal coherence. """ n = len(local_stars) if n < min_segment_size * 2: return [ { "start_time": times[0][0], "end_time": times[-1][1], "avg_stars": float(local_stars.mean()), "n_windows": n, } ] # Calculate window midpoints mids = np.array([(t0 + t1) / 2 for t0, t1 in times]) # SORT by midpoint time (critical for temporal coherence!) order = np.argsort(mids) mids_sorted = mids[order] stars_sorted = local_stars[order] times_sorted = [times[i] for i in order] # Build cell boundaries (1D Voronoi on SORTED windows) cell_bounds = [times_sorted[0][0]] # Song start for i in range(len(mids_sorted) - 1): cell_bounds.append((mids_sorted[i] + mids_sorted[i + 1]) / 2) cell_bounds.append(times_sorted[-1][1]) # Song end # Ruptures detection (on SORTED data) signal = stars_sorted.reshape(-1, 1) penalty = np.var(stars_sorted) * penalty_scale algo = rpt.Pelt(model="l2", min_size=min_segment_size).fit(signal) change_points = algo.predict(pen=penalty) # Build segments segments = [] prev_idx = 0 for cp in change_points: seg_stars = stars_sorted[prev_idx:cp] start_t = cell_bounds[prev_idx] end_t = cell_bounds[cp] segments.append( { "start_time": float(start_t), "end_time": float(end_t), "avg_stars": float(seg_stars.mean()), "n_windows": cp - prev_idx, } ) prev_idx = cp return segments def _plot_segments( times: list[tuple[float, float]], local_stars: np.ndarray, segments: list[dict], title: str, ): """ Plot local difficulty with segment backgrounds (non-overlapping). """ t0 = np.array([a for a, _ in times], dtype=np.float64) t1 = np.array([b for _, b in times], dtype=np.float64) mids = (t0 + t1) / 2.0 order = np.argsort(mids) mids_s = mids[order] stars_s = local_stars[order] # Colormap: Red=Hard, Green=Easy cmap = plt.get_cmap("RdYlGn_r") fig, ax = plt.subplots(figsize=(12, 4)) # Normalize colors max_star = max(s["avg_stars"] for s in segments) if segments else 10 min_star = min(s["avg_stars"] for s in segments) if segments else 0 star_range = max(max_star - min_star, 1) # Draw segment backgrounds (should NOT overlap now) for seg in segments: color = cmap((seg["avg_stars"] - min_star) / star_range) ax.axvspan( seg["start_time"], seg["end_time"], alpha=0.3, color=color, linewidth=0 ) # Horizontal line at segment average ax.hlines( y=seg["avg_stars"], xmin=seg["start_time"], xmax=seg["end_time"], colors=color, linewidth=3, alpha=0.9, ) # Label (only if segment is wide enough) duration = seg["end_time"] - seg["start_time"] if duration > 4: # Only label if > 4 seconds mid_x = (seg["start_time"] + seg["end_time"]) / 2 ax.text( mid_x, seg["avg_stars"] + 0.02, f"{seg['avg_stars']:.1f}", ha="center", va="bottom", fontsize=8, fontweight="bold", color="black", alpha=0.8, ) # Raw data on top ax.plot(mids_s, stars_s, color="gray", alpha=0.4, linewidth=1) # Boundary lines for seg in segments[1:]: ax.axvline( x=seg["start_time"], color="black", linewidth=1, linestyle="--", alpha=0.5 ) ax.set_xlabel("Time (s)") ax.set_ylabel("Raw Score") ax.set_title(title) ax.set_ylim(bottom=0, top=max_star + 2) ax.grid(True, alpha=0.15, axis="y") fig.tight_layout() return fig def _plot_attention_concentration( avg_attention: np.ndarray, title: str, ): # Cumulative mass of attention sorted by weight (how concentrated the model is) attn = np.clip(avg_attention.astype(np.float64), 0.0, None) if attn.sum() > 0: attn = attn / attn.sum() attn_sorted = np.sort(attn)[::-1] cum = np.cumsum(attn_sorted) k = np.arange(1, len(attn_sorted) + 1) fig, ax = plt.subplots(figsize=(10, 3.2)) ax.plot(k, cum, linewidth=2) ax.set_xlabel("Top-k instances (sorted by attention)") ax.set_ylabel("Cumulative attention mass") ax.set_ylim(0, 1.02) ax.set_title(title) ax.grid(True, alpha=0.25) fig.tight_layout() return fig def run_inference( tja_file, tja_text: str, course_name: str, checkpoint_path: str, device: str, window_measures_text: str, hop_measures: int, max_instances: int, ): if tja_file: with open(tja_file, "r", encoding="utf-8", errors="ignore") as f: tja_text = f.read() parsed = parse_tja(tja_text) if not parsed.courses: raise gr.Error("No COURSE found and no chart parsed.") if course_name not in parsed.courses: # Fallback to first course_name = next(iter(parsed.courses.keys())) course = parsed.courses[course_name] try: window_measures = [ int(x.strip()) for x in window_measures_text.split(",") if x.strip() ] except ValueError: raise gr.Error( "window_measures must be a comma-separated list of integers, e.g. 2,4" ) if not window_measures: window_measures = [2, 4] device = _resolve_device(device) model = _load_model(checkpoint_path, device=device) max_tokens = int(getattr(model.config, "max_seq_len", 128)) instances, masks, counts, times, token_counts = _build_instances_from_segments( course.segments, max_tokens_per_instance=max_tokens, window_measures=window_measures, hop_measures=int(hop_measures), max_instances_per_chart=int(max_instances), ) instances = instances.to(torch.device(device)) masks = masks.to(torch.device(device)) counts = counts.to(torch.device(device)) difficulty_hint = None if course.difficulty_hint is not None: mapping = {"easy": 0, "normal": 1, "hard": 2, "oni": 3, "ura": 4} difficulty_hint = torch.tensor( [mapping[course.difficulty_hint]], device=torch.device(device) ) with torch.no_grad(): out = model.forward( instances, masks, counts, difficulty_hint=difficulty_hint, return_attention=True, ) # Scalars difficulty_names = ["easy", "normal", "hard", "oni", "ura"] pred_class_id = int(out.difficulty_logits.argmax(dim=-1).item()) pred_class = difficulty_names[pred_class_id] raw_score = float(out.raw_score.item()) raw_star = float(out.raw_star.item()) display_star = float(out.display_star.item()) # Attention details attn = out.attention_info avg_attn = attn.get("average_attention") branch_attn = attn.get("branch_attentions") topk_mask = attn.get("topk_mask") # Local Difficulty Estimation (Probe) # Use the predicted class ID if no hint was provided calib_diff_id = difficulty_hint if calib_diff_id is None: calib_diff_id = out.difficulty_logits.argmax(dim=-1, keepdim=True) # [1, 1] local_raw, local_stars = model.get_instance_scores( out.instance_embeddings, difficulty_class_id=calib_diff_id.view(-1) ) avg_attn_np = ( avg_attn[0, : counts.item()].detach().cpu().numpy() if avg_attn is not None else None ) topk_np = ( topk_mask[0, : counts.item()].detach().cpu().numpy() if topk_mask is not None else None ) branch_np = ( branch_attn[0, :, : counts.item()].detach().cpu().numpy() if branch_attn is not None else None ) local_stars_np = local_stars[0, : counts.item()].detach().cpu().numpy() local_raw_np = local_raw[0, : counts.item()].detach().cpu().numpy() # Plots fig_attn = None fig_heat = None fig_density = None fig_conc = None if avg_attn_np is not None: fig_attn = _plot_attention( times, avg_attn_np, topk_np, title="MIL average attention over time" ) if avg_attn_np is not None: fig_density = _plot_density_and_attention( times, token_counts, avg_attn_np, topk_np, title="Token density vs attention (time-sorted)", ) fig_conc = _plot_attention_concentration( avg_attn_np, title="Attention concentration (how many windows dominate)", ) fig_local_diff = None if local_stars_np is not None: fig_local_diff = _plot_local_difficulty( times, local_stars_np, token_counts, title=f"Estimated Local Difficulty Curve (Assuming {pred_class} calibration)", ) # Segment Detection (Piecewise Constant Change Point Detection) fig_segments = None segment_table_df = None if local_raw_np is not None and len(times) > 0: segments = _detect_segments( local_raw_np, # Use raw score instead of stars times, min_segment_size=3, penalty_scale=0.5, ) # Create table rows seg_rows = [] for i, seg in enumerate(segments): seg_rows.append( [ i + 1, f"{seg['start_time']:.1f}", f"{seg['end_time']:.1f}", f"{seg['end_time'] - seg['start_time']:.1f}", f"{seg['avg_stars']:.1f}", # This is now avg_raw seg["n_windows"], ] ) segment_table_df = seg_rows fig_segments = _plot_segments( times, local_raw_np, # Use raw score segments, title=f"Chart Structure: {len(segments)} Segments Detected", ) # Meta/details if branch_np is not None: mids = np.array([(a + b) / 2.0 for a, b in times], dtype=np.float64) order = np.argsort(mids) branch_sorted = branch_np[:, order] fig_heat = _plot_branch_heatmap( branch_sorted, title="MIL attention (branches x instances)" ) # Add a few time tick labels ax = fig_heat.axes[0] if len(order) > 1: n_ticks = 6 tick_pos = np.linspace(0, len(order) - 1, n_ticks, dtype=int) tick_labels = [f"{mids[order[p]]:.0f}s" for p in tick_pos] ax.set_xticks(tick_pos) ax.set_xticklabels(tick_labels) # Table rows = [] for i, (t0, t1) in enumerate(times): rows.append( [ i, float(t0), float(t1), float((t0 + t1) / 2.0), int(token_counts[i]) if i < len(token_counts) else None, float(avg_attn_np[i]) if avg_attn_np is not None else None, int(topk_np[i]) if topk_np is not None else None, float(local_stars_np[i]) if i < len(local_stars_np) else None, ] ) # More intuitive summary: show top attention windows top_md = "" if avg_attn_np is not None: t0 = np.array([a for a, _ in times], dtype=np.float64) t1 = np.array([b for _, b in times], dtype=np.float64) mids = (t0 + t1) / 2.0 durations = np.maximum(t1 - t0, 1e-6) token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64) density = token_counts_np / durations top_n = min(8, len(avg_attn_np)) top_idx = np.argsort(avg_attn_np)[::-1][:top_n] lines = ["### Top segments (by attention)"] for rank, idx in enumerate(top_idx, start=1): is_topk = int(topk_np[idx]) if topk_np is not None else 0 lines.append( f"{rank}. `[{t0[idx]:.1f}s - {t1[idx]:.1f}s]` " f"attn={avg_attn_np[idx]:.4f}, dens={density[idx]:.1f} tok/s, topk={is_topk}" ) top_md = "\n".join(lines) # Meta/details meta_out = { "TITLE": parsed.meta.get("TITLE"), "BPM": parsed.meta.get("BPM"), "OFFSET": parsed.meta.get("OFFSET"), "COURSE": course.name, "LEVEL": course.level, "difficulty_hint": course.difficulty_hint, "n_instances": int(counts.item()), "max_tokens_per_instance": int(max_tokens), "window_measures": window_measures, "hop_measures": int(hop_measures), "attention_entropy": ( float(attn.get("entropy")[0].item()) if attn.get("entropy") is not None else None ), "attention_effective_n": ( float(attn.get("effective_n")[0].item()) if attn.get("effective_n") is not None else None ), "attention_top5_mass": ( float(attn.get("top5_mass")[0].item()) if attn.get("top5_mass") is not None else None ), } summary_md = ( f"### Prediction\n" f"- predicted difficulty: `{pred_class}`\n" f"- raw_score: `{raw_score:.4f}`\n" f"- raw_star: `{raw_star:.4f}`\n" f"- display_star: `{display_star:.4f}`\n" ) return ( summary_md, meta_out, fig_attn, fig_density, fig_heat, fig_conc, top_md, rows, fig_local_diff, fig_segments, segment_table_df, ) def _update_course_dropdown(tja_file, tja_text: str): if tja_file: with open(tja_file, "r", encoding="utf-8", errors="ignore") as f: tja_text = f.read() try: parsed = parse_tja(tja_text) choices = list(parsed.courses.keys()) value = choices[0] if choices else None return gr.Dropdown(choices=choices, value=value) except Exception: return gr.Dropdown(choices=[], value=None) def build_app() -> gr.Blocks: checkpoints = _discover_checkpoints() with gr.Blocks(title="TaikoChartEstimator Inference") as demo: gr.Markdown("# TaikoChartEstimator - Inference") with gr.Row(): # Left: Input (Upload/Paste with tabs) with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("Upload"): tja_file = gr.File(label="Upload TJA file") with gr.TabItem("Paste"): tja_text = gr.Textbox(label="Paste TJA content", lines=12) course = gr.Dropdown(label="COURSE", choices=[], value=None) btn = gr.Button("Run Inference", variant="primary", size="lg") # Right: Options with gr.Column(scale=1): gr.Markdown("### Options") checkpoint = gr.Dropdown( label="Checkpoint", choices=checkpoints, value=checkpoints[-1] if checkpoints else None, allow_custom_value=True, ) device = gr.Dropdown( label="Device", choices=["cpu", "mps", "cuda"], value="cpu" ) with gr.Accordion("Advanced", open=False): window_measures = gr.Textbox( label="window_measures (comma-separated)", value="2,4" ) hop_measures = gr.Slider( label="hop_measures", minimum=1, maximum=8, value=2, step=1 ) max_instances = gr.Slider( label="max_instances", minimum=1, maximum=512, value=128, step=1 ) with gr.Row(): with gr.Column(scale=1): summary = gr.Markdown() top_segments = gr.Markdown() with gr.Column(scale=1): meta_json = gr.JSON(label="Metadata") with gr.Tabs(): with gr.TabItem("Chart Structure"): gr.Markdown("### Automatic Segment Detection") gr.Markdown( "Detects distinct sections based on difficulty changes (Piecewise Constant Model)." ) plot_segments = gr.Plot(label="Detected Segments") segment_table = gr.Dataframe( headers=[ "#", "Start (s)", "End (s)", "Duration", "Avg Raw", "Windows", ], datatype=["number", "str", "str", "str", "str", "number"], label="Segment Details", ) with gr.TabItem("Local Difficulty"): plot_local_diff = gr.Plot(label="Local Difficulty Curve") with gr.TabItem("Attention & Density"): plot_density = gr.Plot(label="Density vs Attention") with gr.TabItem("Attention Details"): plot_attn = gr.Plot(label="Raw Attention") with gr.TabItem("Heatmap"): plot_heat = gr.Plot(label="Branch Heatmap") with gr.TabItem("Concentration"): plot_conc = gr.Plot(label="Concentration") with gr.TabItem("Raw Data"): # headers needs to match rows df = gr.Dataframe( headers=[ "id", "start", "end", "mid", "tokens", "attention", "is_topk", "local_stars", ], datatype=[ "number", "number", "number", "number", "number", "number", "number", "number", ], ) # Auto-refresh COURSE choices when input changes tja_file.change( _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course] ) tja_text.change( _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course] ) btn.click( run_inference, inputs=[ tja_file, tja_text, course, checkpoint, device, window_measures, hop_measures, max_instances, ], outputs=[ summary, meta_json, plot_attn, plot_density, plot_heat, plot_conc, top_segments, df, plot_local_diff, plot_segments, segment_table, ], ) return demo if __name__ == "__main__": app = build_app() app.launch()