JacobLinCool commited on
Commit
f7793f8
·
verified ·
1 Parent(s): 2789e13

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +578 -53
  2. requirements.txt +1 -0
app.py CHANGED
@@ -6,7 +6,10 @@ from typing import Any, Optional
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
  import numpy as np
 
9
  import torch
 
 
10
 
11
  from TaikoChartEstimator.data.tokenizer import EventTokenizer
12
  from TaikoChartEstimator.model.model import TaikoChartEstimator
@@ -499,6 +502,445 @@ def _plot_density_and_attention(
499
  return fig
500
 
501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  def _plot_attention_concentration(
503
  avg_attention: np.ndarray,
504
  title: str,
@@ -603,6 +1045,16 @@ def run_inference(
603
  branch_attn = attn.get("branch_attentions")
604
  topk_mask = attn.get("topk_mask")
605
 
 
 
 
 
 
 
 
 
 
 
606
  avg_attn_np = (
607
  avg_attn[0, : counts.item()].detach().cpu().numpy()
608
  if avg_attn is not None
@@ -618,6 +1070,8 @@ def run_inference(
618
  if branch_attn is not None
619
  else None
620
  )
 
 
621
 
622
  # Plots
623
  fig_attn = None
@@ -641,7 +1095,50 @@ def run_inference(
641
  title="Attention concentration (how many windows dominate)",
642
  )
643
 
644
- # Heatmap: sort instances by time for interpretability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  if branch_np is not None:
646
  mids = np.array([(a + b) / 2.0 for a, b in times], dtype=np.float64)
647
  order = np.argsort(mids)
@@ -670,6 +1167,7 @@ def run_inference(
670
  int(token_counts[i]) if i < len(token_counts) else None,
671
  float(avg_attn_np[i]) if avg_attn_np is not None else None,
672
  int(topk_np[i]) if topk_np is not None else None,
 
673
  ]
674
  )
675
 
@@ -735,6 +1233,9 @@ def run_inference(
735
  fig_conc,
736
  top_md,
737
  rows,
 
 
 
738
  )
739
 
740
 
@@ -756,69 +1257,91 @@ def build_app() -> gr.Blocks:
756
 
757
  with gr.Blocks(title="TaikoChartEstimator Inference") as demo:
758
  gr.Markdown("# TaikoChartEstimator - Inference")
759
- gr.Markdown(
760
- """
761
- ## How to Read Visualizations
762
-
763
- - The model splits the chart into multiple **windows (instances)** and aggregates them using MIL (Multiple Instance Learning) for a prediction.
764
- - `Avg attention` is the importance weight of this window for the final judgment; it is typically normalized by softmax within a single chart, so the values are usually small.
765
- - `Top-k` is another Top-K pooling branch that selects windows that "look most like peak difficulty points"; they do not necessarily overlap perfectly with attention peaks.
766
-
767
- Recommended combinations:
768
- - `Token density vs attention`: Check if high-density segments are simultaneously emphasized.
769
- - `Attention concentration`: Check if the model relies on only a few windows (closer to 1 means more concentrated).
770
- """
771
- )
772
 
773
  with gr.Row():
774
- with gr.Column(scale=1):
775
- tja_file = gr.File(
776
- label="Upload .tja", file_types=[".tja"], type="filepath"
777
- )
778
- tja_text = gr.Textbox(label="Or paste TJA content", lines=16)
 
 
779
 
780
  course = gr.Dropdown(label="COURSE", choices=[], value=None)
 
781
 
 
 
 
782
  checkpoint = gr.Dropdown(
783
  label="Checkpoint",
784
  choices=checkpoints,
785
  value=checkpoints[-1] if checkpoints else None,
786
  allow_custom_value=True,
787
  )
788
-
789
  device = gr.Dropdown(
790
  label="Device", choices=["cpu", "mps", "cuda"], value="cpu"
791
  )
792
 
793
- window_measures = gr.Textbox(
794
- label="window_measures (comma-separated)", value="2,4"
795
- )
796
- hop_measures = gr.Slider(
797
- label="hop_measures", minimum=1, maximum=8, value=2, step=1
798
- )
799
- max_instances = gr.Slider(
800
- label="max_instances", minimum=8, maximum=256, value=64, step=1
801
- )
802
-
803
- run_btn = gr.Button("Run inference", variant="primary")
804
 
805
- with gr.Column(scale=2):
 
806
  summary = gr.Markdown()
807
- meta_json = gr.JSON(label="Details")
808
- attn_plot = gr.Plot(label="Attention (time-sorted)")
809
- density_plot = gr.Plot(label="Token density vs attention")
810
- heat_plot = gr.Plot(label="Branch attention heatmap")
811
- conc_plot = gr.Plot(label="Attention concentration")
812
  top_segments = gr.Markdown()
813
- table = gr.Dataframe(
 
 
 
 
 
 
 
 
 
 
814
  headers=[
815
- "instance_idx",
816
- "t_start",
817
- "t_end",
818
- "t_mid",
819
- "token_count",
820
- "avg_attention",
821
- "topk_selected",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
  ],
823
  datatype=[
824
  "number",
@@ -828,9 +1351,8 @@ Recommended combinations:
828
  "number",
829
  "number",
830
  "number",
 
831
  ],
832
- label="Per-instance details",
833
- wrap=True,
834
  )
835
 
836
  # Auto-refresh COURSE choices when input changes
@@ -841,7 +1363,7 @@ Recommended combinations:
841
  _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course]
842
  )
843
 
844
- run_btn.click(
845
  run_inference,
846
  inputs=[
847
  tja_file,
@@ -856,12 +1378,15 @@ Recommended combinations:
856
  outputs=[
857
  summary,
858
  meta_json,
859
- attn_plot,
860
- density_plot,
861
- heat_plot,
862
- conc_plot,
863
  top_segments,
864
- table,
 
 
 
865
  ],
866
  )
867
 
 
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
+ import ruptures as rpt
10
  import torch
11
+ from sklearn.cluster import KMeans
12
+ from sklearn.metrics import silhouette_score
13
 
14
  from TaikoChartEstimator.data.tokenizer import EventTokenizer
15
  from TaikoChartEstimator.model.model import TaikoChartEstimator
 
502
  return fig
503
 
504
 
505
+ def _plot_local_difficulty(
506
+ times: list[tuple[float, float]],
507
+ local_stars: np.ndarray,
508
+ token_counts: list[int],
509
+ title: str,
510
+ ):
511
+ """Plot estimated local difficulty (star rating) over time."""
512
+ t0 = np.array([a for a, _ in times], dtype=np.float64)
513
+ t1 = np.array([b for _, b in times], dtype=np.float64)
514
+ mids = (t0 + t1) / 2.0
515
+ durations = np.maximum(t1 - t0, 1e-6)
516
+ token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64)
517
+ density = token_counts_np / durations
518
+
519
+ order = np.argsort(mids)
520
+ mids_s = mids[order]
521
+ stars_s = local_stars[order]
522
+ dens_s = density[order]
523
+
524
+ # EMA Smoothing
525
+ # Alpha = 2 / (span + 1), for span=4 (approx 8-16s depending on window) -> alpha=0.4
526
+ alpha = 0.3
527
+ if len(stars_s) > 0:
528
+ stars_smooth = np.zeros_like(stars_s)
529
+ stars_smooth[0] = stars_s[0]
530
+ for i in range(1, len(stars_s)):
531
+ stars_smooth[i] = alpha * stars_s[i] + (1 - alpha) * stars_smooth[i - 1]
532
+ else:
533
+ stars_smooth = stars_s
534
+
535
+ fig, ax1 = plt.subplots(figsize=(10, 3.5))
536
+
537
+ # Plot difficulty curve
538
+ color = "tab:red"
539
+ ax1.set_xlabel("Time (s)")
540
+ ax1.set_ylabel("Estimated Local Stars", color=color)
541
+
542
+ # Plot raw faint
543
+ ax1.plot(mids_s, stars_s, color=color, linewidth=1, alpha=0.3, label="Raw")
544
+ # Plot smoothed main
545
+ ax1.plot(mids_s, stars_smooth, color=color, linewidth=2.5, label="Smoothed (EMA)")
546
+
547
+ ax1.tick_params(axis="y", labelcolor=color)
548
+ ax1.grid(True, alpha=0.25)
549
+
550
+ # Fill area under smoothed curve
551
+ ax1.fill_between(mids_s, stars_smooth, alpha=0.1, color=color)
552
+
553
+ # Plot density on secondary axis for context
554
+ ax2 = ax1.twinx()
555
+ color2 = "tab:blue"
556
+ ax2.set_ylabel("Density (notes/s)", color=color2)
557
+ ax2.plot(
558
+ mids_s,
559
+ dens_s,
560
+ color=color2,
561
+ linewidth=1,
562
+ linestyle="--",
563
+ alpha=0.5,
564
+ label="Note Density",
565
+ )
566
+ ax2.tick_params(axis="y", labelcolor=color2)
567
+
568
+ ax1.set_title(title)
569
+
570
+ # Legends
571
+ lines1, labels1 = ax1.get_legend_handles_labels()
572
+ lines2, labels2 = ax2.get_legend_handles_labels()
573
+ ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
574
+
575
+ fig.tight_layout()
576
+ return fig
577
+
578
+
579
+ def _smooth_embeddings(embeddings: np.ndarray, window_size: int = 3) -> np.ndarray:
580
+ """Apply temporal smoothing (moving average) to embeddings."""
581
+ if len(embeddings) < window_size:
582
+ return embeddings
583
+
584
+ # Kernel for simple moving average
585
+ kernel = np.ones(window_size) / window_size
586
+
587
+ # Apply to each dimension independenty
588
+ # We can use scipy.ndimage.convolve1d or simplified numpy for dependency-free
589
+ smoothed = np.zeros_like(embeddings)
590
+ for dim in range(embeddings.shape[1]):
591
+ # Padding: 'edge' mode equivalent
592
+ x = embeddings[:, dim]
593
+ pad_width = window_size // 2
594
+ padded = np.pad(x, pad_width, mode="edge")
595
+
596
+ # Convolve
597
+ s = np.convolve(padded, kernel, mode="valid")
598
+
599
+ # Handle shape mismatch due to even/odd window
600
+ if len(s) > len(x):
601
+ s = s[: len(x)]
602
+ elif len(s) < len(x):
603
+ # Should not happen with padded='edge' widely enough but just in case
604
+ s = np.pad(s, (0, len(x) - len(s)), mode="edge")
605
+
606
+ smoothed[:, dim] = s
607
+
608
+ return smoothed
609
+
610
+
611
+ def _smooth_labels(labels: np.ndarray, window_size: int = 3) -> np.ndarray:
612
+ """Apply mode filter to labels to enforce temporal continuity."""
613
+ if len(labels) < window_size:
614
+ return labels
615
+
616
+ n = len(labels)
617
+ smoothed = labels.copy()
618
+ pad = window_size // 2
619
+
620
+ # Simple sliding window mode
621
+ for i in range(n):
622
+ start = max(0, i - pad)
623
+ end = min(n, i + pad + 1)
624
+ window = labels[start:end]
625
+
626
+ # Find mode
627
+ counts = np.bincount(window)
628
+ smoothed[i] = np.argmax(counts)
629
+
630
+ return smoothed
631
+
632
+
633
+ def _perform_clustering(
634
+ embeddings: np.ndarray,
635
+ min_k: int = 3,
636
+ max_k: int = 8,
637
+ smoothing_window: int = 3,
638
+ label_smoothing_window: int = 3,
639
+ random_state: int = 42,
640
+ ) -> tuple[np.ndarray, int, dict]:
641
+ """
642
+ Perform K-Means clustering with automatic K selection using Silhouette Score.
643
+ Applying temporal smoothing to stabilize clusters.
644
+
645
+ Args:
646
+ embeddings: [N, D] data points
647
+ min_k: Minimum number of clusters
648
+ max_k: Maximum number of clusters
649
+
650
+ Returns:
651
+ labels: [N] cluster labels
652
+ best_k: Selected number of clusters
653
+ stats: Info about clustering quality
654
+ """
655
+ # Simply if N is too small
656
+ N = embeddings.shape[0]
657
+ if N < min_k:
658
+ return np.zeros(N, dtype=int), 1, {"score": 0.0}
659
+
660
+ # 1. Temporal Smoothing
661
+ if smoothing_window > 1:
662
+ # print(f"Smoothing embeddings with window={smoothing_window}")
663
+ work_embeddings = _smooth_embeddings(embeddings, window_size=smoothing_window)
664
+ else:
665
+ work_embeddings = embeddings
666
+
667
+ best_score = -1.0
668
+ best_k = min_k
669
+ best_model = None
670
+
671
+ print(f"Clustering {N} instances...")
672
+
673
+ effective_max_k = min(max_k, N - 1)
674
+ if effective_max_k < min_k:
675
+ effective_max_k = min_k
676
+
677
+ for k in range(min_k, effective_max_k + 1):
678
+ kmeans = KMeans(n_clusters=k, random_state=random_state, n_init=10)
679
+ labels = kmeans.fit_predict(work_embeddings)
680
+ try:
681
+ score = silhouette_score(work_embeddings, labels)
682
+ # print(f"K={k}, Silhouette={score:.4f}")
683
+ if score > best_score:
684
+ best_score = score
685
+ best_k = k
686
+ best_model = kmeans
687
+ except Exception:
688
+ pass
689
+
690
+ if best_model is None:
691
+ # Fallback
692
+ kmeans = KMeans(n_clusters=min_k, random_state=random_state, n_init=10)
693
+ kmeans.fit(work_embeddings)
694
+ best_model = kmeans
695
+ best_k = min_k
696
+
697
+ labels = best_model.labels_
698
+
699
+ # 2. Label Smoothing (Post-processing)
700
+ if label_smoothing_window > 1:
701
+ labels = _smooth_labels(labels, window_size=label_smoothing_window)
702
+
703
+ return labels, best_k, {"silhouette": best_score}
704
+
705
+
706
+ def _analyze_clusters(
707
+ cluster_labels: np.ndarray,
708
+ local_stars: np.ndarray,
709
+ note_density: np.ndarray,
710
+ avg_attention: Optional[np.ndarray] = None,
711
+ ) -> list[dict]:
712
+ """
713
+ Analyze properties of each cluster to create a profile.
714
+
715
+ Returns list of dicts: [{id, count, avg_stars, avg_density, avg_attn, desc}]
716
+ """
717
+ unique_labels = np.unique(cluster_labels)
718
+ profiles = []
719
+
720
+ for label in unique_labels:
721
+ mask = cluster_labels == label
722
+ count = mask.sum()
723
+
724
+ avg_s = local_stars[mask].mean() if len(local_stars) > 0 else 0
725
+ avg_d = note_density[mask].mean() if len(note_density) > 0 else 0
726
+ avg_a = avg_attention[mask].mean() if avg_attention is not None else 0
727
+
728
+ profiles.append(
729
+ {
730
+ "Cluster ID": int(label),
731
+ "Count": int(count),
732
+ "Avg Stars": float(f"{avg_s:.2f}"),
733
+ "Avg Density": float(f"{avg_d:.2f}"),
734
+ "Avg Attention": float(f"{avg_a:.4f}"),
735
+ }
736
+ )
737
+
738
+ # Sort by Avg Stars to make it intuitive (Cluster 0 = Easiest or Hardest?)
739
+ # Let's keep ID but maybe we can add a rank?
740
+ # Sorting purely by ID is safer for consistency with plot colors.
741
+ profiles.sort(key=lambda x: x["Cluster ID"])
742
+ return profiles
743
+
744
+
745
+ def _plot_clusters(
746
+ times: list[tuple[float, float]],
747
+ cluster_labels: np.ndarray,
748
+ local_stars: np.ndarray,
749
+ title: str,
750
+ ):
751
+ """Plot timeline colored by cluster ID."""
752
+ t0 = np.array([a for a, _ in times], dtype=np.float64)
753
+ t1 = np.array([b for _, b in times], dtype=np.float64)
754
+ mids = (t0 + t1) / 2.0
755
+
756
+ # Sort
757
+ order = np.argsort(mids)
758
+ mids_s = mids[order]
759
+ stars_s = local_stars[order]
760
+ labels_s = cluster_labels[order]
761
+
762
+ unique_labels = np.unique(labels_s)
763
+ n_clusters = len(unique_labels)
764
+
765
+ # Use a distinct colormap
766
+ cmap = plt.get_cmap("tab10" if n_clusters <= 10 else "tab20")
767
+
768
+ fig, ax = plt.subplots(figsize=(10, 3.5))
769
+
770
+ # We want to plot segments. Since they are time-sorted, we can just scatter or valid-bar plot.
771
+ # A step plot or bar plot might be good.
772
+ # Let's use a scatter plot for simplicity but heavy markers.
773
+
774
+ for i, label in enumerate(unique_labels):
775
+ mask = labels_s == label
776
+ ax.scatter(
777
+ mids_s[mask],
778
+ stars_s[mask],
779
+ color=cmap(i),
780
+ label=f"Cluster {label}",
781
+ s=20,
782
+ alpha=0.8,
783
+ )
784
+
785
+ # Also plot a faint line to show connectivity
786
+ ax.plot(mids_s, stars_s, color="gray", alpha=0.2, linewidth=1)
787
+
788
+ ax.set_xlabel("Time (s)")
789
+ ax.set_ylabel("Local Stars")
790
+ ax.set_title(title)
791
+ ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)
792
+ ax.grid(True, alpha=0.25)
793
+
794
+ fig.tight_layout()
795
+ return fig
796
+
797
+
798
+ def _detect_segments(
799
+ local_stars: np.ndarray,
800
+ times: list[tuple[float, float]],
801
+ min_segment_size: int = 3,
802
+ penalty_scale: float = 1.0,
803
+ ) -> list[dict]:
804
+ """
805
+ Detect segments using Change Point Detection.
806
+
807
+ IMPORTANT: Windows may not be in temporal order (e.g., mixed window sizes).
808
+ We sort by midpoint time first to ensure temporal coherence.
809
+ """
810
+ n = len(local_stars)
811
+ if n < min_segment_size * 2:
812
+ return [
813
+ {
814
+ "start_time": times[0][0],
815
+ "end_time": times[-1][1],
816
+ "avg_stars": float(local_stars.mean()),
817
+ "n_windows": n,
818
+ }
819
+ ]
820
+
821
+ # Calculate window midpoints
822
+ mids = np.array([(t0 + t1) / 2 for t0, t1 in times])
823
+
824
+ # SORT by midpoint time (critical for temporal coherence!)
825
+ order = np.argsort(mids)
826
+ mids_sorted = mids[order]
827
+ stars_sorted = local_stars[order]
828
+ times_sorted = [times[i] for i in order]
829
+
830
+ # Build cell boundaries (1D Voronoi on SORTED windows)
831
+ cell_bounds = [times_sorted[0][0]] # Song start
832
+ for i in range(len(mids_sorted) - 1):
833
+ cell_bounds.append((mids_sorted[i] + mids_sorted[i + 1]) / 2)
834
+ cell_bounds.append(times_sorted[-1][1]) # Song end
835
+
836
+ # Ruptures detection (on SORTED data)
837
+ signal = stars_sorted.reshape(-1, 1)
838
+ penalty = np.var(stars_sorted) * penalty_scale
839
+ algo = rpt.Pelt(model="l2", min_size=min_segment_size).fit(signal)
840
+ change_points = algo.predict(pen=penalty)
841
+
842
+ # Build segments
843
+ segments = []
844
+ prev_idx = 0
845
+
846
+ for cp in change_points:
847
+ seg_stars = stars_sorted[prev_idx:cp]
848
+
849
+ start_t = cell_bounds[prev_idx]
850
+ end_t = cell_bounds[cp]
851
+
852
+ segments.append(
853
+ {
854
+ "start_time": float(start_t),
855
+ "end_time": float(end_t),
856
+ "avg_stars": float(seg_stars.mean()),
857
+ "n_windows": cp - prev_idx,
858
+ }
859
+ )
860
+ prev_idx = cp
861
+
862
+ return segments
863
+
864
+
865
+ def _plot_segments(
866
+ times: list[tuple[float, float]],
867
+ local_stars: np.ndarray,
868
+ segments: list[dict],
869
+ title: str,
870
+ ):
871
+ """
872
+ Plot local difficulty with segment backgrounds (non-overlapping).
873
+ """
874
+ t0 = np.array([a for a, _ in times], dtype=np.float64)
875
+ t1 = np.array([b for _, b in times], dtype=np.float64)
876
+ mids = (t0 + t1) / 2.0
877
+
878
+ order = np.argsort(mids)
879
+ mids_s = mids[order]
880
+ stars_s = local_stars[order]
881
+
882
+ # Colormap: Red=Hard, Green=Easy
883
+ cmap = plt.get_cmap("RdYlGn_r")
884
+
885
+ fig, ax = plt.subplots(figsize=(12, 4))
886
+
887
+ # Normalize colors
888
+ max_star = max(s["avg_stars"] for s in segments) if segments else 10
889
+ min_star = min(s["avg_stars"] for s in segments) if segments else 0
890
+ star_range = max(max_star - min_star, 1)
891
+
892
+ # Draw segment backgrounds (should NOT overlap now)
893
+ for seg in segments:
894
+ color = cmap((seg["avg_stars"] - min_star) / star_range)
895
+ ax.axvspan(
896
+ seg["start_time"], seg["end_time"], alpha=0.3, color=color, linewidth=0
897
+ )
898
+
899
+ # Horizontal line at segment average
900
+ ax.hlines(
901
+ y=seg["avg_stars"],
902
+ xmin=seg["start_time"],
903
+ xmax=seg["end_time"],
904
+ colors=color,
905
+ linewidth=3,
906
+ alpha=0.9,
907
+ )
908
+
909
+ # Label (only if segment is wide enough)
910
+ duration = seg["end_time"] - seg["start_time"]
911
+ if duration > 4: # Only label if > 4 seconds
912
+ mid_x = (seg["start_time"] + seg["end_time"]) / 2
913
+ ax.text(
914
+ mid_x,
915
+ seg["avg_stars"] + 0.02,
916
+ f"{seg['avg_stars']:.1f}",
917
+ ha="center",
918
+ va="bottom",
919
+ fontsize=8,
920
+ fontweight="bold",
921
+ color="black",
922
+ alpha=0.8,
923
+ )
924
+
925
+ # Raw data on top
926
+ ax.plot(mids_s, stars_s, color="gray", alpha=0.4, linewidth=1)
927
+
928
+ # Boundary lines
929
+ for seg in segments[1:]:
930
+ ax.axvline(
931
+ x=seg["start_time"], color="black", linewidth=1, linestyle="--", alpha=0.5
932
+ )
933
+
934
+ ax.set_xlabel("Time (s)")
935
+ ax.set_ylabel("Raw Score")
936
+ ax.set_title(title)
937
+ ax.set_ylim(bottom=0, top=max_star + 2)
938
+ ax.grid(True, alpha=0.15, axis="y")
939
+
940
+ fig.tight_layout()
941
+ return fig
942
+
943
+
944
  def _plot_attention_concentration(
945
  avg_attention: np.ndarray,
946
  title: str,
 
1045
  branch_attn = attn.get("branch_attentions")
1046
  topk_mask = attn.get("topk_mask")
1047
 
1048
+ # Local Difficulty Estimation (Probe)
1049
+ # Use the predicted class ID if no hint was provided
1050
+ calib_diff_id = difficulty_hint
1051
+ if calib_diff_id is None:
1052
+ calib_diff_id = out.difficulty_logits.argmax(dim=-1, keepdim=True) # [1, 1]
1053
+
1054
+ local_raw, local_stars = model.get_instance_scores(
1055
+ out.instance_embeddings, difficulty_class_id=calib_diff_id.view(-1)
1056
+ )
1057
+
1058
  avg_attn_np = (
1059
  avg_attn[0, : counts.item()].detach().cpu().numpy()
1060
  if avg_attn is not None
 
1070
  if branch_attn is not None
1071
  else None
1072
  )
1073
+ local_stars_np = local_stars[0, : counts.item()].detach().cpu().numpy()
1074
+ local_raw_np = local_raw[0, : counts.item()].detach().cpu().numpy()
1075
 
1076
  # Plots
1077
  fig_attn = None
 
1095
  title="Attention concentration (how many windows dominate)",
1096
  )
