JacobLinCool commited on
Commit
e964427
·
verified ·
1 Parent(s): dce0030

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +873 -0
  2. requirements.txt +94 -0
app.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional
5
+
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ from TaikoChartEstimator.data.tokenizer import EventTokenizer
12
+ from TaikoChartEstimator.model.model import TaikoChartEstimator
13
+
14
+
15
+ @dataclass
16
+ class ParsedCourse:
17
+ name: str
18
+ level: Optional[int]
19
+ segments: list[dict]
20
+ difficulty_hint: Optional[str]
21
+
22
+
23
+ @dataclass
24
+ class ParsedTJA:
25
+ meta: dict[str, Any]
26
+ courses: dict[str, ParsedCourse]
27
+
28
+
29
+ NOTE_DIGIT_TO_TYPE = {
30
+ "1": "Don",
31
+ "2": "Ka",
32
+ "3": "DonBig",
33
+ "4": "KaBig",
34
+ "5": "Roll",
35
+ "6": "RollBig",
36
+ "7": "Balloon",
37
+ "8": "EndOf",
38
+ "9": "BalloonAlt",
39
+ }
40
+
41
+
42
+ def _strip_comment(line: str) -> str:
43
+ if "//" in line:
44
+ line = line.split("//", 1)[0]
45
+ return line.strip()
46
+
47
+
48
+ def parse_tja(text: str) -> ParsedTJA:
49
+ """Parse a (single-song) TJA into dataset-like `segments` per course.
50
+
51
+ Supported (best-effort): COURSE/LEVEL, BPM, OFFSET, #START/#END,
52
+ #BPMCHANGE, #MEASURE, #SCROLL, #DELAY, #GOGOSTART/#GOGOEND.
53
+
54
+ Branching commands are ignored.
55
+ """
56
+
57
+ if not text or not text.strip():
58
+ raise ValueError("Empty TJA input")
59
+
60
+ text = text.replace("\ufeff", "")
61
+ lines = [_strip_comment(l) for l in text.replace("\r\n", "\n").split("\n")]
62
+ lines = [l for l in lines if l]
63
+
64
+ meta: dict[str, Any] = {}
65
+ courses: dict[str, dict[str, Any]] = {}
66
+
67
+ current_course: Optional[dict[str, Any]] = None
68
+ in_chart = False
69
+
70
+ bpm = 120.0
71
+ offset = 0.0
72
+ measure_num = 4
73
+ measure_den = 4
74
+ scroll = 1.0
75
+ gogo = False
76
+
77
+ current_time = 0.0
78
+ measure_start_time = 0.0
79
+ measure_digits: list[str] = []
80
+
81
+ def beats_per_measure() -> float:
82
+ # TJA: #MEASURE a/b means measure length = 4 * a / b quarter-note beats
83
+ return 4.0 * float(measure_num) / float(measure_den)
84
+
85
+ def measure_duration_sec(local_bpm: float) -> float:
86
+ return beats_per_measure() * 60.0 / max(local_bpm, 1e-6)
87
+
88
+ def flush_measure_if_any() -> None:
89
+ nonlocal current_time, measure_start_time, measure_digits
90
+ if current_course is None:
91
+ return
92
+ digits = "".join(measure_digits).strip()
93
+ if not digits:
94
+ return
95
+
96
+ dur = measure_duration_sec(bpm)
97
+ step = dur / max(len(digits), 1)
98
+ notes: list[dict] = []
99
+ for i, ch in enumerate(digits):
100
+ if ch == "0":
101
+ continue
102
+ note_type = NOTE_DIGIT_TO_TYPE.get(ch)
103
+ if not note_type:
104
+ continue
105
+ t = measure_start_time + i * step
106
+ notes.append(
107
+ {
108
+ "note_type": note_type,
109
+ "timestamp": float(t),
110
+ "bpm": float(bpm),
111
+ "scroll": float(scroll),
112
+ "gogo": bool(gogo),
113
+ }
114
+ )
115
+
116
+ current_course["segments"].append(
117
+ {
118
+ "timestamp": float(measure_start_time),
119
+ "measure_num": int(measure_num),
120
+ "measure_den": int(measure_den),
121
+ "notes": notes,
122
+ }
123
+ )
124
+
125
+ # Advance time by exactly one measure
126
+ current_time = measure_start_time + dur
127
+ measure_start_time = current_time
128
+ measure_digits = []
129
+
130
+ def finalize_long_note_durations() -> None:
131
+ if current_course is None:
132
+ return
133
+ # Flatten notes
134
+ flat: list[dict] = []
135
+ for seg in current_course["segments"]:
136
+ for n in seg.get("notes", []):
137
+ flat.append(n)
138
+ flat.sort(key=lambda n: n.get("timestamp", 0.0))
139
+
140
+ open_idx: list[int] = []
141
+ for i, n in enumerate(flat):
142
+ nt = n.get("note_type")
143
+ if nt in {"Roll", "RollBig", "Balloon", "BalloonAlt"}:
144
+ open_idx.append(i)
145
+ elif nt == "EndOf" and open_idx:
146
+ start_i = open_idx.pop()
147
+ start = flat[start_i]
148
+ start_bpm = float(start.get("bpm", 120.0))
149
+ dt = float(n.get("timestamp", 0.0)) - float(start.get("timestamp", 0.0))
150
+ dur_beats = max(0.0, dt * start_bpm / 60.0)
151
+ start["delay"] = float(dur_beats)
152
+
153
+ def ensure_course(name: str) -> dict[str, Any]:
154
+ nonlocal courses
155
+ if name not in courses:
156
+ courses[name] = {
157
+ "name": name,
158
+ "level": None,
159
+ "segments": [],
160
+ "difficulty_hint": None,
161
+ }
162
+ return courses[name]
163
+
164
+ for raw in lines:
165
+ line = raw.strip()
166
+
167
+ if not in_chart and ":" in line and not line.startswith("#"):
168
+ k, v = [p.strip() for p in line.split(":", 1)]
169
+ ku = k.upper()
170
+ meta[ku] = v
171
+ if ku == "BPM":
172
+ try:
173
+ bpm = float(v)
174
+ except ValueError:
175
+ pass
176
+ elif ku == "OFFSET":
177
+ try:
178
+ offset = float(v)
179
+ except ValueError:
180
+ pass
181
+ elif ku == "COURSE":
182
+ current_course = ensure_course(v)
183
+ # Reset per-course chart state
184
+ in_chart = False
185
+ elif ku == "LEVEL" and current_course is not None:
186
+ try:
187
+ current_course["level"] = int(float(v))
188
+ except ValueError:
189
+ current_course["level"] = None
190
+ continue
191
+
192
+ if line.startswith("#START"):
193
+ if current_course is None:
194
+ current_course = ensure_course("(default)")
195
+ # Reset chart state at start
196
+ in_chart = True
197
+ bpm = float(meta.get("BPM", bpm) or bpm)
198
+ try:
199
+ offset = float(meta.get("OFFSET", offset) or offset)
200
+ except ValueError:
201
+ offset = offset
202
+ measure_num, measure_den = 4, 4
203
+ scroll = 1.0
204
+ gogo = False
205
+ current_time = 0.0
206
+ measure_start_time = 0.0
207
+ measure_digits = []
208
+ # Apply offset as a global shift (best-effort)
209
+ current_time += float(offset)
210
+ measure_start_time = current_time
211
+ continue
212
+
213
+ if not in_chart:
214
+ continue
215
+
216
+ if line.startswith("#END"):
217
+ flush_measure_if_any()
218
+ finalize_long_note_durations()
219
+ in_chart = False
220
+ continue
221
+
222
+ if line.startswith("#"):
223
+ cmd = line[1:].strip()
224
+ cmd_u = cmd.upper()
225
+ if cmd_u.startswith("BPMCHANGE"):
226
+ flush_measure_if_any()
227
+ try:
228
+ bpm = float(cmd.split(maxsplit=1)[1])
229
+ except Exception:
230
+ pass
231
+ elif cmd_u.startswith("MEASURE"):
232
+ flush_measure_if_any()
233
+ try:
234
+ frac = cmd.split(maxsplit=1)[1].strip()
235
+ a, b = frac.split("/", 1)
236
+ measure_num = int(a)
237
+ measure_den = int(b)
238
+ except Exception:
239
+ pass
240
+ elif cmd_u.startswith("SCROLL"):
241
+ flush_measure_if_any()
242
+ try:
243
+ scroll = float(cmd.split(maxsplit=1)[1])
244
+ except Exception:
245
+ pass
246
+ elif cmd_u.startswith("DELAY"):
247
+ flush_measure_if_any()
248
+ try:
249
+ current_time += float(cmd.split(maxsplit=1)[1])
250
+ except Exception:
251
+ pass
252
+ measure_start_time = current_time
253
+ elif cmd_u.startswith("GOGOSTART"):
254
+ flush_measure_if_any()
255
+ gogo = True
256
+ elif cmd_u.startswith("GOGOEND"):
257
+ flush_measure_if_any()
258
+ gogo = False
259
+ else:
260
+ # Ignore other commands (branching etc.)
261
+ pass
262
+ continue
263
+
264
+ # Note data: may contain multiple commas
265
+ for ch in line:
266
+ if ch.isdigit():
267
+ measure_digits.append(ch)
268
+ elif ch == ",":
269
+ flush_measure_if_any()
270
+
271
+ # Build ParsedTJA
272
+ parsed_courses: dict[str, ParsedCourse] = {}
273
+ difficulty_map = {
274
+ "0": "easy",
275
+ "easy": "easy",
276
+ "1": "normal",
277
+ "normal": "normal",
278
+ "2": "hard",
279
+ "hard": "hard",
280
+ "3": "oni",
281
+ "oni": "oni",
282
+ "4": "oni",
283
+ "ura": "oni",
284
+ "edit": "oni",
285
+ }
286
+ for name, c in courses.items():
287
+ name_l = name.strip().lower()
288
+ hint = difficulty_map.get(name_l)
289
+ parsed_courses[name] = ParsedCourse(
290
+ name=name,
291
+ level=c.get("level"),
292
+ segments=c.get("segments", []),
293
+ difficulty_hint=hint,
294
+ )
295
+
296
+ return ParsedTJA(meta=meta, courses=parsed_courses)
297
+
298
+
299
+ def _discover_checkpoints() -> list[str]:
300
+ # Prefer local trained outputs
301
+ paths = []
302
+ for p in glob.glob("outputs/*/pretrained/*"):
303
+ if os.path.isdir(p) and os.path.exists(os.path.join(p, "config.json")):
304
+ paths.append(p)
305
+ # Also accept HF / user-provided paths via manual input
306
+ if not paths:
307
+ return ["JacobLinCool/TaikoChartEstimator-20251228"]
308
+ return sorted(paths)
309
+
310
+
311
+ _MODEL_CACHE: dict[str, TaikoChartEstimator] = {}
312
+
313
+
314
+ def _resolve_device(device: str) -> str:
315
+ device = (device or "cpu").lower()
316
+ if device == "cuda" and torch.cuda.is_available():
317
+ return "cuda"
318
+ if (
319
+ device == "mps"
320
+ and hasattr(torch.backends, "mps")
321
+ and torch.backends.mps.is_available()
322
+ ):
323
+ return "mps"
324
+ return "cpu"
325
+
326
+
327
+ def _load_model(checkpoint_path: str, device: str) -> TaikoChartEstimator:
328
+ device = _resolve_device(device)
329
+ key = f"{checkpoint_path}::{device}"
330
+ if key in _MODEL_CACHE:
331
+ return _MODEL_CACHE[key]
332
+
333
+ model = TaikoChartEstimator.from_pretrained(checkpoint_path)
334
+ model.eval()
335
+ model.to(torch.device(device))
336
+ _MODEL_CACHE[key] = model
337
+ return model
338
+
339
+
340
+ def _build_instances_from_segments(
341
+ segments: list[dict],
342
+ max_tokens_per_instance: int,
343
+ window_measures: list[int],
344
+ hop_measures: int,
345
+ max_instances_per_chart: int,
346
+ ) -> tuple[
347
+ torch.Tensor, torch.Tensor, torch.Tensor, list[tuple[float, float]], list[int]
348
+ ]:
349
+ tokenizer = EventTokenizer()
350
+ tokens = tokenizer.tokenize_chart(segments)
351
+
352
+ all_instances: list[torch.Tensor] = []
353
+ all_masks: list[torch.Tensor] = []
354
+ all_times: list[tuple[float, float]] = []
355
+ all_token_counts: list[int] = []
356
+
357
+ for window_size in window_measures:
358
+ windows = tokenizer.create_windows(
359
+ tokens, window_measures=window_size, hop_measures=hop_measures
360
+ )
361
+ for window_tokens in windows:
362
+ if not window_tokens:
363
+ continue
364
+ tensor, mask = tokenizer.tokens_to_tensor(
365
+ window_tokens, max_length=max_tokens_per_instance
366
+ )
367
+ all_token_counts.append(int(mask.sum().item()))
368
+ tensor, mask = tokenizer.pad_sequence(tensor, mask, max_tokens_per_instance)
369
+ all_instances.append(tensor)
370
+ all_masks.append(mask)
371
+ all_times.append(
372
+ (float(window_tokens[0].timestamp), float(window_tokens[-1].timestamp))
373
+ )
374
+
375
+ if not all_instances:
376
+ raise ValueError("No note events parsed (empty chart or unsupported format)")
377
+
378
+ if len(all_instances) > max_instances_per_chart:
379
+ idx = np.linspace(
380
+ 0, len(all_instances) - 1, max_instances_per_chart, dtype=int
381
+ ).tolist()
382
+ all_instances = [all_instances[i] for i in idx]
383
+ all_masks = [all_masks[i] for i in idx]
384
+ all_times = [all_times[i] for i in idx]
385
+ all_token_counts = [all_token_counts[i] for i in idx]
386
+
387
+ instances = torch.stack(all_instances).unsqueeze(0) # [1, N, L, 6]
388
+ masks = torch.stack(all_masks).unsqueeze(0) # [1, N, L]
389
+ counts = torch.tensor([len(all_instances)], dtype=torch.long) # [1]
390
+ return instances, masks, counts, all_times, all_token_counts
391
+
392
+
393
+ def _plot_attention(
394
+ times: list[tuple[float, float]],
395
+ avg_attention: np.ndarray,
396
+ topk_mask: Optional[np.ndarray],
397
+ title: str,
398
+ ):
399
+ # Sort by time to avoid misleading zig-zag lines when windows are generated in mixed order.
400
+ t0 = np.array([a for a, _ in times], dtype=np.float64)
401
+ t1 = np.array([b for _, b in times], dtype=np.float64)
402
+ mids = (t0 + t1) / 2.0
403
+ order = np.argsort(mids)
404
+
405
+ mids_s = mids[order]
406
+ attn_s = avg_attention[order]
407
+ topk_s = topk_mask[order] if topk_mask is not None else None
408
+
409
+ fig, ax = plt.subplots(figsize=(10, 3.2))
410
+ ax.scatter(mids_s, attn_s, s=14, alpha=0.8, label="Instance")
411
+ ax.plot(mids_s, attn_s, linewidth=1.5, alpha=0.6)
412
+
413
+ if topk_s is not None:
414
+ sel = topk_s.astype(bool)
415
+ ax.scatter(
416
+ mids_s[sel],
417
+ attn_s[sel],
418
+ s=40,
419
+ marker="o",
420
+ edgecolors="black",
421
+ linewidths=0.4,
422
+ label="Top-k",
423
+ )
424
+
425
+ ax.set_xlabel("Time (s)")
426
+ ax.set_ylabel("Avg attention (weight)")
427
+ ax.set_title(title)
428
+ ax.grid(True, alpha=0.25)
429
+ ax.legend(loc="best")
430
+ fig.tight_layout()
431
+ return fig
432
+
433
+
434
+ def _plot_branch_heatmap(branch_attn: np.ndarray, title: str):
435
+ # branch_attn: [n_branches, n_instances]
436
+ fig, ax = plt.subplots(figsize=(10, 3.2))
437
+ im = ax.imshow(branch_attn, aspect="auto", interpolation="nearest")
438
+ ax.set_title(title)
439
+ ax.set_xlabel("Instance (time-sorted)")
440
+ ax.set_ylabel("Branch")
441
+ cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04)
442
+ cbar.set_label("Attention weight")
443
+ fig.tight_layout()
444
+ return fig
445
+
446
+
447
+ def _plot_density_and_attention(
448
+ times: list[tuple[float, float]],
449
+ token_counts: list[int],
450
+ avg_attention: np.ndarray,
451
+ topk_mask: Optional[np.ndarray],
452
+ title: str,
453
+ ):
454
+ t0 = np.array([a for a, _ in times], dtype=np.float64)
455
+ t1 = np.array([b for _, b in times], dtype=np.float64)
456
+ mids = (t0 + t1) / 2.0
457
+ durations = np.maximum(t1 - t0, 1e-6)
458
+ token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64)
459
+ density = token_counts_np / durations
460
+ order = np.argsort(mids)
461
+
462
+ mids_s = mids[order]
463
+ dens_s = density[order]
464
+ attn_s = avg_attention[order]
465
+ topk_s = topk_mask[order] if topk_mask is not None else None
466
+
467
+ fig, ax1 = plt.subplots(figsize=(10, 3.2))
468
+ ax1.plot(mids_s, dens_s, linewidth=1.8, color="tab:blue", label="Token density")
469
+ ax1.set_xlabel("Time (s)")
470
+ ax1.set_ylabel("Tokens / sec", color="tab:blue")
471
+ ax1.tick_params(axis="y", labelcolor="tab:blue")
472
+ ax1.grid(True, alpha=0.25)
473
+
474
+ ax2 = ax1.twinx()
475
+ ax2.scatter(
476
+ mids_s, attn_s, s=14, color="tab:orange", alpha=0.75, label="Avg attention"
477
+ )
478
+ if topk_s is not None:
479
+ sel = topk_s.astype(bool)
480
+ ax2.scatter(
481
+ mids_s[sel],
482
+ attn_s[sel],
483
+ s=40,
484
+ marker="o",
485
+ edgecolors="black",
486
+ linewidths=0.4,
487
+ color="tab:orange",
488
+ label="Top-k attention",
489
+ )
490
+ ax2.set_ylabel("Avg attention", color="tab:orange")
491
+ ax2.tick_params(axis="y", labelcolor="tab:orange")
492
+
493
+ ax1.set_title(title)
494
+ # Merge legends
495
+ h1, l1 = ax1.get_legend_handles_labels()
496
+ h2, l2 = ax2.get_legend_handles_labels()
497
+ ax1.legend(h1 + h2, l1 + l2, loc="best")
498
+ fig.tight_layout()
499
+ return fig
500
+
501
+
502
+ def _plot_attention_concentration(
503
+ avg_attention: np.ndarray,
504
+ title: str,
505
+ ):
506
+ # Cumulative mass of attention sorted by weight (how concentrated the model is)
507
+ attn = np.clip(avg_attention.astype(np.float64), 0.0, None)
508
+ if attn.sum() > 0:
509
+ attn = attn / attn.sum()
510
+ attn_sorted = np.sort(attn)[::-1]
511
+ cum = np.cumsum(attn_sorted)
512
+ k = np.arange(1, len(attn_sorted) + 1)
513
+
514
+ fig, ax = plt.subplots(figsize=(10, 3.2))
515
+ ax.plot(k, cum, linewidth=2)
516
+ ax.set_xlabel("Top-k instances (sorted by attention)")
517
+ ax.set_ylabel("Cumulative attention mass")
518
+ ax.set_ylim(0, 1.02)
519
+ ax.set_title(title)
520
+ ax.grid(True, alpha=0.25)
521
+ fig.tight_layout()
522
+ return fig
523
+
524
+
525
+ def run_inference(
526
+ tja_file,
527
+ tja_text: str,
528
+ course_name: str,
529
+ checkpoint_path: str,
530
+ device: str,
531
+ window_measures_text: str,
532
+ hop_measures: int,
533
+ max_instances: int,
534
+ ):
535
+ if tja_file:
536
+ with open(tja_file, "r", encoding="utf-8", errors="ignore") as f:
537
+ tja_text = f.read()
538
+
539
+ parsed = parse_tja(tja_text)
540
+ if not parsed.courses:
541
+ raise gr.Error("No COURSE found and no chart parsed.")
542
+
543
+ if course_name not in parsed.courses:
544
+ # Fallback to first
545
+ course_name = next(iter(parsed.courses.keys()))
546
+
547
+ course = parsed.courses[course_name]
548
+
549
+ try:
550
+ window_measures = [
551
+ int(x.strip()) for x in window_measures_text.split(",") if x.strip()
552
+ ]
553
+ except ValueError:
554
+ raise gr.Error(
555
+ "window_measures must be a comma-separated list of integers, e.g. 2,4"
556
+ )
557
+ if not window_measures:
558
+ window_measures = [2, 4]
559
+
560
+ device = _resolve_device(device)
561
+ model = _load_model(checkpoint_path, device=device)
562
+ max_tokens = int(getattr(model.config, "max_seq_len", 128))
563
+
564
+ instances, masks, counts, times, token_counts = _build_instances_from_segments(
565
+ course.segments,
566
+ max_tokens_per_instance=max_tokens,
567
+ window_measures=window_measures,
568
+ hop_measures=int(hop_measures),
569
+ max_instances_per_chart=int(max_instances),
570
+ )
571
+
572
+ instances = instances.to(torch.device(device))
573
+ masks = masks.to(torch.device(device))
574
+ counts = counts.to(torch.device(device))
575
+
576
+ difficulty_hint = None
577
+ if course.difficulty_hint is not None:
578
+ mapping = {"easy": 0, "normal": 1, "hard": 2, "oni": 3, "ura": 4}
579
+ difficulty_hint = torch.tensor(
580
+ [mapping[course.difficulty_hint]], device=torch.device(device)
581
+ )
582
+
583
+ with torch.no_grad():
584
+ out = model.forward(
585
+ instances,
586
+ masks,
587
+ counts,
588
+ difficulty_hint=difficulty_hint,
589
+ return_attention=True,
590
+ )
591
+
592
+ # Scalars
593
+ difficulty_names = ["easy", "normal", "hard", "oni", "ura"]
594
+ pred_class_id = int(out.difficulty_logits.argmax(dim=-1).item())
595
+ pred_class = difficulty_names[pred_class_id]
596
+ raw_score = float(out.raw_score.item())
597
+ raw_star = float(out.raw_star.item())
598
+ display_star = float(out.display_star.item())
599
+
600
+ # Attention details
601
+ attn = out.attention_info
602
+ avg_attn = attn.get("average_attention")
603
+ branch_attn = attn.get("branch_attentions")
604
+ topk_mask = attn.get("topk_mask")
605
+
606
+ avg_attn_np = (
607
+ avg_attn[0, : counts.item()].detach().cpu().numpy()
608
+ if avg_attn is not None
609
+ else None
610
+ )
611
+ topk_np = (
612
+ topk_mask[0, : counts.item()].detach().cpu().numpy()
613
+ if topk_mask is not None
614
+ else None
615
+ )
616
+ branch_np = (
617
+ branch_attn[0, :, : counts.item()].detach().cpu().numpy()
618
+ if branch_attn is not None
619
+ else None
620
+ )
621
+
622
+ # Plots
623
+ fig_attn = None
624
+ fig_heat = None
625
+ fig_density = None
626
+ fig_conc = None
627
+ if avg_attn_np is not None:
628
+ fig_attn = _plot_attention(
629
+ times, avg_attn_np, topk_np, title="MIL average attention over time"
630
+ )
631
+ if avg_attn_np is not None:
632
+ fig_density = _plot_density_and_attention(
633
+ times,
634
+ token_counts,
635
+ avg_attn_np,
636
+ topk_np,
637
+ title="Token density vs attention (time-sorted)",
638
+ )
639
+ fig_conc = _plot_attention_concentration(
640
+ avg_attn_np,
641
+ title="Attention concentration (how many windows dominate)",
642
+ )
643
+
644
+ # Heatmap: sort instances by time for interpretability
645
+ if branch_np is not None:
646
+ mids = np.array([(a + b) / 2.0 for a, b in times], dtype=np.float64)
647
+ order = np.argsort(mids)
648
+ branch_sorted = branch_np[:, order]
649
+ fig_heat = _plot_branch_heatmap(
650
+ branch_sorted, title="MIL attention (branches x instances)"
651
+ )
652
+ # Add a few time tick labels
653
+ ax = fig_heat.axes[0]
654
+ if len(order) > 1:
655
+ n_ticks = 6
656
+ tick_pos = np.linspace(0, len(order) - 1, n_ticks, dtype=int)
657
+ tick_labels = [f"{mids[order[p]]:.0f}s" for p in tick_pos]
658
+ ax.set_xticks(tick_pos)
659
+ ax.set_xticklabels(tick_labels)
660
+
661
+ # Table
662
+ rows = []
663
+ for i, (t0, t1) in enumerate(times):
664
+ rows.append(
665
+ [
666
+ i,
667
+ float(t0),
668
+ float(t1),
669
+ float((t0 + t1) / 2.0),
670
+ int(token_counts[i]) if i < len(token_counts) else None,
671
+ float(avg_attn_np[i]) if avg_attn_np is not None else None,
672
+ int(topk_np[i]) if topk_np is not None else None,
673
+ ]
674
+ )
675
+
676
+ # More intuitive summary: show top attention windows
677
+ top_md = ""
678
+ if avg_attn_np is not None:
679
+ t0 = np.array([a for a, _ in times], dtype=np.float64)
680
+ t1 = np.array([b for _, b in times], dtype=np.float64)
681
+ mids = (t0 + t1) / 2.0
682
+ durations = np.maximum(t1 - t0, 1e-6)
683
+ token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64)
684
+ density = token_counts_np / durations
685
+
686
+ top_n = min(8, len(avg_attn_np))
687
+ top_idx = np.argsort(avg_attn_np)[::-1][:top_n]
688
+
689
+ lines = ["### Top segments (by attention)"]
690
+ for rank, idx in enumerate(top_idx, start=1):
691
+ is_topk = int(topk_np[idx]) if topk_np is not None else 0
692
+ lines.append(
693
+ f"{rank}. `[{t0[idx]:.1f}s - {t1[idx]:.1f}s]` "
694
+ f"attn={avg_attn_np[idx]:.4f}, dens={density[idx]:.1f} tok/s, topk={is_topk}"
695
+ )
696
+ top_md = "\n".join(lines)
697
+
698
+ # Meta/details
699
+ meta_out = {
700
+ "TITLE": parsed.meta.get("TITLE"),
701
+ "BPM": parsed.meta.get("BPM"),
702
+ "OFFSET": parsed.meta.get("OFFSET"),
703
+ "COURSE": course.name,
704
+ "LEVEL": course.level,
705
+ "difficulty_hint": course.difficulty_hint,
706
+ "n_instances": int(counts.item()),
707
+ "max_tokens_per_instance": int(max_tokens),
708
+ "window_measures": window_measures,
709
+ "hop_measures": int(hop_measures),
710
+ "attention_entropy": float(attn.get("entropy")[0].item())
711
+ if attn.get("entropy") is not None
712
+ else None,
713
+ "attention_effective_n": float(attn.get("effective_n")[0].item())
714
+ if attn.get("effective_n") is not None
715
+ else None,
716
+ "attention_top5_mass": float(attn.get("top5_mass")[0].item())
717
+ if attn.get("top5_mass") is not None
718
+ else None,
719
+ }
720
+
721
+ summary_md = (
722
+ f"### Prediction\n"
723
+ f"- predicted difficulty: `{pred_class}`\n"
724
+ f"- raw_score: `{raw_score:.4f}`\n"
725
+ f"- raw_star: `{raw_star:.4f}`\n"
726
+ f"- display_star: `{display_star:.4f}`\n"
727
+ )
728
+
729
+ return (
730
+ summary_md,
731
+ meta_out,
732
+ fig_attn,
733
+ fig_density,
734
+ fig_heat,
735
+ fig_conc,
736
+ top_md,
737
+ rows,
738
+ )
739
+
740
+
741
+ def _update_course_dropdown(tja_file, tja_text: str):
742
+ if tja_file:
743
+ with open(tja_file, "r", encoding="utf-8", errors="ignore") as f:
744
+ tja_text = f.read()
745
+ try:
746
+ parsed = parse_tja(tja_text)
747
+ choices = list(parsed.courses.keys())
748
+ value = choices[0] if choices else None
749
+ return gr.Dropdown(choices=choices, value=value)
750
+ except Exception:
751
+ return gr.Dropdown(choices=[], value=None)
752
+
753
+
754
+ def build_app() -> gr.Blocks:
755
+ checkpoints = _discover_checkpoints()
756
+
757
+ with gr.Blocks(title="TaikoChartEstimator Inference") as demo:
758
+ gr.Markdown("# TaikoChartEstimator - Inference")
759
+ gr.Markdown(
760
+ """
761
+ ## How to Read Visualizations
762
+
763
+ - The model splits the chart into multiple **windows (instances)** and aggregates them using MIL (Multiple Instance Learning) for a prediction.
764
+ - `Avg attention` is the importance weight of this window for the final judgment; it is typically normalized by softmax within a single chart, so the values are usually small.
765
+ - `Top-k` is another Top-K pooling branch that selects windows that "look most like peak difficulty points"; they do not necessarily overlap perfectly with attention peaks.
766
+
767
+ Recommended combinations:
768
+ - `Token density vs attention`: Check if high-density segments are simultaneously emphasized.
769
+ - `Attention concentration`: Check if the model relies on only a few windows (closer to 1 means more concentrated).
770
+ """
771
+ )
772
+
773
+ with gr.Row():
774
+ with gr.Column(scale=1):
775
+ tja_file = gr.File(
776
+ label="Upload .tja", file_types=[".tja"], type="filepath"
777
+ )
778
+ tja_text = gr.Textbox(label="Or paste TJA content", lines=16)
779
+
780
+ course = gr.Dropdown(label="COURSE", choices=[], value=None)
781
+
782
+ checkpoint = gr.Dropdown(
783
+ label="Checkpoint",
784
+ choices=checkpoints,
785
+ value=checkpoints[-1] if checkpoints else None,
786
+ allow_custom_value=True,
787
+ )
788
+
789
+ device = gr.Dropdown(
790
+ label="Device", choices=["cpu", "mps", "cuda"], value="cpu"
791
+ )
792
+
793
+ window_measures = gr.Textbox(
794
+ label="window_measures (comma-separated)", value="2,4"
795
+ )
796
+ hop_measures = gr.Slider(
797
+ label="hop_measures", minimum=1, maximum=8, value=2, step=1
798
+ )
799
+ max_instances = gr.Slider(
800
+ label="max_instances", minimum=8, maximum=256, value=64, step=1
801
+ )
802
+
803
+ run_btn = gr.Button("Run inference", variant="primary")
804
+
805
+ with gr.Column(scale=2):
806
+ summary = gr.Markdown()
807
+ meta_json = gr.JSON(label="Details")
808
+ attn_plot = gr.Plot(label="Attention (time-sorted)")
809
+ density_plot = gr.Plot(label="Token density vs attention")
810
+ heat_plot = gr.Plot(label="Branch attention heatmap")
811
+ conc_plot = gr.Plot(label="Attention concentration")
812
+ top_segments = gr.Markdown()
813
+ table = gr.Dataframe(
814
+ headers=[
815
+ "instance_idx",
816
+ "t_start",
817
+ "t_end",
818
+ "t_mid",
819
+ "token_count",
820
+ "avg_attention",
821
+ "topk_selected",
822
+ ],
823
+ datatype=[
824
+ "number",
825
+ "number",
826
+ "number",
827
+ "number",
828
+ "number",
829
+ "number",
830
+ "number",
831
+ ],
832
+ label="Per-instance details",
833
+ wrap=True,
834
+ )
835
+
836
+ # Auto-refresh COURSE choices when input changes
837
+ tja_file.change(
838
+ _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course]
839
+ )
840
+ tja_text.change(
841
+ _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course]
842
+ )
843
+
844
+ run_btn.click(
845
+ run_inference,
846
+ inputs=[
847
+ tja_file,
848
+ tja_text,
849
+ course,
850
+ checkpoint,
851
+ device,
852
+ window_measures,
853
+ hop_measures,
854
+ max_instances,
855
+ ],
856
+ outputs=[
857
+ summary,
858
+ meta_json,
859
+ attn_plot,
860
+ density_plot,
861
+ heat_plot,
862
+ conc_plot,
863
+ top_segments,
864
+ table,
865
+ ],
866
+ )
867
+
868
+ return demo
869
+
870
+
871
+ if __name__ == "__main__":
872
+ app = build_app()
873
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ accelerate==1.12.0
3
+ aiofiles==24.1.0
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.13.2
6
+ aiosignal==1.4.0
7
+ annotated-doc==0.0.4
8
+ annotated-types==0.7.0
9
+ anyio==4.12.0
10
+ attrs==25.4.0
11
+ brotli==1.2.0
12
+ certifi==2025.11.12
13
+ charset-normalizer==3.4.4
14
+ click==8.3.1
15
+ contourpy==1.3.3
16
+ cycler==0.12.1
17
+ datasets==4.4.2
18
+ dill==0.4.0
19
+ fastapi==0.128.0
20
+ ffmpy==1.0.0
21
+ filelock==3.20.1
22
+ fonttools==4.61.1
23
+ frozenlist==1.8.0
24
+ fsspec==2025.10.0
25
+ gradio==6.2.0
26
+ gradio-client==2.0.2
27
+ groovy==0.1.2
28
+ grpcio==1.76.0
29
+ h11==0.16.0
30
+ hf-xet==1.2.0
31
+ httpcore==1.0.9
32
+ httpx==0.28.1
33
+ huggingface-hub==1.2.3
34
+ idna==3.11
35
+ jinja2==3.1.6
36
+ joblib==1.5.3
37
+ kiwisolver==1.4.9
38
+ markdown==3.10
39
+ markdown-it-py==4.0.0
40
+ markupsafe==3.0.3
41
+ matplotlib==3.10.8
42
+ mdurl==0.1.2
43
+ mpmath==1.3.0
44
+ multidict==6.7.0
45
+ multiprocess==0.70.18
46
+ networkx==3.6.1
47
+ numpy==2.4.0
48
+ orjson==3.11.5
49
+ packaging==25.0
50
+ pandas==2.3.3
51
+ pillow==12.0.0
52
+ propcache==0.4.1
53
+ protobuf==6.33.2
54
+ psutil==7.2.0
55
+ pyarrow==22.0.0
56
+ pydantic==2.12.5
57
+ pydantic-core==2.41.5
58
+ pydub==0.25.1
59
+ pygments==2.19.2
60
+ pyparsing==3.3.1
61
+ python-dateutil==2.9.0.post0
62
+ python-multipart==0.0.21
63
+ pytz==2025.2
64
+ pyyaml==6.0.3
65
+ requests==2.32.5
66
+ rich==14.2.0
67
+ safehttpx==0.1.7
68
+ safetensors==0.7.0
69
+ scikit-learn==1.8.0
70
+ scipy==1.16.3
71
+ semantic-version==2.10.0
72
+ setuptools==80.9.0
73
+ shellingham==1.5.4
74
+ six==1.17.0
75
+ starlette==0.50.0
76
+ sympy==1.14.0
77
+ tensorboard==2.20.0
78
+ tensorboard-data-server==0.7.2
79
+ threadpoolctl==3.6.0
80
+ tomlkit==0.13.3
81
+ torch==2.9.1
82
+ torchaudio==2.9.1
83
+ torchcodec==0.9.1
84
+ tqdm==4.67.1
85
+ typer==0.21.0
86
+ typer-slim==0.21.0
87
+ typing-extensions==4.15.0
88
+ typing-inspection==0.4.2
89
+ tzdata==2025.3
90
+ urllib3==2.6.2
91
+ uvicorn==0.40.0
92
+ werkzeug==3.1.4
93
+ xxhash==3.6.0
94
+ yarl==1.22.0