Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +578 -53
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -6,7 +6,10 @@ from typing import Any, Optional
|
|
| 6 |
import gradio as gr
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import numpy as np
|
|
|
|
| 9 |
import torch
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from TaikoChartEstimator.data.tokenizer import EventTokenizer
|
| 12 |
from TaikoChartEstimator.model.model import TaikoChartEstimator
|
|
@@ -499,6 +502,445 @@ def _plot_density_and_attention(
|
|
| 499 |
return fig
|
| 500 |
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
def _plot_attention_concentration(
|
| 503 |
avg_attention: np.ndarray,
|
| 504 |
title: str,
|
|
@@ -603,6 +1045,16 @@ def run_inference(
|
|
| 603 |
branch_attn = attn.get("branch_attentions")
|
| 604 |
topk_mask = attn.get("topk_mask")
|
| 605 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
avg_attn_np = (
|
| 607 |
avg_attn[0, : counts.item()].detach().cpu().numpy()
|
| 608 |
if avg_attn is not None
|
|
@@ -618,6 +1070,8 @@ def run_inference(
|
|
| 618 |
if branch_attn is not None
|
| 619 |
else None
|
| 620 |
)
|
|
|
|
|
|
|
| 621 |
|
| 622 |
# Plots
|
| 623 |
fig_attn = None
|
|
@@ -641,7 +1095,50 @@ def run_inference(
|
|
| 641 |
title="Attention concentration (how many windows dominate)",
|
| 642 |
)
|
| 643 |
|
| 644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
if branch_np is not None:
|
| 646 |
mids = np.array([(a + b) / 2.0 for a, b in times], dtype=np.float64)
|
| 647 |
order = np.argsort(mids)
|
|
@@ -670,6 +1167,7 @@ def run_inference(
|
|
| 670 |
int(token_counts[i]) if i < len(token_counts) else None,
|
| 671 |
float(avg_attn_np[i]) if avg_attn_np is not None else None,
|
| 672 |
int(topk_np[i]) if topk_np is not None else None,
|
|
|
|
| 673 |
]
|
| 674 |
)
|
| 675 |
|
|
@@ -735,6 +1233,9 @@ def run_inference(
|
|
| 735 |
fig_conc,
|
| 736 |
top_md,
|
| 737 |
rows,
|
|
|
|
|
|
|
|
|
|
| 738 |
)
|
| 739 |
|
| 740 |
|
|
@@ -756,69 +1257,91 @@ def build_app() -> gr.Blocks:
|
|
| 756 |
|
| 757 |
with gr.Blocks(title="TaikoChartEstimator Inference") as demo:
|
| 758 |
gr.Markdown("# TaikoChartEstimator - Inference")
|
| 759 |
-
gr.Markdown(
|
| 760 |
-
"""
|
| 761 |
-
## How to Read Visualizations
|
| 762 |
-
|
| 763 |
-
- The model splits the chart into multiple **windows (instances)** and aggregates them using MIL (Multiple Instance Learning) for a prediction.
|
| 764 |
-
- `Avg attention` is the importance weight of this window for the final judgment; it is typically normalized by softmax within a single chart, so the values are usually small.
|
| 765 |
-
- `Top-k` is another Top-K pooling branch that selects windows that "look most like peak difficulty points"; they do not necessarily overlap perfectly with attention peaks.
|
| 766 |
-
|
| 767 |
-
Recommended combinations:
|
| 768 |
-
- `Token density vs attention`: Check if high-density segments are simultaneously emphasized.
|
| 769 |
-
- `Attention concentration`: Check if the model relies on only a few windows (closer to 1 means more concentrated).
|
| 770 |
-
"""
|
| 771 |
-
)
|
| 772 |
|
| 773 |
with gr.Row():
|
| 774 |
-
with
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
|
|
|
|
|
|
| 779 |
|
| 780 |
course = gr.Dropdown(label="COURSE", choices=[], value=None)
|
|
|
|
| 781 |
|
|
|
|
|
|
|
|
|
|
| 782 |
checkpoint = gr.Dropdown(
|
| 783 |
label="Checkpoint",
|
| 784 |
choices=checkpoints,
|
| 785 |
value=checkpoints[-1] if checkpoints else None,
|
| 786 |
allow_custom_value=True,
|
| 787 |
)
|
| 788 |
-
|
| 789 |
device = gr.Dropdown(
|
| 790 |
label="Device", choices=["cpu", "mps", "cuda"], value="cpu"
|
| 791 |
)
|
| 792 |
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
run_btn = gr.Button("Run inference", variant="primary")
|
| 804 |
|
| 805 |
-
|
|
|
|
| 806 |
summary = gr.Markdown()
|
| 807 |
-
meta_json = gr.JSON(label="Details")
|
| 808 |
-
attn_plot = gr.Plot(label="Attention (time-sorted)")
|
| 809 |
-
density_plot = gr.Plot(label="Token density vs attention")
|
| 810 |
-
heat_plot = gr.Plot(label="Branch attention heatmap")
|
| 811 |
-
conc_plot = gr.Plot(label="Attention concentration")
|
| 812 |
top_segments = gr.Markdown()
|
| 813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
headers=[
|
| 815 |
-
"
|
| 816 |
-
"
|
| 817 |
-
"
|
| 818 |
-
"
|
| 819 |
-
"
|
| 820 |
-
"
|
| 821 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 822 |
],
|
| 823 |
datatype=[
|
| 824 |
"number",
|
|
@@ -828,9 +1351,8 @@ Recommended combinations:
|
|
| 828 |
"number",
|
| 829 |
"number",
|
| 830 |
"number",
|
|
|
|
| 831 |
],
|
| 832 |
-
label="Per-instance details",
|
| 833 |
-
wrap=True,
|
| 834 |
)
|
| 835 |
|
| 836 |
# Auto-refresh COURSE choices when input changes
|
|
@@ -841,7 +1363,7 @@ Recommended combinations:
|
|
| 841 |
_update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course]
|
| 842 |
)
|
| 843 |
|
| 844 |
-
|
| 845 |
run_inference,
|
| 846 |
inputs=[
|
| 847 |
tja_file,
|
|
@@ -856,12 +1378,15 @@ Recommended combinations:
|
|
| 856 |
outputs=[
|
| 857 |
summary,
|
| 858 |
meta_json,
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
top_segments,
|
| 864 |
-
|
|
|
|
|
|
|
|
|
|
| 865 |
],
|
| 866 |
)
|
| 867 |
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import numpy as np
|
| 9 |
+
import ruptures as rpt
|
| 10 |
import torch
|
| 11 |
+
from sklearn.cluster import KMeans
|
| 12 |
+
from sklearn.metrics import silhouette_score
|
| 13 |
|
| 14 |
from TaikoChartEstimator.data.tokenizer import EventTokenizer
|
| 15 |
from TaikoChartEstimator.model.model import TaikoChartEstimator
|
|
|
|
| 502 |
return fig
|
| 503 |
|
| 504 |
|
| 505 |
+
def _plot_local_difficulty(
|
| 506 |
+
times: list[tuple[float, float]],
|
| 507 |
+
local_stars: np.ndarray,
|
| 508 |
+
token_counts: list[int],
|
| 509 |
+
title: str,
|
| 510 |
+
):
|
| 511 |
+
"""Plot estimated local difficulty (star rating) over time."""
|
| 512 |
+
t0 = np.array([a for a, _ in times], dtype=np.float64)
|
| 513 |
+
t1 = np.array([b for _, b in times], dtype=np.float64)
|
| 514 |
+
mids = (t0 + t1) / 2.0
|
| 515 |
+
durations = np.maximum(t1 - t0, 1e-6)
|
| 516 |
+
token_counts_np = np.array(token_counts[: len(times)], dtype=np.float64)
|
| 517 |
+
density = token_counts_np / durations
|
| 518 |
+
|
| 519 |
+
order = np.argsort(mids)
|
| 520 |
+
mids_s = mids[order]
|
| 521 |
+
stars_s = local_stars[order]
|
| 522 |
+
dens_s = density[order]
|
| 523 |
+
|
| 524 |
+
# EMA Smoothing
|
| 525 |
+
# Alpha = 2 / (span + 1), for span=4 (approx 8-16s depending on window) -> alpha=0.4
|
| 526 |
+
alpha = 0.3
|
| 527 |
+
if len(stars_s) > 0:
|
| 528 |
+
stars_smooth = np.zeros_like(stars_s)
|
| 529 |
+
stars_smooth[0] = stars_s[0]
|
| 530 |
+
for i in range(1, len(stars_s)):
|
| 531 |
+
stars_smooth[i] = alpha * stars_s[i] + (1 - alpha) * stars_smooth[i - 1]
|
| 532 |
+
else:
|
| 533 |
+
stars_smooth = stars_s
|
| 534 |
+
|
| 535 |
+
fig, ax1 = plt.subplots(figsize=(10, 3.5))
|
| 536 |
+
|
| 537 |
+
# Plot difficulty curve
|
| 538 |
+
color = "tab:red"
|
| 539 |
+
ax1.set_xlabel("Time (s)")
|
| 540 |
+
ax1.set_ylabel("Estimated Local Stars", color=color)
|
| 541 |
+
|
| 542 |
+
# Plot raw faint
|
| 543 |
+
ax1.plot(mids_s, stars_s, color=color, linewidth=1, alpha=0.3, label="Raw")
|
| 544 |
+
# Plot smoothed main
|
| 545 |
+
ax1.plot(mids_s, stars_smooth, color=color, linewidth=2.5, label="Smoothed (EMA)")
|
| 546 |
+
|
| 547 |
+
ax1.tick_params(axis="y", labelcolor=color)
|
| 548 |
+
ax1.grid(True, alpha=0.25)
|
| 549 |
+
|
| 550 |
+
# Fill area under smoothed curve
|
| 551 |
+
ax1.fill_between(mids_s, stars_smooth, alpha=0.1, color=color)
|
| 552 |
+
|
| 553 |
+
# Plot density on secondary axis for context
|
| 554 |
+
ax2 = ax1.twinx()
|
| 555 |
+
color2 = "tab:blue"
|
| 556 |
+
ax2.set_ylabel("Density (notes/s)", color=color2)
|
| 557 |
+
ax2.plot(
|
| 558 |
+
mids_s,
|
| 559 |
+
dens_s,
|
| 560 |
+
color=color2,
|
| 561 |
+
linewidth=1,
|
| 562 |
+
linestyle="--",
|
| 563 |
+
alpha=0.5,
|
| 564 |
+
label="Note Density",
|
| 565 |
+
)
|
| 566 |
+
ax2.tick_params(axis="y", labelcolor=color2)
|
| 567 |
+
|
| 568 |
+
ax1.set_title(title)
|
| 569 |
+
|
| 570 |
+
# Legends
|
| 571 |
+
lines1, labels1 = ax1.get_legend_handles_labels()
|
| 572 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
| 573 |
+
ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
|
| 574 |
+
|
| 575 |
+
fig.tight_layout()
|
| 576 |
+
return fig
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def _smooth_embeddings(embeddings: np.ndarray, window_size: int = 3) -> np.ndarray:
|
| 580 |
+
"""Apply temporal smoothing (moving average) to embeddings."""
|
| 581 |
+
if len(embeddings) < window_size:
|
| 582 |
+
return embeddings
|
| 583 |
+
|
| 584 |
+
# Kernel for simple moving average
|
| 585 |
+
kernel = np.ones(window_size) / window_size
|
| 586 |
+
|
| 587 |
+
# Apply to each dimension independenty
|
| 588 |
+
# We can use scipy.ndimage.convolve1d or simplified numpy for dependency-free
|
| 589 |
+
smoothed = np.zeros_like(embeddings)
|
| 590 |
+
for dim in range(embeddings.shape[1]):
|
| 591 |
+
# Padding: 'edge' mode equivalent
|
| 592 |
+
x = embeddings[:, dim]
|
| 593 |
+
pad_width = window_size // 2
|
| 594 |
+
padded = np.pad(x, pad_width, mode="edge")
|
| 595 |
+
|
| 596 |
+
# Convolve
|
| 597 |
+
s = np.convolve(padded, kernel, mode="valid")
|
| 598 |
+
|
| 599 |
+
# Handle shape mismatch due to even/odd window
|
| 600 |
+
if len(s) > len(x):
|
| 601 |
+
s = s[: len(x)]
|
| 602 |
+
elif len(s) < len(x):
|
| 603 |
+
# Should not happen with padded='edge' widely enough but just in case
|
| 604 |
+
s = np.pad(s, (0, len(x) - len(s)), mode="edge")
|
| 605 |
+
|
| 606 |
+
smoothed[:, dim] = s
|
| 607 |
+
|
| 608 |
+
return smoothed
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def _smooth_labels(labels: np.ndarray, window_size: int = 3) -> np.ndarray:
|
| 612 |
+
"""Apply mode filter to labels to enforce temporal continuity."""
|
| 613 |
+
if len(labels) < window_size:
|
| 614 |
+
return labels
|
| 615 |
+
|
| 616 |
+
n = len(labels)
|
| 617 |
+
smoothed = labels.copy()
|
| 618 |
+
pad = window_size // 2
|
| 619 |
+
|
| 620 |
+
# Simple sliding window mode
|
| 621 |
+
for i in range(n):
|
| 622 |
+
start = max(0, i - pad)
|
| 623 |
+
end = min(n, i + pad + 1)
|
| 624 |
+
window = labels[start:end]
|
| 625 |
+
|
| 626 |
+
# Find mode
|
| 627 |
+
counts = np.bincount(window)
|
| 628 |
+
smoothed[i] = np.argmax(counts)
|
| 629 |
+
|
| 630 |
+
return smoothed
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def _perform_clustering(
|
| 634 |
+
embeddings: np.ndarray,
|
| 635 |
+
min_k: int = 3,
|
| 636 |
+
max_k: int = 8,
|
| 637 |
+
smoothing_window: int = 3,
|
| 638 |
+
label_smoothing_window: int = 3,
|
| 639 |
+
random_state: int = 42,
|
| 640 |
+
) -> tuple[np.ndarray, int, dict]:
|
| 641 |
+
"""
|
| 642 |
+
Perform K-Means clustering with automatic K selection using Silhouette Score.
|
| 643 |
+
Applying temporal smoothing to stabilize clusters.
|
| 644 |
+
|
| 645 |
+
Args:
|
| 646 |
+
embeddings: [N, D] data points
|
| 647 |
+
min_k: Minimum number of clusters
|
| 648 |
+
max_k: Maximum number of clusters
|
| 649 |
+
|
| 650 |
+
Returns:
|
| 651 |
+
labels: [N] cluster labels
|
| 652 |
+
best_k: Selected number of clusters
|
| 653 |
+
stats: Info about clustering quality
|
| 654 |
+
"""
|
| 655 |
+
# Simply if N is too small
|
| 656 |
+
N = embeddings.shape[0]
|
| 657 |
+
if N < min_k:
|
| 658 |
+
return np.zeros(N, dtype=int), 1, {"score": 0.0}
|
| 659 |
+
|
| 660 |
+
# 1. Temporal Smoothing
|
| 661 |
+
if smoothing_window > 1:
|
| 662 |
+
# print(f"Smoothing embeddings with window={smoothing_window}")
|
| 663 |
+
work_embeddings = _smooth_embeddings(embeddings, window_size=smoothing_window)
|
| 664 |
+
else:
|
| 665 |
+
work_embeddings = embeddings
|
| 666 |
+
|
| 667 |
+
best_score = -1.0
|
| 668 |
+
best_k = min_k
|
| 669 |
+
best_model = None
|
| 670 |
+
|
| 671 |
+
print(f"Clustering {N} instances...")
|
| 672 |
+
|
| 673 |
+
effective_max_k = min(max_k, N - 1)
|
| 674 |
+
if effective_max_k < min_k:
|
| 675 |
+
effective_max_k = min_k
|
| 676 |
+
|
| 677 |
+
for k in range(min_k, effective_max_k + 1):
|
| 678 |
+
kmeans = KMeans(n_clusters=k, random_state=random_state, n_init=10)
|
| 679 |
+
labels = kmeans.fit_predict(work_embeddings)
|
| 680 |
+
try:
|
| 681 |
+
score = silhouette_score(work_embeddings, labels)
|
| 682 |
+
# print(f"K={k}, Silhouette={score:.4f}")
|
| 683 |
+
if score > best_score:
|
| 684 |
+
best_score = score
|
| 685 |
+
best_k = k
|
| 686 |
+
best_model = kmeans
|
| 687 |
+
except Exception:
|
| 688 |
+
pass
|
| 689 |
+
|
| 690 |
+
if best_model is None:
|
| 691 |
+
# Fallback
|
| 692 |
+
kmeans = KMeans(n_clusters=min_k, random_state=random_state, n_init=10)
|
| 693 |
+
kmeans.fit(work_embeddings)
|
| 694 |
+
best_model = kmeans
|
| 695 |
+
best_k = min_k
|
| 696 |
+
|
| 697 |
+
labels = best_model.labels_
|
| 698 |
+
|
| 699 |
+
# 2. Label Smoothing (Post-processing)
|
| 700 |
+
if label_smoothing_window > 1:
|
| 701 |
+
labels = _smooth_labels(labels, window_size=label_smoothing_window)
|
| 702 |
+
|
| 703 |
+
return labels, best_k, {"silhouette": best_score}
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def _analyze_clusters(
|
| 707 |
+
cluster_labels: np.ndarray,
|
| 708 |
+
local_stars: np.ndarray,
|
| 709 |
+
note_density: np.ndarray,
|
| 710 |
+
avg_attention: Optional[np.ndarray] = None,
|
| 711 |
+
) -> list[dict]:
|
| 712 |
+
"""
|
| 713 |
+
Analyze properties of each cluster to create a profile.
|
| 714 |
+
|
| 715 |
+
Returns list of dicts: [{id, count, avg_stars, avg_density, avg_attn, desc}]
|
| 716 |
+
"""
|
| 717 |
+
unique_labels = np.unique(cluster_labels)
|
| 718 |
+
profiles = []
|
| 719 |
+
|
| 720 |
+
for label in unique_labels:
|
| 721 |
+
mask = cluster_labels == label
|
| 722 |
+
count = mask.sum()
|
| 723 |
+
|
| 724 |
+
avg_s = local_stars[mask].mean() if len(local_stars) > 0 else 0
|
| 725 |
+
avg_d = note_density[mask].mean() if len(note_density) > 0 else 0
|
| 726 |
+
avg_a = avg_attention[mask].mean() if avg_attention is not None else 0
|
| 727 |
+
|
| 728 |
+
profiles.append(
|
| 729 |
+
{
|
| 730 |
+
"Cluster ID": int(label),
|
| 731 |
+
"Count": int(count),
|
| 732 |
+
"Avg Stars": float(f"{avg_s:.2f}"),
|
| 733 |
+
"Avg Density": float(f"{avg_d:.2f}"),
|
| 734 |
+
"Avg Attention": float(f"{avg_a:.4f}"),
|
| 735 |
+
}
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Sort by Avg Stars to make it intuitive (Cluster 0 = Easiest or Hardest?)
|
| 739 |
+
# Let's keep ID but maybe we can add a rank?
|
| 740 |
+
# Sorting purely by ID is safer for consistency with plot colors.
|
| 741 |
+
profiles.sort(key=lambda x: x["Cluster ID"])
|
| 742 |
+
return profiles
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def _plot_clusters(
|
| 746 |
+
times: list[tuple[float, float]],
|
| 747 |
+
cluster_labels: np.ndarray,
|
| 748 |
+
local_stars: np.ndarray,
|
| 749 |
+
title: str,
|
| 750 |
+
):
|
| 751 |
+
"""Plot timeline colored by cluster ID."""
|
| 752 |
+
t0 = np.array([a for a, _ in times], dtype=np.float64)
|
| 753 |
+
t1 = np.array([b for _, b in times], dtype=np.float64)
|
| 754 |
+
mids = (t0 + t1) / 2.0
|
| 755 |
+
|
| 756 |
+
# Sort
|
| 757 |
+
order = np.argsort(mids)
|
| 758 |
+
mids_s = mids[order]
|
| 759 |
+
stars_s = local_stars[order]
|
| 760 |
+
labels_s = cluster_labels[order]
|
| 761 |
+
|
| 762 |
+
unique_labels = np.unique(labels_s)
|
| 763 |
+
n_clusters = len(unique_labels)
|
| 764 |
+
|
| 765 |
+
# Use a distinct colormap
|
| 766 |
+
cmap = plt.get_cmap("tab10" if n_clusters <= 10 else "tab20")
|
| 767 |
+
|
| 768 |
+
fig, ax = plt.subplots(figsize=(10, 3.5))
|
| 769 |
+
|
| 770 |
+
# We want to plot segments. Since they are time-sorted, we can just scatter or valid-bar plot.
|
| 771 |
+
# A step plot or bar plot might be good.
|
| 772 |
+
# Let's use a scatter plot for simplicity but heavy markers.
|
| 773 |
+
|
| 774 |
+
for i, label in enumerate(unique_labels):
|
| 775 |
+
mask = labels_s == label
|
| 776 |
+
ax.scatter(
|
| 777 |
+
mids_s[mask],
|
| 778 |
+
stars_s[mask],
|
| 779 |
+
color=cmap(i),
|
| 780 |
+
label=f"Cluster {label}",
|
| 781 |
+
s=20,
|
| 782 |
+
alpha=0.8,
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
# Also plot a faint line to show connectivity
|
| 786 |
+
ax.plot(mids_s, stars_s, color="gray", alpha=0.2, linewidth=1)
|
| 787 |
+
|
| 788 |
+
ax.set_xlabel("Time (s)")
|
| 789 |
+
ax.set_ylabel("Local Stars")
|
| 790 |
+
ax.set_title(title)
|
| 791 |
+
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)
|
| 792 |
+
ax.grid(True, alpha=0.25)
|
| 793 |
+
|
| 794 |
+
fig.tight_layout()
|
| 795 |
+
return fig
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def _detect_segments(
|
| 799 |
+
local_stars: np.ndarray,
|
| 800 |
+
times: list[tuple[float, float]],
|
| 801 |
+
min_segment_size: int = 3,
|
| 802 |
+
penalty_scale: float = 1.0,
|
| 803 |
+
) -> list[dict]:
|
| 804 |
+
"""
|
| 805 |
+
Detect segments using Change Point Detection.
|
| 806 |
+
|
| 807 |
+
IMPORTANT: Windows may not be in temporal order (e.g., mixed window sizes).
|
| 808 |
+
We sort by midpoint time first to ensure temporal coherence.
|
| 809 |
+
"""
|
| 810 |
+
n = len(local_stars)
|
| 811 |
+
if n < min_segment_size * 2:
|
| 812 |
+
return [
|
| 813 |
+
{
|
| 814 |
+
"start_time": times[0][0],
|
| 815 |
+
"end_time": times[-1][1],
|
| 816 |
+
"avg_stars": float(local_stars.mean()),
|
| 817 |
+
"n_windows": n,
|
| 818 |
+
}
|
| 819 |
+
]
|
| 820 |
+
|
| 821 |
+
# Calculate window midpoints
|
| 822 |
+
mids = np.array([(t0 + t1) / 2 for t0, t1 in times])
|
| 823 |
+
|
| 824 |
+
# SORT by midpoint time (critical for temporal coherence!)
|
| 825 |
+
order = np.argsort(mids)
|
| 826 |
+
mids_sorted = mids[order]
|
| 827 |
+
stars_sorted = local_stars[order]
|
| 828 |
+
times_sorted = [times[i] for i in order]
|
| 829 |
+
|
| 830 |
+
# Build cell boundaries (1D Voronoi on SORTED windows)
|
| 831 |
+
cell_bounds = [times_sorted[0][0]] # Song start
|
| 832 |
+
for i in range(len(mids_sorted) - 1):
|
| 833 |
+
cell_bounds.append((mids_sorted[i] + mids_sorted[i + 1]) / 2)
|
| 834 |
+
cell_bounds.append(times_sorted[-1][1]) # Song end
|
| 835 |
+
|
| 836 |
+
# Ruptures detection (on SORTED data)
|
| 837 |
+
signal = stars_sorted.reshape(-1, 1)
|
| 838 |
+
penalty = np.var(stars_sorted) * penalty_scale
|
| 839 |
+
algo = rpt.Pelt(model="l2", min_size=min_segment_size).fit(signal)
|
| 840 |
+
change_points = algo.predict(pen=penalty)
|
| 841 |
+
|
| 842 |
+
# Build segments
|
| 843 |
+
segments = []
|
| 844 |
+
prev_idx = 0
|
| 845 |
+
|
| 846 |
+
for cp in change_points:
|
| 847 |
+
seg_stars = stars_sorted[prev_idx:cp]
|
| 848 |
+
|
| 849 |
+
start_t = cell_bounds[prev_idx]
|
| 850 |
+
end_t = cell_bounds[cp]
|
| 851 |
+
|
| 852 |
+
segments.append(
|
| 853 |
+
{
|
| 854 |
+
"start_time": float(start_t),
|
| 855 |
+
"end_time": float(end_t),
|
| 856 |
+
"avg_stars": float(seg_stars.mean()),
|
| 857 |
+
"n_windows": cp - prev_idx,
|
| 858 |
+
}
|
| 859 |
+
)
|
| 860 |
+
prev_idx = cp
|
| 861 |
+
|
| 862 |
+
return segments
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def _plot_segments(
|
| 866 |
+
times: list[tuple[float, float]],
|
| 867 |
+
local_stars: np.ndarray,
|
| 868 |
+
segments: list[dict],
|
| 869 |
+
title: str,
|
| 870 |
+
):
|
| 871 |
+
"""
|
| 872 |
+
Plot local difficulty with segment backgrounds (non-overlapping).
|
| 873 |
+
"""
|
| 874 |
+
t0 = np.array([a for a, _ in times], dtype=np.float64)
|
| 875 |
+
t1 = np.array([b for _, b in times], dtype=np.float64)
|
| 876 |
+
mids = (t0 + t1) / 2.0
|
| 877 |
+
|
| 878 |
+
order = np.argsort(mids)
|
| 879 |
+
mids_s = mids[order]
|
| 880 |
+
stars_s = local_stars[order]
|
| 881 |
+
|
| 882 |
+
# Colormap: Red=Hard, Green=Easy
|
| 883 |
+
cmap = plt.get_cmap("RdYlGn_r")
|
| 884 |
+
|
| 885 |
+
fig, ax = plt.subplots(figsize=(12, 4))
|
| 886 |
+
|
| 887 |
+
# Normalize colors
|
| 888 |
+
max_star = max(s["avg_stars"] for s in segments) if segments else 10
|
| 889 |
+
min_star = min(s["avg_stars"] for s in segments) if segments else 0
|
| 890 |
+
star_range = max(max_star - min_star, 1)
|
| 891 |
+
|
| 892 |
+
# Draw segment backgrounds (should NOT overlap now)
|
| 893 |
+
for seg in segments:
|
| 894 |
+
color = cmap((seg["avg_stars"] - min_star) / star_range)
|
| 895 |
+
ax.axvspan(
|
| 896 |
+
seg["start_time"], seg["end_time"], alpha=0.3, color=color, linewidth=0
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# Horizontal line at segment average
|
| 900 |
+
ax.hlines(
|
| 901 |
+
y=seg["avg_stars"],
|
| 902 |
+
xmin=seg["start_time"],
|
| 903 |
+
xmax=seg["end_time"],
|
| 904 |
+
colors=color,
|
| 905 |
+
linewidth=3,
|
| 906 |
+
alpha=0.9,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# Label (only if segment is wide enough)
|
| 910 |
+
duration = seg["end_time"] - seg["start_time"]
|
| 911 |
+
if duration > 4: # Only label if > 4 seconds
|
| 912 |
+
mid_x = (seg["start_time"] + seg["end_time"]) / 2
|
| 913 |
+
ax.text(
|
| 914 |
+
mid_x,
|
| 915 |
+
seg["avg_stars"] + 0.02,
|
| 916 |
+
f"{seg['avg_stars']:.1f}",
|
| 917 |
+
ha="center",
|
| 918 |
+
va="bottom",
|
| 919 |
+
fontsize=8,
|
| 920 |
+
fontweight="bold",
|
| 921 |
+
color="black",
|
| 922 |
+
alpha=0.8,
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
# Raw data on top
|
| 926 |
+
ax.plot(mids_s, stars_s, color="gray", alpha=0.4, linewidth=1)
|
| 927 |
+
|
| 928 |
+
# Boundary lines
|
| 929 |
+
for seg in segments[1:]:
|
| 930 |
+
ax.axvline(
|
| 931 |
+
x=seg["start_time"], color="black", linewidth=1, linestyle="--", alpha=0.5
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
ax.set_xlabel("Time (s)")
|
| 935 |
+
ax.set_ylabel("Raw Score")
|
| 936 |
+
ax.set_title(title)
|
| 937 |
+
ax.set_ylim(bottom=0, top=max_star + 2)
|
| 938 |
+
ax.grid(True, alpha=0.15, axis="y")
|
| 939 |
+
|
| 940 |
+
fig.tight_layout()
|
| 941 |
+
return fig
|
| 942 |
+
|
| 943 |
+
|
| 944 |
def _plot_attention_concentration(
|
| 945 |
avg_attention: np.ndarray,
|
| 946 |
title: str,
|
|
|
|
| 1045 |
branch_attn = attn.get("branch_attentions")
|
| 1046 |
topk_mask = attn.get("topk_mask")
|
| 1047 |
|
| 1048 |
+
# Local Difficulty Estimation (Probe)
|
| 1049 |
+
# Use the predicted class ID if no hint was provided
|
| 1050 |
+
calib_diff_id = difficulty_hint
|
| 1051 |
+
if calib_diff_id is None:
|
| 1052 |
+
calib_diff_id = out.difficulty_logits.argmax(dim=-1, keepdim=True) # [1, 1]
|
| 1053 |
+
|
| 1054 |
+
local_raw, local_stars = model.get_instance_scores(
|
| 1055 |
+
out.instance_embeddings, difficulty_class_id=calib_diff_id.view(-1)
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
avg_attn_np = (
|
| 1059 |
avg_attn[0, : counts.item()].detach().cpu().numpy()
|
| 1060 |
if avg_attn is not None
|
|
|
|
| 1070 |
if branch_attn is not None
|
| 1071 |
else None
|
| 1072 |
)
|
| 1073 |
+
local_stars_np = local_stars[0, : counts.item()].detach().cpu().numpy()
|
| 1074 |
+
local_raw_np = local_raw[0, : counts.item()].detach().cpu().numpy()
|
| 1075 |
|
| 1076 |
# Plots
|
| 1077 |
fig_attn = None
|
|
|
|
| 1095 |
title="Attention concentration (how many windows dominate)",
|
| 1096 |
)
|
| 1097 |
|
| 1098 |
+
fig_local_diff = None
|
| 1099 |
+
if local_stars_np is not None:
|
| 1100 |
+
fig_local_diff = _plot_local_difficulty(
|
| 1101 |
+
times,
|
| 1102 |
+
local_stars_np,
|
| 1103 |
+
token_counts,
|
| 1104 |
+
title=f"Estimated Local Difficulty Curve (Assuming {pred_class} calibration)",
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
# Segment Detection (Piecewise Constant Change Point Detection)
|
| 1108 |
+
fig_segments = None
|
| 1109 |
+
segment_table_df = None
|
| 1110 |
+
|
| 1111 |
+
if local_raw_np is not None and len(times) > 0:
|
| 1112 |
+
segments = _detect_segments(
|
| 1113 |
+
local_raw_np, # Use raw score instead of stars
|
| 1114 |
+
times,
|
| 1115 |
+
min_segment_size=3,
|
| 1116 |
+
penalty_scale=0.5,
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
# Create table rows
|
| 1120 |
+
seg_rows = []
|
| 1121 |
+
for i, seg in enumerate(segments):
|
| 1122 |
+
seg_rows.append(
|
| 1123 |
+
[
|
| 1124 |
+
i + 1,
|
| 1125 |
+
f"{seg['start_time']:.1f}",
|
| 1126 |
+
f"{seg['end_time']:.1f}",
|
| 1127 |
+
f"{seg['end_time'] - seg['start_time']:.1f}",
|
| 1128 |
+
f"{seg['avg_stars']:.1f}", # This is now avg_raw
|
| 1129 |
+
seg["n_windows"],
|
| 1130 |
+
]
|
| 1131 |
+
)
|
| 1132 |
+
segment_table_df = seg_rows
|
| 1133 |
+
|
| 1134 |
+
fig_segments = _plot_segments(
|
| 1135 |
+
times,
|
| 1136 |
+
local_raw_np, # Use raw score
|
| 1137 |
+
segments,
|
| 1138 |
+
title=f"Chart Structure: {len(segments)} Segments Detected",
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
# Meta/details
|
| 1142 |
if branch_np is not None:
|
| 1143 |
mids = np.array([(a + b) / 2.0 for a, b in times], dtype=np.float64)
|
| 1144 |
order = np.argsort(mids)
|
|
|
|
| 1167 |
int(token_counts[i]) if i < len(token_counts) else None,
|
| 1168 |
float(avg_attn_np[i]) if avg_attn_np is not None else None,
|
| 1169 |
int(topk_np[i]) if topk_np is not None else None,
|
| 1170 |
+
float(local_stars_np[i]) if i < len(local_stars_np) else None,
|
| 1171 |
]
|
| 1172 |
)
|
| 1173 |
|
|
|
|
| 1233 |
fig_conc,
|
| 1234 |
top_md,
|
| 1235 |
rows,
|
| 1236 |
+
fig_local_diff,
|
| 1237 |
+
fig_segments,
|
| 1238 |
+
segment_table_df,
|
| 1239 |
)
|
| 1240 |
|
| 1241 |
|
|
|
|
| 1257 |
|
| 1258 |
with gr.Blocks(title="TaikoChartEstimator Inference") as demo:
|
| 1259 |
gr.Markdown("# TaikoChartEstimator - Inference")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1260 |
|
| 1261 |
with gr.Row():
|
| 1262 |
+
# Left: Input (Upload/Paste with tabs)
|
| 1263 |
+
with gr.Column(scale=2):
|
| 1264 |
+
with gr.Tabs():
|
| 1265 |
+
with gr.TabItem("Upload"):
|
| 1266 |
+
tja_file = gr.File(label="Upload TJA file")
|
| 1267 |
+
with gr.TabItem("Paste"):
|
| 1268 |
+
tja_text = gr.Textbox(label="Paste TJA content", lines=12)
|
| 1269 |
|
| 1270 |
course = gr.Dropdown(label="COURSE", choices=[], value=None)
|
| 1271 |
+
btn = gr.Button("Run Inference", variant="primary", size="lg")
|
| 1272 |
|
| 1273 |
+
# Right: Options
|
| 1274 |
+
with gr.Column(scale=1):
|
| 1275 |
+
gr.Markdown("### Options")
|
| 1276 |
checkpoint = gr.Dropdown(
|
| 1277 |
label="Checkpoint",
|
| 1278 |
choices=checkpoints,
|
| 1279 |
value=checkpoints[-1] if checkpoints else None,
|
| 1280 |
allow_custom_value=True,
|
| 1281 |
)
|
|
|
|
| 1282 |
device = gr.Dropdown(
|
| 1283 |
label="Device", choices=["cpu", "mps", "cuda"], value="cpu"
|
| 1284 |
)
|
| 1285 |
|
| 1286 |
+
with gr.Accordion("Advanced", open=False):
|
| 1287 |
+
window_measures = gr.Textbox(
|
| 1288 |
+
label="window_measures (comma-separated)", value="2,4"
|
| 1289 |
+
)
|
| 1290 |
+
hop_measures = gr.Slider(
|
| 1291 |
+
label="hop_measures", minimum=1, maximum=8, value=2, step=1
|
| 1292 |
+
)
|
| 1293 |
+
max_instances = gr.Slider(
|
| 1294 |
+
label="max_instances", minimum=1, maximum=512, value=128, step=1
|
| 1295 |
+
)
|
|
|
|
| 1296 |
|
| 1297 |
+
with gr.Row():
|
| 1298 |
+
with gr.Column(scale=1):
|
| 1299 |
summary = gr.Markdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1300 |
top_segments = gr.Markdown()
|
| 1301 |
+
with gr.Column(scale=1):
|
| 1302 |
+
meta_json = gr.JSON(label="Metadata")
|
| 1303 |
+
|
| 1304 |
+
with gr.Tabs():
|
| 1305 |
+
with gr.TabItem("Chart Structure"):
|
| 1306 |
+
gr.Markdown("### Automatic Segment Detection")
|
| 1307 |
+
gr.Markdown(
|
| 1308 |
+
"Detects distinct sections based on difficulty changes (Piecewise Constant Model)."
|
| 1309 |
+
)
|
| 1310 |
+
plot_segments = gr.Plot(label="Detected Segments")
|
| 1311 |
+
segment_table = gr.Dataframe(
|
| 1312 |
headers=[
|
| 1313 |
+
"#",
|
| 1314 |
+
"Start (s)",
|
| 1315 |
+
"End (s)",
|
| 1316 |
+
"Duration",
|
| 1317 |
+
"Avg Raw",
|
| 1318 |
+
"Windows",
|
| 1319 |
+
],
|
| 1320 |
+
datatype=["number", "str", "str", "str", "str", "number"],
|
| 1321 |
+
label="Segment Details",
|
| 1322 |
+
)
|
| 1323 |
+
with gr.TabItem("Local Difficulty"):
|
| 1324 |
+
plot_local_diff = gr.Plot(label="Local Difficulty Curve")
|
| 1325 |
+
with gr.TabItem("Attention & Density"):
|
| 1326 |
+
plot_density = gr.Plot(label="Density vs Attention")
|
| 1327 |
+
with gr.TabItem("Attention Details"):
|
| 1328 |
+
plot_attn = gr.Plot(label="Raw Attention")
|
| 1329 |
+
with gr.TabItem("Heatmap"):
|
| 1330 |
+
plot_heat = gr.Plot(label="Branch Heatmap")
|
| 1331 |
+
with gr.TabItem("Concentration"):
|
| 1332 |
+
plot_conc = gr.Plot(label="Concentration")
|
| 1333 |
+
with gr.TabItem("Raw Data"):
|
| 1334 |
+
# headers needs to match rows
|
| 1335 |
+
df = gr.Dataframe(
|
| 1336 |
+
headers=[
|
| 1337 |
+
"id",
|
| 1338 |
+
"start",
|
| 1339 |
+
"end",
|
| 1340 |
+
"mid",
|
| 1341 |
+
"tokens",
|
| 1342 |
+
"attention",
|
| 1343 |
+
"is_topk",
|
| 1344 |
+
"local_stars",
|
| 1345 |
],
|
| 1346 |
datatype=[
|
| 1347 |
"number",
|
|
|
|
| 1351 |
"number",
|
| 1352 |
"number",
|
| 1353 |
"number",
|
| 1354 |
+
"number",
|
| 1355 |
],
|
|
|
|
|
|
|
| 1356 |
)
|
| 1357 |
|
| 1358 |
# Auto-refresh COURSE choices when input changes
|
|
|
|
| 1363 |
_update_course_dropdown, inputs=[tja_file, tja_text], outputs=[course]
|
| 1364 |
)
|
| 1365 |
|
| 1366 |
+
btn.click(
|
| 1367 |
run_inference,
|
| 1368 |
inputs=[
|
| 1369 |
tja_file,
|
|
|
|
| 1378 |
outputs=[
|
| 1379 |
summary,
|
| 1380 |
meta_json,
|
| 1381 |
+
plot_attn,
|
| 1382 |
+
plot_density,
|
| 1383 |
+
plot_heat,
|
| 1384 |
+
plot_conc,
|
| 1385 |
top_segments,
|
| 1386 |
+
df,
|
| 1387 |
+
plot_local_diff,
|
| 1388 |
+
plot_segments,
|
| 1389 |
+
segment_table,
|
| 1390 |
],
|
| 1391 |
)
|
| 1392 |
|
requirements.txt
CHANGED
|
@@ -64,6 +64,7 @@ pytz==2025.2
|
|
| 64 |
pyyaml==6.0.3
|
| 65 |
requests==2.32.5
|
| 66 |
rich==14.2.0
|
|
|
|
| 67 |
safehttpx==0.1.7
|
| 68 |
safetensors==0.7.0
|
| 69 |
scikit-learn==1.8.0
|
|
|
|
| 64 |
pyyaml==6.0.3
|
| 65 |
requests==2.32.5
|
| 66 |
rich==14.2.0
|
| 67 |
+
ruptures==1.1.10
|
| 68 |
safehttpx==0.1.7
|
| 69 |
safetensors==0.7.0
|
| 70 |
scikit-learn==1.8.0
|