1097
 
1098
+ fig_local_diff = None
1099
+ if local_stars_np is not None:
1100
+ fig_local_diff = _plot_local_difficulty(
1101
+ times,
1102
+ local_stars_np,
1103
+ token_counts,
1104
+ title=f"Estimated Local Difficulty Curve (Assuming {pred_class} calibration)",
1105
+ )
1106
+
1107
+ # Segment Detection (Piecewise Constant Change Point Detection)
1108
+ fig_segments = None
1109
+ segment_table_df = None
1110
+
1111
+ if local_raw_np is not None and len(times) > 0:
1112
+ segments = _detect_segments(
1113
+ local_raw_np, # Use raw score instead of stars
1114
+ times,
1115
+ min_segment_size=3,
1116
+ penalty_scale=0.5,
1117
+ )
1118
+
1119
+ # Create table rows
1120
+ seg_rows = []
1121
+ for i, seg in enumerate(segments):
1122
+ seg_rows.append(
1123
+ [
1124
+ i + 1,
1125
+ f"{seg['start_time']:.1f}",
1126
+ f"{seg['end_time']:.1f}",
1127
+ f"{seg['end_time'] - seg['start_time']:.1f}",
1128
+ f"{seg['avg_stars']:.1f}", # This is now avg_raw
1129
+ seg["n_windows"],
1130
+ ]
1131
+ )
1132
+ segment_table_df = seg_rows
1133
+
1134
+ fig_segments = _plot_segments(
1135
+ times,
1136
+ local_raw_np, # Use raw score
1137
+ segments,
1138
+ title=f"Chart Structure: {len(segments)} Segments Detected",
1139
+ )
1140
+
1141
+ # Meta/details
1142
  if branch_np is not None:
