Spaces:
Running
Running
Rik Hoffbauer commited on
Commit ·
6362e08
1
Parent(s): 94f6c04
Implement waveform cue editor and feedback-to-learning path
Browse files- app.py +135 -3
- cue_editor.py +246 -0
- cue_learning.py +44 -0
app.py
CHANGED
|
@@ -1326,6 +1326,7 @@ class AppState:
|
|
| 1326 |
self.rendered_set = None
|
| 1327 |
self.benchmarks = []
|
| 1328 |
self.set_order_metadata = {}
|
|
|
|
| 1329 |
|
| 1330 |
app_state = AppState()
|
| 1331 |
|
|
@@ -1492,6 +1493,8 @@ def render_full_set(max_iter, progress=gr.Progress()):
|
|
| 1492 |
|
| 1493 |
progress(0.05, desc="Compiling full-set AutomationIR...")
|
| 1494 |
from automation_set_renderer import render_set_with_automation_ir
|
|
|
|
|
|
|
| 1495 |
|
| 1496 |
def progress_cb(_p, _msg):
|
| 1497 |
# The AutomationIR renderer is deterministic and currently not chunk-progressive.
|
|
@@ -1503,8 +1506,10 @@ def render_full_set(max_iter, progress=gr.Progress()):
|
|
| 1503 |
app_state.transitions,
|
| 1504 |
load_audio_segment=load_audio_segment,
|
| 1505 |
time_stretch_audio=time_stretch_audio,
|
|
|
|
| 1506 |
sr=44100,
|
| 1507 |
)
|
|
|
|
| 1508 |
app_state.rendered_set = set_audio
|
| 1509 |
|
| 1510 |
progress(0.82, desc="Running diagnostics...")
|
|
@@ -1523,7 +1528,8 @@ def render_full_set(max_iter, progress=gr.Progress()):
|
|
| 1523 |
summary += f"- **Total duration:** {set_info['total_duration']:.1f}s ({set_info['total_duration']/60:.1f} min)\n"
|
| 1524 |
summary += f"- **Tracks:** {len(set_info['tracks'])}\n"
|
| 1525 |
summary += f"- **Transitions:** {len(set_info.get('transitions', []))}\n"
|
| 1526 |
-
summary += f"- **AutomationIR:** {set_info['automation_ir']['clips']} clips, {set_info['automation_ir']['lanes']} lanes\n
|
|
|
|
| 1527 |
|
| 1528 |
summary += "## Tracklist\n"
|
| 1529 |
for i, t in enumerate(set_info["tracks"]):
|
|
@@ -1558,6 +1564,8 @@ def render_single_transition(transition_idx, candidate_rank=0, progress=gr.Progr
|
|
| 1558 |
|
| 1559 |
progress(0.15, desc="Compiling automation IR...")
|
| 1560 |
from automation_ir import render_transition_candidate
|
|
|
|
|
|
|
| 1561 |
audio, ir, candidate = render_transition_candidate(
|
| 1562 |
trans,
|
| 1563 |
track_a,
|
|
@@ -1565,8 +1573,10 @@ def render_single_transition(transition_idx, candidate_rank=0, progress=gr.Progr
|
|
| 1565 |
candidate_rank=rank,
|
| 1566 |
load_audio_segment=load_audio_segment,
|
| 1567 |
time_stretch_audio=time_stretch_audio,
|
|
|
|
| 1568 |
sr=44100,
|
| 1569 |
)
|
|
|
|
| 1570 |
|
| 1571 |
progress(0.85, desc="Saving preview...")
|
| 1572 |
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
|
@@ -1579,6 +1589,10 @@ def render_single_transition(transition_idx, candidate_rank=0, progress=gr.Progr
|
|
| 1579 |
else:
|
| 1580 |
edge_score = trans.score_breakdown.get('overall')
|
| 1581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1582 |
info = (
|
| 1583 |
f"**Transition {idx+1}:** {track_a.filename} → {track_b.filename}\n"
|
| 1584 |
f"**Candidate:** {cue_source}\n"
|
|
@@ -1588,6 +1602,8 @@ def render_single_transition(transition_idx, candidate_rank=0, progress=gr.Progr
|
|
| 1588 |
f"B in {ir.metadata['mix_in_point']:.2f}s, B drop {ir.metadata['b_drop']:.2f}s\n"
|
| 1589 |
f"**Duration:** {ir.metadata['duration_seconds']:.2f}s; score={edge_score if edge_score is not None else 'n/a'}\n"
|
| 1590 |
f"**Preview file duration:** {audio.shape[-1] / 44100:.1f}s\n\n"
|
|
|
|
|
|
|
| 1591 |
f"```json\n{json.dumps(ir.to_dict(), indent=2)[:6000]}\n```"
|
| 1592 |
)
|
| 1593 |
return tmp.name, info
|
|
@@ -1624,6 +1640,13 @@ def apply_manual_transition_edit(transition_idx, mix_out_point, mix_in_point, du
|
|
| 1624 |
assumptions.append("manual cue/timing override applied; preview this candidate before full render")
|
| 1625 |
trans.assumptions = assumptions
|
| 1626 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1627 |
return (
|
| 1628 |
f"✅ Updated transition {idx+1}\n\n"
|
| 1629 |
f"- Type: `{trans.transition_type}`\n"
|
|
@@ -1634,6 +1657,100 @@ def apply_manual_transition_edit(transition_idx, mix_out_point, mix_in_point, du
|
|
| 1634 |
)
|
| 1635 |
|
| 1636 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1637 |
def save_transition_rating(transition_idx, candidate_rank, rating, accepted, notes):
|
| 1638 |
"""Persist a human listening rating for a transition candidate."""
|
| 1639 |
if not app_state.transitions:
|
|
@@ -1736,15 +1853,28 @@ def build_ui():
|
|
| 1736 |
preview_info = gr.Markdown()
|
| 1737 |
preview_btn.click(render_single_transition, [trans_idx_input, candidate_rank_input], [preview_audio, preview_info])
|
| 1738 |
|
| 1739 |
-
gr.Markdown("###
|
| 1740 |
with gr.Row():
|
| 1741 |
manual_idx = gr.Number(value=1, label="Transition", minimum=1, precision=0)
|
| 1742 |
manual_type = gr.Dropdown(choices=list(TRANSITION_TYPES.keys()), value="eq_crossfade", label="Transition type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1743 |
with gr.Row():
|
| 1744 |
manual_mix_out = gr.Number(value=0, label="A mix-out seconds", precision=2)
|
| 1745 |
manual_mix_in = gr.Number(value=0, label="B mix-in seconds", precision=2)
|
| 1746 |
manual_beats = gr.Number(value=64, label="Duration beats", minimum=1, precision=0)
|
| 1747 |
-
manual_apply = gr.Button("Apply
|
| 1748 |
manual_output = gr.Markdown()
|
| 1749 |
manual_apply.click(apply_manual_transition_edit, [manual_idx, manual_mix_out, manual_mix_in, manual_beats, manual_type], [manual_output])
|
| 1750 |
|
|
@@ -1756,9 +1886,11 @@ def build_ui():
|
|
| 1756 |
with gr.Row():
|
| 1757 |
save_rating_btn = gr.Button("Save rating")
|
| 1758 |
show_ratings_btn = gr.Button("Show rating summary")
|
|
|
|
| 1759 |
rating_output = gr.Markdown()
|
| 1760 |
save_rating_btn.click(save_transition_rating, [trans_idx_input, candidate_rank_input, rating, accepted, notes], [rating_output])
|
| 1761 |
show_ratings_btn.click(show_listening_benchmarks, [], [rating_output])
|
|
|
|
| 1762 |
|
| 1763 |
# ──── TAB 4: RENDER FULL SET ────
|
| 1764 |
with gr.Tab("4️⃣ Render Full Set"):
|
|
|
|
| 1326 |
self.rendered_set = None
|
| 1327 |
self.benchmarks = []
|
| 1328 |
self.set_order_metadata = {}
|
| 1329 |
+
self.last_stem_diagnostics = {}
|
| 1330 |
|
| 1331 |
app_state = AppState()
|
| 1332 |
|
|
|
|
| 1493 |
|
| 1494 |
progress(0.05, desc="Compiling full-set AutomationIR...")
|
| 1495 |
from automation_set_renderer import render_set_with_automation_ir
|
| 1496 |
+
from stem_provider import StemProvider
|
| 1497 |
+
stem_provider = StemProvider()
|
| 1498 |
|
| 1499 |
def progress_cb(_p, _msg):
|
| 1500 |
# The AutomationIR renderer is deterministic and currently not chunk-progressive.
|
|
|
|
| 1506 |
app_state.transitions,
|
| 1507 |
load_audio_segment=load_audio_segment,
|
| 1508 |
time_stretch_audio=time_stretch_audio,
|
| 1509 |
+
stem_resolver=stem_provider.resolver(),
|
| 1510 |
sr=44100,
|
| 1511 |
)
|
| 1512 |
+
app_state.last_stem_diagnostics = dict(stem_provider.diagnostics)
|
| 1513 |
app_state.rendered_set = set_audio
|
| 1514 |
|
| 1515 |
progress(0.82, desc="Running diagnostics...")
|
|
|
|
| 1528 |
summary += f"- **Total duration:** {set_info['total_duration']:.1f}s ({set_info['total_duration']/60:.1f} min)\n"
|
| 1529 |
summary += f"- **Tracks:** {len(set_info['tracks'])}\n"
|
| 1530 |
summary += f"- **Transitions:** {len(set_info.get('transitions', []))}\n"
|
| 1531 |
+
summary += f"- **AutomationIR:** {set_info['automation_ir']['clips']} clips, {set_info['automation_ir']['lanes']} lanes\n"
|
| 1532 |
+
summary += f"- **Stem lane method:** `{set_info['automation_ir'].get('component_lane_method', 'n/a')}`\n\n"
|
| 1533 |
|
| 1534 |
summary += "## Tracklist\n"
|
| 1535 |
for i, t in enumerate(set_info["tracks"]):
|
|
|
|
| 1564 |
|
| 1565 |
progress(0.15, desc="Compiling automation IR...")
|
| 1566 |
from automation_ir import render_transition_candidate
|
| 1567 |
+
from stem_provider import StemProvider
|
| 1568 |
+
stem_provider = StemProvider()
|
| 1569 |
audio, ir, candidate = render_transition_candidate(
|
| 1570 |
trans,
|
| 1571 |
track_a,
|
|
|
|
| 1573 |
candidate_rank=rank,
|
| 1574 |
load_audio_segment=load_audio_segment,
|
| 1575 |
time_stretch_audio=time_stretch_audio,
|
| 1576 |
+
stem_resolver=stem_provider.resolver(),
|
| 1577 |
sr=44100,
|
| 1578 |
)
|
| 1579 |
+
app_state.last_stem_diagnostics = dict(stem_provider.diagnostics)
|
| 1580 |
|
| 1581 |
progress(0.85, desc="Saving preview...")
|
| 1582 |
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
|
|
|
| 1589 |
else:
|
| 1590 |
edge_score = trans.score_breakdown.get('overall')
|
| 1591 |
|
| 1592 |
+
from transition_diagnostics import diagnose_transition_audio, format_transition_diagnostics
|
| 1593 |
+
diag = diagnose_transition_audio(audio, sr=44100, anchor_seconds=ir.anchor_seconds)
|
| 1594 |
+
stem_diag = json.dumps(app_state.last_stem_diagnostics, indent=2)[:2500] if app_state.last_stem_diagnostics else "{}"
|
| 1595 |
+
|
| 1596 |
info = (
|
| 1597 |
f"**Transition {idx+1}:** {track_a.filename} → {track_b.filename}\n"
|
| 1598 |
f"**Candidate:** {cue_source}\n"
|
|
|
|
| 1602 |
f"B in {ir.metadata['mix_in_point']:.2f}s, B drop {ir.metadata['b_drop']:.2f}s\n"
|
| 1603 |
f"**Duration:** {ir.metadata['duration_seconds']:.2f}s; score={edge_score if edge_score is not None else 'n/a'}\n"
|
| 1604 |
f"**Preview file duration:** {audio.shape[-1] / 44100:.1f}s\n\n"
|
| 1605 |
+
f"{format_transition_diagnostics(diag)}\n\n"
|
| 1606 |
+
f"### Stem provider diagnostics\n```json\n{stem_diag}\n```\n\n"
|
| 1607 |
f"```json\n{json.dumps(ir.to_dict(), indent=2)[:6000]}\n```"
|
| 1608 |
)
|
| 1609 |
return tmp.name, info
|
|
|
|
| 1640 |
assumptions.append("manual cue/timing override applied; preview this candidate before full render")
|
| 1641 |
trans.assumptions = assumptions
|
| 1642 |
|
| 1643 |
+
try:
|
| 1644 |
+
from cue_learning import append_training_example
|
| 1645 |
+
for cue in trans.selected_cues.values():
|
| 1646 |
+
append_training_example("data/manual-cue-edits.jsonl", cue, duration=track_b.duration, label=1, source="numeric_manual_editor")
|
| 1647 |
+
except Exception as exc:
|
| 1648 |
+
logger.warning(f"Could not append manual cue training examples: {exc}")
|
| 1649 |
+
|
| 1650 |
return (
|
| 1651 |
f"✅ Updated transition {idx+1}\n\n"
|
| 1652 |
f"- Type: `{trans.transition_type}`\n"
|
|
|
|
| 1657 |
)
|
| 1658 |
|
| 1659 |
|
| 1660 |
+
def load_waveform_cue_editor(transition_idx):
|
| 1661 |
+
"""Load waveform image and cue choices for the selected transition."""
|
| 1662 |
+
if not app_state.transitions:
|
| 1663 |
+
return None, "⚠️ Generate a set plan first", gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[])
|
| 1664 |
+
idx = int(transition_idx) - 1
|
| 1665 |
+
if idx < 0 or idx >= len(app_state.transitions):
|
| 1666 |
+
return None, f"⚠️ Invalid transition index. Choose 1-{len(app_state.transitions)}", gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[])
|
| 1667 |
+
trans = app_state.transitions[idx]
|
| 1668 |
+
track_a = app_state.analyses[trans.track_a_idx]
|
| 1669 |
+
track_b = app_state.analyses[trans.track_b_idx]
|
| 1670 |
+
from cue_editor import render_transition_cue_editor, choices_for_transition
|
| 1671 |
+
image_path, summary = render_transition_cue_editor(track_a, track_b, trans)
|
| 1672 |
+
choices = choices_for_transition(track_a, track_b, trans)
|
| 1673 |
+
return (
|
| 1674 |
+
image_path,
|
| 1675 |
+
summary,
|
| 1676 |
+
gr.update(choices=choices["a_choices"], value=choices["a_default"]),
|
| 1677 |
+
gr.update(choices=choices["b_in_choices"], value=choices["b_in_default"]),
|
| 1678 |
+
gr.update(choices=choices["b_drop_choices"], value=choices["b_drop_default"]),
|
| 1679 |
+
)
|
| 1680 |
+
|
| 1681 |
+
|
| 1682 |
+
def apply_waveform_cue_choices(transition_idx, a_choice, b_in_choice, b_drop_choice, transition_type):
|
| 1683 |
+
"""Apply cue choices from the waveform editor."""
|
| 1684 |
+
if not app_state.transitions:
|
| 1685 |
+
return "⚠️ Generate a set plan first"
|
| 1686 |
+
idx = int(transition_idx) - 1
|
| 1687 |
+
if idx < 0 or idx >= len(app_state.transitions):
|
| 1688 |
+
return f"⚠️ Invalid transition index. Choose 1-{len(app_state.transitions)}"
|
| 1689 |
+
trans = app_state.transitions[idx]
|
| 1690 |
+
track_a = app_state.analyses[trans.track_a_idx]
|
| 1691 |
+
track_b = app_state.analyses[trans.track_b_idx]
|
| 1692 |
+
from cue_editor import apply_choices_to_plan
|
| 1693 |
+
mix_out, mix_in, duration, selected = apply_choices_to_plan(
|
| 1694 |
+
trans,
|
| 1695 |
+
a_choice=a_choice,
|
| 1696 |
+
b_in_choice=b_in_choice,
|
| 1697 |
+
b_drop_choice=b_drop_choice,
|
| 1698 |
+
transition_type=transition_type if transition_type in TRANSITION_TYPES else None,
|
| 1699 |
+
)
|
| 1700 |
+
trans.mix_out_point = round(max(0.0, min(mix_out, track_a.duration)), 3)
|
| 1701 |
+
trans.mix_in_point = round(max(0.0, min(mix_in, track_b.duration)), 3)
|
| 1702 |
+
trans.duration_seconds = round(max(0.25, duration), 3)
|
| 1703 |
+
trans.duration_beats = max(1, round(trans.duration_seconds * max(track_b.bpm, 60.0) / 60.0))
|
| 1704 |
+
trans.needs_stems = trans.transition_type in ("bass_swap", "acapella_over_instrumental", "drums_first", "double_drop")
|
| 1705 |
+
trans.selected_cues = selected
|
| 1706 |
+
confs = [float(c.get("confidence", 0.0) or 0.0) for c in selected.values()]
|
| 1707 |
+
trans.cue_confidence = round(sum(confs) / len(confs), 3) if confs else 1.0
|
| 1708 |
+
trans.score_breakdown = {**dict(trans.score_breakdown), "waveform_editor_override": 1.0, "cue_confidence": trans.cue_confidence}
|
| 1709 |
+
assumptions = [a for a in trans.assumptions if "waveform" not in a.lower()]
|
| 1710 |
+
assumptions.append("waveform cue editor override applied; audition before full render")
|
| 1711 |
+
trans.assumptions = assumptions
|
| 1712 |
+
|
| 1713 |
+
# Persist positive cue examples as manual supervision. The user can later
|
| 1714 |
+
# train a cue model from this file or merge it with listening ratings.
|
| 1715 |
+
try:
|
| 1716 |
+
from cue_learning import append_training_example
|
| 1717 |
+
for cue in selected.values():
|
| 1718 |
+
append_training_example("data/manual-cue-edits.jsonl", cue, duration=track_b.duration, label=1, source="waveform_editor")
|
| 1719 |
+
except Exception as exc:
|
| 1720 |
+
logger.warning(f"Could not append manual cue training examples: {exc}")
|
| 1721 |
+
|
| 1722 |
+
return (
|
| 1723 |
+
f"✅ Applied waveform cue edit to transition {idx+1}\n\n"
|
| 1724 |
+
f"- Type: `{trans.transition_type}`\n"
|
| 1725 |
+
f"- A mix-out: {trans.mix_out_point:.2f}s\n"
|
| 1726 |
+
f"- B mix-in: {trans.mix_in_point:.2f}s\n"
|
| 1727 |
+
f"- B drop: {trans.selected_cues['b_drop']['time']:.2f}s\n"
|
| 1728 |
+
f"- Duration: {trans.duration_beats} beats / {trans.duration_seconds:.2f}s\n"
|
| 1729 |
+
f"- Cue confidence: {trans.cue_confidence:.0%}"
|
| 1730 |
+
)
|
| 1731 |
+
|
| 1732 |
+
|
| 1733 |
+
def train_cue_model_from_feedback():
|
| 1734 |
+
"""Train the lightweight cue scorer from ratings and manual edits."""
|
| 1735 |
+
from cue_learning import load_training_examples, examples_from_rating_rows, fit_logistic_model, save_model
|
| 1736 |
+
from listening_benchmarks import load_ratings
|
| 1737 |
+
manual = load_training_examples("data/manual-cue-edits.jsonl")
|
| 1738 |
+
rating_examples = examples_from_rating_rows(load_ratings())
|
| 1739 |
+
examples = manual + rating_examples
|
| 1740 |
+
if not examples:
|
| 1741 |
+
return "⚠️ No manual cue edits or decisive listening ratings available for training yet."
|
| 1742 |
+
model = fit_logistic_model(examples)
|
| 1743 |
+
path = save_model(model)
|
| 1744 |
+
return (
|
| 1745 |
+
"✅ Cue model trained\n\n"
|
| 1746 |
+
f"- Examples: {len(examples)}\n"
|
| 1747 |
+
f"- Manual cue edits: {len(manual)}\n"
|
| 1748 |
+
f"- Rating-derived examples: {len(rating_examples)}\n"
|
| 1749 |
+
f"- Output: `{path}`\n\n"
|
| 1750 |
+
"New analyses will blend this learned probability into cue confidence."
|
| 1751 |
+
)
|
| 1752 |
+
|
| 1753 |
+
|
| 1754 |
def save_transition_rating(transition_idx, candidate_rank, rating, accepted, notes):
|
| 1755 |
"""Persist a human listening rating for a transition candidate."""
|
| 1756 |
if not app_state.transitions:
|
|
|
|
| 1853 |
preview_info = gr.Markdown()
|
| 1854 |
preview_btn.click(render_single_transition, [trans_idx_input, candidate_rank_input], [preview_audio, preview_info])
|
| 1855 |
|
| 1856 |
+
gr.Markdown("### Waveform-backed cue editor")
|
| 1857 |
with gr.Row():
|
| 1858 |
manual_idx = gr.Number(value=1, label="Transition", minimum=1, precision=0)
|
| 1859 |
manual_type = gr.Dropdown(choices=list(TRANSITION_TYPES.keys()), value="eq_crossfade", label="Transition type")
|
| 1860 |
+
load_editor_btn = gr.Button("Load waveform editor")
|
| 1861 |
+
cue_waveform = gr.Image(label="A/B waveform with cue markers", type="filepath", interactive=False)
|
| 1862 |
+
cue_editor_summary = gr.Markdown()
|
| 1863 |
+
with gr.Row():
|
| 1864 |
+
a_cue_choice = gr.Dropdown(choices=[], label="A mix-out cue")
|
| 1865 |
+
b_in_choice = gr.Dropdown(choices=[], label="B mix-in cue")
|
| 1866 |
+
b_drop_choice = gr.Dropdown(choices=[], label="B drop cue")
|
| 1867 |
+
apply_waveform_btn = gr.Button("Apply waveform cue choices")
|
| 1868 |
+
waveform_output = gr.Markdown()
|
| 1869 |
+
load_editor_btn.click(load_waveform_cue_editor, [manual_idx], [cue_waveform, cue_editor_summary, a_cue_choice, b_in_choice, b_drop_choice])
|
| 1870 |
+
apply_waveform_btn.click(apply_waveform_cue_choices, [manual_idx, a_cue_choice, b_in_choice, b_drop_choice, manual_type], [waveform_output])
|
| 1871 |
+
|
| 1872 |
+
gr.Markdown("### Numeric fallback editor")
|
| 1873 |
with gr.Row():
|
| 1874 |
manual_mix_out = gr.Number(value=0, label="A mix-out seconds", precision=2)
|
| 1875 |
manual_mix_in = gr.Number(value=0, label="B mix-in seconds", precision=2)
|
| 1876 |
manual_beats = gr.Number(value=64, label="Duration beats", minimum=1, precision=0)
|
| 1877 |
+
manual_apply = gr.Button("Apply numeric override")
|
| 1878 |
manual_output = gr.Markdown()
|
| 1879 |
manual_apply.click(apply_manual_transition_edit, [manual_idx, manual_mix_out, manual_mix_in, manual_beats, manual_type], [manual_output])
|
| 1880 |
|
|
|
|
| 1886 |
with gr.Row():
|
| 1887 |
save_rating_btn = gr.Button("Save rating")
|
| 1888 |
show_ratings_btn = gr.Button("Show rating summary")
|
| 1889 |
+
train_cue_btn = gr.Button("Train cue model from feedback")
|
| 1890 |
rating_output = gr.Markdown()
|
| 1891 |
save_rating_btn.click(save_transition_rating, [trans_idx_input, candidate_rank_input, rating, accepted, notes], [rating_output])
|
| 1892 |
show_ratings_btn.click(show_listening_benchmarks, [], [rating_output])
|
| 1893 |
+
train_cue_btn.click(train_cue_model_from_feedback, [], [rating_output])
|
| 1894 |
|
| 1895 |
# ──── TAB 4: RENDER FULL SET ────
|
| 1896 |
with gr.Tab("4️⃣ Render Full Set"):
|
cue_editor.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Waveform-backed manual cue editing helpers.
|
| 2 |
+
|
| 3 |
+
The previous manual editor only exposed numeric inputs. This module adds a
|
| 4 |
+
visual, audio-derived cue editor: it renders waveform overviews for the two
|
| 5 |
+
tracks in a transition, overlays selected cue positions and alternative cue
|
| 6 |
+
candidates, and returns stable cue-choice strings that can be applied back to
|
| 7 |
+
TransitionPlan objects.
|
| 8 |
+
|
| 9 |
+
The UI remains deliberately simple because Gradio event/click APIs vary across
|
| 10 |
+
versions. The backend is still real: the waveform image is computed from the
|
| 11 |
+
actual audio files, candidate lists are built from analysis cue objects, and
|
| 12 |
+
manual edits become explicit cue overrides that can later be exported as
|
| 13 |
+
training examples.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Iterable, Mapping
|
| 21 |
+
import hashlib
|
| 22 |
+
import math
|
| 23 |
+
import tempfile
|
| 24 |
+
|
| 25 |
+
import librosa
|
| 26 |
+
import matplotlib
|
| 27 |
+
matplotlib.use("Agg")
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass(frozen=True)
|
| 33 |
+
class CueChoice:
|
| 34 |
+
role: str
|
| 35 |
+
time: float
|
| 36 |
+
label: str
|
| 37 |
+
confidence: float
|
| 38 |
+
source: str = ""
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def value(self) -> str:
|
| 42 |
+
safe_label = self.label.replace("|", "/")
|
| 43 |
+
safe_source = self.source.replace("|", "/")
|
| 44 |
+
return f"{self.role}|{self.time:.3f}|{self.confidence:.3f}|{safe_label}|{safe_source}"
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def display(self) -> str:
|
| 48 |
+
source = f" · {self.source}" if self.source else ""
|
| 49 |
+
return f"{self.role} @ {self.time:.2f}s · {self.confidence:.0%} · {self.label}{source}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def parse_cue_choice(value: str | None) -> CueChoice | None:
|
| 53 |
+
if not value:
|
| 54 |
+
return None
|
| 55 |
+
parts = str(value).split("|", 4)
|
| 56 |
+
if len(parts) < 4:
|
| 57 |
+
return None
|
| 58 |
+
role, time_s, confidence, label = parts[:4]
|
| 59 |
+
source = parts[4] if len(parts) > 4 else ""
|
| 60 |
+
try:
|
| 61 |
+
return CueChoice(role=role, time=float(time_s), confidence=float(confidence), label=label, source=source)
|
| 62 |
+
except ValueError:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _cue_source(cue: Mapping[str, Any]) -> str:
|
| 67 |
+
ev = cue.get("evidence", {})
|
| 68 |
+
return str(ev.get("source", "")) if isinstance(ev, Mapping) else ""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def cue_choices(track: Any, role: str, *, limit: int = 12) -> list[tuple[str, str]]:
|
| 72 |
+
"""Return Gradio-compatible `(label, value)` cue choices for a role."""
|
| 73 |
+
aliases = {
|
| 74 |
+
"a_out": {"mix_out", "loopable", "drop"},
|
| 75 |
+
"b_in": {"mix_in", "loopable"},
|
| 76 |
+
"b_drop": {"first_drop", "drop"},
|
| 77 |
+
}.get(role, {role})
|
| 78 |
+
cues = [c for c in getattr(track, "cue_points", []) if str(c.get("kind", c.get("type", ""))) in aliases]
|
| 79 |
+
cues.sort(key=lambda c: (-float(c.get("confidence", 0.0) or 0.0), float(c.get("time", 0.0) or 0.0)))
|
| 80 |
+
out: list[tuple[str, str]] = []
|
| 81 |
+
for cue in cues[:limit]:
|
| 82 |
+
choice = CueChoice(
|
| 83 |
+
role=role,
|
| 84 |
+
time=float(cue.get("time", 0.0) or 0.0),
|
| 85 |
+
label=str(cue.get("label", cue.get("kind", "cue"))),
|
| 86 |
+
confidence=float(cue.get("confidence", 0.0) or 0.0),
|
| 87 |
+
source=_cue_source(cue),
|
| 88 |
+
)
|
| 89 |
+
out.append((choice.display, choice.value))
|
| 90 |
+
return out
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def default_choice(track: Any, role: str, time_s: float) -> str | None:
|
| 94 |
+
choices = cue_choices(track, role)
|
| 95 |
+
if not choices:
|
| 96 |
+
return None
|
| 97 |
+
parsed = [(label, value, parse_cue_choice(value)) for label, value in choices]
|
| 98 |
+
parsed = [(label, value, cue) for label, value, cue in parsed if cue is not None]
|
| 99 |
+
if not parsed:
|
| 100 |
+
return None
|
| 101 |
+
return min(parsed, key=lambda item: abs(item[2].time - float(time_s)))[1]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _load_preview(path: str, *, max_duration: float = 300.0, sr: int = 12000) -> tuple[np.ndarray, int]:
|
| 105 |
+
try:
|
| 106 |
+
y, got_sr = librosa.load(path, sr=sr, mono=True, duration=max_duration)
|
| 107 |
+
except Exception:
|
| 108 |
+
# A missing/corrupt file should not kill the editor. Return a visible flatline.
|
| 109 |
+
got_sr = sr
|
| 110 |
+
y = np.zeros(sr, dtype=np.float32)
|
| 111 |
+
if y.size == 0:
|
| 112 |
+
y = np.zeros(sr, dtype=np.float32)
|
| 113 |
+
y = np.asarray(y, dtype=np.float32)
|
| 114 |
+
peak = float(np.max(np.abs(y))) if y.size else 0.0
|
| 115 |
+
if peak > 0:
|
| 116 |
+
y = y / peak
|
| 117 |
+
return y, got_sr
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _amplitude_envelope(y: np.ndarray, sr: int, *, bins: int = 1800) -> tuple[np.ndarray, np.ndarray]:
|
| 121 |
+
n = len(y)
|
| 122 |
+
if n == 0:
|
| 123 |
+
return np.array([0.0]), np.array([0.0])
|
| 124 |
+
bins = max(64, min(bins, n))
|
| 125 |
+
edges = np.linspace(0, n, bins + 1, dtype=int)
|
| 126 |
+
env = np.zeros(bins, dtype=np.float32)
|
| 127 |
+
for i in range(bins):
|
| 128 |
+
seg = y[edges[i]:edges[i + 1]]
|
| 129 |
+
env[i] = float(np.max(np.abs(seg))) if len(seg) else 0.0
|
| 130 |
+
times = np.linspace(0.0, n / sr, bins)
|
| 131 |
+
return times, env
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _draw_track(ax: Any, track: Any, *, selected: dict[str, float], title: str) -> None:
|
| 135 |
+
y, sr = _load_preview(getattr(track, "path", ""), max_duration=float(getattr(track, "duration", 300.0) or 300.0))
|
| 136 |
+
times, env = _amplitude_envelope(y, sr)
|
| 137 |
+
ax.fill_between(times, -env, env, alpha=0.35, linewidth=0)
|
| 138 |
+
ax.plot(times, env, linewidth=0.35)
|
| 139 |
+
ax.plot(times, -env, linewidth=0.35)
|
| 140 |
+
duration = float(getattr(track, "duration", times[-1] if len(times) else 0.0) or 0.0)
|
| 141 |
+
ax.set_xlim(0, max(1.0, min(duration, times[-1] if len(times) else duration)))
|
| 142 |
+
ax.set_ylim(-1.05, 1.05)
|
| 143 |
+
ax.set_yticks([])
|
| 144 |
+
ax.set_title(title, loc="left", fontsize=10)
|
| 145 |
+
ax.set_xlabel("seconds")
|
| 146 |
+
|
| 147 |
+
# Segment spans give the user context beyond the raw waveform.
|
| 148 |
+
for seg in getattr(track, "segments", [])[:40]:
|
| 149 |
+
if not isinstance(seg, Mapping):
|
| 150 |
+
continue
|
| 151 |
+
start = float(seg.get("start", 0.0) or 0.0)
|
| 152 |
+
end = float(seg.get("end", start) or start)
|
| 153 |
+
label = str(seg.get("label", "section"))
|
| 154 |
+
if end <= start:
|
| 155 |
+
continue
|
| 156 |
+
ax.axvspan(start, end, alpha=0.04)
|
| 157 |
+
if end - start > 5:
|
| 158 |
+
ax.text(start + 0.15, 0.82, label, fontsize=7, alpha=0.65)
|
| 159 |
+
|
| 160 |
+
cue_palette = {
|
| 161 |
+
"mix_in": (0.2, 0.7, 0.2),
|
| 162 |
+
"mix_out": (0.8, 0.25, 0.2),
|
| 163 |
+
"first_drop": (0.55, 0.2, 0.8),
|
| 164 |
+
"drop": (0.55, 0.2, 0.8),
|
| 165 |
+
"loopable": (0.2, 0.45, 0.85),
|
| 166 |
+
}
|
| 167 |
+
for cue in getattr(track, "cue_points", [])[:60]:
|
| 168 |
+
kind = str(cue.get("kind", cue.get("type", "cue")))
|
| 169 |
+
t = float(cue.get("time", 0.0) or 0.0)
|
| 170 |
+
if t < 0 or t > duration:
|
| 171 |
+
continue
|
| 172 |
+
conf = float(cue.get("confidence", 0.0) or 0.0)
|
| 173 |
+
color = cue_palette.get(kind, (0.3, 0.3, 0.3))
|
| 174 |
+
ax.axvline(t, color=color, alpha=max(0.12, min(0.55, conf * 0.55)), linewidth=0.8)
|
| 175 |
+
|
| 176 |
+
for name, t in selected.items():
|
| 177 |
+
ax.axvline(float(t), color="black", linewidth=2.0, alpha=0.95)
|
| 178 |
+
ax.text(float(t), -0.92, name, rotation=90, va="bottom", ha="right", fontsize=8, fontweight="bold")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def render_transition_cue_editor(track_a: Any, track_b: Any, plan: Any, *, output_dir: str | Path | None = None) -> tuple[str, str]:
|
| 182 |
+
"""Render a two-track waveform/cue overview and return `(png_path, markdown)`."""
|
| 183 |
+
output_dir = Path(output_dir or tempfile.gettempdir())
|
| 184 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 185 |
+
fingerprint = hashlib.sha1(
|
| 186 |
+
f"{getattr(track_a, 'path', '')}|{getattr(track_b, 'path', '')}|{getattr(plan, 'mix_out_point', 0)}|{getattr(plan, 'mix_in_point', 0)}|{getattr(plan, 'duration_seconds', 0)}|{getattr(plan, 'transition_type', '')}".encode()
|
| 187 |
+
).hexdigest()[:12]
|
| 188 |
+
out = output_dir / f"ai-dj-cue-editor-{fingerprint}.png"
|
| 189 |
+
|
| 190 |
+
selected = getattr(plan, "selected_cues", {}) or {}
|
| 191 |
+
a_out = float(selected.get("a_out", {}).get("time", getattr(plan, "mix_out_point", 0.0)))
|
| 192 |
+
b_in = float(selected.get("b_in", {}).get("time", getattr(plan, "mix_in_point", 0.0)))
|
| 193 |
+
b_drop = float(selected.get("b_drop", {}).get("time", b_in + getattr(plan, "duration_seconds", 0.0)))
|
| 194 |
+
|
| 195 |
+
fig, axes = plt.subplots(2, 1, figsize=(15, 5.2), constrained_layout=True)
|
| 196 |
+
_draw_track(axes[0], track_a, selected={"A OUT": a_out}, title=f"A: {getattr(track_a, 'filename', 'track A')}")
|
| 197 |
+
_draw_track(axes[1], track_b, selected={"B IN": b_in, "B DROP": b_drop}, title=f"B: {getattr(track_b, 'filename', 'track B')}")
|
| 198 |
+
fig.suptitle(f"Transition cue editor · {getattr(plan, 'transition_type', 'transition')} · {getattr(plan, 'duration_beats', '?')} beats", fontsize=12)
|
| 199 |
+
fig.savefig(out, dpi=150)
|
| 200 |
+
plt.close(fig)
|
| 201 |
+
|
| 202 |
+
summary = [
|
| 203 |
+
"### Waveform cue editor",
|
| 204 |
+
"The black markers are the currently selected transition anchors. Thin colored lines are ranked cue candidates from analysis.",
|
| 205 |
+
f"- A mix-out: **{a_out:.2f}s**",
|
| 206 |
+
f"- B mix-in: **{b_in:.2f}s**",
|
| 207 |
+
f"- B drop: **{b_drop:.2f}s**",
|
| 208 |
+
f"- Transition type: `{getattr(plan, 'transition_type', 'unknown')}`",
|
| 209 |
+
f"- Duration: **{float(getattr(plan, 'duration_seconds', 0.0)):.2f}s** / **{int(getattr(plan, 'duration_beats', 0))} beats**",
|
| 210 |
+
]
|
| 211 |
+
return str(out), "\n".join(summary)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def choices_for_transition(track_a: Any, track_b: Any, plan: Any) -> dict[str, Any]:
|
| 215 |
+
"""Return choice lists and defaults for the UI/backend tests."""
|
| 216 |
+
return {
|
| 217 |
+
"a_choices": cue_choices(track_a, "a_out"),
|
| 218 |
+
"b_in_choices": cue_choices(track_b, "b_in"),
|
| 219 |
+
"b_drop_choices": cue_choices(track_b, "b_drop"),
|
| 220 |
+
"a_default": default_choice(track_a, "a_out", float(getattr(plan, "mix_out_point", 0.0) or 0.0)),
|
| 221 |
+
"b_in_default": default_choice(track_b, "b_in", float(getattr(plan, "mix_in_point", 0.0) or 0.0)),
|
| 222 |
+
"b_drop_default": default_choice(track_b, "b_drop", float(getattr(plan, "mix_in_point", 0.0) + getattr(plan, "duration_seconds", 0.0))),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def apply_choices_to_plan(plan: Any, *, a_choice: str | None, b_in_choice: str | None, b_drop_choice: str | None, transition_type: str | None = None) -> tuple[float, float, float, dict[str, Any]]:
|
| 227 |
+
"""Apply cue-choice strings to a TransitionPlan-like object.
|
| 228 |
+
|
| 229 |
+
Returns `(mix_out, mix_in, duration_seconds, selected_cues)` so callers can
|
| 230 |
+
update additional derived fields such as beat count.
|
| 231 |
+
"""
|
| 232 |
+
a = parse_cue_choice(a_choice)
|
| 233 |
+
b = parse_cue_choice(b_in_choice)
|
| 234 |
+
d = parse_cue_choice(b_drop_choice)
|
| 235 |
+
mix_out = float(a.time if a else getattr(plan, "mix_out_point", 0.0))
|
| 236 |
+
mix_in = float(b.time if b else getattr(plan, "mix_in_point", 0.0))
|
| 237 |
+
drop = float(d.time if d else mix_in + float(getattr(plan, "duration_seconds", 0.0)))
|
| 238 |
+
duration = max(0.25, drop - mix_in)
|
| 239 |
+
if transition_type:
|
| 240 |
+
setattr(plan, "transition_type", transition_type)
|
| 241 |
+
selected = {
|
| 242 |
+
"a_out": {"kind": "mix_out", "label": a.label if a else "manual waveform value", "time": round(mix_out, 3), "confidence": a.confidence if a else 1.0, "evidence": {"source": a.source if a else "waveform_editor"}},
|
| 243 |
+
"b_in": {"kind": "mix_in", "label": b.label if b else "manual waveform value", "time": round(mix_in, 3), "confidence": b.confidence if b else 1.0, "evidence": {"source": b.source if b else "waveform_editor"}},
|
| 244 |
+
"b_drop": {"kind": "drop", "label": d.label if d else "manual waveform value", "time": round(drop, 3), "confidence": d.confidence if d else 1.0, "evidence": {"source": d.source if d else "waveform_editor"}},
|
| 245 |
+
}
|
| 246 |
+
return mix_out, mix_in, duration, selected
|
cue_learning.py
CHANGED
|
@@ -162,3 +162,47 @@ def train_from_jsonl(path: str | Path, *, output_path: str | Path = DEFAULT_MODE
|
|
| 162 |
model = fit_logistic_model(load_training_examples(path))
|
| 163 |
save_model(model, output_path)
|
| 164 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
model = fit_logistic_model(load_training_examples(path))
|
| 163 |
save_model(model, output_path)
|
| 164 |
return model
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def examples_from_rating_rows(rows: list[Mapping[str, Any]]) -> list[dict[str, Any]]:
|
| 168 |
+
"""Convert listening ratings into cue-training examples.
|
| 169 |
+
|
| 170 |
+
Positive labels come from accepted candidates or ratings >= 4. Negative
|
| 171 |
+
labels come from explicitly rejected low-rated candidates. Neutral ratings
|
| 172 |
+
are ignored to avoid teaching ambiguous preferences.
|
| 173 |
+
"""
|
| 174 |
+
examples: list[dict[str, Any]] = []
|
| 175 |
+
for row in rows:
|
| 176 |
+
rating = float(row.get("rating", 0.0) or 0.0)
|
| 177 |
+
accepted = bool(row.get("accepted", False))
|
| 178 |
+
if accepted or rating >= 4.0:
|
| 179 |
+
label = 1
|
| 180 |
+
elif rating <= 2.0:
|
| 181 |
+
label = 0
|
| 182 |
+
else:
|
| 183 |
+
continue
|
| 184 |
+
transition = row.get("transition", {}) if isinstance(row.get("transition", {}), Mapping) else {}
|
| 185 |
+
cues = transition.get("selected_cues", {}) if isinstance(transition.get("selected_cues", {}), Mapping) else {}
|
| 186 |
+
duration = float(transition.get("duration_seconds", 1.0) or 1.0)
|
| 187 |
+
for role, cue in cues.items():
|
| 188 |
+
if isinstance(cue, Mapping):
|
| 189 |
+
enriched = dict(cue)
|
| 190 |
+
enriched.setdefault("kind", "mix_out" if role == "a_out" else "mix_in" if role == "b_in" else "drop")
|
| 191 |
+
examples.append({"cue": enriched, "duration": duration, "label": label, "source": "listening_rating", "rating": rating, "accepted": accepted})
|
| 192 |
+
return examples
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def train_from_listening_ratings(ratings_path: str | Path = "data/listening-ratings.jsonl", *, output_path: str | Path = DEFAULT_MODEL_PATH) -> CuePointModel:
|
| 196 |
+
from listening_benchmarks import load_ratings
|
| 197 |
+
examples = examples_from_rating_rows(load_ratings(ratings_path))
|
| 198 |
+
model = fit_logistic_model(examples)
|
| 199 |
+
save_model(model, output_path)
|
| 200 |
+
return model
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def append_training_example(path: str | Path, cue: Mapping[str, Any], *, duration: float, label: int, source: str) -> None:
|
| 204 |
+
path = Path(path)
|
| 205 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
row = {"cue": dict(cue), "duration": float(duration), "label": int(label), "source": source}
|
| 207 |
+
with path.open("a", encoding="utf-8") as f:
|
| 208 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|