Spaces:
Sleeping
Sleeping
| import glob | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any, Optional | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import ruptures as rpt | |
| import torch | |
| from sklearn.cluster import KMeans | |
| from sklearn.metrics import silhouette_score | |
| from TaikoChartEstimator.data.tokenizer import EventTokenizer | |
| from TaikoChartEstimator.model.model import TaikoChartEstimator | |
| class ParsedCourse: | |
| name: str | |
| level: Optional[int] | |
| segments: list[dict] | |
| difficulty_hint: Optional[str] | |
| class ParsedTJA: | |
| meta: dict[str, Any] | |
| courses: dict[str, ParsedCourse] | |
| NOTE_DIGIT_TO_TYPE = { | |
| "1": "Don", | |
| "2": "Ka", | |
| "3": "DonBig", | |
| "4": "KaBig", | |
| "5": "Roll", | |
| "6": "RollBig", | |
| "7": "Balloon", | |
| "8": "EndOf", | |
| "9": "BalloonAlt", | |
| } | |
| def _strip_comment(line: str) -> str: | |
| if "//" in line: | |
| line = line.split("//", 1)[0] | |
| return line.strip() | |
| def parse_tja(text: str) -> ParsedTJA: | |
| """Parse a (single-song) TJA into dataset-like `segments` per course. | |
| Supported (best-effort): COURSE/LEVEL, BPM, OFFSET, #START/#END, | |
| #BPMCHANGE, #MEASURE, #SCROLL, #DELAY, #GOGOSTART/#GOGOEND. | |
| Branching commands are ignored. | |
| """ | |
| if not text or not text.strip(): | |
| raise ValueError("Empty TJA input") | |
| text = text.replace("\ufeff", "") | |
| lines = [_strip_comment(l) for l in text.replace("\r\n", "\n").split("\n")] | |
| lines = [l for l in lines if l] | |
| meta: dict[str, Any] = {} | |
| courses: dict[str, dict[str, Any]] = {} | |
| current_course: Optional[dict[str, Any]] = None | |
| in_chart = False | |
| bpm = 120.0 | |
| offset = 0.0 | |
| measure_num = 4 | |
| measure_den = 4 | |
| scroll = 1.0 | |
| gogo = False | |
| current_time = 0.0 | |
| measure_start_time = 0.0 | |
| measure_digits: list[str] = [] | |
| def beats_per_measure() -> float: | |
| # TJA: #MEASURE a/b means measure length = 4 * a / b quarter-note beats | |
| return 4.0 * float(measure_num) / float(measure_den) | |
| def measure_duration_sec(local_bpm: float) -> float: | |
| return beats_per_measure() * 60.0 / max(local_bpm, 1e-6) | |
| def flush_measure_if_any() -> None: | |
| nonlocal current_time, measure_start_time, measure_digits | |
| if current_course is None: | |
| return | |
| digits = "".join(measure_digits).strip() | |
| if not digits: | |
| return | |
| dur = measure_duration_sec(bpm) | |
| step = dur / max(len(digits), 1) | |
| notes: list[dict] = [] | |
| for i, ch in enumerate(digits): | |
| if ch == "0": | |
| continue | |
| note_type = NOTE_DIGIT_TO_TYPE.get(ch) | |
| if not note_type: | |
| continue | |
| t = measure_start_time + i * step | |
| notes.append( | |
| { | |
| "note_type": note_type, | |
| "timestamp": float(t), | |
| "bpm": float(bpm), | |
| "scroll": float(scroll), | |
| "gogo": bool(gogo), | |
| } | |
| ) | |
| current_course["segments"].append( | |
| { | |
| "timestamp": float(measure_start_time), | |
| "measure_num": int(measure_num), | |
| "measure_den": int(measure_den), | |
| "notes": notes, | |
| } | |
| ) | |
| # Advance time by exactly one measure | |
| current_time = measure_start_time + dur | |
| measure_start_time = current_time | |
| measure_digits = [] | |
| def finalize_long_note_durations() -> None: | |
| if current_course is None: | |
| return | |
| # Flatten notes | |
| flat: list[dict] = [] | |
| for seg in current_course["segments"]: | |
| for n in seg.get("notes", []): | |
| flat.append(n) | |
| flat.sort(key=lambda n: n.get("timestamp", 0.0)) | |
| open_idx: list[int] = [] | |
| for i, n in enumerate(flat): | |
| nt = n.get("note_type") | |
| if nt in {"Roll", "RollBig", "Balloon", "BalloonAlt"}: | |
| open_idx.append(i) | |
| elif nt == "EndOf" and open_idx: | |
| start_i = open_idx.pop() | |
| start = flat[start_i] | |
| start_bpm = float(start.get("bpm", 120.0)) | |
| dt = float(n.get("timestamp", 0.0)) - float(start.get("timestamp", 0.0)) | |
| dur_beats = max(0.0, dt * start_bpm / 60.0) | |
| start["delay"] = float(dur_beats) | |
| def ensure_course(name: str) -> dict[str, Any]: | |
| nonlocal courses | |
| if name not in courses: | |
| courses[name] = { | |
| "name": name, | |
| "level": None, | |
| "segments": [], | |
| "difficulty_hint": None, | |
| } | |
| return courses[name] | |
| for raw in lines: | |
| line = raw.strip() | |
| if not in_chart and ":" in line and not line.startswith("#"): | |
| k, v = [p.strip() for p in line.split(":", 1)] | |
| ku = k.upper() | |
| meta[ku] = v | |
| if ku == "BPM": | |
| try: | |
| bpm = float(v) | |
| except ValueError: | |
| pass | |
| elif ku == "OFFSET": | |
| try: | |
| offset = float(v) | |
| except ValueError: | |
| pass | |
| elif ku == "COURSE": | |
| current_course = ensure_course(v) | |
| # Reset per-course chart state | |
| in_chart = False | |
| elif ku == "LEVEL" and current_course is not None: | |
| try: | |
| current_course["level"] = int(float(v)) | |
| except ValueError: | |
| current_course["level"] = None | |
| continue | |
| if line.startswith("#START"): | |
| if current_course is None: | |
| current_course = ensure_course("(default)") | |
| # Reset chart state at start | |
| in_chart = True | |
| bpm = float(meta.get("BPM", bpm) or bpm) | |
| try: | |
| offset = float(meta.get("OFFSET", offset) or offset) | |
| except ValueError: | |
| offset = offset | |
| measure_num, measure_den = 4, 4 | |
| scroll = 1.0 | |
| gogo = False | |
| current_time = 0.0 | |
| measure_start_time = 0.0 | |
| measure_digits = [] | |
| # Apply offset as a global shift (best-effort) | |
| current_time += float(offset) | |
| measure_start_time = current_time | |
| continue | |
| if not in_chart: | |
| continue | |
| if line.startswith("#END"): | |
| flush_measure_if_any() | |
| finalize_long_note_durations() | |
| in_chart = False | |
| continue | |
| if line.startswith("#"): | |
| cmd = line[1:].strip() | |
| cmd_u = cmd.upper() | |
| if cmd_u.startswith("BPMCHANGE"): | |
| flush_measure_if_any() | |
| try: | |
| bpm = float(cmd.split(maxsplit=1)[1]) | |
| except Exception: | |
| pass | |
| elif cmd_u.startswith("MEASURE"): | |
| flush_measure_if_any() | |
| try: | |
| frac = cmd.split(maxsplit=1)[1].strip() | |
| a, b = frac.split("/", 1) | |
| measure_num = int(a) | |
| measure_den = int(b) | |
| except Exception: | |
| pass | |
| elif cmd_u.startswith("SCROLL"): | |
| flush_measure_if_any() | |
| try: | |
| scroll = float(cmd.split(maxsplit=1)[1]) | |
| except Exception: | |
| pass | |
| elif cmd_u.startswith("DELAY"): | |
| flush_measure_if_any() | |
| try: | |
| current_time += float(cmd.split(maxsplit=1)[1]) | |
| except Exception: | |
| pass | |
| measure_start_time = current_time | |
| elif cmd_u.startswith("GOGOSTART"): | |
| flush_measure_if_any() | |
| gogo = True | |
| elif cmd_u.startswith("GOGOEND"): | |
| flush_measure_if_any() | |
| gogo = False | |
| else: | |
| # Ignore other commands (branching etc.) | |
| pass | |
| continue | |
| # Note data: may contain multiple commas | |
| for ch in line: | |
| if ch.isdigit(): | |
| measure_digits.append(ch) | |
| elif ch == ",": | |
| flush_measure_if_any() | |
| # Build ParsedTJA | |
| parsed_courses: dict[str, ParsedCourse] = {} | |
| difficulty_map = { | |
| "0": "easy", | |
| "easy": "easy", | |
| "1": "normal", | |
| "normal": "normal", | |
| "2": "hard", | |
| "hard": "hard", | |
| "3": "oni", | |
| "oni": "oni", | |
| "4": "oni", | |
| "ura": "oni", | |
| "edit": "oni", | |
| } | |
| for name, c in courses.items(): | |
| name_l = name.strip().lower() | |
| hint = difficulty_map.get(name_l) | |
| parsed_courses[name] = ParsedCourse( | |
| name=name, | |
| level=c.get("level"), | |
| segments=c.get("segments", []), | |
| difficulty_hint=hint, | |
| ) | |
| return ParsedTJA(meta=meta, courses=parsed_courses) | |
| def _discover_checkpoints() -> list[str]: | |
| # Prefer local trained outputs | |
| paths = [] | |
| for p in glob.glob("outputs/*/pretrained/*"): | |
| if os.path.isdir(p) and os.path.exists(os.path.join(p, "config.json")): | |
| paths.append(p) | |
| # Also accept HF / user-provided paths via manual input | |
| if not paths: | |
| return [ | |
| "JacobLinCool/TaikoChartEstimator-20251228", | |
| "JacobLinCool/TaikoChartEstimator-20251229", | |
| ] | |
| return sorted(paths) | |
| _MODEL_CACHE: dict[str, TaikoChartEstimator] = {} | |
| def _resolve_device(device: str) -> str: | |
| device = (device or "cpu").lower() | |
| if device == "cuda" and torch.cuda.is_available(): | |
| return "cuda" | |
| if ( | |
| device == "mps" | |
| and hasattr(torch.backends, "mps") | |
| and torch.backends.mps.is_available() | |
| ): | |
| return "mps" | |
| return "cpu" | |
| def _load_model(checkpoint_path: str, device: str) -> TaikoChartEstimator: | |
| device = _resolve_device(device) | |
| key = f"{checkpoint_path}::{device}" | |
| if key in _MODEL_CACHE: | |
| return _MODEL_CACHE[key] | |
| model = TaikoChartEstimator.from_pretrained(checkpoint_path) | |
| model.eval() | |
| model.to(torch.device(device)) | |
| _MODEL_CACHE[key] = model | |
| return model | |
| def _build_instances_from_segments( | |
| segments: list[dict], | |
| max_tokens_per_instance: int, | |
| window_measures: list[int], | |
| hop_measures: int, | |
| max_instances_per_chart: int, | |
| ) -> tuple[ | |
| torch.Tensor, torch.Tensor, torch.Tensor, list[tuple[float, float]], list[int] | |
| ]: | |
| tokenizer = EventTokenizer() | |
| tokens = tokenizer.tokenize_chart(segments) | |
| all_instances: list[torch.Tensor] = [] | |
| all_masks: list[torch.Tensor] = [] | |
| all_times: list[tuple[float, float]] = [] | |
| all_token_counts: list[int] = [] | |
| for window_size in window_measures: | |
| windows = tokenizer.create_windows( | |
| tokens, window_measures=window_size, hop_measures=hop_measures | |
| ) | |
| for window_tokens in windows: | |
| if not window_tokens: | |
| continue | |
| tensor, mask = tokenizer.tokens_to_tensor( | |
| window_tokens, max_length=max_tokens_per_instance | |
| ) | |
| all_token_counts.append(int(mask.sum().item())) | |
| tensor, mask = tokenizer.pad_sequence(tensor, mask, max_tokens_per_instance) | |
| all_instances.append(tensor) | |
| all_masks.append(mask) | |
| all_times.append( | |
| (float(window_tokens[0].timestamp), float(window_tokens[-1].timestamp)) | |
| ) | |
| if not all_instances: | |
| raise ValueError("No note events parsed (empty chart or unsupported format)") | |
| if len(all_instances) > max_instances_per_chart: | |
| idx = np.linspace( | |
| 0, len(all_instances) - 1, max_instances_per_chart, dtype=int | |
| ).tolist() | |
| all_instances = [all_instances[i] for i in idx] | |
| all_masks = [all_masks[i] for i in idx] | |
| all_times = [all_times[i] for i in idx] | |
| all_token_counts = [all_token_counts[i] for i in idx] | |
| instances = torch.stack(all_instances).unsqueeze(0) # [1, N, L, 6] | |
| masks = torch.stack(all_masks).unsqueeze(0) # [1, N, L] | |
| counts = torch.tensor([len(all_instances)], dtype=torch.long) # [1] | |
| return instances, masks, counts, all_times, all_token_counts | |
| def _plot_attention( | |
| times: list[tuple[float, float]], | |
| avg_attention: np.ndarray, | |
| topk_mask: Optional[np.ndarray], | |
| title: str, | |
| ): | |
| # Sort by time to avoid misleading zig-zag lines when windows are generated in mixed order. | |
| t0 = np.array([a for a, _ in times], dtype=np.float64) | |
| t1 = np.array([b for _, b in times], dtype=np.float64) | |
| mids = (t0 + t1) / 2.0 | |
| order = np.argsort(mids) | |
| mids_s = mids[order] | |
| attn_s = avg_attention[order] | |
| topk_s = topk_mask[order] if topk_mask is not None else None | |
| fig, ax = plt.subplots(figsize=(10, 3.2)) | |
| ax.scatter(mids_s, attn_s, s=14, alpha=0.8, label="Instance") | |
| ax.plot(mids_s, attn_s, linewidth=1.5, alpha=0.6) | |
| if topk_s is not None: | |
| sel = topk_s.astype(bool) | |
| ax.scatter( | |
| mids_s[sel], | |
| attn_s[sel], | |
| s=40, | |
| marker="o", | |
| edgecolors="black", | |
| linewidths=0.4, | |
| label="Top-k", | |
| ) | |
| ax.set_xlabel("Time (s)") | |
| ax.set_ylabel("Avg attention (weight)") | |
| ax.set_title(title) | |
| ax.grid(True, alpha=0.25) | |
| ax.legend(loc="best") | |
| fig.tight_layout() | |
| return fig | |
| def _plot_branch_heatmap(branch_attn: np.ndarray, title: str): | |
| # branch_attn: [n_branches, n_instances] | |
| fig, ax = plt.subplots(figsize=(10, 3.2)) | |
| im = ax.imshow(branch_attn, aspect="auto", interpolation="nearest") | |
| ax.set_title(title) | |
| ax.set_xlabel("Instance (time-sorted)") | |
| ax.set_ylabel("Branch") | |
| cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04) | |
| cbar.set_label("Attention weight") | |
| fig.tight_layout() | |
| return fig | |
| def _plot_density_and_attention( | |
| times: list[tuple[float, float]], | |
| token_counts: list[int], | |
| avg_attention: np.ndarray, | |
| topk_mask: Optional[np.ndarray], | |
| title: str, | |
| ): | |
| t0 = np.array([a for a, _ in times], dtype=np.float64) | |
| t1 = np.array([b for _, b in times], dtype=np.float64) | |
| mids = (t0 + t1) / 2.0 | |
| durations = np.maximum(t1 - t0, 1e-6) | |
| token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64) | |
| density = token_counts_np / durations | |
| order = np.argsort(mids) | |
| mids_s = mids[order] | |
| dens_s = density[order] | |
| attn_s = avg_attention[order] | |
| topk_s = topk_mask[order] if topk_mask is not None else None | |
| fig, ax1 = plt.subplots(figsize=(10, 3.2)) | |
| ax1.plot(mids_s, dens_s, linewidth=1.8, color="tab:blue", label="Token density") | |
| ax1.set_xlabel("Time (s)") | |
| ax1.set_ylabel("Tokens / sec", color="tab:blue") | |
| ax1.tick_params(axis="y", labelcolor="tab:blue") | |
| ax1.grid(True, alpha=0.25) | |
| ax2 = ax1.twinx() | |
| ax2.scatter( | |
| mids_s, attn_s, s=14, color="tab:orange", alpha=0.75, label="Avg attention" | |
| ) | |
| if topk_s is not None: | |
| sel = topk_s.astype(bool) | |
| ax2.scatter( | |
| mids_s[sel], | |
| attn_s[sel], | |
| s=40, | |
| marker="o", | |
| edgecolors="black", | |
| linewidths=0.4, | |
| color="tab:orange", | |
| label="Top-k attention", | |
| ) | |
| ax2.set_ylabel("Avg attention", color="tab:orange") | |
| ax2.tick_params(axis="y", labelcolor="tab:orange") | |
| ax1.set_title(title) | |
| # Merge legends | |
| h1, l1 = ax1.get_legend_handles_labels() | |
| h2, l2 = ax2.get_legend_handles_labels() | |
| ax1.legend(h1 + h2, l1 + l2, loc="best") | |
| fig.tight_layout() | |
| return fig | |
| def _plot_local_difficulty( | |
| times: list[tuple[float, float]], | |
| local_stars: np.ndarray, | |
| token_counts: list[int], | |
| title: str, | |
| ): | |
| """Plot estimated local difficulty (star rating) over time.""" | |
| t0 = np.array([a for a, _ in times], dtype=np.float64) | |
| t1 = np.array([b for _, b in times], dtype=np.float64) | |
| mids = (t0 + t1) / 2.0 | |
| durations = np.maximum(t1 - t0, 1e-6) | |
| token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64) | |
| density = token_counts_np / durations | |
| order = np.argsort(mids) | |
| mids_s = mids[order] | |
| stars_s = local_stars[order] | |
| dens_s = density[order] | |
| # EMA Smoothing | |
| # Alpha = 2 / (span + 1), for span=4 (approx 8-16s depending on window) -> alpha=0.4 | |
| alpha = 0.3 | |
| if len(stars_s) > 0: | |
| stars_smooth = np.zeros_like(stars_s) | |
| stars_smooth[0] = stars_s[0] | |
| for i in range(1, len(stars_s)): | |
| stars_smooth[i] = alpha * stars_s[i] + (1 - alpha) * stars_smooth[i - 1] | |
| else: | |
| stars_smooth = stars_s | |
| fig, ax1 = plt.subplots(figsize=(10, 3.5)) | |
| # Plot difficulty curve | |
| color = "tab:red" | |
| ax1.set_xlabel("Time (s)") | |
| ax1.set_ylabel("Estimated Local Stars", color=color) | |
| # Plot raw faint | |
| ax1.plot(mids_s, stars_s, color=color, linewidth=1, alpha=0.3, label="Raw") | |
| # Plot smoothed main | |
| ax1.plot(mids_s, stars_smooth, color=color, linewidth=2.5, label="Smoothed (EMA)") | |
| ax1.tick_params(axis="y", labelcolor=color) | |
| ax1.grid(True, alpha=0.25) | |
| # Fill area under smoothed curve | |
| ax1.fill_between(mids_s, stars_smooth, alpha=0.1, color=color) | |
| # Plot density on secondary axis for context | |
| ax2 = ax1.twinx() | |
| color2 = "tab:blue" | |
| ax2.set_ylabel("Density (notes/s)", color=color2) | |
| ax2.plot( | |
| mids_s, | |
| dens_s, | |
| color=color2, | |
| linewidth=1, | |
| linestyle="--", | |
| alpha=0.5, | |
| label="Note Density", | |
| ) | |
| ax2.tick_params(axis="y", labelcolor=color2) | |
| ax1.set_title(title) | |
| # Legends | |
| lines1, labels1 = ax1.get_legend_handles_labels() | |
| lines2, labels2 = ax2.get_legend_handles_labels() | |
| ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right") | |
| fig.tight_layout() | |
| return fig | |
| def _smooth_embeddings(embeddings: np.ndarray, window_size: int = 3) -> np.ndarray: | |
| """Apply temporal smoothing (moving average) to embeddings.""" | |
| if len(embeddings) < window_size: | |
| return embeddings | |
| # Kernel for simple moving average | |
| kernel = np.ones(window_size) / window_size | |
| # Apply to each dimension independenty | |
| # We can use scipy.ndimage.convolve1d or simplified numpy for dependency-free | |
| smoothed = np.zeros_like(embeddings) | |
| for dim in range(embeddings.shape[1]): | |
| # Padding: 'edge' mode equivalent | |
| x = embeddings[:, dim] | |
| pad_width = window_size // 2 | |
| padded = np.pad(x, pad_width, mode="edge") | |
| # Convolve | |
| s = np.convolve(padded, kernel, mode="valid") | |
| # Handle shape mismatch due to even/odd window | |
| if len(s) > len(x): | |
| s = s[: len(x)] | |
| elif len(s) < len(x): | |
| # Should not happen with padded='edge' widely enough but just in case | |
| s = np.pad(s, (0, len(x) - len(s)), mode="edge") | |
| smoothed[:, dim] = s | |
| return smoothed | |
| def _smooth_labels(labels: np.ndarray, window_size: int = 3) -> np.ndarray: | |
| """Apply mode filter to labels to enforce temporal continuity.""" | |
| if len(labels) < window_size: | |
| return labels | |
| n = len(labels) | |
| smoothed = labels.copy() | |
| pad = window_size // 2 | |
| # Simple sliding window mode | |
| for i in range(n): | |
| start = max(0, i - pad) | |
| end = min(n, i + pad + 1) | |
| window = labels[start:end] | |
| # Find mode | |
| counts = np.bincount(window) | |
| smoothed[i] = np.argmax(counts) | |
| return smoothed | |
| def _perform_clustering( | |
| embeddings: np.ndarray, | |
| min_k: int = 3, | |
| max_k: int = 8, | |
| smoothing_window: int = 3, | |
| label_smoothing_window: int = 3, | |
| random_state: int = 42, | |
| ) -> tuple[np.ndarray, int, dict]: | |
| """ | |
| Perform K-Means clustering with automatic K selection using Silhouette Score. | |
| Applying temporal smoothing to stabilize clusters. | |
| Args: | |
| embeddings: [N, D] data points | |
| min_k: Minimum number of clusters | |
| max_k: Maximum number of clusters | |
| Returns: | |
| labels: [N] cluster labels | |
| best_k: Selected number of clusters | |
| stats: Info about clustering quality | |
| """ | |
| # Simply if N is too small | |
| N = embeddings.shape[0] | |
| if N < min_k: | |
| return np.zeros(N, dtype=int), 1, {"score": 0.0} | |
| # 1. Temporal Smoothing | |
| if smoothing_window > 1: | |
| # print(f"Smoothing embeddings with window={smoothing_window}") | |
| work_embeddings = _smooth_embeddings(embeddings, window_size=smoothing_window) | |
| else: | |
| work_embeddings = embeddings | |
| best_score = -1.0 | |
| best_k = min_k | |
| best_model = None | |
| print(f"Clustering {N} instances...") | |
| effective_max_k = min(max_k, N - 1) | |
| if effective_max_k < min_k: | |
| effective_max_k = min_k | |
| for k in range(min_k, effective_max_k + 1): | |
| kmeans = KMeans(n_clusters=k, random_state=random_state, n_init=10) | |
| labels = kmeans.fit_predict(work_embeddings) | |
| try: | |
| score = silhouette_score(work_embeddings, labels) | |
| # print(f"K={k}, Silhouette={score:.4f}") | |
| if score > best_score: | |
| best_score = score | |
| best_k = k | |
| best_model = kmeans | |
| except Exception: | |
| pass | |
| if best_model is None: | |
| # Fallback | |
| kmeans = KMeans(n_clusters=min_k, random_state=random_state, n_init=10) | |
| kmeans.fit(work_embeddings) | |
| best_model = kmeans | |
| best_k = min_k | |
| labels = best_model.labels_ | |
| # 2. Label Smoothing (Post-processing) | |
| if label_smoothing_window > 1: | |
| labels = _smooth_labels(labels, window_size=label_smoothing_window) | |
| return labels, best_k, {"silhouette": best_score} | |
| def _analyze_clusters( | |
| cluster_labels: np.ndarray, | |
| local_stars: np.ndarray, | |
| note_density: np.ndarray, | |
| avg_attention: Optional[np.ndarray] = None, | |
| ) -> list[dict]: | |
| """ | |
| Analyze properties of each cluster to create a profile. | |
| Returns list of dicts: [{id, count, avg_stars, avg_density, avg_attn, desc}] | |
| """ | |
| unique_labels = np.unique(cluster_labels) | |
| profiles = [] | |
| for label in unique_labels: | |
| mask = cluster_labels == label | |
| count = mask.sum() | |
| avg_s = local_stars[mask].mean() if len(local_stars) > 0 else 0 | |
| avg_d = note_density[mask].mean() if len(note_density) > 0 else 0 | |
| avg_a = avg_attention[mask].mean() if avg_attention is not None else 0 | |
| profiles.append( | |
| { | |
| "Cluster ID": int(label), | |
| "Count": int(count), | |
| "Avg Stars": float(f"{avg_s:.2f}"), | |
| "Avg Density": float(f"{avg_d:.2f}"), | |
| "Avg Attention": float(f"{avg_a:.4f}"), | |
| } | |
| ) | |
| # Sort by Avg Stars to make it intuitive (Cluster 0 = Easiest or Hardest?) | |
| # Let's keep ID but maybe we can add a rank? | |
| # Sorting purely by ID is safer for consistency with plot colors. | |
| profiles.sort(key=lambda x: x["Cluster ID"]) | |
| return profiles | |
| def _plot_clusters( | |
| times: list[tuple[float, float]], | |
| cluster_labels: np.ndarray, | |
| local_stars: np.ndarray, | |
| title: str, | |
| ): | |
| """Plot timeline colored by cluster ID.""" | |
| t0 = np.array([a for a, _ in times], dtype=np.float64) | |
| t1 = np.array([b for _, b in times], dtype=np.float64) | |
| mids = (t0 + t1) / 2.0 | |
| # Sort | |
| order = np.argsort(mids) | |
| mids_s = mids[order] | |
| stars_s = local_stars[order] | |
| labels_s = cluster_labels[order] | |
| unique_labels = np.unique(labels_s) | |
| n_clusters = len(unique_labels) | |
| # Use a distinct colormap | |
| cmap = plt.get_cmap("tab10" if n_clusters <= 10 else "tab20") | |
| fig, ax = plt.subplots(figsize=(10, 3.5)) | |
| # We want to plot segments. Since they are time-sorted, we can just scatter or valid-bar plot. | |
| # A step plot or bar plot might be good. | |
| # Let's use a scatter plot for simplicity but heavy markers. | |
| for i, label in enumerate(unique_labels): | |
| mask = labels_s == label | |
| ax.scatter( | |
| mids_s[mask], | |
| stars_s[mask], | |
| color=cmap(i), | |
| label=f"Cluster {label}", | |
| s=20, | |
| alpha=0.8, | |
| ) | |
| # Also plot a faint line to show connectivity | |
| ax.plot(mids_s, stars_s, color="gray", alpha=0.2, linewidth=1) | |
| ax.set_xlabel("Time (s)") | |
| ax.set_ylabel("Local Stars") | |
| ax.set_title(title) | |
| ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) | |
| ax.grid(True, alpha=0.25) | |
| fig.tight_layout() | |
| return fig | |
| def _detect_segments( | |
| local_stars: np.ndarray, | |
| times: list[tuple[float, float]], | |
| min_segment_size: int = 3, | |
| penalty_scale: float = 1.0, | |
| ) -> list[dict]: | |
| """ | |
| Detect segments using Change Point Detection. | |
| IMPORTANT: Windows may not be in temporal order (e.g., mixed window sizes). | |
| We sort by midpoint time first to ensure temporal coherence. | |
| """ | |
| n = len(local_stars) | |
| if n < min_segment_size * 2: | |
| return [ | |
| { | |
| "start_time": times[0][0], | |
| "end_time": times[-1][1], | |
| "avg_stars": float(local_stars.mean()), | |
| "n_windows": n, | |
| } | |
| ] | |
| # Calculate window midpoints | |
| mids = np.array([(t0 + t1) / 2 for t0, t1 in times]) | |
| # SORT by midpoint time (critical for temporal coherence!) | |
| order = np.argsort(mids) | |
| mids_sorted = mids[order] | |
| stars_sorted = local_stars[order] | |
| times_sorted = [times[i] for i in order] | |
| # Build cell boundaries (1D Voronoi on SORTED windows) | |
| cell_bounds = [times_sorted[0][0]] # Song start | |
| for i in range(len(mids_sorted) - 1): | |
| cell_bounds.append((mids_sorted[i] + mids_sorted[i + 1]) / 2) | |
| cell_bounds.append(times_sorted[-1][1]) # Song end | |
| # Ruptures detection (on SORTED data) | |
| signal = stars_sorted.reshape(-1, 1) | |
| penalty = np.var(stars_sorted) * penalty_scale | |
| algo = rpt.Pelt(model="l2", min_size=min_segment_size).fit(signal) | |
| change_points = algo.predict(pen=penalty) | |
| # Build segments | |
| segments = [] | |
| prev_idx = 0 | |
| for cp in change_points: | |
| seg_stars = stars_sorted[prev_idx:cp] | |
| start_t = cell_bounds[prev_idx] | |
| end_t = cell_bounds[cp] | |
| segments.append( | |
| { | |
| "start_time": float(start_t), | |
| "end_time": float(end_t), | |
| "avg_stars": float(seg_stars.mean()), | |
| "n_windows": cp - prev_idx, | |
| } | |
| ) | |
| prev_idx = cp | |
| return segments | |
| def _plot_segments( | |
| times: list[tuple[float, float]], | |
| local_stars: np.ndarray, | |
| segments: list[dict], | |
| title: str, | |
| ): | |
| """ | |
| Plot local difficulty with segment backgrounds (non-overlapping). | |
| """ | |
| t0 = np.array([a for a, _ in times], dtype=np.float64) | |
| t1 = np.array([b for _, b in times], dtype=np.float64) | |
| mids = (t0 + t1) / 2.0 | |
| order = np.argsort(mids) | |
| mids_s = mids[order] | |
| stars_s = local_stars[order] | |
| # Colormap: Red=Hard, Green=Easy | |
| cmap = plt.get_cmap("RdYlGn_r") | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| # Normalize colors | |
| max_star = max(s["avg_stars"] for s in segments) if segments else 10 | |
| min_star = min(s["avg_stars"] for s in segments) if segments else 0 | |
| star_range = max(max_star - min_star, 1) | |
| # Draw segment backgrounds (should NOT overlap now) | |
| for seg in segments: | |
| color = cmap((seg["avg_stars"] - min_star) / star_range) | |
| ax.axvspan( | |
| seg["start_time"], seg["end_time"], alpha=0.3, color=color, linewidth=0 | |
| ) | |
| # Horizontal line at segment average | |
| ax.hlines( | |
| y=seg["avg_stars"], | |
| xmin=seg["start_time"], | |
| xmax=seg["end_time"], | |
| colors=color, | |
| linewidth=3, | |
| alpha=0.9, | |
| ) | |
| # Label (only if segment is wide enough) | |
| duration = seg["end_time"] - seg["start_time"] | |
| if duration > 4: # Only label if > 4 seconds | |
| mid_x = (seg["start_time"] + seg["end_time"]) / 2 | |
| ax.text( | |
| mid_x, | |
| seg["avg_stars"] + 0.02, | |
| f"{seg['avg_stars']:.1f}", | |
| ha="center", | |
| va="bottom", | |
| fontsize=8, | |
| fontweight="bold", | |
| color="black", | |
| alpha=0.8, | |
| ) | |
| # Raw data on top | |
| ax.plot(mids_s, stars_s, color="gray", alpha=0.4, linewidth=1) | |
| # Boundary lines | |
| for seg in segments[1:]: | |
| ax.axvline( | |
| x=seg["start_time"], color="black", linewidth=1, linestyle="--", alpha=0.5 | |
| ) | |
| ax.set_xlabel("Time (s)") | |
| ax.set_ylabel("Raw Score") | |
| ax.set_title(title) | |
| ax.set_ylim(bottom=0, top=max_star + 2) | |
| ax.grid(True, alpha=0.15, axis="y") | |
| fig.tight_layout() | |
| return fig | |
| def _plot_attention_concentration( | |
| avg_attention: np.ndarray, | |
| title: str, | |
| ): | |
| # Cumulative mass of attention sorted by weight (how concentrated the model is) | |
| attn = np.clip(avg_attention.astype(np.float64), 0.0, None) | |
| if attn.sum() > 0: | |
| attn = attn / attn.sum() | |
| attn_sorted = np.sort(attn)[::-1] | |
| cum = np.cumsum(attn_sorted) | |
| k = np.arange(1, len(attn_sorted) + 1) | |
| fig, ax = plt.subplots(figsize=(10, 3.2)) | |
| ax.plot(k, cum, linewidth=2) | |
| ax.set_xlabel("Top-k instances (sorted by attention)") | |
| ax.set_ylabel("Cumulative attention mass") | |
| ax.set_ylim(0, 1.02) | |
| ax.set_title(title) | |
| ax.grid(True, alpha=0.25) | |
| fig.tight_layout() | |
| return fig | |
| def run_inference( | |
| tja_file, | |
| tja_text: str, | |
| course_name: str, | |
| checkpoint_path: str, | |
| device: str, | |
| window_measures_text: str, | |
| hop_measures: int, | |
| max_instances: int, | |
| ): | |
| if tja_file: | |
| with open(tja_file, "r", encoding="utf-8", errors="ignore") as f: | |
| tja_text = f.read() | |
| parsed = parse_tja(tja_text) | |
| if not parsed.courses: | |
| raise gr.Error("No COURSE found and no chart parsed.") | |
| if course_name not in parsed.courses: | |
| # Fallback to first | |
| course_name = next(iter(parsed.courses.keys())) | |
| course = parsed.courses[course_name] | |
| try: | |
| window_measures = [ | |
| int(x.strip()) for x in window_measures_text.split(",") if x.strip() | |
| ] | |
| except ValueError: | |
| raise gr.Error( | |
| "window_measures must be a comma-separated list of integers, e.g. 2,4" | |
| ) | |
| if not window_measures: | |
| window_measures = [2, 4] | |
| device = _resolve_device(device) | |
| model = _load_model(checkpoint_path, device=device) | |
| max_tokens = int(getattr(model.config, "max_seq_len", 128)) | |
| instances, masks, counts, times, token_counts = _build_instances_from_segments( | |
| course.segments, | |
| max_tokens_per_instance=max_tokens, | |
| window_measures=window_measures, | |
| hop_measures=int(hop_measures), | |
| max_instances_per_chart=int(max_instances), | |
| ) | |
| instances = instances.to(torch.device(device)) | |
| masks = masks.to(torch.device(device)) | |
| counts = counts.to(torch.device(device)) | |
| difficulty_hint = None | |
| if course.difficulty_hint is not None: | |
| mapping = {"easy": 0, "normal": 1, "hard": 2, "oni": 3, "ura": 4} | |
| difficulty_hint = torch.tensor( | |
| [mapping[course.difficulty_hint]], device=torch.device(device) | |
| ) | |
| with torch.no_grad(): | |
| out = model.forward( | |
| instances, | |
| masks, | |
| counts, | |
| difficulty_hint=difficulty_hint, | |
| return_attention=True, | |
| ) | |
| # Scalars | |
| difficulty_names = ["easy", "normal", "hard", "oni", "ura"] | |
| pred_class_id = int(out.difficulty_logits.argmax(dim=-1).item()) | |
| pred_class = difficulty_names[pred_class_id] | |
| raw_score = float(out.raw_score.item()) | |
| raw_star = float(out.raw_star.item()) | |
| display_star = float(out.display_star.item()) | |
| # Attention details | |
| attn = out.attention_info | |
| avg_attn = attn.get("average_attention") | |
| branch_attn = attn.get("branch_attentions") | |
| topk_mask = attn.get("topk_mask") | |
| # Local Difficulty Estimation (Probe) | |
| # Use the predicted class ID if no hint was provided | |
| calib_diff_id = difficulty_hint | |
| if calib_diff_id is None: | |
| calib_diff_id = out.difficulty_logits.argmax(dim=-1, keepdim=True) # [1, 1] | |
| local_raw, local_stars = model.get_instance_scores( | |
| out.instance_embeddings, difficulty_class_id=calib_diff_id.view(-1) | |
| ) | |
| avg_attn_np = ( | |
| avg_attn[0, : counts.item()].detach().cpu().numpy() | |
| if avg_attn is not None | |
| else None | |
| ) | |
| topk_np = ( | |
| topk_mask[0, : counts.item()].detach().cpu().numpy() | |
| if topk_mask is not None | |
| else None | |
| ) | |
| branch_np = ( | |
| branch_attn[0, :, : counts.item()].detach().cpu().numpy() | |
| if branch_attn is not None | |
| else None | |
| ) | |
| local_stars_np = local_stars[0, : counts.item()].detach().cpu().numpy() | |
| local_raw_np = local_raw[0, : counts.item()].detach().cpu().numpy() | |
| # Plots | |
| fig_attn = None | |
| fig_heat = None | |
| fig_density = None | |
| fig_conc = None | |
| if avg_attn_np is not None: | |
| fig_attn = _plot_attention( | |
| times, avg_attn_np, topk_np, title="MIL average attention over time" | |
| ) | |
| if avg_attn_np is not None: | |
| fig_density = _plot_density_and_attention( | |
| times, | |
| token_counts, | |
| avg_attn_np, | |
| topk_np, | |
| title="Token density vs attention (time-sorted)", | |
| ) | |
| fig_conc = _plot_attention_concentration( | |
| avg_attn_np, | |
| title="Attention concentration (how many windows dominate)", | |
| ) | |
| fig_local_diff = None | |
| if local_stars_np is not None: | |
| fig_local_diff = _plot_local_difficulty( | |
| times, | |
| local_stars_np, | |
| token_counts, | |
| title=f"Estimated Local Difficulty Curve (Assuming {pred_class} calibration)", | |
| ) | |
| # Segment Detection (Piecewise Constant Change Point Detection) | |
| fig_segments = None | |
| segment_table_df = None | |
| if local_raw_np is not None and len(times) > 0: | |
| segments = _detect_segments( | |
| local_raw_np, # Use raw score instead of stars | |
| times, | |
| min_segment_size=3, | |
| penalty_scale=0.5, | |
| ) | |
| # Create table rows | |
| seg_rows = [] | |
| for i, seg in enumerate(segments): | |
| seg_rows.append( | |
| [ | |
| i + 1, | |
| f"{seg['start_time']:.1f}", | |
| f"{seg['end_time']:.1f}", | |
| f"{seg['end_time'] - seg['start_time']:.1f}", | |
| f"{seg['avg_stars']:.1f}", # This is now avg_raw | |
| seg["n_windows"], | |
| ] | |
| ) | |
| segment_table_df = seg_rows | |
| fig_segments = _plot_segments( | |
| times, | |
| local_raw_np, # Use raw score | |
| segments, | |
| title=f"Chart Structure: {len(segments)} Segments Detected", | |
| ) | |
| # Meta/details | |
| if branch_np is not None: | |
| mids = np.array([(a + b) / 2.0 for a, b in times], dtype=np.float64) | |
| order = np.argsort(mids) | |
| branch_sorted = branch_np[:, order] | |
| fig_heat = _plot_branch_heatmap( | |
| branch_sorted, title="MIL attention (branches x instances)" | |
| ) | |
| # Add a few time tick labels | |
| ax = fig_heat.axes[0] | |
| if len(order) > 1: | |
| n_ticks = 6 | |
| tick_pos = np.linspace(0, len(order) - 1, n_ticks, dtype=int) | |
| tick_labels = [f"{mids[order[p]]:.0f}s" for p in tick_pos] | |
| ax.set_xticks(tick_pos) | |
| ax.set_xticklabels(tick_labels) | |
| # Table | |
| rows = [] | |
| for i, (t0, t1) in enumerate(times): | |
| rows.append( | |
| [ | |
| i, | |
| float(t0), | |
| float(t1), | |
| float((t0 + t1) / 2.0), | |
| int(token_counts[i]) if i < len(token_counts) else None, | |
| float(avg_attn_np[i]) if avg_attn_np is not None else None, | |
| int(topk_np[i]) if topk_np is not None else None, | |
| float(local_stars_np[i]) if i < len(local_stars_np) else None, | |
| ] | |
| ) | |
| # More intuitive summary: show top attention windows | |
| top_md = "" | |
| if avg_attn_np is not None: | |
| t0 = np.array([a for a, _ in times], dtype=np.float64) | |
| t1 = np.array([b for _, b in times], dtype=np.float64) | |
| mids = (t0 + t1) / 2.0 | |
| durations = np.maximum(t1 - t0, 1e-6) | |
| token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64) | |
| density = token_counts_np / durations | |
| top_n = min(8, len(avg_attn_np)) | |
| top_idx = np.argsort(avg_attn_np)[::-1][:top_n] | |
| lines = ["### Top segments (by attention)"] | |
| for rank, idx in enumerate(top_idx, start=1): | |
| is_topk = int(topk_np[idx]) if topk_np is not None else 0 | |
| lines.append( | |
| f"{rank}. `[{t0[idx]:.1f}s - {t1[idx]:.1f}s]` " | |
| f"attn={avg_attn_np[idx]:.4f}, dens={density[idx]:.1f} tok/s, topk={is_topk}" | |
| ) | |
| top_md = "\n".join(lines) | |
| # Meta/details | |
| meta_out = { | |
| "TITLE": parsed.meta.get("TITLE"), | |
| "BPM": parsed.meta.get("BPM"), | |
| "OFFSET": parsed.meta.get("OFFSET"), | |
| "COURSE": course.name, | |
| "LEVEL": course.level, | |
| "difficulty_hint": course.difficulty_hint, | |
| "n_instances": int(counts.item()), | |
| "max_tokens_per_instance": int(max_tokens), | |
| "window_measures": window_measures, | |
| "hop_measures": int(hop_measures), | |
| "attention_entropy": ( | |
| float(attn.get("entropy")[0].item()) | |
| if attn.get("entropy") is not None | |
| else None | |
| ), | |
| "attention_effective_n": ( | |
| float(attn.get("effective_n")[0].item()) | |
| if attn.get("effective_n") is not None | |
| else None | |
| ), | |
| "attention_top5_mass": ( | |
| float(attn.get("top5_mass")[0].item()) | |
| if attn.get("top5_mass") is not None | |
| else None | |
| ), | |
| } | |
| summary_md = ( | |
| f"### Prediction\n" | |
| f"- predicted difficulty: `{pred_class}`\n" | |
| f"- raw_score: `{raw_score:.4f}`\n" | |
| f"- raw_star: `{raw_star:.4f}`\n" | |
| f"- display_star: `{display_star:.4f}`\n" | |
| ) | |
| return ( | |
| summary_md, | |
| meta_out, | |
| fig_attn, | |
| fig_density, | |
| fig_heat, | |
| fig_conc, | |
| top_md, | |
| rows, | |
| fig_local_diff, | |
| fig_segments, | |
| segment_table_df, | |
| ) | |
| def _update_course_dropdown(tja_file, tja_text: str): | |
| if tja_file: | |
| with open(tja_file, "r", encoding="utf-8", errors="ignore") as f: | |
| tja_text = f.read() | |
| try: | |
| parsed = parse_tja(tja_text) | |
| choices = list(parsed.courses.keys()) | |
| value = choices[0] if choices else None | |
| return gr.Dropdown(choices=choices, value=value) | |
| except Exception: | |
| return gr.Dropdown(choices=[], value=None) | |
| def build_app() -> gr.Blocks: | |
| checkpoints = _discover_checkpoints() | |
| with gr.Blocks(title="TaikoChartEstimator Inference") as demo: | |
| gr.Markdown("# TaikoChartEstimator - Inference") | |
| with gr.Row(): | |
| # Left: Input (Upload/Paste with tabs) | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("Upload"): | |
| tja_file = gr.File(label="Upload TJA file") | |
| with gr.TabItem("Paste"): | |
| tja_text = gr.Textbox(label="Paste TJA content", lines=12) | |
| course = gr.Dropdown(label="COURSE", choices=[], value=None) | |
| btn = gr.Button("Run Inference", variant="primary", size="lg") | |
| # Right: Options | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Options") | |
| checkpoint = gr.Dropdown( | |
| label="Checkpoint", | |
| choices=checkpoints, | |
| value=checkpoints[-1] if checkpoints else None, | |
| allow_custom_value=True, | |
| ) | |
| device = gr.Dropdown( | |
| label="Device", choices=["cpu", "mps", "cuda"], value="cpu" | |
| ) | |
| with gr.Accordion("Advanced", open=False): | |
| window_measures = gr.Textbox( | |
| label="window_measures (comma-separated)", value="2,4" | |
| ) | |
| hop_measures = gr.Slider( | |
| label="hop_measures", minimum=1, maximum=8, value=2, step=1 | |
| ) | |
| max_instances = gr.Slider( | |
| label="max_instances", minimum=1, maximum=512, value=128, step=1 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| summary = gr.Markdown() | |
| top_segments = gr.Markdown() | |
| with gr.Column(scale=1): | |
| meta_json = gr.JSON(label="Metadata") | |
| with gr.Tabs(): | |
| with gr.TabItem("Chart Structure"): | |
| gr.Markdown("### Automatic Segment Detection") | |
| gr.Markdown( | |
| "Detects distinct sections based on difficulty changes (Piecewise Constant Model)." | |
| ) | |
| plot_segments = gr.Plot(label="Detected Segments") | |
| segment_table = gr.Dataframe( | |
| headers=[ | |
| "#", | |
| "Start (s)", | |
| "End (s)", | |
| "Duration", | |
| "Avg Raw", | |
| "Windows", | |
| ], | |
| datatype=["number", "str", "str", "str", "str", "number"], | |
| label="Segment Details", | |
| ) | |
| with gr.TabItem("Local Difficulty"): | |
| plot_local_diff = gr.Plot(label="Local Difficulty Curve") | |
| with gr.TabItem("Attention & Density"): | |
| plot_density = gr.Plot(label="Density vs Attention") | |
| with gr.TabItem("Attention Details"): | |
| plot_attn = gr.Plot(label="Raw Attention") | |
| with gr.TabItem("Heatmap"): | |
| plot_heat = gr.Plot(label="Branch Heatmap") | |
| with gr.TabItem("Concentration"): | |
| plot_conc = gr.Plot(label="Concentration") | |
| with gr.TabItem("Raw Data"): | |
| # headers needs to match rows | |
| df = gr.Dataframe( | |
| headers=[ | |
| "id", | |
| "start", | |
| "end", | |
| "mid", | |
| "tokens", | |
| "attention", | |
| "is_topk", | |
| "local_stars", | |
| ], | |
| datatype=[ | |
| "number", | |
| "number", | |
| "number", | |
| "number", | |
| "number", | |
| "number", | |
| "number", | |
| "number", | |
| ], | |
| ) | |
| # Auto-refresh COURSE choices when input changes | |
| tja_file.change( | |
| _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course] | |
| ) | |
| tja_text.change( | |
| _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course] | |
| ) | |
| btn.click( | |
| run_inference, | |
| inputs=[ | |
| tja_file, | |
| tja_text, | |
| course, | |
| checkpoint, | |
| device, | |
| window_measures, | |
| hop_measures, | |
| max_instances, | |
| ], | |
| outputs=[ | |
| summary, | |
| meta_json, | |
| plot_attn, | |
| plot_density, | |
| plot_heat, | |
| plot_conc, | |
| top_segments, | |
| df, | |
| plot_local_diff, | |
| plot_segments, | |
| segment_table, | |
| ], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch() | |