1143
  mids = np.array([(a + b) / 2.0 for a, b in times], dtype=np.float64)
1144
  order = np.argsort(mids)
 
1167
  int(token_counts[i]) if i < len(token_counts) else None,
1168
  float(avg_attn_np[i]) if avg_attn_np is not None else None,
1169
  int(topk_np[i]) if topk_np is not None else None,
1170
+ float(local_stars_np[i]) if i < len(local_stars_np) else None,
1171
  ]
1172
  )
1173
 
 
1233
  fig_conc,
1234
  top_md,
1235
  rows,
1236
+ fig_local_diff,
1237
+ fig_segments,
1238
+ segment_table_df,
1239
  )
1240
 
1241
 
 
1257
 
1258
  with gr.Blocks(title="TaikoChartEstimator Inference") as demo:
1259
  gr.Markdown("# TaikoChartEstimator - Inference")
 
 
 
 
 
 
 
 
 
 
 
 
 
1260
 
1261
  with gr.Row():
1262
+ # Left: Input (Upload/Paste with tabs)
1263
+ with gr.Column(scale=2):
1264
+ with gr.Tabs():
1265
+ with gr.TabItem("Upload"):
1266
+ tja_file = gr.File(label="Upload TJA file")
1267
+ with gr.TabItem("Paste"):
1268
+ tja_text = gr.Textbox(label="Paste TJA content", lines=12)
1269
 
1270
  course = gr.Dropdown(label="COURSE", choices=[], value=None)
1271
+ btn = gr.Button("Run Inference", variant="primary", size="lg")
1272
 
1273
+ # Right: Options
1274
+ with gr.Column(scale=1):
1275
+ gr.Markdown("### Options")
1276
  checkpoint = gr.Dropdown(
1277
  label="Checkpoint",
1278
  choices=checkpoints,
1279
  value=checkpoints[-1] if checkpoints else None,
1280
  allow_custom_value=True,
1281
  )
 
1282
  device = gr.Dropdown(
1283
  label="Device", choices=["cpu", "mps", "cuda"], value="cpu"
1284
  )
1285
 
1286
+ with gr.Accordion("Advanced", open=False):
1287
+ window_measures = gr.Textbox(
1288
+ label="window_measures (comma-separated)", value="2,4"
1289
+ )
1290
+ hop_measures = gr.Slider(
1291
+ label="hop_measures", minimum=1, maximum=8, value=2, step=1
1292
+ )
1293
+ max_instances = gr.Slider(
1294
+ label="max_instances", minimum=1, maximum=512, value=128, step=1
1295
+ )
 
1296
 
1297
+ with gr.Row():
1298
+ with gr.Column(scale=1):
1299
  summary = gr.Markdown()
 
 
 
 
 
1300
  top_segments = gr.Markdown()
1301
+ with gr.Column(scale=1):
1302
+ meta_json = gr.JSON(label="Metadata")
1303
+
1304
+ with gr.Tabs():
1305
+ with gr.TabItem("Chart Structure"):
1306
+ gr.Markdown("### Automatic Segment Detection")
1307
+ gr.Markdown(
1308
+ "Detects distinct sections based on difficulty changes (Piecewise Constant Model)."
1309
+ )
1310
+ plot_segments = gr.Plot(label="Detected Segments")
1311
+ segment_table = gr.Dataframe(
1312
  headers=[
1313
+ "#",
1314
+ "Start (s)",
1315
+ "End (s)",
1316
+ "Duration",
1317
+ "Avg Raw",
1318
+ "Windows",
1319
+ ],
1320
+ datatype=["number", "str", "str", "str", "str", "number"],
1321
+ label="Segment Details",
1322
+ )
1323
+ with gr.TabItem("Local Difficulty"):
1324
+ plot_local_diff = gr.Plot(label="Local Difficulty Curve")
1325
+ with gr.TabItem("Attention & Density"):
1326
+ plot_density = gr.Plot(label="Density vs Attention")
1327
+ with gr.TabItem("Attention Details"):
1328
+ plot_attn = gr.Plot(label="Raw Attention")
1329
+ with gr.TabItem("Heatmap"):
1330
+ plot_heat = gr.Plot(label="Branch Heatmap")
1331
+ with gr.TabItem("Concentration"):
1332
+ plot_conc = gr.Plot(label="Concentration")
1333
+ with gr.TabItem("Raw Data"):
1334
+ # headers needs to match rows
1335
+ df = gr.Dataframe(
1336
+ headers=[
1337
+ "id",
1338
+ "start",
1339
+ "end",
1340
+ "mid",
1341
+ "tokens",
1342
+ "attention",
1343
+ "is_topk",
1344
+ "local_stars",
1345
  ],
1346
  datatype=[
1347
  "number",
 
1351
  "number",
1352
  "number",
1353
  "number",
1354
+ "number",
1355
  ],
 
 
1356
  )
1357
 
1358
  # Auto-refresh COURSE choices when input changes
 
1363
  _update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course]
1364
  )
1365
 
1366
+ btn.click(
1367
  run_inference,
1368
  inputs=[
1369
  tja_file,
 
1378
  outputs=[
1379
  summary,
1380
  meta_json,
1381
+ plot_attn,
1382
+ plot_density,
1383
+ plot_heat,
1384
+ plot_conc,
1385
  top_segments,
1386
+ df,
1387
+ plot_local_diff,
1388
+ plot_segments,
1389
+ segment_table,
1390
  ],
1391
  )
1392
 
requirements.txt CHANGED
@@ -64,6 +64,7 @@ pytz==2025.2
64
  pyyaml==6.0.3
65
  requests==2.32.5
66
  rich==14.2.0
 
67
  safehttpx==0.1.7
68
  safetensors==0.7.0
69
  scikit-learn==1.8.0
 
64
  pyyaml==6.0.3
65
  requests==2.32.5
66
  rich==14.2.0
67
+ ruptures==1.1.10
68
  safehttpx==0.1.7
69
  safetensors==0.7.0
70
  scikit-learn==1.8.0