Spaces:
Running
Running
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- .gitattributes +2 -0
- Dockerfile +4 -2
- README.md +4 -3
- app/backend/app.py +0 -0
- app/backend/data/auto_annotator.py +157 -34
- app/backend/data/pre_encoder.py +354 -0
- app/backend/data/projects.py +1023 -0
- app/backend/data/slicing.py +183 -0
- app/core/audio/midi_input.py +172 -0
- app/core/config.py +16 -86
- app/core/generation/audio_generator.py +490 -473
- app/core/generation/audio_post_process.py +713 -44
- app/core/model_manager.py +628 -437
- app/core/training/hyperparam_suggester.py +299 -141
- app/core/training/sa3_lora_runner.py +331 -0
- app/core/training/sa3_trainer.py +839 -0
- app/frontend/index.html +29 -6
- app/frontend/logs/fragmenta_20260525.log +8 -0
- app/frontend/package.json +2 -2
- app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf +3 -0
- app/frontend/public/InterTight-VariableFont_wght.ttf +3 -0
- app/frontend/public/fragmenta_background.png +2 -2
- app/frontend/public/interface.png +2 -2
- app/frontend/src/App.js +0 -0
- app/frontend/src/api.js +1 -0
- app/frontend/src/components/AboutDialog.js +130 -0
- app/frontend/src/components/AudioWaveform.js +258 -0
- app/frontend/src/components/ChannelFragmentHistory.js +217 -0
- app/frontend/src/components/CheckpointManagerWindow.js +243 -0
- app/frontend/src/components/CheckpointRow.js +270 -0
- app/frontend/src/components/DatasetPrep.js +1823 -0
- app/frontend/src/components/EditPanel.js +597 -0
- app/frontend/src/components/GeneratedFragmentsWindow.js +420 -70
- app/frontend/src/components/GenerationWaveform.js +217 -0
- app/frontend/src/components/InfoView.js +91 -0
- app/frontend/src/components/LoraStack.js +252 -0
- app/frontend/src/components/LossChart.js +27 -11
- app/frontend/src/components/MidiConfigMenu.js +118 -46
- app/frontend/src/components/MidiContext.js +38 -48
- app/frontend/src/components/PerformanceChannel.js +618 -239
- app/frontend/src/components/PerformancePanel.js +0 -0
- app/frontend/src/components/StorageDrilldown.js +84 -0
- app/frontend/src/components/Tooltip.js +35 -0
- app/frontend/src/components/TrainingMonitor.js +76 -35
- app/frontend/src/components/WelcomePage.js +22 -33
- app/frontend/src/components/usePerformanceSession.js +37 -7
- app/frontend/src/theme.js +0 -0
- app/frontend/src/tooltips.js +134 -0
- app/frontend/src/utils/cueAudio.js +29 -6
- app/frontend/src/utils/fragmentDrag.js +25 -0
.gitattributes
CHANGED
|
@@ -47,3 +47,5 @@ utils/vendor/wheels/antlr4_python3_runtime-4.9.3-py3-none-any.whl filter=lfs dif
|
|
| 47 |
vendor/stable-audio-tools/demo_cfg_3_00000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
vendor/stable-audio-tools/demo_cfg_6_00000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 49 |
vendor/stable-audio-tools/demo_cfg_9_00000001.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 47 |
vendor/stable-audio-tools/demo_cfg_3_00000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
vendor/stable-audio-tools/demo_cfg_6_00000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 49 |
vendor/stable-audio-tools/demo_cfg_9_00000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
app/frontend/public/InterTight-VariableFont_wght.ttf filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
CHANGED
|
@@ -60,8 +60,9 @@ RUN grep -ivE 'flash-attn|extra-index-url|pycairo|pygobject|pywebview' requireme
|
|
| 60 |
COPY . .
|
| 61 |
COPY --from=frontend-builder /build/frontend/build ./app/frontend/build
|
| 62 |
|
| 63 |
-
# Install
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
# Create writable directories
|
| 67 |
RUN mkdir -p /app/models/pretrained \
|
|
@@ -105,6 +106,7 @@ ENV FLASK_HOST=0.0.0.0
|
|
| 105 |
ENV FLASK_PORT=7860
|
| 106 |
ENV FRAGMENTA_LOG_LEVEL=INFO
|
| 107 |
ENV FRAGMENTA_DOCKER=1
|
|
|
|
| 108 |
ENV FRAGMENTA_USE_CUSTOM_MODELS=true
|
| 109 |
ENV HOME=/home/user
|
| 110 |
ENV PATH="/home/user/.local/bin:${PATH}"
|
|
|
|
| 60 |
COPY . .
|
| 61 |
COPY --from=frontend-builder /build/frontend/build ./app/frontend/build
|
| 62 |
|
| 63 |
+
# Install vendored Stable Audio 3 in-tree (--no-deps: runtime deps come from
|
| 64 |
+
# requirements.txt). Makes `import stable_audio_3` resolve.
|
| 65 |
+
RUN pip install --no-cache-dir --root-user-action=ignore --no-deps -e ./vendor/stable-audio-3/
|
| 66 |
|
| 67 |
# Create writable directories
|
| 68 |
RUN mkdir -p /app/models/pretrained \
|
|
|
|
| 106 |
ENV FLASK_PORT=7860
|
| 107 |
ENV FRAGMENTA_LOG_LEVEL=INFO
|
| 108 |
ENV FRAGMENTA_DOCKER=1
|
| 109 |
+
ENV PYTHONPATH=/app/vendor/stable-audio-3
|
| 110 |
ENV FRAGMENTA_USE_CUSTOM_MODELS=true
|
| 111 |
ENV HOME=/home/user
|
| 112 |
ENV PATH="/home/user/.local/bin:${PATH}"
|
README.md
CHANGED
|
@@ -17,9 +17,10 @@ Generate and fine-tune audio from text prompts using Stable Audio Open.
|
|
| 17 |
|
| 18 |
## Getting Started
|
| 19 |
|
| 20 |
-
1.
|
| 21 |
-
|
| 22 |
-
- `
|
|
|
|
| 23 |
2. The Space will auto-rebuild after the upload.
|
| 24 |
3. Use the **Data Processing** tab to upload audio + prompts.
|
| 25 |
4. Use the **Training** tab to fine-tune.
|
|
|
|
| 17 |
|
| 18 |
## Getting Started
|
| 19 |
|
| 20 |
+
1. Download an SA3 checkpoint via the in-app Checkpoint Manager, or place one
|
| 21 |
+
under `models/pretrained/sa3/hub/` in the Space Files tab.
|
| 22 |
+
- `sa3-small-music` (recommended for CPU Spaces)
|
| 23 |
+
- `sa3-medium` (recommended for GPU Spaces with Flash Attention 2)
|
| 24 |
2. The Space will auto-rebuild after the upload.
|
| 25 |
3. Use the **Data Processing** tab to upload audio + prompts.
|
| 26 |
4. Use the **Training** tab to fine-tune.
|
app/backend/app.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/backend/data/auto_annotator.py
CHANGED
|
@@ -22,6 +22,11 @@ AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac", ".m4a", ".ogg", ".aac")
|
|
| 22 |
CLAP_CKPT_FILENAME = "music_audioset_epoch_15_esc_90.14.pt"
|
| 23 |
CLAP_REPO = "lukewys/laion_clap"
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
KEY_NAMES_SHARP = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
| 26 |
KEY_NAMES_FLAT = ["C", "Db", "D", "Eb", "E", "F", "Gb", "G", "Ab", "A", "Bb", "B"]
|
| 27 |
|
|
@@ -44,9 +49,18 @@ def _iter_audio_files(folder: Path) -> List[Path]:
|
|
| 44 |
|
| 45 |
def _estimate_tempo(y, sr) -> Optional[int]:
|
| 46 |
import librosa
|
|
|
|
| 47 |
try:
|
| 48 |
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
if bpm <= 0:
|
| 51 |
return None
|
| 52 |
return int(round(bpm))
|
|
@@ -178,12 +192,50 @@ class _ClapTagger:
|
|
| 178 |
f"CLAP checkpoint not found at {self.ckpt_path}. "
|
| 179 |
"Download it first via /api/bulk-annotate/download-clap."
|
| 180 |
)
|
| 181 |
-
import laion_clap
|
| 182 |
-
import torch
|
| 183 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# torch >= 2.6 flipped torch.load(weights_only=True) and newer
|
| 189 |
# transformers dropped the roberta position_ids buffer, so
|
|
@@ -268,44 +320,97 @@ def clap_checkpoint_path(models_pretrained_dir: Path) -> Path:
|
|
| 268 |
return models_pretrained_dir / "clap" / CLAP_CKPT_FILENAME
|
| 269 |
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
def clap_checkpoint_available(models_pretrained_dir: Path) -> bool:
|
| 272 |
return clap_checkpoint_path(models_pretrained_dir).exists()
|
| 273 |
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
def download_clap_checkpoint(
|
| 276 |
models_pretrained_dir: Path,
|
| 277 |
progress_cb: Optional[Callable[[str], None]] = None,
|
|
|
|
| 278 |
) -> Path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
target = clap_checkpoint_path(models_pretrained_dir)
|
| 280 |
target.parent.mkdir(parents=True, exist_ok=True)
|
| 281 |
-
|
| 282 |
-
|
| 283 |
|
| 284 |
-
from huggingface_hub import hf_hub_download
|
| 285 |
import os
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
-
# Use custom CLAP from fragmenta-models on HF Spaces
|
| 291 |
-
use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true'
|
| 292 |
-
if use_custom_repo:
|
| 293 |
-
repo_id = "MazCodes/fragmenta-models"
|
| 294 |
-
else:
|
| 295 |
-
repo_id = CLAP_REPO
|
| 296 |
-
|
| 297 |
-
downloaded = hf_hub_download(
|
| 298 |
-
repo_id=repo_id,
|
| 299 |
-
filename=CLAP_CKPT_FILENAME,
|
| 300 |
-
local_dir=str(target.parent),
|
| 301 |
-
)
|
| 302 |
-
downloaded_path = Path(downloaded)
|
| 303 |
-
if downloaded_path != target:
|
| 304 |
-
try:
|
| 305 |
-
downloaded_path.replace(target)
|
| 306 |
-
except OSError:
|
| 307 |
-
import shutil
|
| 308 |
-
shutil.copy2(downloaded_path, target)
|
| 309 |
return target
|
| 310 |
|
| 311 |
|
|
@@ -328,8 +433,10 @@ def annotate_file(
|
|
| 328 |
label_sets: Dict[str, List[str]],
|
| 329 |
sr: int = 22050,
|
| 330 |
max_seconds: float = 60.0,
|
|
|
|
| 331 |
) -> Dict[str, Any]:
|
| 332 |
import librosa
|
|
|
|
| 333 |
|
| 334 |
parts: Dict[str, Any] = {}
|
| 335 |
try:
|
|
@@ -343,10 +450,19 @@ def annotate_file(
|
|
| 343 |
"error": f"load failed: {exc}",
|
| 344 |
}
|
| 345 |
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
if tier == "rich" and clap_tagger is not None:
|
| 352 |
try:
|
|
@@ -355,7 +471,14 @@ def annotate_file(
|
|
| 355 |
except Exception as exc:
|
| 356 |
logger.warning("CLAP tagging failed for %s: %s", audio_path.name, exc)
|
| 357 |
|
| 358 |
-
prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
return {
|
| 360 |
"file_name": audio_path.name,
|
| 361 |
"prompt": prompt,
|
|
|
|
| 22 |
CLAP_CKPT_FILENAME = "music_audioset_epoch_15_esc_90.14.pt"
|
| 23 |
CLAP_REPO = "lukewys/laion_clap"
|
| 24 |
|
| 25 |
+
# Text-side dependencies laion_clap pulls from HF on construction.
|
| 26 |
+
# We stage these into models/pretrained/clap/hub/ so the rich tier is
|
| 27 |
+
# fully offline after a single download and nothing leaks to ~/.cache.
|
| 28 |
+
CLAP_TEXT_DEPS = ("roberta-base", "bert-base-uncased", "facebook/bart-base")
|
| 29 |
+
|
| 30 |
KEY_NAMES_SHARP = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
| 31 |
KEY_NAMES_FLAT = ["C", "Db", "D", "Eb", "E", "F", "Gb", "G", "Ab", "A", "Bb", "B"]
|
| 32 |
|
|
|
|
| 49 |
|
| 50 |
def _estimate_tempo(y, sr) -> Optional[int]:
|
| 51 |
import librosa
|
| 52 |
+
import numpy as np
|
| 53 |
try:
|
| 54 |
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 55 |
+
# librosa 0.10+ returns tempo as np.ndarray (shape (1,) typically).
|
| 56 |
+
# numpy 2.x removed implicit float() conversion of N-d arrays β
|
| 57 |
+
# `float(np.array([120.]))` now raises TypeError instead of returning
|
| 58 |
+
# 120.0 like numpy 1.x did. Normalize via .flat[0] which handles
|
| 59 |
+
# scalar, 0-d, 1-d, and N-d uniformly.
|
| 60 |
+
arr = np.atleast_1d(np.asarray(tempo))
|
| 61 |
+
if arr.size == 0:
|
| 62 |
+
return None
|
| 63 |
+
bpm = float(arr.flat[0])
|
| 64 |
if bpm <= 0:
|
| 65 |
return None
|
| 66 |
return int(round(bpm))
|
|
|
|
| 192 |
f"CLAP checkpoint not found at {self.ckpt_path}. "
|
| 193 |
"Download it first via /api/bulk-annotate/download-clap."
|
| 194 |
)
|
|
|
|
|
|
|
| 195 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 196 |
|
| 197 |
+
# Point HF resolution at our project-local cache and disable the
|
| 198 |
+
# HEAD-revalidation traffic. After download_clap_checkpoint() has
|
| 199 |
+
# staged the text deps under <pretrained>/clap/hub/, CLAP_Module
|
| 200 |
+
# loads them offline with zero HF hub requests.
|
| 201 |
+
#
|
| 202 |
+
# Two reasons env vars alone aren't enough:
|
| 203 |
+
# 1. huggingface_hub.constants.HF_HUB_OFFLINE is captured at
|
| 204 |
+
# module-import time (constants.py:185). model_manager.py
|
| 205 |
+
# imports huggingface_hub at app startup, so the constant is
|
| 206 |
+
# already False by the time we set the env var here.
|
| 207 |
+
# transformers.utils.hub.is_offline_mode reads that same
|
| 208 |
+
# constant β patching the attribute makes both libraries see
|
| 209 |
+
# offline mode.
|
| 210 |
+
# 2. laion_clap/training/data.py:44-46 runs three from_pretrained
|
| 211 |
+
# calls at MODULE LEVEL β those fire the first time we do
|
| 212 |
+
# `import laion_clap` and predate any patch we do after the
|
| 213 |
+
# import. So we patch BEFORE the import, not after.
|
| 214 |
+
hub_dir = self.ckpt_path.parent / "hub"
|
| 215 |
+
env_keys = ("HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE", "TRANSFORMERS_CACHE",
|
| 216 |
+
"HF_HUB_OFFLINE", "TRANSFORMERS_OFFLINE")
|
| 217 |
+
prev_env = {k: os.environ.get(k) for k in env_keys}
|
| 218 |
+
os.environ["HF_HUB_CACHE"] = str(hub_dir)
|
| 219 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
|
| 220 |
+
os.environ["TRANSFORMERS_CACHE"] = str(hub_dir)
|
| 221 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 222 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 223 |
+
|
| 224 |
+
import huggingface_hub.constants as _hhc
|
| 225 |
+
prev_offline_attr = _hhc.HF_HUB_OFFLINE
|
| 226 |
+
_hhc.HF_HUB_OFFLINE = True
|
| 227 |
+
try:
|
| 228 |
+
import laion_clap # noqa: E402 β must follow the offline patch
|
| 229 |
+
import torch
|
| 230 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 231 |
+
model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device=device)
|
| 232 |
+
finally:
|
| 233 |
+
_hhc.HF_HUB_OFFLINE = prev_offline_attr
|
| 234 |
+
for k, v in prev_env.items():
|
| 235 |
+
if v is None:
|
| 236 |
+
os.environ.pop(k, None)
|
| 237 |
+
else:
|
| 238 |
+
os.environ[k] = v
|
| 239 |
|
| 240 |
# torch >= 2.6 flipped torch.load(weights_only=True) and newer
|
| 241 |
# transformers dropped the roberta position_ids buffer, so
|
|
|
|
| 320 |
return models_pretrained_dir / "clap" / CLAP_CKPT_FILENAME
|
| 321 |
|
| 322 |
|
| 323 |
+
def clap_hub_dir(models_pretrained_dir: Path) -> Path:
|
| 324 |
+
"""HF cache for laion_clap's text-side deps. Sibling of the .pt."""
|
| 325 |
+
return models_pretrained_dir / "clap" / "hub"
|
| 326 |
+
|
| 327 |
+
|
| 328 |
def clap_checkpoint_available(models_pretrained_dir: Path) -> bool:
|
| 329 |
return clap_checkpoint_path(models_pretrained_dir).exists()
|
| 330 |
|
| 331 |
|
| 332 |
+
def _text_dep_snapshot_present(hub_dir: Path, repo_id: str) -> bool:
|
| 333 |
+
safe = "models--" + repo_id.replace("/", "--")
|
| 334 |
+
snap_root = hub_dir / safe / "snapshots"
|
| 335 |
+
if not snap_root.exists():
|
| 336 |
+
return False
|
| 337 |
+
return any(snap_root.iterdir())
|
| 338 |
+
|
| 339 |
+
|
| 340 |
def download_clap_checkpoint(
|
| 341 |
models_pretrained_dir: Path,
|
| 342 |
progress_cb: Optional[Callable[[str], None]] = None,
|
| 343 |
+
phase_cb: Optional[Callable[[int, int, str], None]] = None,
|
| 344 |
) -> Path:
|
| 345 |
+
"""Download the CLAP audio .pt plus laion_clap's text-side HF snapshots.
|
| 346 |
+
|
| 347 |
+
Four sequential phases β emit a phase update (current, total, label) at the
|
| 348 |
+
start of each so a multi-phase progress UI can show real context. Skips
|
| 349 |
+
phases whose artifacts are already on disk.
|
| 350 |
+
|
| 351 |
+
`progress_cb` (str-only) is kept for the bulk-annotate API.
|
| 352 |
+
`phase_cb` (current, total, label) is the structured channel.
|
| 353 |
+
"""
|
| 354 |
target = clap_checkpoint_path(models_pretrained_dir)
|
| 355 |
target.parent.mkdir(parents=True, exist_ok=True)
|
| 356 |
+
hub_dir = clap_hub_dir(models_pretrained_dir)
|
| 357 |
+
hub_dir.mkdir(parents=True, exist_ok=True)
|
| 358 |
|
| 359 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 360 |
import os
|
| 361 |
|
| 362 |
+
total_phases = 1 + len(CLAP_TEXT_DEPS)
|
| 363 |
+
|
| 364 |
+
def _emit(phase_index: int, label: str) -> None:
|
| 365 |
+
if phase_cb:
|
| 366 |
+
phase_cb(phase_index, total_phases, label)
|
| 367 |
+
if progress_cb:
|
| 368 |
+
progress_cb(f"[{phase_index}/{total_phases}] {label}")
|
| 369 |
+
|
| 370 |
+
if not target.exists():
|
| 371 |
+
_emit(1, "CLAP audio model (~2.35 GB)")
|
| 372 |
+
|
| 373 |
+
# Use custom CLAP from fragmenta-models on HF Spaces
|
| 374 |
+
use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true'
|
| 375 |
+
if use_custom_repo:
|
| 376 |
+
repo_id = "MazCodes/fragmenta-models"
|
| 377 |
+
else:
|
| 378 |
+
repo_id = CLAP_REPO
|
| 379 |
+
|
| 380 |
+
downloaded = hf_hub_download(
|
| 381 |
+
repo_id=repo_id,
|
| 382 |
+
filename=CLAP_CKPT_FILENAME,
|
| 383 |
+
local_dir=str(target.parent),
|
| 384 |
+
)
|
| 385 |
+
downloaded_path = Path(downloaded)
|
| 386 |
+
if downloaded_path != target:
|
| 387 |
+
try:
|
| 388 |
+
downloaded_path.replace(target)
|
| 389 |
+
except OSError:
|
| 390 |
+
import shutil
|
| 391 |
+
shutil.copy2(downloaded_path, target)
|
| 392 |
+
|
| 393 |
+
# laion_clap's CLAP_Module(...) constructor instantiates a Roberta text
|
| 394 |
+
# branch plus bert/bart tokenizers at import time. Pre-stage them into
|
| 395 |
+
# our own cache so the rich tier is fully offline after this step.
|
| 396 |
+
# safetensors only β pytorch_model.bin is a redundant copy.
|
| 397 |
+
for i, repo_id in enumerate(CLAP_TEXT_DEPS, start=2):
|
| 398 |
+
if _text_dep_snapshot_present(hub_dir, repo_id):
|
| 399 |
+
continue
|
| 400 |
+
_emit(i, f"Text encoder: {repo_id}")
|
| 401 |
+
snapshot_download(
|
| 402 |
+
repo_id=repo_id,
|
| 403 |
+
cache_dir=str(hub_dir),
|
| 404 |
+
allow_patterns=[
|
| 405 |
+
"config.json",
|
| 406 |
+
"tokenizer*",
|
| 407 |
+
"vocab*",
|
| 408 |
+
"merges.txt",
|
| 409 |
+
"special_tokens_map.json",
|
| 410 |
+
"model.safetensors",
|
| 411 |
+
],
|
| 412 |
+
)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
return target
|
| 415 |
|
| 416 |
|
|
|
|
| 433 |
label_sets: Dict[str, List[str]],
|
| 434 |
sr: int = 22050,
|
| 435 |
max_seconds: float = 60.0,
|
| 436 |
+
prompt_template: Optional[str] = None,
|
| 437 |
) -> Dict[str, Any]:
|
| 438 |
import librosa
|
| 439 |
+
import warnings
|
| 440 |
|
| 441 |
parts: Dict[str, Any] = {}
|
| 442 |
try:
|
|
|
|
| 450 |
"error": f"load failed: {exc}",
|
| 451 |
}
|
| 452 |
|
| 453 |
+
# Silent / harmonically flat clips trip librosa's "Trying to estimate
|
| 454 |
+
# tuning from empty frequency set" warning during chroma extraction.
|
| 455 |
+
# The warning is benign β the analysis returns sensible defaults β but
|
| 456 |
+
# it spams stderr on every silent file, so we mute it here.
|
| 457 |
+
with warnings.catch_warnings():
|
| 458 |
+
warnings.filterwarnings(
|
| 459 |
+
"ignore",
|
| 460 |
+
message="Trying to estimate tuning from empty frequency set",
|
| 461 |
+
)
|
| 462 |
+
parts["bpm"] = _estimate_tempo(y, loaded_sr)
|
| 463 |
+
parts["key"] = _estimate_key(y, loaded_sr)
|
| 464 |
+
parts["brightness"] = _estimate_brightness(y, loaded_sr)
|
| 465 |
+
parts["character"] = _estimate_character(y, loaded_sr)
|
| 466 |
|
| 467 |
if tier == "rich" and clap_tagger is not None:
|
| 468 |
try:
|
|
|
|
| 471 |
except Exception as exc:
|
| 472 |
logger.warning("CLAP tagging failed for %s: %s", audio_path.name, exc)
|
| 473 |
|
| 474 |
+
# Template-driven prompt assembly. Falls back to the legacy descriptive
|
| 475 |
+
# prose if no template is supplied (call sites that haven't been
|
| 476 |
+
# threaded with project metadata yet).
|
| 477 |
+
if prompt_template is not None and prompt_template.strip():
|
| 478 |
+
from app.backend.data.projects import apply_template
|
| 479 |
+
prompt = apply_template(prompt_template, parts)
|
| 480 |
+
else:
|
| 481 |
+
prompt = _compose_prompt(parts)
|
| 482 |
return {
|
| 483 |
"file_name": audio_path.name,
|
| 484 |
"prompt": prompt,
|
app/backend/data/pre_encoder.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SA3 pre-encoding job runner β Phase 6.
|
| 2 |
+
|
| 3 |
+
Encodes every audio clip in a Dataset Workbench project into SA3 latents
|
| 4 |
+
once, ahead of training, so the training subprocess can skip the SAME
|
| 5 |
+
autoencoder pass per step. Mirrors the shape of `_project_annotate_jobs`
|
| 6 |
+
in app.py (background thread, per-project state, cooperative cancel).
|
| 7 |
+
|
| 8 |
+
Latents land in `<project>/.latents/` β a hidden subdirectory inside the
|
| 9 |
+
project folder. Disk layout matches SA3's `pre_encode_dataset.py`:
|
| 10 |
+
|
| 11 |
+
<project>/.latents/
|
| 12 |
+
000000000000.npy # latent tensor (shape (256, T_lat))
|
| 13 |
+
000000000000.json # {"prompt": "...", "padding_mask": [...], ...}
|
| 14 |
+
000001000000.npy
|
| 15 |
+
000001000000.json
|
| 16 |
+
...
|
| 17 |
+
silence.npy # padding latent (auto-generated)
|
| 18 |
+
_meta.json # Fragmenta-specific: AE used, source clip count
|
| 19 |
+
|
| 20 |
+
SA3's `train_lora.py --encoded_dir <project>/.latents` consumes this layout
|
| 21 |
+
directly. `SA3Trainer._stage_dataset` auto-detects the directory and feeds
|
| 22 |
+
`--encoded_dir` to the subprocess when latents are present.
|
| 23 |
+
|
| 24 |
+
Cache invalidation lives in projects.py β any project mutation that could
|
| 25 |
+
desync the latents (commit, delete_clip, slice_clip) wipes the directory.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import json
|
| 31 |
+
import os
|
| 32 |
+
import re
|
| 33 |
+
import signal
|
| 34 |
+
import subprocess
|
| 35 |
+
import sys
|
| 36 |
+
import threading
|
| 37 |
+
import time
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
from typing import Any, Dict, Optional
|
| 40 |
+
|
| 41 |
+
from app.backend.data.projects import project_path
|
| 42 |
+
from app.core.config import get_config
|
| 43 |
+
from utils.logger import get_logger
|
| 44 |
+
|
| 45 |
+
logger = get_logger("PreEncoder")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# --- Per-project job registry ----------------------------------------------
|
| 49 |
+
|
| 50 |
+
_pre_encode_jobs: Dict[str, Dict[str, Any]] = {}
|
| 51 |
+
_pre_encode_jobs_lock = threading.Lock()
|
| 52 |
+
_pre_encode_processes: Dict[str, subprocess.Popen] = {}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_pre_encode_job(project_name: str) -> Dict[str, Any]:
|
| 56 |
+
"""Snapshot of the current job state for a project. Always returns a
|
| 57 |
+
well-formed dict so the frontend can render against it without guards."""
|
| 58 |
+
with _pre_encode_jobs_lock:
|
| 59 |
+
job = _pre_encode_jobs.get(project_name)
|
| 60 |
+
if job is None:
|
| 61 |
+
return _idle_job()
|
| 62 |
+
return dict(job)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _idle_job() -> Dict[str, Any]:
|
| 66 |
+
return {
|
| 67 |
+
"state": "idle", # idle | queued | running | complete | failed | cancelled
|
| 68 |
+
"current": 0, # batch index (0-based)
|
| 69 |
+
"total": 0, # total batches (derived from clip count)
|
| 70 |
+
"current_file": "",
|
| 71 |
+
"error": None,
|
| 72 |
+
"started_at": None,
|
| 73 |
+
"finished_at": None,
|
| 74 |
+
"autoencoder": None,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# --- Autoencoder selection -------------------------------------------------
|
| 79 |
+
|
| 80 |
+
# Bind latents to a specific SA3 autoencoder. Latents from same-s only work
|
| 81 |
+
# with small-music / small-sfx DiTs; same-l latents only work with medium.
|
| 82 |
+
# For v1 we default to same-s (covers the most common base) and leave a
|
| 83 |
+
# manifest in .latents/_meta.json that training reads to verify
|
| 84 |
+
# compatibility. If a user trains against medium with same-s latents,
|
| 85 |
+
# SA3Trainer falls back to non-encoded training and logs a warning.
|
| 86 |
+
DEFAULT_AUTOENCODER = "same-s"
|
| 87 |
+
|
| 88 |
+
# Audio length (samples per channel) the dataset pads/crops to before
|
| 89 |
+
# encoding. SA3's pre_encode_dataset.py defaults to ~285s at 44.1 kHz, which
|
| 90 |
+
# covers any training-time --duration up to that limit (and SA3 small caps
|
| 91 |
+
# at 120s anyway). Longer clips in the project will be cropped to this
|
| 92 |
+
# length during encoding β a documented limitation for v1.
|
| 93 |
+
DEFAULT_SAMPLE_SIZE = 12_582_912
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# --- Job lifecycle ---------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def latents_dir(project_name: str) -> Path:
|
| 99 |
+
return project_path(project_name) / ".latents"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def latents_count(project_name: str) -> int:
|
| 103 |
+
d = latents_dir(project_name)
|
| 104 |
+
if not d.exists():
|
| 105 |
+
return 0
|
| 106 |
+
return sum(
|
| 107 |
+
1 for p in d.glob("*.npy")
|
| 108 |
+
if p.name != "silence.npy"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def latents_meta(project_name: str) -> Optional[Dict[str, Any]]:
|
| 113 |
+
"""Read the manifest we drop alongside the .npy files."""
|
| 114 |
+
p = latents_dir(project_name) / "_meta.json"
|
| 115 |
+
if not p.exists():
|
| 116 |
+
return None
|
| 117 |
+
try:
|
| 118 |
+
return json.loads(p.read_text(encoding="utf-8"))
|
| 119 |
+
except Exception:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def latents_match_base(project_name: str, base_model: str) -> bool:
|
| 124 |
+
"""Whether the cached latents are compatible with the chosen base.
|
| 125 |
+
|
| 126 |
+
same-s β small-music / small-sfx (and their *-base variants).
|
| 127 |
+
same-l β medium (and medium-base).
|
| 128 |
+
"""
|
| 129 |
+
meta = latents_meta(project_name)
|
| 130 |
+
if not meta:
|
| 131 |
+
return False
|
| 132 |
+
ae = meta.get("autoencoder")
|
| 133 |
+
if ae == "same-s":
|
| 134 |
+
return base_model in ("sa3-small-music", "sa3-small-music-base",
|
| 135 |
+
"sa3-small-sfx", "sa3-small-sfx-base")
|
| 136 |
+
if ae == "same-l":
|
| 137 |
+
return base_model in ("sa3-medium", "sa3-medium-base")
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def cancel_pre_encode(project_name: str) -> bool:
|
| 142 |
+
"""Send a cancel signal to an in-flight job. Returns True if cancelled."""
|
| 143 |
+
with _pre_encode_jobs_lock:
|
| 144 |
+
job = _pre_encode_jobs.get(project_name)
|
| 145 |
+
if not job or job.get("state") not in ("queued", "running"):
|
| 146 |
+
return False
|
| 147 |
+
job["state"] = "cancelled"
|
| 148 |
+
job["cancelled"] = True
|
| 149 |
+
|
| 150 |
+
proc = _pre_encode_processes.get(project_name)
|
| 151 |
+
if proc is not None and proc.poll() is None:
|
| 152 |
+
try:
|
| 153 |
+
proc.send_signal(signal.SIGINT)
|
| 154 |
+
try:
|
| 155 |
+
proc.wait(timeout=5)
|
| 156 |
+
except subprocess.TimeoutExpired:
|
| 157 |
+
proc.terminate()
|
| 158 |
+
try:
|
| 159 |
+
proc.wait(timeout=3)
|
| 160 |
+
except subprocess.TimeoutExpired:
|
| 161 |
+
proc.kill()
|
| 162 |
+
except Exception as exc:
|
| 163 |
+
logger.warning("Failed to signal pre-encode subprocess: %s", exc)
|
| 164 |
+
return True
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def start_pre_encode(
|
| 168 |
+
project_name: str,
|
| 169 |
+
autoencoder: Optional[str] = None,
|
| 170 |
+
sample_size: Optional[int] = None,
|
| 171 |
+
) -> Dict[str, Any]:
|
| 172 |
+
"""Spawn the pre-encode subprocess in a background thread. Returns the
|
| 173 |
+
job state β frontend polls /pre-encode/status thereafter.
|
| 174 |
+
"""
|
| 175 |
+
proj_dir = project_path(project_name)
|
| 176 |
+
if not proj_dir.exists():
|
| 177 |
+
raise FileNotFoundError(f"project not found: {project_name}")
|
| 178 |
+
|
| 179 |
+
ae = autoencoder or DEFAULT_AUTOENCODER
|
| 180 |
+
if ae not in ("same-s", "same-l"):
|
| 181 |
+
raise ValueError(f"autoencoder must be 'same-s' or 'same-l'; got {ae!r}")
|
| 182 |
+
|
| 183 |
+
with _pre_encode_jobs_lock:
|
| 184 |
+
existing = _pre_encode_jobs.get(project_name)
|
| 185 |
+
if existing and existing.get("state") in ("queued", "running"):
|
| 186 |
+
return dict(existing)
|
| 187 |
+
|
| 188 |
+
# Count source clips (sidecars committed) so we know the denominator.
|
| 189 |
+
sidecars = list(proj_dir.glob("*.txt"))
|
| 190 |
+
clip_count = sum(
|
| 191 |
+
1 for p in sidecars
|
| 192 |
+
if p.read_text(encoding="utf-8").strip()
|
| 193 |
+
and p.with_suffix(".wav").exists() # cheap & accurate enough
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
job: Dict[str, Any] = {
|
| 197 |
+
"state": "queued",
|
| 198 |
+
"current": 0,
|
| 199 |
+
"total": clip_count,
|
| 200 |
+
"current_file": "",
|
| 201 |
+
"error": None,
|
| 202 |
+
"started_at": time.time(),
|
| 203 |
+
"finished_at": None,
|
| 204 |
+
"autoencoder": ae,
|
| 205 |
+
"cancelled": False,
|
| 206 |
+
}
|
| 207 |
+
_pre_encode_jobs[project_name] = job
|
| 208 |
+
|
| 209 |
+
thread = threading.Thread(
|
| 210 |
+
target=_run_pre_encode,
|
| 211 |
+
args=(project_name, ae, sample_size or DEFAULT_SAMPLE_SIZE),
|
| 212 |
+
daemon=True,
|
| 213 |
+
name=f"sa3-pre-encode:{project_name}",
|
| 214 |
+
)
|
| 215 |
+
thread.start()
|
| 216 |
+
return get_pre_encode_job(project_name)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# --- Worker ----------------------------------------------------------------
|
| 220 |
+
|
| 221 |
+
def _update_job(project_name: str, **fields: Any) -> None:
|
| 222 |
+
with _pre_encode_jobs_lock:
|
| 223 |
+
job = _pre_encode_jobs.get(project_name)
|
| 224 |
+
if job is None:
|
| 225 |
+
return
|
| 226 |
+
job.update(fields)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _run_pre_encode(project_name: str, ae: str, sample_size: int) -> None:
|
| 230 |
+
"""Background-thread target. Spawns the SA3 pre_encode_dataset.py script,
|
| 231 |
+
streams stdout for progress, writes a _meta.json manifest on success."""
|
| 232 |
+
cfg = get_config()
|
| 233 |
+
proj_dir = project_path(project_name)
|
| 234 |
+
out_dir = latents_dir(project_name)
|
| 235 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 236 |
+
|
| 237 |
+
sa3_vendor = cfg.get_path("stable_audio_3")
|
| 238 |
+
venv_python = sys.executable
|
| 239 |
+
|
| 240 |
+
cmd = [
|
| 241 |
+
venv_python,
|
| 242 |
+
str(sa3_vendor / "scripts" / "pre_encode_dataset.py"),
|
| 243 |
+
"--model", ae,
|
| 244 |
+
"--data_dir", str(proj_dir),
|
| 245 |
+
"--output_path", str(out_dir),
|
| 246 |
+
"--batch_size", "1",
|
| 247 |
+
"--sample_size", str(int(sample_size)),
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
env = os.environ.copy()
|
| 251 |
+
pp = env.get("PYTHONPATH", "")
|
| 252 |
+
env["PYTHONPATH"] = (
|
| 253 |
+
f"{sa3_vendor}{os.pathsep}{pp}" if pp else str(sa3_vendor)
|
| 254 |
+
)
|
| 255 |
+
hub_dir = cfg.get_path("models_pretrained") / "sa3" / "hub"
|
| 256 |
+
env["HF_HUB_CACHE"] = str(hub_dir)
|
| 257 |
+
env["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
|
| 258 |
+
env["TRANSFORMERS_CACHE"] = str(hub_dir)
|
| 259 |
+
env["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 260 |
+
env["HF_HUB_OFFLINE"] = "1"
|
| 261 |
+
env["TRANSFORMERS_OFFLINE"] = "1"
|
| 262 |
+
|
| 263 |
+
_update_job(project_name, state="running")
|
| 264 |
+
logger.info(
|
| 265 |
+
"Pre-encoding started Β· project=%s Β· autoencoder=%s Β· clips=%d Β· sample_size=%d",
|
| 266 |
+
project_name, ae, get_pre_encode_job(project_name)["total"], sample_size,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
batch_pat = re.compile(r"Processing batch (\d+)")
|
| 270 |
+
process: Optional[subprocess.Popen] = None
|
| 271 |
+
try:
|
| 272 |
+
process = subprocess.Popen(
|
| 273 |
+
cmd,
|
| 274 |
+
cwd=str(cfg.project_root),
|
| 275 |
+
env=env,
|
| 276 |
+
stdout=subprocess.PIPE,
|
| 277 |
+
stderr=subprocess.STDOUT,
|
| 278 |
+
text=True,
|
| 279 |
+
bufsize=1,
|
| 280 |
+
)
|
| 281 |
+
_pre_encode_processes[project_name] = process
|
| 282 |
+
|
| 283 |
+
if process.stdout is not None:
|
| 284 |
+
for line in process.stdout:
|
| 285 |
+
line = line.rstrip()
|
| 286 |
+
m = batch_pat.search(line)
|
| 287 |
+
if m:
|
| 288 |
+
# Subprocess prints "Processing batch N" once per batch
|
| 289 |
+
# (and batch_size=1 β one batch per clip). N starts at 0.
|
| 290 |
+
_update_job(project_name, current=int(m.group(1)) + 1)
|
| 291 |
+
|
| 292 |
+
rc = process.wait() if process else 1
|
| 293 |
+
|
| 294 |
+
# Check whether we got cancelled mid-flight.
|
| 295 |
+
snapshot = get_pre_encode_job(project_name)
|
| 296 |
+
if snapshot.get("cancelled"):
|
| 297 |
+
_update_job(
|
| 298 |
+
project_name,
|
| 299 |
+
state="cancelled",
|
| 300 |
+
finished_at=time.time(),
|
| 301 |
+
)
|
| 302 |
+
logger.info("Pre-encoding cancelled Β· project=%s", project_name)
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
if rc != 0:
|
| 306 |
+
_update_job(
|
| 307 |
+
project_name,
|
| 308 |
+
state="failed",
|
| 309 |
+
error=f"pre_encode_dataset.py exited with code {rc}",
|
| 310 |
+
finished_at=time.time(),
|
| 311 |
+
)
|
| 312 |
+
logger.error(
|
| 313 |
+
"Pre-encoding failed (exit %s) Β· project=%s",
|
| 314 |
+
rc, project_name,
|
| 315 |
+
)
|
| 316 |
+
return
|
| 317 |
+
|
| 318 |
+
# Success β write manifest so SA3Trainer can verify AE compatibility.
|
| 319 |
+
manifest = {
|
| 320 |
+
"autoencoder": ae,
|
| 321 |
+
"sample_size": sample_size,
|
| 322 |
+
"created_at": time.time(),
|
| 323 |
+
"source_clip_count": snapshot.get("total", 0),
|
| 324 |
+
"encoded_count": latents_count(project_name),
|
| 325 |
+
}
|
| 326 |
+
try:
|
| 327 |
+
(out_dir / "_meta.json").write_text(
|
| 328 |
+
json.dumps(manifest, indent=2), encoding="utf-8",
|
| 329 |
+
)
|
| 330 |
+
except Exception as exc:
|
| 331 |
+
logger.warning("Failed to write latents manifest: %s", exc)
|
| 332 |
+
|
| 333 |
+
_update_job(
|
| 334 |
+
project_name,
|
| 335 |
+
state="complete",
|
| 336 |
+
current=manifest["encoded_count"],
|
| 337 |
+
total=manifest["encoded_count"] or snapshot.get("total", 0),
|
| 338 |
+
finished_at=time.time(),
|
| 339 |
+
)
|
| 340 |
+
logger.info(
|
| 341 |
+
"Pre-encoding complete Β· project=%s Β· %d latent(s) Β· ae=%s",
|
| 342 |
+
project_name, manifest["encoded_count"], ae,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
except Exception as exc:
|
| 346 |
+
_update_job(
|
| 347 |
+
project_name,
|
| 348 |
+
state="failed",
|
| 349 |
+
error=str(exc),
|
| 350 |
+
finished_at=time.time(),
|
| 351 |
+
)
|
| 352 |
+
logger.exception("Pre-encoding crashed for project=%s", project_name)
|
| 353 |
+
finally:
|
| 354 |
+
_pre_encode_processes.pop(project_name, None)
|
app/backend/data/projects.py
ADDED
|
@@ -0,0 +1,1023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""On-disk projects + buffered in-memory editing for SA3 sidecar datasets.
|
| 2 |
+
|
| 3 |
+
A *project* is a folder under `<user_data_dir>/projects/<name>/` (or wherever
|
| 4 |
+
`FRAGMENTA_PROJECTS_DIR` points) holding audio + `.txt` sidecar pairs plus a
|
| 5 |
+
hidden `.project.json` with Fragmenta metadata. The on-disk folder is the
|
| 6 |
+
**committed** dataset β what training reads, what survives across app
|
| 7 |
+
restarts.
|
| 8 |
+
|
| 9 |
+
The UI works against an **in-memory session** per loaded project. Prompt
|
| 10 |
+
edits, auto-annotate output, and just-ingested audio all live in memory
|
| 11 |
+
until the user explicitly persists them via:
|
| 12 |
+
|
| 13 |
+
Save β write `.draft.json` (transient, hidden). Survives app restart
|
| 14 |
+
but is not the SA3 deliverable.
|
| 15 |
+
Commit β flush prompts to `.txt` sidecars, mark current audio as
|
| 16 |
+
committed in `.project.json`, delete `.draft.json`.
|
| 17 |
+
Discard β drop the in-memory session, delete `.draft.json`, remove any
|
| 18 |
+
audio files added since the last commit.
|
| 19 |
+
|
| 20 |
+
See DATASET_PREP_REDESIGN.md for the full design and rationale.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import json
|
| 26 |
+
import logging
|
| 27 |
+
import os
|
| 28 |
+
import re
|
| 29 |
+
import shutil
|
| 30 |
+
import threading
|
| 31 |
+
import time
|
| 32 |
+
from dataclasses import dataclass, field
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 35 |
+
|
| 36 |
+
from app.backend.data.auto_annotator import AUDIO_EXTENSIONS, _iter_audio_files
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
PROJECT_METADATA_FILENAME = ".project.json"
|
| 41 |
+
PROJECT_DRAFT_FILENAME = ".draft.json"
|
| 42 |
+
DEFAULT_INGEST_MODE = "copy" # copy | symlink
|
| 43 |
+
INGEST_MODES = ("copy", "symlink")
|
| 44 |
+
|
| 45 |
+
# SA3's prompting guide (vendor/stable-audio-3/docs/guides/prompting.md)
|
| 46 |
+
# distinguishes three generation modes β music, stems / solo instruments,
|
| 47 |
+
# and audio samples / SFX β each with its own AudioSparx-tag convention.
|
| 48 |
+
# We ship one preset per mode and let the user pick a single id; the rest
|
| 49 |
+
# is opinionated defaults. Each segment is rendered by apply_template's
|
| 50 |
+
# segment-drop semantics, so missing CLAP attributes never leave dangling
|
| 51 |
+
# punctuation.
|
| 52 |
+
PROMPT_TEMPLATE_PRESETS: Dict[str, Dict[str, str]] = {
|
| 53 |
+
"music": {
|
| 54 |
+
"label": "Music",
|
| 55 |
+
"description": "Full instrumental tracks (SA3's `TrackType: Music` convention).",
|
| 56 |
+
"template": (
|
| 57 |
+
"TrackType: Music, VocalType: Instrumental, "
|
| 58 |
+
"Genre: {genre}, Mood: {mood}, Instruments: {instruments}, "
|
| 59 |
+
"BPM: {bpm}, Key: {key}"
|
| 60 |
+
),
|
| 61 |
+
},
|
| 62 |
+
"instrument": {
|
| 63 |
+
"label": "Instrument / Stem",
|
| 64 |
+
"description": "Isolated parts or single-instrument pieces (`TrackType: Instrument`).",
|
| 65 |
+
"template": (
|
| 66 |
+
"TrackType: Instrument, "
|
| 67 |
+
"Instruments: {instruments}, Genre: {genre}, "
|
| 68 |
+
"BPM: {bpm}, Key: {key}, Mood: {mood}"
|
| 69 |
+
),
|
| 70 |
+
},
|
| 71 |
+
"sfx": {
|
| 72 |
+
"label": "Sample / SFX",
|
| 73 |
+
"description": "Sound effects, one-shots, samples (`TrackType: SFX`).",
|
| 74 |
+
"template": "TrackType: SFX, {brightness}, {character}",
|
| 75 |
+
},
|
| 76 |
+
}
|
| 77 |
+
DEFAULT_PROMPT_TEMPLATE_PRESET = "music"
|
| 78 |
+
|
| 79 |
+
# Names must look like reasonable filesystem folders.
|
| 80 |
+
_VALID_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9 _\-.]{0,99}$")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------- Locations -------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_projects_dir() -> Path:
|
| 87 |
+
"""Resolve the projects root.
|
| 88 |
+
|
| 89 |
+
Honors `FRAGMENTA_PROJECTS_DIR` for power users; otherwise sits next to
|
| 90 |
+
`data/` and `models/` under the configured user_data_dir.
|
| 91 |
+
"""
|
| 92 |
+
override = os.environ.get("FRAGMENTA_PROJECTS_DIR")
|
| 93 |
+
if override:
|
| 94 |
+
root = Path(override).expanduser()
|
| 95 |
+
else:
|
| 96 |
+
from app.core.config import get_config
|
| 97 |
+
root = get_config().user_data_dir / "projects"
|
| 98 |
+
root.mkdir(parents=True, exist_ok=True)
|
| 99 |
+
return root
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def project_path(name: str) -> Path:
|
| 103 |
+
return get_projects_dir() / name
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def project_metadata_path(name: str) -> Path:
|
| 107 |
+
return project_path(name) / PROJECT_METADATA_FILENAME
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def project_draft_path(name: str) -> Path:
|
| 111 |
+
return project_path(name) / PROJECT_DRAFT_FILENAME
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ---------- Validation ------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def sanitize_project_name(raw: Any) -> str:
|
| 118 |
+
if not isinstance(raw, str):
|
| 119 |
+
raise ValueError("Project name must be a string.")
|
| 120 |
+
name = raw.strip()
|
| 121 |
+
if not name:
|
| 122 |
+
raise ValueError("Project name cannot be empty.")
|
| 123 |
+
if name in (".", ".."):
|
| 124 |
+
raise ValueError("Invalid project name.")
|
| 125 |
+
if not _VALID_NAME_RE.match(name):
|
| 126 |
+
raise ValueError(
|
| 127 |
+
"Project name must start with a letter or digit and may only "
|
| 128 |
+
"contain letters, digits, spaces, dashes, underscores, and dots."
|
| 129 |
+
)
|
| 130 |
+
return name
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ---------- Disk persistence: committed state -------------------------------
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _default_metadata(name: str) -> Dict[str, Any]:
|
| 137 |
+
now = time.time()
|
| 138 |
+
return {
|
| 139 |
+
"name": name,
|
| 140 |
+
"created_at": now,
|
| 141 |
+
"modified_at": now,
|
| 142 |
+
"committed_at": None,
|
| 143 |
+
"ingest_mode": DEFAULT_INGEST_MODE,
|
| 144 |
+
"prompt_template_preset": DEFAULT_PROMPT_TEMPLATE_PRESET,
|
| 145 |
+
"source_folders": [],
|
| 146 |
+
"committed_files": [], # files written to disk + already committed
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _read_metadata(name: str) -> Dict[str, Any]:
|
| 151 |
+
path = project_metadata_path(name)
|
| 152 |
+
if not path.exists():
|
| 153 |
+
return _default_metadata(name)
|
| 154 |
+
try:
|
| 155 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 156 |
+
data = json.load(f)
|
| 157 |
+
except (OSError, json.JSONDecodeError) as exc:
|
| 158 |
+
logger.warning("Could not read project metadata %s: %s; using defaults.", path, exc)
|
| 159 |
+
return _default_metadata(name)
|
| 160 |
+
defaults = _default_metadata(name)
|
| 161 |
+
for k, v in defaults.items():
|
| 162 |
+
data.setdefault(k, v)
|
| 163 |
+
return data
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _write_metadata(name: str, metadata: Dict[str, Any]) -> None:
|
| 167 |
+
metadata["modified_at"] = time.time()
|
| 168 |
+
path = project_metadata_path(name)
|
| 169 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 170 |
+
tmp = path.with_suffix(path.suffix + ".tmp")
|
| 171 |
+
with open(tmp, "w", encoding="utf-8") as f:
|
| 172 |
+
json.dump(metadata, f, indent=2)
|
| 173 |
+
os.replace(tmp, path)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _sidecar_for(audio_path: Path) -> Path:
|
| 177 |
+
return audio_path.with_suffix(".txt")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _read_sidecar(audio_path: Path) -> str:
|
| 181 |
+
txt = _sidecar_for(audio_path)
|
| 182 |
+
if not txt.exists():
|
| 183 |
+
return ""
|
| 184 |
+
try:
|
| 185 |
+
return txt.read_text(encoding="utf-8").strip()
|
| 186 |
+
except OSError:
|
| 187 |
+
return ""
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _write_sidecar(audio_path: Path, prompt: str) -> None:
|
| 191 |
+
_sidecar_for(audio_path).write_text(prompt or "", encoding="utf-8")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ---------- Disk persistence: draft state -----------------------------------
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _read_draft(name: str) -> Optional[Dict[str, Any]]:
|
| 198 |
+
path = project_draft_path(name)
|
| 199 |
+
if not path.exists():
|
| 200 |
+
return None
|
| 201 |
+
try:
|
| 202 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 203 |
+
return json.load(f)
|
| 204 |
+
except (OSError, json.JSONDecodeError) as exc:
|
| 205 |
+
logger.warning("Could not read draft %s: %s; treating as no draft.", path, exc)
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _write_draft(name: str, draft: Dict[str, Any]) -> None:
|
| 210 |
+
draft["saved_at"] = time.time()
|
| 211 |
+
path = project_draft_path(name)
|
| 212 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 213 |
+
tmp = path.with_suffix(path.suffix + ".tmp")
|
| 214 |
+
with open(tmp, "w", encoding="utf-8") as f:
|
| 215 |
+
json.dump(draft, f, indent=2)
|
| 216 |
+
os.replace(tmp, path)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _delete_draft(name: str) -> None:
|
| 220 |
+
path = project_draft_path(name)
|
| 221 |
+
if path.exists():
|
| 222 |
+
path.unlink()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ---------- In-memory session ----------------------------------------------
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@dataclass
|
| 229 |
+
class ClipState:
|
| 230 |
+
"""One clip in an active project session.
|
| 231 |
+
|
| 232 |
+
`prompt` is the live in-memory value (what the UI shows). `committed_prompt`
|
| 233 |
+
is what's on disk in the sidecar β used to compute dirtiness.
|
| 234 |
+
|
| 235 |
+
`parent` is the original clip's file_name if this clip was produced by a
|
| 236 |
+
slice operation in the current session. In-memory only; not persisted
|
| 237 |
+
across restart (yet). Future merge-back will need disk-level lineage.
|
| 238 |
+
"""
|
| 239 |
+
file_name: str
|
| 240 |
+
path: str
|
| 241 |
+
prompt: str = ""
|
| 242 |
+
committed_prompt: str = ""
|
| 243 |
+
committed: bool = True # False if audio was added since last commit
|
| 244 |
+
parent: Optional[str] = None
|
| 245 |
+
|
| 246 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 247 |
+
return {
|
| 248 |
+
"file_name": self.file_name,
|
| 249 |
+
"path": self.path,
|
| 250 |
+
"prompt": self.prompt,
|
| 251 |
+
"committed_prompt": self.committed_prompt,
|
| 252 |
+
"committed": self.committed,
|
| 253 |
+
"dirty": self.prompt != self.committed_prompt,
|
| 254 |
+
"parent": self.parent,
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@dataclass
|
| 259 |
+
class ProjectSession:
|
| 260 |
+
"""In-memory view of a project. One per loaded project name.
|
| 261 |
+
|
| 262 |
+
Loading happens lazily on first GET. The session stays alive until
|
| 263 |
+
the user discards, commits, or the process exits.
|
| 264 |
+
"""
|
| 265 |
+
name: str
|
| 266 |
+
clips: Dict[str, ClipState] = field(default_factory=dict) # by file_name
|
| 267 |
+
saved_at: Optional[float] = None # last time .draft.json was written
|
| 268 |
+
last_save_snapshot: Dict[str, str] = field(default_factory=dict)
|
| 269 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 270 |
+
cancel_event: threading.Event = field(default_factory=threading.Event)
|
| 271 |
+
lock: threading.Lock = field(default_factory=threading.Lock)
|
| 272 |
+
# file_name -> (peaks, duration). Lazily filled by get_or_compute_peaks.
|
| 273 |
+
# Cleared on Discard. Survives an annotate; safe to recompute on miss.
|
| 274 |
+
peaks_cache: Dict[str, Tuple[List[float], float]] = field(default_factory=dict)
|
| 275 |
+
# file_name -> duration_sec. Same lifecycle, but populated cheaply via
|
| 276 |
+
# soundfile.info() instead of waiting for a peaks fetch.
|
| 277 |
+
duration_cache: Dict[str, float] = field(default_factory=dict)
|
| 278 |
+
|
| 279 |
+
def _draft_snapshot(self) -> Dict[str, str]:
|
| 280 |
+
"""Map file_name -> prompt, only for clips whose prompt differs from
|
| 281 |
+
the committed sidecar. Used both to decide if a Save is needed and
|
| 282 |
+
to compute the on-disk draft contents."""
|
| 283 |
+
return {c.file_name: c.prompt for c in self.clips.values() if c.prompt != c.committed_prompt}
|
| 284 |
+
|
| 285 |
+
def has_dirty_prompts(self) -> bool:
|
| 286 |
+
return any(c.prompt != c.committed_prompt for c in self.clips.values())
|
| 287 |
+
|
| 288 |
+
def has_uncommitted_files(self) -> bool:
|
| 289 |
+
return any(not c.committed for c in self.clips.values())
|
| 290 |
+
|
| 291 |
+
def has_unsaved_changes(self) -> bool:
|
| 292 |
+
"""True if the in-memory state differs from the saved draft."""
|
| 293 |
+
return self._draft_snapshot() != self.last_save_snapshot
|
| 294 |
+
|
| 295 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 296 |
+
ordered = sorted(self.clips.values(), key=lambda c: c.file_name)
|
| 297 |
+
# Phase 6 β pre-encoded latents state. The latents live inside the
|
| 298 |
+
# project at .latents/. Surface presence + count for the UI, plus
|
| 299 |
+
# the per-project "don't ask again" flag for the post-commit dialog.
|
| 300 |
+
proj_path = project_path(self.name)
|
| 301 |
+
latents_dir = proj_path / ".latents"
|
| 302 |
+
latents_npy = (
|
| 303 |
+
[p for p in latents_dir.glob("*.npy") if p.name != "silence.npy"]
|
| 304 |
+
if latents_dir.exists() else []
|
| 305 |
+
)
|
| 306 |
+
return {
|
| 307 |
+
"name": self.name,
|
| 308 |
+
"created_at": self.metadata.get("created_at"),
|
| 309 |
+
"modified_at": self.metadata.get("modified_at"),
|
| 310 |
+
"committed_at": self.metadata.get("committed_at"),
|
| 311 |
+
"ingest_mode": self.metadata.get("ingest_mode", DEFAULT_INGEST_MODE),
|
| 312 |
+
"prompt_template_preset": (
|
| 313 |
+
self.metadata.get("prompt_template_preset") or DEFAULT_PROMPT_TEMPLATE_PRESET
|
| 314 |
+
),
|
| 315 |
+
"prompt_template_presets": [
|
| 316 |
+
{"id": k, "label": v["label"], "description": v["description"], "template": v["template"]}
|
| 317 |
+
for k, v in PROMPT_TEMPLATE_PRESETS.items()
|
| 318 |
+
],
|
| 319 |
+
"source_folders": list(self.metadata.get("source_folders", [])),
|
| 320 |
+
"saved_at": self.saved_at,
|
| 321 |
+
"dirty": self.has_dirty_prompts() or self.has_uncommitted_files(),
|
| 322 |
+
"has_unsaved_changes": self.has_unsaved_changes(),
|
| 323 |
+
"uncommitted_files": [c.file_name for c in ordered if not c.committed],
|
| 324 |
+
"clips": [c.to_dict() for c in ordered],
|
| 325 |
+
"clip_count": len(self.clips),
|
| 326 |
+
"latents_present": bool(latents_npy),
|
| 327 |
+
"latents_count": len(latents_npy),
|
| 328 |
+
"suppress_pre_encode_prompt": bool(self.metadata.get("suppress_pre_encode_prompt")),
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Registry of active sessions keyed by project name.
|
| 333 |
+
_sessions: Dict[str, ProjectSession] = {}
|
| 334 |
+
_sessions_lock = threading.Lock()
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _get_or_load_session(name: str) -> ProjectSession:
|
| 338 |
+
"""Return the active session for `name`, loading from disk if needed."""
|
| 339 |
+
with _sessions_lock:
|
| 340 |
+
existing = _sessions.get(name)
|
| 341 |
+
if existing is not None:
|
| 342 |
+
return existing
|
| 343 |
+
|
| 344 |
+
# Validate folder exists.
|
| 345 |
+
path = project_path(name)
|
| 346 |
+
if not path.exists() or not path.is_dir():
|
| 347 |
+
raise FileNotFoundError(f"Project not found: {name}")
|
| 348 |
+
|
| 349 |
+
metadata = _read_metadata(name)
|
| 350 |
+
committed_files = set(metadata.get("committed_files") or [])
|
| 351 |
+
|
| 352 |
+
# Build clip states from the disk layout. `committed_prompt` is whatever's
|
| 353 |
+
# in the .txt sidecar today.
|
| 354 |
+
clips: Dict[str, ClipState] = {}
|
| 355 |
+
for audio_path in sorted(path.iterdir()):
|
| 356 |
+
if not audio_path.is_file():
|
| 357 |
+
continue
|
| 358 |
+
if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
|
| 359 |
+
continue
|
| 360 |
+
committed_prompt = _read_sidecar(audio_path)
|
| 361 |
+
is_committed = audio_path.name in committed_files
|
| 362 |
+
clips[audio_path.name] = ClipState(
|
| 363 |
+
file_name=audio_path.name,
|
| 364 |
+
path=str(audio_path),
|
| 365 |
+
prompt=committed_prompt,
|
| 366 |
+
committed_prompt=committed_prompt,
|
| 367 |
+
committed=is_committed,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
session = ProjectSession(name=name, clips=clips, metadata=metadata)
|
| 371 |
+
|
| 372 |
+
# Overlay any draft prompts on top of committed values.
|
| 373 |
+
draft = _read_draft(name)
|
| 374 |
+
if draft:
|
| 375 |
+
for file_name, prompt in (draft.get("prompts") or {}).items():
|
| 376 |
+
clip = session.clips.get(file_name)
|
| 377 |
+
if clip is not None:
|
| 378 |
+
clip.prompt = prompt
|
| 379 |
+
session.saved_at = draft.get("saved_at")
|
| 380 |
+
session.last_save_snapshot = dict(draft.get("prompts") or {})
|
| 381 |
+
|
| 382 |
+
with _sessions_lock:
|
| 383 |
+
# Race: another thread may have loaded concurrently. Use whichever
|
| 384 |
+
# got in first.
|
| 385 |
+
existing = _sessions.get(name)
|
| 386 |
+
if existing is not None:
|
| 387 |
+
return existing
|
| 388 |
+
_sessions[name] = session
|
| 389 |
+
return session
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _drop_session(name: str) -> None:
|
| 393 |
+
with _sessions_lock:
|
| 394 |
+
_sessions.pop(name, None)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# ---------- CRUD ------------------------------------------------------------
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def list_projects() -> List[Dict[str, Any]]:
|
| 401 |
+
root = get_projects_dir()
|
| 402 |
+
out: List[Dict[str, Any]] = []
|
| 403 |
+
for entry in sorted(root.iterdir()):
|
| 404 |
+
if not entry.is_dir() or entry.name.startswith("."):
|
| 405 |
+
continue
|
| 406 |
+
try:
|
| 407 |
+
meta = _read_metadata(entry.name)
|
| 408 |
+
except Exception as exc:
|
| 409 |
+
logger.warning("Skipping project %s: %s", entry.name, exc)
|
| 410 |
+
continue
|
| 411 |
+
clip_count = sum(
|
| 412 |
+
1 for f in entry.iterdir()
|
| 413 |
+
if f.is_file() and f.suffix.lower() in AUDIO_EXTENSIONS
|
| 414 |
+
)
|
| 415 |
+
has_draft = project_draft_path(entry.name).exists()
|
| 416 |
+
out.append({
|
| 417 |
+
"name": entry.name,
|
| 418 |
+
"created_at": meta.get("created_at"),
|
| 419 |
+
"modified_at": meta.get("modified_at"),
|
| 420 |
+
"committed_at": meta.get("committed_at"),
|
| 421 |
+
"clip_count": clip_count,
|
| 422 |
+
"has_draft": has_draft,
|
| 423 |
+
})
|
| 424 |
+
return out
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def create_project(name: str) -> Dict[str, Any]:
|
| 428 |
+
name = sanitize_project_name(name)
|
| 429 |
+
path = project_path(name)
|
| 430 |
+
if path.exists():
|
| 431 |
+
raise FileExistsError(f"Project '{name}' already exists.")
|
| 432 |
+
path.mkdir(parents=True)
|
| 433 |
+
metadata = _default_metadata(name)
|
| 434 |
+
_write_metadata(name, metadata)
|
| 435 |
+
return get_project(name)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def get_project(name: str) -> Dict[str, Any]:
|
| 439 |
+
session = _get_or_load_session(name)
|
| 440 |
+
with session.lock:
|
| 441 |
+
return session.to_dict()
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _stage_file(src: Path, dst: Path, mode: str) -> str:
|
| 445 |
+
"""Place `src` at `dst` using the requested ingest mode."""
|
| 446 |
+
if dst.exists() or dst.is_symlink():
|
| 447 |
+
return "skipped"
|
| 448 |
+
if mode == "symlink":
|
| 449 |
+
try:
|
| 450 |
+
dst.symlink_to(src.resolve())
|
| 451 |
+
return "symlinked"
|
| 452 |
+
except OSError as exc:
|
| 453 |
+
logger.warning("Symlink failed for %s -> %s: %s; falling back to copy.", src, dst, exc)
|
| 454 |
+
shutil.copy2(src, dst)
|
| 455 |
+
return "copied"
|
| 456 |
+
else:
|
| 457 |
+
shutil.copy2(src, dst)
|
| 458 |
+
return "copied"
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def ingest_folder(name: str, source_folder: Path, mode: str) -> Dict[str, Any]:
|
| 462 |
+
"""Add every audio file under `source_folder` to project `name`.
|
| 463 |
+
|
| 464 |
+
Audio is written to disk immediately (we don't buffer gigabytes). The
|
| 465 |
+
new files are flagged as uncommitted in the session so a later Discard
|
| 466 |
+
can remove them.
|
| 467 |
+
"""
|
| 468 |
+
if mode not in INGEST_MODES:
|
| 469 |
+
raise ValueError(f"Invalid ingest mode: {mode}")
|
| 470 |
+
if not source_folder.exists() or not source_folder.is_dir():
|
| 471 |
+
raise FileNotFoundError(f"Source folder not found: {source_folder}")
|
| 472 |
+
|
| 473 |
+
session = _get_or_load_session(name)
|
| 474 |
+
proj_path = project_path(name)
|
| 475 |
+
|
| 476 |
+
files = _iter_audio_files(source_folder)
|
| 477 |
+
if not files:
|
| 478 |
+
raise ValueError(f"No audio files found in {source_folder}")
|
| 479 |
+
|
| 480 |
+
copied = 0
|
| 481 |
+
symlinked = 0
|
| 482 |
+
skipped = 0
|
| 483 |
+
with session.lock:
|
| 484 |
+
for src in files:
|
| 485 |
+
dst = proj_path / src.name
|
| 486 |
+
tag = _stage_file(src, dst, mode)
|
| 487 |
+
if tag == "copied":
|
| 488 |
+
copied += 1
|
| 489 |
+
elif tag == "symlinked":
|
| 490 |
+
symlinked += 1
|
| 491 |
+
else:
|
| 492 |
+
skipped += 1
|
| 493 |
+
if tag != "skipped" and src.name not in session.clips:
|
| 494 |
+
# Newly added file β uncommitted.
|
| 495 |
+
session.clips[src.name] = ClipState(
|
| 496 |
+
file_name=src.name,
|
| 497 |
+
path=str(dst),
|
| 498 |
+
prompt="",
|
| 499 |
+
committed_prompt="",
|
| 500 |
+
committed=False,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
session.metadata["ingest_mode"] = mode
|
| 504 |
+
src_abs = str(source_folder.resolve())
|
| 505 |
+
if src_abs not in session.metadata.setdefault("source_folders", []):
|
| 506 |
+
session.metadata["source_folders"].append(src_abs)
|
| 507 |
+
|
| 508 |
+
return {
|
| 509 |
+
"copied": copied,
|
| 510 |
+
"symlinked": symlinked,
|
| 511 |
+
"skipped": skipped,
|
| 512 |
+
"added": copied + symlinked,
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def update_clip_prompt(name: str, file_name: str, prompt: str) -> Dict[str, Any]:
|
| 517 |
+
"""In-memory only. Disk is not touched until Save or Commit."""
|
| 518 |
+
session = _get_or_load_session(name)
|
| 519 |
+
with session.lock:
|
| 520 |
+
clip = session.clips.get(file_name)
|
| 521 |
+
if clip is None:
|
| 522 |
+
raise FileNotFoundError(f"Clip not found in project '{name}': {file_name}")
|
| 523 |
+
clip.prompt = prompt or ""
|
| 524 |
+
return clip.to_dict()
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def delete_clip(name: str, file_name: str) -> None:
|
| 528 |
+
"""Remove a clip immediately (audio + sidecar + session entry).
|
| 529 |
+
|
| 530 |
+
Treated like ingest: the disk change happens now, since carrying a
|
| 531 |
+
pending-deletion in memory complicates everything for no real win.
|
| 532 |
+
Discard cannot recover deleted files.
|
| 533 |
+
"""
|
| 534 |
+
session = _get_or_load_session(name)
|
| 535 |
+
proj_path = project_path(name)
|
| 536 |
+
with session.lock:
|
| 537 |
+
audio_path = proj_path / file_name
|
| 538 |
+
txt_path = _sidecar_for(audio_path)
|
| 539 |
+
if audio_path.exists():
|
| 540 |
+
audio_path.unlink()
|
| 541 |
+
if txt_path.exists():
|
| 542 |
+
txt_path.unlink()
|
| 543 |
+
session.clips.pop(file_name, None)
|
| 544 |
+
# Evict any cached peaks for this file (regardless of N).
|
| 545 |
+
for key in list(session.peaks_cache):
|
| 546 |
+
if key.startswith(f"{file_name}:"):
|
| 547 |
+
del session.peaks_cache[key]
|
| 548 |
+
session.duration_cache.pop(file_name, None)
|
| 549 |
+
committed = session.metadata.get("committed_files") or []
|
| 550 |
+
if file_name in committed:
|
| 551 |
+
session.metadata["committed_files"] = [f for f in committed if f != file_name]
|
| 552 |
+
# Invalidate latents β outside the lock so we don't block under FS I/O.
|
| 553 |
+
_invalidate_latents(name)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# ---------- Save / Commit / Discard -----------------------------------------
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def save_project(name: str) -> Dict[str, Any]:
|
| 560 |
+
"""Persist the current in-memory prompt diffs as a hidden draft."""
|
| 561 |
+
session = _get_or_load_session(name)
|
| 562 |
+
with session.lock:
|
| 563 |
+
snapshot = session._draft_snapshot()
|
| 564 |
+
draft = {
|
| 565 |
+
"prompts": snapshot,
|
| 566 |
+
"uncommitted_files": [c.file_name for c in session.clips.values() if not c.committed],
|
| 567 |
+
}
|
| 568 |
+
_write_draft(name, draft)
|
| 569 |
+
session.saved_at = time.time()
|
| 570 |
+
session.last_save_snapshot = dict(snapshot)
|
| 571 |
+
return session.to_dict()
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def _invalidate_latents(name: str) -> None:
|
| 575 |
+
"""Phase 6 β wipe any pre-encoded latents for this project.
|
| 576 |
+
|
| 577 |
+
Latents are bound to specific source-clip content; any mutation that
|
| 578 |
+
changes the source set (commit, delete_clip, slice_clip) renders them
|
| 579 |
+
misaligned. v1 strategy is wipe-and-recompute; per-clip invalidation
|
| 580 |
+
is a follow-up (not worth the complexity for the speed-up we get).
|
| 581 |
+
"""
|
| 582 |
+
latents_dir = project_path(name) / ".latents"
|
| 583 |
+
if latents_dir.exists():
|
| 584 |
+
shutil.rmtree(latents_dir, ignore_errors=True)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def update_pre_encode_suppression(name: str, suppress: bool) -> Dict[str, Any]:
|
| 588 |
+
"""Persist the 'Don't ask again' choice from the post-commit dialog.
|
| 589 |
+
|
| 590 |
+
Stored on .project.json so it survives restart. The Training-tab
|
| 591 |
+
fallback button is always available regardless of this flag.
|
| 592 |
+
"""
|
| 593 |
+
session = _get_or_load_session(name)
|
| 594 |
+
with session.lock:
|
| 595 |
+
session.metadata["suppress_pre_encode_prompt"] = bool(suppress)
|
| 596 |
+
_write_metadata(name, session.metadata)
|
| 597 |
+
return session.to_dict()
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def commit_project(name: str) -> Dict[str, Any]:
|
| 601 |
+
"""Flush in-memory state to disk as the canonical SA3 dataset.
|
| 602 |
+
|
| 603 |
+
Overwrites existing sidecars. Marks all current audio as committed.
|
| 604 |
+
Deletes any draft. Wipes any pre-encoded latents β re-encode is
|
| 605 |
+
explicit via the post-commit dialog or the Training-tab button.
|
| 606 |
+
"""
|
| 607 |
+
_invalidate_latents(name)
|
| 608 |
+
session = _get_or_load_session(name)
|
| 609 |
+
proj_path = project_path(name)
|
| 610 |
+
with session.lock:
|
| 611 |
+
# Write a sidecar for every clip, even if the prompt didn't change.
|
| 612 |
+
# This guarantees the on-disk state is exactly the in-memory state
|
| 613 |
+
# after Commit, no surprises.
|
| 614 |
+
for clip in session.clips.values():
|
| 615 |
+
audio_path = proj_path / clip.file_name
|
| 616 |
+
_write_sidecar(audio_path, clip.prompt)
|
| 617 |
+
clip.committed_prompt = clip.prompt
|
| 618 |
+
clip.committed = True
|
| 619 |
+
|
| 620 |
+
session.metadata["committed_files"] = sorted(session.clips.keys())
|
| 621 |
+
session.metadata["committed_at"] = time.time()
|
| 622 |
+
_write_metadata(name, session.metadata)
|
| 623 |
+
_delete_draft(name)
|
| 624 |
+
session.saved_at = None
|
| 625 |
+
session.last_save_snapshot = {}
|
| 626 |
+
return session.to_dict()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def delete_project(name: str) -> None:
|
| 630 |
+
"""Permanently remove a project β folder, sidecars, drafts, session.
|
| 631 |
+
|
| 632 |
+
Destructive: there is no recovery path. Caller should confirm with
|
| 633 |
+
the user before invoking.
|
| 634 |
+
"""
|
| 635 |
+
proj_path = project_path(name)
|
| 636 |
+
if not proj_path.exists():
|
| 637 |
+
raise FileNotFoundError(f"Project not found: {name}")
|
| 638 |
+
|
| 639 |
+
# Cancel any in-flight annotate first, drop the session, then nuke
|
| 640 |
+
# the folder. Order matters: if we rm the folder while another
|
| 641 |
+
# thread is writing to it (e.g. annotate writing prompts to memory
|
| 642 |
+
# is fine, but the audio-stream endpoint could be holding a file
|
| 643 |
+
# handle), at least the session is gone so no fresh writes happen.
|
| 644 |
+
with _sessions_lock:
|
| 645 |
+
existing = _sessions.pop(name, None)
|
| 646 |
+
if existing is not None:
|
| 647 |
+
existing.cancel_event.set()
|
| 648 |
+
shutil.rmtree(proj_path, ignore_errors=True)
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def discard_project(name: str) -> Dict[str, Any]:
|
| 652 |
+
"""Throw away all uncommitted work.
|
| 653 |
+
|
| 654 |
+
- Delete the draft.
|
| 655 |
+
- Delete audio files added since the last commit (and their sidecars).
|
| 656 |
+
- Drop the in-memory session so the next GET rebuilds from disk.
|
| 657 |
+
"""
|
| 658 |
+
session = _get_or_load_session(name)
|
| 659 |
+
proj_path = project_path(name)
|
| 660 |
+
with session.lock:
|
| 661 |
+
# Cancel any in-flight annotate before we tear state apart.
|
| 662 |
+
session.cancel_event.set()
|
| 663 |
+
|
| 664 |
+
uncommitted = [c.file_name for c in session.clips.values() if not c.committed]
|
| 665 |
+
for file_name in uncommitted:
|
| 666 |
+
audio_path = proj_path / file_name
|
| 667 |
+
txt_path = _sidecar_for(audio_path)
|
| 668 |
+
if audio_path.exists():
|
| 669 |
+
audio_path.unlink()
|
| 670 |
+
if txt_path.exists():
|
| 671 |
+
txt_path.unlink()
|
| 672 |
+
|
| 673 |
+
_delete_draft(name)
|
| 674 |
+
|
| 675 |
+
_drop_session(name)
|
| 676 |
+
return get_project(name)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
# ---------- Annotate cancellation handle ------------------------------------
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def get_session_handle(name: str) -> ProjectSession:
|
| 683 |
+
"""Used by the annotate endpoint to share a cancel handle + clip dict."""
|
| 684 |
+
return _get_or_load_session(name)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def reset_cancel(session: ProjectSession) -> None:
|
| 688 |
+
session.cancel_event.clear()
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
# ---------- Prompt template -------------------------------------------------
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
_TEMPLATE_VAR_RE = re.compile(r"\{([a-zA-Z_][a-zA-Z0-9_]*)\}")
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def _render_value(name: str, raw: Any) -> str:
|
| 698 |
+
"""Stringify one variable value. Lists get joined; falsy is empty."""
|
| 699 |
+
if raw is None:
|
| 700 |
+
return ""
|
| 701 |
+
if isinstance(raw, (list, tuple)):
|
| 702 |
+
parts = [str(x).strip() for x in raw if str(x).strip()]
|
| 703 |
+
return ", ".join(parts)
|
| 704 |
+
text = str(raw).strip()
|
| 705 |
+
return text
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def apply_template(template: str, attributes: Dict[str, Any]) -> str:
|
| 709 |
+
"""Segment-based templating with graceful missing-value handling.
|
| 710 |
+
|
| 711 |
+
The template is split on ',' (segments). For each segment, every
|
| 712 |
+
{var} placeholder is resolved against `attributes`. If any placeholder
|
| 713 |
+
in the segment resolves to empty/missing, the whole segment is dropped
|
| 714 |
+
β so a missing key/BPM/whatever doesn't leave dangling punctuation.
|
| 715 |
+
|
| 716 |
+
Segments without any placeholders (e.g. "TrackType: Music") always
|
| 717 |
+
appear.
|
| 718 |
+
"""
|
| 719 |
+
if not template:
|
| 720 |
+
return ""
|
| 721 |
+
out_segments: List[str] = []
|
| 722 |
+
for raw_segment in template.split(","):
|
| 723 |
+
segment = raw_segment.strip()
|
| 724 |
+
if not segment:
|
| 725 |
+
continue
|
| 726 |
+
var_names = _TEMPLATE_VAR_RE.findall(segment)
|
| 727 |
+
if var_names:
|
| 728 |
+
resolved = {n: _render_value(n, attributes.get(n)) for n in var_names}
|
| 729 |
+
if any(not v for v in resolved.values()):
|
| 730 |
+
continue # drop the segment β one of its vars is missing
|
| 731 |
+
segment = _TEMPLATE_VAR_RE.sub(
|
| 732 |
+
lambda m: resolved[m.group(1)],
|
| 733 |
+
segment,
|
| 734 |
+
)
|
| 735 |
+
out_segments.append(segment)
|
| 736 |
+
return ", ".join(out_segments)
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def resolve_prompt_template(session: "ProjectSession") -> str:
|
| 740 |
+
"""Return the active template string for the project's selected preset.
|
| 741 |
+
|
| 742 |
+
Falls back to the music default if the stored preset id is unknown
|
| 743 |
+
(e.g. someone hand-edited .project.json to a bad value).
|
| 744 |
+
"""
|
| 745 |
+
preset_id = (session.metadata.get("prompt_template_preset")
|
| 746 |
+
or DEFAULT_PROMPT_TEMPLATE_PRESET)
|
| 747 |
+
preset = PROMPT_TEMPLATE_PRESETS.get(preset_id)
|
| 748 |
+
if preset is None:
|
| 749 |
+
preset = PROMPT_TEMPLATE_PRESETS[DEFAULT_PROMPT_TEMPLATE_PRESET]
|
| 750 |
+
return preset["template"]
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def update_project_template_preset(name: str, preset_id: str) -> Dict[str, Any]:
|
| 754 |
+
"""Persist the user-selected preset id and return updated project state."""
|
| 755 |
+
if not isinstance(preset_id, str) or preset_id not in PROMPT_TEMPLATE_PRESETS:
|
| 756 |
+
valid = ", ".join(PROMPT_TEMPLATE_PRESETS.keys())
|
| 757 |
+
raise ValueError(f"Unknown preset id: {preset_id!r}. Valid: {valid}")
|
| 758 |
+
session = _get_or_load_session(name)
|
| 759 |
+
with session.lock:
|
| 760 |
+
session.metadata["prompt_template_preset"] = preset_id
|
| 761 |
+
# Drop the legacy free-form field so we stop carrying two parallel
|
| 762 |
+
# ways to configure annotation shape.
|
| 763 |
+
session.metadata.pop("prompt_template", None)
|
| 764 |
+
_write_metadata(name, session.metadata)
|
| 765 |
+
return get_project(name)
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
# ---------- Waveform peaks --------------------------------------------------
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def _compute_peaks(audio_path: Path, n: int) -> Tuple[List[float], float]:
|
| 772 |
+
"""Return N normalized peak amplitudes + duration in seconds.
|
| 773 |
+
|
| 774 |
+
Reads N short blocks at evenly spaced offsets via soundfile.seek instead
|
| 775 |
+
of decoding the whole file. ~40x faster than librosa.load on a typical
|
| 776 |
+
30s clip; bounded I/O regardless of file length (a 5-minute clip costs
|
| 777 |
+
the same as a 30s one).
|
| 778 |
+
|
| 779 |
+
Falls back to a librosa-based decode for formats soundfile can't open
|
| 780 |
+
on this build (typically m4a/aac without ffmpeg-libsndfile).
|
| 781 |
+
"""
|
| 782 |
+
import numpy as np
|
| 783 |
+
try:
|
| 784 |
+
import soundfile as sf
|
| 785 |
+
with sf.SoundFile(str(audio_path)) as src:
|
| 786 |
+
total = src.frames
|
| 787 |
+
sr = src.samplerate
|
| 788 |
+
if total == 0:
|
| 789 |
+
return ([0.0] * n, 0.0)
|
| 790 |
+
duration = float(total / sr)
|
| 791 |
+
# ~6 buckets-worth of samples per probe gives stable peaks without
|
| 792 |
+
# devolving into "read the whole file."
|
| 793 |
+
block = max(256, total // (n * 6))
|
| 794 |
+
peaks = np.zeros(n, dtype="float32")
|
| 795 |
+
for i in range(n):
|
| 796 |
+
center = int((i + 0.5) * total / n)
|
| 797 |
+
start = max(0, center - block // 2)
|
| 798 |
+
src.seek(start)
|
| 799 |
+
data = src.read(block, dtype="float32", always_2d=False)
|
| 800 |
+
if data.ndim > 1:
|
| 801 |
+
data = data.max(axis=1)
|
| 802 |
+
if len(data):
|
| 803 |
+
peaks[i] = float(np.abs(data).max())
|
| 804 |
+
max_peak = float(peaks.max())
|
| 805 |
+
if max_peak > 0:
|
| 806 |
+
peaks = peaks / max_peak
|
| 807 |
+
return (peaks.tolist(), duration)
|
| 808 |
+
except Exception as exc:
|
| 809 |
+
logger.debug("soundfile peak path failed for %s (%s); falling back to librosa", audio_path.name, exc)
|
| 810 |
+
|
| 811 |
+
# Fallback: librosa.load handles every codec we register, at the cost of
|
| 812 |
+
# a full-file decode + resample. Slower but bulletproof.
|
| 813 |
+
import librosa
|
| 814 |
+
y, sr = librosa.load(str(audio_path), sr=8000, mono=True)
|
| 815 |
+
if len(y) == 0:
|
| 816 |
+
return ([0.0] * n, 0.0)
|
| 817 |
+
duration = float(len(y) / sr)
|
| 818 |
+
chunks = np.array_split(y, n)
|
| 819 |
+
peaks = np.array([float(np.abs(c).max()) if len(c) else 0.0 for c in chunks])
|
| 820 |
+
max_peak = peaks.max()
|
| 821 |
+
if max_peak > 0:
|
| 822 |
+
peaks = peaks / max_peak
|
| 823 |
+
return (peaks.tolist(), duration)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def get_or_compute_peaks(
|
| 827 |
+
session: ProjectSession,
|
| 828 |
+
file_name: str,
|
| 829 |
+
audio_path: Path,
|
| 830 |
+
n: int = 200,
|
| 831 |
+
) -> Tuple[List[float], float]:
|
| 832 |
+
"""Memoized per-session peak computation. Cache key is `file_name:N`."""
|
| 833 |
+
cache_key = f"{file_name}:{n}"
|
| 834 |
+
cached = session.peaks_cache.get(cache_key)
|
| 835 |
+
if cached is not None:
|
| 836 |
+
return cached
|
| 837 |
+
result = _compute_peaks(audio_path, n)
|
| 838 |
+
session.peaks_cache[cache_key] = result
|
| 839 |
+
return result
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
# ---------- Health checks ---------------------------------------------------
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def _clip_duration_sec(audio_path: Path) -> Optional[float]:
|
| 846 |
+
"""Cheap duration probe via soundfile.info() β header read, no decode."""
|
| 847 |
+
try:
|
| 848 |
+
import soundfile as sf
|
| 849 |
+
info = sf.info(str(audio_path))
|
| 850 |
+
if info.samplerate <= 0:
|
| 851 |
+
return None
|
| 852 |
+
return float(info.frames / info.samplerate)
|
| 853 |
+
except Exception:
|
| 854 |
+
return None
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
def compute_health(
|
| 858 |
+
name: str,
|
| 859 |
+
short_threshold_sec: float = 1.0,
|
| 860 |
+
) -> Dict[str, Any]:
|
| 861 |
+
"""Per-clip checks that surface dataset problems before training.
|
| 862 |
+
|
| 863 |
+
Note: we don't flag "too long" clips. The SA3 dataloader handles them
|
| 864 |
+
via random-crop per __getitem__ β long files just get sampled at
|
| 865 |
+
different windows across epochs. Slicing remains useful for annotation
|
| 866 |
+
granularity and CLAP's 10s window, but it's not a correctness issue.
|
| 867 |
+
|
| 868 |
+
We also don't flag mixed sample rates or loudness: SA3 resamples every
|
| 869 |
+
file to its model rate (T.Resample in its dataset loader) and Fragmenta
|
| 870 |
+
enables SA3's built-in -16 LUFS VolumeNorm at train/pre-encode time, so
|
| 871 |
+
both are handled automatically downstream.
|
| 872 |
+
|
| 873 |
+
short_threshold_sec defaults to 1s β clips below this end up mostly
|
| 874 |
+
silence-padded into the training window.
|
| 875 |
+
"""
|
| 876 |
+
from collections import defaultdict
|
| 877 |
+
|
| 878 |
+
# Single source of truth for what SA3's loader actually accepts. Fragmenta
|
| 879 |
+
# ingest accepts a wider set (.m4a, .aac) β those files would be silently
|
| 880 |
+
# skipped at train time, so we surface them here.
|
| 881 |
+
from app.core.training.sa3_lora_runner import SA3_AUDIO_EXTENSIONS
|
| 882 |
+
|
| 883 |
+
session = _get_or_load_session(name)
|
| 884 |
+
with session.lock:
|
| 885 |
+
clips = list(session.clips.values())
|
| 886 |
+
|
| 887 |
+
empty_prompts: List[str] = []
|
| 888 |
+
too_short: List[str] = []
|
| 889 |
+
unsupported_format: List[str] = []
|
| 890 |
+
prompt_groups: Dict[str, List[str]] = defaultdict(list)
|
| 891 |
+
|
| 892 |
+
for c in clips:
|
| 893 |
+
if not (c.prompt or "").strip():
|
| 894 |
+
empty_prompts.append(c.file_name)
|
| 895 |
+
else:
|
| 896 |
+
prompt_groups[c.prompt.strip().lower()].append(c.file_name)
|
| 897 |
+
|
| 898 |
+
ext = Path(c.file_name).suffix.lower()
|
| 899 |
+
if ext not in SA3_AUDIO_EXTENSIONS:
|
| 900 |
+
unsupported_format.append(c.file_name)
|
| 901 |
+
|
| 902 |
+
# Duration (header-only, ~free) β only used for the too-short check now.
|
| 903 |
+
dur = session.duration_cache.get(c.file_name)
|
| 904 |
+
if dur is None:
|
| 905 |
+
dur = _clip_duration_sec(Path(c.path))
|
| 906 |
+
if dur is not None:
|
| 907 |
+
session.duration_cache[c.file_name] = dur
|
| 908 |
+
if dur is not None and dur < short_threshold_sec:
|
| 909 |
+
too_short.append(c.file_name)
|
| 910 |
+
|
| 911 |
+
# --- Duplicate annotations: any non-empty prompt shared by 2+ clips.
|
| 912 |
+
dup_groups = [files for files in prompt_groups.values() if len(files) > 1]
|
| 913 |
+
dup_files = sorted({f for group in dup_groups for f in group})
|
| 914 |
+
|
| 915 |
+
empty_prompts.sort()
|
| 916 |
+
too_short.sort()
|
| 917 |
+
unsupported_format.sort()
|
| 918 |
+
|
| 919 |
+
return {
|
| 920 |
+
"total_clips": len(clips),
|
| 921 |
+
"empty_prompts": {"count": len(empty_prompts), "files": empty_prompts},
|
| 922 |
+
"too_short": {
|
| 923 |
+
"count": len(too_short),
|
| 924 |
+
"threshold_sec": short_threshold_sec,
|
| 925 |
+
"files": too_short,
|
| 926 |
+
},
|
| 927 |
+
"unsupported_format": {
|
| 928 |
+
"count": len(unsupported_format),
|
| 929 |
+
"accepted": sorted(SA3_AUDIO_EXTENSIONS),
|
| 930 |
+
"files": unsupported_format,
|
| 931 |
+
},
|
| 932 |
+
"duplicate_annotations": {
|
| 933 |
+
"count": len(dup_files),
|
| 934 |
+
"group_count": len(dup_groups),
|
| 935 |
+
"files": dup_files,
|
| 936 |
+
},
|
| 937 |
+
}
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
# ---------- Slicing ---------------------------------------------------------
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
def slice_clip(
|
| 944 |
+
name: str,
|
| 945 |
+
file_name: str,
|
| 946 |
+
target_sec: float,
|
| 947 |
+
overlap_sec: float,
|
| 948 |
+
strategy: str,
|
| 949 |
+
) -> Dict[str, Any]:
|
| 950 |
+
"""Split one clip into N children. Disk-level β happens immediately.
|
| 951 |
+
|
| 952 |
+
The parent file (and its sidecar) is deleted. Each child:
|
| 953 |
+
- lives in the project folder as `<stem>__NNN.wav`
|
| 954 |
+
- inherits the parent's in-memory prompt verbatim
|
| 955 |
+
- is uncommitted (so Discard rolls it back)
|
| 956 |
+
- keeps `parent=<parent_file_name>` in its session state
|
| 957 |
+
|
| 958 |
+
Discard cannot recover the parent file from children β same rule as
|
| 959 |
+
delete_clip. Commit makes the slice permanent.
|
| 960 |
+
"""
|
| 961 |
+
from app.backend.data.slicing import plan_slices, write_slices
|
| 962 |
+
|
| 963 |
+
session = _get_or_load_session(name)
|
| 964 |
+
proj_path = project_path(name)
|
| 965 |
+
audio_path = proj_path / file_name
|
| 966 |
+
|
| 967 |
+
if not audio_path.exists():
|
| 968 |
+
raise FileNotFoundError(f"Clip not on disk: {file_name}")
|
| 969 |
+
|
| 970 |
+
plans = plan_slices(audio_path, target_sec, overlap_sec, strategy)
|
| 971 |
+
if len(plans) <= 1:
|
| 972 |
+
raise ValueError(
|
| 973 |
+
f"{file_name} is shorter than the target duration "
|
| 974 |
+
f"({target_sec:.1f}s); nothing to slice."
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
stem = audio_path.stem
|
| 978 |
+
children = write_slices(audio_path, plans, proj_path, stem)
|
| 979 |
+
if not children:
|
| 980 |
+
raise RuntimeError("Slice produced no children β check the audio file.")
|
| 981 |
+
|
| 982 |
+
with session.lock:
|
| 983 |
+
parent_clip = session.clips.get(file_name)
|
| 984 |
+
inherited_prompt = parent_clip.prompt if parent_clip else ""
|
| 985 |
+
|
| 986 |
+
# Remove the parent from session + disk.
|
| 987 |
+
session.clips.pop(file_name, None)
|
| 988 |
+
for key in list(session.peaks_cache):
|
| 989 |
+
if key.startswith(f"{file_name}:"):
|
| 990 |
+
del session.peaks_cache[key]
|
| 991 |
+
session.duration_cache.pop(file_name, None)
|
| 992 |
+
sidecar = _sidecar_for(audio_path)
|
| 993 |
+
if audio_path.exists():
|
| 994 |
+
audio_path.unlink()
|
| 995 |
+
if sidecar.exists():
|
| 996 |
+
sidecar.unlink()
|
| 997 |
+
committed = session.metadata.get("committed_files") or []
|
| 998 |
+
if file_name in committed:
|
| 999 |
+
session.metadata["committed_files"] = [f for f in committed if f != file_name]
|
| 1000 |
+
|
| 1001 |
+
# Register children as uncommitted clips with parent linkage.
|
| 1002 |
+
for child_path in children:
|
| 1003 |
+
session.clips[child_path.name] = ClipState(
|
| 1004 |
+
file_name=child_path.name,
|
| 1005 |
+
path=str(child_path),
|
| 1006 |
+
prompt=inherited_prompt,
|
| 1007 |
+
committed_prompt="",
|
| 1008 |
+
committed=False,
|
| 1009 |
+
parent=file_name,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
# Slicing replaces the parent's audio with N children β any cached
|
| 1013 |
+
# latents reference the deleted parent and are now misaligned.
|
| 1014 |
+
_invalidate_latents(name)
|
| 1015 |
+
|
| 1016 |
+
return {
|
| 1017 |
+
"parent": file_name,
|
| 1018 |
+
"children": [
|
| 1019 |
+
{"file_name": p.name, "start_sec": pl.start_sec, "end_sec": pl.end_sec}
|
| 1020 |
+
for p, pl in zip(children, plans)
|
| 1021 |
+
],
|
| 1022 |
+
"project": get_project(name),
|
| 1023 |
+
}
|
app/backend/data/slicing.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio slicing for the Dataset Workbench.
|
| 2 |
+
|
| 3 |
+
Splits one audio file into N children. Three strategies:
|
| 4 |
+
|
| 5 |
+
hard β uniform cuts every `target_duration` seconds.
|
| 6 |
+
transient β uniform anchor points, each snapped to the nearest onset
|
| 7 |
+
(librosa.onset.onset_detect).
|
| 8 |
+
silence β uniform anchor points, each snapped to the nearest low-RMS
|
| 9 |
+
window (cleanest splice between phrases).
|
| 10 |
+
|
| 11 |
+
All three honor `overlap_sec`, applied as a head-overlap on every child
|
| 12 |
+
after the first: child i starts at (end of child i-1) - overlap_sec.
|
| 13 |
+
|
| 14 |
+
Writes WAV regardless of source format (lossless, no codec deps). Parent
|
| 15 |
+
prompt is inherited verbatim; the user edits children individually after.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import List, Literal, Tuple
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
SliceStrategy = Literal["hard", "transient", "silence"]
|
| 28 |
+
VALID_STRATEGIES = ("hard", "transient", "silence")
|
| 29 |
+
|
| 30 |
+
# How far a snap is allowed to move from the uniform anchor. Beyond this we
|
| 31 |
+
# just take the anchor β better a tidy cut than a wildly off-target chunk.
|
| 32 |
+
SNAP_WINDOW_FRAC = 0.35
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class SlicePlan:
|
| 37 |
+
"""One child's location inside the parent. Times are in seconds."""
|
| 38 |
+
index: int # 1-based
|
| 39 |
+
start_sec: float
|
| 40 |
+
end_sec: float
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _uniform_anchors(duration_sec: float, target_sec: float, overlap_sec: float) -> List[Tuple[float, float]]:
|
| 44 |
+
"""Return [(start, end), ...] for uniform cuts, before any snapping."""
|
| 45 |
+
if target_sec <= 0:
|
| 46 |
+
raise ValueError("target_duration must be positive")
|
| 47 |
+
if overlap_sec < 0 or overlap_sec >= target_sec:
|
| 48 |
+
raise ValueError("overlap_sec must be >= 0 and < target_duration")
|
| 49 |
+
step = target_sec - overlap_sec
|
| 50 |
+
anchors: List[Tuple[float, float]] = []
|
| 51 |
+
start = 0.0
|
| 52 |
+
while start < duration_sec - 0.05: # don't emit a sub-50ms tail
|
| 53 |
+
end = min(start + target_sec, duration_sec)
|
| 54 |
+
anchors.append((start, end))
|
| 55 |
+
if end >= duration_sec:
|
| 56 |
+
break
|
| 57 |
+
start += step
|
| 58 |
+
return anchors
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _snap_to_onsets(anchors: List[Tuple[float, float]], y, sr: int, target_sec: float) -> List[Tuple[float, float]]:
|
| 62 |
+
"""Snap each cut boundary to the nearest detected onset within a window."""
|
| 63 |
+
import librosa
|
| 64 |
+
import numpy as np
|
| 65 |
+
onsets = librosa.onset.onset_detect(y=y, sr=sr, units="time", backtrack=True)
|
| 66 |
+
if len(onsets) == 0:
|
| 67 |
+
return anchors
|
| 68 |
+
snap_window = target_sec * SNAP_WINDOW_FRAC
|
| 69 |
+
out: List[Tuple[float, float]] = []
|
| 70 |
+
for i, (s, e) in enumerate(anchors):
|
| 71 |
+
if i > 0:
|
| 72 |
+
# Snap the start (= previous end) to nearest onset within window.
|
| 73 |
+
candidates = onsets[(onsets >= s - snap_window) & (onsets <= s + snap_window)]
|
| 74 |
+
if len(candidates):
|
| 75 |
+
s = float(min(candidates, key=lambda t: abs(t - s)))
|
| 76 |
+
out.append((s, e))
|
| 77 |
+
# Stitch ends to match next start so no gap/overlap drift creeps in.
|
| 78 |
+
for i in range(len(out) - 1):
|
| 79 |
+
s, _ = out[i]
|
| 80 |
+
next_s, _ = out[i + 1]
|
| 81 |
+
out[i] = (s, next_s + (target_sec * 0.0)) # next_s alone β overlap is in next_s already from caller
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _snap_to_silence(anchors: List[Tuple[float, float]], y, sr: int, target_sec: float) -> List[Tuple[float, float]]:
|
| 86 |
+
"""Snap each cut boundary to the lowest-RMS frame within a window."""
|
| 87 |
+
import librosa
|
| 88 |
+
import numpy as np
|
| 89 |
+
# Frame-level RMS at ~20ms hop.
|
| 90 |
+
hop = max(1, sr // 50)
|
| 91 |
+
rms = librosa.feature.rms(y=y, frame_length=hop * 2, hop_length=hop)[0]
|
| 92 |
+
if len(rms) == 0:
|
| 93 |
+
return anchors
|
| 94 |
+
frame_times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=hop)
|
| 95 |
+
snap_window = target_sec * SNAP_WINDOW_FRAC
|
| 96 |
+
out: List[Tuple[float, float]] = []
|
| 97 |
+
for i, (s, e) in enumerate(anchors):
|
| 98 |
+
if i > 0:
|
| 99 |
+
mask = (frame_times >= s - snap_window) & (frame_times <= s + snap_window)
|
| 100 |
+
if mask.any():
|
| 101 |
+
local_idx = int(np.argmin(rms[mask]))
|
| 102 |
+
# Map masked-index back to absolute time.
|
| 103 |
+
masked_times = frame_times[mask]
|
| 104 |
+
s = float(masked_times[local_idx])
|
| 105 |
+
out.append((s, e))
|
| 106 |
+
return out
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def plan_slices(
|
| 110 |
+
audio_path: Path,
|
| 111 |
+
target_sec: float,
|
| 112 |
+
overlap_sec: float,
|
| 113 |
+
strategy: SliceStrategy,
|
| 114 |
+
) -> List[SlicePlan]:
|
| 115 |
+
"""Compute the (start, end) for each child without writing anything yet."""
|
| 116 |
+
if strategy not in VALID_STRATEGIES:
|
| 117 |
+
raise ValueError(f"Unknown strategy: {strategy}")
|
| 118 |
+
import librosa
|
| 119 |
+
# Use mono for boundary detection only; final write uses the original.
|
| 120 |
+
y, sr = librosa.load(str(audio_path), sr=22050, mono=True)
|
| 121 |
+
duration = float(len(y) / sr) if len(y) else 0.0
|
| 122 |
+
if duration <= 0:
|
| 123 |
+
raise ValueError(f"{audio_path.name} has zero duration")
|
| 124 |
+
if duration < target_sec:
|
| 125 |
+
# Single child = the whole file. Skip the slice loop entirely.
|
| 126 |
+
return [SlicePlan(index=1, start_sec=0.0, end_sec=duration)]
|
| 127 |
+
|
| 128 |
+
anchors = _uniform_anchors(duration, target_sec, overlap_sec)
|
| 129 |
+
if strategy == "transient":
|
| 130 |
+
anchors = _snap_to_onsets(anchors, y, sr, target_sec)
|
| 131 |
+
elif strategy == "silence":
|
| 132 |
+
anchors = _snap_to_silence(anchors, y, sr, target_sec)
|
| 133 |
+
|
| 134 |
+
return [
|
| 135 |
+
SlicePlan(index=i + 1, start_sec=s, end_sec=e)
|
| 136 |
+
for i, (s, e) in enumerate(anchors)
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def write_slices(
|
| 141 |
+
audio_path: Path,
|
| 142 |
+
plans: List[SlicePlan],
|
| 143 |
+
out_dir: Path,
|
| 144 |
+
stem: str,
|
| 145 |
+
) -> List[Path]:
|
| 146 |
+
"""Write children as `<stem>__001.wav`, `<stem>__002.wav`, ... in `out_dir`.
|
| 147 |
+
|
| 148 |
+
Uses soundfile for lossless WAV write at the source's native sample rate.
|
| 149 |
+
Skips names that already exist on disk to avoid clobbering.
|
| 150 |
+
"""
|
| 151 |
+
import soundfile as sf
|
| 152 |
+
import numpy as np
|
| 153 |
+
|
| 154 |
+
info = sf.info(str(audio_path))
|
| 155 |
+
sr = info.samplerate
|
| 156 |
+
total_frames = info.frames
|
| 157 |
+
written: List[Path] = []
|
| 158 |
+
width = max(3, len(str(len(plans))))
|
| 159 |
+
|
| 160 |
+
with sf.SoundFile(str(audio_path)) as src:
|
| 161 |
+
for plan in plans:
|
| 162 |
+
start_frame = max(0, int(plan.start_sec * sr))
|
| 163 |
+
end_frame = min(total_frames, int(plan.end_sec * sr))
|
| 164 |
+
if end_frame <= start_frame:
|
| 165 |
+
logger.warning("Skipping empty slice %s [%.2f-%.2f]", plan.index, plan.start_sec, plan.end_sec)
|
| 166 |
+
continue
|
| 167 |
+
src.seek(start_frame)
|
| 168 |
+
data = src.read(end_frame - start_frame, dtype="float32", always_2d=True)
|
| 169 |
+
|
| 170 |
+
child_name = f"{stem}__{plan.index:0{width}d}.wav"
|
| 171 |
+
child_path = out_dir / child_name
|
| 172 |
+
if child_path.exists():
|
| 173 |
+
# Don't silently overwrite; bump the suffix until free.
|
| 174 |
+
k = 2
|
| 175 |
+
while True:
|
| 176 |
+
candidate = out_dir / f"{stem}__{plan.index:0{width}d}_{k}.wav"
|
| 177 |
+
if not candidate.exists():
|
| 178 |
+
child_path = candidate
|
| 179 |
+
break
|
| 180 |
+
k += 1
|
| 181 |
+
sf.write(str(child_path), data, sr, subtype="PCM_16")
|
| 182 |
+
written.append(child_path)
|
| 183 |
+
return written
|
app/core/audio/midi_input.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Native MIDI input.
|
| 2 |
+
|
| 3 |
+
Reads hardware MIDI via python-rtmidi (CoreMIDI on macOS, WinMM on Windows,
|
| 4 |
+
ALSA on Linux) so MIDI works regardless of the web engine the OS gives us β
|
| 5 |
+
WKWebView has no Web MIDI, WebView2's is flaky. Same pattern as the native
|
| 6 |
+
Ableton Link binding in link_sync.py: wrap an optional native lib and no-op
|
| 7 |
+
gracefully if it isn't importable.
|
| 8 |
+
|
| 9 |
+
The backend owns the *transport* only: it enumerates input ports, opens one,
|
| 10 |
+
and broadcasts incoming messages to subscribers (drained by the SSE endpoint
|
| 11 |
+
in app.py). All mapping / learn / takeover logic stays in the frontend
|
| 12 |
+
MidiContext β it just consumes these events instead of Web MIDI.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import ctypes
|
| 17 |
+
import glob
|
| 18 |
+
import os
|
| 19 |
+
import queue
|
| 20 |
+
import sys
|
| 21 |
+
import threading
|
| 22 |
+
from typing import Any, Dict, List, Optional
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _preload_bundled_jack() -> None:
|
| 26 |
+
"""Work around a broken RPATH in python-rtmidi's manylinux wheel.
|
| 27 |
+
|
| 28 |
+
The wheel bundles libjack as `python_rtmidi/libjack-<hash>.so.*`, but the
|
| 29 |
+
`_rtmidi` extension's RPATH points at a directory that doesn't exist
|
| 30 |
+
(`$ORIGIN/../python_rtmidi.` β note the stray trailing dot), so the loader
|
| 31 |
+
can't find it and `import rtmidi` dies with
|
| 32 |
+
`ImportError: libjack-<hash>.so...: cannot open shared object file`.
|
| 33 |
+
|
| 34 |
+
The bundled lib's soname matches the extension's DT_NEEDED exactly, so
|
| 35 |
+
dlopen'ing it with RTLD_GLOBAL first lets the loader satisfy the dependency
|
| 36 |
+
from the already-loaded object. Doing it here (rather than patching the
|
| 37 |
+
venv) survives a pip reinstall and needs no patchelf/root. Linux-only; a
|
| 38 |
+
no-op everywhere the glob finds nothing.
|
| 39 |
+
"""
|
| 40 |
+
if not sys.platform.startswith("linux"):
|
| 41 |
+
return
|
| 42 |
+
for base in sys.path:
|
| 43 |
+
if not base or not os.path.isdir(base):
|
| 44 |
+
continue
|
| 45 |
+
for lib in glob.glob(os.path.join(base, "python_rtmidi*", "libjack-*.so*")):
|
| 46 |
+
try:
|
| 47 |
+
ctypes.CDLL(lib, mode=ctypes.RTLD_GLOBAL)
|
| 48 |
+
except OSError:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
import rtmidi # python-rtmidi
|
| 54 |
+
_RTMIDI_OK = True
|
| 55 |
+
except Exception: # pragma: no cover - import guard
|
| 56 |
+
# Most likely the bundled-libjack RPATH bug β preload it and retry once.
|
| 57 |
+
try:
|
| 58 |
+
_preload_bundled_jack()
|
| 59 |
+
import rtmidi
|
| 60 |
+
_RTMIDI_OK = True
|
| 61 |
+
except Exception:
|
| 62 |
+
rtmidi = None
|
| 63 |
+
_RTMIDI_OK = False
|
| 64 |
+
|
| 65 |
+
_lock = threading.Lock()
|
| 66 |
+
_midi_in: Any = None # the open rtmidi.MidiIn, or None
|
| 67 |
+
_current_port: Optional[str] = None # name of the open port, or None
|
| 68 |
+
_subscribers: List["queue.Queue"] = []
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def is_available() -> bool:
|
| 72 |
+
"""True if the native MIDI backend is importable."""
|
| 73 |
+
return _RTMIDI_OK
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def list_inputs() -> List[Dict[str, Any]]:
|
| 77 |
+
"""Enumerate input ports. `id` is the port name (stable across index
|
| 78 |
+
shuffles); `index` is its current rtmidi index."""
|
| 79 |
+
if not _RTMIDI_OK:
|
| 80 |
+
return []
|
| 81 |
+
mi = rtmidi.MidiIn()
|
| 82 |
+
try:
|
| 83 |
+
names = mi.get_ports()
|
| 84 |
+
finally:
|
| 85 |
+
mi.delete()
|
| 86 |
+
return [{"id": name, "name": name, "index": i} for i, name in enumerate(names)]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def current_port() -> Optional[str]:
|
| 90 |
+
with _lock:
|
| 91 |
+
return _current_port
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _on_message(event, _data=None) -> None:
|
| 95 |
+
"""rtmidi callback (runs on its own thread). `event` is (message, delta).
|
| 96 |
+
Broadcast the raw status/data bytes so the frontend can reuse its existing
|
| 97 |
+
Web-MIDI-shaped dispatcher unchanged."""
|
| 98 |
+
message, _delta = event
|
| 99 |
+
payload = {"data": list(message)}
|
| 100 |
+
with _lock:
|
| 101 |
+
subs = list(_subscribers)
|
| 102 |
+
for q in subs:
|
| 103 |
+
try:
|
| 104 |
+
q.put_nowait(payload)
|
| 105 |
+
except queue.Full:
|
| 106 |
+
pass # slow consumer β drop rather than block the MIDI thread
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def close_input() -> None:
|
| 110 |
+
global _midi_in, _current_port
|
| 111 |
+
with _lock:
|
| 112 |
+
mi = _midi_in
|
| 113 |
+
_midi_in = None
|
| 114 |
+
_current_port = None
|
| 115 |
+
if mi is not None:
|
| 116 |
+
try:
|
| 117 |
+
mi.cancel_callback()
|
| 118 |
+
except Exception:
|
| 119 |
+
pass
|
| 120 |
+
try:
|
| 121 |
+
mi.close_port()
|
| 122 |
+
except Exception:
|
| 123 |
+
pass
|
| 124 |
+
try:
|
| 125 |
+
mi.delete()
|
| 126 |
+
except Exception:
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def open_input(port_id: Optional[str]) -> bool:
|
| 131 |
+
"""Open the input port whose name == port_id. A falsy port_id just closes
|
| 132 |
+
the current port. Returns True on success (or on a pure close)."""
|
| 133 |
+
if not _RTMIDI_OK:
|
| 134 |
+
return False
|
| 135 |
+
close_input()
|
| 136 |
+
if not port_id:
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
mi = rtmidi.MidiIn()
|
| 140 |
+
idx = None
|
| 141 |
+
for i, name in enumerate(mi.get_ports()):
|
| 142 |
+
if name == port_id:
|
| 143 |
+
idx = i
|
| 144 |
+
break
|
| 145 |
+
if idx is None:
|
| 146 |
+
mi.delete()
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
mi.open_port(idx)
|
| 150 |
+
# Drop sysex / timing-clock / active-sensing so the stream stays to the
|
| 151 |
+
# control messages the mapper cares about (CC + notes).
|
| 152 |
+
mi.ignore_types(sysex=True, timing=True, active_sense=True)
|
| 153 |
+
mi.set_callback(_on_message)
|
| 154 |
+
|
| 155 |
+
global _midi_in, _current_port
|
| 156 |
+
with _lock:
|
| 157 |
+
_midi_in = mi
|
| 158 |
+
_current_port = port_id
|
| 159 |
+
return True
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def subscribe() -> "queue.Queue":
|
| 163 |
+
q: "queue.Queue" = queue.Queue(maxsize=512)
|
| 164 |
+
with _lock:
|
| 165 |
+
_subscribers.append(q)
|
| 166 |
+
return q
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def unsubscribe(q: "queue.Queue") -> None:
|
| 170 |
+
with _lock:
|
| 171 |
+
if q in _subscribers:
|
| 172 |
+
_subscribers.remove(q)
|
app/core/config.py
CHANGED
|
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
| 4 |
from typing import Dict, Any, Optional
|
| 5 |
import json
|
| 6 |
|
|
|
|
| 7 |
class ProjectConfig:
|
| 8 |
|
| 9 |
def __init__(self, project_root: Optional[Path] = None) -> None:
|
|
@@ -18,11 +19,11 @@ class ProjectConfig:
|
|
| 18 |
self.user_data_dir = Path.home() / "Library" / "Application Support" / "FragmentaDesktop"
|
| 19 |
else:
|
| 20 |
self.user_data_dir = Path.home() / ".local" / "share" / "FragmentaDesktop"
|
| 21 |
-
|
| 22 |
self.user_data_dir.mkdir(parents=True, exist_ok=True)
|
| 23 |
print(f"Running in frozen mode. Project root: {self.project_root}")
|
| 24 |
print(f"User data directory: {self.user_data_dir}")
|
| 25 |
-
|
| 26 |
else:
|
| 27 |
self.frozen = False
|
| 28 |
if project_root is None:
|
|
@@ -37,123 +38,52 @@ class ProjectConfig:
|
|
| 37 |
break
|
| 38 |
else:
|
| 39 |
project_root = config_file_dir
|
| 40 |
-
|
| 41 |
self.project_root: Path = Path(project_root).resolve()
|
| 42 |
self.user_data_dir = self.project_root
|
| 43 |
|
| 44 |
fine_tuned_override = os.environ.get("FRAGMENTA_FINE_TUNED_DIR")
|
| 45 |
fine_tuned_dir = Path(fine_tuned_override) if fine_tuned_override else self.user_data_dir / "models" / "fine_tuned"
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
self.paths: Dict[str, Path] = {
|
| 51 |
"models": self.user_data_dir / "models",
|
| 52 |
"models_config": self.user_data_dir / "models" / "config",
|
| 53 |
"models_pretrained": self.user_data_dir / "models" / "pretrained",
|
| 54 |
"models_fine_tuned": fine_tuned_dir,
|
| 55 |
-
"
|
| 56 |
"logs": self.user_data_dir / "logs",
|
| 57 |
"output": self.user_data_dir / "output",
|
| 58 |
|
| 59 |
"application": self.project_root,
|
| 60 |
"backend": self.project_root / "app" / "backend",
|
| 61 |
"frontend": self.project_root / "app" / "frontend",
|
| 62 |
-
"
|
| 63 |
-
"loraw_vendor": self.project_root / "vendor" / "loraw_vendor",
|
| 64 |
"venv": self.project_root / "venv",
|
| 65 |
}
|
| 66 |
|
| 67 |
self._ensure_directories()
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
|
| 71 |
def _ensure_directories(self) -> None:
|
| 72 |
|
| 73 |
for path_name, path in self.paths.items():
|
| 74 |
-
if path_name.endswith(('_fine_tuned', '
|
| 75 |
path.mkdir(parents=True, exist_ok=True)
|
| 76 |
|
| 77 |
-
def _load_model_configs(self) -> Dict[str, Dict[str, str]]:
|
| 78 |
-
|
| 79 |
-
return {
|
| 80 |
-
"stable-audio-open-1.0": {
|
| 81 |
-
"config": str(self.paths["models_config"] / "model_config.json"),
|
| 82 |
-
"ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-model.safetensors")
|
| 83 |
-
},
|
| 84 |
-
"stable-audio-open-small": {
|
| 85 |
-
"config": str(self.paths["models_config"] / "model_config_small.json"),
|
| 86 |
-
"ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-small-model.safetensors")
|
| 87 |
-
},
|
| 88 |
-
"custom": {
|
| 89 |
-
"config": str(self.paths["models_config"] / "model_config_small.json"),
|
| 90 |
-
"ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-small-model.safetensors")
|
| 91 |
-
}
|
| 92 |
-
}
|
| 93 |
-
|
| 94 |
def get_path(self, path_name: str) -> Path:
|
| 95 |
if path_name not in self.paths:
|
| 96 |
raise ValueError(f"Unknown path name: {path_name}")
|
| 97 |
return self.paths[path_name]
|
| 98 |
|
| 99 |
-
def get_model_config(self, model_name: str) -> Dict[str, str]:
|
| 100 |
-
if model_name not in self.model_configs:
|
| 101 |
-
raise ValueError(f"Unknown model: {model_name}")
|
| 102 |
-
return self.model_configs[model_name]
|
| 103 |
-
|
| 104 |
-
def get_dataset_config_path(self) -> str:
|
| 105 |
-
return str(self.paths["models_config"] / "dataset-config.json")
|
| 106 |
-
|
| 107 |
-
def get_custom_metadata_path(self) -> str:
|
| 108 |
-
return str(self.project_root / "vendor" / "stable-audio-tools" / "custom_metadata.py")
|
| 109 |
-
|
| 110 |
-
def get_metadata_json_path(self) -> str:
|
| 111 |
-
return str(self.paths["data"] / "metadata.json")
|
| 112 |
-
|
| 113 |
-
def update_dataset_config(self) -> None:
|
| 114 |
-
from app.backend.data.simple_audio_processor import SimpleAudioProcessor
|
| 115 |
-
|
| 116 |
-
try:
|
| 117 |
-
processor = SimpleAudioProcessor(
|
| 118 |
-
model_config_path=self.paths["models_config"] / "model_config.json"
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
result = processor.create_dataset_config(
|
| 122 |
-
input_dir=self.paths["data"],
|
| 123 |
-
output_dir=self.paths["data"]
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
target_config = self.paths["models_config"] / "dataset-config.json"
|
| 127 |
-
with open(target_config, 'w') as f:
|
| 128 |
-
json.dump(result["dataset_config"], f, indent=4)
|
| 129 |
-
|
| 130 |
-
print(f"Updated dataset config: {target_config}")
|
| 131 |
-
print(f"Points to {result['file_count']} original audio files")
|
| 132 |
-
print(f"Sample size: {result['sample_size']} samples ({result['sample_size']/result['sample_rate']:.1f}s)")
|
| 133 |
-
print(f"Random cropping during training (correct!)")
|
| 134 |
-
|
| 135 |
-
except Exception as e:
|
| 136 |
-
print(f"Failed to update dataset config: {e}")
|
| 137 |
-
print("Falling back to basic dataset config...")
|
| 138 |
-
|
| 139 |
-
dataset_config: Dict[str, Any] = {
|
| 140 |
-
"dataset_type": "audio_dir",
|
| 141 |
-
"datasets": [
|
| 142 |
-
{
|
| 143 |
-
"id": "fine_tune_data",
|
| 144 |
-
"path": str(self.paths["data"]),
|
| 145 |
-
"custom_metadata_module": "custom_metadata"
|
| 146 |
-
}
|
| 147 |
-
],
|
| 148 |
-
"random_crop": True
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
config_path = self.paths["models_config"] / "dataset-config.json"
|
| 152 |
-
with open(config_path, 'w') as f:
|
| 153 |
-
json.dump(dataset_config, f, indent=4)
|
| 154 |
-
|
| 155 |
-
print(f"Updated fallback dataset config: {config_path}")
|
| 156 |
-
|
| 157 |
def to_dict(self) -> Dict[str, Any]:
|
| 158 |
return {
|
| 159 |
"project_root": str(self.project_root),
|
|
|
|
| 4 |
from typing import Dict, Any, Optional
|
| 5 |
import json
|
| 6 |
|
| 7 |
+
|
| 8 |
class ProjectConfig:
|
| 9 |
|
| 10 |
def __init__(self, project_root: Optional[Path] = None) -> None:
|
|
|
|
| 19 |
self.user_data_dir = Path.home() / "Library" / "Application Support" / "FragmentaDesktop"
|
| 20 |
else:
|
| 21 |
self.user_data_dir = Path.home() / ".local" / "share" / "FragmentaDesktop"
|
| 22 |
+
|
| 23 |
self.user_data_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
print(f"Running in frozen mode. Project root: {self.project_root}")
|
| 25 |
print(f"User data directory: {self.user_data_dir}")
|
| 26 |
+
|
| 27 |
else:
|
| 28 |
self.frozen = False
|
| 29 |
if project_root is None:
|
|
|
|
| 38 |
break
|
| 39 |
else:
|
| 40 |
project_root = config_file_dir
|
| 41 |
+
|
| 42 |
self.project_root: Path = Path(project_root).resolve()
|
| 43 |
self.user_data_dir = self.project_root
|
| 44 |
|
| 45 |
fine_tuned_override = os.environ.get("FRAGMENTA_FINE_TUNED_DIR")
|
| 46 |
fine_tuned_dir = Path(fine_tuned_override) if fine_tuned_override else self.user_data_dir / "models" / "fine_tuned"
|
| 47 |
|
| 48 |
+
# Scratch area for browser folder uploads (/api/upload-folder). The
|
| 49 |
+
# SA2-era "data" dataset directory is gone in 0.2.0 β datasets are now
|
| 50 |
+
# Dataset Workbench projects under projects/.
|
| 51 |
+
uploads_override = os.environ.get("FRAGMENTA_UPLOADS_DIR")
|
| 52 |
+
uploads_dir = Path(uploads_override) if uploads_override else self.user_data_dir / "uploads"
|
| 53 |
|
| 54 |
self.paths: Dict[str, Path] = {
|
| 55 |
"models": self.user_data_dir / "models",
|
| 56 |
"models_config": self.user_data_dir / "models" / "config",
|
| 57 |
"models_pretrained": self.user_data_dir / "models" / "pretrained",
|
| 58 |
"models_fine_tuned": fine_tuned_dir,
|
| 59 |
+
"uploads": uploads_dir,
|
| 60 |
"logs": self.user_data_dir / "logs",
|
| 61 |
"output": self.user_data_dir / "output",
|
| 62 |
|
| 63 |
"application": self.project_root,
|
| 64 |
"backend": self.project_root / "app" / "backend",
|
| 65 |
"frontend": self.project_root / "app" / "frontend",
|
| 66 |
+
"stable_audio_3": self.project_root / "vendor" / "stable-audio-3",
|
|
|
|
| 67 |
"venv": self.project_root / "venv",
|
| 68 |
}
|
| 69 |
|
| 70 |
self._ensure_directories()
|
| 71 |
+
# The SA3 catalog lives in app/core/model_manager.py. This dict stays
|
| 72 |
+
# empty; it's retained only because to_dict()/print_paths() and the
|
| 73 |
+
# config validator still reference it.
|
| 74 |
+
self.model_configs: Dict[str, Dict[str, str]] = {}
|
| 75 |
|
| 76 |
def _ensure_directories(self) -> None:
|
| 77 |
|
| 78 |
for path_name, path in self.paths.items():
|
| 79 |
+
if path_name.endswith(('_fine_tuned', 'uploads')):
|
| 80 |
path.mkdir(parents=True, exist_ok=True)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def get_path(self, path_name: str) -> Path:
|
| 83 |
if path_name not in self.paths:
|
| 84 |
raise ValueError(f"Unknown path name: {path_name}")
|
| 85 |
return self.paths[path_name]
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def to_dict(self) -> Dict[str, Any]:
|
| 88 |
return {
|
| 89 |
"project_root": str(self.project_root),
|
app/core/generation/audio_generator.py
CHANGED
|
@@ -1,519 +1,536 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import re
|
| 8 |
import sys
|
| 9 |
import threading
|
| 10 |
import time
|
| 11 |
import warnings
|
| 12 |
-
from
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
def _slugify_prompt(text: str, max_len: int = 40) -> str:
|
| 21 |
-
s = re.sub(r'[^a-zA-Z0-9]+', '_', text.strip().lower())
|
| 22 |
-
s = re.sub(r'_+', '_', s).strip('_')
|
| 23 |
-
return s[:max_len] or 'untitled'
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# LoRAW lives at <project>/vendor/loraw_vendor; expose its `loraw` package for inference.
|
| 28 |
-
sys.path.append(
|
| 29 |
-
str(Path(__file__).parent.parent.parent.parent / "vendor" / "loraw_vendor"))
|
| 30 |
|
| 31 |
|
| 32 |
-
|
| 33 |
-
"
|
| 34 |
-
|
| 35 |
-
category=UserWarning,
|
| 36 |
-
)
|
| 37 |
|
| 38 |
-
from stable_audio_tools.models.utils import load_ckpt_state_dict
|
| 39 |
-
from stable_audio_tools.inference.generation import generate_diffusion_cond
|
| 40 |
-
from stable_audio_tools.models import create_model_from_config
|
| 41 |
-
from loraw.network import create_lora_from_config
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
class AudioGenerator:
|
| 47 |
-
|
| 48 |
-
self.config = config
|
| 49 |
-
self.model = None
|
| 50 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 51 |
-
self.current_model_name = None
|
| 52 |
-
self.current_model_path = None
|
| 53 |
-
self.current_model_key = None
|
| 54 |
-
self.is_distilled_small = False
|
| 55 |
-
self.is_fine_tuned = False
|
| 56 |
-
# LoRA state. `lora` holds the LoRAWrapper instance when one is active;
|
| 57 |
-
# `_active_lora_path` / `_active_lora_multiplier` are used (along with
|
| 58 |
-
# the base-model identifier) in `current_model_key` so the cache
|
| 59 |
-
# invalidates whenever the LoRA selection changes β forcing a fresh
|
| 60 |
-
# base reload because LoRAW's `activate()` is not reversible in-place.
|
| 61 |
-
self.lora = None
|
| 62 |
-
self._active_lora_path = None
|
| 63 |
-
self._active_lora_multiplier = 1.0
|
| 64 |
-
self._stop_event = threading.Event()
|
| 65 |
-
logger.info(f"Using device: {self.device}")
|
| 66 |
-
|
| 67 |
-
def _apply_lora(self, lora_path: str, lora_config: Dict[str, Any], multiplier: float = 1.0):
|
| 68 |
-
"""Wrap the currently-loaded base model with a LoRA from LoRAW.
|
| 69 |
-
|
| 70 |
-
Caller is responsible for ensuring the base model is fresh (no prior
|
| 71 |
-
LoRA injected) β typically by routing through `generate_audio`'s cache
|
| 72 |
-
invalidation, which reloads the base when the LoRA selection changes.
|
| 73 |
-
"""
|
| 74 |
-
if self.model is None:
|
| 75 |
-
raise RuntimeError("Base model must be loaded before applying a LoRA")
|
| 76 |
-
# torch.compile wraps in OptimizedModule, which prefixes named_modules()
|
| 77 |
-
# with `_orig_mod/`. LoRAW's saved state has no such prefix (training
|
| 78 |
-
# didn't compile). Operate on the underlying module so scan_model keys
|
| 79 |
-
# match the checkpoint exactly. The compiled wrapper still dispatches
|
| 80 |
-
# forward through this same module, so the LoRA stays active.
|
| 81 |
-
target = getattr(self.model, "_orig_mod", self.model)
|
| 82 |
-
full_config = {
|
| 83 |
-
"model_type": getattr(target, "model_type", "diffusion_cond"),
|
| 84 |
-
"lora": lora_config,
|
| 85 |
-
}
|
| 86 |
-
self.lora = create_lora_from_config(full_config, target)
|
| 87 |
-
state = torch.load(lora_path, map_location=self.device)
|
| 88 |
-
self.lora.load_weights(state, multiplier=multiplier)
|
| 89 |
-
self.lora.activate()
|
| 90 |
-
self._active_lora_path = lora_path
|
| 91 |
-
self._active_lora_multiplier = multiplier
|
| 92 |
-
logger.info(f"LoRA applied: {Path(lora_path).name} (multiplier={multiplier})")
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def request_stop(self) -> bool:
|
| 95 |
-
|
| 96 |
-
already_set = self._stop_event.is_set()
|
| 97 |
-
self._stop_event.set()
|
| 98 |
-
return not already_set
|
| 99 |
-
|
| 100 |
-
def load_local_base_model(self, model_name: str = "stable-audio-open-small") -> bool:
|
| 101 |
-
try:
|
| 102 |
-
logger.info(f"Loading local base model: {model_name}")
|
| 103 |
-
|
| 104 |
-
self.current_model_name = model_name
|
| 105 |
-
|
| 106 |
-
from stable_audio_tools.models.factory import create_model_from_config
|
| 107 |
-
from stable_audio_tools.models.utils import load_ckpt_state_dict
|
| 108 |
-
if "small" in model_name:
|
| 109 |
-
config_file = "model_config_small.json"
|
| 110 |
-
else:
|
| 111 |
-
config_file = "model_config.json"
|
| 112 |
-
self.is_distilled_small = "small" in model_name.lower()
|
| 113 |
-
self.is_fine_tuned = False
|
| 114 |
-
|
| 115 |
-
config_path = Path(__file__).parent.parent.parent.parent / "models" / "config" / config_file
|
| 116 |
-
logger.info(f"Using config file: {config_path}")
|
| 117 |
-
|
| 118 |
-
with open(config_path, 'r') as f:
|
| 119 |
-
import json
|
| 120 |
-
model_config = json.load(f)
|
| 121 |
-
|
| 122 |
-
self.model = create_model_from_config(model_config)
|
| 123 |
-
if model_name == 'stable-audio-open-small':
|
| 124 |
-
model_file_name = 'stable-audio-open-small-model.safetensors'
|
| 125 |
-
elif model_name == 'stable-audio-open-1.0':
|
| 126 |
-
model_file_name = 'stable-audio-open-model.safetensors'
|
| 127 |
-
else:
|
| 128 |
-
model_file_name = f"{model_name}-model.safetensors"
|
| 129 |
-
|
| 130 |
-
model_file = Path(__file__).parent.parent.parent.parent / "models" / "pretrained" / model_file_name
|
| 131 |
-
self.current_model_path = str(model_file)
|
| 132 |
-
logger.info(f"Loading weights from: {model_file}")
|
| 133 |
-
|
| 134 |
-
if not model_file.exists():
|
| 135 |
-
raise FileNotFoundError(f"Local model file not found: {model_file}")
|
| 136 |
-
|
| 137 |
-
state_dict = load_ckpt_state_dict(str(model_file))
|
| 138 |
-
self.model.load_state_dict(state_dict, strict=False)
|
| 139 |
-
|
| 140 |
-
self.model = self.model.to(self.device)
|
| 141 |
-
self.model.eval()
|
| 142 |
-
self.model.requires_grad_(False)
|
| 143 |
-
if self.device.startswith("cuda"):
|
| 144 |
-
self.model = torch.compile(self.model, mode="reduce-overhead")
|
| 145 |
-
|
| 146 |
-
logger.info("Local base model loaded successfully")
|
| 147 |
-
return True
|
| 148 |
-
|
| 149 |
-
except Exception as e:
|
| 150 |
-
logger.error(f"Failed to load local base model: {e}")
|
| 151 |
return False
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
if
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
)
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
import json
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
self.
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
|
| 215 |
-
return True
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
|
|
|
| 219 |
return False
|
|
|
|
|
|
|
|
|
|
| 220 |
|
|
|
|
| 221 |
def generate_audio(
|
| 222 |
self,
|
| 223 |
prompt: str,
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
config_file: Optional[str] = None,
|
| 227 |
duration: float = 10.0,
|
| 228 |
-
|
| 229 |
-
|
| 230 |
seed: int = -1,
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
) -> Path:
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
#
|
| 246 |
-
#
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
else
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
# 2. Fine-tuned small β distillation destroyed by SFT but the
|
| 311 |
-
# objective is still rectified-flow, so the sampler name must
|
| 312 |
-
# come from the rectified-flow family (euler|rk4|dpmpp|pingpong),
|
| 313 |
-
# NOT from the v-diffusion family. Use external CFG.
|
| 314 |
-
# 3. Large model β standard v-diffusion, accepts dpmpp-3m-sde.
|
| 315 |
-
use_distilled_recipe = self.is_distilled_small and not self.is_fine_tuned
|
| 316 |
-
if use_distilled_recipe:
|
| 317 |
-
effective_sampler = "pingpong"
|
| 318 |
-
effective_steps = 8
|
| 319 |
-
effective_cfg = 1.0
|
| 320 |
-
sigma_kwargs = {}
|
| 321 |
-
elif self.is_distilled_small:
|
| 322 |
-
effective_sampler = "dpmpp"
|
| 323 |
-
effective_steps = steps
|
| 324 |
-
effective_cfg = cfg_scale
|
| 325 |
-
sigma_kwargs = {"sigma_max": 1.0}
|
| 326 |
-
else:
|
| 327 |
-
effective_sampler = "dpmpp-3m-sde"
|
| 328 |
-
effective_steps = steps
|
| 329 |
-
effective_cfg = cfg_scale
|
| 330 |
-
sigma_kwargs = {"sigma_min": 0.03, "sigma_max": 1000}
|
| 331 |
-
|
| 332 |
-
print(f"Generating audio for prompt: '{prompt}'")
|
| 333 |
-
recipe_note = ""
|
| 334 |
-
if use_distilled_recipe:
|
| 335 |
-
recipe_note = " (distilled small overrides applied)"
|
| 336 |
-
elif self.is_fine_tuned and self.is_distilled_small:
|
| 337 |
-
recipe_note = " (fine-tuned small: rectified-flow dpmpp + external CFG)"
|
| 338 |
-
print(
|
| 339 |
-
f"Duration: {duration}s, CFG scale: {effective_cfg}, "
|
| 340 |
-
f"Steps: {effective_steps}, Sampler: {effective_sampler}"
|
| 341 |
-
+ recipe_note
|
| 342 |
-
)
|
| 343 |
-
requested_sample_size = int(duration * self.model.sample_rate)
|
| 344 |
-
max_sample_size = None
|
| 345 |
-
try:
|
| 346 |
-
max_sample_size = self.model.sample_size
|
| 347 |
-
except AttributeError:
|
| 348 |
-
if hasattr(self.model, 'model') and hasattr(self.model.model, 'sample_size'):
|
| 349 |
-
max_sample_size = self.model.model.sample_size
|
| 350 |
-
else:
|
| 351 |
-
config_path = Path(__file__).parent.parent.parent.parent / "models" / "config"
|
| 352 |
-
if hasattr(self, 'current_model_name') and self.current_model_name:
|
| 353 |
-
if 'small' in self.current_model_name:
|
| 354 |
-
config_file = config_path / "model_config_small.json"
|
| 355 |
-
else:
|
| 356 |
-
config_file = config_path / "model_config.json"
|
| 357 |
-
else:
|
| 358 |
-
if hasattr(self, 'current_model_path') and self.current_model_path:
|
| 359 |
-
model_file = Path(self.current_model_path)
|
| 360 |
-
if model_file.exists():
|
| 361 |
-
file_size_gb = model_file.stat().st_size / (1024**3)
|
| 362 |
-
if file_size_gb < 2.0:
|
| 363 |
-
config_file = config_path / "model_config_small.json"
|
| 364 |
-
else:
|
| 365 |
-
config_file = config_path / "model_config.json"
|
| 366 |
-
else:
|
| 367 |
-
config_file = config_path / "model_config_small.json"
|
| 368 |
-
else:
|
| 369 |
-
config_file = config_path / "model_config_small.json"
|
| 370 |
-
|
| 371 |
-
if config_file.exists():
|
| 372 |
-
with open(config_file, 'r') as f:
|
| 373 |
-
import json
|
| 374 |
-
config_data = json.load(f)
|
| 375 |
-
max_sample_size = config_data.get('sample_size', 44100 * 10)
|
| 376 |
-
else:
|
| 377 |
-
max_sample_size = 44100 * 10
|
| 378 |
-
if max_sample_size and requested_sample_size > max_sample_size:
|
| 379 |
-
print(f"Requested duration {duration}s exceeds model maximum. Truncating.")
|
| 380 |
-
requested_sample_size = max_sample_size
|
| 381 |
-
duration = requested_sample_size / self.model.sample_rate
|
| 382 |
-
|
| 383 |
-
if seed == -1:
|
| 384 |
-
import numpy as np
|
| 385 |
-
seed = np.random.randint(0, 2**32 - 1, dtype=np.int64)
|
| 386 |
-
|
| 387 |
-
print(f"Using seed: {seed}")
|
| 388 |
-
|
| 389 |
-
if loop_mode and max_sample_size:
|
| 390 |
-
song_seconds = max(int(duration),
|
| 391 |
-
int(max_sample_size / self.model.sample_rate))
|
| 392 |
-
else:
|
| 393 |
-
song_seconds = int(duration)
|
| 394 |
-
|
| 395 |
-
conditioning = [{
|
| 396 |
-
"prompt": prompt,
|
| 397 |
-
"seconds_start": 0,
|
| 398 |
-
"seconds_total": song_seconds,
|
| 399 |
-
}]
|
| 400 |
-
|
| 401 |
-
device = next(self.model.parameters()).device
|
| 402 |
-
print(f"Using device: {device}")
|
| 403 |
-
|
| 404 |
-
with warnings.catch_warnings():
|
| 405 |
-
# Known torchsde float-boundary chatter from dpmpp-3m-sde.
|
| 406 |
-
warnings.filterwarnings(
|
| 407 |
-
"ignore",
|
| 408 |
-
message=r"Should have tb<=t1 but got tb=.*",
|
| 409 |
-
category=UserWarning,
|
| 410 |
-
module=r"torchsde\._brownian\.brownian_interval",
|
| 411 |
)
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
category=UserWarning,
|
| 416 |
-
module=r"torchsde\._brownian\.brownian_interval",
|
| 417 |
-
)
|
| 418 |
-
|
| 419 |
-
audio = generate_diffusion_cond(
|
| 420 |
-
model=self.model,
|
| 421 |
-
steps=effective_steps,
|
| 422 |
-
cfg_scale=effective_cfg,
|
| 423 |
-
conditioning=conditioning,
|
| 424 |
-
batch_size=1,
|
| 425 |
-
sample_size=requested_sample_size,
|
| 426 |
-
seed=seed,
|
| 427 |
-
device=str(device),
|
| 428 |
-
sampler_type=effective_sampler,
|
| 429 |
-
callback=_stop_callback,
|
| 430 |
-
**sigma_kwargs,
|
| 431 |
)
|
| 432 |
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
audio = rearrange(audio, "b d n -> d (b n)").to(torch.float32)
|
| 437 |
-
audio = audio / audio.abs().max()
|
| 438 |
-
audio_int16 = (audio.clamp(-1, 1) * 32767).to(torch.int16).cpu()
|
| 439 |
-
|
| 440 |
-
if output_path is None:
|
| 441 |
-
output_dir = Path(__file__).parent.parent.parent.parent / "output"
|
| 442 |
-
output_dir.mkdir(exist_ok=True)
|
| 443 |
-
ts = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 444 |
-
slug = _slugify_prompt(prompt)
|
| 445 |
-
suffix = f"_{batch_index}" if batch_total > 1 else ""
|
| 446 |
-
output_path = output_dir / f"fragmenta_{ts}_{slug}{suffix}.wav"
|
| 447 |
-
|
| 448 |
-
self.save_audio(audio_int16, output_path, self.model.sample_rate)
|
| 449 |
-
|
| 450 |
-
print(f"AUDIO GENERATOR: Generation complete")
|
| 451 |
-
print(f" - Output file: {output_path}")
|
| 452 |
-
print(f" - Output file size: {output_path.stat().st_size} bytes")
|
| 453 |
-
|
| 454 |
-
return output_path
|
| 455 |
-
|
| 456 |
except GenerationStopped:
|
| 457 |
-
|
| 458 |
raise
|
| 459 |
-
except Exception as
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
traceback.print_exc()
|
| 463 |
raise
|
| 464 |
-
finally:
|
| 465 |
-
self._stop_event.clear()
|
| 466 |
-
|
| 467 |
-
def generate_batch(
|
| 468 |
-
self,
|
| 469 |
-
prompts: List[str],
|
| 470 |
-
duration: float = 10.0,
|
| 471 |
-
cfg_scale: float = 6.0,
|
| 472 |
-
steps: int = 250,
|
| 473 |
-
seed: int = -1,
|
| 474 |
-
output_dir: Optional[Path] = None
|
| 475 |
-
) -> List[Path]:
|
| 476 |
-
results = []
|
| 477 |
-
|
| 478 |
-
for i, prompt in enumerate(prompts):
|
| 479 |
-
print(f"Generating audio {i+1}/{len(prompts)}")
|
| 480 |
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
except Exception as e:
|
| 499 |
-
print(f"Failed to generate audio for prompt {i+1}: {e}")
|
| 500 |
-
results.append(None)
|
| 501 |
-
|
| 502 |
-
return results
|
| 503 |
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SA3 inference engine.
|
| 2 |
+
|
| 3 |
+
Thin wrapper around stable_audio_3.StableAudioModel.from_pretrained() that
|
| 4 |
+
caches the loaded model between requests (eviction on model_id change),
|
| 5 |
+
auto-detects the device, and writes 44.1 kHz stereo int16 WAV.
|
| 6 |
+
|
| 7 |
+
Cancellation is wired via `request_stop()` for API parity, but SA3's
|
| 8 |
+
generate() doesn't expose a per-step callback yet β the flag is checked
|
| 9 |
+
between calls, not inside them. A finer-grained cancel hook is a Phase
|
| 10 |
+
3.1 follow-up.
|
| 11 |
+
"""
|
| 12 |
+
import os
|
| 13 |
+
import platform
|
| 14 |
import re
|
| 15 |
import sys
|
| 16 |
import threading
|
| 17 |
import time
|
| 18 |
import warnings
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 21 |
|
| 22 |
+
import numpy as np
|
| 23 |
+
import soundfile as sf
|
| 24 |
+
import torch
|
| 25 |
|
| 26 |
+
from utils.logger import get_logger
|
| 27 |
+
|
| 28 |
+
logger = get_logger("AudioGenerator")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Live progress from the SA3 sampler. SA3's `model.generate(**sampler_kwargs)`
|
| 32 |
+
# forwards `callback=fn` into the sampler, which fires it per ODE step with
|
| 33 |
+
# `{'i': step_index, ...}`. We mirror that into this dict so the frontend can
|
| 34 |
+
# poll real progress instead of a fake ticker. Reset on each new generation.
|
| 35 |
+
_generation_state: Dict[str, Any] = {
|
| 36 |
+
"is_generating": False,
|
| 37 |
+
# idle | loading | sampling | decoding | complete | failed
|
| 38 |
+
"phase": "idle",
|
| 39 |
+
"step": 0,
|
| 40 |
+
"total_steps": 0,
|
| 41 |
+
"progress": 0, # 0-100, derived
|
| 42 |
+
"batch_index": 0,
|
| 43 |
+
"batch_total": 0,
|
| 44 |
+
"started_at": None,
|
| 45 |
+
"ended_at": None,
|
| 46 |
+
"error": None,
|
| 47 |
+
}
|
| 48 |
+
_generation_state_lock = threading.Lock()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_generation_progress() -> Dict[str, Any]:
|
| 52 |
+
"""Snapshot of the current generation's live progress. Cheap to call."""
|
| 53 |
+
with _generation_state_lock:
|
| 54 |
+
return dict(_generation_state)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _set_progress(**kwargs: Any) -> None:
|
| 58 |
+
"""Merge fields into _generation_state under the lock. Recomputes
|
| 59 |
+
`progress` automatically when step/total_steps land in the same update."""
|
| 60 |
+
with _generation_state_lock:
|
| 61 |
+
_generation_state.update(kwargs)
|
| 62 |
+
total = int(_generation_state.get("total_steps") or 0)
|
| 63 |
+
step = int(_generation_state.get("step") or 0)
|
| 64 |
+
_generation_state["progress"] = (
|
| 65 |
+
int(round(100 * step / total)) if total > 0 else 0
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _reset_progress() -> None:
|
| 70 |
+
with _generation_state_lock:
|
| 71 |
+
_generation_state.update({
|
| 72 |
+
"is_generating": False, "phase": "idle",
|
| 73 |
+
"step": 0, "total_steps": 0, "progress": 0,
|
| 74 |
+
"batch_index": 0, "batch_total": 0,
|
| 75 |
+
"started_at": None, "ended_at": None, "error": None,
|
| 76 |
+
})
|
| 77 |
+
|
| 78 |
+
# Vendored SA3 lives at <repo>/vendor/stable-audio-3 β put it on sys.path so
|
| 79 |
+
# `import stable_audio_3` resolves without a global pip install.
|
| 80 |
+
_SA3_VENDOR = Path(__file__).resolve().parents[3] / "vendor" / "stable-audio-3"
|
| 81 |
+
if str(_SA3_VENDOR) not in sys.path:
|
| 82 |
+
sys.path.insert(0, str(_SA3_VENDOR))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# model_id -> (sa3_name passed to StableAudioModel.from_pretrained,
|
| 86 |
+
# "user-visible or base" tag, max duration seconds).
|
| 87 |
+
# Kept in sync manually with _SA3_CATALOG in app/core/model_manager.py.
|
| 88 |
+
_MODEL_INFO: Dict[str, Tuple[str, str, int]] = {
|
| 89 |
+
"sa3-small-music": ("small-music", "post", 120),
|
| 90 |
+
"sa3-small-sfx": ("small-sfx", "post", 120),
|
| 91 |
+
"sa3-medium": ("medium", "post", 380),
|
| 92 |
+
"sa3-small-music-base": ("small-music-base", "base", 120),
|
| 93 |
+
"sa3-small-sfx-base": ("small-sfx-base", "base", 120),
|
| 94 |
+
"sa3-medium-base": ("medium-base", "base", 380),
|
| 95 |
+
}
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
class GenerationStopped(Exception):
|
| 99 |
+
"""Raised when an in-flight generation is interrupted by a stop request."""
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
+
def _slugify(text: str, max_len: int = 40) -> str:
|
| 103 |
+
s = re.sub(r"[^a-zA-Z0-9_-]+", "_", text or "")
|
| 104 |
+
return s[:max_len].strip("_").lower() or "audio"
|
|
|
|
|
|
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
def _autodetect_device() -> str:
|
| 108 |
+
"""cuda β mps β cpu, with FRAGMENTA_FORCE_DEVICE override."""
|
| 109 |
+
override = os.environ.get("FRAGMENTA_FORCE_DEVICE")
|
| 110 |
+
if override:
|
| 111 |
+
return override
|
| 112 |
+
if torch.cuda.is_available():
|
| 113 |
+
return "cuda"
|
| 114 |
+
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 115 |
+
return "mps"
|
| 116 |
+
return "cpu"
|
| 117 |
|
| 118 |
|
| 119 |
class AudioGenerator:
|
| 120 |
+
"""One-model warm cache. Reload only when model_id changes."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
def __init__(self, config: Any) -> None:
|
| 123 |
+
self.config = config
|
| 124 |
+
self.model: Any = None
|
| 125 |
+
self._model_id: Optional[str] = None
|
| 126 |
+
self._device: Optional[str] = None
|
| 127 |
+
self._stop_requested: bool = False
|
| 128 |
+
# Tracks LoRAs currently injected into self.model. List of
|
| 129 |
+
# {"path": str, "strength": float}. Empty when no LoRAs are active.
|
| 130 |
+
self._loaded_loras: list = []
|
| 131 |
+
|
| 132 |
+
# --- cooperative cancel ---------------------------------------------------
|
| 133 |
def request_stop(self) -> bool:
|
| 134 |
+
if self._stop_requested:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
return False
|
| 136 |
+
self._stop_requested = True
|
| 137 |
+
return True
|
| 138 |
|
| 139 |
+
# --- model load -----------------------------------------------------------
|
| 140 |
+
def _ensure_model(
|
| 141 |
+
self,
|
| 142 |
+
model_id: str,
|
| 143 |
+
device: Optional[str] = None,
|
| 144 |
+
half: bool = True,
|
| 145 |
+
) -> None:
|
| 146 |
+
if model_id not in _MODEL_INFO:
|
| 147 |
+
raise ValueError(f"Unknown SA3 model_id: {model_id}")
|
| 148 |
+
sa3_name, _kind, _max_dur = _MODEL_INFO[model_id]
|
| 149 |
+
|
| 150 |
+
if model_id in ("sa3-medium", "sa3-medium-base"):
|
| 151 |
+
# Medium normally requires Flash Attention 2 for its long-form (up
|
| 152 |
+
# to 380s) sliding-window attention. FRAGMENTA_MEDIUM_NO_FLASH=1 is
|
| 153 |
+
# the Path-B validation switch: it lets medium load WITHOUT
|
| 154 |
+
# flash_attn and fall back to PyTorch-native attention
|
| 155 |
+
# (flex_attention -> chunked-halo SDPA -> masked SDPA; see
|
| 156 |
+
# transformer.apply_attn). Output is math-equivalent, but VRAM is
|
| 157 |
+
# higher and sampling slower at long durations. Off by default, so
|
| 158 |
+
# the shipped behaviour is unchanged until the fallback is validated.
|
| 159 |
+
allow_no_flash = os.environ.get("FRAGMENTA_MEDIUM_NO_FLASH") == "1"
|
| 160 |
+
try:
|
| 161 |
+
import flash_attn # noqa: F401
|
| 162 |
+
have_flash = True
|
| 163 |
+
except ImportError as err:
|
| 164 |
+
have_flash = False
|
| 165 |
+
_flash_err = err
|
| 166 |
+
|
| 167 |
+
if not have_flash and not allow_no_flash:
|
| 168 |
+
if platform.system() == "Windows":
|
| 169 |
+
raise RuntimeError(
|
| 170 |
+
"sa3-medium requires Flash Attention 2, which doesn't "
|
| 171 |
+
"have Windows wheels. Use sa3-small-music / sa3-small-sfx, "
|
| 172 |
+
"run Fragmenta via Docker on WSL2, or set "
|
| 173 |
+
"FRAGMENTA_MEDIUM_NO_FLASH=1 to run on the (slower, "
|
| 174 |
+
"higher-memory) PyTorch attention fallback."
|
| 175 |
+
) from _flash_err
|
| 176 |
+
raise RuntimeError(
|
| 177 |
+
"sa3-medium needs Flash Attention 2 (flash_attn) but the "
|
| 178 |
+
f"current install is unusable: {_flash_err}.\n"
|
| 179 |
+
"Pick the wheel matching your torch+ABI+Python+CUDA from\n"
|
| 180 |
+
" https://github.com/Dao-AILab/flash-attention/releases\n"
|
| 181 |
+
"and install with `pip install --no-deps <wheel-url>`. "
|
| 182 |
+
"See the note next to flash-attn in requirements.txt for an example.\n"
|
| 183 |
+
"Or set FRAGMENTA_MEDIUM_NO_FLASH=1 to use the PyTorch "
|
| 184 |
+
"attention fallback."
|
| 185 |
+
) from _flash_err
|
| 186 |
+
|
| 187 |
+
if not have_flash:
|
| 188 |
+
logger.warning(
|
| 189 |
+
"sa3-medium loading WITHOUT Flash Attention 2 "
|
| 190 |
+
"(FRAGMENTA_MEDIUM_NO_FLASH=1). Using the PyTorch-native "
|
| 191 |
+
"attention fallback β expect higher VRAM and slower sampling "
|
| 192 |
+
"at long durations. Validate memory headroom before "
|
| 193 |
+
"generating long-form (up to 380s) clips."
|
| 194 |
)
|
| 195 |
|
| 196 |
+
device = device or _autodetect_device()
|
| 197 |
+
if (
|
| 198 |
+
self.model is not None
|
| 199 |
+
and self._model_id == model_id
|
| 200 |
+
and self._device == device
|
| 201 |
+
):
|
| 202 |
+
return # warm cache hit
|
| 203 |
+
|
| 204 |
+
if self.model is not None:
|
| 205 |
+
del self.model
|
| 206 |
+
self.model = None
|
| 207 |
+
if torch.cuda.is_available():
|
| 208 |
+
torch.cuda.empty_cache()
|
| 209 |
+
|
| 210 |
+
# Two layouts to support during the unification transition:
|
| 211 |
+
# 1. Canonical (post-Phase 5c): HF cache layout rooted at
|
| 212 |
+
# <app>/models/pretrained/sa3/hub/. model_manager sets
|
| 213 |
+
# HF_HUB_CACHE to that path, so StableAudioModel.from_pretrained
|
| 214 |
+
# finds files there without going to ~/.cache/huggingface.
|
| 215 |
+
# 2. Legacy: <app>/models/pretrained/sa3/<model_id>/ flat layout
|
| 216 |
+
# from earlier downloads. We fall back to direct load so
|
| 217 |
+
# pre-existing users don't have to re-download.
|
| 218 |
+
#
|
| 219 |
+
# Defense-in-depth: re-force the HF cache vars here too. model_manager
|
| 220 |
+
# sets them at construction, but if generation is reached via an
|
| 221 |
+
# alternate code path or the env was clobbered later, we still
|
| 222 |
+
# guarantee resolution into <pretrained>/sa3/hub/.
|
| 223 |
+
hub_dir = self.config.get_path("models_pretrained") / "sa3" / "hub"
|
| 224 |
+
hf_env_keys = ("HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE",
|
| 225 |
+
"TRANSFORMERS_CACHE", "HF_HUB_OFFLINE")
|
| 226 |
+
prev_env = {k: os.environ.get(k) for k in hf_env_keys}
|
| 227 |
+
os.environ["HF_HUB_CACHE"] = str(hub_dir)
|
| 228 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
|
| 229 |
+
os.environ["TRANSFORMERS_CACHE"] = str(hub_dir)
|
| 230 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 231 |
+
# huggingface_hub captures HF_HUB_CACHE and HF_HUB_OFFLINE as
|
| 232 |
+
# module-level constants AT IMPORT TIME. The Flask backend imports
|
| 233 |
+
# huggingface_hub (transitively, via model_manager.py) before we ever
|
| 234 |
+
# set these env vars, so the constants point at ~/.cache/huggingface/
|
| 235 |
+
# and offline=False. Setting os.environ now has no effect on already-
|
| 236 |
+
# captured constants. We have to monkey-patch them directly.
|
| 237 |
+
# Same trick we used for the CLAP loader.
|
| 238 |
+
prev_hub_constants = {}
|
| 239 |
+
try:
|
| 240 |
+
import huggingface_hub.constants as _hf_const
|
| 241 |
+
prev_hub_constants = {
|
| 242 |
+
"HF_HUB_CACHE": _hf_const.HF_HUB_CACHE,
|
| 243 |
+
"HF_HUB_OFFLINE": _hf_const.HF_HUB_OFFLINE,
|
| 244 |
+
}
|
| 245 |
+
_hf_const.HF_HUB_CACHE = str(hub_dir)
|
| 246 |
+
_hf_const.HF_HUB_OFFLINE = True
|
| 247 |
+
except Exception:
|
| 248 |
+
_hf_const = None
|
| 249 |
+
try:
|
| 250 |
+
try:
|
| 251 |
+
from stable_audio_3 import StableAudioModel
|
| 252 |
+
with warnings.catch_warnings():
|
| 253 |
+
warnings.simplefilter("ignore")
|
| 254 |
+
self.model = StableAudioModel.from_pretrained(
|
| 255 |
+
sa3_name, device=device, model_half=half,
|
| 256 |
+
)
|
| 257 |
+
except (FileNotFoundError, OSError) as primary_err:
|
| 258 |
+
# HF cache miss β fall back to flat layout.
|
| 259 |
+
legacy_dir = self.config.get_path("models_pretrained") / "sa3" / model_id
|
| 260 |
+
config_path = legacy_dir / "model_config.json"
|
| 261 |
+
ckpt_path = legacy_dir / "model.safetensors"
|
| 262 |
+
if not (config_path.exists() and ckpt_path.exists()):
|
| 263 |
+
raise FileNotFoundError(
|
| 264 |
+
f"Checkpoint '{model_id}' not found in HF cache "
|
| 265 |
+
f"({os.environ.get('HF_HUB_CACHE')}) or legacy flat "
|
| 266 |
+
f"layout ({legacy_dir}). Download it from the "
|
| 267 |
+
f"Checkpoint Manager."
|
| 268 |
+
) from primary_err
|
| 269 |
import json
|
| 270 |
+
with open(config_path) as fh:
|
| 271 |
+
model_config = json.load(fh)
|
| 272 |
+
from stable_audio_3.loading_utils import load_diffusion_cond
|
| 273 |
+
with warnings.catch_warnings():
|
| 274 |
+
warnings.simplefilter("ignore")
|
| 275 |
+
inner = load_diffusion_cond(
|
| 276 |
+
model_config, str(ckpt_path),
|
| 277 |
+
device=device, model_half=half,
|
| 278 |
+
)
|
| 279 |
+
inner.use_lora = False
|
| 280 |
+
inner.lora_names = []
|
| 281 |
+
self.model = StableAudioModel(inner, model_config, device, half)
|
| 282 |
+
finally:
|
| 283 |
+
for k, v in prev_env.items():
|
| 284 |
+
if v is None:
|
| 285 |
+
os.environ.pop(k, None)
|
| 286 |
+
else:
|
| 287 |
+
os.environ[k] = v
|
| 288 |
+
# Restore the patched constants so we don't permanently alter
|
| 289 |
+
# global huggingface_hub state for anything else in-process.
|
| 290 |
+
if _hf_const is not None and prev_hub_constants:
|
| 291 |
+
_hf_const.HF_HUB_CACHE = prev_hub_constants["HF_HUB_CACHE"]
|
| 292 |
+
_hf_const.HF_HUB_OFFLINE = prev_hub_constants["HF_HUB_OFFLINE"]
|
| 293 |
+
self._model_id = model_id
|
| 294 |
+
self._device = device
|
| 295 |
+
|
| 296 |
+
# --- LoRA stack -----------------------------------------------------------
|
| 297 |
+
def _apply_loras(self, loras: list) -> None:
|
| 298 |
+
"""Inject the given LoRA stack into self.model (idempotent).
|
| 299 |
+
|
| 300 |
+
loras: [{"path": str, "strength": float}, ...]
|
| 301 |
+
|
| 302 |
+
Strategy:
|
| 303 |
+
* Same paths in same order β just update strengths in place.
|
| 304 |
+
* Different paths β remove all, load fresh.
|
| 305 |
+
"""
|
| 306 |
+
if self.model is None:
|
| 307 |
+
return
|
| 308 |
+
|
| 309 |
+
new_paths = [l["path"] for l in loras]
|
| 310 |
+
cur_paths = [l["path"] for l in self._loaded_loras]
|
| 311 |
|
| 312 |
+
if new_paths == cur_paths:
|
| 313 |
+
# Path-set unchanged; only strengths may have moved.
|
| 314 |
+
for i, l in enumerate(loras):
|
| 315 |
+
self.model.set_lora_strength(l["strength"], lora_index=i)
|
| 316 |
+
self._loaded_loras = list(loras)
|
| 317 |
+
return
|
| 318 |
|
| 319 |
+
# Path-set changed. Remove any currently loaded, then load the new set.
|
| 320 |
+
if cur_paths:
|
| 321 |
+
try:
|
| 322 |
+
from stable_audio_3.models.lora import remove_lora
|
| 323 |
+
# SA3 applies LoRA to the DiffusionCond's DiT (.model) and
|
| 324 |
+
# conditioner (.conditioner) β mirror StableAudioModel's own
|
| 325 |
+
# set_lora_strength which iterates both submodules.
|
| 326 |
+
# `self.model` is StableAudioModel; `self.model.model` is the
|
| 327 |
+
# inner DiffusionCond.
|
| 328 |
+
#
|
| 329 |
+
# remove_lora() strips *every* LoRA parametrization in one
|
| 330 |
+
# pass. We use it instead of remove_lora_by_index(..., 0) in a
|
| 331 |
+
# loop: removal does NOT renumber the remaining adapters, so
|
| 332 |
+
# repeatedly popping index 0 only ever clears the first LoRA
|
| 333 |
+
# and leaves indices 1..n-1 stranded β stale adapters then
|
| 334 |
+
# contaminate every later generation with a different stack.
|
| 335 |
+
inner = self.model.model
|
| 336 |
+
remove_lora(inner.model)
|
| 337 |
+
remove_lora(inner.conditioner)
|
| 338 |
+
except Exception as exc:
|
| 339 |
+
# If removal fails (e.g. an upstream API change), force a
|
| 340 |
+
# base-model reload so we don't carry stale adapters. KEEP
|
| 341 |
+
# _model_id intact β _ensure_model needs it to know what to
|
| 342 |
+
# reload. (Previous code zeroed it; the reload then raised
|
| 343 |
+
# "Unknown SA3 model_id: None".)
|
| 344 |
+
logger.warning(
|
| 345 |
+
"LoRA removal failed (%s); reloading base model %s",
|
| 346 |
+
exc, self._model_id,
|
| 347 |
+
)
|
| 348 |
+
self.model = None
|
| 349 |
|
| 350 |
+
if self.model is None and self._model_id is not None:
|
| 351 |
+
# Forced full reload (only if remove failed above).
|
| 352 |
+
self._ensure_model(self._model_id, device=self._device, half=True)
|
| 353 |
|
| 354 |
+
if loras:
|
| 355 |
+
with warnings.catch_warnings():
|
| 356 |
+
warnings.simplefilter("ignore")
|
| 357 |
+
self.model.load_lora(new_paths)
|
| 358 |
+
for i, l in enumerate(loras):
|
| 359 |
+
self.model.set_lora_strength(l["strength"], lora_index=i)
|
| 360 |
|
| 361 |
+
self._loaded_loras = list(loras)
|
|
|
|
| 362 |
|
| 363 |
+
def set_lora_strength(self, index: int, strength: float) -> bool:
|
| 364 |
+
"""Live-update one slot's strength. Returns False if index invalid."""
|
| 365 |
+
if not self.model or index < 0 or index >= len(self._loaded_loras):
|
| 366 |
return False
|
| 367 |
+
self.model.set_lora_strength(float(strength), lora_index=index)
|
| 368 |
+
self._loaded_loras[index]["strength"] = float(strength)
|
| 369 |
+
return True
|
| 370 |
|
| 371 |
+
# --- public entry ---------------------------------------------------------
|
| 372 |
def generate_audio(
|
| 373 |
self,
|
| 374 |
prompt: str,
|
| 375 |
+
*,
|
| 376 |
+
model_id: str,
|
|
|
|
| 377 |
duration: float = 10.0,
|
| 378 |
+
steps: Optional[int] = None,
|
| 379 |
+
cfg_scale: Optional[float] = None,
|
| 380 |
seed: int = -1,
|
| 381 |
+
negative_prompt: Optional[str] = None,
|
| 382 |
+
batch_size: int = 1,
|
| 383 |
+
device: Optional[str] = None,
|
| 384 |
+
half: bool = True,
|
| 385 |
+
chunked_decode: Optional[bool] = None,
|
| 386 |
+
loop_mode: bool = False, # bars-mode passthrough
|
| 387 |
+
loras: Optional[list] = None, # [{path, strength}, ...]
|
| 388 |
+
# Phase 7: audio-to-audio + inpainting -----------------------------
|
| 389 |
+
init_audio_path: Optional[str] = None,
|
| 390 |
+
init_noise_level: float = 1.0,
|
| 391 |
+
inpaint_audio_path: Optional[str] = None,
|
| 392 |
+
inpaint_starts: Optional[list] = None, # list[float], seconds
|
| 393 |
+
inpaint_ends: Optional[list] = None,
|
| 394 |
+
# Phase 7: seamless looping ----------------------------------------
|
| 395 |
+
loop_stitch: Optional[str] = None, # "inpaint" | "crossfade" | None
|
| 396 |
+
loop_bars: Optional[int] = None,
|
| 397 |
+
loop_bpm: Optional[float] = None,
|
| 398 |
+
**_ignored_legacy_kwargs: Any,
|
| 399 |
) -> Path:
|
| 400 |
+
self._stop_requested = False
|
| 401 |
+
if self._stop_requested: # honour pre-call stop
|
| 402 |
+
raise GenerationStopped()
|
| 403 |
+
|
| 404 |
+
# `loop_stitch` / `loop_bars` / `loop_bpm` are accepted for API
|
| 405 |
+
# compatibility but ignored β the seamless-loop pipeline was
|
| 406 |
+
# removed because user A/B testing showed it degraded audio
|
| 407 |
+
# quality on every prompt class. We deliver raw model output.
|
| 408 |
+
|
| 409 |
+
_set_progress(
|
| 410 |
+
is_generating=True, phase="loading",
|
| 411 |
+
step=0, total_steps=0, error=None,
|
| 412 |
+
started_at=time.time(), ended_at=None,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
self._ensure_model(model_id, device=device, half=half)
|
| 416 |
+
self._apply_loras(loras or [])
|
| 417 |
+
|
| 418 |
+
init_audio = self._load_audio(init_audio_path) if init_audio_path else None
|
| 419 |
+
inpaint_audio = self._load_audio(inpaint_audio_path) if inpaint_audio_path else None
|
| 420 |
+
|
| 421 |
+
_, kind, max_dur = _MODEL_INFO[model_id]
|
| 422 |
+
is_base = (kind == "base")
|
| 423 |
+
|
| 424 |
+
# Defaults differ by model kind. Post-trained models distilled CFG
|
| 425 |
+
# away; we force cfg=1.0 there even if the caller overrides.
|
| 426 |
+
effective_steps = int(steps) if steps else (50 if is_base else 8)
|
| 427 |
+
effective_cfg = float(cfg_scale) if (cfg_scale is not None and is_base) else (
|
| 428 |
+
7.0 if is_base else 1.0
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
duration = float(min(max(1.0, float(duration)), float(max_dur)))
|
| 432 |
+
|
| 433 |
+
target_samples = int(round(duration * 44100))
|
| 434 |
+
gen_duration = duration
|
| 435 |
+
total_steps_logical = effective_steps
|
| 436 |
+
|
| 437 |
+
if self._stop_requested: # one more check before the heavy call
|
| 438 |
+
raise GenerationStopped()
|
| 439 |
+
|
| 440 |
+
# Sampler callback β fires per ODE step. Also gives us a cheap
|
| 441 |
+
# cancellation hook: raising mid-callback aborts the sampler.
|
| 442 |
+
def _cb(info: Dict[str, Any]) -> None:
|
| 443 |
+
if self._stop_requested:
|
| 444 |
+
raise GenerationStopped()
|
| 445 |
+
i = info.get("i")
|
| 446 |
+
if isinstance(i, int):
|
| 447 |
+
_set_progress(step=min(i + 1, total_steps_logical))
|
| 448 |
+
|
| 449 |
+
_set_progress(phase="sampling", total_steps=int(total_steps_logical), step=0)
|
| 450 |
+
|
| 451 |
+
gen_kwargs = dict(
|
| 452 |
+
prompt=prompt,
|
| 453 |
+
negative_prompt=negative_prompt or None,
|
| 454 |
+
duration=gen_duration,
|
| 455 |
+
steps=effective_steps,
|
| 456 |
+
cfg_scale=effective_cfg,
|
| 457 |
+
seed=int(seed),
|
| 458 |
+
batch_size=int(batch_size),
|
| 459 |
+
chunked_decode=chunked_decode,
|
| 460 |
+
callback=_cb,
|
| 461 |
+
)
|
| 462 |
+
if init_audio is not None:
|
| 463 |
+
gen_kwargs["init_audio"] = init_audio
|
| 464 |
+
gen_kwargs["init_noise_level"] = float(init_noise_level)
|
| 465 |
+
if inpaint_audio is not None:
|
| 466 |
+
gen_kwargs["inpaint_audio"] = inpaint_audio
|
| 467 |
+
if inpaint_starts is not None and len(inpaint_starts) > 0:
|
| 468 |
+
# SA3 accepts a single float or a list for multi-region.
|
| 469 |
+
gen_kwargs["inpaint_mask_start_seconds"] = (
|
| 470 |
+
list(inpaint_starts) if len(inpaint_starts) > 1 else float(inpaint_starts[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
)
|
| 472 |
+
if inpaint_ends is not None and len(inpaint_ends) > 0:
|
| 473 |
+
gen_kwargs["inpaint_mask_end_seconds"] = (
|
| 474 |
+
list(inpaint_ends) if len(inpaint_ends) > 1 else float(inpaint_ends[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
)
|
| 476 |
|
| 477 |
+
try:
|
| 478 |
+
audio = self.model.generate(**gen_kwargs)
|
| 479 |
+
# audio: torch.Tensor[B, channels=2, samples] in [-1, 1] @ 44.1 kHz
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
except GenerationStopped:
|
| 481 |
+
_set_progress(phase="idle", is_generating=False, ended_at=time.time())
|
| 482 |
raise
|
| 483 |
+
except Exception as exc:
|
| 484 |
+
_set_progress(phase="failed", is_generating=False,
|
| 485 |
+
error=str(exc), ended_at=time.time())
|
|
|
|
| 486 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
+
# Seamless-loop processing (quantize, inpaint, crossfade) was
|
| 489 |
+
# removed: the user A/B-compared raw SA3 output against the full
|
| 490 |
+
# pipeline and confirmed the post-processing made every prompt
|
| 491 |
+
# worse β silence-at-start on percussion, smeared transients,
|
| 492 |
+
# off-grid anchoring. We now deliver the raw model output. The
|
| 493 |
+
# `loop_stitch` / `loop_bars` / `loop_bpm` parameters are still
|
| 494 |
+
# accepted from the frontend for API compatibility but are
|
| 495 |
+
# ignored. Performance-Bars looping will have an audible click
|
| 496 |
+
# at the wrap point and multi-channel stacks will not be
|
| 497 |
+
# sample-aligned β both acceptable trade-offs vs. the artifacts
|
| 498 |
+
# the quantizer was introducing.
|
| 499 |
+
_set_progress(phase="decoding", step=total_steps_logical)
|
| 500 |
+
try:
|
| 501 |
+
return self._finalize(audio, prompt=prompt, model_id=model_id)
|
| 502 |
+
finally:
|
| 503 |
+
_set_progress(phase="complete", is_generating=False,
|
| 504 |
+
step=total_steps_logical, ended_at=time.time())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
+
# --- audio loader (a2a + inpaint inputs) ----------------------------------
|
| 507 |
+
@staticmethod
|
| 508 |
+
def _load_audio(path: str):
|
| 509 |
+
"""Load a wav/mp3/flac into the (sample_rate, tensor) tuple SA3 expects.
|
| 510 |
|
| 511 |
+
Returns a stereo float32 tensor of shape (channels, samples). Mono
|
| 512 |
+
inputs are duplicated to stereo (SA3 expects 2 channels); β₯3-channel
|
| 513 |
+
inputs are truncated to the first 2.
|
| 514 |
+
"""
|
| 515 |
+
import torchaudio
|
| 516 |
+
wav, sr = torchaudio.load(str(path)) # (channels, samples), float32
|
| 517 |
+
if wav.shape[0] == 1:
|
| 518 |
+
wav = wav.repeat(2, 1)
|
| 519 |
+
elif wav.shape[0] > 2:
|
| 520 |
+
wav = wav[:2]
|
| 521 |
+
return int(sr), wav
|
| 522 |
+
|
| 523 |
+
# --- output --------------------------------------------------------------
|
| 524 |
+
def _finalize(self, audio: torch.Tensor, *, prompt: str, model_id: str) -> Path:
|
| 525 |
+
audio = audio.detach().clamp_(-1.0, 1.0).cpu()
|
| 526 |
+
if audio.ndim != 3:
|
| 527 |
+
raise RuntimeError(f"Unexpected SA3 output shape {tuple(audio.shape)}")
|
| 528 |
+
first = audio[0] # [C, samples]
|
| 529 |
+
pcm = (first.numpy() * 32767.0).astype(np.int16).T # β [samples, C]
|
| 530 |
+
|
| 531 |
+
out_dir = self.config.get_path("output")
|
| 532 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 533 |
+
ts = time.strftime("%Y%m%d_%H%M%S")
|
| 534 |
+
out_path = out_dir / f"{ts}_{model_id}_{_slugify(prompt)}.wav"
|
| 535 |
+
sf.write(str(out_path), pcm, 44100, subtype="PCM_16")
|
| 536 |
+
return out_path
|
app/core/generation/audio_post_process.py
CHANGED
|
@@ -1,9 +1,40 @@
|
|
| 1 |
"""Beat-align and tempo-conform a generated WAV to a target BPM and bar count.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"""
|
| 3 |
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import logging
|
|
|
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Optional, Tuple
|
| 9 |
|
|
@@ -14,35 +45,429 @@ import soundfile as sf
|
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
_STRETCH_SAFE_MIN = 0.6
|
| 22 |
_STRETCH_SAFE_MAX = 1.7
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def align_to_grid(
|
| 26 |
input_path: Path,
|
| 27 |
target_bpm: float,
|
| 28 |
target_bars: int,
|
| 29 |
beats_per_bar: int = 4,
|
| 30 |
) -> Path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
audio, sr = sf.read(str(input_path), always_2d=True)
|
| 32 |
audio = audio.astype(np.float32, copy=False)
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
|
| 36 |
|
| 37 |
-
|
| 38 |
-
# from half-time / double-time interpretations of the same grid.
|
| 39 |
-
detected_bpm, first_beat = _detect_grid_anchor(mono, sr, start_bpm=target_bpm)
|
| 40 |
|
|
|
|
| 41 |
head_offset = 0
|
| 42 |
-
if
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
head_offset = _detect_first_onset_sample(mono, sr)
|
| 47 |
if head_offset > 0:
|
| 48 |
logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms (onset fallback)")
|
|
@@ -50,12 +475,26 @@ def align_to_grid(
|
|
| 50 |
if head_offset > 0:
|
| 51 |
audio = audio[head_offset:]
|
| 52 |
mono = mono[head_offset:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
| 54 |
if detected_bpm is not None:
|
| 55 |
-
rate, effective_bpm = _best_stretch_rate(
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
interp_note = (
|
| 60 |
f" (interpreted as {effective_bpm:.2f} BPM, "
|
| 61 |
f"octave={effective_bpm / detected_bpm:.2f}Γ)"
|
|
@@ -66,33 +505,267 @@ def align_to_grid(
|
|
| 66 |
f"align_to_grid: detected {detected_bpm:.2f} BPM{interp_note}, "
|
| 67 |
f"stretched by {rate:.4f} to match target {target_bpm:.2f} BPM"
|
| 68 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
else:
|
| 70 |
logger.info(
|
| 71 |
f"align_to_grid: detected {detected_bpm:.2f} BPM has no safe "
|
| 72 |
-
f"interpretation vs target {target_bpm:.2f}
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
else:
|
| 75 |
logger.info("align_to_grid: no usable tempo detected; skipping warp")
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if audio.shape[0] > target_samples:
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
|
| 89 |
sf.write(str(input_path), audio, sr, subtype="PCM_16")
|
| 90 |
return input_path
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def _best_stretch_rate(
|
| 94 |
detected_bpm: float,
|
| 95 |
target_bpm: float,
|
|
|
|
|
|
|
|
|
|
| 96 |
) -> Tuple[Optional[float], float]:
|
| 97 |
"""Pick the time-stretch rate that maps detected β target, considering
|
| 98 |
half-time and double-time interpretations of the detected tempo. Returns
|
|
@@ -101,27 +774,20 @@ def _best_stretch_rate(
|
|
| 101 |
nothing safe is available.
|
| 102 |
|
| 103 |
Order of preference:
|
| 104 |
-
1. Detected as-is, if it lands inside
|
| 105 |
2. Octave-corrected (detected Γ 0.5 or Γ 2.0), only when the as-is
|
| 106 |
interpretation is out of range. This is the librosa half-/double-
|
| 107 |
time error recovery path.
|
| 108 |
-
|
| 109 |
-
This biases the algorithm toward honesty: only re-interpret the
|
| 110 |
-
detector's reading when it can't otherwise produce a usable stretch.
|
| 111 |
"""
|
| 112 |
-
# First, try the detector's reading at face value.
|
| 113 |
rate_asis = target_bpm / detected_bpm
|
| 114 |
-
if
|
| 115 |
return rate_asis, detected_bpm
|
| 116 |
|
| 117 |
-
# As-is is out of safe range β almost certainly a librosa octave error.
|
| 118 |
-
# Try the half-time and double-time reinterpretations and pick whichever
|
| 119 |
-
# is closest to a no-op stretch.
|
| 120 |
candidates = []
|
| 121 |
for octave_factor in (0.5, 2.0):
|
| 122 |
interpreted = detected_bpm * octave_factor
|
| 123 |
rate = target_bpm / interpreted
|
| 124 |
-
if
|
| 125 |
candidates.append((abs(rate - 1.0), rate, interpreted))
|
| 126 |
if not candidates:
|
| 127 |
return None, detected_bpm
|
|
@@ -130,6 +796,7 @@ def _best_stretch_rate(
|
|
| 130 |
return best_rate, best_interp
|
| 131 |
|
| 132 |
|
|
|
|
| 133 |
def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int:
|
| 134 |
"""Return the sample index of the first detected onset, or 0 if none found."""
|
| 135 |
try:
|
|
@@ -147,15 +814,16 @@ def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int:
|
|
| 147 |
return first
|
| 148 |
|
| 149 |
|
| 150 |
-
|
|
|
|
| 151 |
mono: np.ndarray,
|
| 152 |
sr: int,
|
| 153 |
start_bpm: Optional[float] = None,
|
| 154 |
-
) -> Tuple[Optional[float], Optional[
|
| 155 |
-
"""Run librosa beat tracking with the target tempo as a prior.
|
| 156 |
-
start_bpm reduces (but doesn't
|
| 157 |
-
|
| 158 |
-
still gets wrong."""
|
| 159 |
try:
|
| 160 |
kwargs = {"y": mono, "sr": sr, "units": "samples"}
|
| 161 |
if start_bpm is not None and start_bpm > 0:
|
|
@@ -169,9 +837,10 @@ def _detect_grid_anchor(
|
|
| 169 |
bpm = float(np.atleast_1d(tempo).flatten()[0])
|
| 170 |
if not (40.0 <= bpm <= 240.0):
|
| 171 |
return None, None
|
| 172 |
-
return bpm,
|
| 173 |
|
| 174 |
|
|
|
|
| 175 |
def _time_stretch_multichannel(audio: np.ndarray, rate: float) -> np.ndarray:
|
| 176 |
"""Phase-vocoder time stretch, applied per channel and re-stacked."""
|
| 177 |
stretched = librosa.effects.time_stretch(audio.T, rate=rate)
|
|
|
|
| 1 |
"""Beat-align and tempo-conform a generated WAV to a target BPM and bar count.
|
| 2 |
+
|
| 3 |
+
DEPRECATED β this entire module is superseded by ``app/core/loop_quantizer/``
|
| 4 |
+
(see ``task_1.md`` and ``AUDIT.md`` Β§9 "Scheduled for removal"). The legacy
|
| 5 |
+
``align_to_grid`` / ``align_for_loop`` path and the gated ``_stage_a_v2`` path
|
| 6 |
+
both live here until the new module passes acceptance; once it does, every
|
| 7 |
+
public symbol below is removed and the file deletes itself. Do NOT add new
|
| 8 |
+
callers, do NOT extend the v1 helpers, and prefer adding work directly under
|
| 9 |
+
``app/core/loop_quantizer/`` for any new behaviour.
|
| 10 |
+
|
| 11 |
+
SA3 generates at the exact requested duration via variable-length flow
|
| 12 |
+
matching, so the post-processor's role is **drift correction**, not length
|
| 13 |
+
control: it only nudges the audio when librosa detects that the realised
|
| 14 |
+
tempo has drifted from the target. The tempo-conform gate is intentionally
|
| 15 |
+
tight β `|rate - 1| > 5%` AND `rate in [0.85, 1.15]` β so we never warp
|
| 16 |
+
audibly when SA3 was already close.
|
| 17 |
+
|
| 18 |
+
Pipeline (in order):
|
| 19 |
+
1. Detect tempo + beat grid via librosa (with target BPM as prior).
|
| 20 |
+
2. Head-trim to the first detected beat (or first onset as fallback),
|
| 21 |
+
followed by a 3 ms equal-power fade-in to mask the trim seam.
|
| 22 |
+
3. Tempo-conform via phase-vocoder time-stretch, ONLY when the detected
|
| 23 |
+
tempo drifts >5% from target AND the resulting stretch lies inside
|
| 24 |
+
the safe range [0.85, 1.15]. Outside this window we leave the audio
|
| 25 |
+
alone and let the user re-roll.
|
| 26 |
+
4. End-anchored truncation: snap the cut to the nearest detected beat
|
| 27 |
+
within Β±Β½ beat of the mathematical target sample count, so loops
|
| 28 |
+
don't end mid-note. Followed by an 8 ms equal-power fade-out so the
|
| 29 |
+
loop seam doesn't click.
|
| 30 |
+
5. Zero-pad if the audio came out shorter than the target.
|
| 31 |
"""
|
| 32 |
|
| 33 |
from __future__ import annotations
|
| 34 |
|
| 35 |
import logging
|
| 36 |
+
import os
|
| 37 |
+
import warnings
|
| 38 |
from pathlib import Path
|
| 39 |
from typing import Optional, Tuple
|
| 40 |
|
|
|
|
| 45 |
logger = logging.getLogger(__name__)
|
| 46 |
|
| 47 |
|
| 48 |
+
# DEPRECATED: flag goes away with the v1/v2 split (AUDIT.md Β§9d).
|
| 49 |
+
def beatsync_v2_enabled() -> bool:
|
| 50 |
+
"""Feature gate for the hardened Stage A pipeline (sample-exact length,
|
| 51 |
+
first-transient-to-zero alignment, transient-preserving stretch).
|
| 52 |
+
|
| 53 |
+
Off by default: with the flag unset, every Stage A function takes its
|
| 54 |
+
legacy code path, so Bars-mode output is byte-identical to pre-flag
|
| 55 |
+
builds and Seconds mode (which never enters Stage A at all) is unaffected.
|
| 56 |
+
Enable with ``FRAGMENTA_BEATSYNC_V2=1``.
|
| 57 |
+
"""
|
| 58 |
+
return os.environ.get("FRAGMENTA_BEATSYNC_V2", "0").strip().lower() in (
|
| 59 |
+
"1", "true", "yes", "on",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# DEPRECATED: flag goes away with the v1/v2 split (AUDIT.md Β§9d).
|
| 64 |
+
def _warp_enabled() -> bool:
|
| 65 |
+
"""Per-beat (Ableton 'Beats'-style) warp gate β OFF by default.
|
| 66 |
+
|
| 67 |
+
The warp is only as reliable as librosa's per-beat detection; on real audio
|
| 68 |
+
a single mis-detected beat scrambles the groove. Anchor + exact-crop already
|
| 69 |
+
lands real loops at ~3 ms, so the warp is opt-in for experimentation only.
|
| 70 |
+
Enable with ``FRAGMENTA_BEATSYNC_WARP=1``."""
|
| 71 |
+
return os.environ.get("FRAGMENTA_BEATSYNC_WARP", "0").strip().lower() in (
|
| 72 |
+
"1", "true", "yes", "on",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Liberal module-default range for `_best_stretch_rate`. Kept wide so any
|
| 77 |
+
# future force-warp caller has room; the bars-mode drift-correction path
|
| 78 |
+
# (`align_to_grid`) overrides with tighter bounds below.
|
| 79 |
_STRETCH_SAFE_MIN = 0.6
|
| 80 |
_STRETCH_SAFE_MAX = 1.7
|
| 81 |
|
| 82 |
+
# Bars-mode drift correction. SA3 hits the requested duration exactly via
|
| 83 |
+
# variable-length generation, so the post-processor only kicks in when the
|
| 84 |
+
# detected tempo of the generated audio drifts from the requested target.
|
| 85 |
+
# Tight gates avoid audible vocoder artifacts when SA3 was already close.
|
| 86 |
+
_BARS_MODE_STRETCH_MIN = 0.85
|
| 87 |
+
_BARS_MODE_STRETCH_MAX = 1.15
|
| 88 |
+
_BARS_MODE_DEADBAND = 0.05
|
| 89 |
+
|
| 90 |
+
# Loop-mode (Phase 7) is stricter β a 5% tempo slack compounds visibly when
|
| 91 |
+
# multiple loop channels run side-by-side, even though loop iteration
|
| 92 |
+
# lengths are sample-exact. 0.5% is below librosa's noise floor for beat
|
| 93 |
+
# detection on rhythmic content, so we won't be acting on noise, but we
|
| 94 |
+
# WILL correct anything detectable that the looser bars-mode would skip.
|
| 95 |
+
_LOOP_MODE_DEADBAND = 0.005
|
| 96 |
+
|
| 97 |
+
# Fade durations applied at trim points. Kept very short β the fade is
|
| 98 |
+
# click-prevention, not a perceptible ramp. Performance Mode loops these
|
| 99 |
+
# clips, and longer fades audibly "duck" the loop seam.
|
| 100 |
+
_HEAD_FADE_SEC = 0.003 # mask click at the trimmed head
|
| 101 |
+
_TAIL_FADE_SEC = 0.003 # mask click at a mid-note truncation; skipped on beats
|
| 102 |
+
|
| 103 |
+
# Trailing-silence detection. SA3 occasionally pads a generation with low-
|
| 104 |
+
# level tail; the post-processor used to keep that and fade over it, which
|
| 105 |
+
# produced perceptible "silence + duck" at the loop point.
|
| 106 |
+
_SILENCE_THRESHOLD_DB = -50.0 # anything below is silence
|
| 107 |
+
_SILENCE_WINDOW_SEC = 0.05 # RMS window granularity
|
| 108 |
+
_SILENCE_TAIL_KEEP_SEC = 0.010 # leave a tiny natural decay
|
| 109 |
+
|
| 110 |
+
# v2 first-transient search: a downbeat lands within the first bar or two of
|
| 111 |
+
# generated content, so we never hunt past this window for the musical "1".
|
| 112 |
+
_V2_TRANSIENT_SEARCH_SEC = 1.5
|
| 113 |
+
_V2_STRONG_RATIO = 0.30 # candidate must reach 30% of peak
|
| 114 |
+
_V2_RISE_RATIO = 0.15 # rising-edge threshold for refinement
|
| 115 |
+
_V2_REFINE_WIN_SEC = 0.03 # +/- window for sample-accurate refine
|
| 116 |
+
|
| 117 |
+
# Grid confidence. librosa's beat tracker emits a tempo for ANY input β on
|
| 118 |
+
# ambient/textural content it is essentially noise (measured: 49-161 BPM on a
|
| 119 |
+
# 120-BPM target, 130+ ms intra-beat drift). Warping toward a wrong detected
|
| 120 |
+
# tempo is worse than not warping, so we only tempo-conform when the detected
|
| 121 |
+
# grid is trustworthy: beats evenly spaced (low interval CV) AND a clear pulse
|
| 122 |
+
# in the onset envelope. Below the threshold we trust the *requested* grid and
|
| 123 |
+
# skip the stretch (still doing the safe, tempo-independent transient@0 + crop).
|
| 124 |
+
# Calibrated on real fixtures: clean drum/bass loops score 0.76-0.88, pure
|
| 125 |
+
# pads 0.00 (no trackable beat), and ambiguous textures 0.44-0.57 β often with
|
| 126 |
+
# a wrong detected tempo. 0.65 sits in that gap. (The safe-range gate in
|
| 127 |
+
# _best_stretch_rate independently rejects octave-wrong tempos like 49/161 BPM.)
|
| 128 |
+
_GRID_CONFIDENCE_MIN = 0.65
|
| 129 |
+
_CV_MAX = 0.20 # interval CV at which regularity -> 0
|
| 130 |
+
|
| 131 |
+
# Beat-synchronous warp (Ableton "Beats"-style). Measured: real drum loops are
|
| 132 |
+
# already coherent to ~3-6 ms, where anchor+exact-crop alone lands single-digit
|
| 133 |
+
# ms β so a global/elastic warp there only adds phase-vocoder jitter for no gain.
|
| 134 |
+
# We therefore warp ONLY when a confident grid still drifts past this threshold,
|
| 135 |
+
# and need enough beats to define segments.
|
| 136 |
+
_WARP_DRIFT_MIN_MS = 15.0
|
| 137 |
+
_WARP_MIN_BEATS = 6
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# === Stage A v2 (FRAGMENTA_BEATSYNC_V2) ====================================
|
| 141 |
+
# DEPRECATED: every symbol in this section is scheduled for relocation into
|
| 142 |
+
# `app/core/loop_quantizer` (see AUDIT.md Β§9c). Port the logic, then delete
|
| 143 |
+
# the originals here. Do NOT add new callers to anything below.
|
| 144 |
+
# A single hardened core shared by both align entry points. It enforces the
|
| 145 |
+
# locked invariants directly instead of relying on librosa's beat[0] for
|
| 146 |
+
# phase and on end-snap/silence-trim for length:
|
| 147 |
+
# * tempo conform with a bounded phase-vocoder stretch (_conform_stretch) β
|
| 148 |
+
# gen-time warp only, no live tracking (decision: v1);
|
| 149 |
+
# * align the first STRONG transient to sample 0 (rotate-free head trim) so
|
| 150 |
+
# two independently-correct clips share a downbeat with zero per-clip code;
|
| 151 |
+
# * crop to the exact target sample count β overgenerate-then-trim, never
|
| 152 |
+
# zero-pad in the common path (pad only as a logged last resort).
|
| 153 |
+
|
| 154 |
+
def _stage_a_v2(
|
| 155 |
+
audio: np.ndarray,
|
| 156 |
+
sr: int,
|
| 157 |
+
*,
|
| 158 |
+
target_samples: int,
|
| 159 |
+
target_bpm: float,
|
| 160 |
+
deadband: float,
|
| 161 |
+
) -> np.ndarray:
|
| 162 |
+
"""Hardened Stage A core. Input/return: float32 ``[T, C]``.
|
| 163 |
+
|
| 164 |
+
Decides per clip how to land it on the grid:
|
| 165 |
+
* low grid confidence -> place as-is (trust the requested grid; no warp,
|
| 166 |
+
no trim β Ableton likewise won't warp a pulse-less texture);
|
| 167 |
+
* confident + non-uniform drift -> beat-synchronous warp (each inter-beat
|
| 168 |
+
segment stretched onto the exact grid, Ableton "Beats" warp);
|
| 169 |
+
* confident + already coherent -> anchor + (optional) whole-loop tempo
|
| 170 |
+
nudge; the measured workhorse path (single-digit ms on real loops).
|
| 171 |
+
Always finishes with: first-strong-transient -> sample 0, then exact crop.
|
| 172 |
+
"""
|
| 173 |
+
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
|
| 174 |
+
detected_bpm, beats = _detect_grid(mono, sr, start_bpm=target_bpm)
|
| 175 |
+
confidence = _grid_confidence(mono, sr, beats)
|
| 176 |
+
spb = sr * 60.0 / target_bpm
|
| 177 |
+
|
| 178 |
+
trusted = (
|
| 179 |
+
detected_bpm is not None
|
| 180 |
+
and confidence >= _GRID_CONFIDENCE_MIN
|
| 181 |
+
and beats is not None
|
| 182 |
+
and len(beats) >= _WARP_MIN_BEATS
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if not trusted:
|
| 186 |
+
logger.info(
|
| 187 |
+
"stage_a_v2: %s; trusting requested %.2f BPM grid, exact-length only",
|
| 188 |
+
"low grid confidence (%.2f < %.2f)" % (confidence, _GRID_CONFIDENCE_MIN)
|
| 189 |
+
if detected_bpm is not None else "no usable grid",
|
| 190 |
+
target_bpm,
|
| 191 |
+
)
|
| 192 |
+
return _exact_len(audio, target_samples, sr)
|
| 193 |
+
|
| 194 |
+
# --- anchor the musical "1" to sample 0 (INV#4, enables INV#9) --------
|
| 195 |
+
# Anchor to the first TRACKED beat, not the "first loud onset": the tracked
|
| 196 |
+
# beat is the same metrical position across clips, so two loops coincide;
|
| 197 |
+
# "first loud onset" lands on whatever transient happens to be loudest and
|
| 198 |
+
# differs per clip (measured: 200+ ms apart). Refine beats[0] to the exact
|
| 199 |
+
# rising edge for sample accuracy.
|
| 200 |
+
anchor = _refine_to_transient(mono, int(beats[0]), sr)
|
| 201 |
+
if anchor > 0:
|
| 202 |
+
audio = audio[anchor:]
|
| 203 |
+
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
|
| 204 |
+
beats = np.asarray(beats, dtype=np.int64) - anchor
|
| 205 |
+
beats = beats[beats >= 0]
|
| 206 |
+
|
| 207 |
+
drift = _grid_drift_samples(beats)
|
| 208 |
+
if (_warp_enabled() and drift > _WARP_DRIFT_MIN_MS * sr / 1000.0
|
| 209 |
+
and len(beats) >= 2):
|
| 210 |
+
# OFF BY DEFAULT (FRAGMENTA_BEATSYNC_WARP). Per-beat warp is only as good
|
| 211 |
+
# as librosa's beat detection β when detection is even slightly off it
|
| 212 |
+
# warps the wrong points onto the grid and SCRAMBLES the groove on real
|
| 213 |
+
# audio. Measured gain on clean drift was marginal (it merely halved it
|
| 214 |
+
# and added jitter), while anchor + exact-crop already lands real loops
|
| 215 |
+
# at ~3 ms. So it's opt-in for experiments, not the default path.
|
| 216 |
+
audio = _beat_sync_warp(audio, beats, spb)
|
| 217 |
+
logger.info("stage_a_v2: anchored + beat-sync warp (intra-loop drift "
|
| 218 |
+
"%.1f ms)", drift / sr * 1000)
|
| 219 |
+
else:
|
| 220 |
+
# Already coherent: a single global stretch is sufficient (and cleaner
|
| 221 |
+
# than per-segment warping) when the overall tempo is off; otherwise
|
| 222 |
+
# the anchor + exact crop is all that's needed.
|
| 223 |
+
rate, eff = _best_stretch_rate(
|
| 224 |
+
detected_bpm, target_bpm,
|
| 225 |
+
safe_min=_BARS_MODE_STRETCH_MIN, safe_max=_BARS_MODE_STRETCH_MAX,
|
| 226 |
+
)
|
| 227 |
+
if rate is not None and abs(rate - 1.0) > deadband:
|
| 228 |
+
audio = _conform_stretch(audio, rate, sr)
|
| 229 |
+
logger.info("stage_a_v2: anchored + global tempo conform x%.4f "
|
| 230 |
+
"(detected %.2f -> %.2f)", rate, detected_bpm, target_bpm)
|
| 231 |
+
else:
|
| 232 |
+
logger.info("stage_a_v2: anchored only (low drift, on-tempo)")
|
| 233 |
+
|
| 234 |
+
return _exact_len(audio, target_samples, sr)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _exact_len(audio: np.ndarray, target_samples: int, sr: int) -> np.ndarray:
|
| 238 |
+
"""Crop to exactly target_samples (INV#2/#3). Pads only as a logged last
|
| 239 |
+
resort β the generation overshoots duration so trimming is the norm."""
|
| 240 |
+
if audio.shape[0] >= target_samples:
|
| 241 |
+
return np.ascontiguousarray(audio[:target_samples], dtype=np.float32)
|
| 242 |
+
pad = target_samples - audio.shape[0]
|
| 243 |
+
logger.warning(
|
| 244 |
+
"stage_a_v2: content short by %d samp (%.0f ms) β padding as a last "
|
| 245 |
+
"resort; raise generation headroom or re-roll", pad, pad / sr * 1000,
|
| 246 |
+
)
|
| 247 |
+
return np.ascontiguousarray(
|
| 248 |
+
np.concatenate([audio, np.zeros((pad, audio.shape[1]), np.float32)], 0),
|
| 249 |
+
dtype=np.float32,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _grid_drift_samples(beats: Optional[np.ndarray]) -> float:
|
| 254 |
+
"""Std of detected-beat residuals vs a uniform least-squares grid (samples).
|
| 255 |
+
A coherent loop sits near 0; tempo wobble shows up as a large residual."""
|
| 256 |
+
if beats is None or len(beats) < 4:
|
| 257 |
+
return 0.0
|
| 258 |
+
idx = np.arange(len(beats))
|
| 259 |
+
A = np.vstack([idx, np.ones_like(idx)]).T
|
| 260 |
+
slope, icpt = np.linalg.lstsq(A, beats.astype(float), rcond=None)[0]
|
| 261 |
+
resid = beats.astype(float) - (slope * idx + icpt)
|
| 262 |
+
return float(np.std(resid))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _refine_to_transient(mono: np.ndarray, approx: int, sr: int,
|
| 266 |
+
win_sec: float = 0.015) -> int:
|
| 267 |
+
"""Snap a frame-resolution beat sample to the exact rising edge of the
|
| 268 |
+
transient AT that beat. librosa picks WHICH transient is the beat (good);
|
| 269 |
+
this gives it sample accuracy (INV#4). The window is deliberately tight
|
| 270 |
+
(~15 ms): wide enough to cover beat-tracker frame jitter, narrow enough not
|
| 271 |
+
to jump to a neighbouring transient (which would desync clips, INV#9)."""
|
| 272 |
+
n = len(mono)
|
| 273 |
+
if n == 0:
|
| 274 |
+
return 0
|
| 275 |
+
approx = int(max(0, min(approx, n - 1)))
|
| 276 |
+
lo = max(0, approx - int(sr * win_sec))
|
| 277 |
+
hi = min(n, approx + int(sr * win_sec))
|
| 278 |
+
if hi - lo < 2:
|
| 279 |
+
return approx
|
| 280 |
+
seg = np.abs(mono[lo:hi])
|
| 281 |
+
pk = float(seg.max())
|
| 282 |
+
if pk <= 1e-6:
|
| 283 |
+
return approx
|
| 284 |
+
above = np.flatnonzero(seg >= _V2_RISE_RATIO * pk)
|
| 285 |
+
return int(lo + above[0]) if len(above) else approx
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _beat_sync_warp(audio: np.ndarray, beats: np.ndarray, spb: float) -> np.ndarray:
|
| 289 |
+
"""Ableton 'Beats'-style warp: stretch each inter-beat segment to exactly
|
| 290 |
+
round(spb) samples. Output starts at the first detected beat and has a
|
| 291 |
+
perfectly uniform grid, so two clips at the same tempo become sample-for-
|
| 292 |
+
sample periodic (INV#9). Phase-vocoder per segment; only invoked when drift
|
| 293 |
+
is high enough to be worth the boundary jitter."""
|
| 294 |
+
beats = np.asarray(beats, dtype=np.int64)
|
| 295 |
+
beats = beats[(beats >= 0) & (beats < audio.shape[0])]
|
| 296 |
+
if len(beats) < 2:
|
| 297 |
+
return audio
|
| 298 |
+
target_spb = int(round(spb))
|
| 299 |
+
segs = []
|
| 300 |
+
for i in range(len(beats) - 1):
|
| 301 |
+
s, e = int(beats[i]), int(beats[i + 1])
|
| 302 |
+
seg = audio[s:e]
|
| 303 |
+
if seg.shape[0] < 16:
|
| 304 |
+
continue
|
| 305 |
+
rate = float(np.clip(seg.shape[0] / spb, 0.5, 2.0))
|
| 306 |
+
w = librosa.effects.time_stretch(seg.T, rate=rate).T
|
| 307 |
+
if w.shape[0] >= target_spb:
|
| 308 |
+
w = w[:target_spb]
|
| 309 |
+
else:
|
| 310 |
+
w = np.concatenate(
|
| 311 |
+
[w, np.zeros((target_spb - w.shape[0], w.shape[1]), np.float32)], 0)
|
| 312 |
+
segs.append(np.ascontiguousarray(w, dtype=np.float32))
|
| 313 |
+
return np.concatenate(segs, 0) if segs else audio
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _grid_confidence(
|
| 317 |
+
mono: np.ndarray, sr: int, beats: Optional[np.ndarray]
|
| 318 |
+
) -> float:
|
| 319 |
+
"""Trustworthiness of the detected beat grid, in [0, 1].
|
| 320 |
+
|
| 321 |
+
Two evidence sources, averaged:
|
| 322 |
+
* regularity β how evenly spaced the detected beats are (1 - interval
|
| 323 |
+
coefficient of variation, clamped); a locked tracker gives near-even
|
| 324 |
+
intervals, ambient content gives erratic ones;
|
| 325 |
+
* pulse clarity β the strongest off-zero peak of the onset-envelope
|
| 326 |
+
autocorrelation relative to lag 0; high when there is a real periodic
|
| 327 |
+
pulse, low for drones/pads.
|
| 328 |
+
"""
|
| 329 |
+
if beats is None or len(beats) < 4:
|
| 330 |
+
return 0.0
|
| 331 |
+
intervals = np.diff(beats.astype(np.float64))
|
| 332 |
+
mean_i = float(np.mean(intervals)) if len(intervals) else 0.0
|
| 333 |
+
if mean_i <= 0:
|
| 334 |
+
return 0.0
|
| 335 |
+
cv = float(np.std(intervals) / mean_i)
|
| 336 |
+
regularity = max(0.0, min(1.0, 1.0 - cv / _CV_MAX))
|
| 337 |
+
|
| 338 |
+
clarity = 0.0
|
| 339 |
+
try:
|
| 340 |
+
oenv = librosa.onset.onset_strength(y=mono, sr=sr)
|
| 341 |
+
oenv = oenv - float(np.mean(oenv))
|
| 342 |
+
ac = librosa.autocorrelate(oenv)
|
| 343 |
+
if len(ac) > 4 and ac[0] > 0:
|
| 344 |
+
clarity = float(np.max(ac[4:]) / ac[0])
|
| 345 |
+
clarity = max(0.0, min(1.0, clarity))
|
| 346 |
+
except Exception as exc:
|
| 347 |
+
logger.warning("grid-confidence clarity failed: %s", exc)
|
| 348 |
+
|
| 349 |
+
return 0.5 * regularity + 0.5 * clarity
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _first_strong_transient(mono: np.ndarray, sr: int) -> int:
|
| 353 |
+
"""Sample index of the first STRONG transient, refined to the rising edge.
|
| 354 |
+
|
| 355 |
+
Two-stage so we neither latch onto low-level noise nor lose sample
|
| 356 |
+
accuracy to librosa's 512-sample hop:
|
| 357 |
+
1. librosa onset candidates; take the first whose local peak reaches
|
| 358 |
+
``_V2_STRONG_RATIO`` of the search-window peak;
|
| 359 |
+
2. refine within a small window to the first sample crossing
|
| 360 |
+
``_V2_RISE_RATIO`` of that local peak β the attack's true start.
|
| 361 |
+
Returns 0 when the clip is silent or no strong transient is found.
|
| 362 |
+
"""
|
| 363 |
+
n = len(mono)
|
| 364 |
+
search = min(n, int(sr * _V2_TRANSIENT_SEARCH_SEC))
|
| 365 |
+
if search <= 0:
|
| 366 |
+
return 0
|
| 367 |
+
peak = float(np.max(np.abs(mono[:search])))
|
| 368 |
+
if peak <= 1e-6:
|
| 369 |
+
return 0
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
onsets = librosa.onset.onset_detect(
|
| 373 |
+
y=mono, sr=sr, units="samples", backtrack=True
|
| 374 |
+
)
|
| 375 |
+
except Exception as exc:
|
| 376 |
+
logger.warning("v2 onset detection failed: %s", exc)
|
| 377 |
+
onsets = None
|
| 378 |
+
|
| 379 |
+
cand: Optional[int] = None
|
| 380 |
+
if onsets is not None and len(onsets) > 0:
|
| 381 |
+
look = int(sr * 0.05)
|
| 382 |
+
for o in np.asarray(onsets, dtype=np.int64):
|
| 383 |
+
if o >= search:
|
| 384 |
+
break
|
| 385 |
+
lo, hi = int(o), min(n, int(o) + look)
|
| 386 |
+
if float(np.max(np.abs(mono[lo:hi]))) >= _V2_STRONG_RATIO * peak:
|
| 387 |
+
cand = int(o)
|
| 388 |
+
break
|
| 389 |
+
|
| 390 |
+
if cand is None:
|
| 391 |
+
# No qualifying onset β fall back to the first sample that crosses a
|
| 392 |
+
# fraction of the window peak (handles smooth/pad content).
|
| 393 |
+
idx = np.flatnonzero(np.abs(mono[:search]) >= _V2_STRONG_RATIO * peak)
|
| 394 |
+
return int(idx[0]) if len(idx) else 0
|
| 395 |
|
| 396 |
+
win = int(sr * _V2_REFINE_WIN_SEC)
|
| 397 |
+
lo = max(0, cand - win)
|
| 398 |
+
hi = min(n, cand + win)
|
| 399 |
+
local_peak = float(np.max(np.abs(mono[lo:hi]))) or peak
|
| 400 |
+
seg = np.abs(mono[lo:hi])
|
| 401 |
+
above = np.flatnonzero(seg >= _V2_RISE_RATIO * local_peak)
|
| 402 |
+
return int(lo + above[0]) if len(above) else cand
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def _conform_stretch(audio: np.ndarray, rate: float, sr: int) -> np.ndarray:
|
| 406 |
+
"""Tempo-conform time-stretch β the INV#5 "justified equivalent".
|
| 407 |
+
|
| 408 |
+
We use the librosa phase vocoder (no external binary to ship) rather than
|
| 409 |
+
RubberBand's transient mode, justified by three properties that keep
|
| 410 |
+
transient smearing perceptually negligible here:
|
| 411 |
+
|
| 412 |
+
1. Bounded rate. This only runs inside the safe range [0.85, 1.15] β at
|
| 413 |
+
most a 15% stretch β where phase-vocoder transient blur is minor.
|
| 414 |
+
2. Rare path. It fires only on high grid-confidence, off-by->0.5%-tempo
|
| 415 |
+
loops; SA3 usually hits the target at gen-time and skips it entirely.
|
| 416 |
+
3. The perceptually critical transient β the downbeat β is positioned by
|
| 417 |
+
the sample-accurate trim in `_stage_a_v2`, NOT by this stretch, so the
|
| 418 |
+
musical "1" is never vocoded.
|
| 419 |
+
|
| 420 |
+
`sr` is accepted for call-site symmetry (the phase vocoder is rate-only)."""
|
| 421 |
+
if abs(rate - 1.0) < 1e-9:
|
| 422 |
+
return audio
|
| 423 |
+
return _time_stretch_multichannel(audio, rate)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# DEPRECATED: superseded by app/core/loop_quantizer (see task_1.md / AUDIT.md Β§9a).
|
| 427 |
+
# Public entry; emits DeprecationWarning at runtime. Scheduled for removal once
|
| 428 |
+
# the new module passes acceptance.
|
| 429 |
def align_to_grid(
|
| 430 |
input_path: Path,
|
| 431 |
target_bpm: float,
|
| 432 |
target_bars: int,
|
| 433 |
beats_per_bar: int = 4,
|
| 434 |
) -> Path:
|
| 435 |
+
warnings.warn(
|
| 436 |
+
"align_to_grid is deprecated and will be removed once "
|
| 437 |
+
"app/core/loop_quantizer ships (see task_1.md / AUDIT.md Β§9a).",
|
| 438 |
+
DeprecationWarning,
|
| 439 |
+
stacklevel=2,
|
| 440 |
+
)
|
| 441 |
audio, sr = sf.read(str(input_path), always_2d=True)
|
| 442 |
audio = audio.astype(np.float32, copy=False)
|
| 443 |
+
samples_per_beat = sr * 60.0 / float(target_bpm)
|
| 444 |
+
target_samples = int(round(target_bars * beats_per_bar * samples_per_beat))
|
| 445 |
+
|
| 446 |
+
if beatsync_v2_enabled():
|
| 447 |
+
out = _stage_a_v2(
|
| 448 |
+
np.ascontiguousarray(audio), sr,
|
| 449 |
+
target_samples=target_samples, target_bpm=float(target_bpm),
|
| 450 |
+
deadband=_BARS_MODE_DEADBAND,
|
| 451 |
+
)
|
| 452 |
+
# 3 ms head fade-in masks any click at the new sample-0 transient.
|
| 453 |
+
_apply_fade(out, _HEAD_FADE_SEC, sr, fade_in=True)
|
| 454 |
+
sf.write(str(input_path), out, sr, subtype="PCM_16")
|
| 455 |
+
logger.info("align_to_grid[v2]: %d samples (exact target %d)",
|
| 456 |
+
out.shape[0], target_samples)
|
| 457 |
+
return input_path
|
| 458 |
|
| 459 |
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
|
| 460 |
|
| 461 |
+
detected_bpm, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
|
|
|
|
|
|
|
| 462 |
|
| 463 |
+
# --- Head trim ---------------------------------------------------------
|
| 464 |
head_offset = 0
|
| 465 |
+
if beat_samples is not None and len(beat_samples) > 0:
|
| 466 |
+
first_beat = int(beat_samples[0])
|
| 467 |
+
if 0 < first_beat < sr * 1.5:
|
| 468 |
+
head_offset = first_beat
|
| 469 |
+
logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms to first beat")
|
| 470 |
+
elif beat_samples is None:
|
| 471 |
head_offset = _detect_first_onset_sample(mono, sr)
|
| 472 |
if head_offset > 0:
|
| 473 |
logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms (onset fallback)")
|
|
|
|
| 475 |
if head_offset > 0:
|
| 476 |
audio = audio[head_offset:]
|
| 477 |
mono = mono[head_offset:]
|
| 478 |
+
if beat_samples is not None:
|
| 479 |
+
shifted = np.asarray(beat_samples, dtype=np.int64) - head_offset
|
| 480 |
+
beat_samples = shifted[shifted > 0]
|
| 481 |
+
# Head fade-in: 3 ms equal-power so the trim seam doesn't click.
|
| 482 |
+
_apply_fade(audio, _HEAD_FADE_SEC, sr, fade_in=True)
|
| 483 |
|
| 484 |
+
# --- Tempo conform -----------------------------------------------------
|
| 485 |
if detected_bpm is not None:
|
| 486 |
+
rate, effective_bpm = _best_stretch_rate(
|
| 487 |
+
detected_bpm,
|
| 488 |
+
target_bpm,
|
| 489 |
+
safe_min=_BARS_MODE_STRETCH_MIN,
|
| 490 |
+
safe_max=_BARS_MODE_STRETCH_MAX,
|
| 491 |
+
)
|
| 492 |
+
if rate is not None and abs(rate - 1.0) > _BARS_MODE_DEADBAND:
|
| 493 |
+
audio = _time_stretch_multichannel(audio, rate)
|
| 494 |
+
# Beats have moved β re-detect from the warped audio so the
|
| 495 |
+
# end-snap step below sees current beat positions.
|
| 496 |
+
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
|
| 497 |
+
_, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
|
| 498 |
interp_note = (
|
| 499 |
f" (interpreted as {effective_bpm:.2f} BPM, "
|
| 500 |
f"octave={effective_bpm / detected_bpm:.2f}Γ)"
|
|
|
|
| 505 |
f"align_to_grid: detected {detected_bpm:.2f} BPM{interp_note}, "
|
| 506 |
f"stretched by {rate:.4f} to match target {target_bpm:.2f} BPM"
|
| 507 |
)
|
| 508 |
+
elif rate is not None:
|
| 509 |
+
logger.info(
|
| 510 |
+
f"align_to_grid: detected {detected_bpm:.2f} BPM is within "
|
| 511 |
+
f"{_BARS_MODE_DEADBAND * 100:.0f}% of target {target_bpm:.2f}; "
|
| 512 |
+
f"skipping stretch to preserve transients"
|
| 513 |
+
)
|
| 514 |
else:
|
| 515 |
logger.info(
|
| 516 |
f"align_to_grid: detected {detected_bpm:.2f} BPM has no safe "
|
| 517 |
+
f"interpretation vs target {target_bpm:.2f} within "
|
| 518 |
+
f"[{_BARS_MODE_STRETCH_MIN:.2f}, {_BARS_MODE_STRETCH_MAX:.2f}]; "
|
| 519 |
+
f"skipping warp (user re-roll recommended)"
|
| 520 |
)
|
| 521 |
else:
|
| 522 |
logger.info("align_to_grid: no usable tempo detected; skipping warp")
|
| 523 |
|
| 524 |
+
# --- Trim trailing silence --------------------------------------------
|
| 525 |
+
# Done before end-snap so the snap operates on real audio, not on
|
| 526 |
+
# beats that happen to fall inside a quiet tail.
|
| 527 |
+
new_len = _trailing_audio_end(audio, sr)
|
| 528 |
+
if new_len < audio.shape[0]:
|
| 529 |
+
trimmed_ms = (audio.shape[0] - new_len) / sr * 1000
|
| 530 |
+
logger.info(f"align_to_grid: trimmed {trimmed_ms:.0f} ms trailing silence")
|
| 531 |
+
audio = audio[:new_len]
|
| 532 |
+
if beat_samples is not None:
|
| 533 |
+
beat_samples = beat_samples[beat_samples < new_len]
|
| 534 |
+
|
| 535 |
+
# --- End-anchored truncation ------------------------------------------
|
| 536 |
if audio.shape[0] > target_samples:
|
| 537 |
+
end = _snap_to_beat(target_samples, beat_samples, samples_per_beat, audio.shape[0])
|
| 538 |
+
cut_on_beat = beat_samples is not None and end in beat_samples.tolist()
|
| 539 |
+
audio = audio[:end]
|
| 540 |
+
if not cut_on_beat:
|
| 541 |
+
# Mid-note cut β short fade hides the click. On a clean beat
|
| 542 |
+
# boundary the cut is on a natural transient edge, so the fade
|
| 543 |
+
# would only "duck" the start of the next beat at the loop
|
| 544 |
+
# seam without preventing any audible click.
|
| 545 |
+
_apply_fade(audio, _TAIL_FADE_SEC, sr, fade_in=False)
|
| 546 |
+
# If we came in shorter than target, return the actual audio without
|
| 547 |
+
# zero-padding. A 7.5-bar clip that loops cleanly beats an 8-bar clip
|
| 548 |
+
# with 0.5 bars of silence at the loop seam.
|
| 549 |
|
| 550 |
sf.write(str(input_path), audio, sr, subtype="PCM_16")
|
| 551 |
return input_path
|
| 552 |
|
| 553 |
|
| 554 |
+
# --- Phase 7 loop alignment -----------------------------------------------
|
| 555 |
+
|
| 556 |
+
# DEPRECATED: superseded by app/core/loop_quantizer (see task_1.md / AUDIT.md Β§9a).
|
| 557 |
+
# Public entry; emits DeprecationWarning at runtime. Scheduled for removal once
|
| 558 |
+
# the new module passes acceptance.
|
| 559 |
+
def align_for_loop(
|
| 560 |
+
audio: np.ndarray,
|
| 561 |
+
sr: int,
|
| 562 |
+
*,
|
| 563 |
+
target_samples: int,
|
| 564 |
+
target_bpm: float,
|
| 565 |
+
) -> np.ndarray:
|
| 566 |
+
"""Align a baseline clip for seamless looping at an exact length.
|
| 567 |
+
|
| 568 |
+
DEPRECATED β superseded by ``app/core/loop_quantizer`` (see ``task_1.md`` /
|
| 569 |
+
``AUDIT.md`` Β§9a). Scheduled for removal once the new module ships.
|
| 570 |
+
|
| 571 |
+
Pipeline (in-memory, no disk I/O):
|
| 572 |
+
1. Detect tempo + beat grid via librosa.
|
| 573 |
+
2. Time-stretch (uniformly) if detected BPM drifts past the bars-mode
|
| 574 |
+
deadband AND the required rate is in the safe range. Drift
|
| 575 |
+
beyond the safe range is left alone (caller can re-roll).
|
| 576 |
+
3. Head-trim to the first detected beat (or first onset as fallback),
|
| 577 |
+
within the first ~1.5 s. This is the phase-alignment step β it
|
| 578 |
+
puts the loop's "downbeat" at sample 0 so multiple channels'
|
| 579 |
+
beats coincide when launched on a bar boundary.
|
| 580 |
+
4. Crop or zero-pad to exactly `target_samples`. No end-snap: the
|
| 581 |
+
loop iteration length is sample-exact so it stays phase-locked
|
| 582 |
+
to the master clock across iterations.
|
| 583 |
+
|
| 584 |
+
Returns a `np.ndarray` of shape `(target_samples, channels)` (or 1-D
|
| 585 |
+
if input was 1-D). The caller is expected to wrap-and-inpaint the
|
| 586 |
+
output to smooth the seam β `align_for_loop` does no fade.
|
| 587 |
+
"""
|
| 588 |
+
warnings.warn(
|
| 589 |
+
"align_for_loop is deprecated and will be removed once "
|
| 590 |
+
"app/core/loop_quantizer ships (see task_1.md / AUDIT.md Β§9a).",
|
| 591 |
+
DeprecationWarning,
|
| 592 |
+
stacklevel=2,
|
| 593 |
+
)
|
| 594 |
+
if audio.ndim == 1:
|
| 595 |
+
audio = audio[:, np.newaxis]
|
| 596 |
+
squeeze_out = True
|
| 597 |
+
else:
|
| 598 |
+
squeeze_out = False
|
| 599 |
+
audio = np.ascontiguousarray(audio, dtype=np.float32)
|
| 600 |
+
|
| 601 |
+
if beatsync_v2_enabled():
|
| 602 |
+
out = _stage_a_v2(
|
| 603 |
+
audio, sr,
|
| 604 |
+
target_samples=target_samples, target_bpm=float(target_bpm),
|
| 605 |
+
deadband=_LOOP_MODE_DEADBAND,
|
| 606 |
+
)
|
| 607 |
+
return out.squeeze(1) if squeeze_out else out
|
| 608 |
+
|
| 609 |
+
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
|
| 610 |
+
detected_bpm, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
|
| 611 |
+
|
| 612 |
+
# --- 1+2: tempo conform ---------------------------------------------
|
| 613 |
+
if detected_bpm is not None:
|
| 614 |
+
rate, effective_bpm = _best_stretch_rate(
|
| 615 |
+
detected_bpm,
|
| 616 |
+
target_bpm,
|
| 617 |
+
safe_min=_BARS_MODE_STRETCH_MIN,
|
| 618 |
+
safe_max=_BARS_MODE_STRETCH_MAX,
|
| 619 |
+
)
|
| 620 |
+
if rate is not None and abs(rate - 1.0) > _LOOP_MODE_DEADBAND:
|
| 621 |
+
audio = _time_stretch_multichannel(audio, rate)
|
| 622 |
+
mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
|
| 623 |
+
_, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
|
| 624 |
+
interp = (
|
| 625 |
+
f" (interpreted as {effective_bpm:.2f} BPM)"
|
| 626 |
+
if abs(effective_bpm - detected_bpm) > 1e-2 else ""
|
| 627 |
+
)
|
| 628 |
+
logger.info(
|
| 629 |
+
"align_for_loop: detected %.2f BPM%s, stretched by %.4f to "
|
| 630 |
+
"match %.2f target",
|
| 631 |
+
detected_bpm, interp, rate, target_bpm,
|
| 632 |
+
)
|
| 633 |
+
elif rate is not None:
|
| 634 |
+
logger.info(
|
| 635 |
+
"align_for_loop: detected %.2f BPM within %.2f%% of %.2f target; "
|
| 636 |
+
"no stretch",
|
| 637 |
+
detected_bpm, _LOOP_MODE_DEADBAND * 100, target_bpm,
|
| 638 |
+
)
|
| 639 |
+
else:
|
| 640 |
+
logger.info(
|
| 641 |
+
"align_for_loop: detected %.2f BPM has no safe stretch to "
|
| 642 |
+
"%.2f target within [%.2f, %.2f]; leaving tempo as-is",
|
| 643 |
+
detected_bpm, target_bpm,
|
| 644 |
+
_BARS_MODE_STRETCH_MIN, _BARS_MODE_STRETCH_MAX,
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
logger.info("align_for_loop: no usable tempo detected; skipping stretch")
|
| 648 |
+
|
| 649 |
+
# --- 3: head-trim to first beat / onset (phase alignment) -----------
|
| 650 |
+
head_offset = 0
|
| 651 |
+
if beat_samples is not None and len(beat_samples) > 0:
|
| 652 |
+
first_beat = int(beat_samples[0])
|
| 653 |
+
if 0 < first_beat < sr * 1.5:
|
| 654 |
+
head_offset = first_beat
|
| 655 |
+
if head_offset == 0:
|
| 656 |
+
# Onset fallback when beat tracking didn't lock β gives at least
|
| 657 |
+
# a transient-aligned start instead of mid-attack on sample 0.
|
| 658 |
+
head_offset = _detect_first_onset_sample(mono, sr)
|
| 659 |
+
if head_offset >= sr * 1.5:
|
| 660 |
+
head_offset = 0
|
| 661 |
+
if head_offset > 0:
|
| 662 |
+
audio = audio[head_offset:]
|
| 663 |
+
logger.info(
|
| 664 |
+
"align_for_loop: head-trimmed %.1f ms to first beat/onset",
|
| 665 |
+
head_offset / sr * 1000,
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# --- 4: crop or pad to exact target_samples -------------------------
|
| 669 |
+
if audio.shape[0] > target_samples:
|
| 670 |
+
audio = audio[:target_samples]
|
| 671 |
+
elif audio.shape[0] < target_samples:
|
| 672 |
+
pad = target_samples - audio.shape[0]
|
| 673 |
+
audio = np.concatenate(
|
| 674 |
+
[audio, np.zeros((pad, audio.shape[1]), dtype=audio.dtype)],
|
| 675 |
+
axis=0,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
return audio.squeeze(1) if squeeze_out else audio
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
# --- helpers ---------------------------------------------------------------
|
| 682 |
+
|
| 683 |
+
# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
|
| 684 |
+
def _trailing_audio_end(audio: np.ndarray, sr: int) -> int:
|
| 685 |
+
"""Return the sample index just past the last audible content.
|
| 686 |
+
|
| 687 |
+
Walks backwards in non-overlapping windows of `_SILENCE_WINDOW_SEC` and
|
| 688 |
+
finds the last window whose RMS exceeds `_SILENCE_THRESHOLD_DB`. Returns
|
| 689 |
+
the end of that window plus a small natural-decay tail.
|
| 690 |
+
|
| 691 |
+
Falls back to the original audio length when the entire clip is below
|
| 692 |
+
threshold (silent input) or shorter than one window.
|
| 693 |
+
"""
|
| 694 |
+
n = audio.shape[0]
|
| 695 |
+
window = int(sr * _SILENCE_WINDOW_SEC)
|
| 696 |
+
if n <= window:
|
| 697 |
+
return n
|
| 698 |
+
mono = audio.mean(axis=1) if audio.ndim > 1 else audio
|
| 699 |
+
# Squared amplitudes β comparing to thresholdΒ² is equivalent to RMS vs
|
| 700 |
+
# threshold but avoids a sqrt per window.
|
| 701 |
+
sq = (mono ** 2)
|
| 702 |
+
thresh_sq = (10.0 ** (_SILENCE_THRESHOLD_DB / 20.0)) ** 2
|
| 703 |
+
tail_keep = int(sr * _SILENCE_TAIL_KEEP_SEC)
|
| 704 |
+
end = n
|
| 705 |
+
while end > 0:
|
| 706 |
+
start = max(0, end - window)
|
| 707 |
+
if float(sq[start:end].mean()) > thresh_sq:
|
| 708 |
+
return min(n, end + tail_keep)
|
| 709 |
+
end = start
|
| 710 |
+
# Whole clip is below threshold β leave as-is rather than truncate to 0.
|
| 711 |
+
return n
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
|
| 715 |
+
def _snap_to_beat(
|
| 716 |
+
target_samples: int,
|
| 717 |
+
beat_samples: Optional[np.ndarray],
|
| 718 |
+
samples_per_beat: float,
|
| 719 |
+
audio_len: int,
|
| 720 |
+
) -> int:
|
| 721 |
+
"""Return the cut point: the nearest detected beat within Β±Β½ beat of
|
| 722 |
+
target_samples, falling back to target_samples itself if no beat is in
|
| 723 |
+
range. Never overshoots audio length."""
|
| 724 |
+
fallback = min(target_samples, audio_len)
|
| 725 |
+
if beat_samples is None or len(beat_samples) == 0:
|
| 726 |
+
return fallback
|
| 727 |
+
tol = samples_per_beat * 0.5
|
| 728 |
+
valid = beat_samples[(beat_samples > 0) & (beat_samples <= audio_len)]
|
| 729 |
+
if len(valid) == 0:
|
| 730 |
+
return fallback
|
| 731 |
+
diffs = np.abs(valid - target_samples)
|
| 732 |
+
idx = int(np.argmin(diffs))
|
| 733 |
+
if diffs[idx] <= tol:
|
| 734 |
+
return int(valid[idx])
|
| 735 |
+
return fallback
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
# DEPRECATED: superseded by loop_quantizer (AUDIT.md Β§9b); may be ported if reused.
|
| 739 |
+
def _apply_fade(audio: np.ndarray, duration_sec: float, sr: int, *, fade_in: bool) -> None:
|
| 740 |
+
"""In-place equal-power fade on the head (fade_in=True) or tail."""
|
| 741 |
+
n = min(int(duration_sec * sr), audio.shape[0])
|
| 742 |
+
if n <= 1:
|
| 743 |
+
return
|
| 744 |
+
ramp = _equal_power_ramp(n, fade_in=fade_in, dtype=audio.dtype)
|
| 745 |
+
if audio.ndim > 1:
|
| 746 |
+
ramp = ramp[:, np.newaxis]
|
| 747 |
+
if fade_in:
|
| 748 |
+
audio[:n] *= ramp
|
| 749 |
+
else:
|
| 750 |
+
audio[-n:] *= ramp
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# DEPRECATED: superseded by loop_quantizer (AUDIT.md Β§9b); may be ported if reused.
|
| 754 |
+
def _equal_power_ramp(n: int, *, fade_in: bool, dtype) -> np.ndarray:
|
| 755 |
+
"""Cosine-shaped equal-power fade. Energy at the midpoint is preserved
|
| 756 |
+
when summing fade-out + fade-in of complementary segments, avoiding the
|
| 757 |
+
perceptible 'duck' that linear ramps produce at loop seams."""
|
| 758 |
+
t = np.linspace(0.0, np.pi / 2.0, n).astype(dtype, copy=False)
|
| 759 |
+
return np.sin(t) if fade_in else np.cos(t)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
|
| 763 |
def _best_stretch_rate(
|
| 764 |
detected_bpm: float,
|
| 765 |
target_bpm: float,
|
| 766 |
+
*,
|
| 767 |
+
safe_min: float = _STRETCH_SAFE_MIN,
|
| 768 |
+
safe_max: float = _STRETCH_SAFE_MAX,
|
| 769 |
) -> Tuple[Optional[float], float]:
|
| 770 |
"""Pick the time-stretch rate that maps detected β target, considering
|
| 771 |
half-time and double-time interpretations of the detected tempo. Returns
|
|
|
|
| 774 |
nothing safe is available.
|
| 775 |
|
| 776 |
Order of preference:
|
| 777 |
+
1. Detected as-is, if it lands inside [safe_min, safe_max].
|
| 778 |
2. Octave-corrected (detected Γ 0.5 or Γ 2.0), only when the as-is
|
| 779 |
interpretation is out of range. This is the librosa half-/double-
|
| 780 |
time error recovery path.
|
|
|
|
|
|
|
|
|
|
| 781 |
"""
|
|
|
|
| 782 |
rate_asis = target_bpm / detected_bpm
|
| 783 |
+
if safe_min <= rate_asis <= safe_max:
|
| 784 |
return rate_asis, detected_bpm
|
| 785 |
|
|
|
|
|
|
|
|
|
|
| 786 |
candidates = []
|
| 787 |
for octave_factor in (0.5, 2.0):
|
| 788 |
interpreted = detected_bpm * octave_factor
|
| 789 |
rate = target_bpm / interpreted
|
| 790 |
+
if safe_min <= rate <= safe_max:
|
| 791 |
candidates.append((abs(rate - 1.0), rate, interpreted))
|
| 792 |
if not candidates:
|
| 793 |
return None, detected_bpm
|
|
|
|
| 796 |
return best_rate, best_interp
|
| 797 |
|
| 798 |
|
| 799 |
+
# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
|
| 800 |
def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int:
|
| 801 |
"""Return the sample index of the first detected onset, or 0 if none found."""
|
| 802 |
try:
|
|
|
|
| 814 |
return first
|
| 815 |
|
| 816 |
|
| 817 |
+
# DEPRECATED: superseded by loop_quantizer detector (AUDIT.md Β§9c); port or replace.
|
| 818 |
+
def _detect_grid(
|
| 819 |
mono: np.ndarray,
|
| 820 |
sr: int,
|
| 821 |
start_bpm: Optional[float] = None,
|
| 822 |
+
) -> Tuple[Optional[float], Optional[np.ndarray]]:
|
| 823 |
+
"""Run librosa beat tracking with the target tempo as a prior. Returns
|
| 824 |
+
(bpm, beat_samples_array). Passing start_bpm reduces (but doesn't
|
| 825 |
+
eliminate) half-time / double-time errors; the octave-correction in
|
| 826 |
+
_best_stretch_rate handles whatever librosa still gets wrong."""
|
| 827 |
try:
|
| 828 |
kwargs = {"y": mono, "sr": sr, "units": "samples"}
|
| 829 |
if start_bpm is not None and start_bpm > 0:
|
|
|
|
| 837 |
bpm = float(np.atleast_1d(tempo).flatten()[0])
|
| 838 |
if not (40.0 <= bpm <= 240.0):
|
| 839 |
return None, None
|
| 840 |
+
return bpm, np.asarray(beats, dtype=np.int64)
|
| 841 |
|
| 842 |
|
| 843 |
+
# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
|
| 844 |
def _time_stretch_multichannel(audio: np.ndarray, rate: float) -> np.ndarray:
|
| 845 |
"""Phase-vocoder time stretch, applied per channel and re-stacked."""
|
| 846 |
stretched = librosa.effects.time_stretch(audio.T, rate=rate)
|
app/core/model_manager.py
CHANGED
|
@@ -1,478 +1,669 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import json
|
|
|
|
| 3 |
import shutil
|
| 4 |
-
|
| 5 |
-
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
-
import
|
| 8 |
-
from
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
| 12 |
class ModelManager:
|
|
|
|
| 13 |
|
| 14 |
-
def __init__(self, config):
|
| 15 |
self.config = config
|
| 16 |
-
self.models_dir = config.get_path("models_pretrained")
|
| 17 |
self.models_dir.mkdir(exist_ok=True, parents=True)
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
'description': 'Fast generation, good quality, lower memory usage',
|
| 39 |
-
'best_for': 'Beginners, quick experiments, limited GPU',
|
| 40 |
-
'license': 'Stability AI License',
|
| 41 |
-
'checksum': 'sha256:abc123...'
|
| 42 |
-
},
|
| 43 |
-
'stable-audio-open-1.0': {
|
| 44 |
-
'name': 'Stable Audio Open 1.0',
|
| 45 |
-
'repo': models_repo if use_custom_repo else models_repo_large,
|
| 46 |
-
'files': [large_file],
|
| 47 |
-
'size': '8.2 GB',
|
| 48 |
-
'description': 'Highest quality, more detailed audio',
|
| 49 |
-
'best_for': 'Professional use, high-end GPUs',
|
| 50 |
-
'license': 'Stability AI License',
|
| 51 |
-
'checksum': 'sha256:def456...'
|
| 52 |
-
}
|
| 53 |
}
|
| 54 |
|
| 55 |
-
self.
|
| 56 |
-
self.
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
if not file_path.exists() or not file_path.is_file():
|
| 97 |
-
return "0 B"
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
return f"{total_size:.1f} TB"
|
| 117 |
|
| 118 |
-
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
info['terms_accepted'] = self.is_terms_accepted(model_id)
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
| 129 |
|
| 130 |
def is_model_downloaded(self, model_id: str) -> bool:
|
| 131 |
-
|
| 132 |
-
if model_id == 'stable-audio-open-small':
|
| 133 |
-
model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
|
| 134 |
-
return model_file.exists() and model_file.is_file()
|
| 135 |
-
elif model_id == 'stable-audio-open-1.0':
|
| 136 |
-
model_file = self.models_dir / 'stable-audio-open-model.safetensors'
|
| 137 |
-
return model_file.exists() and model_file.is_file()
|
| 138 |
-
else:
|
| 139 |
-
model_path = self.models_dir / model_id
|
| 140 |
-
if model_path.exists() and model_path.is_dir():
|
| 141 |
-
return any(model_path.iterdir())
|
| 142 |
-
pattern = f"*{model_id}*.safetensors"
|
| 143 |
-
matching_files = list(self.models_dir.glob(pattern))
|
| 144 |
-
return len(matching_files) > 0
|
| 145 |
-
|
| 146 |
-
def is_terms_accepted(self, model_id: str) -> bool:
|
| 147 |
-
|
| 148 |
-
if not self.terms_file.exists():
|
| 149 |
return False
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
return False
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
return False
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
try:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
return False
|
| 185 |
-
|
| 186 |
-
def download_model(self, model_id: str, progress_callback: Optional[Callable] = None) -> bool:
|
| 187 |
-
|
| 188 |
-
if model_id not in self.available_models:
|
| 189 |
return False
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
try:
|
| 211 |
-
|
| 212 |
-
print(f"Authenticated as: {user}")
|
| 213 |
if progress_callback:
|
| 214 |
-
progress_callback(
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
"
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
if progress_callback:
|
| 228 |
-
progress_callback(
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
if self.callback:
|
| 259 |
-
self.callback(
|
| 260 |
-
percent,
|
| 261 |
-
f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB"
|
| 262 |
-
)
|
| 263 |
-
return inner
|
| 264 |
-
|
| 265 |
-
downloaded_files = []
|
| 266 |
-
total_files = len(model_info['files'])
|
| 267 |
-
|
| 268 |
-
for i, file_pattern in enumerate(model_info['files']):
|
| 269 |
if progress_callback:
|
| 270 |
progress_callback(
|
| 271 |
-
|
| 272 |
-
f"
|
| 273 |
)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
final_filename = f"{model_id}-model.safetensors"
|
| 283 |
-
else:
|
| 284 |
-
final_filename = f"{model_id}-{file_pattern}"
|
| 285 |
-
|
| 286 |
-
tqdm_callback = TqdmToCallback(progress_callback, i, total_files)
|
| 287 |
-
|
| 288 |
-
# hf_hub_download drives its own tqdm β monkey-patch its init/update so we
|
| 289 |
-
# forward byte progress to progress_callback without a second progress bar.
|
| 290 |
-
original_tqdm_init = tqdm.__init__
|
| 291 |
-
|
| 292 |
-
def patched_tqdm_init(self, *args, **kwargs):
|
| 293 |
-
original_tqdm_init(self, *args, **kwargs)
|
| 294 |
-
original_update = self.update
|
| 295 |
-
def new_update(n=1):
|
| 296 |
-
result = original_update(n)
|
| 297 |
-
if progress_callback and self.total:
|
| 298 |
-
file_progress = (self.n / self.total)
|
| 299 |
-
overall_progress = (i + file_progress) / total_files
|
| 300 |
-
percent = 20 + int(overall_progress * 70)
|
| 301 |
-
downloaded_mb = self.n / (1024 * 1024)
|
| 302 |
-
total_mb = self.total / (1024 * 1024)
|
| 303 |
-
progress_callback(
|
| 304 |
-
percent,
|
| 305 |
-
f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB"
|
| 306 |
-
)
|
| 307 |
-
return result
|
| 308 |
-
self.update = new_update
|
| 309 |
-
|
| 310 |
-
tqdm.__init__ = patched_tqdm_init
|
| 311 |
-
|
| 312 |
-
try:
|
| 313 |
-
downloaded_file = hf_hub_download(
|
| 314 |
-
repo_id=model_info['repo'],
|
| 315 |
-
filename=file_pattern,
|
| 316 |
-
resume_download=True
|
| 317 |
-
)
|
| 318 |
-
finally:
|
| 319 |
-
tqdm.__init__ = original_tqdm_init
|
| 320 |
-
|
| 321 |
-
downloaded_path = Path(downloaded_file)
|
| 322 |
-
final_path = target_dir / final_filename
|
| 323 |
-
|
| 324 |
-
final_path.parent.mkdir(parents=True, exist_ok=True)
|
| 325 |
-
|
| 326 |
-
shutil.copy2(str(downloaded_path), str(final_path))
|
| 327 |
-
print(f"Saved as {final_filename}")
|
| 328 |
-
|
| 329 |
-
downloaded_files.append(str(final_path))
|
| 330 |
-
|
| 331 |
-
if progress_callback:
|
| 332 |
-
progress_callback(
|
| 333 |
-
20 + int(((i + 1) / total_files) * 70),
|
| 334 |
-
f"Completed {file_pattern}"
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
except Exception as file_error:
|
| 338 |
-
print(
|
| 339 |
-
f"Failed to download {file_pattern}: {file_error}")
|
| 340 |
-
if progress_callback:
|
| 341 |
-
progress_callback(
|
| 342 |
-
0, f"Failed to download {file_pattern}")
|
| 343 |
-
continue
|
| 344 |
-
|
| 345 |
-
print(f"Downloaded {len(downloaded_files)} files")
|
| 346 |
-
|
| 347 |
-
if progress_callback:
|
| 348 |
-
progress_callback(
|
| 349 |
-
95, "Download completed, verifying files...")
|
| 350 |
-
|
| 351 |
-
except Exception as download_error:
|
| 352 |
-
print(f"Error during download: {download_error}")
|
| 353 |
-
if progress_callback:
|
| 354 |
-
progress_callback(
|
| 355 |
-
0, f"Download failed: {str(download_error)}")
|
| 356 |
-
return False
|
| 357 |
-
|
| 358 |
if progress_callback:
|
| 359 |
-
progress_callback(
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
return True
|
| 378 |
-
else:
|
| 379 |
-
if progress_callback:
|
| 380 |
-
progress_callback(0, "Download verification failed")
|
| 381 |
-
print(f"Expected files not found: {expected_files}")
|
| 382 |
-
return False
|
| 383 |
-
|
| 384 |
-
except Exception as e:
|
| 385 |
-
print(f"Error downloading {model_info['name']}: {e}")
|
| 386 |
-
if progress_callback:
|
| 387 |
-
progress_callback(0, f"Error: {str(e)}")
|
| 388 |
-
|
| 389 |
-
if "403" in str(e) and "gated repositories" in str(e).lower():
|
| 390 |
-
print("Token permission issue detected!")
|
| 391 |
-
print(
|
| 392 |
-
"Your Hugging Face token needs 'Read access to public gated repositories'")
|
| 393 |
-
print("Please:")
|
| 394 |
-
print("1. Go to https://huggingface.co/settings/tokens")
|
| 395 |
-
print("2. Edit your token or create a new one")
|
| 396 |
-
print("3. Enable 'Read access to public gated repositories'")
|
| 397 |
-
print("4. Try the download again")
|
| 398 |
-
elif "401" in str(e) or "restricted" in str(e).lower():
|
| 399 |
-
print("This model requires Hugging Face authentication.")
|
| 400 |
-
print("Please visit the model page and accept terms first:")
|
| 401 |
-
print(f"https://huggingface.co/{model_info['repo']}")
|
| 402 |
-
return False
|
| 403 |
|
| 404 |
def delete_model(self, model_id: str) -> bool:
|
| 405 |
-
|
| 406 |
-
deleted_something = False
|
| 407 |
-
|
| 408 |
-
if model_id == 'stable-audio-open-small':
|
| 409 |
-
model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
|
| 410 |
-
config_file = self.models_dir / 'stable-audio-open-small-config.json'
|
| 411 |
-
elif model_id == 'stable-audio-open-1.0':
|
| 412 |
-
model_file = self.models_dir / 'stable-audio-open-model.safetensors'
|
| 413 |
-
config_file = self.models_dir / 'stable-audio-open-1.0-config.json'
|
| 414 |
-
else:
|
| 415 |
-
model_file = self.models_dir / f"{model_id}-model.safetensors"
|
| 416 |
-
config_file = self.models_dir / f"{model_id}-config.json"
|
| 417 |
-
|
| 418 |
-
for file_path in [model_file, config_file]:
|
| 419 |
-
if file_path.exists():
|
| 420 |
-
try:
|
| 421 |
-
file_path.unlink()
|
| 422 |
-
print(f"Deleted {file_path.name}")
|
| 423 |
-
deleted_something = True
|
| 424 |
-
except Exception as e:
|
| 425 |
-
print(f"Error deleting {file_path.name}: {e}")
|
| 426 |
-
|
| 427 |
-
model_path = self.models_dir / model_id
|
| 428 |
-
if model_path.exists() and model_path.is_dir():
|
| 429 |
-
try:
|
| 430 |
-
shutil.rmtree(model_path)
|
| 431 |
-
print(f"Deleted {model_id} directory")
|
| 432 |
-
deleted_something = True
|
| 433 |
-
except Exception as e:
|
| 434 |
-
print(f"Error deleting {model_id} directory: {e}")
|
| 435 |
-
|
| 436 |
-
if deleted_something:
|
| 437 |
-
print(f"Deleted {model_id}")
|
| 438 |
-
return True
|
| 439 |
-
else:
|
| 440 |
-
print(f"No files found for {model_id}")
|
| 441 |
return False
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
return {
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
'models_dir': str(self.models_dir)
|
| 470 |
}
|
| 471 |
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Checkpoint Manager β SA3 catalog, HF downloads, license + auth.
|
| 2 |
+
|
| 3 |
+
Phase 2a in SA3_INTEGRATION_PLAN.md. Replaces the SA2-era SAO catalog.
|
| 4 |
+
Eight downloadable artifacts (3 post-trained + 3 base + 2 autoencoders);
|
| 5 |
+
each is fetched via `huggingface_hub.snapshot_download` with cooperative
|
| 6 |
+
cancel + progress reporting.
|
| 7 |
+
|
| 8 |
+
The Phase 2b frontend (CheckpointManagerWindow.js) consumes the JSON shapes
|
| 9 |
+
returned by the `/api/checkpoints/*` endpoints in `app/backend/app.py`.
|
| 10 |
+
"""
|
| 11 |
import json
|
| 12 |
+
import os
|
| 13 |
import shutil
|
| 14 |
+
import threading
|
| 15 |
+
import uuid
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
from datetime import datetime
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 20 |
+
|
| 21 |
+
from huggingface_hub import get_token, snapshot_download, whoami
|
| 22 |
+
from huggingface_hub.errors import GatedRepoError, RepositoryNotFoundError
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# --- Catalog ------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
# Approximate sizes; the frontend can refine these by hitting
|
| 28 |
+
# `huggingface_hub.HfApi().model_info(repo_id)` lazily. Numbers come from the
|
| 29 |
+
# HF model cards (paragraph parameter counts Γ bytes/param, rounded).
|
| 30 |
+
_SA3_CATALOG: Dict[str, Dict[str, Any]] = {
|
| 31 |
+
# --- Generation models (post-trained) ----------------------------------
|
| 32 |
+
"sa3-small-music": {
|
| 33 |
+
"user_visible": True,
|
| 34 |
+
"kind": "post-trained",
|
| 35 |
+
"name": "Small - Music",
|
| 36 |
+
"sa3_name": "small-music",
|
| 37 |
+
"repo": "stabilityai/stable-audio-3-small-music",
|
| 38 |
+
"size_bytes": 2_270_000_000,
|
| 39 |
+
"hardware": "cpu", # CPU / MPS / CUDA all work
|
| 40 |
+
"max_duration_sec": 120,
|
| 41 |
+
"description": "Fast distilled music generation. Locked to 8 steps, cfg 1.0.",
|
| 42 |
+
},
|
| 43 |
+
"sa3-small-sfx": {
|
| 44 |
+
"user_visible": True,
|
| 45 |
+
"kind": "post-trained",
|
| 46 |
+
"name": "Small - SFX",
|
| 47 |
+
"sa3_name": "small-sfx",
|
| 48 |
+
"repo": "stabilityai/stable-audio-3-small-sfx",
|
| 49 |
+
"size_bytes": 2_270_000_000,
|
| 50 |
+
"hardware": "cpu",
|
| 51 |
+
"max_duration_sec": 120,
|
| 52 |
+
"description": "Fast distilled SFX/foley generation. Locked to 8 steps, cfg 1.0.",
|
| 53 |
+
},
|
| 54 |
+
"sa3-medium": {
|
| 55 |
+
"user_visible": True,
|
| 56 |
+
"kind": "post-trained",
|
| 57 |
+
"name": "Medium",
|
| 58 |
+
"sa3_name": "medium",
|
| 59 |
+
"repo": "stabilityai/stable-audio-3-medium",
|
| 60 |
+
"size_bytes": 9_220_000_000,
|
| 61 |
+
"hardware": "cuda+flash-attn",
|
| 62 |
+
"max_duration_sec": 380,
|
| 63 |
+
"description": "Fast distilled hi-fi generation, up to 380s. Locked to 8 steps, cfg 1.0.",
|
| 64 |
+
},
|
| 65 |
+
# --- Base checkpoints (full artist control) ----------------------------
|
| 66 |
+
# These are the CFG-aware pre-distillation models. Slower (~50 steps,
|
| 67 |
+
# cfg ~7), but the user controls cfg_scale, steps, and the inference
|
| 68 |
+
# trajectory. Also the canonical targets for LoRA training.
|
| 69 |
+
"sa3-small-music-base": {
|
| 70 |
+
"user_visible": True,
|
| 71 |
+
"kind": "base",
|
| 72 |
+
"name": "Small - Music (Base)",
|
| 73 |
+
"sa3_name": "small-music-base",
|
| 74 |
+
"repo": "stabilityai/stable-audio-3-small-music-base",
|
| 75 |
+
"size_bytes": 2_270_000_000,
|
| 76 |
+
"hardware": "cpu",
|
| 77 |
+
"max_duration_sec": 120,
|
| 78 |
+
"description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.",
|
| 79 |
+
},
|
| 80 |
+
"sa3-small-sfx-base": {
|
| 81 |
+
"user_visible": True,
|
| 82 |
+
"kind": "base",
|
| 83 |
+
"name": "Small - SFX (Base)",
|
| 84 |
+
"sa3_name": "small-sfx-base",
|
| 85 |
+
"repo": "stabilityai/stable-audio-3-small-sfx-base",
|
| 86 |
+
"size_bytes": 2_270_000_000,
|
| 87 |
+
"hardware": "cpu",
|
| 88 |
+
"max_duration_sec": 120,
|
| 89 |
+
"description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.",
|
| 90 |
+
},
|
| 91 |
+
"sa3-medium-base": {
|
| 92 |
+
"user_visible": True,
|
| 93 |
+
"kind": "base",
|
| 94 |
+
"name": "Medium (Base)",
|
| 95 |
+
"sa3_name": "medium-base",
|
| 96 |
+
"repo": "stabilityai/stable-audio-3-medium-base",
|
| 97 |
+
"size_bytes": 9_220_000_000,
|
| 98 |
+
"hardware": "cuda+flash-attn",
|
| 99 |
+
"max_duration_sec": 380,
|
| 100 |
+
"description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.",
|
| 101 |
+
},
|
| 102 |
+
# Standalone autoencoders: the AE is bundled INSIDE each DiT repo
|
| 103 |
+
# already (StableAudioModel.from_pretrained loads it from there), so
|
| 104 |
+
# we don't surface SAME-S / SAME-L in the manager. They remain
|
| 105 |
+
# downloadable via /api/checkpoints?include=all for advanced uses
|
| 106 |
+
# (autoencoder-only workflows, pre-encoding datasets for training).
|
| 107 |
+
"sa3-same-s": {
|
| 108 |
+
"user_visible": False,
|
| 109 |
+
"kind": "autoencoder",
|
| 110 |
+
"name": "SAME-S",
|
| 111 |
+
"sa3_name": "same-s",
|
| 112 |
+
"repo": "stabilityai/SAME-S",
|
| 113 |
+
"size_bytes": 530_000_000,
|
| 114 |
+
"hardware": "cpu",
|
| 115 |
+
"description": "Standalone autoencoder (266M). Already bundled with the small-* DiTs.",
|
| 116 |
+
},
|
| 117 |
+
"sa3-same-l": {
|
| 118 |
+
"user_visible": False,
|
| 119 |
+
"kind": "autoencoder",
|
| 120 |
+
"name": "SAME-L",
|
| 121 |
+
"sa3_name": "same-l",
|
| 122 |
+
"repo": "stabilityai/SAME-L",
|
| 123 |
+
"size_bytes": 3_400_000_000,
|
| 124 |
+
"hardware": "cuda",
|
| 125 |
+
"description": "Standalone autoencoder (1.7B). Already bundled with medium.",
|
| 126 |
+
},
|
| 127 |
+
# --- Auto-annotation tools ---------------------------------------------
|
| 128 |
+
# Single-file HF download, lives under <models_pretrained>/clap/.
|
| 129 |
+
# `is_model_downloaded` and `_run_download` special-case kind=="tagger".
|
| 130 |
+
"clap-music": {
|
| 131 |
+
"user_visible": True,
|
| 132 |
+
"kind": "tagger",
|
| 133 |
+
"name": "LAION-CLAP (music)",
|
| 134 |
+
"sa3_name": "clap-music",
|
| 135 |
+
"repo": "lukewys/laion_clap",
|
| 136 |
+
"filename": "music_audioset_epoch_15_esc_90.14.pt",
|
| 137 |
+
# ~2.35 GB .pt + ~1.4 GB of text-encoder snapshots (roberta-base,
|
| 138 |
+
# bert-base-uncased, facebook/bart-base) that laion_clap loads at
|
| 139 |
+
# construction. download_clap_checkpoint pulls all of them.
|
| 140 |
+
"size_bytes": 3_800_000_000,
|
| 141 |
+
"hardware": "cpu",
|
| 142 |
+
"description": (
|
| 143 |
+
"Zero-shot tagger used by the dataset prep's rich-tier annotation. "
|
| 144 |
+
"Scores each clip against your genre / mood / instrument vocabulary."
|
| 145 |
+
),
|
| 146 |
+
},
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# --- Job state for in-flight downloads ----------------------------------------
|
| 150 |
+
|
| 151 |
+
@dataclass
|
| 152 |
+
class _DownloadJob:
|
| 153 |
+
"""In-memory record of one download attempt."""
|
| 154 |
+
job_id: str
|
| 155 |
+
model_id: str
|
| 156 |
+
status: str = "queued" # queued | running | complete | failed | cancelled
|
| 157 |
+
downloaded_bytes: int = 0
|
| 158 |
+
total_bytes: int = 0
|
| 159 |
+
error: Optional[str] = None
|
| 160 |
+
started_at: Optional[str] = None
|
| 161 |
+
finished_at: Optional[str] = None
|
| 162 |
+
_cancel_flag: threading.Event = field(default_factory=threading.Event)
|
| 163 |
+
_thread: Optional[threading.Thread] = None
|
| 164 |
+
|
| 165 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 166 |
+
return {
|
| 167 |
+
"job_id": self.job_id,
|
| 168 |
+
"model_id": self.model_id,
|
| 169 |
+
"status": self.status,
|
| 170 |
+
"downloaded_bytes": self.downloaded_bytes,
|
| 171 |
+
"total_bytes": self.total_bytes,
|
| 172 |
+
"error": self.error,
|
| 173 |
+
"started_at": self.started_at,
|
| 174 |
+
"finished_at": self.finished_at,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class _DownloadCancelled(Exception):
|
| 179 |
+
"""Raised inside the tqdm hook when a job's cancel flag fires."""
|
| 180 |
|
| 181 |
|
| 182 |
+
# --- ModelManager -------------------------------------------------------------
|
| 183 |
+
|
| 184 |
class ModelManager:
|
| 185 |
+
"""Owns the SA3 catalog and the on-disk pretrained directory."""
|
| 186 |
|
| 187 |
+
def __init__(self, config: Any) -> None:
|
| 188 |
self.config = config
|
| 189 |
+
self.models_dir: Path = config.get_path("models_pretrained")
|
| 190 |
self.models_dir.mkdir(exist_ok=True, parents=True)
|
| 191 |
|
| 192 |
+
# Project-wide policy: every HF download lands inside
|
| 193 |
+
# <app>/models/pretrained/. SA3 generation + training uses
|
| 194 |
+
# <pretrained>/sa3/hub/; CLAP text deps use <pretrained>/clap/hub/.
|
| 195 |
+
# Both are HF cache layout so snapshot_download / hf_hub_download /
|
| 196 |
+
# from_pretrained resolve there transparently.
|
| 197 |
+
self.hub_dir: Path = self.models_dir / "sa3" / "hub"
|
| 198 |
+
self.hub_dir.mkdir(exist_ok=True, parents=True)
|
| 199 |
+
# Hard-force the resolution vars β never let an external env leak
|
| 200 |
+
# downloads into ~/.cache/huggingface or anywhere else outside the
|
| 201 |
+
# app folder. Covers huggingface_hub (current + legacy name) and
|
| 202 |
+
# transformers (which still consults TRANSFORMERS_CACHE).
|
| 203 |
+
os.environ["HF_HUB_CACHE"] = str(self.hub_dir)
|
| 204 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = str(self.hub_dir)
|
| 205 |
+
os.environ["TRANSFORMERS_CACHE"] = str(self.hub_dir)
|
| 206 |
+
|
| 207 |
+
# available_models is exposed for backwards compat with the existing
|
| 208 |
+
# /api/models/available endpoint. New code should use get_catalog().
|
| 209 |
+
self.available_models: Dict[str, Dict] = {
|
| 210 |
+
mid: dict(meta) for mid, meta in _SA3_CATALOG.items()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
}
|
| 212 |
|
| 213 |
+
self._jobs: Dict[str, _DownloadJob] = {}
|
| 214 |
+
self._jobs_lock = threading.Lock()
|
| 215 |
+
|
| 216 |
+
# --- Catalog --------------------------------------------------------------
|
| 217 |
+
|
| 218 |
+
def get_catalog(self, include_hidden: bool = False) -> List[Dict[str, Any]]:
|
| 219 |
+
"""Checkpoint Manager catalog with per-item state.
|
| 220 |
+
|
| 221 |
+
Default returns only user-visible entries (the three generation
|
| 222 |
+
models). `include_hidden=True` also returns base + standalone-AE
|
| 223 |
+
entries β used by the Phase 5 training subprocess to ensure the
|
| 224 |
+
right base variant is on disk before kicking train_lora.py.
|
| 225 |
+
"""
|
| 226 |
+
return [
|
| 227 |
+
self._catalog_entry(mid)
|
| 228 |
+
for mid, info in _SA3_CATALOG.items()
|
| 229 |
+
if include_hidden or info.get("user_visible")
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
def _catalog_entry(self, model_id: str) -> Dict[str, Any]:
|
| 233 |
+
info = _SA3_CATALOG[model_id]
|
| 234 |
+
downloaded = self.is_model_downloaded(model_id)
|
| 235 |
+
bytes_total = 0
|
| 236 |
+
if downloaded:
|
| 237 |
+
for d in (self._hub_cache_dir_for(model_id), self._legacy_flat_dir_for(model_id)):
|
| 238 |
+
if d.exists():
|
| 239 |
+
bytes_total += self._dir_size(d)
|
| 240 |
+
|
| 241 |
+
# Surface the most recent in-flight job for this model so the
|
| 242 |
+
# frontend can resume the progress bar after the Checkpoint Manager
|
| 243 |
+
# dialog is closed and reopened. The job lives on the backend; only
|
| 244 |
+
# the polling died with the dismissed UI.
|
| 245 |
+
active_job = None
|
| 246 |
+
with self._jobs_lock:
|
| 247 |
+
in_flight = [
|
| 248 |
+
j for j in self._jobs.values()
|
| 249 |
+
if j.model_id == model_id and j.status in ("queued", "running")
|
| 250 |
+
]
|
| 251 |
+
if in_flight:
|
| 252 |
+
in_flight.sort(key=lambda j: j.started_at or "", reverse=True)
|
| 253 |
+
active_job = in_flight[0].to_dict()
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
return {
|
| 256 |
+
"id": model_id,
|
| 257 |
+
"kind": info.get("kind"),
|
| 258 |
+
"name": info["name"],
|
| 259 |
+
"sa3_name": info["sa3_name"],
|
| 260 |
+
"repo": info["repo"],
|
| 261 |
+
"size_bytes": info["size_bytes"],
|
| 262 |
+
"hardware": info["hardware"],
|
| 263 |
+
"max_duration_sec": info.get("max_duration_sec"),
|
| 264 |
+
"description": info["description"],
|
| 265 |
+
"user_visible": info.get("user_visible", False),
|
| 266 |
+
"downloaded": downloaded,
|
| 267 |
+
"downloaded_bytes": bytes_total,
|
| 268 |
+
"active_job": active_job,
|
| 269 |
+
}
|
| 270 |
|
| 271 |
+
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
| 272 |
+
if model_id not in _SA3_CATALOG:
|
| 273 |
+
return None
|
| 274 |
+
return self._catalog_entry(model_id)
|
|
|
|
| 275 |
|
| 276 |
+
# --- Filesystem layout ----------------------------------------------------
|
| 277 |
|
| 278 |
+
def _hub_cache_dir_for(self, model_id: str) -> Path:
|
| 279 |
+
"""HF-cache-shaped directory inside the app folder."""
|
| 280 |
+
info = _SA3_CATALOG.get(model_id)
|
| 281 |
+
if info is None:
|
| 282 |
+
return self.hub_dir / "_unknown"
|
| 283 |
+
safe = "models--" + info["repo"].replace("/", "--")
|
| 284 |
+
return self.hub_dir / safe
|
| 285 |
|
| 286 |
+
def _legacy_flat_dir_for(self, model_id: str) -> Path:
|
| 287 |
+
"""Pre-unification per-model dir. Read-only fallback for migration."""
|
| 288 |
+
return self.models_dir / "sa3" / model_id
|
|
|
|
| 289 |
|
| 290 |
+
def _local_dir_for(self, model_id: str) -> Path:
|
| 291 |
+
"""Public: returns the canonical (HF cache) directory for a model."""
|
| 292 |
+
return self._hub_cache_dir_for(model_id)
|
| 293 |
|
| 294 |
def is_model_downloaded(self, model_id: str) -> bool:
|
| 295 |
+
if model_id not in _SA3_CATALOG:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
return False
|
| 297 |
+
info = _SA3_CATALOG[model_id]
|
| 298 |
+
if info.get("kind") == "tagger":
|
| 299 |
+
# Single-file artifacts live in <models_pretrained>/<group>/<filename>.
|
| 300 |
+
# auto_annotator owns the exact path for CLAP, so we delegate.
|
| 301 |
+
from app.backend.data.auto_annotator import clap_checkpoint_available
|
| 302 |
+
return clap_checkpoint_available(self.models_dir)
|
| 303 |
+
# Canonical: HF cache layout under <app>/models/pretrained/sa3/hub/.
|
| 304 |
+
# Look for the *top-level* model.safetensors only β NOT recursive β
|
| 305 |
+
# because a sibling repo may have only its conditioner subfolder
|
| 306 |
+
# downloaded (e.g. via the eager T5Gemma companion fetch when the
|
| 307 |
+
# user installed the matching *-base), and that doesn't make the
|
| 308 |
+
# post-trained model "downloaded".
|
| 309 |
+
main_present = False
|
| 310 |
+
hub = self._hub_cache_dir_for(model_id)
|
| 311 |
+
if hub.is_dir():
|
| 312 |
+
snaps = hub / "snapshots"
|
| 313 |
+
if snaps.is_dir():
|
| 314 |
+
for sub in snaps.iterdir():
|
| 315 |
+
if any(sub.glob("*.safetensors")):
|
| 316 |
+
main_present = True
|
| 317 |
+
break
|
| 318 |
+
if not main_present:
|
| 319 |
+
# Fallback: legacy flat layout (predates the unification). Counts
|
| 320 |
+
# as downloaded for inference purposes; trainer will re-stage into
|
| 321 |
+
# hub.
|
| 322 |
+
legacy = self._legacy_flat_dir_for(model_id)
|
| 323 |
+
if legacy.is_dir() and any(legacy.glob("*.safetensors")):
|
| 324 |
+
main_present = True
|
| 325 |
+
if not main_present:
|
| 326 |
return False
|
| 327 |
|
| 328 |
+
# Base models need a T5Gemma conditioner that lives in a subfolder
|
| 329 |
+
# of the *post-trained sibling* repo. "Installed" must mean "ready
|
| 330 |
+
# to train / generate" β without the companion the first run blocks
|
| 331 |
+
# for 30s+ on an HF fetch.
|
| 332 |
+
return self._is_companion_present(model_id)
|
| 333 |
+
|
| 334 |
+
def _is_companion_present(self, model_id: str) -> bool:
|
| 335 |
+
from app.core.training.sa3_lora_runner import SA3_T5GEMMA_SIBLINGS
|
| 336 |
+
sibling = SA3_T5GEMMA_SIBLINGS.get(model_id)
|
| 337 |
+
if not sibling:
|
| 338 |
+
return True # nothing to check (post-trained / autoencoder / tagger)
|
| 339 |
+
sib_repo, sib_subfolder = sibling
|
| 340 |
+
safe = "models--" + sib_repo.replace("/", "--")
|
| 341 |
+
sib_hub = self.hub_dir / safe
|
| 342 |
+
snaps = sib_hub / "snapshots"
|
| 343 |
+
if not snaps.is_dir():
|
| 344 |
return False
|
| 345 |
+
for sub in snaps.iterdir():
|
| 346 |
+
if (sub / sib_subfolder).is_dir():
|
| 347 |
+
# Any non-empty file presence is good enough β the eager
|
| 348 |
+
# fetch always pulls the tokenizer + config + safetensors.
|
| 349 |
+
if any((sub / sib_subfolder).iterdir()):
|
| 350 |
+
return True
|
| 351 |
+
return False
|
| 352 |
+
|
| 353 |
+
# --- HF auth --------------------------------------------------------------
|
| 354 |
+
|
| 355 |
+
@staticmethod
|
| 356 |
+
def hf_auth_status() -> Dict[str, Any]:
|
| 357 |
+
token = get_token()
|
| 358 |
+
if not token:
|
| 359 |
+
return {"signed_in": False, "username": None}
|
|
|
|
| 360 |
try:
|
| 361 |
+
user = whoami(token=token)
|
| 362 |
+
return {"signed_in": True, "username": user.get("name") or user.get("fullname")}
|
| 363 |
+
except Exception as err:
|
| 364 |
+
return {"signed_in": False, "username": None, "error": str(err)}
|
| 365 |
+
|
| 366 |
+
# --- Downloads ------------------------------------------------------------
|
| 367 |
+
|
| 368 |
+
def start_download(
|
| 369 |
+
self,
|
| 370 |
+
model_id: str,
|
| 371 |
+
progress_callback: Optional[Callable[[int, str], None]] = None,
|
| 372 |
+
) -> Dict[str, Any]:
|
| 373 |
+
"""Spawn a background download job. Returns the job descriptor."""
|
| 374 |
+
if model_id not in _SA3_CATALOG:
|
| 375 |
+
return {"error": f"Unknown checkpoint: {model_id}"}
|
| 376 |
+
|
| 377 |
+
job = _DownloadJob(
|
| 378 |
+
job_id=str(uuid.uuid4()),
|
| 379 |
+
model_id=model_id,
|
| 380 |
+
total_bytes=_SA3_CATALOG[model_id]["size_bytes"],
|
| 381 |
+
)
|
| 382 |
+
with self._jobs_lock:
|
| 383 |
+
self._jobs[job.job_id] = job
|
| 384 |
+
|
| 385 |
+
thread = threading.Thread(
|
| 386 |
+
target=self._run_download,
|
| 387 |
+
args=(job, progress_callback),
|
| 388 |
+
daemon=True,
|
| 389 |
+
name=f"sa3-download:{model_id}",
|
| 390 |
+
)
|
| 391 |
+
job._thread = thread
|
| 392 |
+
thread.start()
|
| 393 |
+
return job.to_dict()
|
| 394 |
+
|
| 395 |
+
def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
| 396 |
+
with self._jobs_lock:
|
| 397 |
+
job = self._jobs.get(job_id)
|
| 398 |
+
return job.to_dict() if job else None
|
| 399 |
+
|
| 400 |
+
def list_jobs(self) -> List[Dict[str, Any]]:
|
| 401 |
+
with self._jobs_lock:
|
| 402 |
+
return [j.to_dict() for j in self._jobs.values()]
|
| 403 |
+
|
| 404 |
+
def cancel_job(self, job_id: str) -> bool:
|
| 405 |
+
with self._jobs_lock:
|
| 406 |
+
job = self._jobs.get(job_id)
|
| 407 |
+
if not job:
|
| 408 |
return False
|
| 409 |
+
if job.status not in ("queued", "running"):
|
|
|
|
|
|
|
|
|
|
| 410 |
return False
|
| 411 |
+
job._cancel_flag.set()
|
| 412 |
+
return True
|
| 413 |
+
|
| 414 |
+
def _run_download(
|
| 415 |
+
self,
|
| 416 |
+
job: _DownloadJob,
|
| 417 |
+
progress_callback: Optional[Callable[[int, str], None]],
|
| 418 |
+
) -> None:
|
| 419 |
+
info = _SA3_CATALOG[job.model_id]
|
| 420 |
+
job.status = "running"
|
| 421 |
+
job.started_at = datetime.now().isoformat()
|
| 422 |
+
|
| 423 |
+
# Tagger kind (e.g. CLAP) is a .pt file plus auxiliary HF snapshots
|
| 424 |
+
# living outside the sa3/hub layout. Multi-phase: 1 hf_hub_download
|
| 425 |
+
# for the audio .pt, then N sequential snapshot_downloads for the
|
| 426 |
+
# text encoders. Each spawns its own tqdm bars, so we use the
|
| 427 |
+
# cumulative hook to accumulate bytes across phases, and a phase_cb
|
| 428 |
+
# to prefix the message with which step the user is on.
|
| 429 |
+
if info.get("kind") == "tagger":
|
|
|
|
| 430 |
try:
|
| 431 |
+
from app.backend.data.auto_annotator import download_clap_checkpoint
|
|
|
|
| 432 |
if progress_callback:
|
| 433 |
+
progress_callback(0, f"Downloading {info['name']}β¦")
|
| 434 |
+
# Pin total to the catalog estimate so the % stays anchored
|
| 435 |
+
# even before tqdm reports any file's size.
|
| 436 |
+
job.total_bytes = info["size_bytes"]
|
| 437 |
+
current_phase = {"label": ""}
|
| 438 |
+
|
| 439 |
+
def phase_cb(idx: int, total: int, label: str) -> None:
|
| 440 |
+
current_phase["label"] = f"[{idx}/{total}] {label}"
|
| 441 |
+
if progress_callback:
|
| 442 |
+
pct = (int(job.downloaded_bytes / job.total_bytes * 100)
|
| 443 |
+
if job.total_bytes else 0)
|
| 444 |
+
progress_callback(pct, current_phase["label"])
|
| 445 |
+
|
| 446 |
+
with _cumulative_tqdm_hook(job, progress_callback, current_phase):
|
| 447 |
+
download_clap_checkpoint(self.models_dir, phase_cb=phase_cb)
|
| 448 |
+
job.downloaded_bytes = job.total_bytes
|
| 449 |
+
job.status = "complete"
|
| 450 |
+
job.finished_at = datetime.now().isoformat()
|
| 451 |
if progress_callback:
|
| 452 |
+
progress_callback(100, f"Downloaded {info['name']}")
|
| 453 |
+
except _DownloadCancelled:
|
| 454 |
+
job.status = "cancelled"
|
| 455 |
+
job.error = "Cancelled by user"
|
| 456 |
+
job.finished_at = datetime.now().isoformat()
|
| 457 |
+
except Exception as err:
|
| 458 |
+
job.status = "failed"
|
| 459 |
+
job.error = f"{type(err).__name__}: {err}"
|
| 460 |
+
job.finished_at = datetime.now().isoformat()
|
| 461 |
+
return
|
| 462 |
+
|
| 463 |
+
cache_dir = self._hub_cache_dir_for(job.model_id).parent # = self.hub_dir
|
| 464 |
+
cache_dir.mkdir(exist_ok=True, parents=True)
|
| 465 |
+
target = self._hub_cache_dir_for(job.model_id)
|
| 466 |
+
|
| 467 |
+
token = get_token()
|
| 468 |
|
| 469 |
+
try:
|
| 470 |
+
with _tqdm_progress_hook(job, progress_callback):
|
| 471 |
+
# Write into hub/ in HF cache layout. snapshot_download in
|
| 472 |
+
# hf-hub 1.x populates `<cache_dir>/models--<org>--<name>/`
|
| 473 |
+
# with the blobs/refs/snapshots structure that
|
| 474 |
+
# hf_hub_download() and StableAudioModel.from_pretrained()
|
| 475 |
+
# both consume.
|
| 476 |
+
snapshot_download(
|
| 477 |
+
repo_id=info["repo"],
|
| 478 |
+
cache_dir=str(cache_dir),
|
| 479 |
+
token=token,
|
| 480 |
+
allow_patterns=[
|
| 481 |
+
"*.safetensors", "*.json", "*.txt", "*.model",
|
| 482 |
+
"tokenizer*", "*.tiktoken",
|
| 483 |
+
],
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Companion fetch: base models reference their T5Gemma
|
| 487 |
+
# conditioner in a subfolder of the *post-trained sibling*
|
| 488 |
+
# repo. Without it the training subprocess crashes at
|
| 489 |
+
# AutoTokenizer.from_pretrained, and inference can't build
|
| 490 |
+
# the conditioner either. Pull it eagerly so "Installed"
|
| 491 |
+
# actually means "ready to use".
|
| 492 |
+
from app.core.training.sa3_lora_runner import SA3_T5GEMMA_SIBLINGS
|
| 493 |
+
sibling = SA3_T5GEMMA_SIBLINGS.get(job.model_id)
|
| 494 |
+
if sibling:
|
| 495 |
+
sib_repo, sib_subfolder = sibling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
if progress_callback:
|
| 497 |
progress_callback(
|
| 498 |
+
min(99, int(job.downloaded_bytes / max(1, job.total_bytes) * 100)),
|
| 499 |
+
f"Fetching T5Gemma conditioner from {sib_repo}β¦",
|
| 500 |
)
|
| 501 |
+
snapshot_download(
|
| 502 |
+
repo_id=sib_repo,
|
| 503 |
+
cache_dir=str(cache_dir),
|
| 504 |
+
token=token,
|
| 505 |
+
allow_patterns=[f"{sib_subfolder}/*"],
|
| 506 |
+
)
|
| 507 |
+
job.status = "complete"
|
| 508 |
+
job.downloaded_bytes = self._dir_size(target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
if progress_callback:
|
| 510 |
+
progress_callback(100, f"Downloaded {info['name']}")
|
| 511 |
+
except _DownloadCancelled:
|
| 512 |
+
job.status = "cancelled"
|
| 513 |
+
job.error = "Cancelled by user"
|
| 514 |
+
shutil.rmtree(target, ignore_errors=True)
|
| 515 |
+
except GatedRepoError as err:
|
| 516 |
+
job.status = "failed"
|
| 517 |
+
job.error = f"hf_auth_required: {err}"
|
| 518 |
+
except RepositoryNotFoundError as err:
|
| 519 |
+
job.status = "failed"
|
| 520 |
+
job.error = f"Repository not found: {err}"
|
| 521 |
+
except Exception as err:
|
| 522 |
+
job.status = "failed"
|
| 523 |
+
job.error = str(err)
|
| 524 |
+
finally:
|
| 525 |
+
job.finished_at = datetime.now().isoformat()
|
| 526 |
+
|
| 527 |
+
# --- Delete ---------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
def delete_model(self, model_id: str) -> bool:
|
| 530 |
+
if model_id not in _SA3_CATALOG:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
return False
|
| 532 |
+
# Remove both the canonical hub copy and the legacy flat copy if
|
| 533 |
+
# they exist. Either being present is enough to consider the
|
| 534 |
+
# model "downloaded", so both must be cleaned for the row to
|
| 535 |
+
# flip back to "Get".
|
| 536 |
+
hub = self._hub_cache_dir_for(model_id)
|
| 537 |
+
legacy = self._legacy_flat_dir_for(model_id)
|
| 538 |
+
any_existed = hub.exists() or legacy.exists()
|
| 539 |
+
if hub.exists():
|
| 540 |
+
shutil.rmtree(hub, ignore_errors=True)
|
| 541 |
+
if legacy.exists():
|
| 542 |
+
shutil.rmtree(legacy, ignore_errors=True)
|
| 543 |
+
return any_existed and not (hub.exists() or legacy.exists())
|
| 544 |
+
|
| 545 |
+
# --- Storage --------------------------------------------------------------
|
| 546 |
+
|
| 547 |
+
def get_storage_info(self) -> Dict[str, Any]:
|
| 548 |
+
per_model: List[Dict[str, Any]] = []
|
| 549 |
+
total_used = 0
|
| 550 |
+
for mid in _SA3_CATALOG:
|
| 551 |
+
bytes_ = 0
|
| 552 |
+
for d in (self._hub_cache_dir_for(mid), self._legacy_flat_dir_for(mid)):
|
| 553 |
+
if d.exists():
|
| 554 |
+
bytes_ += self._dir_size(d)
|
| 555 |
+
per_model.append({
|
| 556 |
+
"id": mid,
|
| 557 |
+
"downloaded": self.is_model_downloaded(mid),
|
| 558 |
+
"bytes": bytes_,
|
| 559 |
+
})
|
| 560 |
+
total_used += bytes_
|
| 561 |
return {
|
| 562 |
+
"total_used_bytes": total_used,
|
| 563 |
+
"total_free_bytes": shutil.disk_usage(self.models_dir).free,
|
| 564 |
+
"per_model": per_model,
|
|
|
|
| 565 |
}
|
| 566 |
|
| 567 |
+
# --- Helpers --------------------------------------------------------------
|
| 568 |
+
|
| 569 |
+
@staticmethod
|
| 570 |
+
def _dir_size(path: Path) -> int:
|
| 571 |
+
if not path.exists():
|
| 572 |
+
return 0
|
| 573 |
+
return sum(p.stat().st_size for p in path.rglob("*") if p.is_file())
|
| 574 |
+
|
| 575 |
+
# --- tqdm hook ----------------------------------------------------------------
|
| 576 |
+
|
| 577 |
+
import contextlib
|
| 578 |
+
|
| 579 |
+
@contextlib.contextmanager
|
| 580 |
+
def _tqdm_progress_hook(
|
| 581 |
+
job: _DownloadJob,
|
| 582 |
+
progress_callback: Optional[Callable[[int, str], None]],
|
| 583 |
+
):
|
| 584 |
+
"""Monkey-patch tqdm so snapshot_download updates flow into the job state.
|
| 585 |
+
|
| 586 |
+
`snapshot_download` doesn't expose a progress callback. tqdm is its
|
| 587 |
+
internal progress bar β we wrap `update` to update job state and raise
|
| 588 |
+
`_DownloadCancelled` when the job's cancel flag fires.
|
| 589 |
+
"""
|
| 590 |
+
from tqdm.auto import tqdm
|
| 591 |
+
original_init = tqdm.__init__
|
| 592 |
+
|
| 593 |
+
def patched_init(self, *args: Any, **kwargs: Any) -> None:
|
| 594 |
+
original_init(self, *args, **kwargs)
|
| 595 |
+
original_update = self.update
|
| 596 |
+
|
| 597 |
+
def new_update(n: int = 1) -> Any:
|
| 598 |
+
if job._cancel_flag.is_set():
|
| 599 |
+
raise _DownloadCancelled()
|
| 600 |
+
result = original_update(n)
|
| 601 |
+
if self.total:
|
| 602 |
+
job.downloaded_bytes = max(job.downloaded_bytes, self.n)
|
| 603 |
+
if job.total_bytes < self.total:
|
| 604 |
+
job.total_bytes = self.total
|
| 605 |
+
if progress_callback:
|
| 606 |
+
pct = int(self.n / self.total * 100) if self.total else 0
|
| 607 |
+
mb_done = self.n / (1024 * 1024)
|
| 608 |
+
mb_total = self.total / (1024 * 1024)
|
| 609 |
+
progress_callback(pct, f"Downloading: {mb_done:.1f}MB / {mb_total:.1f}MB")
|
| 610 |
+
return result
|
| 611 |
+
|
| 612 |
+
self.update = new_update # type: ignore[method-assign]
|
| 613 |
+
|
| 614 |
+
tqdm.__init__ = patched_init # type: ignore[method-assign]
|
| 615 |
+
try:
|
| 616 |
+
yield
|
| 617 |
+
finally:
|
| 618 |
+
tqdm.__init__ = original_init # type: ignore[method-assign]
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
@contextlib.contextmanager
|
| 622 |
+
def _cumulative_tqdm_hook(
|
| 623 |
+
job: _DownloadJob,
|
| 624 |
+
progress_callback: Optional[Callable[[int, str], None]],
|
| 625 |
+
current_phase: Dict[str, str],
|
| 626 |
+
):
|
| 627 |
+
"""Like _tqdm_progress_hook, but sums bytes across sequential bars.
|
| 628 |
+
|
| 629 |
+
Each tqdm bar reports `self.n` cumulative within ITS file. The single-bar
|
| 630 |
+
hook uses max() which freezes the UI when a fresh bar starts smaller than
|
| 631 |
+
the previous bar's total. Here we track the previous `self.n` per bar id
|
| 632 |
+
and add only the delta to job.downloaded_bytes β so progress climbs
|
| 633 |
+
monotonically across all phases.
|
| 634 |
+
"""
|
| 635 |
+
from tqdm.auto import tqdm
|
| 636 |
+
original_init = tqdm.__init__
|
| 637 |
+
prev_n: Dict[int, int] = {}
|
| 638 |
+
|
| 639 |
+
def patched_init(self, *args: Any, **kwargs: Any) -> None:
|
| 640 |
+
original_init(self, *args, **kwargs)
|
| 641 |
+
original_update = self.update
|
| 642 |
+
prev_n[id(self)] = 0
|
| 643 |
+
|
| 644 |
+
def new_update(n: int = 1) -> Any:
|
| 645 |
+
if job._cancel_flag.is_set():
|
| 646 |
+
raise _DownloadCancelled()
|
| 647 |
+
result = original_update(n)
|
| 648 |
+
prev = prev_n.get(id(self), 0)
|
| 649 |
+
delta = self.n - prev
|
| 650 |
+
prev_n[id(self)] = self.n
|
| 651 |
+
if delta > 0:
|
| 652 |
+
job.downloaded_bytes += delta
|
| 653 |
+
if progress_callback and job.total_bytes:
|
| 654 |
+
pct = min(int(job.downloaded_bytes / job.total_bytes * 100), 99)
|
| 655 |
+
mb_done = job.downloaded_bytes / (1024 * 1024)
|
| 656 |
+
mb_total = job.total_bytes / (1024 * 1024)
|
| 657 |
+
label = current_phase.get("label", "")
|
| 658 |
+
msg = (f"{label} Β· {mb_done:.0f} MB / {mb_total:.0f} MB"
|
| 659 |
+
if label else f"{mb_done:.0f} MB / {mb_total:.0f} MB")
|
| 660 |
+
progress_callback(pct, msg)
|
| 661 |
+
return result
|
| 662 |
+
|
| 663 |
+
self.update = new_update # type: ignore[method-assign]
|
| 664 |
+
|
| 665 |
+
tqdm.__init__ = patched_init # type: ignore[method-assign]
|
| 666 |
+
try:
|
| 667 |
+
yield
|
| 668 |
+
finally:
|
| 669 |
+
tqdm.__init__ = original_init # type: ignore[method-assign]
|
app/core/training/hyperparam_suggester.py
CHANGED
|
@@ -1,76 +1,67 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
-
import
|
| 13 |
-
import os
|
| 14 |
-
import subprocess
|
| 15 |
from pathlib import Path
|
| 16 |
-
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
AUDIO_EXTS = {".wav", ".mp3", ".flac", ".m4a"}
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
# 10-30s; we don't want to pay that on every button click. Cache key is the
|
| 22 |
-
# (file_count, max_mtime_int) pair β invalidates automatically when files
|
| 23 |
-
# are added/removed/touched.
|
| 24 |
-
_DURATION_CACHE_NAME = ".duration_cache.json"
|
| 25 |
|
| 26 |
|
| 27 |
def _list_audio_files(data_dir: Path) -> List[Path]:
|
|
|
|
| 28 |
if not data_dir.exists():
|
| 29 |
return []
|
| 30 |
return [
|
| 31 |
p for p in data_dir.iterdir()
|
| 32 |
-
if p.is_file() and p.suffix.lower() in
|
| 33 |
]
|
| 34 |
|
| 35 |
|
| 36 |
-
def
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
file_count = len(audio_files)
|
| 41 |
-
max_mtime = int(max(p.stat().st_mtime for p in audio_files))
|
| 42 |
-
cache_key = f"{file_count}:{max_mtime}"
|
| 43 |
-
|
| 44 |
-
if cache_path.exists():
|
| 45 |
-
try:
|
| 46 |
-
cached = json.loads(cache_path.read_text())
|
| 47 |
-
if cached.get("key") == cache_key:
|
| 48 |
-
return float(cached["duration_sec"])
|
| 49 |
-
except Exception:
|
| 50 |
-
pass
|
| 51 |
-
|
| 52 |
-
total = 0.0
|
| 53 |
for f in audio_files:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
}))
|
| 70 |
-
except Exception:
|
| 71 |
-
pass
|
| 72 |
-
|
| 73 |
-
return total
|
| 74 |
|
| 75 |
|
| 76 |
def _detect_vram_gb() -> Optional[float]:
|
|
@@ -83,6 +74,9 @@ def _detect_vram_gb() -> Optional[float]:
|
|
| 83 |
return None
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
def _bucket(file_count: int) -> str:
|
| 87 |
if file_count < 20:
|
| 88 |
return "tiny"
|
|
@@ -93,66 +87,136 @@ def _bucket(file_count: int) -> str:
|
|
| 93 |
return "large"
|
| 94 |
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
bucket = _bucket(file_count)
|
| 101 |
-
|
| 102 |
-
constrained = (
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
"medium": 1500,
|
| 110 |
-
"large": 3000,
|
| 111 |
-
}
|
| 112 |
-
target_steps = target_steps_by_bucket[bucket]
|
| 113 |
-
|
| 114 |
-
# Rank/LR/alpha scale with how much "capacity per data point" the run needs.
|
| 115 |
-
# Small dataset trick: keep rank moderate (16) and conservative LR (1e-4 β
|
| 116 |
-
# 2e-4 caused overshoot/flat loss in testing), but boost alpha so the
|
| 117 |
-
# LoRA delta trains at higher effective voltage (scaling = alpha/rank).
|
| 118 |
-
# This produces a stronger imprint without the parameter bloat of rank=32
|
| 119 |
-
# or the instability of higher LR.
|
| 120 |
-
if bucket in ("tiny", "small"):
|
| 121 |
-
rank, alpha, lr = 16, 32, 1e-4
|
| 122 |
-
else:
|
| 123 |
-
rank, alpha, lr = 16, 16, 1e-4
|
| 124 |
-
|
| 125 |
-
# Batch size: smaller on small datasets (more updates per epoch + better
|
| 126 |
-
# gradient noise); larger on medium/large for throughput. VRAM caps the top.
|
| 127 |
-
if bucket == "tiny":
|
| 128 |
-
batch = 1 if constrained else 2
|
| 129 |
-
elif bucket == "small":
|
| 130 |
-
# Hold batch=2 even on roomy VRAM β the noise benefit on a small
|
| 131 |
-
# dataset outweighs the throughput win, and it keeps the epoch
|
| 132 |
-
# count to a reasonable display number.
|
| 133 |
-
batch = 2
|
| 134 |
-
elif bucket == "medium":
|
| 135 |
-
batch = 2 if constrained else 4
|
| 136 |
-
else:
|
| 137 |
-
batch = 4 if constrained else 8
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
|
| 142 |
return {
|
|
|
|
| 143 |
"batchSize": batch,
|
| 144 |
-
"learningRate":
|
| 145 |
-
"
|
| 146 |
-
"
|
| 147 |
-
"
|
| 148 |
-
"
|
| 149 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
"_meta": {
|
| 151 |
"bucket": bucket,
|
| 152 |
-
"target_steps":
|
| 153 |
-
"steps_per_epoch": steps_per_epoch,
|
| 154 |
-
"total_steps": steps_per_epoch * epochs,
|
| 155 |
"vram_constrained": constrained,
|
|
|
|
| 156 |
},
|
| 157 |
}
|
| 158 |
|
|
@@ -166,68 +230,162 @@ def _format_duration(seconds: float) -> str:
|
|
| 166 |
return f"{m}m {s}s"
|
| 167 |
|
| 168 |
|
| 169 |
-
def _compose_rationale(
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
bullets.append(
|
| 174 |
-
f"Dataset: {file_count}
|
| 175 |
-
f"total {_format_duration(
|
| 176 |
-
f"\"{meta['bucket']}\" bucket."
|
| 177 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
if vram_gb is not None:
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
else:
|
| 182 |
-
bullets.append("No GPU detected β
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
f"{meta['total_steps']} steps over the recommended epoch count."
|
| 187 |
-
)
|
| 188 |
-
if meta["bucket"] in ("tiny", "small"):
|
| 189 |
bullets.append(
|
| 190 |
-
"
|
| 191 |
-
"
|
| 192 |
-
"double voltage. Stronger imprint without overshoot risk."
|
| 193 |
)
|
| 194 |
else:
|
| 195 |
bullets.append(
|
| 196 |
-
"
|
| 197 |
-
"
|
|
|
|
| 198 |
)
|
| 199 |
-
return bullets
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
audio_files = _list_audio_files(data_dir)
|
| 205 |
file_count = len(audio_files)
|
| 206 |
if file_count == 0:
|
| 207 |
return {
|
| 208 |
"ok": False,
|
| 209 |
-
"error":
|
|
|
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
|
| 212 |
-
|
| 213 |
-
duration_sec = _measure_total_duration(audio_files, cache_path)
|
| 214 |
vram_gb = _detect_vram_gb()
|
| 215 |
|
| 216 |
-
suggestion = _heuristic(file_count,
|
| 217 |
meta = suggestion.pop("_meta")
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
return {
|
| 221 |
"ok": True,
|
| 222 |
"stats": {
|
| 223 |
"file_count": file_count,
|
| 224 |
-
"duration_sec":
|
| 225 |
-
"duration_human": _format_duration(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
"vram_gb": round(vram_gb, 2) if vram_gb is not None else None,
|
| 227 |
"bucket": meta["bucket"],
|
| 228 |
-
"
|
| 229 |
-
"
|
| 230 |
},
|
| 231 |
"config": suggestion,
|
| 232 |
-
"rationale":
|
|
|
|
| 233 |
}
|
|
|
|
| 1 |
+
"""SA3 LoRA hyperparameter suggester for the Training tab's "Suggest" button.
|
| 2 |
+
|
| 3 |
+
Reads a Dataset Workbench project directly β counts SA3-compatible audio
|
| 4 |
+
files, measures their durations via the same `soundfile.info()` header-only
|
| 5 |
+
probe used elsewhere in the app, factors in the user's picked base model
|
| 6 |
+
and detected GPU VRAM, and returns a config that:
|
| 7 |
+
|
| 8 |
+
* matches the upstream SA3 LoRA docs as the starting point
|
| 9 |
+
(see vendor/stable-audio-3/docs/workflows/lora.md)
|
| 10 |
+
* sets `--include transformer.layers` and `--exclude seconds_total
|
| 11 |
+
to_local_embed` by default (documented best practices, prevents the
|
| 12 |
+
"conditioner hijacking" failure mode on small datasets)
|
| 13 |
+
* picks a `-XS` adapter family when VRAM is tight for the chosen base
|
| 14 |
+
* proposes a `duration` derived from the actual clip lengths in the
|
| 15 |
+
project β not a hardcoded 30s
|
| 16 |
+
* warns when the dataset is below SA3's documented minimum (~20 clips)
|
| 17 |
+
or when clips are too short to learn from
|
| 18 |
+
|
| 19 |
+
Returns the same shape the frontend `trainingConfig` uses, so Apply can
|
| 20 |
+
spread the result into state directly.
|
| 21 |
"""
|
| 22 |
|
| 23 |
from __future__ import annotations
|
| 24 |
|
| 25 |
+
import math
|
|
|
|
|
|
|
| 26 |
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 28 |
+
|
| 29 |
+
from app.backend.data.projects import _clip_duration_sec
|
| 30 |
+
from app.core.training.sa3_lora_runner import SA3_AUDIO_EXTENSIONS, SA3_BASE_MODELS
|
| 31 |
|
|
|
|
| 32 |
|
| 33 |
+
# --- Discovery -------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def _list_audio_files(data_dir: Path) -> List[Path]:
|
| 37 |
+
"""Files SA3's loader would actually train on. Mirrors the loader's filter."""
|
| 38 |
if not data_dir.exists():
|
| 39 |
return []
|
| 40 |
return [
|
| 41 |
p for p in data_dir.iterdir()
|
| 42 |
+
if p.is_file() and p.suffix.lower() in SA3_AUDIO_EXTENSIONS
|
| 43 |
]
|
| 44 |
|
| 45 |
|
| 46 |
+
def _duration_stats(audio_files: List[Path]) -> Dict[str, Optional[float]]:
|
| 47 |
+
"""Header-only duration probe + summary stats. None-safe for unreadable files."""
|
| 48 |
+
durations: List[float] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
for f in audio_files:
|
| 50 |
+
d = _clip_duration_sec(f)
|
| 51 |
+
if d is not None and d > 0:
|
| 52 |
+
durations.append(d)
|
| 53 |
+
if not durations:
|
| 54 |
+
return {"count": 0, "total": 0.0, "median": None, "p95": None, "max": None, "min": None}
|
| 55 |
+
durations.sort()
|
| 56 |
+
n = len(durations)
|
| 57 |
+
return {
|
| 58 |
+
"count": n,
|
| 59 |
+
"total": float(sum(durations)),
|
| 60 |
+
"median": float(durations[n // 2]),
|
| 61 |
+
"p95": float(durations[min(n - 1, int(math.ceil(0.95 * n)) - 1)]),
|
| 62 |
+
"max": float(durations[-1]),
|
| 63 |
+
"min": float(durations[0]),
|
| 64 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def _detect_vram_gb() -> Optional[float]:
|
|
|
|
| 74 |
return None
|
| 75 |
|
| 76 |
|
| 77 |
+
# --- Bucketing & sizing ----------------------------------------------------
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def _bucket(file_count: int) -> str:
|
| 81 |
if file_count < 20:
|
| 82 |
return "tiny"
|
|
|
|
| 87 |
return "large"
|
| 88 |
|
| 89 |
|
| 90 |
+
# SA3's documented quick-start: --steps 1000, with no dataset-size caveat.
|
| 91 |
+
# (vendor/stable-audio-3/docs/workflows/lora.md, "Standard (recommended starting point)".)
|
| 92 |
+
# SA3 trains by *windows seen*, not epochs, so a 5h dataset doesn't need more
|
| 93 |
+
# steps than a 30min one β it just produces more diverse sampling per step.
|
| 94 |
+
# We keep the SA3 default for tiny/small, and bump modestly only when a
|
| 95 |
+
# dataset is large enough that 1000 steps won't see all unique windows.
|
| 96 |
+
_STEPS_BY_BUCKET: Dict[str, int] = {
|
| 97 |
+
"tiny": 1000,
|
| 98 |
+
"small": 1000,
|
| 99 |
+
"medium": 2000,
|
| 100 |
+
"large": 4000,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Per-base-model VRAM table from SA3 docs. (standard_gb, xs_bf16_gb)
|
| 105 |
+
# Source: docs/workflows/lora.md memory table.
|
| 106 |
+
_VRAM_REQ: Dict[str, Tuple[float, float]] = {
|
| 107 |
+
"sa3-small-music-base": (2.5, 2.0),
|
| 108 |
+
"sa3-small-sfx-base": (2.5, 2.0),
|
| 109 |
+
"sa3-medium-base": (6.5, 5.5),
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _pick_adapter(base_model: Optional[str], vram_gb: Optional[float]) -> Tuple[str, bool]:
|
| 114 |
+
"""Choose adapter family. Returns (adapter_type, vram_constrained_flag).
|
| 115 |
+
|
| 116 |
+
SA3 docs recommend the `-xs` family + bf16 base precision for VRAM-limited
|
| 117 |
+
hosts. Headroom rule: standard_gb + 4 GB activations is the comfort target;
|
| 118 |
+
below that we pick the xs family.
|
| 119 |
+
"""
|
| 120 |
+
default = "dora-rows"
|
| 121 |
+
if base_model is None or vram_gb is None:
|
| 122 |
+
return default, False
|
| 123 |
+
std_gb, _xs_gb = _VRAM_REQ.get(base_model, (2.5, 2.0))
|
| 124 |
+
comfort = std_gb + 4.0
|
| 125 |
+
constrained = vram_gb < comfort
|
| 126 |
+
return ("dora-rows-xs" if constrained else default), constrained
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _model_max_window_sec(base_model: Optional[str]) -> float:
|
| 130 |
+
"""SA3's native training length for the base, from its model config
|
| 131 |
+
sample_size / sample_rate: medium-base β380s, small bases β120s. The
|
| 132 |
+
`seconds_total` conditioner caps at 384s, so 380 is the safe medium ceiling.
|
| 133 |
+
Longer windows aren't a model limit below these β they're VRAM/time bound.
|
| 134 |
+
"""
|
| 135 |
+
if base_model and "medium" in base_model:
|
| 136 |
+
return 380.0
|
| 137 |
+
return 120.0
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _pick_duration(p95_clip_sec: Optional[float], base_model: Optional[str]) -> float:
|
| 141 |
+
"""Set training window from the project's actual p95 clip length.
|
| 142 |
+
|
| 143 |
+
Floors at 5s; caps at β and defaults to β the model's native length
|
| 144 |
+
(β120s small / β380s medium) rather than an arbitrary 30s. SA3 random-crops
|
| 145 |
+
longer files, so the only real limits are the model's sequence length and
|
| 146 |
+
VRAM. Rounds up p95 with 2s headroom so the window isn't cropping the tails
|
| 147 |
+
of typical clips. With no duration data, defaults to the model max.
|
| 148 |
+
"""
|
| 149 |
+
model_max = _model_max_window_sec(base_model)
|
| 150 |
+
if p95_clip_sec is None or p95_clip_sec <= 0:
|
| 151 |
+
return model_max
|
| 152 |
+
suggested = math.ceil(p95_clip_sec + 2.0)
|
| 153 |
+
return float(max(5, min(model_max, suggested)))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _pick_batch_size(bucket: str, vram_gb: Optional[float]) -> int:
|
| 157 |
+
"""SA3 examples all use batch 1. Only go higher on roomy hardware + big data.
|
| 158 |
+
|
| 159 |
+
24 GB threshold for batch 2 leaves enough headroom for medium-base + bf16
|
| 160 |
+
activations across two samples. Going beyond batch 2 hits diminishing
|
| 161 |
+
returns and risks OOM mid-run.
|
| 162 |
+
"""
|
| 163 |
+
if vram_gb is None or vram_gb < 24:
|
| 164 |
+
return 1
|
| 165 |
+
if bucket in ("medium", "large"):
|
| 166 |
+
return 2
|
| 167 |
+
return 1
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Filter pattern straight from SA3 docs:
|
| 171 |
+
# --include transformer.layers --exclude seconds_total to_local_embed
|
| 172 |
+
# "Everything except local embedding and seconds_total conditioner" β prevents
|
| 173 |
+
# the conditioner-hijacking failure mode that bites small datasets hardest.
|
| 174 |
+
_INCLUDE_DEFAULT: List[str] = ["transformer.layers"]
|
| 175 |
+
_EXCLUDE_DEFAULT: List[str] = ["seconds_total", "to_local_embed"]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# --- Suggestion + rationale ------------------------------------------------
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _heuristic(
|
| 182 |
+
file_count: int,
|
| 183 |
+
dur_stats: Dict[str, Optional[float]],
|
| 184 |
+
base_model: Optional[str],
|
| 185 |
+
vram_gb: Optional[float],
|
| 186 |
+
) -> Dict[str, Any]:
|
| 187 |
bucket = _bucket(file_count)
|
| 188 |
+
steps = _STEPS_BY_BUCKET[bucket]
|
| 189 |
+
adapter, constrained = _pick_adapter(base_model, vram_gb)
|
| 190 |
+
duration = _pick_duration(dur_stats.get("p95"), base_model)
|
| 191 |
+
batch = _pick_batch_size(bucket, vram_gb)
|
| 192 |
+
|
| 193 |
+
# Mild dropout for tiny datasets only β extra regularization where overfit
|
| 194 |
+
# is most likely. SA3 default is 0.0; we deviate intentionally.
|
| 195 |
+
dropout = 0.05 if bucket == "tiny" else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
+
# Checkpoint cadence: ~10 checkpoints per run, but keep within sane bounds
|
| 198 |
+
# so we don't write a checkpoint every 50 steps on tiny runs or sit on a
|
| 199 |
+
# 2K-step gap on long ones.
|
| 200 |
+
checkpoint_every = max(250, min(1000, steps // 10))
|
| 201 |
|
| 202 |
return {
|
| 203 |
+
"steps": steps,
|
| 204 |
"batchSize": batch,
|
| 205 |
+
"learningRate": 1e-4,
|
| 206 |
+
"loraRank": 16,
|
| 207 |
+
"loraAlpha": 16,
|
| 208 |
+
"loraDropout": dropout,
|
| 209 |
+
"adapterType": adapter,
|
| 210 |
+
"precision": "bf16",
|
| 211 |
+
"duration": duration,
|
| 212 |
+
"checkpointSteps": checkpoint_every,
|
| 213 |
+
"include": list(_INCLUDE_DEFAULT),
|
| 214 |
+
"exclude": list(_EXCLUDE_DEFAULT),
|
| 215 |
"_meta": {
|
| 216 |
"bucket": bucket,
|
| 217 |
+
"target_steps": steps,
|
|
|
|
|
|
|
| 218 |
"vram_constrained": constrained,
|
| 219 |
+
"picked_adapter_for_vram": constrained,
|
| 220 |
},
|
| 221 |
}
|
| 222 |
|
|
|
|
| 230 |
return f"{m}m {s}s"
|
| 231 |
|
| 232 |
|
| 233 |
+
def _compose_rationale(
|
| 234 |
+
file_count: int,
|
| 235 |
+
dur_stats: Dict[str, Optional[float]],
|
| 236 |
+
base_model: Optional[str],
|
| 237 |
+
vram_gb: Optional[float],
|
| 238 |
+
config: Dict[str, Any],
|
| 239 |
+
meta: Dict[str, Any],
|
| 240 |
+
) -> Tuple[List[str], List[str]]:
|
| 241 |
+
"""Return (bullets, warnings). Warnings are surfaced separately in the UI."""
|
| 242 |
+
bullets: List[str] = []
|
| 243 |
+
warnings: List[str] = []
|
| 244 |
+
|
| 245 |
+
total = dur_stats.get("total") or 0.0
|
| 246 |
bullets.append(
|
| 247 |
+
f"Dataset: {file_count} clip{'s' if file_count != 1 else ''}, "
|
| 248 |
+
f"total {_format_duration(total)} β \"{meta['bucket']}\" bucket."
|
|
|
|
| 249 |
)
|
| 250 |
+
|
| 251 |
+
p95 = dur_stats.get("p95")
|
| 252 |
+
median = dur_stats.get("median")
|
| 253 |
+
if p95 is not None and median is not None:
|
| 254 |
+
bullets.append(
|
| 255 |
+
f"Clip durations: median {median:.1f}s, p95 {p95:.1f}s. "
|
| 256 |
+
f"Training window set to {config['duration']:.0f}s."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
if vram_gb is not None:
|
| 260 |
+
bullets.append(
|
| 261 |
+
f"Detected GPU: {vram_gb:.1f} GB"
|
| 262 |
+
+ (" (tight for the chosen base β switched adapter to a -XS variant)."
|
| 263 |
+
if meta["vram_constrained"] else " (comfortable headroom).")
|
| 264 |
+
)
|
| 265 |
else:
|
| 266 |
+
bullets.append("No CUDA GPU detected β adapter defaults to dora-rows; "
|
| 267 |
+
"training will run on CPU/MPS where supported.")
|
| 268 |
+
|
| 269 |
+
if meta["target_steps"] == 1000:
|
|
|
|
|
|
|
|
|
|
| 270 |
bullets.append(
|
| 271 |
+
"Target 1 000 optimizer steps β SA3's documented quick-start. "
|
| 272 |
+
"LoRAs typically overfit well before this; watch the loss curve."
|
|
|
|
| 273 |
)
|
| 274 |
else:
|
| 275 |
bullets.append(
|
| 276 |
+
f"Target {meta['target_steps']:,} optimizer steps β modest bump "
|
| 277 |
+
f"above SA3's 1 000-step default for larger datasets to see more "
|
| 278 |
+
"unique sampling windows."
|
| 279 |
)
|
|
|
|
| 280 |
|
| 281 |
+
bullets.append(
|
| 282 |
+
f"Layer filter: include `{config['include'][0]}`, exclude "
|
| 283 |
+
f"`{' '.join(config['exclude'])}`. "
|
| 284 |
+
"Documented SA3 default β prevents conditioner-hijacking on small sets."
|
| 285 |
+
)
|
| 286 |
|
| 287 |
+
bullets.append(
|
| 288 |
+
f"Adapter `{config['adapterType']}` Β· rank 16 Β· Ξ± 16 Β· "
|
| 289 |
+
f"dropout {config['loraDropout']} Β· {config['precision']} base."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# --- Warnings (separate channel) ---------------------------------------
|
| 293 |
+
|
| 294 |
+
if file_count < 20:
|
| 295 |
+
warnings.append(
|
| 296 |
+
f"{file_count} clips is below SA3's documented minimum of ~20. "
|
| 297 |
+
"Expect heavy overfit and poor generalization β add more data if you can."
|
| 298 |
+
)
|
| 299 |
+
if median is not None and median < 2.0:
|
| 300 |
+
warnings.append(
|
| 301 |
+
f"Median clip is only {median:.1f}s β most of the training window "
|
| 302 |
+
f"({config['duration']:.0f}s) will be silence-padded. "
|
| 303 |
+
"Re-slice the source material to longer chunks for better signal."
|
| 304 |
+
)
|
| 305 |
+
if config["duration"] > 45:
|
| 306 |
+
warnings.append(
|
| 307 |
+
f"Training window is {config['duration']:.0f}s. Longer windows use "
|
| 308 |
+
"markedly more VRAM and step time (DiT attention scales with length). "
|
| 309 |
+
"If you hit OOM, lower the window or pre-encode the dataset first."
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# VRAM Γ base model crosscheck
|
| 313 |
+
if base_model in _VRAM_REQ:
|
| 314 |
+
std_gb, xs_gb = _VRAM_REQ[base_model]
|
| 315 |
+
if vram_gb is None:
|
| 316 |
+
if base_model == "sa3-medium-base":
|
| 317 |
+
warnings.append(
|
| 318 |
+
"No CUDA GPU detected, but you picked Medium-Base. "
|
| 319 |
+
"Medium-base needs CUDA + Flash-Attn 2 (Linux) and β₯5.5 GB VRAM. "
|
| 320 |
+
"Consider Small-Music-Base or Small-SFX-Base for CPU/MPS hosts."
|
| 321 |
+
)
|
| 322 |
+
elif vram_gb < xs_gb:
|
| 323 |
+
warnings.append(
|
| 324 |
+
f"GPU has {vram_gb:.1f} GB; even {base_model} with bf16+lora-xs needs "
|
| 325 |
+
f"~{xs_gb:.1f} GB. Training will likely OOM. Pick a smaller base."
|
| 326 |
+
)
|
| 327 |
+
elif vram_gb < std_gb:
|
| 328 |
+
warnings.append(
|
| 329 |
+
f"GPU has {vram_gb:.1f} GB; {base_model} standard config needs "
|
| 330 |
+
f"~{std_gb:.1f} GB. The -XS adapter (selected) brings it to ~{xs_gb:.1f} GB."
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return bullets, warnings
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def suggest(data_dir: Path, base_model: Optional[str] = None) -> Dict[str, Any]:
|
| 337 |
+
"""Public entry point. SA3 is LoRA-only; no `mode` switch."""
|
| 338 |
audio_files = _list_audio_files(data_dir)
|
| 339 |
file_count = len(audio_files)
|
| 340 |
if file_count == 0:
|
| 341 |
return {
|
| 342 |
"ok": False,
|
| 343 |
+
"error": (
|
| 344 |
+
f"No SA3-compatible audio in {data_dir}. SA3's loader accepts "
|
| 345 |
+
+ ", ".join(SA3_AUDIO_EXTENSIONS) + "."
|
| 346 |
+
),
|
| 347 |
}
|
| 348 |
|
| 349 |
+
dur_stats = _duration_stats(audio_files)
|
|
|
|
| 350 |
vram_gb = _detect_vram_gb()
|
| 351 |
|
| 352 |
+
suggestion = _heuristic(file_count, dur_stats, base_model, vram_gb)
|
| 353 |
meta = suggestion.pop("_meta")
|
| 354 |
+
bullets, warnings = _compose_rationale(
|
| 355 |
+
file_count, dur_stats, base_model, vram_gb, suggestion, meta
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Caption coverage: SA3 trains on audio + matching .txt sidecars, and
|
| 359 |
+
# silently drops clips whose prompt is blank. Surface missing captions so
|
| 360 |
+
# the user isn't unknowingly training on a fraction of the dataset.
|
| 361 |
+
uncaptioned = sum(
|
| 362 |
+
1 for p in audio_files
|
| 363 |
+
if not (p.with_suffix(".txt").exists()
|
| 364 |
+
and p.with_suffix(".txt").read_text(encoding="utf-8", errors="ignore").strip())
|
| 365 |
+
)
|
| 366 |
+
if uncaptioned:
|
| 367 |
+
warnings.insert(0, (
|
| 368 |
+
f"{uncaptioned} of {file_count} clip{'s' if file_count != 1 else ''} "
|
| 369 |
+
"have no annotation. SA3 silently skips un-captioned clips at train "
|
| 370 |
+
"time β annotate them first or they won't contribute to the LoRA."
|
| 371 |
+
))
|
| 372 |
|
| 373 |
return {
|
| 374 |
"ok": True,
|
| 375 |
"stats": {
|
| 376 |
"file_count": file_count,
|
| 377 |
+
"duration_sec": dur_stats.get("total") or 0.0,
|
| 378 |
+
"duration_human": _format_duration(dur_stats.get("total") or 0.0),
|
| 379 |
+
"median_clip_sec": dur_stats.get("median"),
|
| 380 |
+
"p95_clip_sec": dur_stats.get("p95"),
|
| 381 |
+
"max_clip_sec": dur_stats.get("max"),
|
| 382 |
+
"min_clip_sec": dur_stats.get("min"),
|
| 383 |
"vram_gb": round(vram_gb, 2) if vram_gb is not None else None,
|
| 384 |
"bucket": meta["bucket"],
|
| 385 |
+
"total_steps": meta["target_steps"],
|
| 386 |
+
"base_model": base_model,
|
| 387 |
},
|
| 388 |
"config": suggestion,
|
| 389 |
+
"rationale": bullets,
|
| 390 |
+
"warnings": warnings,
|
| 391 |
}
|
app/core/training/sa3_lora_runner.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for the SA3 LoRA training pipeline.
|
| 2 |
+
|
| 3 |
+
Responsibilities:
|
| 4 |
+
* Pre-stage the base model in an app-folder HF cache so the training
|
| 5 |
+
subprocess finds it without falling back to ~/.cache/huggingface.
|
| 6 |
+
* Build the train_lora.py subprocess command + env.
|
| 7 |
+
* Convert PyTorch Lightning .ckpt LoRA outputs to SA3-native .safetensors
|
| 8 |
+
with the base_model and run name embedded in the metadata header.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# SA3 model_id β (sa3_name passed to train_lora.py --model, HF repo id)
|
| 20 |
+
# Only `*-base` variants are valid LoRA targets β SA3 won't train against
|
| 21 |
+
# the post-trained / distilled checkpoints.
|
| 22 |
+
SA3_BASE_MODELS: Dict[str, Tuple[str, str]] = {
|
| 23 |
+
"sa3-small-music-base": ("small-music-base", "stabilityai/stable-audio-3-small-music-base"),
|
| 24 |
+
"sa3-small-sfx-base": ("small-sfx-base", "stabilityai/stable-audio-3-small-sfx-base"),
|
| 25 |
+
"sa3-medium-base": ("medium-base", "stabilityai/stable-audio-3-medium-base"),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Each *-base config references its T5Gemma conditioner at a subfolder of the
|
| 30 |
+
# *post-trained sibling* repo (e.g., medium-base's t5gemma lives at
|
| 31 |
+
# stabilityai/stable-audio-3-medium / t5gemma-b-b-ul2/). Without that subtree
|
| 32 |
+
# in the cache, training crashes inside the conditioner constructor when SA3
|
| 33 |
+
# does `AutoTokenizer.from_pretrained(repo_id, subfolder=...)`.
|
| 34 |
+
# Keep in sync with model_config.json's `conditioning.configs[0].config.repo_id`.
|
| 35 |
+
SA3_T5GEMMA_SIBLINGS: Dict[str, Tuple[str, str]] = {
|
| 36 |
+
"sa3-small-music-base": ("stabilityai/stable-audio-3-small-music", "t5gemma-b-b-ul2"),
|
| 37 |
+
"sa3-small-sfx-base": ("stabilityai/stable-audio-3-small-sfx", "t5gemma-b-b-ul2"),
|
| 38 |
+
"sa3-medium-base": ("stabilityai/stable-audio-3-medium", "t5gemma-b-b-ul2"),
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Extensions SA3's training data loader actually accepts.
|
| 43 |
+
# Source: vendor/stable-audio-3/stable_audio_3/data/dataset.py:91.
|
| 44 |
+
# Single source of truth β both the health check and the hyperparam suggester
|
| 45 |
+
# use this so what we count matches what the loader will train on.
|
| 46 |
+
SA3_AUDIO_EXTENSIONS: Tuple[str, ...] = (".wav", ".mp3", ".flac", ".ogg", ".aif", ".opus")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# --- Base model pre-staging -------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def prestage_base_model(
|
| 52 |
+
sa3_model_id: str,
|
| 53 |
+
hub_dir: Path,
|
| 54 |
+
token: Optional[str] = None,
|
| 55 |
+
progress_callback: Optional[Any] = None,
|
| 56 |
+
) -> Path:
|
| 57 |
+
"""Ensure the base model is in `hub_dir` (HF-cache layout, inside app folder).
|
| 58 |
+
|
| 59 |
+
train_lora.py calls `model_cfg.resolve()` which is hf_hub_download under
|
| 60 |
+
the hood β it reads from the HF cache root. We point it at hub_dir via
|
| 61 |
+
the HF_HUB_CACHE env var on the subprocess; for that to actually find
|
| 62 |
+
files we need to download into hub_dir using snapshot_download with
|
| 63 |
+
`cache_dir=hub_dir`.
|
| 64 |
+
|
| 65 |
+
Idempotent: if the model is already cached there, returns the cached
|
| 66 |
+
snapshot dir without re-downloading.
|
| 67 |
+
"""
|
| 68 |
+
if sa3_model_id not in SA3_BASE_MODELS:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"'{sa3_model_id}' is not a valid LoRA base. Pick one of "
|
| 71 |
+
f"{list(SA3_BASE_MODELS)} (only *-base variants are CFG-aware)."
|
| 72 |
+
)
|
| 73 |
+
sa3_name, repo_id = SA3_BASE_MODELS[sa3_model_id]
|
| 74 |
+
hub_dir.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
from huggingface_hub import snapshot_download
|
| 77 |
+
|
| 78 |
+
allow_patterns = [
|
| 79 |
+
"*.safetensors", "*.json", "*.txt", "*.model",
|
| 80 |
+
"tokenizer*", "*.tiktoken",
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
if progress_callback:
|
| 84 |
+
progress_callback(5, f"Staging {sa3_name} base model in {hub_dir.name}/...")
|
| 85 |
+
|
| 86 |
+
# Prefer cache. snapshot_download otherwise phones home on every run to
|
| 87 |
+
# check the model's revision β wasteful and noisy when the user just
|
| 88 |
+
# downloaded the weights through the Checkpoint Manager. If anything's
|
| 89 |
+
# missing, fall back to an online fetch.
|
| 90 |
+
try:
|
| 91 |
+
local_snap = snapshot_download(
|
| 92 |
+
repo_id=repo_id,
|
| 93 |
+
cache_dir=str(hub_dir),
|
| 94 |
+
token=token,
|
| 95 |
+
allow_patterns=allow_patterns,
|
| 96 |
+
local_files_only=True,
|
| 97 |
+
)
|
| 98 |
+
if progress_callback:
|
| 99 |
+
progress_callback(15, "Base model ready (from cache).")
|
| 100 |
+
except Exception:
|
| 101 |
+
if progress_callback:
|
| 102 |
+
progress_callback(8, "Cache miss β fetching from HuggingFaceβ¦")
|
| 103 |
+
local_snap = snapshot_download(
|
| 104 |
+
repo_id=repo_id,
|
| 105 |
+
cache_dir=str(hub_dir),
|
| 106 |
+
token=token,
|
| 107 |
+
allow_patterns=allow_patterns,
|
| 108 |
+
)
|
| 109 |
+
if progress_callback:
|
| 110 |
+
progress_callback(15, "Base model ready.")
|
| 111 |
+
|
| 112 |
+
# Pre-stage the T5Gemma conditioner from the post-trained sibling repo.
|
| 113 |
+
# SA3's *-base model_config.json points the prompt conditioner at
|
| 114 |
+
# e.g. stabilityai/stable-audio-3-medium / t5gemma-b-b-ul2/, NOT at the
|
| 115 |
+
# base repo. Without this subtree in the cache, the training subprocess
|
| 116 |
+
# (HF_HUB_OFFLINE=1) crashes when AutoTokenizer.from_pretrained tries
|
| 117 |
+
# to phone home.
|
| 118 |
+
sibling = SA3_T5GEMMA_SIBLINGS.get(sa3_model_id)
|
| 119 |
+
if sibling:
|
| 120 |
+
sib_repo, sib_subfolder = sibling
|
| 121 |
+
sib_patterns = [f"{sib_subfolder}/*"]
|
| 122 |
+
if progress_callback:
|
| 123 |
+
progress_callback(16, f"Staging T5Gemma conditioner from {sib_repo}β¦")
|
| 124 |
+
try:
|
| 125 |
+
snapshot_download(
|
| 126 |
+
repo_id=sib_repo,
|
| 127 |
+
cache_dir=str(hub_dir),
|
| 128 |
+
token=token,
|
| 129 |
+
allow_patterns=sib_patterns,
|
| 130 |
+
local_files_only=True,
|
| 131 |
+
)
|
| 132 |
+
if progress_callback:
|
| 133 |
+
progress_callback(18, "T5Gemma conditioner ready (from cache).")
|
| 134 |
+
except Exception:
|
| 135 |
+
if progress_callback:
|
| 136 |
+
progress_callback(17, f"T5Gemma cache miss β fetching from {sib_repo}β¦")
|
| 137 |
+
snapshot_download(
|
| 138 |
+
repo_id=sib_repo,
|
| 139 |
+
cache_dir=str(hub_dir),
|
| 140 |
+
token=token,
|
| 141 |
+
allow_patterns=sib_patterns,
|
| 142 |
+
)
|
| 143 |
+
if progress_callback:
|
| 144 |
+
progress_callback(18, "T5Gemma conditioner ready.")
|
| 145 |
+
|
| 146 |
+
return Path(local_snap)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# --- Subprocess command builder ---------------------------------------------
|
| 150 |
+
|
| 151 |
+
def build_train_command(
|
| 152 |
+
*,
|
| 153 |
+
venv_python: str,
|
| 154 |
+
sa3_vendor_dir: Path,
|
| 155 |
+
sa3_model_name: str,
|
| 156 |
+
data_dir: Path,
|
| 157 |
+
encoded_dir: Optional[Path] = None,
|
| 158 |
+
svd_bases_path: Optional[Path] = None,
|
| 159 |
+
save_dir: Path,
|
| 160 |
+
rank: int = 16,
|
| 161 |
+
lora_alpha: Optional[int] = None,
|
| 162 |
+
adapter_type: str = "dora-rows",
|
| 163 |
+
dropout: float = 0.0,
|
| 164 |
+
lr: float = 1e-4,
|
| 165 |
+
steps: int = 5000,
|
| 166 |
+
batch_size: int = 1,
|
| 167 |
+
duration: float = 30.0,
|
| 168 |
+
base_precision: str = "bf16",
|
| 169 |
+
include: Optional[List[str]] = None,
|
| 170 |
+
exclude: Optional[List[str]] = None,
|
| 171 |
+
seed: int = 42,
|
| 172 |
+
checkpoint_every: int = 500,
|
| 173 |
+
# `--log_every` controls how often DiffusionCondTrainingWrapper calls
|
| 174 |
+
# self.log(). 50 is SA3's example value and gives a much cleaner chart
|
| 175 |
+
# than per-step logging β diffusion loss is intrinsically noisy (each
|
| 176 |
+
# step samples a random timestep), so per-step values bounce wildly and
|
| 177 |
+
# the trend is hard to read. Sampling every 50 steps gives ~20 points
|
| 178 |
+
# for a 1000-step run, which the EMA smoother turns into a legible
|
| 179 |
+
# descent. First point arrives after step 49 (β15s on small, β50s on
|
| 180 |
+
# medium, dominated by first-step JIT warmup anyway).
|
| 181 |
+
log_every: int = 50,
|
| 182 |
+
num_workers: int = 2,
|
| 183 |
+
name: str = "fragmenta-lora",
|
| 184 |
+
) -> List[str]:
|
| 185 |
+
"""Construct the train_lora.py subprocess argv."""
|
| 186 |
+
cmd = [
|
| 187 |
+
venv_python,
|
| 188 |
+
str(sa3_vendor_dir / "scripts" / "train_lora.py"),
|
| 189 |
+
"--model", sa3_model_name,
|
| 190 |
+
"--data_dir", str(data_dir),
|
| 191 |
+
"--save_dir", str(save_dir),
|
| 192 |
+
"--rank", str(int(rank)),
|
| 193 |
+
"--adapter_type", adapter_type,
|
| 194 |
+
"--dropout", str(float(dropout)),
|
| 195 |
+
"--lr", str(float(lr)),
|
| 196 |
+
"--steps", str(int(steps)),
|
| 197 |
+
"--batch_size", str(int(batch_size)),
|
| 198 |
+
"--duration", str(float(duration)),
|
| 199 |
+
"--base_precision", base_precision,
|
| 200 |
+
"--seed", str(int(seed)),
|
| 201 |
+
"--checkpoint_every", str(int(checkpoint_every)),
|
| 202 |
+
"--log_every", str(int(log_every)),
|
| 203 |
+
"--num_workers", str(int(num_workers)),
|
| 204 |
+
"--name", name,
|
| 205 |
+
"--logger", "csv",
|
| 206 |
+
# demo_every set to a very large number β Fragmenta's training
|
| 207 |
+
# monitor doesn't surface demo audio, no need to spend cycles.
|
| 208 |
+
"--demo_every", "1000000",
|
| 209 |
+
]
|
| 210 |
+
if encoded_dir is not None:
|
| 211 |
+
# Phase 6 β feed pre-encoded latents directory. SA3's train_lora.py
|
| 212 |
+
# then uses PreEncodedDataset instead of SampleDataset and skips
|
| 213 |
+
# the SAME autoencoder pass per step.
|
| 214 |
+
cmd += ["--encoded_dir", str(encoded_dir)]
|
| 215 |
+
if svd_bases_path is not None and adapter_type.endswith("-xs"):
|
| 216 |
+
# -XS adapters factor weights against precomputed SVD bases. SA3 only
|
| 217 |
+
# *loads* bases from this path (it doesn't write them), so we pass it
|
| 218 |
+
# only when a cached .pt already exists β otherwise SA3 recomputes the
|
| 219 |
+
# SVD per layer on device (slower, but correct). See SA3Trainer for the
|
| 220 |
+
# cache path convention.
|
| 221 |
+
cmd += ["--svd_bases_path", str(svd_bases_path)]
|
| 222 |
+
if lora_alpha is not None:
|
| 223 |
+
cmd += ["--lora_alpha", str(int(lora_alpha))]
|
| 224 |
+
if include:
|
| 225 |
+
cmd += ["--include", *include]
|
| 226 |
+
if exclude:
|
| 227 |
+
cmd += ["--exclude", *exclude]
|
| 228 |
+
return cmd
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# --- Checkpoint conversion (.ckpt β .safetensors with base_model metadata) ---
|
| 232 |
+
|
| 233 |
+
def convert_run_checkpoints_to_safetensors(
|
| 234 |
+
run_dir: Path,
|
| 235 |
+
base_model: str,
|
| 236 |
+
model_name: Optional[str] = None,
|
| 237 |
+
delete_originals: bool = True,
|
| 238 |
+
) -> List[Path]:
|
| 239 |
+
"""Convert PyTorch Lightning .ckpt files in a run's checkpoints/ directory
|
| 240 |
+
to SA3's native .safetensors LoRA format, with `base_model` injected into
|
| 241 |
+
the safetensors metadata header so /api/loras can filter by it.
|
| 242 |
+
|
| 243 |
+
Why: SA3's `train_lora.py` writes Lightning .ckpt files. The inference
|
| 244 |
+
LoRA picker (/api/loras) globs for *.safetensors only. Without this
|
| 245 |
+
conversion, every trained LoRA is functionally orphaned β saved
|
| 246 |
+
correctly to disk but invisible to the inference loader.
|
| 247 |
+
|
| 248 |
+
Idempotent: skips any .ckpt whose .safetensors sibling already exists
|
| 249 |
+
with a non-zero size.
|
| 250 |
+
|
| 251 |
+
Returns the list of paths to the produced .safetensors files (sorted).
|
| 252 |
+
"""
|
| 253 |
+
ckpt_dir = run_dir / "checkpoints"
|
| 254 |
+
if not ckpt_dir.exists():
|
| 255 |
+
return []
|
| 256 |
+
|
| 257 |
+
# Imports deferred so this module can be imported without the SA3 vendor
|
| 258 |
+
# being on sys.path (e.g., during pure orchestrator construction).
|
| 259 |
+
from app.core.config import get_config
|
| 260 |
+
sa3_vendor = get_config().get_path("stable_audio_3")
|
| 261 |
+
pp = sys.path[:]
|
| 262 |
+
if str(sa3_vendor) not in pp:
|
| 263 |
+
sys.path.insert(0, str(sa3_vendor))
|
| 264 |
+
try:
|
| 265 |
+
from stable_audio_3.models.lora.utils import load_lora_checkpoint
|
| 266 |
+
from safetensors.torch import save_file as st_save_file
|
| 267 |
+
finally:
|
| 268 |
+
# Don't permanently mutate sys.path from a helper call.
|
| 269 |
+
if sys.path != pp:
|
| 270 |
+
sys.path[:] = pp
|
| 271 |
+
|
| 272 |
+
written: List[Path] = []
|
| 273 |
+
for ckpt_path in sorted(ckpt_dir.glob("*.ckpt")):
|
| 274 |
+
out_path = ckpt_path.with_suffix(".safetensors")
|
| 275 |
+
if out_path.exists() and out_path.stat().st_size > 0:
|
| 276 |
+
# Already converted (older artifact or a previous pass). Just
|
| 277 |
+
# bookkeep so the caller sees it in the return list.
|
| 278 |
+
written.append(out_path)
|
| 279 |
+
continue
|
| 280 |
+
try:
|
| 281 |
+
state_dict, lora_config = load_lora_checkpoint(ckpt_path)
|
| 282 |
+
except Exception:
|
| 283 |
+
# Corrupt or truncated ckpt β skip rather than crash the
|
| 284 |
+
# post-training pass.
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
# Top-level metadata is what /api/loras' safetensors reader inspects
|
| 288 |
+
# directly. We also keep the canonical `lora_config` JSON blob so
|
| 289 |
+
# SA3's own load_lora_checkpoint() can parse the file as-is.
|
| 290 |
+
metadata = {
|
| 291 |
+
"lora_config": json.dumps(lora_config or {}),
|
| 292 |
+
"base_model": base_model,
|
| 293 |
+
}
|
| 294 |
+
if model_name:
|
| 295 |
+
metadata["model_name"] = model_name
|
| 296 |
+
# Cast fp16 to keep file sizes consistent with SA3's standard format.
|
| 297 |
+
fp16_dict = {k: (v.half() if v.is_floating_point() else v)
|
| 298 |
+
for k, v in state_dict.items()}
|
| 299 |
+
st_save_file(fp16_dict, str(out_path), metadata=metadata)
|
| 300 |
+
if delete_originals:
|
| 301 |
+
try:
|
| 302 |
+
ckpt_path.unlink()
|
| 303 |
+
except OSError:
|
| 304 |
+
pass
|
| 305 |
+
written.append(out_path)
|
| 306 |
+
return sorted(written)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def build_train_env(sa3_vendor_dir: Path, hub_dir: Path) -> Dict[str, str]:
|
| 310 |
+
"""Subprocess env: redirect HF cache into the app folder + silence WANDB."""
|
| 311 |
+
env = os.environ.copy()
|
| 312 |
+
# Make `import stable_audio_3` work without pip-installing the package.
|
| 313 |
+
pp = env.get("PYTHONPATH", "")
|
| 314 |
+
env["PYTHONPATH"] = (
|
| 315 |
+
f"{sa3_vendor_dir}{os.pathsep}{pp}" if pp else str(sa3_vendor_dir)
|
| 316 |
+
)
|
| 317 |
+
# Pin the HF cache to our app-folder hub dir; otherwise train_lora.py's
|
| 318 |
+
# model_cfg.resolve() would write into ~/.cache/huggingface/hub. Cover
|
| 319 |
+
# the legacy + transformers env names too for defense-in-depth.
|
| 320 |
+
env["HF_HUB_CACHE"] = str(hub_dir)
|
| 321 |
+
env["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
|
| 322 |
+
env["TRANSFORMERS_CACHE"] = str(hub_dir)
|
| 323 |
+
env["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 324 |
+
env["WANDB_DISABLED"] = "1"
|
| 325 |
+
# Force the training subprocess into offline mode for HF β we already
|
| 326 |
+
# pre-staged the base model in prestage_base_model(), so any remaining
|
| 327 |
+
# network call from the SA3 internals would be a noisy revision check
|
| 328 |
+
# against a cache we know is current.
|
| 329 |
+
env["HF_HUB_OFFLINE"] = "1"
|
| 330 |
+
env["TRANSFORMERS_OFFLINE"] = "1"
|
| 331 |
+
return env
|
app/core/training/sa3_trainer.py
ADDED
|
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SA3 LoRA training orchestrator β Phase 5.
|
| 2 |
+
|
| 3 |
+
Public surface (matches what app/backend/app.py imports):
|
| 4 |
+
start_training(config) -> dict
|
| 5 |
+
get_training_status() -> dict
|
| 6 |
+
stop_training() -> dict
|
| 7 |
+
preview_training_plan(config) -> dict
|
| 8 |
+
class SA3Trainer
|
| 9 |
+
|
| 10 |
+
Training is dispatched as a subprocess running
|
| 11 |
+
`vendor/stable-audio-3/scripts/train_lora.py`. Progress comes back through
|
| 12 |
+
two channels:
|
| 13 |
+
* stdout/stderr from the subprocess (parsed for tqdm "step X/Y" lines)
|
| 14 |
+
* metrics.csv that train_lora.py writes under --save_dir
|
| 15 |
+
|
| 16 |
+
Config shape (from the frontend training form):
|
| 17 |
+
{
|
| 18 |
+
"modelName": "my-lora", # used for run dir name
|
| 19 |
+
"baseModel": "sa3-medium-base", # must end in -base
|
| 20 |
+
"projectName": "my_first_track", # Dataset Workbench project name
|
| 21 |
+
"steps": 5000,
|
| 22 |
+
"checkpointSteps": 500, # checkpoint cadence
|
| 23 |
+
"batchSize": 1,
|
| 24 |
+
"learningRate": 1.0e-4,
|
| 25 |
+
"duration": 30.0, # max clip seconds per sample
|
| 26 |
+
"loraRank": 16,
|
| 27 |
+
"loraAlpha": 16, # null β defaults to rank
|
| 28 |
+
"loraDropout": 0.0,
|
| 29 |
+
"adapterType": "dora-rows",
|
| 30 |
+
"precision": "bf16", # bf16|fp16
|
| 31 |
+
"seed": 42,
|
| 32 |
+
"include": null, # list[str] or null
|
| 33 |
+
"exclude": null
|
| 34 |
+
}
|
| 35 |
+
"""
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
import csv
|
| 39 |
+
import json
|
| 40 |
+
import os
|
| 41 |
+
import re
|
| 42 |
+
import shlex
|
| 43 |
+
import signal
|
| 44 |
+
import subprocess
|
| 45 |
+
import sys
|
| 46 |
+
import threading
|
| 47 |
+
import time
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
from typing import Any, Dict, List, Optional
|
| 50 |
+
|
| 51 |
+
from app.backend.data.projects import project_path
|
| 52 |
+
from app.core.config import get_config
|
| 53 |
+
from app.core.training.sa3_lora_runner import (
|
| 54 |
+
SA3_BASE_MODELS,
|
| 55 |
+
build_train_command,
|
| 56 |
+
build_train_env,
|
| 57 |
+
convert_run_checkpoints_to_safetensors,
|
| 58 |
+
prestage_base_model,
|
| 59 |
+
)
|
| 60 |
+
from utils.logger import get_logger
|
| 61 |
+
|
| 62 |
+
logger = get_logger("SA3Trainer")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# --- Defaults --------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
DEFAULT_STEPS = 5000
|
| 68 |
+
DEFAULT_CHECKPOINT_STEPS = 500
|
| 69 |
+
DEFAULT_BATCH_SIZE = 1
|
| 70 |
+
DEFAULT_LR = 1e-4
|
| 71 |
+
DEFAULT_DURATION = 30.0
|
| 72 |
+
DEFAULT_RANK = 16
|
| 73 |
+
DEFAULT_ADAPTER = "dora-rows"
|
| 74 |
+
DEFAULT_PRECISION = "bf16"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# --- SA3Trainer singleton --------------------------------------------------
|
| 78 |
+
|
| 79 |
+
class SA3Trainer:
|
| 80 |
+
def __init__(self, config: Dict[str, Any]) -> None:
|
| 81 |
+
self.config: Dict[str, Any] = config or {}
|
| 82 |
+
self.process: Optional[subprocess.Popen] = None
|
| 83 |
+
self.run_dir: Optional[Path] = None
|
| 84 |
+
self.metrics_csv: Optional[Path] = None
|
| 85 |
+
self._monitor_thread: Optional[threading.Thread] = None
|
| 86 |
+
self.status: Dict[str, Any] = {
|
| 87 |
+
"is_training": False,
|
| 88 |
+
"status": "idle",
|
| 89 |
+
"step": 0,
|
| 90 |
+
"total_steps": 0,
|
| 91 |
+
"loss": None,
|
| 92 |
+
"message": "",
|
| 93 |
+
"started_at": None,
|
| 94 |
+
"ended_at": None,
|
| 95 |
+
"log_tail": [], # last ~50 stdout lines
|
| 96 |
+
"checkpoints": [], # safetensors written so far
|
| 97 |
+
"error": None,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# --- Public API --------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
def start(self) -> Dict[str, Any]:
|
| 103 |
+
# Fresh run on this trainer β clear any stop flag from a prior run.
|
| 104 |
+
self._stop_requested = False
|
| 105 |
+
# Mark training as in-flight BEFORE any blocking work. /api/start-training
|
| 106 |
+
# can block for tens of seconds (T5Gemma sibling fetch, base-model
|
| 107 |
+
# prestaging) β during that window the frontend polls
|
| 108 |
+
# /api/training-status and would otherwise see is_training=False from
|
| 109 |
+
# the __init__ default and interpret it as "training complete".
|
| 110 |
+
self.status.update({
|
| 111 |
+
"is_training": True,
|
| 112 |
+
"status": "staging",
|
| 113 |
+
"started_at": time.time(),
|
| 114 |
+
"ended_at": None,
|
| 115 |
+
"step": 0,
|
| 116 |
+
"total_steps": int(self.config.get("steps") or DEFAULT_STEPS),
|
| 117 |
+
"loss": None,
|
| 118 |
+
"error": None,
|
| 119 |
+
"checkpoints": [],
|
| 120 |
+
# Surface the concrete seed (the backend rolls a random one when the
|
| 121 |
+
# UI requests it) so the user can reproduce a run they liked.
|
| 122 |
+
"seed": (int(self.config["seed"]) if self.config.get("seed") is not None else None),
|
| 123 |
+
"message": "Preparing dataset and base modelβ¦",
|
| 124 |
+
})
|
| 125 |
+
try:
|
| 126 |
+
self._maybe_wipe_run_dir()
|
| 127 |
+
self._resolve_paths()
|
| 128 |
+
self._stage_dataset()
|
| 129 |
+
self._stage_base_model()
|
| 130 |
+
cmd, env = self._build_invocation()
|
| 131 |
+
self._spawn(cmd, env)
|
| 132 |
+
logger.info(
|
| 133 |
+
"Training started Β· project=%s Β· base=%s Β· adapter=%s Β· "
|
| 134 |
+
"rank=%s Β· steps=%s Β· batch=%s Β· lr=%s Β· duration=%ss",
|
| 135 |
+
self.config.get("projectName"),
|
| 136 |
+
self.config.get("baseModel"),
|
| 137 |
+
self.config.get("adapterType") or DEFAULT_ADAPTER,
|
| 138 |
+
self.config.get("loraRank") or DEFAULT_RANK,
|
| 139 |
+
self.config.get("steps") or DEFAULT_STEPS,
|
| 140 |
+
self.config.get("batchSize") or DEFAULT_BATCH_SIZE,
|
| 141 |
+
self.config.get("learningRate") or DEFAULT_LR,
|
| 142 |
+
self.config.get("duration") or DEFAULT_DURATION,
|
| 143 |
+
)
|
| 144 |
+
return {"success": True, "run_dir": str(self.run_dir)}
|
| 145 |
+
except Exception as e:
|
| 146 |
+
self.status["error"] = str(e)
|
| 147 |
+
self.status["status"] = "failed"
|
| 148 |
+
self.status["is_training"] = False
|
| 149 |
+
self.status["ended_at"] = time.time()
|
| 150 |
+
logger.error("Training failed to start: %s", e)
|
| 151 |
+
return {"error": str(e)}
|
| 152 |
+
|
| 153 |
+
def get_status(self) -> Dict[str, Any]:
|
| 154 |
+
# Snapshot + add a few derived fields the frontend already reads, so
|
| 155 |
+
# the polling loop in App.js doesn't have to know about both names.
|
| 156 |
+
# SA3 is step-based; we no longer expose `current_epoch`.
|
| 157 |
+
# If the on-disk checkpoint count looks stale (run finished, glob
|
| 158 |
+
# ran with the old filter, no live files surfaced), rescan once
|
| 159 |
+
# lazily so the UI catches up without needing a backend restart.
|
| 160 |
+
if not self.status.get("checkpoints") and self.run_dir is not None:
|
| 161 |
+
ckpt_dir = self.run_dir / "checkpoints"
|
| 162 |
+
if ckpt_dir.exists() and any(ckpt_dir.glob("*.ckpt")):
|
| 163 |
+
self._scan_checkpoints()
|
| 164 |
+
s = dict(self.status)
|
| 165 |
+
total = int(s.get("total_steps") or 0)
|
| 166 |
+
step = int(s.get("step") or 0)
|
| 167 |
+
s["current_step"] = step
|
| 168 |
+
s["progress"] = int(round(100 * step / total)) if total > 0 else 0
|
| 169 |
+
s["checkpoints_saved"] = len(s.get("checkpoints") or [])
|
| 170 |
+
return s
|
| 171 |
+
|
| 172 |
+
def stop(self) -> Dict[str, Any]:
|
| 173 |
+
if not self.process or self.process.poll() is not None:
|
| 174 |
+
return {"error": "Nothing to stop β no active training run."}
|
| 175 |
+
try:
|
| 176 |
+
# Flag the stop so the monitor thread labels the exit "stopped"
|
| 177 |
+
# rather than "failed" β SIGINT doesn't yield a stable rc==-2.
|
| 178 |
+
self._stop_requested = True
|
| 179 |
+
self.process.send_signal(signal.SIGINT)
|
| 180 |
+
try:
|
| 181 |
+
self.process.wait(timeout=10)
|
| 182 |
+
except subprocess.TimeoutExpired:
|
| 183 |
+
self.process.terminate()
|
| 184 |
+
try:
|
| 185 |
+
self.process.wait(timeout=5)
|
| 186 |
+
except subprocess.TimeoutExpired:
|
| 187 |
+
self.process.kill()
|
| 188 |
+
self.status["status"] = "stopped"
|
| 189 |
+
self.status["is_training"] = False
|
| 190 |
+
self.status["ended_at"] = time.time()
|
| 191 |
+
return {"success": True}
|
| 192 |
+
except Exception as e:
|
| 193 |
+
return {"error": str(e)}
|
| 194 |
+
|
| 195 |
+
def preview_plan(self) -> Dict[str, Any]:
|
| 196 |
+
try:
|
| 197 |
+
self._resolve_paths(create_dirs=False)
|
| 198 |
+
except FileNotFoundError as e:
|
| 199 |
+
return {"error": str(e)}
|
| 200 |
+
steps = int(self.config.get("steps") or DEFAULT_STEPS)
|
| 201 |
+
ckpt_every = int(self.config.get("checkpointSteps") or DEFAULT_CHECKPOINT_STEPS)
|
| 202 |
+
ckpts = max(1, steps // max(1, ckpt_every))
|
| 203 |
+
proj_name = self.config.get("projectName") or self.config.get("project_name")
|
| 204 |
+
data_dir = str(project_path(proj_name)) if proj_name else None
|
| 205 |
+
return {
|
| 206 |
+
"model_name": self.config.get("modelName", "fragmenta-lora"),
|
| 207 |
+
"base_model": self.config.get("baseModel"),
|
| 208 |
+
"project_name": proj_name,
|
| 209 |
+
"data_dir": data_dir,
|
| 210 |
+
"save_dir": str(self.run_dir / "checkpoints") if self.run_dir else None,
|
| 211 |
+
"steps": steps,
|
| 212 |
+
"checkpoint_every": ckpt_every,
|
| 213 |
+
"expected_checkpoints": ckpts,
|
| 214 |
+
"rank": int(self.config.get("loraRank") or DEFAULT_RANK),
|
| 215 |
+
"alpha": int(self.config.get("loraAlpha") or self.config.get("loraRank") or DEFAULT_RANK),
|
| 216 |
+
"adapter_type": self.config.get("adapterType") or DEFAULT_ADAPTER,
|
| 217 |
+
"batch_size": int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE),
|
| 218 |
+
"lr": float(self.config.get("learningRate") or DEFAULT_LR),
|
| 219 |
+
"duration": float(self.config.get("duration") or DEFAULT_DURATION),
|
| 220 |
+
"precision": self.config.get("precision") or DEFAULT_PRECISION,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
# --- Internals ---------------------------------------------------------
|
| 224 |
+
|
| 225 |
+
def _resolve_paths(self, create_dirs: bool = True) -> None:
|
| 226 |
+
cfg = get_config()
|
| 227 |
+
run_name = self._safe_name(self.config.get("modelName") or "lora-run")
|
| 228 |
+
self.run_dir = cfg.get_path("models_fine_tuned") / run_name
|
| 229 |
+
# Lightning's CSVLogger writes metrics.csv under
|
| 230 |
+
# `<save_dir>/lightning_logs/version_X/metrics.csv`. We don't know X
|
| 231 |
+
# upfront, so leave this unset and let _scrape_loss_history /
|
| 232 |
+
# _scrape_csv_loss rglob for it the first time they're called.
|
| 233 |
+
self.metrics_csv = None
|
| 234 |
+
if create_dirs:
|
| 235 |
+
self.run_dir.mkdir(parents=True, exist_ok=True)
|
| 236 |
+
(self.run_dir / "checkpoints").mkdir(exist_ok=True)
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def existing_run_info(cls, model_name: str) -> Optional[Dict[str, Any]]:
|
| 240 |
+
"""Look up an existing run dir for a given LoRA name. Returns a dict
|
| 241 |
+
of countable artifacts if the dir exists with content, else None.
|
| 242 |
+
|
| 243 |
+
Used by /api/start-training to refuse a same-name run unless the
|
| 244 |
+
caller explicitly opts in to overwrite. Counts only *.ckpt and
|
| 245 |
+
*.safetensors so a half-set-up dir with only a metadata file
|
| 246 |
+
doesn't trip the prompt.
|
| 247 |
+
"""
|
| 248 |
+
import shutil # noqa: F401 # ensures shutil resolves if user calls _maybe_wipe later
|
| 249 |
+
cfg = get_config()
|
| 250 |
+
run_name = cls._safe_name(model_name or "lora-run")
|
| 251 |
+
run_dir = cfg.get_path("models_fine_tuned") / run_name
|
| 252 |
+
if not run_dir.exists():
|
| 253 |
+
return None
|
| 254 |
+
ckpt_dir = run_dir / "checkpoints"
|
| 255 |
+
files = []
|
| 256 |
+
if ckpt_dir.exists():
|
| 257 |
+
for ext in ("*.safetensors", "*.ckpt"):
|
| 258 |
+
files.extend(ckpt_dir.glob(ext))
|
| 259 |
+
if not files and not (run_dir / "training.log").exists():
|
| 260 |
+
return None
|
| 261 |
+
return {
|
| 262 |
+
"run_dir": str(run_dir),
|
| 263 |
+
"run_name": run_name,
|
| 264 |
+
"checkpoint_count": len(files),
|
| 265 |
+
"has_log": (run_dir / "training.log").exists(),
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
def _maybe_wipe_run_dir(self) -> None:
|
| 269 |
+
"""Honor the `overwrite` flag β wipe the run dir before staging."""
|
| 270 |
+
if not self.config.get("overwrite"):
|
| 271 |
+
return
|
| 272 |
+
cfg = get_config()
|
| 273 |
+
run_name = self._safe_name(self.config.get("modelName") or "lora-run")
|
| 274 |
+
run_dir = cfg.get_path("models_fine_tuned") / run_name
|
| 275 |
+
if run_dir.exists():
|
| 276 |
+
import shutil
|
| 277 |
+
shutil.rmtree(run_dir)
|
| 278 |
+
logger.info("Cleared existing run dir before restart: %s", run_dir)
|
| 279 |
+
|
| 280 |
+
def _stage_dataset(self) -> None:
|
| 281 |
+
"""Resolve --data_dir from a Dataset Workbench project.
|
| 282 |
+
|
| 283 |
+
Training reads the committed `.txt` sidecars sitting next to each
|
| 284 |
+
audio file inside `<projects_dir>/<projectName>/`. The Workbench's
|
| 285 |
+
"Create Dataset" action materialised those sidecars; we don't
|
| 286 |
+
rewrite anything here.
|
| 287 |
+
"""
|
| 288 |
+
project_name = self.config.get("projectName") or self.config.get("project_name")
|
| 289 |
+
if not project_name:
|
| 290 |
+
raise FileNotFoundError(
|
| 291 |
+
"projectName is required. Pick a project in the Training "
|
| 292 |
+
"tab's Dataset picker before starting a run."
|
| 293 |
+
)
|
| 294 |
+
proj_dir = project_path(project_name)
|
| 295 |
+
if not proj_dir.exists():
|
| 296 |
+
raise FileNotFoundError(f"project not found: {project_name}")
|
| 297 |
+
|
| 298 |
+
sidecars = list(proj_dir.glob("*.txt"))
|
| 299 |
+
if not sidecars:
|
| 300 |
+
raise RuntimeError(
|
| 301 |
+
f"project β{project_name}β has no committed prompts yet β "
|
| 302 |
+
"annotate the clips and click Create Dataset, then retry."
|
| 303 |
+
)
|
| 304 |
+
# SA3's caption_metadata_fn rejects clips whose sidecar is empty,
|
| 305 |
+
# so they silently drop out of the training set. Count them upfront
|
| 306 |
+
# so the user knows what they're actually training on (and refuse
|
| 307 |
+
# to start if NONE have prompts β that would just waste GPU hours).
|
| 308 |
+
non_empty = [p for p in sidecars if p.read_text(encoding="utf-8").strip()]
|
| 309 |
+
if not non_empty:
|
| 310 |
+
raise RuntimeError(
|
| 311 |
+
f"project β{project_name}β has {len(sidecars)} clip(s) but every "
|
| 312 |
+
"sidecar is empty β SA3 will reject all of them. Annotate at "
|
| 313 |
+
"least one clip and re-commit before training."
|
| 314 |
+
)
|
| 315 |
+
blank = len(sidecars) - len(non_empty)
|
| 316 |
+
if blank > 0:
|
| 317 |
+
logger.warning(
|
| 318 |
+
"%d of %d clip(s) in project '%s' have empty prompts β "
|
| 319 |
+
"SA3 will silently drop them. Training on %d clip(s).",
|
| 320 |
+
blank, len(sidecars), project_name, len(non_empty),
|
| 321 |
+
)
|
| 322 |
+
self.status["log_tail"].append(
|
| 323 |
+
f"Warning: {blank}/{len(sidecars)} clips have empty prompts and "
|
| 324 |
+
"will be dropped by SA3's data loader."
|
| 325 |
+
)
|
| 326 |
+
self.status["log_tail"].append(
|
| 327 |
+
f"Dataset: project '{project_name}' Β· {len(non_empty)} usable clip(s) Β· {proj_dir}"
|
| 328 |
+
)
|
| 329 |
+
self._data_dir = proj_dir
|
| 330 |
+
|
| 331 |
+
# Phase 6 β opt into pre-encoded latents if a compatible .latents/
|
| 332 |
+
# cache exists. SA3's `train_lora.py --encoded_dir` then skips the
|
| 333 |
+
# autoencoder pass per step. The cache is AE-bound (same-s vs
|
| 334 |
+
# same-l) so we verify the manifest matches the picked base before
|
| 335 |
+
# using it β otherwise we'd feed the DiT mis-shaped latents.
|
| 336 |
+
self._encoded_dir: Optional[Path] = None
|
| 337 |
+
try:
|
| 338 |
+
from app.backend.data.pre_encoder import (
|
| 339 |
+
latents_dir, latents_count, latents_match_base,
|
| 340 |
+
)
|
| 341 |
+
ldir = latents_dir(project_name)
|
| 342 |
+
base_model = self.config.get("baseModel")
|
| 343 |
+
if ldir.exists() and latents_count(project_name) > 0:
|
| 344 |
+
if latents_match_base(project_name, base_model):
|
| 345 |
+
self._encoded_dir = ldir
|
| 346 |
+
self.status["log_tail"].append(
|
| 347 |
+
f"Using pre-encoded latents: {latents_count(project_name)} "
|
| 348 |
+
f"file(s) Β· {ldir}"
|
| 349 |
+
)
|
| 350 |
+
logger.info(
|
| 351 |
+
"Pre-encoded latents detected for project '%s' (%d files) β "
|
| 352 |
+
"skipping SAME autoencoder per step.",
|
| 353 |
+
project_name, latents_count(project_name),
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
logger.warning(
|
| 357 |
+
"Pre-encoded latents exist for project '%s' but were "
|
| 358 |
+
"produced by a different autoencoder than the chosen "
|
| 359 |
+
"base (%s) β falling back to live encoding.",
|
| 360 |
+
project_name, base_model,
|
| 361 |
+
)
|
| 362 |
+
self.status["log_tail"].append(
|
| 363 |
+
f"Note: project has cached latents but they're for a "
|
| 364 |
+
f"different autoencoder than {base_model}. Training "
|
| 365 |
+
"will re-encode audio per step."
|
| 366 |
+
)
|
| 367 |
+
except Exception as exc:
|
| 368 |
+
logger.warning("Pre-encoded latents probe failed: %s", exc)
|
| 369 |
+
|
| 370 |
+
def _stage_base_model(self) -> None:
|
| 371 |
+
cfg = get_config()
|
| 372 |
+
base_model = self.config.get("baseModel")
|
| 373 |
+
if base_model not in SA3_BASE_MODELS:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"baseModel must be one of {list(SA3_BASE_MODELS)}. "
|
| 376 |
+
"Post-trained checkpoints (no -base suffix) can't be used "
|
| 377 |
+
"as a LoRA training base β CFG distillation has collapsed "
|
| 378 |
+
"the gradient signal LoRAs target."
|
| 379 |
+
)
|
| 380 |
+
hub_dir = cfg.get_path("models_pretrained") / "sa3" / "hub"
|
| 381 |
+
try:
|
| 382 |
+
from huggingface_hub import get_token
|
| 383 |
+
token = get_token()
|
| 384 |
+
except Exception:
|
| 385 |
+
token = None
|
| 386 |
+
|
| 387 |
+
def _cb(pct: int, msg: str) -> None:
|
| 388 |
+
self.status["message"] = msg
|
| 389 |
+
self.status["log_tail"].append(f"[stage] {msg}")
|
| 390 |
+
# Mirror to the project logger so the terminal shows what's
|
| 391 |
+
# happening during long blocking operations (e.g. first-time
|
| 392 |
+
# T5Gemma sibling fetch can take ~30s on medium-base).
|
| 393 |
+
logger.info("[stage] %s", msg)
|
| 394 |
+
|
| 395 |
+
prestage_base_model(base_model, hub_dir, token=token, progress_callback=_cb)
|
| 396 |
+
self._hub_dir = hub_dir
|
| 397 |
+
|
| 398 |
+
def _build_invocation(self):
|
| 399 |
+
cfg = get_config()
|
| 400 |
+
sa3_vendor = cfg.get_path("stable_audio_3")
|
| 401 |
+
sa3_name, _repo = SA3_BASE_MODELS[self.config["baseModel"]]
|
| 402 |
+
|
| 403 |
+
# Use the Fragmenta venv's python so we share installed packages.
|
| 404 |
+
venv_python = sys.executable
|
| 405 |
+
|
| 406 |
+
precision_raw = (self.config.get("precision") or DEFAULT_PRECISION).lower()
|
| 407 |
+
precision = "bf16" if precision_raw in ("bf16", "bfloat16", "auto", "") else "fp16"
|
| 408 |
+
|
| 409 |
+
include = self.config.get("include")
|
| 410 |
+
if include and isinstance(include, str):
|
| 411 |
+
include = shlex.split(include)
|
| 412 |
+
exclude = self.config.get("exclude")
|
| 413 |
+
if exclude and isinstance(exclude, str):
|
| 414 |
+
exclude = shlex.split(exclude)
|
| 415 |
+
|
| 416 |
+
adapter_type = self.config.get("adapterType") or DEFAULT_ADAPTER
|
| 417 |
+
|
| 418 |
+
# -XS adapters can reuse a precomputed SVD-bases cache keyed by base
|
| 419 |
+
# model, skipping the per-layer SVD at startup. SA3 only loads (never
|
| 420 |
+
# writes) this file, so we pass it only when present; population is a
|
| 421 |
+
# manual/precompute step. Ensure the dir exists so it's discoverable.
|
| 422 |
+
svd_bases_path = None
|
| 423 |
+
if adapter_type.endswith("-xs"):
|
| 424 |
+
svd_cache_dir = get_config().get_path("models_fine_tuned") / ".svd_cache"
|
| 425 |
+
svd_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 426 |
+
candidate = svd_cache_dir / f"{self.config['baseModel']}.pt"
|
| 427 |
+
if candidate.exists():
|
| 428 |
+
svd_bases_path = candidate
|
| 429 |
+
|
| 430 |
+
cmd = build_train_command(
|
| 431 |
+
venv_python=venv_python,
|
| 432 |
+
sa3_vendor_dir=sa3_vendor,
|
| 433 |
+
sa3_model_name=sa3_name,
|
| 434 |
+
data_dir=self._data_dir,
|
| 435 |
+
encoded_dir=getattr(self, "_encoded_dir", None),
|
| 436 |
+
svd_bases_path=svd_bases_path,
|
| 437 |
+
save_dir=self.run_dir / "checkpoints",
|
| 438 |
+
rank=int(self.config.get("loraRank") or DEFAULT_RANK),
|
| 439 |
+
lora_alpha=self.config.get("loraAlpha"),
|
| 440 |
+
adapter_type=adapter_type,
|
| 441 |
+
dropout=float(self.config.get("loraDropout") or 0.0),
|
| 442 |
+
lr=float(self.config.get("learningRate") or DEFAULT_LR),
|
| 443 |
+
steps=int(self.config.get("steps") or DEFAULT_STEPS),
|
| 444 |
+
batch_size=int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE),
|
| 445 |
+
# Default to AND clamp at the base model's native training length
|
| 446 |
+
# (medium β380s, small β120s) β SA3's DiT tops out at 4096 latent
|
| 447 |
+
# tokens, so a longer window would exceed the model, not just cost
|
| 448 |
+
# VRAM. A missing duration defaults to the model max.
|
| 449 |
+
duration=min(
|
| 450 |
+
float(self.config.get("duration") or (380.0 if "medium" in sa3_name else 120.0)),
|
| 451 |
+
380.0 if "medium" in sa3_name else 120.0,
|
| 452 |
+
),
|
| 453 |
+
base_precision=precision,
|
| 454 |
+
include=include,
|
| 455 |
+
exclude=exclude,
|
| 456 |
+
seed=(int(self.config["seed"]) if self.config.get("seed") is not None else 42),
|
| 457 |
+
checkpoint_every=int(self.config.get("checkpointSteps") or DEFAULT_CHECKPOINT_STEPS),
|
| 458 |
+
name=self.config.get("modelName") or "fragmenta-lora",
|
| 459 |
+
)
|
| 460 |
+
env = build_train_env(sa3_vendor, self._hub_dir)
|
| 461 |
+
return cmd, env
|
| 462 |
+
|
| 463 |
+
def _spawn(self, cmd: List[str], env: Dict[str, str]) -> None:
|
| 464 |
+
log_path = self.run_dir / "training.log"
|
| 465 |
+
rank = int(self.config.get("loraRank") or DEFAULT_RANK)
|
| 466 |
+
alpha_cfg = self.config.get("loraAlpha")
|
| 467 |
+
alpha = int(alpha_cfg) if alpha_cfg not in (None, "") else rank
|
| 468 |
+
# Stamp training_metadata.json so /api/loras can find the base_model
|
| 469 |
+
# if the embedded safetensors metadata is missing it (legacy paths).
|
| 470 |
+
(self.run_dir / "training_metadata.json").write_text(json.dumps({
|
| 471 |
+
"mode": "lora",
|
| 472 |
+
"engine": "sa3",
|
| 473 |
+
"base_model": self.config.get("baseModel"),
|
| 474 |
+
"model_name": self.config.get("modelName"),
|
| 475 |
+
"started_at": time.time(),
|
| 476 |
+
"lora_config": {
|
| 477 |
+
"rank": rank,
|
| 478 |
+
"alpha": alpha,
|
| 479 |
+
"adapter_type": self.config.get("adapterType") or DEFAULT_ADAPTER,
|
| 480 |
+
"dropout": float(self.config.get("loraDropout") or 0.0),
|
| 481 |
+
},
|
| 482 |
+
"steps": int(self.config.get("steps") or DEFAULT_STEPS),
|
| 483 |
+
"lr": float(self.config.get("learningRate") or DEFAULT_LR),
|
| 484 |
+
"batch_size": int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE),
|
| 485 |
+
}, indent=2))
|
| 486 |
+
|
| 487 |
+
self.status.update({
|
| 488 |
+
"is_training": True,
|
| 489 |
+
"status": "running",
|
| 490 |
+
"step": 0,
|
| 491 |
+
"total_steps": int(self.config.get("steps") or DEFAULT_STEPS),
|
| 492 |
+
"loss": None,
|
| 493 |
+
"error": None,
|
| 494 |
+
"started_at": time.time(),
|
| 495 |
+
"ended_at": None,
|
| 496 |
+
"checkpoints": [],
|
| 497 |
+
"message": "Starting training subprocess...",
|
| 498 |
+
})
|
| 499 |
+
|
| 500 |
+
self.process = subprocess.Popen(
|
| 501 |
+
cmd,
|
| 502 |
+
cwd=str(get_config().project_root),
|
| 503 |
+
env=env,
|
| 504 |
+
stdout=subprocess.PIPE,
|
| 505 |
+
stderr=subprocess.STDOUT,
|
| 506 |
+
text=True,
|
| 507 |
+
bufsize=1,
|
| 508 |
+
)
|
| 509 |
+
self._monitor_thread = threading.Thread(
|
| 510 |
+
target=self._monitor,
|
| 511 |
+
args=(log_path,),
|
| 512 |
+
daemon=True,
|
| 513 |
+
name=f"sa3-train-monitor:{self.run_dir.name}",
|
| 514 |
+
)
|
| 515 |
+
self._monitor_thread.start()
|
| 516 |
+
|
| 517 |
+
def _monitor(self, log_path: Path) -> None:
|
| 518 |
+
"""Pull stdout, parse PyTorch Lightning progress, scrape loss, watch checkpoints.
|
| 519 |
+
|
| 520 |
+
SA3 trains via PL whose default progress bar emits *per-epoch* step
|
| 521 |
+
counts ("Epoch 6: 50%|...| 25/50 [00:07<00:07, 3.36it/s, train/loss=0.559]").
|
| 522 |
+
We derive the global step as `epoch * batches_per_epoch + step_in_epoch`,
|
| 523 |
+
capture `batches_per_epoch` from the first such line (it's stable across
|
| 524 |
+
epochs since SampleDataset returns a fixed length), and clamp the
|
| 525 |
+
result to the configured max_steps so the percentage doesn't go past
|
| 526 |
+
100 if the final epoch overruns.
|
| 527 |
+
"""
|
| 528 |
+
epoch_pat = re.compile(r"Epoch\s+(\d+):")
|
| 529 |
+
in_epoch_pat = re.compile(r"\|\s*(\d+)/(\d+)\b") # tqdm's "current/total"
|
| 530 |
+
loss_pat = re.compile(r"train/loss=([\d.eE+\-]+)")
|
| 531 |
+
speed_pat = re.compile(r"([\d.]+)it/s")
|
| 532 |
+
last_log_flush = time.time()
|
| 533 |
+
last_ckpt_scan = 0.0
|
| 534 |
+
last_terminal_log = 0.0
|
| 535 |
+
last_logged_step = -1
|
| 536 |
+
prev_ckpt_count = 0
|
| 537 |
+
current_epoch = 0
|
| 538 |
+
batches_per_epoch = 0
|
| 539 |
+
try:
|
| 540 |
+
with open(log_path, "w") as logf:
|
| 541 |
+
if self.process and self.process.stdout:
|
| 542 |
+
for line in self.process.stdout:
|
| 543 |
+
line = line.rstrip()
|
| 544 |
+
logf.write(line + "\n")
|
| 545 |
+
if time.time() - last_log_flush > 1:
|
| 546 |
+
logf.flush()
|
| 547 |
+
last_log_flush = time.time()
|
| 548 |
+
self.status["log_tail"].append(line)
|
| 549 |
+
if len(self.status["log_tail"]) > 80:
|
| 550 |
+
self.status["log_tail"] = self.status["log_tail"][-50:]
|
| 551 |
+
|
| 552 |
+
# Only parse the step counter on lines that ARE the
|
| 553 |
+
# training progress bar (prefixed with "Epoch N:"),
|
| 554 |
+
# so unrelated tqdm bars during startup (e.g.
|
| 555 |
+
# "Loading checkpoint shards: 9/9") don't pollute
|
| 556 |
+
# batches_per_epoch.
|
| 557 |
+
m_epoch = epoch_pat.search(line)
|
| 558 |
+
if m_epoch:
|
| 559 |
+
current_epoch = int(m_epoch.group(1))
|
| 560 |
+
m_step = in_epoch_pat.search(line)
|
| 561 |
+
if m_step:
|
| 562 |
+
cur_in_epoch = int(m_step.group(1))
|
| 563 |
+
per_epoch = int(m_step.group(2))
|
| 564 |
+
if per_epoch > 0 and batches_per_epoch == 0:
|
| 565 |
+
batches_per_epoch = per_epoch
|
| 566 |
+
if batches_per_epoch > 0:
|
| 567 |
+
global_step = current_epoch * batches_per_epoch + cur_in_epoch
|
| 568 |
+
max_steps = self.status.get("total_steps") or 0
|
| 569 |
+
if max_steps > 0:
|
| 570 |
+
global_step = min(global_step, max_steps)
|
| 571 |
+
if global_step > self.status.get("step", 0):
|
| 572 |
+
self.status["step"] = global_step
|
| 573 |
+
|
| 574 |
+
m_loss = loss_pat.search(line)
|
| 575 |
+
if m_loss:
|
| 576 |
+
try:
|
| 577 |
+
self.status["loss"] = float(m_loss.group(1))
|
| 578 |
+
except ValueError:
|
| 579 |
+
pass
|
| 580 |
+
|
| 581 |
+
# Live checkpoint enumeration + loss history scrape.
|
| 582 |
+
# Lightning writes *.ckpt every N steps; we want the
|
| 583 |
+
# count to climb in the UI as files appear, not only
|
| 584 |
+
# at end-of-run. Bucketed to ~2s so we don't pound
|
| 585 |
+
# the FS. The loss history scrape backfills step
|
| 586 |
+
# 0..49 from metrics.csv since PL's stdout postfix
|
| 587 |
+
# doesn't show train/loss until end-of-epoch-0.
|
| 588 |
+
now = time.time()
|
| 589 |
+
if now - last_ckpt_scan > 2.0:
|
| 590 |
+
last_ckpt_scan = now
|
| 591 |
+
self._scan_checkpoints()
|
| 592 |
+
self._scrape_loss_history()
|
| 593 |
+
cur_ckpt_count = len(self.status.get("checkpoints") or [])
|
| 594 |
+
if cur_ckpt_count > prev_ckpt_count:
|
| 595 |
+
logger.info(
|
| 596 |
+
"Checkpoint saved Β· %d total Β· run=%s",
|
| 597 |
+
cur_ckpt_count, self.run_dir.name,
|
| 598 |
+
)
|
| 599 |
+
prev_ckpt_count = cur_ckpt_count
|
| 600 |
+
|
| 601 |
+
# Throttled progress to the backend terminal log.
|
| 602 |
+
# Lightning emits step lines ~3Γ per second; we
|
| 603 |
+
# condense to one tidy summary every 5s. Omit the
|
| 604 |
+
# loss segment when we don't have a value yet (the
|
| 605 |
+
# CSV scrape runs every 2s but PL may not have
|
| 606 |
+
# logged anything during the very first second).
|
| 607 |
+
cur_step = self.status.get("step") or 0
|
| 608 |
+
if (cur_step > last_logged_step
|
| 609 |
+
and now - last_terminal_log >= 5.0):
|
| 610 |
+
total = self.status.get("total_steps") or 0
|
| 611 |
+
loss = self.status.get("loss")
|
| 612 |
+
pct = round(100 * cur_step / total) if total > 0 else 0
|
| 613 |
+
speed_m = speed_pat.search(line)
|
| 614 |
+
parts = [f"step {cur_step}/{total} ({pct}%)"]
|
| 615 |
+
if isinstance(loss, (int, float)):
|
| 616 |
+
parts.append(f"loss {loss:.4f}")
|
| 617 |
+
if speed_m:
|
| 618 |
+
parts.append(f"{speed_m.group(1)} it/s")
|
| 619 |
+
logger.info(" Β· ".join(parts))
|
| 620 |
+
last_terminal_log = now
|
| 621 |
+
last_logged_step = cur_step
|
| 622 |
+
rc = self.process.wait() if self.process else 1
|
| 623 |
+
except Exception as e:
|
| 624 |
+
self.status["error"] = str(e)
|
| 625 |
+
rc = -1
|
| 626 |
+
|
| 627 |
+
self.status["ended_at"] = time.time()
|
| 628 |
+
self.status["is_training"] = False
|
| 629 |
+
# A user-requested stop wins regardless of the exit code (SIGINT can
|
| 630 |
+
# surface as various negative/non-zero codes across platforms).
|
| 631 |
+
if getattr(self, "_stop_requested", False):
|
| 632 |
+
self.status["status"] = "stopped"
|
| 633 |
+
else:
|
| 634 |
+
self.status["status"] = "complete" if rc == 0 else "failed"
|
| 635 |
+
if self.status["status"] == "failed" and not self.status.get("error"):
|
| 636 |
+
self.status["error"] = f"train_lora.py exited with code {rc}"
|
| 637 |
+
|
| 638 |
+
# Convert PyTorch Lightning .ckpt files to SA3's native .safetensors
|
| 639 |
+
# LoRA format β the inference loader (/api/loras) only sees
|
| 640 |
+
# .safetensors, so unconverted .ckpt files would be functionally
|
| 641 |
+
# orphaned. We also inject `base_model` into the safetensors header
|
| 642 |
+
# so /api/loras' metadata filter passes without a JSON fallback.
|
| 643 |
+
# Best-effort: failure here doesn't fail the run.
|
| 644 |
+
if self.status["status"] in ("complete", "stopped") and self.run_dir:
|
| 645 |
+
try:
|
| 646 |
+
produced = convert_run_checkpoints_to_safetensors(
|
| 647 |
+
self.run_dir,
|
| 648 |
+
base_model=self.config.get("baseModel"),
|
| 649 |
+
model_name=self.config.get("modelName"),
|
| 650 |
+
)
|
| 651 |
+
if produced:
|
| 652 |
+
logger.info(
|
| 653 |
+
"Converted %d checkpoint(s) to .safetensors Β· run=%s",
|
| 654 |
+
len(produced), self.run_dir.name,
|
| 655 |
+
)
|
| 656 |
+
except Exception as exc:
|
| 657 |
+
logger.warning("Checkpoint conversion failed: %s", exc)
|
| 658 |
+
|
| 659 |
+
# Final pass: enumerate written checkpoints + full loss history +
|
| 660 |
+
# latest single-value loss.
|
| 661 |
+
self._scan_checkpoints()
|
| 662 |
+
self._scrape_loss_history()
|
| 663 |
+
self._scrape_csv_loss()
|
| 664 |
+
|
| 665 |
+
final_step = self.status.get("step") or 0
|
| 666 |
+
final_total = self.status.get("total_steps") or 0
|
| 667 |
+
final_loss = self.status.get("loss")
|
| 668 |
+
final_ckpts = len(self.status.get("checkpoints") or [])
|
| 669 |
+
loss_str = f"{final_loss:.4f}" if isinstance(final_loss, (int, float)) else "β"
|
| 670 |
+
if self.status["status"] == "complete":
|
| 671 |
+
logger.info(
|
| 672 |
+
"Training complete Β· %d/%d steps Β· final loss %s Β· %d checkpoint(s) Β· run=%s",
|
| 673 |
+
final_step, final_total, loss_str, final_ckpts, self.run_dir.name,
|
| 674 |
+
)
|
| 675 |
+
elif self.status["status"] == "stopped":
|
| 676 |
+
logger.info(
|
| 677 |
+
"Training stopped at step %d/%d Β· %d checkpoint(s) Β· run=%s",
|
| 678 |
+
final_step, final_total, final_ckpts, self.run_dir.name,
|
| 679 |
+
)
|
| 680 |
+
else:
|
| 681 |
+
logger.error(
|
| 682 |
+
"Training failed (exit %s) Β· %d/%d steps Β· error: %s Β· run=%s",
|
| 683 |
+
rc, final_step, final_total, self.status.get("error"), self.run_dir.name,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
def _scrape_loss_history(self) -> None:
|
| 687 |
+
"""Refresh self.status['loss_history'] from Lightning's metrics.csv.
|
| 688 |
+
|
| 689 |
+
PL's tqdm postfix only surfaces `train/loss=` *after* the first
|
| 690 |
+
metrics flush (typically end-of-epoch-0), so step 0..49 of a fresh
|
| 691 |
+
run never appear in stdout. metrics.csv, on the other hand, has
|
| 692 |
+
per-step rows from step 0 β we just need to read it.
|
| 693 |
+
|
| 694 |
+
Cheap: even at 10K steps a CSV scan is sub-10ms. Skipped silently
|
| 695 |
+
if the file hasn't been created yet (early in the run, before PL's
|
| 696 |
+
CSVLogger flushes anything).
|
| 697 |
+
"""
|
| 698 |
+
if not self.metrics_csv or not self.metrics_csv.exists():
|
| 699 |
+
# CSVLogger writes under <save_dir>/lightning_logs/version_*/
|
| 700 |
+
if self.run_dir:
|
| 701 |
+
for p in (self.run_dir / "checkpoints").rglob("metrics.csv"):
|
| 702 |
+
self.metrics_csv = p
|
| 703 |
+
break
|
| 704 |
+
if not self.metrics_csv or not self.metrics_csv.exists():
|
| 705 |
+
return
|
| 706 |
+
try:
|
| 707 |
+
with open(self.metrics_csv) as f:
|
| 708 |
+
rows = list(csv.DictReader(f))
|
| 709 |
+
except Exception:
|
| 710 |
+
return
|
| 711 |
+
points: List[Dict[str, Any]] = []
|
| 712 |
+
loss_keys = ("train/loss", "loss", "train_loss")
|
| 713 |
+
for row in rows:
|
| 714 |
+
step_raw = row.get("step")
|
| 715 |
+
if step_raw in (None, ""):
|
| 716 |
+
continue
|
| 717 |
+
try:
|
| 718 |
+
step = int(step_raw)
|
| 719 |
+
except ValueError:
|
| 720 |
+
continue
|
| 721 |
+
for k in loss_keys:
|
| 722 |
+
v = row.get(k)
|
| 723 |
+
if v not in (None, ""):
|
| 724 |
+
try:
|
| 725 |
+
points.append({"step": step, "loss": float(v)})
|
| 726 |
+
except ValueError:
|
| 727 |
+
pass
|
| 728 |
+
break
|
| 729 |
+
# Dedupe: csv can have multiple rows per step (different metric flush
|
| 730 |
+
# boundaries) β keep the last loss seen for each step.
|
| 731 |
+
by_step: Dict[int, float] = {}
|
| 732 |
+
for p in points:
|
| 733 |
+
by_step[p["step"]] = p["loss"]
|
| 734 |
+
ordered = sorted(by_step.items())
|
| 735 |
+
self.status["loss_history"] = [{"step": s, "loss": l} for s, l in ordered]
|
| 736 |
+
# Also surface the most recent loss as the scalar so the terminal
|
| 737 |
+
# log and "Current Loss" field don't show "β" until end-of-epoch-0.
|
| 738 |
+
# PL's tqdm postfix is async; the CSV row lands a beat ahead.
|
| 739 |
+
if ordered:
|
| 740 |
+
self.status["loss"] = ordered[-1][1]
|
| 741 |
+
|
| 742 |
+
def _scan_checkpoints(self) -> None:
|
| 743 |
+
"""Update self.status['checkpoints'] from on-disk artifacts.
|
| 744 |
+
|
| 745 |
+
SA3's train_lora.py uses PyTorch Lightning's ModelCheckpoint, which
|
| 746 |
+
writes `.ckpt` files (Lightning pickle format). The diffusion wrapper's
|
| 747 |
+
`on_save_checkpoint` hook strips the state_dict to LoRA-only weights
|
| 748 |
+
plus the embedded `lora_config`, so each .ckpt IS a LoRA checkpoint.
|
| 749 |
+
We also accept .safetensors for forward-compat with a future export
|
| 750 |
+
path or manual conversion.
|
| 751 |
+
"""
|
| 752 |
+
if not self.run_dir:
|
| 753 |
+
return
|
| 754 |
+
ckpt_dir = self.run_dir / "checkpoints"
|
| 755 |
+
if not ckpt_dir.exists():
|
| 756 |
+
return
|
| 757 |
+
found = []
|
| 758 |
+
for ext in ("*.safetensors", "*.ckpt"):
|
| 759 |
+
found.extend(ckpt_dir.glob(ext))
|
| 760 |
+
# Lightning writes nested lightning_logs/version_X/* β those aren't
|
| 761 |
+
# the user-facing artifacts; skip recursion.
|
| 762 |
+
project_root = get_config().project_root
|
| 763 |
+
self.status["checkpoints"] = sorted(
|
| 764 |
+
str(p.relative_to(project_root)) for p in found
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
def _scrape_csv_loss(self) -> None:
|
| 768 |
+
if not self.metrics_csv or not self.metrics_csv.exists():
|
| 769 |
+
# train_lora.py writes its CSV under the lightning logger dir,
|
| 770 |
+
# which is `<save_dir>/<name>/version_*/metrics.csv`. Walk to
|
| 771 |
+
# find it.
|
| 772 |
+
ckpt_dir = self.run_dir / "checkpoints"
|
| 773 |
+
for p in ckpt_dir.rglob("metrics.csv"):
|
| 774 |
+
self.metrics_csv = p
|
| 775 |
+
break
|
| 776 |
+
if not self.metrics_csv or not self.metrics_csv.exists():
|
| 777 |
+
return
|
| 778 |
+
try:
|
| 779 |
+
with open(self.metrics_csv) as f:
|
| 780 |
+
rows = list(csv.DictReader(f))
|
| 781 |
+
for row in reversed(rows):
|
| 782 |
+
for k in ("train/loss", "loss", "train_loss"):
|
| 783 |
+
v = row.get(k)
|
| 784 |
+
if v not in (None, ""):
|
| 785 |
+
try:
|
| 786 |
+
self.status["loss"] = float(v)
|
| 787 |
+
return
|
| 788 |
+
except ValueError:
|
| 789 |
+
pass
|
| 790 |
+
except Exception:
|
| 791 |
+
pass
|
| 792 |
+
|
| 793 |
+
@staticmethod
|
| 794 |
+
def _safe_name(s: str) -> str:
|
| 795 |
+
return re.sub(r"[^a-zA-Z0-9_-]+", "_", s).strip("_") or "lora-run"
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
# --- Module-level singleton + public functions -----------------------------
|
| 799 |
+
|
| 800 |
+
_active: Optional[SA3Trainer] = None
|
| 801 |
+
_lock = threading.Lock()
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def get_trainer() -> Optional[SA3Trainer]:
|
| 805 |
+
return _active
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def start_training(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 809 |
+
global _active
|
| 810 |
+
with _lock:
|
| 811 |
+
if _active and _active.status.get("is_training"):
|
| 812 |
+
return {"error": "A training run is already in progress."}
|
| 813 |
+
_active = SA3Trainer(config)
|
| 814 |
+
return _active.start()
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
def get_training_status() -> Dict[str, Any]:
|
| 818 |
+
if _active is None:
|
| 819 |
+
return {
|
| 820 |
+
"is_training": False,
|
| 821 |
+
"status": "idle",
|
| 822 |
+
"message": "No training run has been started yet.",
|
| 823 |
+
"progress": 0,
|
| 824 |
+
"current_step": 0,
|
| 825 |
+
"total_steps": 0,
|
| 826 |
+
"checkpoints_saved": 0,
|
| 827 |
+
"loss": None,
|
| 828 |
+
}
|
| 829 |
+
return _active.get_status()
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
def stop_training() -> Dict[str, Any]:
|
| 833 |
+
if _active is None:
|
| 834 |
+
return {"error": "No training run to stop."}
|
| 835 |
+
return _active.stop()
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def preview_training_plan(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 839 |
+
return SA3Trainer(config).preview_plan()
|
app/frontend/index.html
CHANGED
|
@@ -7,15 +7,38 @@
|
|
| 7 |
<meta name="theme-color" content="#000000" />
|
| 8 |
<meta
|
| 9 |
name="description"
|
| 10 |
-
content="Fragmenta
|
| 11 |
/>
|
| 12 |
<link rel="manifest" href="/manifest.json" />
|
| 13 |
|
| 14 |
-
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 15 |
-
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 16 |
-
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500;600&family=Space+Mono:wght@400;700&family=IBM+Plex+Mono:wght@300;400;500;600&display=swap" rel="stylesheet">
|
| 17 |
-
|
| 18 |
<style>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
@font-face {
|
| 20 |
font-family: 'Bitcount Single';
|
| 21 |
src: url('/BitcountSingle-VariableFont_CRSV,ELSH,ELXP,slnt,wght.ttf') format('truetype');
|
|
@@ -25,7 +48,7 @@
|
|
| 25 |
}
|
| 26 |
</style>
|
| 27 |
|
| 28 |
-
<title>Fragmenta
|
| 29 |
</head>
|
| 30 |
<body>
|
| 31 |
<noscript>You need to enable JavaScript to run this app.</noscript>
|
|
|
|
| 7 |
<meta name="theme-color" content="#000000" />
|
| 8 |
<meta
|
| 9 |
name="description"
|
| 10 |
+
content="Fragmenta β Stable Audio Fine-Tuning Application"
|
| 11 |
/>
|
| 12 |
<link rel="manifest" href="/manifest.json" />
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
<style>
|
| 15 |
+
/* Layout floor β channel grid + master strip need 1300px
|
| 16 |
+
horizontally to sit side-by-side, and the vertical layout (top
|
| 17 |
+
bar + channels + bottom bar) gets cramped below 830px. Below
|
| 18 |
+
either floor, scrollbars appear so the layout stays intact
|
| 19 |
+
instead of collapsing. The launcher (start.py) opens Chromium
|
| 20 |
+
at 1300Γ830 so the fresh-launch experience lands exactly at
|
| 21 |
+
the floor. */
|
| 22 |
+
html, body {
|
| 23 |
+
min-width: 1300px;
|
| 24 |
+
min-height: 830px;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
/* Local variable fonts β ship with the app, no network dependency. */
|
| 28 |
+
@font-face {
|
| 29 |
+
font-family: 'Bricolage Grotesque';
|
| 30 |
+
src: url('/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf') format('truetype');
|
| 31 |
+
font-weight: 200 800;
|
| 32 |
+
font-style: normal;
|
| 33 |
+
font-display: swap;
|
| 34 |
+
}
|
| 35 |
+
@font-face {
|
| 36 |
+
font-family: 'Inter Tight';
|
| 37 |
+
src: url('/InterTight-VariableFont_wght.ttf') format('truetype');
|
| 38 |
+
font-weight: 100 900;
|
| 39 |
+
font-style: normal;
|
| 40 |
+
font-display: swap;
|
| 41 |
+
}
|
| 42 |
@font-face {
|
| 43 |
font-family: 'Bitcount Single';
|
| 44 |
src: url('/BitcountSingle-VariableFont_CRSV,ELSH,ELXP,slnt,wght.ttf') format('truetype');
|
|
|
|
| 48 |
}
|
| 49 |
</style>
|
| 50 |
|
| 51 |
+
<title>Fragmenta</title>
|
| 52 |
</head>
|
| 53 |
<body>
|
| 54 |
<noscript>You need to enable JavaScript to run this app.</noscript>
|
app/frontend/logs/fragmenta_20260525.log
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-05-25 11:21:33 | [92mINFO[0m | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
|
| 2 |
+
2026-05-25 11:21:33 | [92mINFO[0m | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
|
| 3 |
+
2026-05-25 11:44:54 | [92mINFO[0m | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
|
| 4 |
+
2026-05-25 11:44:54 | [92mINFO[0m | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
|
| 5 |
+
2026-05-25 13:55:04 | [92mINFO[0m | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
|
| 6 |
+
2026-05-25 13:55:04 | [92mINFO[0m | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
|
| 7 |
+
2026-05-25 13:55:05 | [92mINFO[0m | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
|
| 8 |
+
2026-05-25 13:55:05 | [92mINFO[0m | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
|
app/frontend/package.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
{
|
| 2 |
"name": "fragmenta-desktop",
|
| 3 |
-
"version": "0.
|
| 4 |
-
"description": "Fragmenta
|
| 5 |
"type": "module",
|
| 6 |
"scripts": {
|
| 7 |
"dev": "vite",
|
|
|
|
| 1 |
{
|
| 2 |
"name": "fragmenta-desktop",
|
| 3 |
+
"version": "0.2.0",
|
| 4 |
+
"description": "Fragmenta",
|
| 5 |
"type": "module",
|
| 6 |
"scripts": {
|
| 7 |
"dev": "vite",
|
app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31b91d15aae398699fae58363dbc8ca1167faffe7d2cd62e68c716dcaa7d5fdd
|
| 3 |
+
size 407844
|
app/frontend/public/InterTight-VariableFont_wght.ttf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b8ef9ed255ebe7341aa566554c0f3e87ee10ce06d2085f07ccf66f41ef96c28
|
| 3 |
+
size 580572
|
app/frontend/public/fragmenta_background.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
app/frontend/public/interface.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
app/frontend/src/App.js
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/frontend/src/api.js
CHANGED
|
@@ -37,6 +37,7 @@ const api = {
|
|
| 37 |
get: (url, config) => request('GET', url, null, config),
|
| 38 |
post: (url, body, config) => request('POST', url, body, config),
|
| 39 |
put: (url, body, config) => request('PUT', url, body, config),
|
|
|
|
| 40 |
delete: (url, config) => request('DELETE', url, null, config),
|
| 41 |
};
|
| 42 |
|
|
|
|
| 37 |
get: (url, config) => request('GET', url, null, config),
|
| 38 |
post: (url, body, config) => request('POST', url, body, config),
|
| 39 |
put: (url, body, config) => request('PUT', url, body, config),
|
| 40 |
+
patch: (url, body, config) => request('PATCH', url, body, config),
|
| 41 |
delete: (url, config) => request('DELETE', url, null, config),
|
| 42 |
};
|
| 43 |
|
app/frontend/src/components/AboutDialog.js
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Box,
|
| 4 |
+
Button,
|
| 5 |
+
Dialog,
|
| 6 |
+
DialogActions,
|
| 7 |
+
DialogContent,
|
| 8 |
+
DialogTitle,
|
| 9 |
+
Typography,
|
| 10 |
+
} from '@mui/material';
|
| 11 |
+
import {
|
| 12 |
+
Info as InfoIcon,
|
| 13 |
+
BookOpen as BookOpenIcon,
|
| 14 |
+
} from 'lucide-react';
|
| 15 |
+
import { appStyles } from '../theme';
|
| 16 |
+
import { APP_VERSION } from '../version';
|
| 17 |
+
|
| 18 |
+
/**
|
| 19 |
+
* "About Fragmenta" dialog β logo + title, short intro, three doc buttons
|
| 20 |
+
* (About / Documentation / Tutorials), and the Stability AI Community
|
| 21 |
+
* License attribution footer.
|
| 22 |
+
*
|
| 23 |
+
* Props:
|
| 24 |
+
* open: bool
|
| 25 |
+
* onClose: () => void
|
| 26 |
+
* onOpenDocumentation: ('about' | 'documentation') => void
|
| 27 |
+
* isOpeningDocumentation: bool β disables the doc buttons while a
|
| 28 |
+
* native open-file call is in flight
|
| 29 |
+
*/
|
| 30 |
+
export default function AboutDialog({
|
| 31 |
+
open,
|
| 32 |
+
onClose,
|
| 33 |
+
onOpenDocumentation,
|
| 34 |
+
isOpeningDocumentation,
|
| 35 |
+
}) {
|
| 36 |
+
return (
|
| 37 |
+
<Dialog
|
| 38 |
+
open={open}
|
| 39 |
+
onClose={onClose}
|
| 40 |
+
aria-labelledby="about-documentation-dialog-title"
|
| 41 |
+
maxWidth="sm"
|
| 42 |
+
fullWidth
|
| 43 |
+
>
|
| 44 |
+
<DialogTitle id="about-documentation-dialog-title">
|
| 45 |
+
<Box sx={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: 1 }}>
|
| 46 |
+
<Box sx={{
|
| 47 |
+
...appStyles.logo,
|
| 48 |
+
width: 52, height: 52,
|
| 49 |
+
border: 'none',
|
| 50 |
+
boxShadow: 'none',
|
| 51 |
+
filter: 'none',
|
| 52 |
+
}} />
|
| 53 |
+
<Typography variant="h5" component="span" sx={appStyles.title}>
|
| 54 |
+
Fragmenta
|
| 55 |
+
</Typography>
|
| 56 |
+
<Typography variant="caption" color="text.secondary" sx={{ fontSize: '0.7rem', letterSpacing: '0.04em' }}>
|
| 57 |
+
v{APP_VERSION}
|
| 58 |
+
</Typography>
|
| 59 |
+
</Box>
|
| 60 |
+
</DialogTitle>
|
| 61 |
+
<DialogContent>
|
| 62 |
+
<Typography sx={appStyles.infoDialogIntro}>
|
| 63 |
+
Fragmenta is an open source, local-first suit to prepare datasets, train, generate and perform with text-to-audio diffusion models.
|
| 64 |
+
Made by the composer and researcher Misagh Azimi.
|
| 65 |
+
</Typography>
|
| 66 |
+
|
| 67 |
+
<Box sx={appStyles.infoDialogActionStack}>
|
| 68 |
+
<Button
|
| 69 |
+
variant="contained"
|
| 70 |
+
size="small"
|
| 71 |
+
startIcon={<InfoIcon size={16} />}
|
| 72 |
+
onClick={() => onOpenDocumentation('about')}
|
| 73 |
+
disabled={isOpeningDocumentation}
|
| 74 |
+
sx={appStyles.infoDocButton}
|
| 75 |
+
>
|
| 76 |
+
About
|
| 77 |
+
</Button>
|
| 78 |
+
<Button
|
| 79 |
+
variant="outlined"
|
| 80 |
+
size="small"
|
| 81 |
+
startIcon={<BookOpenIcon size={16} />}
|
| 82 |
+
onClick={() => onOpenDocumentation('documentation')}
|
| 83 |
+
disabled={isOpeningDocumentation}
|
| 84 |
+
sx={appStyles.infoDocButton}
|
| 85 |
+
>
|
| 86 |
+
Documentation
|
| 87 |
+
</Button>
|
| 88 |
+
<Button
|
| 89 |
+
variant="outlined"
|
| 90 |
+
size="small"
|
| 91 |
+
disabled
|
| 92 |
+
sx={appStyles.infoDocButton}
|
| 93 |
+
>
|
| 94 |
+
Tutorials (Coming soon...)
|
| 95 |
+
</Button>
|
| 96 |
+
</Box>
|
| 97 |
+
|
| 98 |
+
<Box sx={{ mt: 3, pt: 1.5, borderTop: '1px solid', borderColor: 'divider', textAlign: 'center' }}>
|
| 99 |
+
<Typography variant="caption" color="textSecondary" sx={{ display: 'block', fontStyle: 'italic', fontSize: '0.6rem', lineHeight: 1.5 }}>
|
| 100 |
+
Powered by{' '}
|
| 101 |
+
<Typography
|
| 102 |
+
component="a"
|
| 103 |
+
variant="caption"
|
| 104 |
+
href="https://github.com/Stability-AI/stable-audio-3"
|
| 105 |
+
target="_blank"
|
| 106 |
+
rel="noopener noreferrer"
|
| 107 |
+
sx={{ color: 'primary.main', textDecoration: 'underline', fontStyle: 'italic', fontSize: '0.6rem' }}
|
| 108 |
+
>
|
| 109 |
+
Stable Audio 3
|
| 110 |
+
</Typography>{' '}by Stability AI. "This Stability AI Model is licensed under the{' '}
|
| 111 |
+
<Typography
|
| 112 |
+
component="a"
|
| 113 |
+
variant="caption"
|
| 114 |
+
href="https://stability.ai/license"
|
| 115 |
+
target="_blank"
|
| 116 |
+
rel="noopener noreferrer"
|
| 117 |
+
sx={{ color: 'primary.main', textDecoration: 'underline', fontStyle: 'italic', fontSize: '0.6rem' }}
|
| 118 |
+
>
|
| 119 |
+
Stability AI Community License
|
| 120 |
+
</Typography>,{' '}
|
| 121 |
+
Copyright Β© Stability AI Ltd. All Rights Reserved"
|
| 122 |
+
</Typography>
|
| 123 |
+
</Box>
|
| 124 |
+
</DialogContent>
|
| 125 |
+
<DialogActions>
|
| 126 |
+
<Button onClick={onClose}>Close</Button>
|
| 127 |
+
</DialogActions>
|
| 128 |
+
</Dialog>
|
| 129 |
+
);
|
| 130 |
+
}
|
app/frontend/src/components/AudioWaveform.js
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useEffect, useRef, useState, useCallback } from 'react';
|
| 2 |
+
import { Box, Typography } from '@mui/material';
|
| 3 |
+
|
| 4 |
+
/**
|
| 5 |
+
* Canvas waveform with a single draggable region (for SA3 inpaint UX).
|
| 6 |
+
*
|
| 7 |
+
* Decodes the supplied File via the Web Audio API (no network round-trip),
|
| 8 |
+
* computes per-pixel min/max peaks once per (file, width) pair, and renders
|
| 9 |
+
* a region overlay + two draggable handles. Region drag in three modes:
|
| 10 |
+
* - drag the left handle β adjust start
|
| 11 |
+
* - drag the right handle β adjust end
|
| 12 |
+
* - drag the body β shift the whole region in place
|
| 13 |
+
*
|
| 14 |
+
* Region is controlled: parent owns `start` / `end` in seconds.
|
| 15 |
+
*
|
| 16 |
+
* Props:
|
| 17 |
+
* file: File | null β source audio
|
| 18 |
+
* duration: number β clip length in seconds (must be passed; we
|
| 19 |
+
* don't infer it from decoded length so the
|
| 20 |
+
* caller can drive a probe before decode
|
| 21 |
+
* finishes)
|
| 22 |
+
* start, end: number β region in seconds
|
| 23 |
+
* onRegionChange: (start, end) => void
|
| 24 |
+
* minRegionSec: number β default 0.1
|
| 25 |
+
* height: number β canvas height in px (default 96)
|
| 26 |
+
* color: CSS color β waveform peak color (default theme accent)
|
| 27 |
+
* regionColor: CSS color β fill for the region rect
|
| 28 |
+
*/
|
| 29 |
+
export default function AudioWaveform({
|
| 30 |
+
file,
|
| 31 |
+
duration,
|
| 32 |
+
start,
|
| 33 |
+
end,
|
| 34 |
+
onRegionChange,
|
| 35 |
+
minRegionSec = 0.1,
|
| 36 |
+
height = 96,
|
| 37 |
+
color = '#279FBB',
|
| 38 |
+
regionColor = 'rgba(253, 162, 43, 0.28)',
|
| 39 |
+
}) {
|
| 40 |
+
const canvasRef = useRef(null);
|
| 41 |
+
const containerRef = useRef(null);
|
| 42 |
+
const [width, setWidth] = useState(0);
|
| 43 |
+
const [peaks, setPeaks] = useState(null);
|
| 44 |
+
const [decoding, setDecoding] = useState(false);
|
| 45 |
+
const [decodeError, setDecodeError] = useState(null);
|
| 46 |
+
// Drag state lives in a ref to avoid re-renders during pointer move.
|
| 47 |
+
const dragRef = useRef(null);
|
| 48 |
+
|
| 49 |
+
// --- responsive width via ResizeObserver -----------------------------
|
| 50 |
+
useEffect(() => {
|
| 51 |
+
const el = containerRef.current;
|
| 52 |
+
if (!el) return;
|
| 53 |
+
const ro = new ResizeObserver((entries) => {
|
| 54 |
+
const w = Math.max(1, Math.floor(entries[0].contentRect.width));
|
| 55 |
+
setWidth(w);
|
| 56 |
+
});
|
| 57 |
+
ro.observe(el);
|
| 58 |
+
return () => ro.disconnect();
|
| 59 |
+
}, []);
|
| 60 |
+
|
| 61 |
+
// --- decode + peak computation ---------------------------------------
|
| 62 |
+
useEffect(() => {
|
| 63 |
+
if (!file || !width) return;
|
| 64 |
+
let cancelled = false;
|
| 65 |
+
setDecoding(true);
|
| 66 |
+
setDecodeError(null);
|
| 67 |
+
|
| 68 |
+
(async () => {
|
| 69 |
+
try {
|
| 70 |
+
const buf = await file.arrayBuffer();
|
| 71 |
+
if (cancelled) return;
|
| 72 |
+
// Reuse one AudioContext where possible. Safari and Chrome both
|
| 73 |
+
// permit creating an offline one for pure decode without user
|
| 74 |
+
// gesture, which is what we want.
|
| 75 |
+
const Ctx = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
| 76 |
+
const tmpCtx = Ctx
|
| 77 |
+
? new Ctx(1, 44100, 44100)
|
| 78 |
+
: new (window.AudioContext || window.webkitAudioContext)();
|
| 79 |
+
const audio = await tmpCtx.decodeAudioData(buf.slice(0));
|
| 80 |
+
if (cancelled) return;
|
| 81 |
+
|
| 82 |
+
// Average across channels into mono peaks, then bucket into
|
| 83 |
+
// `width` columns. Each column gets (min, max) in [-1, 1].
|
| 84 |
+
const ch0 = audio.getChannelData(0);
|
| 85 |
+
const ch1 = audio.numberOfChannels > 1 ? audio.getChannelData(1) : null;
|
| 86 |
+
const totalSamples = ch0.length;
|
| 87 |
+
const bucketSize = Math.max(1, Math.floor(totalSamples / width));
|
| 88 |
+
const out = new Float32Array(width * 2);
|
| 89 |
+
for (let i = 0; i < width; i++) {
|
| 90 |
+
const s = i * bucketSize;
|
| 91 |
+
const e = Math.min(totalSamples, s + bucketSize);
|
| 92 |
+
let mn = 0, mx = 0;
|
| 93 |
+
for (let j = s; j < e; j++) {
|
| 94 |
+
const v = ch1 ? (ch0[j] + ch1[j]) * 0.5 : ch0[j];
|
| 95 |
+
if (v < mn) mn = v;
|
| 96 |
+
if (v > mx) mx = v;
|
| 97 |
+
}
|
| 98 |
+
out[i * 2] = mn;
|
| 99 |
+
out[i * 2 + 1] = mx;
|
| 100 |
+
}
|
| 101 |
+
setPeaks(out);
|
| 102 |
+
} catch (err) {
|
| 103 |
+
setDecodeError(err.message || 'Failed to decode audio');
|
| 104 |
+
} finally {
|
| 105 |
+
if (!cancelled) setDecoding(false);
|
| 106 |
+
}
|
| 107 |
+
})();
|
| 108 |
+
|
| 109 |
+
return () => { cancelled = true; };
|
| 110 |
+
}, [file, width]);
|
| 111 |
+
|
| 112 |
+
// --- canvas drawing --------------------------------------------------
|
| 113 |
+
const draw = useCallback(() => {
|
| 114 |
+
const canvas = canvasRef.current;
|
| 115 |
+
if (!canvas || !width || !height) return;
|
| 116 |
+
const dpr = window.devicePixelRatio || 1;
|
| 117 |
+
canvas.width = width * dpr;
|
| 118 |
+
canvas.height = height * dpr;
|
| 119 |
+
const ctx = canvas.getContext('2d');
|
| 120 |
+
ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
|
| 121 |
+
ctx.clearRect(0, 0, width, height);
|
| 122 |
+
|
| 123 |
+
// Background: faint center line so empty audio still shows scale.
|
| 124 |
+
ctx.fillStyle = 'rgba(255, 255, 255, 0.05)';
|
| 125 |
+
ctx.fillRect(0, height / 2 - 0.5, width, 1);
|
| 126 |
+
|
| 127 |
+
// Peaks
|
| 128 |
+
if (peaks) {
|
| 129 |
+
ctx.fillStyle = color;
|
| 130 |
+
const mid = height / 2;
|
| 131 |
+
const scale = (height - 4) / 2;
|
| 132 |
+
for (let i = 0; i < width; i++) {
|
| 133 |
+
const mn = peaks[i * 2];
|
| 134 |
+
const mx = peaks[i * 2 + 1];
|
| 135 |
+
const y0 = mid - mx * scale;
|
| 136 |
+
const y1 = mid - mn * scale;
|
| 137 |
+
ctx.fillRect(i, y0, 1, Math.max(1, y1 - y0));
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// Region overlay
|
| 142 |
+
if (duration > 0 && Number.isFinite(start) && Number.isFinite(end)) {
|
| 143 |
+
const sPx = Math.max(0, Math.min(width, (start / duration) * width));
|
| 144 |
+
const ePx = Math.max(0, Math.min(width, (end / duration) * width));
|
| 145 |
+
const rectW = Math.max(1, ePx - sPx);
|
| 146 |
+
ctx.fillStyle = regionColor;
|
| 147 |
+
ctx.fillRect(sPx, 0, rectW, height);
|
| 148 |
+
// Handles
|
| 149 |
+
ctx.fillStyle = '#FDA22B';
|
| 150 |
+
ctx.fillRect(sPx - 1, 0, 2, height);
|
| 151 |
+
ctx.fillRect(ePx - 1, 0, 2, height);
|
| 152 |
+
}
|
| 153 |
+
}, [width, height, peaks, color, regionColor, start, end, duration]);
|
| 154 |
+
|
| 155 |
+
useEffect(() => { draw(); }, [draw]);
|
| 156 |
+
|
| 157 |
+
// --- pointer interaction --------------------------------------------
|
| 158 |
+
const HIT_PX = 8;
|
| 159 |
+
const pxToSec = useCallback((px) => {
|
| 160 |
+
return Math.max(0, Math.min(duration, (px / width) * duration));
|
| 161 |
+
}, [width, duration]);
|
| 162 |
+
|
| 163 |
+
const onPointerDown = (e) => {
|
| 164 |
+
if (!duration || !width) return;
|
| 165 |
+
const rect = canvasRef.current.getBoundingClientRect();
|
| 166 |
+
const px = e.clientX - rect.left;
|
| 167 |
+
const sPx = (start / duration) * width;
|
| 168 |
+
const ePx = (end / duration) * width;
|
| 169 |
+
let mode;
|
| 170 |
+
if (Math.abs(px - sPx) <= HIT_PX) mode = 'start';
|
| 171 |
+
else if (Math.abs(px - ePx) <= HIT_PX) mode = 'end';
|
| 172 |
+
else if (px > sPx && px < ePx) mode = 'body';
|
| 173 |
+
else mode = 'new'; // start a new region by drag
|
| 174 |
+
dragRef.current = {
|
| 175 |
+
mode,
|
| 176 |
+
startPx: px,
|
| 177 |
+
origStart: start,
|
| 178 |
+
origEnd: end,
|
| 179 |
+
};
|
| 180 |
+
canvasRef.current.setPointerCapture(e.pointerId);
|
| 181 |
+
if (mode === 'new') {
|
| 182 |
+
const t = pxToSec(px);
|
| 183 |
+
onRegionChange?.(t, Math.min(duration, t + minRegionSec));
|
| 184 |
+
dragRef.current.mode = 'end';
|
| 185 |
+
dragRef.current.origStart = t;
|
| 186 |
+
dragRef.current.origEnd = t + minRegionSec;
|
| 187 |
+
}
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
const onPointerMove = (e) => {
|
| 191 |
+
const d = dragRef.current;
|
| 192 |
+
if (!d) return;
|
| 193 |
+
const rect = canvasRef.current.getBoundingClientRect();
|
| 194 |
+
const px = e.clientX - rect.left;
|
| 195 |
+
const delta = pxToSec(px) - pxToSec(d.startPx);
|
| 196 |
+
let s = d.origStart;
|
| 197 |
+
let en = d.origEnd;
|
| 198 |
+
if (d.mode === 'start') {
|
| 199 |
+
s = Math.max(0, Math.min(d.origEnd - minRegionSec, d.origStart + delta));
|
| 200 |
+
} else if (d.mode === 'end') {
|
| 201 |
+
en = Math.max(d.origStart + minRegionSec, Math.min(duration, d.origEnd + delta));
|
| 202 |
+
} else if (d.mode === 'body') {
|
| 203 |
+
const span = d.origEnd - d.origStart;
|
| 204 |
+
s = Math.max(0, Math.min(duration - span, d.origStart + delta));
|
| 205 |
+
en = s + span;
|
| 206 |
+
}
|
| 207 |
+
onRegionChange?.(s, en);
|
| 208 |
+
};
|
| 209 |
+
|
| 210 |
+
const onPointerUp = (e) => {
|
| 211 |
+
if (dragRef.current) {
|
| 212 |
+
canvasRef.current.releasePointerCapture(e.pointerId);
|
| 213 |
+
dragRef.current = null;
|
| 214 |
+
}
|
| 215 |
+
};
|
| 216 |
+
|
| 217 |
+
// --- render ----------------------------------------------------------
|
| 218 |
+
return (
|
| 219 |
+
<Box ref={containerRef} sx={{ width: '100%', position: 'relative' }}>
|
| 220 |
+
<canvas
|
| 221 |
+
ref={canvasRef}
|
| 222 |
+
style={{
|
| 223 |
+
width: '100%',
|
| 224 |
+
height,
|
| 225 |
+
display: 'block',
|
| 226 |
+
cursor: dragRef.current ? 'grabbing' : 'crosshair',
|
| 227 |
+
touchAction: 'none',
|
| 228 |
+
borderRadius: 4,
|
| 229 |
+
background: 'rgba(255,255,255,0.02)',
|
| 230 |
+
}}
|
| 231 |
+
onPointerDown={onPointerDown}
|
| 232 |
+
onPointerMove={onPointerMove}
|
| 233 |
+
onPointerUp={onPointerUp}
|
| 234 |
+
onPointerCancel={onPointerUp}
|
| 235 |
+
/>
|
| 236 |
+
{(decoding || decodeError || !file) && (
|
| 237 |
+
<Box
|
| 238 |
+
sx={{
|
| 239 |
+
position: 'absolute',
|
| 240 |
+
inset: 0,
|
| 241 |
+
display: 'flex',
|
| 242 |
+
alignItems: 'center',
|
| 243 |
+
justifyContent: 'center',
|
| 244 |
+
pointerEvents: 'none',
|
| 245 |
+
}}
|
| 246 |
+
>
|
| 247 |
+
<Typography variant="caption" color="text.secondary">
|
| 248 |
+
{decodeError
|
| 249 |
+
? `decode failed: ${decodeError}`
|
| 250 |
+
: !file
|
| 251 |
+
? 'no source loaded'
|
| 252 |
+
: 'decodingβ¦'}
|
| 253 |
+
</Typography>
|
| 254 |
+
</Box>
|
| 255 |
+
)}
|
| 256 |
+
</Box>
|
| 257 |
+
);
|
| 258 |
+
}
|
app/frontend/src/components/ChannelFragmentHistory.js
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useState } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Box,
|
| 4 |
+
IconButton,
|
| 5 |
+
Dialog,
|
| 6 |
+
DialogTitle,
|
| 7 |
+
DialogContent,
|
| 8 |
+
DialogContentText,
|
| 9 |
+
DialogActions,
|
| 10 |
+
Button,
|
| 11 |
+
} from '@mui/material';
|
| 12 |
+
import { TIPS } from '../tooltips';
|
| 13 |
+
import Tooltip from './Tooltip';
|
| 14 |
+
import {
|
| 15 |
+
Play as PlayIcon,
|
| 16 |
+
Square as StopIcon,
|
| 17 |
+
Star as StarIcon,
|
| 18 |
+
Trash2 as DeleteIcon,
|
| 19 |
+
Check as CommitIcon,
|
| 20 |
+
Eraser as ClearAllIcon,
|
| 21 |
+
} from 'lucide-react';
|
| 22 |
+
import { performanceChannelStyles as styles } from '../theme';
|
| 23 |
+
import { MidiMappable } from './MidiContext';
|
| 24 |
+
|
| 25 |
+
/**
|
| 26 |
+
* Per-channel rolling fragment history. Always visible (empty-state included)
|
| 27 |
+
* so the user knows the strip exists. Chronological order β oldest at
|
| 28 |
+
* the top, newest at the bottom; scrolls vertically when the list grows
|
| 29 |
+
* past ~4 visible rows.
|
| 30 |
+
*
|
| 31 |
+
* Each row exposes four actions, all visible by default (no hover-reveal β
|
| 32 |
+
* Performance use is fast, can't afford the discoverability tax):
|
| 33 |
+
* β’ Cue βΆ/β β audition through the cue output (separate from main mix)
|
| 34 |
+
* β’ Star β
/β β mark as a keeper. Starred fragments survive the cap
|
| 35 |
+
* eviction; unstarred get dropped FIFO when over cap.
|
| 36 |
+
* β’ Delete β« β remove this fragment from history (cancellable confirm not
|
| 37 |
+
* shown for single deletes β the entry can be regenerated
|
| 38 |
+
* or audition can be retriggered after a quick re-tap).
|
| 39 |
+
* β’ Load β β commit this fragment to the channel strip (becomes the
|
| 40 |
+
* audio the channel plays). Disabled while already loaded.
|
| 41 |
+
*
|
| 42 |
+
* Props:
|
| 43 |
+
* fragments: [{ id, audioUrl, blob, prompt, duration, createdAt,
|
| 44 |
+
* starred, number }]
|
| 45 |
+
* color: channel accent color
|
| 46 |
+
* auditioningId: the id currently playing through cue, or null
|
| 47 |
+
* committedId: the id currently loaded into the channel strip, or null
|
| 48 |
+
* maxFragments: cap, default 50 (informational; eviction lives in parent)
|
| 49 |
+
* on{Audition,Commit,ToggleStar,Delete}: (fragmentId) => void
|
| 50 |
+
* onClearAll: () => void (parent confirms separately β we still show
|
| 51 |
+
* a confirm dialog here for the trash-everything action)
|
| 52 |
+
*/
|
| 53 |
+
export default function ChannelFragmentHistory({
|
| 54 |
+
fragments,
|
| 55 |
+
color,
|
| 56 |
+
channelIndex,
|
| 57 |
+
auditioningId,
|
| 58 |
+
committedId,
|
| 59 |
+
maxFragments = 50,
|
| 60 |
+
onAudition,
|
| 61 |
+
onCommit,
|
| 62 |
+
onToggleStar,
|
| 63 |
+
onDelete,
|
| 64 |
+
onClearAll,
|
| 65 |
+
}) {
|
| 66 |
+
const [clearConfirmOpen, setClearConfirmOpen] = useState(false);
|
| 67 |
+
// Channel-scoped MIME type for drag-and-drop. The waveform drop target on
|
| 68 |
+
// this same channel listens for this exact type β cross-channel drags
|
| 69 |
+
// won't highlight or accept because the mime won't match.
|
| 70 |
+
const dragMime = `application/x-fragmenta-fragment-ch${channelIndex}`;
|
| 71 |
+
|
| 72 |
+
return (
|
| 73 |
+
<Box sx={styles.fragmentHistoryPanel}>
|
| 74 |
+
<Box sx={styles.fragmentHistoryHeader}>
|
| 75 |
+
<Box component="span" sx={styles.fragmentHistoryHeaderText}>
|
| 76 |
+
Fragments
|
| 77 |
+
</Box>
|
| 78 |
+
{fragments.length > 0 && (
|
| 79 |
+
<IconButton
|
| 80 |
+
size="small"
|
| 81 |
+
onClick={() => setClearConfirmOpen(true)}
|
| 82 |
+
sx={styles.fragmentHistoryHeaderBtn}
|
| 83 |
+
aria-label="Clear all fragments"
|
| 84 |
+
>
|
| 85 |
+
<ClearAllIcon size={12} />
|
| 86 |
+
</IconButton>
|
| 87 |
+
)}
|
| 88 |
+
</Box>
|
| 89 |
+
|
| 90 |
+
{fragments.length === 0 ? (
|
| 91 |
+
<Box sx={styles.fragmentHistoryEmpty}>Empty</Box>
|
| 92 |
+
) : (
|
| 93 |
+
<Box sx={styles.fragmentHistoryList}>
|
| 94 |
+
{fragments.map((fragment) => {
|
| 95 |
+
const isAuditioning = auditioningId === fragment.id;
|
| 96 |
+
const isCommitted = committedId === fragment.id;
|
| 97 |
+
return (
|
| 98 |
+
<Box
|
| 99 |
+
key={fragment.id}
|
| 100 |
+
draggable
|
| 101 |
+
onDragStart={(e) => {
|
| 102 |
+
e.dataTransfer.setData(dragMime, fragment.id);
|
| 103 |
+
e.dataTransfer.effectAllowed = 'copy';
|
| 104 |
+
}}
|
| 105 |
+
sx={{
|
| 106 |
+
...styles.fragmentRow(color, isCommitted, isAuditioning),
|
| 107 |
+
cursor: 'grab',
|
| 108 |
+
'&:active': { cursor: 'grabbing' },
|
| 109 |
+
}}
|
| 110 |
+
>
|
| 111 |
+
<MidiMappable
|
| 112 |
+
id={`channel.${channelIndex}.fragment.${fragment.id}.audition`}
|
| 113 |
+
label={`Ch ${channelIndex + 1} Β· Fragment ${fragment.number} audition`}
|
| 114 |
+
kind="trigger"
|
| 115 |
+
onChange={() => onAudition(fragment.id)}
|
| 116 |
+
>
|
| 117 |
+
<Tooltip
|
| 118 |
+
title={TIPS.fragments.audition(isAuditioning)}
|
| 119 |
+
placement="top"
|
| 120 |
+
arrow
|
| 121 |
+
enterDelay={300}
|
| 122 |
+
>
|
| 123 |
+
<IconButton
|
| 124 |
+
size="small"
|
| 125 |
+
onClick={() => onAudition(fragment.id)}
|
| 126 |
+
sx={styles.fragmentIconBtn(color, isAuditioning, true)}
|
| 127 |
+
aria-label={isAuditioning ? 'Stop cue' : 'Audition'}
|
| 128 |
+
>
|
| 129 |
+
{isAuditioning
|
| 130 |
+
? <StopIcon size={12} />
|
| 131 |
+
: <PlayIcon size={12} />}
|
| 132 |
+
</IconButton>
|
| 133 |
+
</Tooltip>
|
| 134 |
+
</MidiMappable>
|
| 135 |
+
|
| 136 |
+
<Box sx={styles.fragmentMeta}>
|
| 137 |
+
<Box component="span" sx={styles.fragmentOrdinal}>
|
| 138 |
+
F{fragment.number}
|
| 139 |
+
</Box>
|
| 140 |
+
</Box>
|
| 141 |
+
|
| 142 |
+
<Tooltip
|
| 143 |
+
title={TIPS.fragments.star(fragment.starred)}
|
| 144 |
+
placement="top"
|
| 145 |
+
arrow
|
| 146 |
+
enterDelay={300}
|
| 147 |
+
>
|
| 148 |
+
<IconButton
|
| 149 |
+
size="small"
|
| 150 |
+
onClick={() => onToggleStar(fragment.id)}
|
| 151 |
+
sx={styles.fragmentIconBtn(color, fragment.starred)}
|
| 152 |
+
aria-label={fragment.starred ? 'Unstar fragment' : 'Star fragment'}
|
| 153 |
+
>
|
| 154 |
+
<StarIcon
|
| 155 |
+
size={12}
|
| 156 |
+
fill={fragment.starred ? color : 'none'}
|
| 157 |
+
strokeWidth={2}
|
| 158 |
+
/>
|
| 159 |
+
</IconButton>
|
| 160 |
+
</Tooltip>
|
| 161 |
+
|
| 162 |
+
<IconButton
|
| 163 |
+
size="small"
|
| 164 |
+
onClick={() => onDelete(fragment.id)}
|
| 165 |
+
sx={styles.fragmentDeleteBtn}
|
| 166 |
+
aria-label="Delete fragment"
|
| 167 |
+
>
|
| 168 |
+
<DeleteIcon size={12} />
|
| 169 |
+
</IconButton>
|
| 170 |
+
|
| 171 |
+
<Tooltip
|
| 172 |
+
title={TIPS.fragments.commit(isCommitted)}
|
| 173 |
+
placement="top"
|
| 174 |
+
arrow
|
| 175 |
+
enterDelay={300}
|
| 176 |
+
>
|
| 177 |
+
<span>
|
| 178 |
+
<IconButton
|
| 179 |
+
size="small"
|
| 180 |
+
onClick={() => onCommit(fragment.id)}
|
| 181 |
+
disabled={isCommitted}
|
| 182 |
+
sx={styles.fragmentIconBtn(color, isCommitted, true)}
|
| 183 |
+
aria-label="Load fragment into channel"
|
| 184 |
+
>
|
| 185 |
+
<CommitIcon size={12} strokeWidth={isCommitted ? 3 : 2} />
|
| 186 |
+
</IconButton>
|
| 187 |
+
</span>
|
| 188 |
+
</Tooltip>
|
| 189 |
+
</Box>
|
| 190 |
+
);
|
| 191 |
+
})}
|
| 192 |
+
</Box>
|
| 193 |
+
)}
|
| 194 |
+
|
| 195 |
+
<Dialog open={clearConfirmOpen} onClose={() => setClearConfirmOpen(false)}>
|
| 196 |
+
<DialogTitle>Clear fragment history?</DialogTitle>
|
| 197 |
+
<DialogContent>
|
| 198 |
+
<DialogContentText>
|
| 199 |
+
Removes all {fragments.length} fragments from this channel's history,
|
| 200 |
+
including starred ones. The currently loaded clip stays loaded
|
| 201 |
+
β only the history entries are dropped.
|
| 202 |
+
</DialogContentText>
|
| 203 |
+
</DialogContent>
|
| 204 |
+
<DialogActions>
|
| 205 |
+
<Button onClick={() => setClearConfirmOpen(false)}>Cancel</Button>
|
| 206 |
+
<Button
|
| 207 |
+
onClick={() => { setClearConfirmOpen(false); onClearAll?.(); }}
|
| 208 |
+
color="error"
|
| 209 |
+
variant="contained"
|
| 210 |
+
>
|
| 211 |
+
Clear all
|
| 212 |
+
</Button>
|
| 213 |
+
</DialogActions>
|
| 214 |
+
</Dialog>
|
| 215 |
+
</Box>
|
| 216 |
+
);
|
| 217 |
+
}
|
app/frontend/src/components/CheckpointManagerWindow.js
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useCallback, useEffect, useState } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Dialog,
|
| 4 |
+
DialogTitle,
|
| 5 |
+
DialogContent,
|
| 6 |
+
DialogActions,
|
| 7 |
+
Box,
|
| 8 |
+
Typography,
|
| 9 |
+
Button,
|
| 10 |
+
IconButton,
|
| 11 |
+
Stack,
|
| 12 |
+
Alert,
|
| 13 |
+
TextField,
|
| 14 |
+
LinearProgress,
|
| 15 |
+
} from '@mui/material';
|
| 16 |
+
import {
|
| 17 |
+
X as CloseIcon,
|
| 18 |
+
HardDrive as StorageIcon,
|
| 19 |
+
LogIn as LoginIcon,
|
| 20 |
+
LogOut as LogoutIcon,
|
| 21 |
+
} from 'lucide-react';
|
| 22 |
+
import api from '../api';
|
| 23 |
+
import CheckpointRow from './CheckpointRow';
|
| 24 |
+
import StorageDrilldown from './StorageDrilldown';
|
| 25 |
+
|
| 26 |
+
const fmtBytes = (n) => {
|
| 27 |
+
if (!n && n !== 0) return 'β';
|
| 28 |
+
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
|
| 29 |
+
let v = n;
|
| 30 |
+
let u = 0;
|
| 31 |
+
while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; }
|
| 32 |
+
return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
export default function CheckpointManagerWindow({ open, onClose }) {
|
| 36 |
+
const [catalog, setCatalog] = useState([]);
|
| 37 |
+
const [storage, setStorage] = useState(null);
|
| 38 |
+
const [env, setEnv] = useState(null);
|
| 39 |
+
const [hfAuth, setHfAuth] = useState({ signed_in: false, username: null });
|
| 40 |
+
const [tokenDraft, setTokenDraft] = useState('');
|
| 41 |
+
const [showTokenInput, setShowTokenInput] = useState(false);
|
| 42 |
+
const [authError, setAuthError] = useState(null);
|
| 43 |
+
const [showStorage, setShowStorage] = useState(false);
|
| 44 |
+
const [loading, setLoading] = useState(false);
|
| 45 |
+
const [error, setError] = useState(null);
|
| 46 |
+
|
| 47 |
+
const refresh = useCallback(async () => {
|
| 48 |
+
setLoading(true);
|
| 49 |
+
setError(null);
|
| 50 |
+
try {
|
| 51 |
+
const [cat, store, auth, environment] = await Promise.all([
|
| 52 |
+
api.get('/api/checkpoints'),
|
| 53 |
+
api.get('/api/checkpoints/storage'),
|
| 54 |
+
api.get('/api/hf-auth/status'),
|
| 55 |
+
api.get('/api/environment'),
|
| 56 |
+
]);
|
| 57 |
+
setCatalog(cat.data.checkpoints);
|
| 58 |
+
setStorage(store.data);
|
| 59 |
+
setHfAuth(auth.data);
|
| 60 |
+
setEnv(environment.data);
|
| 61 |
+
} catch (e) {
|
| 62 |
+
setError(e.response?.data?.error || e.message);
|
| 63 |
+
} finally {
|
| 64 |
+
setLoading(false);
|
| 65 |
+
}
|
| 66 |
+
}, []);
|
| 67 |
+
|
| 68 |
+
useEffect(() => {
|
| 69 |
+
if (open) refresh();
|
| 70 |
+
}, [open, refresh]);
|
| 71 |
+
|
| 72 |
+
const submitToken = async () => {
|
| 73 |
+
setAuthError(null);
|
| 74 |
+
try {
|
| 75 |
+
await api.post('/api/hf-auth', { token: tokenDraft.trim() });
|
| 76 |
+
setTokenDraft('');
|
| 77 |
+
setShowTokenInput(false);
|
| 78 |
+
refresh();
|
| 79 |
+
} catch (e) {
|
| 80 |
+
setAuthError(e.response?.data?.error || e.message);
|
| 81 |
+
}
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
const logout = async () => {
|
| 85 |
+
try {
|
| 86 |
+
await api.delete('/api/hf-auth');
|
| 87 |
+
refresh();
|
| 88 |
+
} catch (e) {
|
| 89 |
+
setAuthError(e.response?.data?.error || e.message);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
const anyInstalled = catalog.some(c => c.downloaded);
|
| 94 |
+
|
| 95 |
+
return (
|
| 96 |
+
<>
|
| 97 |
+
<Dialog open={open} onClose={onClose} maxWidth="md" fullWidth scroll="paper">
|
| 98 |
+
<DialogTitle sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
| 99 |
+
<Box sx={{ flex: 1 }}>Checkpoint Manager</Box>
|
| 100 |
+
<IconButton size="small" onClick={onClose}><CloseIcon size={18} /></IconButton>
|
| 101 |
+
</DialogTitle>
|
| 102 |
+
|
| 103 |
+
<DialogContent dividers>
|
| 104 |
+
<Box sx={{ mb: 2 }}>
|
| 105 |
+
<Stack direction="row" alignItems="center" spacing={2} flexWrap="wrap">
|
| 106 |
+
<Button
|
| 107 |
+
size="small"
|
| 108 |
+
variant="text"
|
| 109 |
+
startIcon={<StorageIcon size={14} />}
|
| 110 |
+
onClick={() => setShowStorage(true)}
|
| 111 |
+
disabled={!storage}
|
| 112 |
+
>
|
| 113 |
+
{storage
|
| 114 |
+
? `${fmtBytes(storage.total_used_bytes)} used Β· ${fmtBytes(storage.total_free_bytes)} free`
|
| 115 |
+
: 'β'}
|
| 116 |
+
</Button>
|
| 117 |
+
|
| 118 |
+
<Box sx={{ flex: 1 }} />
|
| 119 |
+
|
| 120 |
+
{hfAuth.signed_in ? (
|
| 121 |
+
<Stack direction="row" alignItems="center" spacing={1}>
|
| 122 |
+
<Typography variant="caption" color="text.secondary">
|
| 123 |
+
HuggingFace: signed in as <strong>{hfAuth.username}</strong>
|
| 124 |
+
</Typography>
|
| 125 |
+
<Button
|
| 126 |
+
size="small"
|
| 127 |
+
variant="text"
|
| 128 |
+
startIcon={<LogoutIcon size={14} />}
|
| 129 |
+
onClick={logout}
|
| 130 |
+
>
|
| 131 |
+
Sign out
|
| 132 |
+
</Button>
|
| 133 |
+
</Stack>
|
| 134 |
+
) : showTokenInput ? (
|
| 135 |
+
<Stack direction="row" alignItems="center" spacing={1}>
|
| 136 |
+
<TextField
|
| 137 |
+
size="small"
|
| 138 |
+
placeholder="hf_..."
|
| 139 |
+
value={tokenDraft}
|
| 140 |
+
onChange={(e) => setTokenDraft(e.target.value)}
|
| 141 |
+
type="password"
|
| 142 |
+
sx={{ width: 240 }}
|
| 143 |
+
/>
|
| 144 |
+
<Button size="small" variant="contained" onClick={submitToken}>
|
| 145 |
+
Sign in
|
| 146 |
+
</Button>
|
| 147 |
+
<Button size="small" onClick={() => { setShowTokenInput(false); setTokenDraft(''); }}>
|
| 148 |
+
Cancel
|
| 149 |
+
</Button>
|
| 150 |
+
</Stack>
|
| 151 |
+
) : (
|
| 152 |
+
<Button
|
| 153 |
+
size="small"
|
| 154 |
+
variant="outlined"
|
| 155 |
+
startIcon={<LoginIcon size={14} />}
|
| 156 |
+
onClick={() => setShowTokenInput(true)}
|
| 157 |
+
>
|
| 158 |
+
Sign in to HuggingFace
|
| 159 |
+
</Button>
|
| 160 |
+
)}
|
| 161 |
+
</Stack>
|
| 162 |
+
{authError && <Alert severity="error" sx={{ mt: 1 }}>{authError}</Alert>}
|
| 163 |
+
</Box>
|
| 164 |
+
|
| 165 |
+
{!hfAuth.signed_in ? (
|
| 166 |
+
<Alert severity="info" sx={{ mb: 2 }}>
|
| 167 |
+
SA3 checkpoints are gated on HuggingFace. You need a{' '}
|
| 168 |
+
<a href="https://huggingface.co/join" target="_blank" rel="noreferrer">
|
| 169 |
+
HuggingFace account
|
| 170 |
+
</a>
|
| 171 |
+
{' '}to continue. Then{' '}
|
| 172 |
+
<a href="https://huggingface.co/settings/tokens" target="_blank" rel="noreferrer">
|
| 173 |
+
create a Read access token
|
| 174 |
+
</a>
|
| 175 |
+
{' '}and sign in above.
|
| 176 |
+
</Alert>
|
| 177 |
+
) : (
|
| 178 |
+
<Alert severity="info" sx={{ mb: 2 }}>
|
| 179 |
+
You're signed in. Each model is gated β click its name below to open the
|
| 180 |
+
HuggingFace page and accept the model's terms before downloading.
|
| 181 |
+
</Alert>
|
| 182 |
+
)}
|
| 183 |
+
|
| 184 |
+
{error && <Alert severity="error" sx={{ mb: 2 }}>{error}</Alert>}
|
| 185 |
+
{loading && <LinearProgress sx={{ mb: 2 }} />}
|
| 186 |
+
|
| 187 |
+
{!loading && !anyInstalled && catalog.length > 0 && (
|
| 188 |
+
<Box sx={{
|
| 189 |
+
p: 2, mb: 2, borderRadius: 1, bgcolor: 'action.hover',
|
| 190 |
+
}}>
|
| 191 |
+
<Typography variant="body2" fontWeight={500}>
|
| 192 |
+
Pick a model to get started.
|
| 193 |
+
</Typography>
|
| 194 |
+
<Typography variant="caption" color="text.secondary">
|
| 195 |
+
Small - Music (1.2 GB) is a good first choice on a laptop or any GPU.
|
| 196 |
+
</Typography>
|
| 197 |
+
</Box>
|
| 198 |
+
)}
|
| 199 |
+
|
| 200 |
+
{[
|
| 201 |
+
{ kind: 'post-trained', label: 'Distilled (fast)', hint: '8 steps, cfg locked at 1.0. Prompt, duration and seed only.' },
|
| 202 |
+
{ kind: 'base', label: 'Base (full control)', hint: 'CFG-aware. ~50 steps, cfg ~7. Cfg-scale and steps are live controls.' },
|
| 203 |
+
{ kind: 'tagger', label: 'Auto-annotation tools', hint: 'Optional helpers for dataset prep. CLAP scores audio against your vocabulary.' },
|
| 204 |
+
].map(group => {
|
| 205 |
+
const rows = catalog.filter(c => c.kind === group.kind);
|
| 206 |
+
if (!rows.length) return null;
|
| 207 |
+
return (
|
| 208 |
+
<Box key={group.kind} sx={{ mb: 2 }}>
|
| 209 |
+
<Typography variant="subtitle2" sx={{ mb: 0.25 }}>{group.label}</Typography>
|
| 210 |
+
<Typography variant="caption" color="text.secondary" sx={{ display: 'block', mb: 0.75 }}>
|
| 211 |
+
{group.hint}
|
| 212 |
+
</Typography>
|
| 213 |
+
<Box sx={{ border: '1px solid', borderColor: 'divider', borderRadius: 1 }}>
|
| 214 |
+
{rows.map(c => (
|
| 215 |
+
<CheckpointRow
|
| 216 |
+
key={c.id}
|
| 217 |
+
checkpoint={c}
|
| 218 |
+
env={env}
|
| 219 |
+
onAuthRequired={() => setShowTokenInput(true)}
|
| 220 |
+
onChanged={refresh}
|
| 221 |
+
/>
|
| 222 |
+
))}
|
| 223 |
+
</Box>
|
| 224 |
+
</Box>
|
| 225 |
+
);
|
| 226 |
+
})}
|
| 227 |
+
</DialogContent>
|
| 228 |
+
|
| 229 |
+
<DialogActions>
|
| 230 |
+
<Button onClick={refresh} disabled={loading}>Refresh</Button>
|
| 231 |
+
<Button onClick={onClose} variant="contained">Close</Button>
|
| 232 |
+
</DialogActions>
|
| 233 |
+
</Dialog>
|
| 234 |
+
|
| 235 |
+
<StorageDrilldown
|
| 236 |
+
open={showStorage}
|
| 237 |
+
onClose={() => setShowStorage(false)}
|
| 238 |
+
storage={storage}
|
| 239 |
+
catalog={catalog}
|
| 240 |
+
/>
|
| 241 |
+
</>
|
| 242 |
+
);
|
| 243 |
+
}
|
app/frontend/src/components/CheckpointRow.js
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useEffect, useRef, useState } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Box,
|
| 4 |
+
Typography,
|
| 5 |
+
Button,
|
| 6 |
+
Chip,
|
| 7 |
+
LinearProgress,
|
| 8 |
+
Stack,
|
| 9 |
+
IconButton,
|
| 10 |
+
} from '@mui/material';
|
| 11 |
+
import { TIPS } from '../tooltips';
|
| 12 |
+
import Tooltip from './Tooltip';
|
| 13 |
+
import {
|
| 14 |
+
CloudDownload as DownloadIcon,
|
| 15 |
+
Trash2 as DeleteIcon,
|
| 16 |
+
X as CancelIcon,
|
| 17 |
+
} from 'lucide-react';
|
| 18 |
+
import api from '../api';
|
| 19 |
+
|
| 20 |
+
const fmtBytes = (n) => {
|
| 21 |
+
if (!n && n !== 0) return 'β';
|
| 22 |
+
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
|
| 23 |
+
let v = n;
|
| 24 |
+
let u = 0;
|
| 25 |
+
while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; }
|
| 26 |
+
return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`;
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
const hardwareLabel = (hw) => ({
|
| 30 |
+
'cpu': 'CPU / GPU',
|
| 31 |
+
'cuda': 'CUDA',
|
| 32 |
+
'cuda+flash-attn': 'CUDA + Flash-Attn',
|
| 33 |
+
}[hw] || hw);
|
| 34 |
+
|
| 35 |
+
// Why this host can't run a given model, or null if it can. Mirrors the gate
|
| 36 |
+
// in audio_generator._ensure_model. `env` comes from GET /api/environment.
|
| 37 |
+
const hostIncompatReason = (hw, env) => {
|
| 38 |
+
if (!env) return null; // capabilities unknown β don't block
|
| 39 |
+
if (hw === 'cuda+flash-attn') {
|
| 40 |
+
if (!env.cuda_available) {
|
| 41 |
+
return 'Requires an NVIDIA CUDA GPU. Use a Small model β those run on CPU, Apple Silicon, or any GPU.';
|
| 42 |
+
}
|
| 43 |
+
// Gate on the real capability, not the platform: Windows works once a
|
| 44 |
+
// matching flash-attn wheel is installed (Blackwell/Ampere + cu12x).
|
| 45 |
+
// No wheel β guide the user to install one (or use Docker on WSL2).
|
| 46 |
+
if (!env.flash_attn_available) {
|
| 47 |
+
return env.platform === 'Windows'
|
| 48 |
+
? 'Requires Flash Attention 2 (flash-attn). No official Windows wheel β install a matching prebuilt/built wheel for your torch+CUDA, or run via Docker on WSL2.'
|
| 49 |
+
: 'Requires Flash Attention 2 (flash-attn) β not installed. Install it, or use a Small model.';
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
if (hw === 'cuda' && !env.cuda_available) {
|
| 53 |
+
return 'Recommended on an NVIDIA CUDA GPU; this host has none.';
|
| 54 |
+
}
|
| 55 |
+
return null;
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
export default function CheckpointRow({ checkpoint, env, onAuthRequired, onChanged }) {
|
| 59 |
+
const [jobId, setJobId] = useState(checkpoint.active_job?.job_id || null);
|
| 60 |
+
const [job, setJob] = useState(checkpoint.active_job || null);
|
| 61 |
+
const [error, setError] = useState(null);
|
| 62 |
+
const [busy, setBusy] = useState(false);
|
| 63 |
+
const pollTimer = useRef(null);
|
| 64 |
+
|
| 65 |
+
// If the parent's refresh tells us about an in-flight job and we don't
|
| 66 |
+
// already have one locally (typical case: dialog was closed mid-download
|
| 67 |
+
// and just got reopened), adopt it. Don't stomp a freshly-started local
|
| 68 |
+
// job_id with stale catalog data β only sync when the local state is empty
|
| 69 |
+
// or a *different* job is now active for this checkpoint.
|
| 70 |
+
useEffect(() => {
|
| 71 |
+
const incoming = checkpoint.active_job?.job_id || null;
|
| 72 |
+
if (incoming && incoming !== jobId) {
|
| 73 |
+
setJobId(incoming);
|
| 74 |
+
setJob(checkpoint.active_job);
|
| 75 |
+
}
|
| 76 |
+
}, [checkpoint.active_job, jobId]);
|
| 77 |
+
|
| 78 |
+
useEffect(() => {
|
| 79 |
+
if (!jobId) return undefined;
|
| 80 |
+
const tick = async () => {
|
| 81 |
+
try {
|
| 82 |
+
const r = await api.get(`/api/checkpoints/jobs/${jobId}`);
|
| 83 |
+
setJob(r.data);
|
| 84 |
+
if (['complete', 'failed', 'cancelled'].includes(r.data.status)) {
|
| 85 |
+
if (r.data.status === 'failed' && (r.data.error || '').startsWith('hf_auth_required')) {
|
| 86 |
+
onAuthRequired?.();
|
| 87 |
+
} else if (r.data.status === 'failed') {
|
| 88 |
+
setError(r.data.error);
|
| 89 |
+
}
|
| 90 |
+
setJobId(null);
|
| 91 |
+
onChanged?.();
|
| 92 |
+
}
|
| 93 |
+
} catch (e) {
|
| 94 |
+
setError(e.response?.data?.error || e.message);
|
| 95 |
+
setJobId(null);
|
| 96 |
+
}
|
| 97 |
+
};
|
| 98 |
+
tick();
|
| 99 |
+
pollTimer.current = setInterval(tick, 1500);
|
| 100 |
+
return () => clearInterval(pollTimer.current);
|
| 101 |
+
}, [jobId, onAuthRequired, onChanged]);
|
| 102 |
+
|
| 103 |
+
const startDownload = async () => {
|
| 104 |
+
setBusy(true);
|
| 105 |
+
setError(null);
|
| 106 |
+
try {
|
| 107 |
+
const r = await api.post(`/api/checkpoints/${checkpoint.id}/download`);
|
| 108 |
+
setJobId(r.data.job_id);
|
| 109 |
+
} catch (e) {
|
| 110 |
+
setError(e.response?.data?.error || e.message);
|
| 111 |
+
} finally {
|
| 112 |
+
setBusy(false);
|
| 113 |
+
}
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
const cancelDownload = async () => {
|
| 117 |
+
try {
|
| 118 |
+
await api.post(`/api/checkpoints/${checkpoint.id}/cancel-download`);
|
| 119 |
+
} catch (e) {
|
| 120 |
+
setError(e.response?.data?.error || e.message);
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
const deleteCheckpoint = async () => {
|
| 125 |
+
if (!window.confirm(`Delete ${checkpoint.name} (${fmtBytes(checkpoint.downloaded_bytes)})?`)) return;
|
| 126 |
+
setBusy(true);
|
| 127 |
+
try {
|
| 128 |
+
await api.delete(`/api/checkpoints/${checkpoint.id}`);
|
| 129 |
+
onChanged?.();
|
| 130 |
+
} catch (e) {
|
| 131 |
+
setError(e.response?.data?.error || e.message);
|
| 132 |
+
} finally {
|
| 133 |
+
setBusy(false);
|
| 134 |
+
}
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
const downloading = !!jobId && job?.status === 'running';
|
| 138 |
+
const queued = !!jobId && job?.status === 'queued';
|
| 139 |
+
const pct = job?.total_bytes ? (job.downloaded_bytes / job.total_bytes) * 100 : 0;
|
| 140 |
+
const incompatReason = hostIncompatReason(checkpoint.hardware, env);
|
| 141 |
+
|
| 142 |
+
const renderAction = () => {
|
| 143 |
+
if (downloading || queued) {
|
| 144 |
+
return (
|
| 145 |
+
<IconButton size="small" onClick={cancelDownload} aria-label="Cancel download"><CancelIcon size={16} /></IconButton>
|
| 146 |
+
);
|
| 147 |
+
}
|
| 148 |
+
if (checkpoint.downloaded) {
|
| 149 |
+
return (
|
| 150 |
+
<IconButton size="small" onClick={deleteCheckpoint} disabled={busy} aria-label="Delete from disk">
|
| 151 |
+
<DeleteIcon size={16} />
|
| 152 |
+
</IconButton>
|
| 153 |
+
);
|
| 154 |
+
}
|
| 155 |
+
if (incompatReason) {
|
| 156 |
+
return (
|
| 157 |
+
<Tooltip title={incompatReason}>
|
| 158 |
+
{/* span wrapper so the tooltip works on a disabled button */}
|
| 159 |
+
<span>
|
| 160 |
+
<Button
|
| 161 |
+
size="small"
|
| 162 |
+
variant="outlined"
|
| 163 |
+
startIcon={<DownloadIcon size={14} />}
|
| 164 |
+
disabled
|
| 165 |
+
>
|
| 166 |
+
Get
|
| 167 |
+
</Button>
|
| 168 |
+
</span>
|
| 169 |
+
</Tooltip>
|
| 170 |
+
);
|
| 171 |
+
}
|
| 172 |
+
return (
|
| 173 |
+
<Button
|
| 174 |
+
size="small"
|
| 175 |
+
variant="contained"
|
| 176 |
+
startIcon={<DownloadIcon size={14} />}
|
| 177 |
+
onClick={startDownload}
|
| 178 |
+
disabled={busy}
|
| 179 |
+
>
|
| 180 |
+
Get
|
| 181 |
+
</Button>
|
| 182 |
+
);
|
| 183 |
+
};
|
| 184 |
+
|
| 185 |
+
return (
|
| 186 |
+
<Box
|
| 187 |
+
sx={{
|
| 188 |
+
py: 1.25,
|
| 189 |
+
px: 1.5,
|
| 190 |
+
borderBottom: '1px solid',
|
| 191 |
+
borderColor: 'divider',
|
| 192 |
+
'&:last-child': { borderBottom: 'none' },
|
| 193 |
+
}}
|
| 194 |
+
>
|
| 195 |
+
<Stack direction="row" alignItems="center" spacing={2}>
|
| 196 |
+
<Box sx={{ flex: 1, minWidth: 0, opacity: (incompatReason && !checkpoint.downloaded) ? 0.55 : 1 }}>
|
| 197 |
+
<Stack direction="row" alignItems="center" spacing={1}>
|
| 198 |
+
<Tooltip title={TIPS.checkpoints.gatedAccess}>
|
| 199 |
+
<Typography
|
| 200 |
+
component="a"
|
| 201 |
+
href={`https://huggingface.co/${checkpoint.repo}`}
|
| 202 |
+
target="_blank"
|
| 203 |
+
rel="noreferrer"
|
| 204 |
+
variant="body2"
|
| 205 |
+
sx={{
|
| 206 |
+
fontWeight: 500,
|
| 207 |
+
color: 'inherit',
|
| 208 |
+
textDecoration: 'none',
|
| 209 |
+
borderBottom: '1px dashed',
|
| 210 |
+
borderColor: 'text.disabled',
|
| 211 |
+
'&:hover': { color: 'primary.main', borderColor: 'primary.main' },
|
| 212 |
+
}}
|
| 213 |
+
>
|
| 214 |
+
{checkpoint.name}
|
| 215 |
+
</Typography>
|
| 216 |
+
</Tooltip>
|
| 217 |
+
<Chip
|
| 218 |
+
size="small"
|
| 219 |
+
label={hardwareLabel(checkpoint.hardware)}
|
| 220 |
+
variant="outlined"
|
| 221 |
+
sx={{ height: 18, fontSize: 10 }}
|
| 222 |
+
/>
|
| 223 |
+
{checkpoint.downloaded && (
|
| 224 |
+
<Chip
|
| 225 |
+
size="small"
|
| 226 |
+
label="installed"
|
| 227 |
+
sx={{
|
| 228 |
+
height: 18,
|
| 229 |
+
fontSize: 10,
|
| 230 |
+
fontWeight: 600,
|
| 231 |
+
bgcolor: 'success.main',
|
| 232 |
+
color: 'common.white',
|
| 233 |
+
}}
|
| 234 |
+
/>
|
| 235 |
+
)}
|
| 236 |
+
</Stack>
|
| 237 |
+
<Typography variant="caption" color="text.secondary">
|
| 238 |
+
{fmtBytes(checkpoint.size_bytes)}
|
| 239 |
+
{checkpoint.max_duration_sec && ` Β· up to ${checkpoint.max_duration_sec}s`}
|
| 240 |
+
</Typography>
|
| 241 |
+
{incompatReason && !checkpoint.downloaded && (
|
| 242 |
+
<Typography variant="caption" color="warning.main" sx={{ display: 'block' }}>
|
| 243 |
+
Not supported on this machine
|
| 244 |
+
</Typography>
|
| 245 |
+
)}
|
| 246 |
+
</Box>
|
| 247 |
+
<Box>{renderAction()}</Box>
|
| 248 |
+
</Stack>
|
| 249 |
+
|
| 250 |
+
{(downloading || queued) && (
|
| 251 |
+
<Box sx={{ mt: 1 }}>
|
| 252 |
+
<LinearProgress
|
| 253 |
+
variant={queued ? 'indeterminate' : 'determinate'}
|
| 254 |
+
value={Math.min(100, pct)}
|
| 255 |
+
sx={{ height: 4, borderRadius: 2 }}
|
| 256 |
+
/>
|
| 257 |
+
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
| 258 |
+
{queued ? 'Queuedβ¦' : `${fmtBytes(job?.downloaded_bytes)} / ${fmtBytes(job?.total_bytes)}`}
|
| 259 |
+
</Typography>
|
| 260 |
+
</Box>
|
| 261 |
+
)}
|
| 262 |
+
|
| 263 |
+
{error && (
|
| 264 |
+
<Typography variant="caption" color="error" sx={{ mt: 0.5, display: 'block' }}>
|
| 265 |
+
{error}
|
| 266 |
+
</Typography>
|
| 267 |
+
)}
|
| 268 |
+
</Box>
|
| 269 |
+
);
|
| 270 |
+
}
|
app/frontend/src/components/DatasetPrep.js
ADDED
|
@@ -0,0 +1,1823 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useCallback, useEffect, useRef, useState } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Accordion,
|
| 4 |
+
AccordionDetails,
|
| 5 |
+
AccordionSummary,
|
| 6 |
+
Alert,
|
| 7 |
+
Autocomplete,
|
| 8 |
+
Box,
|
| 9 |
+
Button,
|
| 10 |
+
Checkbox,
|
| 11 |
+
Chip,
|
| 12 |
+
Dialog,
|
| 13 |
+
DialogActions,
|
| 14 |
+
DialogContent,
|
| 15 |
+
DialogTitle,
|
| 16 |
+
FormControl,
|
| 17 |
+
FormControlLabel,
|
| 18 |
+
IconButton,
|
| 19 |
+
InputLabel,
|
| 20 |
+
LinearProgress,
|
| 21 |
+
MenuItem,
|
| 22 |
+
Paper,
|
| 23 |
+
Portal,
|
| 24 |
+
Radio,
|
| 25 |
+
RadioGroup,
|
| 26 |
+
Select,
|
| 27 |
+
Snackbar,
|
| 28 |
+
Stack,
|
| 29 |
+
Switch,
|
| 30 |
+
Table,
|
| 31 |
+
TableBody,
|
| 32 |
+
TableCell,
|
| 33 |
+
TableContainer,
|
| 34 |
+
TableHead,
|
| 35 |
+
TableRow,
|
| 36 |
+
TextField,
|
| 37 |
+
Typography,
|
| 38 |
+
useTheme,
|
| 39 |
+
} from '@mui/material';
|
| 40 |
+
import { TIPS } from '../tooltips';
|
| 41 |
+
import Tooltip from './Tooltip';
|
| 42 |
+
import {
|
| 43 |
+
ChevronDown as ChevronDownIcon,
|
| 44 |
+
FolderOpenIcon,
|
| 45 |
+
PlusIcon,
|
| 46 |
+
WandSparkles,
|
| 47 |
+
SaveIcon,
|
| 48 |
+
Database as Database,
|
| 49 |
+
DatabaseZap as DatasetIcon,
|
| 50 |
+
Square as StopIcon,
|
| 51 |
+
Trash2 as TrashIcon,
|
| 52 |
+
Play as PlayIcon,
|
| 53 |
+
Pause as PauseIcon,
|
| 54 |
+
Scissors as ScissorsIcon,
|
| 55 |
+
Music as MusicIcon,
|
| 56 |
+
Activity as HealthIcon,
|
| 57 |
+
} from 'lucide-react';
|
| 58 |
+
import api from '../api';
|
| 59 |
+
import { appStyles } from '../theme';
|
| 60 |
+
|
| 61 |
+
/**
|
| 62 |
+
* DatasetPrep β sidecar-native dataset surface with a buffered editing model.
|
| 63 |
+
*
|
| 64 |
+
* One page, no modes. Pick or create a project. The dataset folder on disk
|
| 65 |
+
* is the *committed* state. Edits, auto-annotate output, and just-ingested
|
| 66 |
+
* audio all live in an in-memory session until the user explicitly hits
|
| 67 |
+
* Save (writes a draft) or Commit (writes .txt sidecars).
|
| 68 |
+
*/
|
| 69 |
+
export default function DatasetPrep({ onOpenCheckpointManager }) {
|
| 70 |
+
const [projects, setProjects] = useState([]);
|
| 71 |
+
const [selectedName, setSelectedName] = useState(() => {
|
| 72 |
+
try { return window.localStorage.getItem('fragmenta.datasetPrep.lastProject') || ''; }
|
| 73 |
+
catch { return ''; }
|
| 74 |
+
});
|
| 75 |
+
const [project, setProject] = useState(null);
|
| 76 |
+
const [createOpen, setCreateOpen] = useState(false);
|
| 77 |
+
const [loadOpen, setLoadOpen] = useState(false);
|
| 78 |
+
const [ingestOpen, setIngestOpen] = useState(false);
|
| 79 |
+
const [sliceTarget, setSliceTarget] = useState(null); // file_name or null
|
| 80 |
+
// Single confirm-dialog state powering destructive actions. Mirrors the
|
| 81 |
+
// Free GPU / Start Fresh confirm style from App.js β replaces the
|
| 82 |
+
// browser-native window.confirm() prompts so the UX is consistent.
|
| 83 |
+
const [confirm, setConfirm] = useState(null);
|
| 84 |
+
const [confirmBusy, setConfirmBusy] = useState(false);
|
| 85 |
+
const [error, setError] = useState('');
|
| 86 |
+
|
| 87 |
+
const [errorCode, setErrorCode] = useState('');
|
| 88 |
+
const [errorExtra, setErrorExtra] = useState(null);
|
| 89 |
+
const [annotateJob, setAnnotateJob] = useState(null);
|
| 90 |
+
const [notice, setNotice] = useState(null); // { severity, message } | null
|
| 91 |
+
// Phase 6 β pre-encoded latents
|
| 92 |
+
const [preEncodeJob, setPreEncodeJob] = useState(null);
|
| 93 |
+
const [preEncodeOffer, setPreEncodeOffer] = useState(false); // post-commit dialog
|
| 94 |
+
const [tier, setTier] = useState(() => {
|
| 95 |
+
try { return window.localStorage.getItem('fragmenta.datasetPrep.tier') || 'basic'; }
|
| 96 |
+
catch { return 'basic'; }
|
| 97 |
+
});
|
| 98 |
+
const [skipExisting, setSkipExisting] = useState(true);
|
| 99 |
+
|
| 100 |
+
const pollHandleRef = useRef(null);
|
| 101 |
+
const preEncodePollRef = useRef(null);
|
| 102 |
+
const isAnnotating = annotateJob?.state === 'running';
|
| 103 |
+
const isPreEncoding = preEncodeJob?.state === 'running' || preEncodeJob?.state === 'queued';
|
| 104 |
+
|
| 105 |
+
// --- Multi-row selection (for bulk Slice) -----------------------------
|
| 106 |
+
// Set<string> of clip file_names. Reset whenever the active project
|
| 107 |
+
// changes, since selections from a different project are meaningless.
|
| 108 |
+
const [selectedFiles, setSelectedFiles] = useState(() => new Set());
|
| 109 |
+
useEffect(() => { setSelectedFiles(new Set()); }, [selectedName]);
|
| 110 |
+
|
| 111 |
+
const toggleSelected = useCallback((fileName) => {
|
| 112 |
+
setSelectedFiles((prev) => {
|
| 113 |
+
const next = new Set(prev);
|
| 114 |
+
if (next.has(fileName)) next.delete(fileName);
|
| 115 |
+
else next.add(fileName);
|
| 116 |
+
return next;
|
| 117 |
+
});
|
| 118 |
+
}, []);
|
| 119 |
+
const toggleSelectAll = useCallback((clips) => {
|
| 120 |
+
setSelectedFiles((prev) => {
|
| 121 |
+
const allNames = clips.map((c) => c.file_name);
|
| 122 |
+
const allSelected = allNames.length > 0 && allNames.every((n) => prev.has(n));
|
| 123 |
+
return allSelected ? new Set() : new Set(allNames);
|
| 124 |
+
});
|
| 125 |
+
}, []);
|
| 126 |
+
const clearSelection = useCallback(() => setSelectedFiles(new Set()), []);
|
| 127 |
+
|
| 128 |
+
// --- Per-row audio preview --------------------------------------------
|
| 129 |
+
// One <audio> for the whole table. Rows just say "play me" / "pause";
|
| 130 |
+
// the parent reconciles which file is loaded and where the playhead is.
|
| 131 |
+
const audioRef = useRef(null);
|
| 132 |
+
const [playingFile, setPlayingFile] = useState(null);
|
| 133 |
+
const [playProgress, setPlayProgress] = useState(0); // 0..1
|
| 134 |
+
|
| 135 |
+
const stopPlayback = useCallback(() => {
|
| 136 |
+
const audio = audioRef.current;
|
| 137 |
+
if (audio) { audio.pause(); }
|
| 138 |
+
setPlayingFile(null);
|
| 139 |
+
setPlayProgress(0);
|
| 140 |
+
}, []);
|
| 141 |
+
|
| 142 |
+
const handlePlayToggle = useCallback((fileName) => {
|
| 143 |
+
if (!selectedName) return;
|
| 144 |
+
const audio = audioRef.current;
|
| 145 |
+
if (!audio) return;
|
| 146 |
+
if (playingFile === fileName) {
|
| 147 |
+
audio.pause();
|
| 148 |
+
setPlayingFile(null);
|
| 149 |
+
return;
|
| 150 |
+
}
|
| 151 |
+
const url = `/api/projects/${encodeURIComponent(selectedName)}/clip/${encodeURIComponent(fileName)}/audio`;
|
| 152 |
+
audio.src = url;
|
| 153 |
+
setPlayProgress(0);
|
| 154 |
+
setPlayingFile(fileName);
|
| 155 |
+
audio.play().catch(() => {
|
| 156 |
+
setPlayingFile(null);
|
| 157 |
+
});
|
| 158 |
+
}, [selectedName, playingFile]);
|
| 159 |
+
|
| 160 |
+
// Stop playback when the project changes β the audio element's src would
|
| 161 |
+
// suddenly refer to a different project's file.
|
| 162 |
+
useEffect(() => { stopPlayback(); }, [selectedName, stopPlayback]);
|
| 163 |
+
|
| 164 |
+
const refreshProjects = useCallback(async () => {
|
| 165 |
+
try {
|
| 166 |
+
const { data } = await api.get('/api/projects');
|
| 167 |
+
setProjects(data.projects || []);
|
| 168 |
+
} catch (e) { setError(extractError(e, 'Failed to list projects')); }
|
| 169 |
+
}, []);
|
| 170 |
+
|
| 171 |
+
const [health, setHealth] = useState(null);
|
| 172 |
+
const refreshHealth = useCallback(async (name) => {
|
| 173 |
+
if (!name) { setHealth(null); return; }
|
| 174 |
+
try {
|
| 175 |
+
const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}/health`);
|
| 176 |
+
setHealth(data);
|
| 177 |
+
} catch {
|
| 178 |
+
// Non-fatal β strip just hides until next refresh.
|
| 179 |
+
setHealth(null);
|
| 180 |
+
}
|
| 181 |
+
}, []);
|
| 182 |
+
|
| 183 |
+
const refreshProject = useCallback(async (name) => {
|
| 184 |
+
if (!name) { setProject(null); setHealth(null); return; }
|
| 185 |
+
try {
|
| 186 |
+
const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}`);
|
| 187 |
+
setProject(data);
|
| 188 |
+
refreshHealth(name);
|
| 189 |
+
} catch (e) {
|
| 190 |
+
if (e?.response?.status === 404) {
|
| 191 |
+
setSelectedName('');
|
| 192 |
+
setProject(null);
|
| 193 |
+
setHealth(null);
|
| 194 |
+
await refreshProjects();
|
| 195 |
+
return;
|
| 196 |
+
}
|
| 197 |
+
setError(extractError(e, 'Failed to load project'));
|
| 198 |
+
}
|
| 199 |
+
}, [refreshProjects, refreshHealth]);
|
| 200 |
+
|
| 201 |
+
useEffect(() => { refreshProjects(); }, [refreshProjects]);
|
| 202 |
+
|
| 203 |
+
const pollAnnotateStatus = useCallback(async function poll(name) {
|
| 204 |
+
try {
|
| 205 |
+
const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}/annotate/status`);
|
| 206 |
+
setAnnotateJob(data.job);
|
| 207 |
+
if (data.job.state === 'done') {
|
| 208 |
+
await refreshProject(name);
|
| 209 |
+
return;
|
| 210 |
+
}
|
| 211 |
+
if (data.job.state === 'error') {
|
| 212 |
+
setError(data.job.error || 'Annotation failed');
|
| 213 |
+
return;
|
| 214 |
+
}
|
| 215 |
+
// Only keep polling while the backend is actively annotating. Other
|
| 216 |
+
// states ('idle', 'cancelled', missing) terminate the loop so a
|
| 217 |
+
// freshly-mounted tab doesn't poll forever for a non-existent job.
|
| 218 |
+
if (data.job.state === 'running') {
|
| 219 |
+
pollHandleRef.current = window.setTimeout(() => poll(name), 500);
|
| 220 |
+
}
|
| 221 |
+
} catch (e) { setError(extractError(e, 'Status poll failed')); }
|
| 222 |
+
}, [refreshProject]);
|
| 223 |
+
|
| 224 |
+
// Phase 6 β pre-encode polling. Same survives-tab-switch shape as the
|
| 225 |
+
// annotate poller above.
|
| 226 |
+
const pollPreEncodeStatus = useCallback(async function poll(name) {
|
| 227 |
+
try {
|
| 228 |
+
const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}/pre-encode/status`);
|
| 229 |
+
setPreEncodeJob(data.job);
|
| 230 |
+
if (data.job.state === 'complete') {
|
| 231 |
+
refreshProject(name);
|
| 232 |
+
return;
|
| 233 |
+
}
|
| 234 |
+
if (data.job.state === 'failed') {
|
| 235 |
+
setError(data.job.error || 'Pre-encoding failed');
|
| 236 |
+
return;
|
| 237 |
+
}
|
| 238 |
+
if (data.job.state === 'running' || data.job.state === 'queued') {
|
| 239 |
+
preEncodePollRef.current = window.setTimeout(() => poll(name), 750);
|
| 240 |
+
}
|
| 241 |
+
} catch (e) { /* non-fatal β bar just freezes */ }
|
| 242 |
+
}, [refreshProject]);
|
| 243 |
+
|
| 244 |
+
useEffect(() => {
|
| 245 |
+
if (selectedName) {
|
| 246 |
+
try { window.localStorage.setItem('fragmenta.datasetPrep.lastProject', selectedName); } catch {}
|
| 247 |
+
refreshProject(selectedName);
|
| 248 |
+
// Re-bootstrap progress polling on (re)mount or project switch, so
|
| 249 |
+
// the progress strip survives tab changes while a job runs.
|
| 250 |
+
pollAnnotateStatus(selectedName);
|
| 251 |
+
pollPreEncodeStatus(selectedName);
|
| 252 |
+
} else {
|
| 253 |
+
setProject(null);
|
| 254 |
+
setAnnotateJob(null);
|
| 255 |
+
setPreEncodeJob(null);
|
| 256 |
+
}
|
| 257 |
+
return () => {
|
| 258 |
+
if (pollHandleRef.current) {
|
| 259 |
+
window.clearTimeout(pollHandleRef.current);
|
| 260 |
+
pollHandleRef.current = null;
|
| 261 |
+
}
|
| 262 |
+
if (preEncodePollRef.current) {
|
| 263 |
+
window.clearTimeout(preEncodePollRef.current);
|
| 264 |
+
preEncodePollRef.current = null;
|
| 265 |
+
}
|
| 266 |
+
};
|
| 267 |
+
}, [selectedName, refreshProject, pollAnnotateStatus, pollPreEncodeStatus]);
|
| 268 |
+
|
| 269 |
+
function changeTier(value) {
|
| 270 |
+
setTier(value);
|
| 271 |
+
try { window.localStorage.setItem('fragmenta.datasetPrep.tier', value); } catch {}
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
function trySelectProject(nextName) {
|
| 275 |
+
// Confirm before switching if there are unsaved or uncommitted edits.
|
| 276 |
+
if (project && (project.dirty || project.has_unsaved_changes) && nextName !== project.name) {
|
| 277 |
+
const ok = window.confirm(
|
| 278 |
+
`β${project.name}β has unsaved or uncommitted changes. Switch anyway? They'll stay in memory until you reload the project β but a backend restart will lose them.`,
|
| 279 |
+
);
|
| 280 |
+
if (!ok) return;
|
| 281 |
+
}
|
| 282 |
+
setSelectedName(nextName);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
async function handleAnnotate(scope /* "all" | [file_names] */, opts = {}) {
|
| 286 |
+
if (!project) return;
|
| 287 |
+
setError(''); setErrorCode(''); setErrorExtra(null);
|
| 288 |
+
try {
|
| 289 |
+
await api.post(`/api/projects/${encodeURIComponent(project.name)}/annotate`, {
|
| 290 |
+
tier,
|
| 291 |
+
scope: scope ?? 'all',
|
| 292 |
+
skip_existing: opts.skip_existing ?? skipExisting,
|
| 293 |
+
});
|
| 294 |
+
pollAnnotateStatus(project.name);
|
| 295 |
+
} catch (e) {
|
| 296 |
+
const body = e?.response?.data || {};
|
| 297 |
+
setError(extractError(e, 'Failed to start annotation'));
|
| 298 |
+
setErrorCode(body.code || '');
|
| 299 |
+
setErrorExtra(body.install_command ? { install_command: body.install_command } : null);
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
async function handleCancelAnnotate() {
|
| 304 |
+
if (!project) return;
|
| 305 |
+
try {
|
| 306 |
+
await api.post(`/api/projects/${encodeURIComponent(project.name)}/annotate/cancel`);
|
| 307 |
+
} catch (e) { setError(extractError(e, 'Cancel failed')); }
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
async function handleSave() {
|
| 311 |
+
if (!project) return;
|
| 312 |
+
setError('');
|
| 313 |
+
try {
|
| 314 |
+
const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/save`);
|
| 315 |
+
setProject(data);
|
| 316 |
+
setNotice({ severity: 'success', message: `Draft saved Β· ${data.clip_count} clips` });
|
| 317 |
+
} catch (e) { setError(extractError(e, 'Save failed')); }
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
async function handleStartPreEncode() {
|
| 321 |
+
if (!project) return;
|
| 322 |
+
setError('');
|
| 323 |
+
try {
|
| 324 |
+
const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/pre-encode`);
|
| 325 |
+
setPreEncodeJob(data.job);
|
| 326 |
+
pollPreEncodeStatus(project.name);
|
| 327 |
+
} catch (e) { setError(extractError(e, 'Pre-encode failed to start')); }
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
async function handleCancelPreEncode() {
|
| 331 |
+
if (!project) return;
|
| 332 |
+
try {
|
| 333 |
+
await api.post(`/api/projects/${encodeURIComponent(project.name)}/pre-encode/cancel`);
|
| 334 |
+
} catch (e) { setError(extractError(e, 'Cancel failed')); }
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
async function persistPreEncodeSuppression(suppress) {
|
| 338 |
+
if (!project) return;
|
| 339 |
+
try {
|
| 340 |
+
const { data } = await api.patch(
|
| 341 |
+
`/api/projects/${encodeURIComponent(project.name)}/pre-encode/prompt`,
|
| 342 |
+
{ suppress: !!suppress },
|
| 343 |
+
);
|
| 344 |
+
setProject(data);
|
| 345 |
+
} catch (e) { /* non-fatal β dialog still closes */ }
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
async function handleCommit() {
|
| 349 |
+
if (!project) return;
|
| 350 |
+
setError('');
|
| 351 |
+
try {
|
| 352 |
+
const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/commit`);
|
| 353 |
+
setProject(data);
|
| 354 |
+
await refreshProjects();
|
| 355 |
+
// Phase 6 β post-commit pre-encode prompt.
|
| 356 |
+
// Open the dialog unless: (a) latents already present (re-commit
|
| 357 |
+
// wiped them but we still avoid re-asking immediately), or
|
| 358 |
+
// (b) the user previously chose "Don't ask again".
|
| 359 |
+
if (!data.suppress_pre_encode_prompt && !data.latents_present && data.clip_count > 0) {
|
| 360 |
+
setPreEncodeOffer(true);
|
| 361 |
+
}
|
| 362 |
+
setNotice({
|
| 363 |
+
severity: 'success',
|
| 364 |
+
message: `Dataset created Β· ${data.clip_count} clips written to disk`,
|
| 365 |
+
});
|
| 366 |
+
} catch (e) { setError(extractError(e, 'Create Dataset failed')); }
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
function handleDiscard() {
|
| 370 |
+
if (!project) return;
|
| 371 |
+
setConfirm({
|
| 372 |
+
title: 'Delete unsaved changes',
|
| 373 |
+
body: `Delete all changes in β${project.name}β since the last created dataset? Audio files added since then will be removed.`,
|
| 374 |
+
warning: 'This cannot be undone.',
|
| 375 |
+
confirmLabel: 'Delete',
|
| 376 |
+
busyLabel: 'Deletingβ¦',
|
| 377 |
+
danger: true,
|
| 378 |
+
onConfirm: async () => {
|
| 379 |
+
setError('');
|
| 380 |
+
try {
|
| 381 |
+
const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/discard`);
|
| 382 |
+
setProject(data);
|
| 383 |
+
await refreshProjects();
|
| 384 |
+
setNotice({ severity: 'info', message: 'Unsaved changes discarded' });
|
| 385 |
+
} catch (e) { setError(extractError(e, 'Delete failed')); }
|
| 386 |
+
},
|
| 387 |
+
});
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
function handleDeleteProject(name) {
|
| 391 |
+
if (!name) return;
|
| 392 |
+
setConfirm({
|
| 393 |
+
title: 'Delete project',
|
| 394 |
+
body: `Permanently delete project β${name}β? Audio files, sidecars, and any drafts will be removed from disk.`,
|
| 395 |
+
warning: 'This cannot be undone.',
|
| 396 |
+
confirmLabel: 'Delete',
|
| 397 |
+
busyLabel: 'Deletingβ¦',
|
| 398 |
+
danger: true,
|
| 399 |
+
onConfirm: async () => {
|
| 400 |
+
setError('');
|
| 401 |
+
try {
|
| 402 |
+
await api.delete(`/api/projects/${encodeURIComponent(name)}`);
|
| 403 |
+
if (selectedName === name) {
|
| 404 |
+
stopPlayback();
|
| 405 |
+
setSelectedName('');
|
| 406 |
+
setProject(null);
|
| 407 |
+
try { window.localStorage.removeItem('fragmenta.datasetPrep.lastProject'); } catch {}
|
| 408 |
+
}
|
| 409 |
+
await refreshProjects();
|
| 410 |
+
} catch (e) { setError(extractError(e, 'Delete project failed')); }
|
| 411 |
+
},
|
| 412 |
+
});
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
async function handleChangeTemplatePreset(presetId) {
|
| 416 |
+
if (!project) return;
|
| 417 |
+
try {
|
| 418 |
+
const { data } = await api.patch(
|
| 419 |
+
`/api/projects/${encodeURIComponent(project.name)}/template`,
|
| 420 |
+
{ preset: presetId },
|
| 421 |
+
);
|
| 422 |
+
setProject(data);
|
| 423 |
+
} catch (e) {
|
| 424 |
+
setError(extractError(e, 'Could not update annotation style'));
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
function handleClearSelectedAnnotations() {
|
| 429 |
+
if (!project || selectedFiles.size === 0) return;
|
| 430 |
+
const count = selectedFiles.size;
|
| 431 |
+
const files = Array.from(selectedFiles);
|
| 432 |
+
setConfirm({
|
| 433 |
+
title: 'Clear',
|
| 434 |
+
body: `Clear annotations on ${count} clip${count === 1 ? '' : 's'}? Buffered in memory until you Save or Create Dataset.`,
|
| 435 |
+
warning: 'Use the Delete button to revert; this action itself canβt be undone in place.',
|
| 436 |
+
confirmLabel: `Clear (${count})`,
|
| 437 |
+
busyLabel: 'Clearingβ¦',
|
| 438 |
+
danger: true,
|
| 439 |
+
onConfirm: async () => {
|
| 440 |
+
setError('');
|
| 441 |
+
try {
|
| 442 |
+
for (const f of files) {
|
| 443 |
+
await api.patch(
|
| 444 |
+
`/api/projects/${encodeURIComponent(project.name)}/clip/${encodeURIComponent(f)}`,
|
| 445 |
+
{ prompt: '' },
|
| 446 |
+
);
|
| 447 |
+
}
|
| 448 |
+
clearSelection();
|
| 449 |
+
await refreshProject(project.name);
|
| 450 |
+
} catch (e) { setError(extractError(e, 'Clear annotations failed')); }
|
| 451 |
+
},
|
| 452 |
+
});
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
async function handleClipPromptChange(fileName, newPrompt) {
|
| 456 |
+
if (!project) return;
|
| 457 |
+
try {
|
| 458 |
+
await api.patch(
|
| 459 |
+
`/api/projects/${encodeURIComponent(project.name)}/clip/${encodeURIComponent(fileName)}`,
|
| 460 |
+
{ prompt: newPrompt },
|
| 461 |
+
);
|
| 462 |
+
// Reload to pick up dirty-state flip in the header.
|
| 463 |
+
await refreshProject(project.name);
|
| 464 |
+
} catch (e) { setError(extractError(e, 'Failed to save prompt')); }
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
async function handleClipDelete(fileName) {
|
| 468 |
+
if (!project) return;
|
| 469 |
+
if (!window.confirm(`Remove ${fileName} from this project? (Deletes the audio file from disk immediately β cannot be discarded back.)`)) return;
|
| 470 |
+
try {
|
| 471 |
+
await api.delete(
|
| 472 |
+
`/api/projects/${encodeURIComponent(project.name)}/clip/${encodeURIComponent(fileName)}`,
|
| 473 |
+
);
|
| 474 |
+
await refreshProject(project.name);
|
| 475 |
+
} catch (e) { setError(extractError(e, 'Failed to delete clip')); }
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
return (
|
| 479 |
+
<Paper variant="outlined" sx={{ p: { xs: 2.25, sm: 3 }, borderRadius: 2.5 }}>
|
| 480 |
+
<Stack spacing={2.5}>
|
| 481 |
+
<Box>
|
| 482 |
+
<Box sx={{ ...appStyles.sectionCardHeader, mb: 0.5 }}>
|
| 483 |
+
<Box component="span" sx={appStyles.sectionCardIcon}>
|
| 484 |
+
<Database size={20} />
|
| 485 |
+
</Box>
|
| 486 |
+
<Typography variant="h6" sx={appStyles.sectionCardTitle}>
|
| 487 |
+
Dataset Workbench
|
| 488 |
+
</Typography>
|
| 489 |
+
<Box sx={{ flex: 1 }} />
|
| 490 |
+
<Button
|
| 491 |
+
variant="outlined"
|
| 492 |
+
size="small"
|
| 493 |
+
startIcon={<FolderOpenIcon size={16} />}
|
| 494 |
+
onClick={() => setLoadOpen(true)}
|
| 495 |
+
disabled={projects.length === 0}
|
| 496 |
+
>
|
| 497 |
+
Load project
|
| 498 |
+
</Button>
|
| 499 |
+
<Button
|
| 500 |
+
variant="outlined"
|
| 501 |
+
size="small"
|
| 502 |
+
startIcon={<PlusIcon size={16} />}
|
| 503 |
+
onClick={() => setCreateOpen(true)}
|
| 504 |
+
>
|
| 505 |
+
New project
|
| 506 |
+
</Button>
|
| 507 |
+
</Box>
|
| 508 |
+
<Typography variant="body2" color="text.secondary">
|
| 509 |
+
Create a new dataset or load and edit one.
|
| 510 |
+
</Typography>
|
| 511 |
+
<Typography variant="body2" color="text.secondary" paddingBottom={2}>
|
| 512 |
+
You can auto-annotate using Librosa and CLAP or annotate everything manually.
|
| 513 |
+
</Typography>
|
| 514 |
+
</Box>
|
| 515 |
+
|
| 516 |
+
{error && (
|
| 517 |
+
<Alert
|
| 518 |
+
severity={(errorCode === 'clap_not_available' || errorCode === 'clap_package_missing') ? 'warning' : 'error'}
|
| 519 |
+
onClose={() => { setError(''); setErrorCode(''); setErrorExtra(null); }}
|
| 520 |
+
action={
|
| 521 |
+
errorCode === 'clap_not_available' && onOpenCheckpointManager ? (
|
| 522 |
+
<Button
|
| 523 |
+
color="inherit"
|
| 524 |
+
size="small"
|
| 525 |
+
onClick={() => { setError(''); setErrorCode(''); setErrorExtra(null); onOpenCheckpointManager(); }}
|
| 526 |
+
>
|
| 527 |
+
Open Model Management
|
| 528 |
+
</Button>
|
| 529 |
+
) : null
|
| 530 |
+
}
|
| 531 |
+
>
|
| 532 |
+
<Box>
|
| 533 |
+
<Typography variant="body2">{error}</Typography>
|
| 534 |
+
{errorCode === 'clap_package_missing' && errorExtra?.install_command && (
|
| 535 |
+
<Box
|
| 536 |
+
component="pre"
|
| 537 |
+
sx={{
|
| 538 |
+
mt: 1,
|
| 539 |
+
mb: 0,
|
| 540 |
+
p: 1,
|
| 541 |
+
borderRadius: 1,
|
| 542 |
+
bgcolor: 'action.hover',
|
| 543 |
+
fontSize: '0.8rem',
|
| 544 |
+
fontFamily: 'monospace',
|
| 545 |
+
overflowX: 'auto',
|
| 546 |
+
}}
|
| 547 |
+
>
|
| 548 |
+
{errorExtra.install_command}
|
| 549 |
+
</Box>
|
| 550 |
+
)}
|
| 551 |
+
</Box>
|
| 552 |
+
</Alert>
|
| 553 |
+
)}
|
| 554 |
+
|
| 555 |
+
{project && (
|
| 556 |
+
<Stack spacing={2}>
|
| 557 |
+
<ProjectHeader
|
| 558 |
+
project={project}
|
| 559 |
+
onSave={handleSave}
|
| 560 |
+
onCommit={handleCommit}
|
| 561 |
+
onDiscard={handleDiscard}
|
| 562 |
+
onAddAudio={() => setIngestOpen(true)}
|
| 563 |
+
disabled={isAnnotating}
|
| 564 |
+
/>
|
| 565 |
+
|
| 566 |
+
<HealthStrip
|
| 567 |
+
health={health}
|
| 568 |
+
onSelectFiles={(files) => setSelectedFiles(new Set(files))}
|
| 569 |
+
/>
|
| 570 |
+
|
| 571 |
+
{isAnnotating && annotateJob && (
|
| 572 |
+
<Box>
|
| 573 |
+
<LinearProgress
|
| 574 |
+
variant={annotateJob.total > 0 ? 'determinate' : 'indeterminate'}
|
| 575 |
+
value={annotateJob.total > 0 ? (annotateJob.current / annotateJob.total) * 100 : undefined}
|
| 576 |
+
/>
|
| 577 |
+
<Box sx={{ mt: 0.75, display: 'flex', alignItems: 'center', gap: 1.5 }}>
|
| 578 |
+
<Typography variant="caption" color="text.secondary" sx={{ flex: 1 }}>
|
| 579 |
+
Annotating {annotateJob.current} / {annotateJob.total}
|
| 580 |
+
{annotateJob.current_file ? ` Β· ${annotateJob.current_file}` : ''}
|
| 581 |
+
</Typography>
|
| 582 |
+
<Button
|
| 583 |
+
size="small"
|
| 584 |
+
variant="outlined"
|
| 585 |
+
color="error"
|
| 586 |
+
startIcon={<StopIcon size={14} />}
|
| 587 |
+
onClick={handleCancelAnnotate}
|
| 588 |
+
>
|
| 589 |
+
Stop
|
| 590 |
+
</Button>
|
| 591 |
+
</Box>
|
| 592 |
+
</Box>
|
| 593 |
+
)}
|
| 594 |
+
|
| 595 |
+
{isPreEncoding && preEncodeJob && (
|
| 596 |
+
<Box>
|
| 597 |
+
<LinearProgress
|
| 598 |
+
variant={preEncodeJob.total > 0 ? 'determinate' : 'indeterminate'}
|
| 599 |
+
value={preEncodeJob.total > 0 ? (preEncodeJob.current / preEncodeJob.total) * 100 : undefined}
|
| 600 |
+
/>
|
| 601 |
+
<Box sx={{ mt: 0.75, display: 'flex', alignItems: 'center', gap: 1.5 }}>
|
| 602 |
+
<Typography variant="caption" color="text.secondary" sx={{ flex: 1 }}>
|
| 603 |
+
Pre-encoding latents Β· {preEncodeJob.current} / {preEncodeJob.total}
|
| 604 |
+
{preEncodeJob.autoencoder ? ` Β· ${preEncodeJob.autoencoder}` : ''}
|
| 605 |
+
</Typography>
|
| 606 |
+
<Button
|
| 607 |
+
size="small"
|
| 608 |
+
variant="outlined"
|
| 609 |
+
color="error"
|
| 610 |
+
startIcon={<StopIcon size={14} />}
|
| 611 |
+
onClick={handleCancelPreEncode}
|
| 612 |
+
>
|
| 613 |
+
Stop
|
| 614 |
+
</Button>
|
| 615 |
+
</Box>
|
| 616 |
+
</Box>
|
| 617 |
+
)}
|
| 618 |
+
|
| 619 |
+
<ClipTable
|
| 620 |
+
projectName={selectedName}
|
| 621 |
+
clips={project.clips}
|
| 622 |
+
playingFile={playingFile}
|
| 623 |
+
playProgress={playProgress}
|
| 624 |
+
onPlayToggle={handlePlayToggle}
|
| 625 |
+
onPromptChange={handleClipPromptChange}
|
| 626 |
+
onAnnotate={(fname) => handleAnnotate([fname], { skip_existing: false })}
|
| 627 |
+
onDelete={(fname) => {
|
| 628 |
+
if (playingFile === fname) stopPlayback();
|
| 629 |
+
return handleClipDelete(fname);
|
| 630 |
+
}}
|
| 631 |
+
onSlice={(fname) => {
|
| 632 |
+
if (playingFile === fname) stopPlayback();
|
| 633 |
+
setSliceTarget(fname);
|
| 634 |
+
}}
|
| 635 |
+
selectedFiles={selectedFiles}
|
| 636 |
+
onToggleSelected={toggleSelected}
|
| 637 |
+
onToggleSelectAll={() => toggleSelectAll(project.clips)}
|
| 638 |
+
disabled={isAnnotating}
|
| 639 |
+
toolbar={
|
| 640 |
+
<Stack spacing={1}>
|
| 641 |
+
<Box sx={{ display: 'flex', alignItems: 'center', flexWrap: 'wrap', gap: 1.5 }}>
|
| 642 |
+
<Button
|
| 643 |
+
variant="contained"
|
| 644 |
+
color="warm"
|
| 645 |
+
size="small"
|
| 646 |
+
startIcon={<WandSparkles size={16} />}
|
| 647 |
+
onClick={() => handleAnnotate('all')}
|
| 648 |
+
disabled={isAnnotating || project.clip_count === 0}
|
| 649 |
+
>
|
| 650 |
+
Auto-annotate all
|
| 651 |
+
</Button>
|
| 652 |
+
<FormControl size="small" sx={{ minWidth: 180 }}>
|
| 653 |
+
<Select
|
| 654 |
+
value={project.prompt_template_preset || 'music'}
|
| 655 |
+
onChange={(e) => handleChangeTemplatePreset(e.target.value)}
|
| 656 |
+
disabled={isAnnotating}
|
| 657 |
+
renderValue={(v) => {
|
| 658 |
+
const p = (project.prompt_template_presets || []).find((x) => x.id === v);
|
| 659 |
+
return p ? p.label : v;
|
| 660 |
+
}}
|
| 661 |
+
>
|
| 662 |
+
{(project.prompt_template_presets || []).map((p) => (
|
| 663 |
+
<MenuItem key={p.id} value={p.id}>
|
| 664 |
+
<Box>
|
| 665 |
+
<Typography variant="body2">{p.label}</Typography>
|
| 666 |
+
<Typography variant="caption" color="text.secondary">
|
| 667 |
+
{p.description}
|
| 668 |
+
</Typography>
|
| 669 |
+
</Box>
|
| 670 |
+
</MenuItem>
|
| 671 |
+
))}
|
| 672 |
+
</Select>
|
| 673 |
+
</FormControl>
|
| 674 |
+
<Tooltip title={TIPS.dataset.richAnnotate}>
|
| 675 |
+
<FormControlLabel
|
| 676 |
+
control={
|
| 677 |
+
<Switch
|
| 678 |
+
size="small"
|
| 679 |
+
checked={tier === 'rich'}
|
| 680 |
+
onChange={(e) => changeTier(e.target.checked ? 'rich' : 'basic')}
|
| 681 |
+
disabled={isAnnotating}
|
| 682 |
+
/>
|
| 683 |
+
}
|
| 684 |
+
label={<Typography variant="caption" color="text.secondary">Rich annotation</Typography>}
|
| 685 |
+
sx={{ mr: 0 }}
|
| 686 |
+
/>
|
| 687 |
+
</Tooltip>
|
| 688 |
+
<Tooltip title={TIPS.dataset.skipAnnotated}>
|
| 689 |
+
<FormControlLabel
|
| 690 |
+
control={
|
| 691 |
+
<Switch
|
| 692 |
+
size="small"
|
| 693 |
+
checked={skipExisting}
|
| 694 |
+
onChange={(e) => setSkipExisting(e.target.checked)}
|
| 695 |
+
disabled={isAnnotating}
|
| 696 |
+
/>
|
| 697 |
+
}
|
| 698 |
+
label={<Typography variant="caption" color="text.secondary">Skip already annotated</Typography>}
|
| 699 |
+
sx={{ mr: 0 }}
|
| 700 |
+
/>
|
| 701 |
+
</Tooltip>
|
| 702 |
+
<Box sx={{ flex: 1 }} />
|
| 703 |
+
{selectedFiles.size > 0 && (
|
| 704 |
+
<Button
|
| 705 |
+
variant="outlined"
|
| 706 |
+
color="error"
|
| 707 |
+
size="small"
|
| 708 |
+
startIcon={<TrashIcon size={16} />}
|
| 709 |
+
onClick={handleClearSelectedAnnotations}
|
| 710 |
+
disabled={isAnnotating}
|
| 711 |
+
>
|
| 712 |
+
Clear annotations ({selectedFiles.size})
|
| 713 |
+
</Button>
|
| 714 |
+
)}
|
| 715 |
+
</Box>
|
| 716 |
+
{tier === 'rich' && (
|
| 717 |
+
<ClapVocabAccordion disabled={isAnnotating} />
|
| 718 |
+
)}
|
| 719 |
+
</Stack>
|
| 720 |
+
}
|
| 721 |
+
/>
|
| 722 |
+
<audio
|
| 723 |
+
ref={audioRef}
|
| 724 |
+
style={{ display: 'none' }}
|
| 725 |
+
onTimeUpdate={(e) => {
|
| 726 |
+
const a = e.currentTarget;
|
| 727 |
+
if (a.duration && isFinite(a.duration)) {
|
| 728 |
+
setPlayProgress(a.currentTime / a.duration);
|
| 729 |
+
}
|
| 730 |
+
}}
|
| 731 |
+
onEnded={() => { setPlayingFile(null); setPlayProgress(0); }}
|
| 732 |
+
onError={() => { setPlayingFile(null); setPlayProgress(0); }}
|
| 733 |
+
/>
|
| 734 |
+
</Stack>
|
| 735 |
+
)}
|
| 736 |
+
|
| 737 |
+
<CreateProjectDialog
|
| 738 |
+
open={createOpen}
|
| 739 |
+
existingNames={projects.map((p) => p.name)}
|
| 740 |
+
onClose={() => setCreateOpen(false)}
|
| 741 |
+
onCreated={async (name) => {
|
| 742 |
+
setCreateOpen(false);
|
| 743 |
+
await refreshProjects();
|
| 744 |
+
setSelectedName(name);
|
| 745 |
+
}}
|
| 746 |
+
/>
|
| 747 |
+
|
| 748 |
+
<LoadProjectDialog
|
| 749 |
+
open={loadOpen}
|
| 750 |
+
projects={projects}
|
| 751 |
+
currentName={selectedName}
|
| 752 |
+
onClose={() => setLoadOpen(false)}
|
| 753 |
+
onLoad={(name) => {
|
| 754 |
+
setLoadOpen(false);
|
| 755 |
+
trySelectProject(name);
|
| 756 |
+
}}
|
| 757 |
+
onDeleteProject={handleDeleteProject}
|
| 758 |
+
/>
|
| 759 |
+
|
| 760 |
+
<IngestDialog
|
| 761 |
+
open={ingestOpen}
|
| 762 |
+
projectName={project?.name}
|
| 763 |
+
onClose={() => setIngestOpen(false)}
|
| 764 |
+
onIngested={async () => {
|
| 765 |
+
setIngestOpen(false);
|
| 766 |
+
if (project) await refreshProject(project.name);
|
| 767 |
+
await refreshProjects();
|
| 768 |
+
}}
|
| 769 |
+
/>
|
| 770 |
+
|
| 771 |
+
<SliceDialog
|
| 772 |
+
open={Boolean(sliceTarget)}
|
| 773 |
+
projectName={project?.name}
|
| 774 |
+
fileName={sliceTarget}
|
| 775 |
+
onClose={() => setSliceTarget(null)}
|
| 776 |
+
onSliced={async () => {
|
| 777 |
+
clearSelection();
|
| 778 |
+
if (project) await refreshProject(project.name);
|
| 779 |
+
await refreshProjects();
|
| 780 |
+
}}
|
| 781 |
+
/>
|
| 782 |
+
|
| 783 |
+
<Dialog
|
| 784 |
+
open={Boolean(confirm)}
|
| 785 |
+
onClose={confirmBusy ? undefined : () => setConfirm(null)}
|
| 786 |
+
aria-labelledby="dataset-confirm-title"
|
| 787 |
+
>
|
| 788 |
+
<DialogTitle id="dataset-confirm-title">
|
| 789 |
+
{confirm?.title}
|
| 790 |
+
</DialogTitle>
|
| 791 |
+
<DialogContent>
|
| 792 |
+
<Typography sx={appStyles.dialogBodyText}>
|
| 793 |
+
{confirm?.body}
|
| 794 |
+
</Typography>
|
| 795 |
+
{confirm?.warning && (
|
| 796 |
+
<Typography variant="body2" color="warning.main" sx={appStyles.dialogErrorText}>
|
| 797 |
+
{confirm.warning}
|
| 798 |
+
</Typography>
|
| 799 |
+
)}
|
| 800 |
+
</DialogContent>
|
| 801 |
+
<DialogActions>
|
| 802 |
+
<Button onClick={() => setConfirm(null)} disabled={confirmBusy}>
|
| 803 |
+
Cancel
|
| 804 |
+
</Button>
|
| 805 |
+
<Button
|
| 806 |
+
onClick={async () => {
|
| 807 |
+
if (!confirm?.onConfirm) { setConfirm(null); return; }
|
| 808 |
+
setConfirmBusy(true);
|
| 809 |
+
try {
|
| 810 |
+
await confirm.onConfirm();
|
| 811 |
+
} finally {
|
| 812 |
+
setConfirmBusy(false);
|
| 813 |
+
setConfirm(null);
|
| 814 |
+
}
|
| 815 |
+
}}
|
| 816 |
+
color={confirm?.danger ? 'error' : 'primary'}
|
| 817 |
+
variant="contained"
|
| 818 |
+
disabled={confirmBusy}
|
| 819 |
+
>
|
| 820 |
+
{confirmBusy ? (confirm?.busyLabel || 'Workingβ¦') : (confirm?.confirmLabel || 'Confirm')}
|
| 821 |
+
</Button>
|
| 822 |
+
</DialogActions>
|
| 823 |
+
</Dialog>
|
| 824 |
+
|
| 825 |
+
{/* Phase 6 β post-commit pre-encode dialog. Surfaces after a
|
| 826 |
+
successful Create Dataset commit unless the user previously
|
| 827 |
+
chose "Don't ask again". */}
|
| 828 |
+
<Dialog
|
| 829 |
+
open={preEncodeOffer}
|
| 830 |
+
onClose={() => setPreEncodeOffer(false)}
|
| 831 |
+
maxWidth="xs"
|
| 832 |
+
fullWidth
|
| 833 |
+
>
|
| 834 |
+
<DialogTitle>Pre-encode latents?</DialogTitle>
|
| 835 |
+
<DialogContent>
|
| 836 |
+
<Typography variant="body2" sx={{ mb: 1 }}>
|
| 837 |
+
Encode your audio into SA3 latents now to speed up training. The
|
| 838 |
+
autoencoder runs once up-front instead of every training step.
|
| 839 |
+
</Typography>
|
| 840 |
+
<Typography variant="caption" color="text.secondary">
|
| 841 |
+
Takes a few minutes for ~50 clips. Latents live in
|
| 842 |
+
<code> {project?.name}/.latents/</code> and get wiped automatically
|
| 843 |
+
when you next commit or edit a clip.
|
| 844 |
+
</Typography>
|
| 845 |
+
</DialogContent>
|
| 846 |
+
<DialogActions sx={{ flexWrap: 'wrap' }}>
|
| 847 |
+
<Button
|
| 848 |
+
onClick={() => {
|
| 849 |
+
persistPreEncodeSuppression(true);
|
| 850 |
+
setPreEncodeOffer(false);
|
| 851 |
+
}}
|
| 852 |
+
sx={{ mr: 'auto' }}
|
| 853 |
+
>
|
| 854 |
+
Don't ask again
|
| 855 |
+
</Button>
|
| 856 |
+
<Button onClick={() => setPreEncodeOffer(false)}>Not now</Button>
|
| 857 |
+
<Button
|
| 858 |
+
variant="contained"
|
| 859 |
+
onClick={() => {
|
| 860 |
+
setPreEncodeOffer(false);
|
| 861 |
+
handleStartPreEncode();
|
| 862 |
+
}}
|
| 863 |
+
>
|
| 864 |
+
Pre-encode now
|
| 865 |
+
</Button>
|
| 866 |
+
</DialogActions>
|
| 867 |
+
</Dialog>
|
| 868 |
+
|
| 869 |
+
<Portal>
|
| 870 |
+
<Snackbar
|
| 871 |
+
open={Boolean(notice)}
|
| 872 |
+
autoHideDuration={4000}
|
| 873 |
+
onClose={() => setNotice(null)}
|
| 874 |
+
anchorOrigin={{ vertical: 'bottom', horizontal: 'right' }}
|
| 875 |
+
>
|
| 876 |
+
{notice ? (
|
| 877 |
+
<Alert
|
| 878 |
+
onClose={() => setNotice(null)}
|
| 879 |
+
severity={notice.severity}
|
| 880 |
+
variant="filled"
|
| 881 |
+
sx={{ width: '100%' }}
|
| 882 |
+
>
|
| 883 |
+
{notice.message}
|
| 884 |
+
</Alert>
|
| 885 |
+
) : undefined}
|
| 886 |
+
</Snackbar>
|
| 887 |
+
</Portal>
|
| 888 |
+
</Stack>
|
| 889 |
+
</Paper>
|
| 890 |
+
);
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
// ---------- subcomponents --------------------------------------------------
|
| 894 |
+
|
| 895 |
+
function ClapVocabAccordion({ disabled }) {
|
| 896 |
+
const [labels, setLabels] = useState({ genre: [], mood: [], instruments: [] });
|
| 897 |
+
const [overridden, setOverridden] = useState(false);
|
| 898 |
+
const [dirty, setDirty] = useState(false);
|
| 899 |
+
const [busy, setBusy] = useState(false);
|
| 900 |
+
const [vocabError, setVocabError] = useState('');
|
| 901 |
+
|
| 902 |
+
useEffect(() => {
|
| 903 |
+
let cancelled = false;
|
| 904 |
+
(async () => {
|
| 905 |
+
try {
|
| 906 |
+
const { data } = await api.get('/api/annotator-labels');
|
| 907 |
+
if (cancelled) return;
|
| 908 |
+
setLabels(data.labels || { genre: [], mood: [], instruments: [] });
|
| 909 |
+
setOverridden(!!data.overridden);
|
| 910 |
+
setDirty(false);
|
| 911 |
+
} catch (e) {
|
| 912 |
+
if (!cancelled) setVocabError(extractError(e, 'Failed to load vocabulary'));
|
| 913 |
+
}
|
| 914 |
+
})();
|
| 915 |
+
return () => { cancelled = true; };
|
| 916 |
+
}, []);
|
| 917 |
+
|
| 918 |
+
function setCategory(cat, values) {
|
| 919 |
+
setLabels((prev) => ({ ...prev, [cat]: values }));
|
| 920 |
+
setDirty(true);
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
async function save() {
|
| 924 |
+
setBusy(true);
|
| 925 |
+
setVocabError('');
|
| 926 |
+
try {
|
| 927 |
+
await api.put('/api/annotator-labels', labels);
|
| 928 |
+
setDirty(false);
|
| 929 |
+
setOverridden(true);
|
| 930 |
+
} catch (e) {
|
| 931 |
+
setVocabError(extractError(e, 'Failed to save vocabulary'));
|
| 932 |
+
} finally {
|
| 933 |
+
setBusy(false);
|
| 934 |
+
}
|
| 935 |
+
}
|
| 936 |
+
|
| 937 |
+
async function reset() {
|
| 938 |
+
if (!window.confirm('Reset vocabulary to the built-in defaults? Your custom tags will be lost.')) return;
|
| 939 |
+
setBusy(true);
|
| 940 |
+
setVocabError('');
|
| 941 |
+
try {
|
| 942 |
+
await api.delete('/api/annotator-labels');
|
| 943 |
+
const { data } = await api.get('/api/annotator-labels');
|
| 944 |
+
setLabels(data.labels || { genre: [], mood: [], instruments: [] });
|
| 945 |
+
setOverridden(false);
|
| 946 |
+
setDirty(false);
|
| 947 |
+
} catch (e) {
|
| 948 |
+
setVocabError(extractError(e, 'Failed to reset vocabulary'));
|
| 949 |
+
} finally {
|
| 950 |
+
setBusy(false);
|
| 951 |
+
}
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
const tagCount = (labels.genre?.length || 0) + (labels.mood?.length || 0) + (labels.instruments?.length || 0);
|
| 955 |
+
|
| 956 |
+
return (
|
| 957 |
+
<Accordion
|
| 958 |
+
disableGutters
|
| 959 |
+
sx={{ '&, &.Mui-expanded': { mt: 0, mb: 0 } }}
|
| 960 |
+
>
|
| 961 |
+
<AccordionSummary
|
| 962 |
+
expandIcon={<ChevronDownIcon size={18} />}
|
| 963 |
+
sx={{
|
| 964 |
+
minHeight: 48,
|
| 965 |
+
'&.Mui-expanded': { minHeight: 48 },
|
| 966 |
+
'& .MuiAccordionSummary-content': {
|
| 967 |
+
margin: '12px 0',
|
| 968 |
+
'&.Mui-expanded': { margin: '12px 0' },
|
| 969 |
+
},
|
| 970 |
+
}}
|
| 971 |
+
>
|
| 972 |
+
<Typography variant="subtitle1">CLAP Vocabulary</Typography>
|
| 973 |
+
<Typography variant="caption" color="text.secondary" sx={{ ml: 1.5, alignSelf: 'center' }}>
|
| 974 |
+
{overridden ? 'custom' : 'defaults'} Β· {tagCount} tags
|
| 975 |
+
</Typography>
|
| 976 |
+
</AccordionSummary>
|
| 977 |
+
<AccordionDetails>
|
| 978 |
+
<Stack spacing={2}>
|
| 979 |
+
<Typography variant="body2" color="text.secondary">
|
| 980 |
+
Words CLAP scores each clip against. Empty categories are ignored. Tweak to match your dataset's territory.
|
| 981 |
+
</Typography>
|
| 982 |
+
<VocabCategory
|
| 983 |
+
label="Genre"
|
| 984 |
+
values={labels.genre || []}
|
| 985 |
+
onChange={(v) => setCategory('genre', v)}
|
| 986 |
+
disabled={disabled || busy}
|
| 987 |
+
/>
|
| 988 |
+
<VocabCategory
|
| 989 |
+
label="Mood"
|
| 990 |
+
values={labels.mood || []}
|
| 991 |
+
onChange={(v) => setCategory('mood', v)}
|
| 992 |
+
disabled={disabled || busy}
|
| 993 |
+
/>
|
| 994 |
+
<VocabCategory
|
| 995 |
+
label="Instruments"
|
| 996 |
+
values={labels.instruments || []}
|
| 997 |
+
onChange={(v) => setCategory('instruments', v)}
|
| 998 |
+
disabled={disabled || busy}
|
| 999 |
+
/>
|
| 1000 |
+
{vocabError && <Alert severity="error" onClose={() => setVocabError('')}>{vocabError}</Alert>}
|
| 1001 |
+
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1.5 }}>
|
| 1002 |
+
<Button
|
| 1003 |
+
variant="text"
|
| 1004 |
+
size="small"
|
| 1005 |
+
onClick={reset}
|
| 1006 |
+
disabled={disabled || busy || !overridden}
|
| 1007 |
+
>
|
| 1008 |
+
Reset to defaults
|
| 1009 |
+
</Button>
|
| 1010 |
+
<Box sx={{ flex: 1 }} />
|
| 1011 |
+
<Button
|
| 1012 |
+
variant="contained"
|
| 1013 |
+
size="small"
|
| 1014 |
+
onClick={save}
|
| 1015 |
+
disabled={disabled || busy || !dirty}
|
| 1016 |
+
>
|
| 1017 |
+
Save vocabulary
|
| 1018 |
+
</Button>
|
| 1019 |
+
</Box>
|
| 1020 |
+
</Stack>
|
| 1021 |
+
</AccordionDetails>
|
| 1022 |
+
</Accordion>
|
| 1023 |
+
);
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
function VocabCategory({ label, values, onChange, disabled }) {
|
| 1027 |
+
return (
|
| 1028 |
+
<Autocomplete
|
| 1029 |
+
multiple
|
| 1030 |
+
freeSolo
|
| 1031 |
+
options={[]}
|
| 1032 |
+
value={values}
|
| 1033 |
+
onChange={(_e, newValues) => onChange(newValues)}
|
| 1034 |
+
disabled={disabled}
|
| 1035 |
+
renderTags={(value, getTagProps) =>
|
| 1036 |
+
value.map((option, index) => {
|
| 1037 |
+
const tagProps = getTagProps({ index });
|
| 1038 |
+
return (
|
| 1039 |
+
<Chip
|
| 1040 |
+
variant="outlined"
|
| 1041 |
+
size="small"
|
| 1042 |
+
label={option}
|
| 1043 |
+
{...tagProps}
|
| 1044 |
+
key={`${option}-${index}`}
|
| 1045 |
+
/>
|
| 1046 |
+
);
|
| 1047 |
+
})
|
| 1048 |
+
}
|
| 1049 |
+
renderInput={(params) => (
|
| 1050 |
+
<TextField
|
| 1051 |
+
{...params}
|
| 1052 |
+
label={label}
|
| 1053 |
+
placeholder="Add tag, press Enter"
|
| 1054 |
+
size="small"
|
| 1055 |
+
/>
|
| 1056 |
+
)}
|
| 1057 |
+
/>
|
| 1058 |
+
);
|
| 1059 |
+
}
|
| 1060 |
+
|
| 1061 |
+
function LoadProjectDialog({ open, projects, currentName, onClose, onLoad, onDeleteProject }) {
|
| 1062 |
+
const [picked, setPicked] = useState(currentName || '');
|
| 1063 |
+
|
| 1064 |
+
useEffect(() => {
|
| 1065 |
+
if (open) setPicked(currentName || (projects[0]?.name ?? ''));
|
| 1066 |
+
}, [open, currentName, projects]);
|
| 1067 |
+
|
| 1068 |
+
return (
|
| 1069 |
+
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
|
| 1070 |
+
<DialogTitle>Load project</DialogTitle>
|
| 1071 |
+
<DialogContent>
|
| 1072 |
+
{projects.length === 0 ? (
|
| 1073 |
+
<Typography variant="body2" color="text.secondary" sx={{ py: 2 }}>
|
| 1074 |
+
No projects yet. Create one first.
|
| 1075 |
+
</Typography>
|
| 1076 |
+
) : (
|
| 1077 |
+
<RadioGroup value={picked} onChange={(e) => setPicked(e.target.value)}>
|
| 1078 |
+
{projects.map((p) => (
|
| 1079 |
+
<Box
|
| 1080 |
+
key={p.name}
|
| 1081 |
+
sx={{ display: 'flex', alignItems: 'center', gap: 1, py: 0.25 }}
|
| 1082 |
+
>
|
| 1083 |
+
<FormControlLabel
|
| 1084 |
+
value={p.name}
|
| 1085 |
+
control={<Radio size="small" />}
|
| 1086 |
+
label={
|
| 1087 |
+
<Box>
|
| 1088 |
+
<Typography variant="body2">{p.name}</Typography>
|
| 1089 |
+
<Typography variant="caption" color="text.secondary">
|
| 1090 |
+
{p.clip_count} clip{p.clip_count === 1 ? '' : 's'}
|
| 1091 |
+
{p.has_draft ? ' Β· has unsaved draft' : ''}
|
| 1092 |
+
</Typography>
|
| 1093 |
+
</Box>
|
| 1094 |
+
}
|
| 1095 |
+
sx={{ alignItems: 'flex-start', flex: 1, mr: 0 }}
|
| 1096 |
+
/>
|
| 1097 |
+
<Tooltip title={TIPS.dataset.deleteProject}>
|
| 1098 |
+
<span>
|
| 1099 |
+
<IconButton
|
| 1100 |
+
size="small"
|
| 1101 |
+
sx={{ color: 'text.disabled', '&:hover': { color: 'error.main', bgcolor: 'action.hover' } }}
|
| 1102 |
+
onClick={() => onDeleteProject(p.name)}
|
| 1103 |
+
>
|
| 1104 |
+
<TrashIcon size={16} />
|
| 1105 |
+
</IconButton>
|
| 1106 |
+
</span>
|
| 1107 |
+
</Tooltip>
|
| 1108 |
+
</Box>
|
| 1109 |
+
))}
|
| 1110 |
+
</RadioGroup>
|
| 1111 |
+
)}
|
| 1112 |
+
</DialogContent>
|
| 1113 |
+
<DialogActions>
|
| 1114 |
+
<Button onClick={onClose}>Cancel</Button>
|
| 1115 |
+
<Button
|
| 1116 |
+
variant="contained"
|
| 1117 |
+
onClick={() => onLoad(picked)}
|
| 1118 |
+
disabled={!picked || projects.length === 0}
|
| 1119 |
+
>
|
| 1120 |
+
Load
|
| 1121 |
+
</Button>
|
| 1122 |
+
</DialogActions>
|
| 1123 |
+
</Dialog>
|
| 1124 |
+
);
|
| 1125 |
+
}
|
| 1126 |
+
|
| 1127 |
+
function ProjectHeader({ project, onSave, onCommit, onDiscard, onAddAudio, disabled }) {
|
| 1128 |
+
const stateLabel = (() => {
|
| 1129 |
+
if (project.dirty && project.has_unsaved_changes) return 'Unsaved changes';
|
| 1130 |
+
if (project.dirty && !project.has_unsaved_changes) return 'Draft saved Β· dataset not created';
|
| 1131 |
+
if (!project.dirty) return 'Dataset created';
|
| 1132 |
+
return '';
|
| 1133 |
+
})();
|
| 1134 |
+
return (
|
| 1135 |
+
<Box sx={{ display: 'flex', alignItems: 'center', gap: 2, flexWrap: 'wrap' }}>
|
| 1136 |
+
<Stack direction="row" spacing={2} alignItems="center" sx={{ flex: 1, minWidth: 240 }}>
|
| 1137 |
+
<Box>
|
| 1138 |
+
<Typography variant="h6">Project: “{project.name}”</Typography>
|
| 1139 |
+
<Typography variant="body2" color="text.secondary">
|
| 1140 |
+
{project.clip_count} clip{project.clip_count === 1 ? '' : 's'}
|
| 1141 |
+
{' Β· '}{stateLabel}
|
| 1142 |
+
</Typography>
|
| 1143 |
+
</Box>
|
| 1144 |
+
<Button
|
| 1145 |
+
variant="outlined"
|
| 1146 |
+
size="small"
|
| 1147 |
+
startIcon={<MusicIcon size={16} />}
|
| 1148 |
+
onClick={onAddAudio}
|
| 1149 |
+
disabled={disabled}
|
| 1150 |
+
>
|
| 1151 |
+
Add audio
|
| 1152 |
+
</Button>
|
| 1153 |
+
</Stack>
|
| 1154 |
+
<Stack direction="row" spacing={1}>
|
| 1155 |
+
<Tooltip title={TIPS.dataset.discardChanges}>
|
| 1156 |
+
<span>
|
| 1157 |
+
<Button
|
| 1158 |
+
variant="outlined"
|
| 1159 |
+
color="error"
|
| 1160 |
+
size="small"
|
| 1161 |
+
startIcon={<TrashIcon size={16} />}
|
| 1162 |
+
onClick={onDiscard}
|
| 1163 |
+
disabled={disabled || !project.dirty}
|
| 1164 |
+
>
|
| 1165 |
+
Delete
|
| 1166 |
+
</Button>
|
| 1167 |
+
</span>
|
| 1168 |
+
</Tooltip>
|
| 1169 |
+
<Tooltip title={TIPS.dataset.saveDraft}>
|
| 1170 |
+
<span>
|
| 1171 |
+
<Button
|
| 1172 |
+
variant="outlined"
|
| 1173 |
+
size="small"
|
| 1174 |
+
startIcon={<SaveIcon size={16} />}
|
| 1175 |
+
onClick={onSave}
|
| 1176 |
+
disabled={disabled || !project.has_unsaved_changes}
|
| 1177 |
+
>
|
| 1178 |
+
Save
|
| 1179 |
+
</Button>
|
| 1180 |
+
</span>
|
| 1181 |
+
</Tooltip>
|
| 1182 |
+
<Tooltip title={TIPS.dataset.createDataset}>
|
| 1183 |
+
<span>
|
| 1184 |
+
<Button
|
| 1185 |
+
variant="contained"
|
| 1186 |
+
size="small"
|
| 1187 |
+
startIcon={<DatasetIcon size={16} />}
|
| 1188 |
+
onClick={onCommit}
|
| 1189 |
+
disabled={disabled || !project.dirty}
|
| 1190 |
+
>
|
| 1191 |
+
Create Dataset
|
| 1192 |
+
</Button>
|
| 1193 |
+
</span>
|
| 1194 |
+
</Tooltip>
|
| 1195 |
+
</Stack>
|
| 1196 |
+
</Box>
|
| 1197 |
+
);
|
| 1198 |
+
}
|
| 1199 |
+
|
| 1200 |
+
function HealthStrip({ health, onSelectFiles }) {
|
| 1201 |
+
if (!health || health.total_clips === 0) return null;
|
| 1202 |
+
const empty = health.empty_prompts || { count: 0, files: [] };
|
| 1203 |
+
const tooShort = health.too_short || { count: 0, files: [] };
|
| 1204 |
+
const dups = health.duplicate_annotations || { count: 0, group_count: 0, files: [] };
|
| 1205 |
+
const unsupported = health.unsupported_format || { count: 0, accepted: [], files: [] };
|
| 1206 |
+
const issues = empty.count + tooShort.count
|
| 1207 |
+
+ dups.count + unsupported.count;
|
| 1208 |
+
|
| 1209 |
+
// Three-tier status driven by the share of unique clips touched by any
|
| 1210 |
+
// health check. A single file showing up in multiple categories only
|
| 1211 |
+
// counts once.
|
| 1212 |
+
const affected = new Set([
|
| 1213 |
+
...empty.files,
|
| 1214 |
+
...tooShort.files,
|
| 1215 |
+
...dups.files,
|
| 1216 |
+
...unsupported.files,
|
| 1217 |
+
]);
|
| 1218 |
+
const affectedRatio = health.total_clips > 0 ? affected.size / health.total_clips : 0;
|
| 1219 |
+
let status;
|
| 1220 |
+
if (affected.size === 0) status = 'ok';
|
| 1221 |
+
else if (affectedRatio > 0.5) status = 'bad';
|
| 1222 |
+
else status = 'warn';
|
| 1223 |
+
|
| 1224 |
+
const statusColor = (
|
| 1225 |
+
status === 'ok' ? 'success.main'
|
| 1226 |
+
: status === 'warn' ? 'warm.main'
|
| 1227 |
+
: 'error.main'
|
| 1228 |
+
);
|
| 1229 |
+
const statusText = (
|
| 1230 |
+
status === 'ok'
|
| 1231 |
+
? `All clean Β· ${health.total_clips} clip${health.total_clips === 1 ? '' : 's'} ready`
|
| 1232 |
+
: `${affected.size} of ${health.total_clips} clip${health.total_clips === 1 ? '' : 's'} flagged`
|
| 1233 |
+
);
|
| 1234 |
+
|
| 1235 |
+
return (
|
| 1236 |
+
<Paper variant="outlined" sx={{ borderRadius: 2.5 }}>
|
| 1237 |
+
<Box sx={{ px: 2, py: 1.25, display: 'flex', alignItems: 'center', gap: 1 }}>
|
| 1238 |
+
<Box component="span" sx={appStyles.sectionCardIcon}>
|
| 1239 |
+
<HealthIcon size={18} />
|
| 1240 |
+
</Box>
|
| 1241 |
+
<Typography variant="subtitle1" sx={{ fontWeight: 500, flex: 1 }}>
|
| 1242 |
+
Dataset health
|
| 1243 |
+
</Typography>
|
| 1244 |
+
<Box
|
| 1245 |
+
sx={{
|
| 1246 |
+
width: 8,
|
| 1247 |
+
height: 8,
|
| 1248 |
+
borderRadius: '50%',
|
| 1249 |
+
bgcolor: statusColor,
|
| 1250 |
+
// Soft halo so the dot reads as a status indicator,
|
| 1251 |
+
// not stray decoration.
|
| 1252 |
+
boxShadow: (theme) =>
|
| 1253 |
+
`0 0 0 3px ${theme.palette.mode === 'dark' ? 'rgba(255,255,255,0.04)' : 'rgba(0,0,0,0.04)'}`,
|
| 1254 |
+
}}
|
| 1255 |
+
/>
|
| 1256 |
+
<Typography variant="caption" sx={{ color: statusColor }}>
|
| 1257 |
+
{statusText}
|
| 1258 |
+
</Typography>
|
| 1259 |
+
</Box>
|
| 1260 |
+
|
| 1261 |
+
{issues > 0 && (
|
| 1262 |
+
<Box
|
| 1263 |
+
sx={{
|
| 1264 |
+
px: 2,
|
| 1265 |
+
py: 1.25,
|
| 1266 |
+
borderTop: 1,
|
| 1267 |
+
borderColor: 'divider',
|
| 1268 |
+
display: 'flex',
|
| 1269 |
+
alignItems: 'center',
|
| 1270 |
+
gap: 0.75,
|
| 1271 |
+
flexWrap: 'wrap',
|
| 1272 |
+
}}
|
| 1273 |
+
>
|
| 1274 |
+
{empty.count > 0 && (
|
| 1275 |
+
<Tooltip title={TIPS.dataset.selectClips}>
|
| 1276 |
+
<Chip
|
| 1277 |
+
size="small"
|
| 1278 |
+
variant="outlined"
|
| 1279 |
+
color="warning"
|
| 1280 |
+
label={`${empty.count} empty annotation${empty.count === 1 ? '' : 's'}`}
|
| 1281 |
+
onClick={() => onSelectFiles(empty.files)}
|
| 1282 |
+
/>
|
| 1283 |
+
</Tooltip>
|
| 1284 |
+
)}
|
| 1285 |
+
{tooShort.count > 0 && (
|
| 1286 |
+
<Tooltip title={TIPS.dataset.tooShort(tooShort.threshold_sec)}>
|
| 1287 |
+
<Chip
|
| 1288 |
+
size="small"
|
| 1289 |
+
variant="outlined"
|
| 1290 |
+
color="error"
|
| 1291 |
+
label={`${tooShort.count} too short (< ${tooShort.threshold_sec}s)`}
|
| 1292 |
+
onClick={() => onSelectFiles(tooShort.files)}
|
| 1293 |
+
/>
|
| 1294 |
+
</Tooltip>
|
| 1295 |
+
)}
|
| 1296 |
+
{dups.count > 0 && (
|
| 1297 |
+
<Tooltip title={TIPS.dataset.duplicates(dups.group_count)}>
|
| 1298 |
+
<Chip
|
| 1299 |
+
size="small"
|
| 1300 |
+
variant="outlined"
|
| 1301 |
+
color="warning"
|
| 1302 |
+
label={`${dups.count} duplicate annotation${dups.count === 1 ? '' : 's'}`}
|
| 1303 |
+
onClick={() => onSelectFiles(dups.files)}
|
| 1304 |
+
/>
|
| 1305 |
+
</Tooltip>
|
| 1306 |
+
)}
|
| 1307 |
+
{unsupported.count > 0 && (
|
| 1308 |
+
<Tooltip title={TIPS.dataset.unsupported(unsupported.accepted)}>
|
| 1309 |
+
<Chip
|
| 1310 |
+
size="small"
|
| 1311 |
+
variant="outlined"
|
| 1312 |
+
color="error"
|
| 1313 |
+
label={`${unsupported.count} unsupported format${unsupported.count === 1 ? '' : 's'}`}
|
| 1314 |
+
onClick={() => onSelectFiles(unsupported.files)}
|
| 1315 |
+
/>
|
| 1316 |
+
</Tooltip>
|
| 1317 |
+
)}
|
| 1318 |
+
</Box>
|
| 1319 |
+
)}
|
| 1320 |
+
</Paper>
|
| 1321 |
+
);
|
| 1322 |
+
}
|
| 1323 |
+
|
| 1324 |
+
function Waveform({ projectName, fileName, isActive, progress }) {
|
| 1325 |
+
const canvasRef = useRef(null);
|
| 1326 |
+
const theme = useTheme();
|
| 1327 |
+
const [peaks, setPeaks] = useState(null);
|
| 1328 |
+
const [failed, setFailed] = useState(false);
|
| 1329 |
+
|
| 1330 |
+
useEffect(() => {
|
| 1331 |
+
let cancelled = false;
|
| 1332 |
+
setPeaks(null);
|
| 1333 |
+
setFailed(false);
|
| 1334 |
+
if (!projectName || !fileName) return;
|
| 1335 |
+
const url = `/api/projects/${encodeURIComponent(projectName)}/clip/${encodeURIComponent(fileName)}/peaks?n=80`;
|
| 1336 |
+
api.get(url)
|
| 1337 |
+
.then(({ data }) => { if (!cancelled) setPeaks(data?.peaks || []); })
|
| 1338 |
+
.catch(() => { if (!cancelled) setFailed(true); });
|
| 1339 |
+
return () => { cancelled = true; };
|
| 1340 |
+
}, [projectName, fileName]);
|
| 1341 |
+
|
| 1342 |
+
useEffect(() => {
|
| 1343 |
+
const canvas = canvasRef.current;
|
| 1344 |
+
if (!canvas) return;
|
| 1345 |
+
const dpr = window.devicePixelRatio || 1;
|
| 1346 |
+
const w = canvas.clientWidth;
|
| 1347 |
+
const h = canvas.clientHeight;
|
| 1348 |
+
if (canvas.width !== w * dpr) canvas.width = w * dpr;
|
| 1349 |
+
if (canvas.height !== h * dpr) canvas.height = h * dpr;
|
| 1350 |
+
const ctx = canvas.getContext('2d');
|
| 1351 |
+
ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
|
| 1352 |
+
ctx.clearRect(0, 0, w, h);
|
| 1353 |
+
|
| 1354 |
+
if (!peaks || !peaks.length) {
|
| 1355 |
+
ctx.fillStyle = 'rgba(0,0,0,0.08)';
|
| 1356 |
+
const midY = h / 2;
|
| 1357 |
+
ctx.fillRect(0, midY - 0.5, w, 1);
|
| 1358 |
+
return;
|
| 1359 |
+
}
|
| 1360 |
+
|
| 1361 |
+
const barCount = peaks.length;
|
| 1362 |
+
const barWidth = Math.max(1, w / barCount - 1);
|
| 1363 |
+
const playedIdx = isActive ? Math.floor(progress * barCount) : -1;
|
| 1364 |
+
// Match the Generated-Fragments waveforms: teal accent for the played
|
| 1365 |
+
// portion, dimmed (35% alpha) for the rest.
|
| 1366 |
+
const playedColor = '#279FBB';
|
| 1367 |
+
const restColor = '#279FBB59';
|
| 1368 |
+
|
| 1369 |
+
for (let i = 0; i < barCount; i++) {
|
| 1370 |
+
const v = peaks[i];
|
| 1371 |
+
const barH = Math.max(1, v * (h - 2));
|
| 1372 |
+
const x = i * (w / barCount);
|
| 1373 |
+
const y = (h - barH) / 2;
|
| 1374 |
+
ctx.fillStyle = i <= playedIdx ? playedColor : restColor;
|
| 1375 |
+
ctx.fillRect(x, y, barWidth, barH);
|
| 1376 |
+
}
|
| 1377 |
+
}, [peaks, isActive, progress, theme]);
|
| 1378 |
+
|
| 1379 |
+
return (
|
| 1380 |
+
<Box sx={{ width: 120, height: 28, flexShrink: 0, opacity: failed ? 0.3 : 1 }}>
|
| 1381 |
+
<canvas
|
| 1382 |
+
ref={canvasRef}
|
| 1383 |
+
style={{ width: '100%', height: '100%', display: 'block' }}
|
| 1384 |
+
/>
|
| 1385 |
+
</Box>
|
| 1386 |
+
);
|
| 1387 |
+
}
|
| 1388 |
+
|
| 1389 |
+
function ClipTable({ projectName, clips, playingFile, playProgress, onPlayToggle, onPromptChange, onAnnotate, onDelete, onSlice, selectedFiles, onToggleSelected, onToggleSelectAll, disabled, toolbar }) {
|
| 1390 |
+
const totalSelected = selectedFiles ? selectedFiles.size : 0;
|
| 1391 |
+
const allSelected = clips && clips.length > 0 && totalSelected === clips.length;
|
| 1392 |
+
const partiallySelected = totalSelected > 0 && !allSelected;
|
| 1393 |
+
if (!clips || clips.length === 0) {
|
| 1394 |
+
return (
|
| 1395 |
+
<Paper variant="outlined" sx={{ borderRadius: 2.5, overflow: 'hidden' }}>
|
| 1396 |
+
{toolbar && (
|
| 1397 |
+
<Box sx={{ px: 1.5, py: 1, borderBottom: 1, borderColor: 'divider' }}>
|
| 1398 |
+
{toolbar}
|
| 1399 |
+
</Box>
|
| 1400 |
+
)}
|
| 1401 |
+
<Box sx={{ py: 4, textAlign: 'center', color: 'text.secondary' }}>
|
| 1402 |
+
<Typography variant="body2">
|
| 1403 |
+
No clips yet. Use βAdd audioβ to bring in a folder.
|
| 1404 |
+
</Typography>
|
| 1405 |
+
</Box>
|
| 1406 |
+
</Paper>
|
| 1407 |
+
);
|
| 1408 |
+
}
|
| 1409 |
+
return (
|
| 1410 |
+
<Paper variant="outlined" sx={{ borderRadius: 2.5, overflow: 'hidden' }}>
|
| 1411 |
+
{toolbar && (
|
| 1412 |
+
<Box sx={{ px: 1.5, py: 1, borderBottom: 1, borderColor: 'divider' }}>
|
| 1413 |
+
{toolbar}
|
| 1414 |
+
</Box>
|
| 1415 |
+
)}
|
| 1416 |
+
<TableContainer>
|
| 1417 |
+
<Table size="small">
|
| 1418 |
+
<TableHead>
|
| 1419 |
+
<TableRow>
|
| 1420 |
+
<TableCell padding="checkbox">
|
| 1421 |
+
<Checkbox
|
| 1422 |
+
size="small"
|
| 1423 |
+
checked={allSelected}
|
| 1424 |
+
indeterminate={partiallySelected}
|
| 1425 |
+
onChange={onToggleSelectAll}
|
| 1426 |
+
disabled={disabled || clips.length === 0}
|
| 1427 |
+
/>
|
| 1428 |
+
</TableCell>
|
| 1429 |
+
<TableCell sx={{ width: '36%' }}>File</TableCell>
|
| 1430 |
+
<TableCell>Annotation</TableCell>
|
| 1431 |
+
<TableCell sx={{ width: 132, textAlign: 'right' }}>Actions</TableCell>
|
| 1432 |
+
</TableRow>
|
| 1433 |
+
</TableHead>
|
| 1434 |
+
<TableBody>
|
| 1435 |
+
{clips.map((c) => (
|
| 1436 |
+
<ClipRow
|
| 1437 |
+
key={c.file_name}
|
| 1438 |
+
projectName={projectName}
|
| 1439 |
+
clip={c}
|
| 1440 |
+
isPlaying={playingFile === c.file_name}
|
| 1441 |
+
playProgress={playingFile === c.file_name ? playProgress : 0}
|
| 1442 |
+
onPlayToggle={onPlayToggle}
|
| 1443 |
+
onPromptChange={onPromptChange}
|
| 1444 |
+
onAnnotate={onAnnotate}
|
| 1445 |
+
onDelete={onDelete}
|
| 1446 |
+
onSlice={onSlice}
|
| 1447 |
+
selected={selectedFiles ? selectedFiles.has(c.file_name) : false}
|
| 1448 |
+
onToggleSelected={onToggleSelected}
|
| 1449 |
+
disabled={disabled}
|
| 1450 |
+
/>
|
| 1451 |
+
))}
|
| 1452 |
+
</TableBody>
|
| 1453 |
+
</Table>
|
| 1454 |
+
</TableContainer>
|
| 1455 |
+
</Paper>
|
| 1456 |
+
);
|
| 1457 |
+
}
|
| 1458 |
+
|
| 1459 |
+
// React.memo so the 60Hz audio-playhead ticks don't reconcile every row in
|
| 1460 |
+
// the table. Custom comparator: skip if visual props didn't change. Callback
|
| 1461 |
+
// identity intentionally ignored β they're stable in behavior, just inline
|
| 1462 |
+
// arrows from the parent, and re-creating a row only to re-bind a click
|
| 1463 |
+
// handler isn't worth the work. playProgress only matters on the active row.
|
| 1464 |
+
const ClipRow = React.memo(function ClipRow({ projectName, clip, isPlaying, playProgress, onPlayToggle, onPromptChange, onAnnotate, onDelete, onSlice, selected, onToggleSelected, disabled }) {
|
| 1465 |
+
const [draft, setDraft] = useState(clip.prompt);
|
| 1466 |
+
useEffect(() => { setDraft(clip.prompt); }, [clip.prompt]);
|
| 1467 |
+
|
| 1468 |
+
const dirty = draft !== clip.prompt;
|
| 1469 |
+
|
| 1470 |
+
return (
|
| 1471 |
+
<TableRow hover selected={selected}>
|
| 1472 |
+
<TableCell padding="checkbox">
|
| 1473 |
+
<Checkbox
|
| 1474 |
+
size="small"
|
| 1475 |
+
checked={!!selected}
|
| 1476 |
+
onChange={() => onToggleSelected && onToggleSelected(clip.file_name)}
|
| 1477 |
+
/>
|
| 1478 |
+
</TableCell>
|
| 1479 |
+
<TableCell sx={{ wordBreak: 'break-all' }}>
|
| 1480 |
+
<Stack direction="row" alignItems="center" spacing={1}>
|
| 1481 |
+
<IconButton
|
| 1482 |
+
size="small"
|
| 1483 |
+
onClick={() => onPlayToggle(clip.file_name)}
|
| 1484 |
+
sx={{ width: 28, height: 28 }}
|
| 1485 |
+
>
|
| 1486 |
+
{isPlaying ? <PauseIcon size={14} /> : <PlayIcon size={14} />}
|
| 1487 |
+
</IconButton>
|
| 1488 |
+
<Waveform
|
| 1489 |
+
projectName={projectName}
|
| 1490 |
+
fileName={clip.file_name}
|
| 1491 |
+
isActive={isPlaying}
|
| 1492 |
+
progress={playProgress}
|
| 1493 |
+
/>
|
| 1494 |
+
<Typography variant="body2" sx={{ flex: 1, minWidth: 0, wordBreak: 'break-all' }}>
|
| 1495 |
+
{clip.file_name}
|
| 1496 |
+
</Typography>
|
| 1497 |
+
</Stack>
|
| 1498 |
+
</TableCell>
|
| 1499 |
+
<TableCell>
|
| 1500 |
+
<TextField
|
| 1501 |
+
fullWidth
|
| 1502 |
+
size="small"
|
| 1503 |
+
variant="standard"
|
| 1504 |
+
value={draft}
|
| 1505 |
+
onChange={(e) => setDraft(e.target.value)}
|
| 1506 |
+
onBlur={() => { if (dirty) onPromptChange(clip.file_name, draft); }}
|
| 1507 |
+
placeholder="(empty β write a prompt or auto-annotate)"
|
| 1508 |
+
disabled={disabled}
|
| 1509 |
+
/>
|
| 1510 |
+
</TableCell>
|
| 1511 |
+
<TableCell sx={{ textAlign: 'right', whiteSpace: 'nowrap' }}>
|
| 1512 |
+
<Tooltip title={TIPS.dataset.autoAnnotateClip}>
|
| 1513 |
+
<span>
|
| 1514 |
+
<IconButton
|
| 1515 |
+
size="small"
|
| 1516 |
+
onClick={() => onAnnotate(clip.file_name)}
|
| 1517 |
+
disabled={disabled}
|
| 1518 |
+
sx={{ color: 'warm.main', '&:hover': { color: 'warm.light', bgcolor: 'action.hover' } }}
|
| 1519 |
+
>
|
| 1520 |
+
<WandSparkles size={16} />
|
| 1521 |
+
</IconButton>
|
| 1522 |
+
</span>
|
| 1523 |
+
</Tooltip>
|
| 1524 |
+
<Tooltip title={TIPS.dataset.sliceClip}>
|
| 1525 |
+
<span>
|
| 1526 |
+
<IconButton
|
| 1527 |
+
size="small"
|
| 1528 |
+
onClick={() => onSlice(clip.file_name)}
|
| 1529 |
+
disabled={disabled}
|
| 1530 |
+
>
|
| 1531 |
+
<ScissorsIcon size={16} />
|
| 1532 |
+
</IconButton>
|
| 1533 |
+
</span>
|
| 1534 |
+
</Tooltip>
|
| 1535 |
+
<Tooltip title={TIPS.dataset.removeClip}>
|
| 1536 |
+
<span>
|
| 1537 |
+
<IconButton
|
| 1538 |
+
size="small"
|
| 1539 |
+
onClick={() => onDelete(clip.file_name)}
|
| 1540 |
+
disabled={disabled}
|
| 1541 |
+
>
|
| 1542 |
+
<TrashIcon size={16} />
|
| 1543 |
+
</IconButton>
|
| 1544 |
+
</span>
|
| 1545 |
+
</Tooltip>
|
| 1546 |
+
</TableCell>
|
| 1547 |
+
</TableRow>
|
| 1548 |
+
);
|
| 1549 |
+
}, (prev, next) => {
|
| 1550 |
+
if (prev.clip !== next.clip) return false;
|
| 1551 |
+
if (prev.disabled !== next.disabled) return false;
|
| 1552 |
+
if (prev.projectName !== next.projectName) return false;
|
| 1553 |
+
if (prev.isPlaying !== next.isPlaying) return false;
|
| 1554 |
+
if (prev.selected !== next.selected) return false;
|
| 1555 |
+
// playProgress only matters when this row is the active one β inactive
|
| 1556 |
+
// rows always receive playProgress=0 from the parent, so they're skipped.
|
| 1557 |
+
if (next.isPlaying && prev.playProgress !== next.playProgress) return false;
|
| 1558 |
+
return true;
|
| 1559 |
+
});
|
| 1560 |
+
|
| 1561 |
+
function CreateProjectDialog({ open, existingNames, onClose, onCreated }) {
|
| 1562 |
+
const [name, setName] = useState('');
|
| 1563 |
+
const [busy, setBusy] = useState(false);
|
| 1564 |
+
const [dialogError, setDialogError] = useState('');
|
| 1565 |
+
|
| 1566 |
+
useEffect(() => {
|
| 1567 |
+
if (open) { setName(''); setDialogError(''); }
|
| 1568 |
+
}, [open]);
|
| 1569 |
+
|
| 1570 |
+
const duplicate = existingNames.includes(name.trim());
|
| 1571 |
+
|
| 1572 |
+
async function submit() {
|
| 1573 |
+
setDialogError('');
|
| 1574 |
+
setBusy(true);
|
| 1575 |
+
try {
|
| 1576 |
+
const { data } = await api.post('/api/projects', { name: name.trim() });
|
| 1577 |
+
await onCreated(data.name);
|
| 1578 |
+
} catch (e) {
|
| 1579 |
+
setDialogError(extractError(e, 'Failed to create project'));
|
| 1580 |
+
} finally {
|
| 1581 |
+
setBusy(false);
|
| 1582 |
+
}
|
| 1583 |
+
}
|
| 1584 |
+
|
| 1585 |
+
return (
|
| 1586 |
+
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
|
| 1587 |
+
<DialogTitle>New project</DialogTitle>
|
| 1588 |
+
<DialogContent>
|
| 1589 |
+
<Stack spacing={2} sx={{ pt: 1 }}>
|
| 1590 |
+
<TextField
|
| 1591 |
+
autoFocus
|
| 1592 |
+
label="Project name"
|
| 1593 |
+
value={name}
|
| 1594 |
+
onChange={(e) => setName(e.target.value)}
|
| 1595 |
+
helperText="Letters, digits, spaces, dashes, underscores, dots. Becomes a folder name on disk."
|
| 1596 |
+
error={duplicate}
|
| 1597 |
+
/>
|
| 1598 |
+
{duplicate && (
|
| 1599 |
+
<Typography variant="caption" color="error">
|
| 1600 |
+
A project with this name already exists.
|
| 1601 |
+
</Typography>
|
| 1602 |
+
)}
|
| 1603 |
+
{dialogError && <Alert severity="error">{dialogError}</Alert>}
|
| 1604 |
+
</Stack>
|
| 1605 |
+
</DialogContent>
|
| 1606 |
+
<DialogActions>
|
| 1607 |
+
<Button onClick={onClose} disabled={busy}>Cancel</Button>
|
| 1608 |
+
<Button
|
| 1609 |
+
variant="contained"
|
| 1610 |
+
onClick={submit}
|
| 1611 |
+
disabled={busy || !name.trim() || duplicate}
|
| 1612 |
+
>
|
| 1613 |
+
Create
|
| 1614 |
+
</Button>
|
| 1615 |
+
</DialogActions>
|
| 1616 |
+
</Dialog>
|
| 1617 |
+
);
|
| 1618 |
+
}
|
| 1619 |
+
|
| 1620 |
+
function IngestDialog({ open, projectName, onClose, onIngested }) {
|
| 1621 |
+
const [folder, setFolder] = useState('');
|
| 1622 |
+
const [mode, setMode] = useState('copy');
|
| 1623 |
+
const [busy, setBusy] = useState(false);
|
| 1624 |
+
const [dialogError, setDialogError] = useState('');
|
| 1625 |
+
|
| 1626 |
+
useEffect(() => {
|
| 1627 |
+
if (open) { setFolder(''); setMode('copy'); setDialogError(''); }
|
| 1628 |
+
}, [open]);
|
| 1629 |
+
|
| 1630 |
+
async function pick() {
|
| 1631 |
+
try {
|
| 1632 |
+
const { data } = await api.post('/api/pick-folder', {});
|
| 1633 |
+
if (data?.path) setFolder(data.path);
|
| 1634 |
+
} catch (e) {
|
| 1635 |
+
setDialogError(extractError(e, 'Folder picker failed'));
|
| 1636 |
+
}
|
| 1637 |
+
}
|
| 1638 |
+
|
| 1639 |
+
async function submit() {
|
| 1640 |
+
if (!projectName) return;
|
| 1641 |
+
setBusy(true);
|
| 1642 |
+
setDialogError('');
|
| 1643 |
+
try {
|
| 1644 |
+
await api.post(
|
| 1645 |
+
`/api/projects/${encodeURIComponent(projectName)}/ingest`,
|
| 1646 |
+
{ folder_path: folder, mode },
|
| 1647 |
+
);
|
| 1648 |
+
await onIngested();
|
| 1649 |
+
} catch (e) {
|
| 1650 |
+
setDialogError(extractError(e, 'Ingest failed'));
|
| 1651 |
+
} finally {
|
| 1652 |
+
setBusy(false);
|
| 1653 |
+
}
|
| 1654 |
+
}
|
| 1655 |
+
|
| 1656 |
+
return (
|
| 1657 |
+
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
|
| 1658 |
+
<DialogTitle>Add audio to {projectName}</DialogTitle>
|
| 1659 |
+
<DialogContent>
|
| 1660 |
+
<Stack spacing={2} sx={{ pt: 1 }}>
|
| 1661 |
+
<Stack direction="row" spacing={1.5} alignItems="center">
|
| 1662 |
+
<Button variant="outlined" startIcon={<FolderOpenIcon size={18} />} onClick={pick}>
|
| 1663 |
+
Pick folder
|
| 1664 |
+
</Button>
|
| 1665 |
+
<Typography variant="body2" color="text.secondary" sx={{ wordBreak: 'break-all' }}>
|
| 1666 |
+
{folder || 'No folder selected'}
|
| 1667 |
+
</Typography>
|
| 1668 |
+
</Stack>
|
| 1669 |
+
|
| 1670 |
+
<FormControl>
|
| 1671 |
+
<Typography variant="body2" gutterBottom>How to bring the audio in:</Typography>
|
| 1672 |
+
<RadioGroup value={mode} onChange={(e) => setMode(e.target.value)}>
|
| 1673 |
+
<FormControlLabel
|
| 1674 |
+
value="copy"
|
| 1675 |
+
control={<Radio size="small" />}
|
| 1676 |
+
label={<Typography variant="body2">Copy β duplicates audio into the project (safe, originals untouched)</Typography>}
|
| 1677 |
+
/>
|
| 1678 |
+
<FormControlLabel
|
| 1679 |
+
value="symlink"
|
| 1680 |
+
control={<Radio size="small" />}
|
| 1681 |
+
label={<Typography variant="body2">Symlink β points at the originals (saves disk, breaks if you move them)</Typography>}
|
| 1682 |
+
/>
|
| 1683 |
+
</RadioGroup>
|
| 1684 |
+
</FormControl>
|
| 1685 |
+
|
| 1686 |
+
{dialogError && <Alert severity="error">{dialogError}</Alert>}
|
| 1687 |
+
</Stack>
|
| 1688 |
+
</DialogContent>
|
| 1689 |
+
<DialogActions>
|
| 1690 |
+
<Button onClick={onClose} disabled={busy}>Cancel</Button>
|
| 1691 |
+
<Button variant="contained" onClick={submit} disabled={busy || !folder}>
|
| 1692 |
+
{busy ? 'Addingβ¦' : 'Add'}
|
| 1693 |
+
</Button>
|
| 1694 |
+
</DialogActions>
|
| 1695 |
+
</Dialog>
|
| 1696 |
+
);
|
| 1697 |
+
}
|
| 1698 |
+
|
| 1699 |
+
function SliceDialog({ open, projectName, fileName, onClose, onSliced }) {
|
| 1700 |
+
const [target, setTarget] = useState(30);
|
| 1701 |
+
const [overlap, setOverlap] = useState(0);
|
| 1702 |
+
const [strategy, setStrategy] = useState('hard');
|
| 1703 |
+
const [duration, setDuration] = useState(null);
|
| 1704 |
+
const [busy, setBusy] = useState(false);
|
| 1705 |
+
const [dialogError, setDialogError] = useState('');
|
| 1706 |
+
|
| 1707 |
+
useEffect(() => {
|
| 1708 |
+
if (!open) return;
|
| 1709 |
+
setTarget(30);
|
| 1710 |
+
setOverlap(0);
|
| 1711 |
+
setStrategy('hard');
|
| 1712 |
+
setDialogError('');
|
| 1713 |
+
setDuration(null);
|
| 1714 |
+
if (!projectName || !fileName) return;
|
| 1715 |
+
// Reuse the peaks endpoint to pull duration cheaply (cached server-side).
|
| 1716 |
+
api.get(`/api/projects/${encodeURIComponent(projectName)}/clip/${encodeURIComponent(fileName)}/peaks?n=20`)
|
| 1717 |
+
.then(({ data }) => setDuration(data?.duration || null))
|
| 1718 |
+
.catch(() => setDuration(null));
|
| 1719 |
+
}, [open, projectName, fileName]);
|
| 1720 |
+
|
| 1721 |
+
const stepSec = Math.max(0.5, target - overlap);
|
| 1722 |
+
const estChildren = duration && target > 0 ? Math.max(1, Math.ceil(duration / stepSec)) : null;
|
| 1723 |
+
const tooShort = duration !== null && duration <= target;
|
| 1724 |
+
|
| 1725 |
+
async function submit() {
|
| 1726 |
+
setBusy(true);
|
| 1727 |
+
setDialogError('');
|
| 1728 |
+
try {
|
| 1729 |
+
await api.post(
|
| 1730 |
+
`/api/projects/${encodeURIComponent(projectName)}/clip/${encodeURIComponent(fileName)}/slice`,
|
| 1731 |
+
{ target_duration: target, overlap_sec: overlap, strategy },
|
| 1732 |
+
);
|
| 1733 |
+
await onSliced();
|
| 1734 |
+
onClose();
|
| 1735 |
+
} catch (e) {
|
| 1736 |
+
setDialogError(extractError(e, 'Slice failed'));
|
| 1737 |
+
} finally {
|
| 1738 |
+
setBusy(false);
|
| 1739 |
+
}
|
| 1740 |
+
}
|
| 1741 |
+
|
| 1742 |
+
return (
|
| 1743 |
+
<Dialog open={open} onClose={busy ? undefined : onClose} maxWidth="sm" fullWidth>
|
| 1744 |
+
<DialogTitle>Slice {fileName || ''}</DialogTitle>
|
| 1745 |
+
<DialogContent>
|
| 1746 |
+
<Stack spacing={2.5} sx={{ pt: 1 }}>
|
| 1747 |
+
<Typography variant="body2" color="text.secondary">
|
| 1748 |
+
The original file will be replaced by the children on disk. Children inherit this clip's annotation. They stay in the project until you Create Dataset (Delete reverts them).
|
| 1749 |
+
</Typography>
|
| 1750 |
+
|
| 1751 |
+
<Stack direction="row" spacing={2}>
|
| 1752 |
+
<TextField
|
| 1753 |
+
label="Target duration (sec)"
|
| 1754 |
+
type="number"
|
| 1755 |
+
size="small"
|
| 1756 |
+
value={target}
|
| 1757 |
+
onChange={(e) => setTarget(Math.max(0.5, parseFloat(e.target.value) || 0))}
|
| 1758 |
+
inputProps={{ step: 0.5, min: 0.5, max: 60 }}
|
| 1759 |
+
fullWidth
|
| 1760 |
+
/>
|
| 1761 |
+
<TextField
|
| 1762 |
+
label="Overlap (sec)"
|
| 1763 |
+
type="number"
|
| 1764 |
+
size="small"
|
| 1765 |
+
value={overlap}
|
| 1766 |
+
onChange={(e) => setOverlap(Math.max(0, parseFloat(e.target.value) || 0))}
|
| 1767 |
+
inputProps={{ step: 0.1, min: 0, max: Math.max(0, target - 0.5) }}
|
| 1768 |
+
fullWidth
|
| 1769 |
+
helperText="Head-overlap on every child after the first"
|
| 1770 |
+
/>
|
| 1771 |
+
</Stack>
|
| 1772 |
+
|
| 1773 |
+
<FormControl>
|
| 1774 |
+
<Typography variant="body2" gutterBottom>Where each cut should land:</Typography>
|
| 1775 |
+
<RadioGroup value={strategy} onChange={(e) => setStrategy(e.target.value)}>
|
| 1776 |
+
<FormControlLabel
|
| 1777 |
+
value="hard"
|
| 1778 |
+
control={<Radio size="small" />}
|
| 1779 |
+
label={<Typography variant="body2">Hard cut β exact intervals; fastest, can split mid-note</Typography>}
|
| 1780 |
+
/>
|
| 1781 |
+
<FormControlLabel
|
| 1782 |
+
value="transient"
|
| 1783 |
+
control={<Radio size="small" />}
|
| 1784 |
+
label={<Typography variant="body2">Transient-aware β snaps each cut to the nearest onset (good for drums / rhythmic)</Typography>}
|
| 1785 |
+
/>
|
| 1786 |
+
<FormControlLabel
|
| 1787 |
+
value="silence"
|
| 1788 |
+
control={<Radio size="small" />}
|
| 1789 |
+
label={<Typography variant="body2">Silence-aware β snaps to the quietest moment in each window (good for melodic / phrased)</Typography>}
|
| 1790 |
+
/>
|
| 1791 |
+
</RadioGroup>
|
| 1792 |
+
</FormControl>
|
| 1793 |
+
|
| 1794 |
+
{duration !== null && (
|
| 1795 |
+
<Typography variant="caption" color="text.secondary">
|
| 1796 |
+
Source: {duration.toFixed(1)}s
|
| 1797 |
+
{estChildren !== null && !tooShort && ` Β· ~${estChildren} children at this setting`}
|
| 1798 |
+
{tooShort && ' Β· already shorter than the target β nothing to slice'}
|
| 1799 |
+
</Typography>
|
| 1800 |
+
)}
|
| 1801 |
+
|
| 1802 |
+
{dialogError && <Alert severity="error">{dialogError}</Alert>}
|
| 1803 |
+
</Stack>
|
| 1804 |
+
</DialogContent>
|
| 1805 |
+
<DialogActions>
|
| 1806 |
+
<Button onClick={onClose} disabled={busy}>Cancel</Button>
|
| 1807 |
+
<Button
|
| 1808 |
+
variant="contained"
|
| 1809 |
+
onClick={submit}
|
| 1810 |
+
disabled={busy || tooShort || target <= 0 || overlap >= target}
|
| 1811 |
+
>
|
| 1812 |
+
{busy ? 'Slicingβ¦' : 'Slice'}
|
| 1813 |
+
</Button>
|
| 1814 |
+
</DialogActions>
|
| 1815 |
+
</Dialog>
|
| 1816 |
+
);
|
| 1817 |
+
}
|
| 1818 |
+
|
| 1819 |
+
// ---------- utils ----------------------------------------------------------
|
| 1820 |
+
|
| 1821 |
+
function extractError(e, fallback) {
|
| 1822 |
+
return e?.response?.data?.error || e?.message || fallback;
|
| 1823 |
+
}
|
app/frontend/src/components/EditPanel.js
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useState, useRef, useEffect } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Box,
|
| 4 |
+
Typography,
|
| 5 |
+
Button,
|
| 6 |
+
Stack,
|
| 7 |
+
TextField,
|
| 8 |
+
ToggleButton,
|
| 9 |
+
ToggleButtonGroup,
|
| 10 |
+
Slider,
|
| 11 |
+
Alert,
|
| 12 |
+
LinearProgress,
|
| 13 |
+
IconButton,
|
| 14 |
+
Switch,
|
| 15 |
+
FormControlLabel,
|
| 16 |
+
} from '@mui/material';
|
| 17 |
+
import { Upload as UploadIcon, X as ClearIcon, Play as PlayIcon, Square as StopIcon } from 'lucide-react';
|
| 18 |
+
import api from '../api';
|
| 19 |
+
import AudioWaveform from './AudioWaveform';
|
| 20 |
+
import { getFragmentDragPayload } from '../utils/fragmentDrag';
|
| 21 |
+
|
| 22 |
+
/**
|
| 23 |
+
* SA3 audio-to-audio + inpainting UI.
|
| 24 |
+
*
|
| 25 |
+
* Three modes:
|
| 26 |
+
* - Style transfer: feed a source clip + new prompt, init_noise_level
|
| 27 |
+
* controls how much character is preserved (0 = source-faithful,
|
| 28 |
+
* 1 = prompt-only).
|
| 29 |
+
* - Inpaint: regenerate a region of the source clip, keeping the rest.
|
| 30 |
+
* - Extend: append N seconds of new audio to the end of the source.
|
| 31 |
+
*
|
| 32 |
+
* All three send to /api/generate using SA3's init_audio / inpaint_audio
|
| 33 |
+
* params. The backend handles file resolution; this panel just uploads
|
| 34 |
+
* the source clip to /api/audio/upload and posts the returned path.
|
| 35 |
+
*
|
| 36 |
+
* Props:
|
| 37 |
+
* model_id: active SA3 model id
|
| 38 |
+
* negativePrompt: optional, passed through
|
| 39 |
+
* loraStack: [{path, strength, bypassed}] from the Generation panel β
|
| 40 |
+
* applied to the edit so style/inpaint/extend inherit the
|
| 41 |
+
* same LoRA character as plain generation.
|
| 42 |
+
* steps: sampler step count from the Generation panel.
|
| 43 |
+
* cfgScale: CFG from the Generation panel (only sent for *-base models;
|
| 44 |
+
* distilled models bake CFG at 1.0).
|
| 45 |
+
* onGenerated(blob, filename, params): called with the resulting WAV
|
| 46 |
+
*/
|
| 47 |
+
export default function EditPanel({ model_id, negativePrompt, loraStack, steps, cfgScale, onGenerated }) {
|
| 48 |
+
const [mode, setMode] = useState('style'); // 'style' | 'inpaint' | 'extend'
|
| 49 |
+
const [sourcePath, setSourcePath] = useState('');
|
| 50 |
+
const [sourceName, setSourceName] = useState('');
|
| 51 |
+
const [sourceFile, setSourceFile] = useState(null); // kept for in-browser decode (waveform)
|
| 52 |
+
const [sourceUploading, setSourceUploading] = useState(false);
|
| 53 |
+
const [dropActive, setDropActive] = useState(false);
|
| 54 |
+
const [prompt, setPrompt] = useState('');
|
| 55 |
+
const [duration, setDuration] = useState(8);
|
| 56 |
+
// Seed: random by default, mirroring the rest of the app. When off, the
|
| 57 |
+
// numeric field is honoured (0 included β a legitimate seed).
|
| 58 |
+
const [randomSeed, setRandomSeed] = useState(true);
|
| 59 |
+
const [seedValue, setSeedValue] = useState('');
|
| 60 |
+
|
| 61 |
+
// sa3-medium generates up to 380s; small models cap at 120s. Matches the
|
| 62 |
+
// generator's _MODEL_INFO so the slider can't request past the model max.
|
| 63 |
+
const maxDuration = (model_id || '').includes('medium') ? 380 : 120;
|
| 64 |
+
// Distilled (post-trained) models bake CFG at 1.0 and ignore cfg_scale; only
|
| 65 |
+
// *-base variants honour it. Same rule the Generation panel uses.
|
| 66 |
+
const isDistilledBase =
|
| 67 |
+
!!model_id && model_id.startsWith('sa3-') && !model_id.endsWith('-base');
|
| 68 |
+
|
| 69 |
+
// style transfer
|
| 70 |
+
const [initNoiseLevel, setInitNoiseLevel] = useState(0.7);
|
| 71 |
+
|
| 72 |
+
// inpaint
|
| 73 |
+
const [maskStart, setMaskStart] = useState(2.0);
|
| 74 |
+
const [maskEnd, setMaskEnd] = useState(4.0);
|
| 75 |
+
|
| 76 |
+
// extend
|
| 77 |
+
const [extendSeconds, setExtendSeconds] = useState(4.0);
|
| 78 |
+
const [sourceDurationSec, setSourceDurationSec] = useState(null);
|
| 79 |
+
|
| 80 |
+
const [generating, setGenerating] = useState(false);
|
| 81 |
+
const [error, setError] = useState(null);
|
| 82 |
+
const fileInputRef = useRef(null);
|
| 83 |
+
|
| 84 |
+
// Inpaint region audition β a hidden <audio> set to the source clip, played
|
| 85 |
+
// from maskStart and auto-stopped at maskEnd, so users can hear the segment
|
| 86 |
+
// they're about to regenerate before committing.
|
| 87 |
+
const regionAudioRef = useRef(null);
|
| 88 |
+
const regionStopRef = useRef(null); // removes the active timeupdate guard
|
| 89 |
+
const [regionUrl, setRegionUrl] = useState(null);
|
| 90 |
+
const [regionPlaying, setRegionPlaying] = useState(false);
|
| 91 |
+
|
| 92 |
+
useEffect(() => {
|
| 93 |
+
if (!sourceFile) { setRegionUrl(null); return undefined; }
|
| 94 |
+
const url = URL.createObjectURL(sourceFile);
|
| 95 |
+
setRegionUrl(url);
|
| 96 |
+
return () => URL.revokeObjectURL(url);
|
| 97 |
+
}, [sourceFile]);
|
| 98 |
+
|
| 99 |
+
// Stop any in-flight preview when the source changes or the mode switches
|
| 100 |
+
// away from inpaint (don't auto-stop on every region drag β the end is
|
| 101 |
+
// captured per play, so dragging mid-play just runs to the old boundary).
|
| 102 |
+
useEffect(() => {
|
| 103 |
+
const a = regionAudioRef.current;
|
| 104 |
+
if (a) { try { a.pause(); } catch { /* ignore */ } }
|
| 105 |
+
regionStopRef.current?.();
|
| 106 |
+
regionStopRef.current = null;
|
| 107 |
+
setRegionPlaying(false);
|
| 108 |
+
}, [regionUrl, mode]);
|
| 109 |
+
|
| 110 |
+
const toggleRegionPreview = () => {
|
| 111 |
+
const a = regionAudioRef.current;
|
| 112 |
+
if (!a || !regionUrl) return;
|
| 113 |
+
if (regionPlaying) {
|
| 114 |
+
a.pause();
|
| 115 |
+
regionStopRef.current?.();
|
| 116 |
+
regionStopRef.current = null;
|
| 117 |
+
setRegionPlaying(false);
|
| 118 |
+
return;
|
| 119 |
+
}
|
| 120 |
+
const start = Math.max(0, Number(maskStart) || 0);
|
| 121 |
+
const end = Math.max(start + 0.05, Number(maskEnd) || 0);
|
| 122 |
+
const onTime = () => {
|
| 123 |
+
if (a.currentTime >= end) {
|
| 124 |
+
a.pause();
|
| 125 |
+
a.removeEventListener('timeupdate', onTime);
|
| 126 |
+
regionStopRef.current = null;
|
| 127 |
+
setRegionPlaying(false);
|
| 128 |
+
}
|
| 129 |
+
};
|
| 130 |
+
try { a.currentTime = start; } catch { /* ignore */ }
|
| 131 |
+
a.addEventListener('timeupdate', onTime);
|
| 132 |
+
regionStopRef.current = () => a.removeEventListener('timeupdate', onTime);
|
| 133 |
+
a.play()
|
| 134 |
+
.then(() => setRegionPlaying(true))
|
| 135 |
+
.catch(() => {
|
| 136 |
+
a.removeEventListener('timeupdate', onTime);
|
| 137 |
+
regionStopRef.current = null;
|
| 138 |
+
setRegionPlaying(false);
|
| 139 |
+
});
|
| 140 |
+
};
|
| 141 |
+
|
| 142 |
+
// --- source upload ---------------------------------------------------
|
| 143 |
+
const onPickFile = () => fileInputRef.current?.click();
|
| 144 |
+
const uploadFile = async (f) => {
|
| 145 |
+
if (!f) return;
|
| 146 |
+
setSourceUploading(true);
|
| 147 |
+
setError(null);
|
| 148 |
+
try {
|
| 149 |
+
const form = new FormData();
|
| 150 |
+
form.append('file', f);
|
| 151 |
+
const r = await api.post('/api/audio/upload', form);
|
| 152 |
+
setSourcePath(r.data.path);
|
| 153 |
+
setSourceName(r.data.name);
|
| 154 |
+
setSourceFile(f); // keep for in-browser waveform decode
|
| 155 |
+
// Probe duration via a temp object URL β <audio>.
|
| 156 |
+
const url = URL.createObjectURL(f);
|
| 157 |
+
const a = new Audio(url);
|
| 158 |
+
a.addEventListener('loadedmetadata', () => {
|
| 159 |
+
if (Number.isFinite(a.duration)) {
|
| 160 |
+
setSourceDurationSec(a.duration);
|
| 161 |
+
// Default the output length to the source length (clamped to
|
| 162 |
+
// the model max). For inpaint this is mandatory β the mask is
|
| 163 |
+
// measured in source seconds, so the output must be the same
|
| 164 |
+
// length or the masked region drifts off the audio you see.
|
| 165 |
+
setDuration(Math.max(1, Math.min(maxDuration, Math.round(a.duration))));
|
| 166 |
+
// Seed inpaint region to the middle quarter so the
|
| 167 |
+
// waveform shows something sensible without a 4 s default
|
| 168 |
+
// landing past the end of short clips.
|
| 169 |
+
const q = a.duration / 4;
|
| 170 |
+
setMaskStart(Math.max(0, q));
|
| 171 |
+
setMaskEnd(Math.min(a.duration, q * 3));
|
| 172 |
+
}
|
| 173 |
+
URL.revokeObjectURL(url);
|
| 174 |
+
}, { once: true });
|
| 175 |
+
} catch (err) {
|
| 176 |
+
setError(err.response?.data?.error?.message || err.message || 'Upload failed');
|
| 177 |
+
} finally {
|
| 178 |
+
setSourceUploading(false);
|
| 179 |
+
}
|
| 180 |
+
};
|
| 181 |
+
const onFileChange = async (e) => {
|
| 182 |
+
const f = e.target.files?.[0];
|
| 183 |
+
e.target.value = '';
|
| 184 |
+
await uploadFile(f);
|
| 185 |
+
};
|
| 186 |
+
// Pull a fragment already on disk (dragged in from the Generated
|
| 187 |
+
// Fragments window) and run it through the same upload path so it gets a
|
| 188 |
+
// server path + waveform + duration probe, exactly like a picked file.
|
| 189 |
+
const loadFragmentByName = async (filename) => {
|
| 190 |
+
if (!filename) return;
|
| 191 |
+
setSourceUploading(true);
|
| 192 |
+
setError(null);
|
| 193 |
+
try {
|
| 194 |
+
const r = await api.get(`/api/fragments/${encodeURIComponent(filename)}`, { responseType: 'blob' });
|
| 195 |
+
const file = new File([r.data], filename, { type: r.data.type || 'audio/wav' });
|
| 196 |
+
await uploadFile(file);
|
| 197 |
+
} catch (err) {
|
| 198 |
+
setError(err.response?.data?.error?.message || err.message || 'Could not load fragment');
|
| 199 |
+
setSourceUploading(false);
|
| 200 |
+
}
|
| 201 |
+
};
|
| 202 |
+
|
| 203 |
+
const onDrop = async (e) => {
|
| 204 |
+
e.preventDefault();
|
| 205 |
+
setDropActive(false);
|
| 206 |
+
// In-app drag from the Generated Fragments window carries the
|
| 207 |
+
// fragment filename; OS file drags carry dataTransfer.files. Read the
|
| 208 |
+
// custom payload synchronously before any await.
|
| 209 |
+
const fragName = e.dataTransfer.getData('application/x-fragmenta-fragment');
|
| 210 |
+
if (fragName) {
|
| 211 |
+
// Prefer the in-memory blob handed off on dragStart β no disk
|
| 212 |
+
// round-trip, and immune to any in-memory vs on-disk name mismatch.
|
| 213 |
+
const payload = getFragmentDragPayload();
|
| 214 |
+
if (payload?.blob && payload.filename === fragName) {
|
| 215 |
+
const file = new File([payload.blob], fragName || 'fragment.wav', {
|
| 216 |
+
type: payload.blob.type || 'audio/wav',
|
| 217 |
+
});
|
| 218 |
+
await uploadFile(file);
|
| 219 |
+
} else {
|
| 220 |
+
// Fallback: blob wasn't preloaded β fetch it from disk by name.
|
| 221 |
+
await loadFragmentByName(fragName);
|
| 222 |
+
}
|
| 223 |
+
return;
|
| 224 |
+
}
|
| 225 |
+
const f = e.dataTransfer.files?.[0];
|
| 226 |
+
await uploadFile(f);
|
| 227 |
+
};
|
| 228 |
+
const onDragOver = (e) => { e.preventDefault(); setDropActive(true); };
|
| 229 |
+
const onDragLeave = (e) => { e.preventDefault(); setDropActive(false); };
|
| 230 |
+
const clearSource = () => {
|
| 231 |
+
setSourcePath('');
|
| 232 |
+
setSourceName('');
|
| 233 |
+
setSourceFile(null);
|
| 234 |
+
setSourceDurationSec(null);
|
| 235 |
+
};
|
| 236 |
+
|
| 237 |
+
// --- generate --------------------------------------------------------
|
| 238 |
+
const generate = async () => {
|
| 239 |
+
if (!model_id) {
|
| 240 |
+
setError('Pick a model in the Generation tab first.');
|
| 241 |
+
return;
|
| 242 |
+
}
|
| 243 |
+
if (!sourcePath) {
|
| 244 |
+
setError('Upload a source clip first.');
|
| 245 |
+
return;
|
| 246 |
+
}
|
| 247 |
+
if (!prompt.trim() && mode !== 'extend') {
|
| 248 |
+
setError('Enter a prompt describing the change.');
|
| 249 |
+
return;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
setGenerating(true);
|
| 253 |
+
setError(null);
|
| 254 |
+
try {
|
| 255 |
+
// Seed: -1 lets the backend pick (and record) a random one; an
|
| 256 |
+
// explicit value is parsed with parseInt so 0 stays 0 rather than
|
| 257 |
+
// collapsing to random via `|| -1`.
|
| 258 |
+
let seedToSend = -1;
|
| 259 |
+
if (!randomSeed) {
|
| 260 |
+
const parsed = parseInt(seedValue, 10);
|
| 261 |
+
if (Number.isNaN(parsed) || parsed < 0) {
|
| 262 |
+
setError('Enter a non-negative integer seed, or switch Seed to Random.');
|
| 263 |
+
setGenerating(false);
|
| 264 |
+
return;
|
| 265 |
+
}
|
| 266 |
+
seedToSend = parsed;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
const body = {
|
| 270 |
+
model_id,
|
| 271 |
+
prompt: prompt.trim() || 'continue',
|
| 272 |
+
duration,
|
| 273 |
+
seed: seedToSend,
|
| 274 |
+
steps,
|
| 275 |
+
};
|
| 276 |
+
if (negativePrompt) body.negative_prompt = negativePrompt;
|
| 277 |
+
// Only base models honour CFG; sending it on a distilled model is
|
| 278 |
+
// harmless (backend forces 1.0) but we keep the UI honest.
|
| 279 |
+
if (!isDistilledBase) body.cfg_scale = cfgScale;
|
| 280 |
+
// Inherit the Generation panel's LoRA stack. Bypassed slots stay in
|
| 281 |
+
// load order but contribute strength 0 (same as plain generation).
|
| 282 |
+
const activeLoras = (loraStack || [])
|
| 283 |
+
.filter((s) => s.path)
|
| 284 |
+
.map((s) => ({ path: s.path, strength: s.bypassed ? 0 : s.strength }));
|
| 285 |
+
if (activeLoras.length) body.loras = activeLoras;
|
| 286 |
+
|
| 287 |
+
if (mode === 'style') {
|
| 288 |
+
body.init_audio_path = sourcePath;
|
| 289 |
+
body.init_noise_level = initNoiseLevel;
|
| 290 |
+
} else if (mode === 'inpaint') {
|
| 291 |
+
// Pin output length to the source so the mask (measured in
|
| 292 |
+
// source seconds) maps onto the same timeline the user sees.
|
| 293 |
+
if (!Number.isFinite(sourceDurationSec)) {
|
| 294 |
+
setError("Couldn't read source duration β re-upload the file.");
|
| 295 |
+
setGenerating(false);
|
| 296 |
+
return;
|
| 297 |
+
}
|
| 298 |
+
body.duration = sourceDurationSec;
|
| 299 |
+
body.inpaint_audio_path = sourcePath;
|
| 300 |
+
body.inpaint_starts = [Number(maskStart)];
|
| 301 |
+
body.inpaint_ends = [Number(maskEnd)];
|
| 302 |
+
} else if (mode === 'extend') {
|
| 303 |
+
// Extend = inpaint where the mask is the new tail. Total clip
|
| 304 |
+
// duration = source length + extendSeconds; mask covers
|
| 305 |
+
// [source_length, source_length + extendSeconds].
|
| 306 |
+
if (!Number.isFinite(sourceDurationSec)) {
|
| 307 |
+
setError("Couldn't read source duration β re-upload the file.");
|
| 308 |
+
setGenerating(false);
|
| 309 |
+
return;
|
| 310 |
+
}
|
| 311 |
+
body.duration = sourceDurationSec + extendSeconds;
|
| 312 |
+
body.inpaint_audio_path = sourcePath;
|
| 313 |
+
body.inpaint_starts = [sourceDurationSec];
|
| 314 |
+
body.inpaint_ends = [sourceDurationSec + extendSeconds];
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
const resp = await api.post('/api/generate', body, { responseType: 'blob' });
|
| 318 |
+
// Use the backend's real on-disk name (header) so the fragment in
|
| 319 |
+
// the list resolves to an actual file for reveal/delete; only fall
|
| 320 |
+
// back to a synthetic name if the header is absent.
|
| 321 |
+
const fname = resp.headers?.['x-fragment-filename'] || `${mode}_${Date.now()}.wav`;
|
| 322 |
+
// Record the resolved seed (the backend picks a concrete one when we
|
| 323 |
+
// sent -1) so the fragment shows the real value, not "random".
|
| 324 |
+
const resolvedSeed = parseInt(resp.headers?.['x-fragment-seed'], 10);
|
| 325 |
+
const params = Number.isFinite(resolvedSeed) ? { ...body, seed: resolvedSeed } : body;
|
| 326 |
+
onGenerated?.(resp.data, fname, params);
|
| 327 |
+
} catch (err) {
|
| 328 |
+
setError(err.response?.data?.error?.message || err.message || 'Generation failed');
|
| 329 |
+
} finally {
|
| 330 |
+
setGenerating(false);
|
| 331 |
+
}
|
| 332 |
+
};
|
| 333 |
+
|
| 334 |
+
// --- render ----------------------------------------------------------
|
| 335 |
+
return (
|
| 336 |
+
<Box sx={{ p: 2 }}>
|
| 337 |
+
{/* Source picker (drag-and-drop or click) */}
|
| 338 |
+
<Box
|
| 339 |
+
sx={{ mb: 2 }}
|
| 340 |
+
onDragOver={onDragOver}
|
| 341 |
+
onDragLeave={onDragLeave}
|
| 342 |
+
onDrop={onDrop}
|
| 343 |
+
>
|
| 344 |
+
<Typography variant="caption" color="text.secondary" display="block" sx={{ mb: 0.5 }}>
|
| 345 |
+
Source clip
|
| 346 |
+
</Typography>
|
| 347 |
+
{sourcePath ? (
|
| 348 |
+
<Stack
|
| 349 |
+
direction="row"
|
| 350 |
+
alignItems="center"
|
| 351 |
+
spacing={1}
|
| 352 |
+
sx={{
|
| 353 |
+
p: 1,
|
| 354 |
+
border: '1px dashed',
|
| 355 |
+
borderColor: dropActive ? 'primary.main' : 'divider',
|
| 356 |
+
borderRadius: 1,
|
| 357 |
+
transition: 'border-color 120ms',
|
| 358 |
+
}}
|
| 359 |
+
>
|
| 360 |
+
<Typography variant="body2" sx={{ flex: 1, fontFamily: 'monospace', fontSize: 12, overflow: 'hidden', textOverflow: 'ellipsis' }}>
|
| 361 |
+
{sourceName}
|
| 362 |
+
{sourceDurationSec && ` Β· ${sourceDurationSec.toFixed(2)}s`}
|
| 363 |
+
</Typography>
|
| 364 |
+
<IconButton size="small" onClick={clearSource} aria-label="Remove source"><ClearIcon size={14} /></IconButton>
|
| 365 |
+
</Stack>
|
| 366 |
+
) : (
|
| 367 |
+
<Button
|
| 368 |
+
variant="outlined"
|
| 369 |
+
startIcon={<UploadIcon size={14} />}
|
| 370 |
+
onClick={onPickFile}
|
| 371 |
+
disabled={sourceUploading}
|
| 372 |
+
fullWidth
|
| 373 |
+
sx={{
|
| 374 |
+
borderStyle: 'dashed',
|
| 375 |
+
borderColor: dropActive ? 'primary.main' : undefined,
|
| 376 |
+
bgcolor: dropActive ? 'action.hover' : undefined,
|
| 377 |
+
transition: 'border-color 120ms, background-color 120ms',
|
| 378 |
+
}}
|
| 379 |
+
>
|
| 380 |
+
{sourceUploading ? 'Uploadingβ¦' : 'Drop a clip here, or click to pick a file'}
|
| 381 |
+
</Button>
|
| 382 |
+
)}
|
| 383 |
+
<input
|
| 384 |
+
ref={fileInputRef}
|
| 385 |
+
type="file"
|
| 386 |
+
accept=".wav,.mp3,.flac,.m4a,.ogg,.opus,audio/*"
|
| 387 |
+
style={{ display: 'none' }}
|
| 388 |
+
onChange={onFileChange}
|
| 389 |
+
/>
|
| 390 |
+
</Box>
|
| 391 |
+
|
| 392 |
+
{/* Mode selector */}
|
| 393 |
+
<ToggleButtonGroup
|
| 394 |
+
value={mode}
|
| 395 |
+
exclusive
|
| 396 |
+
size="small"
|
| 397 |
+
onChange={(_, v) => v && setMode(v)}
|
| 398 |
+
sx={{ mb: 2 }}
|
| 399 |
+
>
|
| 400 |
+
<ToggleButton value="style">Style transfer</ToggleButton>
|
| 401 |
+
<ToggleButton value="inpaint">Inpaint region</ToggleButton>
|
| 402 |
+
<ToggleButton value="extend">Extend</ToggleButton>
|
| 403 |
+
</ToggleButtonGroup>
|
| 404 |
+
|
| 405 |
+
{/* Mode-specific controls */}
|
| 406 |
+
{mode === 'style' && (
|
| 407 |
+
<Box sx={{ mb: 2 }}>
|
| 408 |
+
<Typography variant="caption" color="text.secondary">
|
| 409 |
+
Preserve source character ββ follow prompt
|
| 410 |
+
</Typography>
|
| 411 |
+
<Stack direction="row" alignItems="center" spacing={2}>
|
| 412 |
+
<Slider
|
| 413 |
+
value={initNoiseLevel}
|
| 414 |
+
onChange={(_, v) => setInitNoiseLevel(v)}
|
| 415 |
+
min={0}
|
| 416 |
+
max={1}
|
| 417 |
+
step={0.05}
|
| 418 |
+
valueLabelDisplay="auto"
|
| 419 |
+
marks={[
|
| 420 |
+
{ value: 0, label: '0' },
|
| 421 |
+
{ value: 0.5, label: '0.5' },
|
| 422 |
+
{ value: 1, label: '1' },
|
| 423 |
+
]}
|
| 424 |
+
sx={{ flex: 1 }}
|
| 425 |
+
/>
|
| 426 |
+
<Typography variant="body2" sx={{ width: 40, textAlign: 'right' }}>
|
| 427 |
+
{initNoiseLevel.toFixed(2)}
|
| 428 |
+
</Typography>
|
| 429 |
+
</Stack>
|
| 430 |
+
</Box>
|
| 431 |
+
)}
|
| 432 |
+
|
| 433 |
+
{mode === 'inpaint' && (
|
| 434 |
+
<Box sx={{ mb: 2 }}>
|
| 435 |
+
<Typography variant="caption" color="text.secondary" display="block" sx={{ mb: 0.5 }}>
|
| 436 |
+
Drag the highlighted region to inpaint
|
| 437 |
+
</Typography>
|
| 438 |
+
<AudioWaveform
|
| 439 |
+
file={sourceFile}
|
| 440 |
+
duration={sourceDurationSec || 0}
|
| 441 |
+
start={maskStart}
|
| 442 |
+
end={maskEnd}
|
| 443 |
+
onRegionChange={(s, e) => { setMaskStart(s); setMaskEnd(e); }}
|
| 444 |
+
/>
|
| 445 |
+
<Stack direction="row" alignItems="center" spacing={2} sx={{ mt: 1 }}>
|
| 446 |
+
<TextField
|
| 447 |
+
label="Start (s)"
|
| 448 |
+
type="number"
|
| 449 |
+
size="small"
|
| 450 |
+
value={maskStart.toFixed(2)}
|
| 451 |
+
onChange={(e) => setMaskStart(parseFloat(e.target.value) || 0)}
|
| 452 |
+
inputProps={{ min: 0, max: sourceDurationSec || 999, step: 0.05 }}
|
| 453 |
+
sx={{ width: 96 }}
|
| 454 |
+
/>
|
| 455 |
+
<TextField
|
| 456 |
+
label="End (s)"
|
| 457 |
+
type="number"
|
| 458 |
+
size="small"
|
| 459 |
+
value={maskEnd.toFixed(2)}
|
| 460 |
+
onChange={(e) => setMaskEnd(parseFloat(e.target.value) || 0)}
|
| 461 |
+
inputProps={{ min: 0, max: sourceDurationSec || 999, step: 0.05 }}
|
| 462 |
+
sx={{ width: 96 }}
|
| 463 |
+
/>
|
| 464 |
+
<Box sx={{ flex: 1, display: 'flex', alignItems: 'center', gap: 1 }}>
|
| 465 |
+
<Button
|
| 466 |
+
size="small"
|
| 467 |
+
variant="outlined"
|
| 468 |
+
startIcon={regionPlaying ? <StopIcon size={14} /> : <PlayIcon size={14} />}
|
| 469 |
+
onClick={toggleRegionPreview}
|
| 470 |
+
disabled={!regionUrl || (maskEnd - maskStart) < 0.05}
|
| 471 |
+
// Fixed width so swapping "Preview" β "Stop" doesn't
|
| 472 |
+
// resize the button. Sized to fit "Preview" + icon.
|
| 473 |
+
sx={{ width: 108, flexShrink: 0 }}
|
| 474 |
+
>
|
| 475 |
+
{regionPlaying ? 'Stop' : 'Preview'}
|
| 476 |
+
</Button>
|
| 477 |
+
<Typography variant="caption" color="text.secondary">
|
| 478 |
+
{(maskEnd - maskStart).toFixed(2)} s
|
| 479 |
+
</Typography>
|
| 480 |
+
</Box>
|
| 481 |
+
</Stack>
|
| 482 |
+
<Typography variant="caption" color="text.secondary" display="block" sx={{ mt: 1 }}>
|
| 483 |
+
Output is the same length as the source β only your selected region is replaced
|
| 484 |
+
</Typography>
|
| 485 |
+
</Box>
|
| 486 |
+
)}
|
| 487 |
+
|
| 488 |
+
{mode === 'extend' && (
|
| 489 |
+
<Box sx={{ mb: 2 }}>
|
| 490 |
+
<TextField
|
| 491 |
+
label="Seconds to add at the end"
|
| 492 |
+
type="number"
|
| 493 |
+
size="small"
|
| 494 |
+
value={extendSeconds}
|
| 495 |
+
onChange={(e) => setExtendSeconds(parseFloat(e.target.value) || 0)}
|
| 496 |
+
inputProps={{ min: 0.5, max: 60, step: 0.5 }}
|
| 497 |
+
fullWidth
|
| 498 |
+
/>
|
| 499 |
+
<Typography variant="caption" color="text.secondary">
|
| 500 |
+
Source is {sourceDurationSec ? sourceDurationSec.toFixed(2) : 'β'} s; final clip will be{' '}
|
| 501 |
+
{sourceDurationSec ? (sourceDurationSec + Number(extendSeconds || 0)).toFixed(2) : 'β'} s.
|
| 502 |
+
</Typography>
|
| 503 |
+
</Box>
|
| 504 |
+
)}
|
| 505 |
+
|
| 506 |
+
{/* Shared inputs */}
|
| 507 |
+
<TextField
|
| 508 |
+
label={mode === 'inpaint' ? 'Prompt for the inpainting region' : 'Prompt for the edit'}
|
| 509 |
+
placeholder={
|
| 510 |
+
mode === 'style' ? 'How the source should sound nowβ¦' :
|
| 511 |
+
mode === 'inpaint' ? 'What goes in the gapβ¦' :
|
| 512 |
+
'What the continuation should sound like (optional)'
|
| 513 |
+
}
|
| 514 |
+
multiline
|
| 515 |
+
minRows={1}
|
| 516 |
+
maxRows={3}
|
| 517 |
+
value={prompt}
|
| 518 |
+
onChange={(e) => setPrompt(e.target.value)}
|
| 519 |
+
fullWidth
|
| 520 |
+
sx={{ mb: 2 }}
|
| 521 |
+
/>
|
| 522 |
+
|
| 523 |
+
{mode === 'style' && (
|
| 524 |
+
<Stack direction="row" alignItems="center" spacing={2} sx={{ mb: 2 }}>
|
| 525 |
+
<Typography variant="body2" color="text.secondary" sx={{ minWidth: 80 }}>
|
| 526 |
+
Duration
|
| 527 |
+
</Typography>
|
| 528 |
+
<Slider
|
| 529 |
+
value={Math.min(duration, maxDuration)}
|
| 530 |
+
onChange={(_, v) => setDuration(v)}
|
| 531 |
+
min={1}
|
| 532 |
+
max={maxDuration}
|
| 533 |
+
step={1}
|
| 534 |
+
valueLabelDisplay="auto"
|
| 535 |
+
sx={{ flex: 1 }}
|
| 536 |
+
/>
|
| 537 |
+
<Typography variant="body2" sx={{ width: 40, textAlign: 'right' }}>
|
| 538 |
+
{duration}s
|
| 539 |
+
</Typography>
|
| 540 |
+
</Stack>
|
| 541 |
+
)}
|
| 542 |
+
|
| 543 |
+
{/* Seed β random by default, mirrors the Generation panel */}
|
| 544 |
+
<Stack direction="row" alignItems="center" spacing={2} sx={{ mb: 2 }}>
|
| 545 |
+
<Typography variant="body2" color="text.secondary" sx={{ minWidth: 80 }}>
|
| 546 |
+
Seed
|
| 547 |
+
</Typography>
|
| 548 |
+
<FormControlLabel
|
| 549 |
+
control={
|
| 550 |
+
<Switch
|
| 551 |
+
size="small"
|
| 552 |
+
checked={randomSeed}
|
| 553 |
+
onChange={(e) => setRandomSeed(e.target.checked)}
|
| 554 |
+
/>
|
| 555 |
+
}
|
| 556 |
+
label="Random"
|
| 557 |
+
sx={{ mr: 0 }}
|
| 558 |
+
/>
|
| 559 |
+
<TextField
|
| 560 |
+
size="small"
|
| 561 |
+
type="number"
|
| 562 |
+
value={seedValue}
|
| 563 |
+
disabled={randomSeed}
|
| 564 |
+
onChange={(e) => setSeedValue(e.target.value)}
|
| 565 |
+
placeholder={randomSeed ? 'Randomized each run (recorded)' : 'e.g. 42'}
|
| 566 |
+
inputProps={{ min: 0, step: 1 }}
|
| 567 |
+
sx={{ flex: 1 }}
|
| 568 |
+
/>
|
| 569 |
+
</Stack>
|
| 570 |
+
|
| 571 |
+
{/* Hidden element backing the inpaint region preview */}
|
| 572 |
+
<audio
|
| 573 |
+
ref={regionAudioRef}
|
| 574 |
+
src={regionUrl || undefined}
|
| 575 |
+
preload="auto"
|
| 576 |
+
style={{ display: 'none' }}
|
| 577 |
+
onEnded={() => setRegionPlaying(false)}
|
| 578 |
+
/>
|
| 579 |
+
|
| 580 |
+
{error && <Alert severity="error" sx={{ mb: 2 }}>{error}</Alert>}
|
| 581 |
+
{generating && <LinearProgress sx={{ mb: 2 }} />}
|
| 582 |
+
|
| 583 |
+
<Button
|
| 584 |
+
variant="contained"
|
| 585 |
+
fullWidth
|
| 586 |
+
onClick={generate}
|
| 587 |
+
disabled={generating || !sourcePath}
|
| 588 |
+
>
|
| 589 |
+
{generating
|
| 590 |
+
? 'Generatingβ¦'
|
| 591 |
+
: mode === 'style' ? 'Apply style'
|
| 592 |
+
: mode === 'inpaint' ? 'Inpaint region'
|
| 593 |
+
: 'Extend clip'}
|
| 594 |
+
</Button>
|
| 595 |
+
</Box>
|
| 596 |
+
);
|
| 597 |
+
}
|
app/frontend/src/components/GeneratedFragmentsWindow.js
CHANGED
|
@@ -1,27 +1,242 @@
|
|
| 1 |
-
import React, { useState, useRef, useCallback } from 'react';
|
| 2 |
-
import {
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import { generatedFragmentsWindowStyles } from '../theme';
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
const [playingFragment, setPlayingFragment] = useState(null);
|
|
|
|
|
|
|
| 9 |
const audioRefs = useRef({});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
const handlePlayPause = (fragment) => {
|
| 12 |
const audio = audioRefs.current[fragment.id];
|
| 13 |
if (!audio) return;
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
audio.pause();
|
|
|
|
| 17 |
setPlayingFragment(null);
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
}
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
};
|
| 26 |
|
| 27 |
const setAudioRef = useCallback((fragmentId, audioElement) => {
|
|
@@ -31,90 +246,225 @@ export default function GeneratedFragmentsWindow({ fragments, onDownload }) {
|
|
| 31 |
}, []);
|
| 32 |
|
| 33 |
return (
|
| 34 |
-
<Paper
|
| 35 |
-
variant="outlined"
|
| 36 |
-
sx={generatedFragmentsWindowStyles.rootPaper}
|
| 37 |
-
>
|
| 38 |
<Box sx={generatedFragmentsWindowStyles.headerRow}>
|
| 39 |
<Box sx={generatedFragmentsWindowStyles.titleRow}>
|
| 40 |
<Box component="span" sx={generatedFragmentsWindowStyles.titleIcon}>
|
| 41 |
-
<
|
| 42 |
</Box>
|
| 43 |
<Typography variant="h6" sx={generatedFragmentsWindowStyles.titleText}>
|
| 44 |
Generated Fragments
|
| 45 |
</Typography>
|
| 46 |
</Box>
|
| 47 |
-
<
|
| 48 |
-
{
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
</Box>
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
{fragments.length === 0 ? (
|
| 53 |
-
<Box
|
| 54 |
-
|
| 55 |
-
>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
<Typography variant="body2">
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
</Typography>
|
| 59 |
</Box>
|
| 60 |
) : (
|
| 61 |
-
<List
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
<Typography
|
| 72 |
-
variant="
|
| 73 |
sx={generatedFragmentsWindowStyles.fragmentPrompt}
|
|
|
|
| 74 |
>
|
| 75 |
{fragment.batchTotal > 1 && (
|
| 76 |
-
<Box component="span" sx={
|
| 77 |
-
|
| 78 |
</Box>
|
| 79 |
)}
|
| 80 |
{fragment.prompt}
|
| 81 |
</Typography>
|
| 82 |
-
<Typography variant="caption" color="textSecondary">
|
| 83 |
-
{fragment.duration}s
|
| 84 |
-
{fragment.cfgScale !== undefined && ` β’ CFG ${fragment.cfgScale}`}
|
| 85 |
-
{fragment.seed !== undefined && ` β’ Seed ${fragment.seed}`}
|
| 86 |
-
{' β’ '}{fragment.timestamp}
|
| 87 |
-
</Typography>
|
| 88 |
</Box>
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
>
|
| 104 |
-
|
| 105 |
-
</
|
| 106 |
-
</
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
</List>
|
| 119 |
)}
|
| 120 |
</Paper>
|
|
|
|
| 1 |
+
import React, { useState, useRef, useCallback, useEffect } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Paper, Box, Typography, List, ListItem, IconButton,
|
| 4 |
+
Dialog, DialogTitle, DialogContent, DialogContentText, DialogActions, Button,
|
| 5 |
+
CircularProgress,
|
| 6 |
+
} from '@mui/material';
|
| 7 |
+
import { TIPS } from '../tooltips';
|
| 8 |
+
import Tooltip from './Tooltip';
|
| 9 |
+
import {
|
| 10 |
+
Square as StopIcon,
|
| 11 |
+
Play as PlayIcon,
|
| 12 |
+
AudioLines as TitleIcon,
|
| 13 |
+
Info as InfoIcon,
|
| 14 |
+
Trash2 as DeleteIcon,
|
| 15 |
+
Eraser as ClearAllIcon,
|
| 16 |
+
FolderOpen as RevealIcon,
|
| 17 |
+
} from 'lucide-react';
|
| 18 |
import { generatedFragmentsWindowStyles } from '../theme';
|
| 19 |
+
import GenerationWaveform from './GenerationWaveform';
|
| 20 |
+
import api from '../api';
|
| 21 |
+
import { setFragmentDragPayload, clearFragmentDragPayload } from '../utils/fragmentDrag';
|
| 22 |
|
| 23 |
+
// Compact human-readable "X ago" with absolute fallback for stale items.
|
| 24 |
+
function relativeTime(createdAt) {
|
| 25 |
+
if (!createdAt) return '';
|
| 26 |
+
const sec = Math.max(0, (Date.now() - createdAt) / 1000);
|
| 27 |
+
if (sec < 10) return 'just now';
|
| 28 |
+
if (sec < 60) return `${Math.floor(sec)}s ago`;
|
| 29 |
+
const min = sec / 60;
|
| 30 |
+
if (min < 60) return `${Math.floor(min)}m ago`;
|
| 31 |
+
const hr = min / 60;
|
| 32 |
+
if (hr < 24) return `${Math.floor(hr)}h ago`;
|
| 33 |
+
const day = hr / 24;
|
| 34 |
+
if (day < 7) return `${Math.floor(day)}d ago`;
|
| 35 |
+
// Older than a week β show absolute date, no time
|
| 36 |
+
return new Date(createdAt).toLocaleDateString();
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
export default function GeneratedFragmentsWindow({ fragments, onDelete, onClearAll }) {
|
| 40 |
const [playingFragment, setPlayingFragment] = useState(null);
|
| 41 |
+
const [playingTime, setPlayingTime] = useState(0);
|
| 42 |
+
const [clearConfirmOpen, setClearConfirmOpen] = useState(false);
|
| 43 |
const audioRefs = useRef({});
|
| 44 |
+
// Tracks a play request that's between "user clicked Play" and "audio
|
| 45 |
+
// actually started". If the user clicks again during this window we
|
| 46 |
+
// need to either no-op (same fragment) or cleanly cancel (different
|
| 47 |
+
// fragment) β re-entering load() would abort the first play() and
|
| 48 |
+
// both attempts would fail with AbortError.
|
| 49 |
+
const playInFlightRef = useRef(null);
|
| 50 |
+
|
| 51 |
+
// Background-preload of disk-hydrated fragments. On app reload the parent
|
| 52 |
+
// gives us fragment metadata + the backend URL (/api/fragments/...) but
|
| 53 |
+
// no in-memory Blob. The first Play click on those would HTTP-fetch the
|
| 54 |
+
// file synchronously through the <audio> element and freeze briefly. We
|
| 55 |
+
// pre-fetch them in parallel on mount and gate the UI behind a single
|
| 56 |
+
// loading screen β once everything is ready, plays + waveform decodes
|
| 57 |
+
// are instant because they work off blob: URLs.
|
| 58 |
+
const fetchingIdsRef = useRef(new Set());
|
| 59 |
+
const loadedRef = useRef({}); // { [id]: { blob, blobUrl } }
|
| 60 |
+
const [loadedTick, setLoadedTick] = useState(0);
|
| 61 |
+
|
| 62 |
+
useEffect(() => {
|
| 63 |
+
let cancelled = false;
|
| 64 |
+
fragments.forEach((frag) => {
|
| 65 |
+
if (frag.audioBlob) return; // already in memory
|
| 66 |
+
if (loadedRef.current[frag.id]) return; // already preloaded
|
| 67 |
+
if (fetchingIdsRef.current.has(frag.id)) return;
|
| 68 |
+
if (!frag.audioUrl) return;
|
| 69 |
+
fetchingIdsRef.current.add(frag.id);
|
| 70 |
+
fetch(frag.audioUrl)
|
| 71 |
+
.then((r) => {
|
| 72 |
+
if (!r.ok) throw new Error(`HTTP ${r.status}`);
|
| 73 |
+
return r.blob();
|
| 74 |
+
})
|
| 75 |
+
.then((blob) => {
|
| 76 |
+
if (cancelled) return;
|
| 77 |
+
const blobUrl = URL.createObjectURL(blob);
|
| 78 |
+
loadedRef.current[frag.id] = { blob, blobUrl };
|
| 79 |
+
setLoadedTick((t) => t + 1);
|
| 80 |
+
})
|
| 81 |
+
.catch((err) => {
|
| 82 |
+
console.warn(`Fragment preload failed (${frag.filename || frag.id}):`, err);
|
| 83 |
+
})
|
| 84 |
+
.finally(() => {
|
| 85 |
+
fetchingIdsRef.current.delete(frag.id);
|
| 86 |
+
});
|
| 87 |
+
});
|
| 88 |
+
return () => { cancelled = true; };
|
| 89 |
+
}, [fragments]);
|
| 90 |
+
|
| 91 |
+
// Revoke all preload blob URLs on unmount so we don't leak.
|
| 92 |
+
useEffect(() => () => {
|
| 93 |
+
Object.values(loadedRef.current).forEach(({ blobUrl }) => {
|
| 94 |
+
try { URL.revokeObjectURL(blobUrl); } catch { /* ignore */ }
|
| 95 |
+
});
|
| 96 |
+
}, []);
|
| 97 |
+
|
| 98 |
+
// Per-fragment helpers that prefer the in-memory blob (immediate) over
|
| 99 |
+
// the HTTP URL. Defined after loadedTick is read so React knows to
|
| 100 |
+
// re-render when a new fragment finishes preloading.
|
| 101 |
+
void loadedTick;
|
| 102 |
+
const effectiveBlob = (frag) => frag.audioBlob || loadedRef.current[frag.id]?.blob || null;
|
| 103 |
+
const effectiveUrl = (frag) => loadedRef.current[frag.id]?.blobUrl || frag.audioUrl;
|
| 104 |
+
const isFragmentReady = (frag) => !!frag.audioBlob || !!loadedRef.current[frag.id];
|
| 105 |
|
| 106 |
+
const readyCount = fragments.filter(isFragmentReady).length;
|
| 107 |
+
const allReady = fragments.length === 0 || readyCount === fragments.length;
|
| 108 |
+
|
| 109 |
+
// Safety buffer: once everything reports ready, keep the loading overlay
|
| 110 |
+
// up for an extra 5s before revealing the list. Audio decodes that are
|
| 111 |
+
// still settling in the background can't be poked (and can't crash the
|
| 112 |
+
// list) while the user is gated behind the spinner.
|
| 113 |
+
const GRACE_MS = 5000;
|
| 114 |
+
const [graceDone, setGraceDone] = useState(false);
|
| 115 |
+
useEffect(() => {
|
| 116 |
+
if (fragments.length === 0) { setGraceDone(true); return undefined; }
|
| 117 |
+
if (!allReady) { setGraceDone(false); return undefined; }
|
| 118 |
+
const t = setTimeout(() => setGraceDone(true), GRACE_MS);
|
| 119 |
+
return () => clearTimeout(t);
|
| 120 |
+
}, [allReady, fragments.length]);
|
| 121 |
+
const showLoading = fragments.length > 0 && (!allReady || !graceDone);
|
| 122 |
+
|
| 123 |
+
// Strict single-play with first-click readiness gate.
|
| 124 |
+
//
|
| 125 |
+
// Race-fixes the old version had:
|
| 126 |
+
// 1. Iterate audioRefs.current and pause everything that isn't the
|
| 127 |
+
// new target β avoids losing the race when two play clicks land
|
| 128 |
+
// before React state settles.
|
| 129 |
+
// 2. For blob URLs, Chromium often doesn't actually pull bytes until
|
| 130 |
+
// the first play() call, and play() rejects/hangs if readyState
|
| 131 |
+
// is too low. If we're not ready, call load() and wait for
|
| 132 |
+
// `canplay` (with a 1500 ms safety timeout) before play().
|
| 133 |
+
// 3. Guard against the user clicking Play twice during loading. A
|
| 134 |
+
// second load() while the first play() is still pending aborts
|
| 135 |
+
// the first with AbortError. playInFlightRef tracks the active
|
| 136 |
+
// request: same-fragment second click is a no-op; different
|
| 137 |
+
// fragment cleanly cancels the prior load timer/listener.
|
| 138 |
const handlePlayPause = (fragment) => {
|
| 139 |
const audio = audioRefs.current[fragment.id];
|
| 140 |
if (!audio) return;
|
| 141 |
|
| 142 |
+
// Stop case: this fragment is currently playing β pause it.
|
| 143 |
+
if (!audio.paused) {
|
| 144 |
+
playInFlightRef.current?.cleanup?.();
|
| 145 |
+
playInFlightRef.current = null;
|
| 146 |
audio.pause();
|
| 147 |
+
audio.currentTime = 0;
|
| 148 |
setPlayingFragment(null);
|
| 149 |
+
setPlayingTime(0);
|
| 150 |
+
return;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// Click during loading of the SAME fragment β ignore.
|
| 154 |
+
if (playInFlightRef.current?.fragmentId === fragment.id) {
|
| 155 |
+
return;
|
| 156 |
+
}
|
| 157 |
+
// Click during loading of a DIFFERENT fragment β cancel that.
|
| 158 |
+
if (playInFlightRef.current) {
|
| 159 |
+
playInFlightRef.current.cleanup?.();
|
| 160 |
+
playInFlightRef.current = null;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
Object.values(audioRefs.current).forEach((el) => {
|
| 164 |
+
if (el && el !== audio) {
|
| 165 |
+
el.pause();
|
| 166 |
+
el.currentTime = 0;
|
| 167 |
}
|
| 168 |
+
});
|
| 169 |
+
|
| 170 |
+
const startedFor = fragment.id;
|
| 171 |
+
setPlayingFragment(startedFor);
|
| 172 |
+
setPlayingTime(0);
|
| 173 |
+
|
| 174 |
+
const startPlayback = () => {
|
| 175 |
+
audio.currentTime = 0;
|
| 176 |
+
Promise.resolve(audio.play())
|
| 177 |
+
.then(() => {
|
| 178 |
+
// Successfully playing β clear the in-flight marker so
|
| 179 |
+
// the next Play click can fire a fresh request.
|
| 180 |
+
if (playInFlightRef.current?.fragmentId === startedFor) {
|
| 181 |
+
playInFlightRef.current = null;
|
| 182 |
+
}
|
| 183 |
+
})
|
| 184 |
+
.catch((err) => {
|
| 185 |
+
// AbortError is expected when the user cancels (clicks
|
| 186 |
+
// Stop or switches fragments) β don't noise the log.
|
| 187 |
+
if (err && err.name !== 'AbortError') {
|
| 188 |
+
console.warn(`Fragment play failed (${fragment.filename || fragment.id}):`, err);
|
| 189 |
+
}
|
| 190 |
+
setPlayingFragment((prev) => (prev === startedFor ? null : prev));
|
| 191 |
+
setPlayingTime(0);
|
| 192 |
+
if (playInFlightRef.current?.fragmentId === startedFor) {
|
| 193 |
+
playInFlightRef.current = null;
|
| 194 |
+
}
|
| 195 |
+
});
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
if (audio.readyState >= 2) {
|
| 199 |
+
playInFlightRef.current = { fragmentId: startedFor, cleanup: null };
|
| 200 |
+
startPlayback();
|
| 201 |
+
return;
|
| 202 |
}
|
| 203 |
+
|
| 204 |
+
// Not ready yet β load and wait for canplay (or 1.5 s timeout).
|
| 205 |
+
try { audio.load(); } catch { /* ignore */ }
|
| 206 |
+
let cancelled = false;
|
| 207 |
+
const onReady = () => {
|
| 208 |
+
audio.removeEventListener('canplay', onReady);
|
| 209 |
+
clearTimeout(timer);
|
| 210 |
+
if (cancelled) return;
|
| 211 |
+
startPlayback();
|
| 212 |
+
};
|
| 213 |
+
audio.addEventListener('canplay', onReady, { once: true });
|
| 214 |
+
// 5 s β disk-hydrated fragments fetch from /api/fragments/...
|
| 215 |
+
// over HTTP, which can take a couple of seconds on first request.
|
| 216 |
+
// Blob-URL fragments (in-memory) hit canplay almost instantly.
|
| 217 |
+
const timer = setTimeout(() => {
|
| 218 |
+
audio.removeEventListener('canplay', onReady);
|
| 219 |
+
if (!cancelled) startPlayback();
|
| 220 |
+
}, 5000);
|
| 221 |
+
playInFlightRef.current = {
|
| 222 |
+
fragmentId: startedFor,
|
| 223 |
+
cleanup: () => {
|
| 224 |
+
cancelled = true;
|
| 225 |
+
audio.removeEventListener('canplay', onReady);
|
| 226 |
+
clearTimeout(timer);
|
| 227 |
+
},
|
| 228 |
+
};
|
| 229 |
+
};
|
| 230 |
+
|
| 231 |
+
// Reveal a fragment in the OS file manager (folder opens with the file
|
| 232 |
+
// highlighted where the platform supports it). Disk-hydrated fragments
|
| 233 |
+
// always have a filename; in-memory-only ones (not yet flushed) won't.
|
| 234 |
+
const revealInFolder = (fragment) => {
|
| 235 |
+
if (!fragment.filename) return;
|
| 236 |
+
api.post('/api/reveal-fragment', { filename: fragment.filename })
|
| 237 |
+
.catch((err) => {
|
| 238 |
+
console.warn(`Reveal failed (${fragment.filename}):`, err);
|
| 239 |
+
});
|
| 240 |
};
|
| 241 |
|
| 242 |
const setAudioRef = useCallback((fragmentId, audioElement) => {
|
|
|
|
| 246 |
}, []);
|
| 247 |
|
| 248 |
return (
|
| 249 |
+
<Paper variant="outlined" sx={generatedFragmentsWindowStyles.rootPaper}>
|
|
|
|
|
|
|
|
|
|
| 250 |
<Box sx={generatedFragmentsWindowStyles.headerRow}>
|
| 251 |
<Box sx={generatedFragmentsWindowStyles.titleRow}>
|
| 252 |
<Box component="span" sx={generatedFragmentsWindowStyles.titleIcon}>
|
| 253 |
+
<TitleIcon size={20} />
|
| 254 |
</Box>
|
| 255 |
<Typography variant="h6" sx={generatedFragmentsWindowStyles.titleText}>
|
| 256 |
Generated Fragments
|
| 257 |
</Typography>
|
| 258 |
</Box>
|
| 259 |
+
<Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5 }}>
|
| 260 |
+
<Typography variant="caption" color="textSecondary" sx={generatedFragmentsWindowStyles.countText}>
|
| 261 |
+
{fragments.length}
|
| 262 |
+
</Typography>
|
| 263 |
+
{fragments.length > 0 && onClearAll && (
|
| 264 |
+
<Tooltip title={TIPS.fragments.clearAll} placement="top" arrow>
|
| 265 |
+
<IconButton
|
| 266 |
+
size="small"
|
| 267 |
+
onClick={() => setClearConfirmOpen(true)}
|
| 268 |
+
sx={{ color: 'text.disabled', '&:hover': { color: 'error.main' } }}
|
| 269 |
+
>
|
| 270 |
+
<ClearAllIcon size={14} />
|
| 271 |
+
</IconButton>
|
| 272 |
+
</Tooltip>
|
| 273 |
+
)}
|
| 274 |
+
</Box>
|
| 275 |
</Box>
|
| 276 |
|
| 277 |
+
<Dialog open={clearConfirmOpen} onClose={() => setClearConfirmOpen(false)}>
|
| 278 |
+
<DialogTitle>Clear all generated fragments?</DialogTitle>
|
| 279 |
+
<DialogContent>
|
| 280 |
+
<DialogContentText>
|
| 281 |
+
Permanently delete all {fragments.length} fragment{fragments.length === 1 ? '' : 's'} from disk.
|
| 282 |
+
Uploaded source clips (used by Edit mode) are not affected.
|
| 283 |
+
</DialogContentText>
|
| 284 |
+
</DialogContent>
|
| 285 |
+
<DialogActions>
|
| 286 |
+
<Button onClick={() => setClearConfirmOpen(false)}>Cancel</Button>
|
| 287 |
+
<Button
|
| 288 |
+
onClick={() => { setClearConfirmOpen(false); onClearAll?.(); }}
|
| 289 |
+
color="error"
|
| 290 |
+
variant="contained"
|
| 291 |
+
>
|
| 292 |
+
Delete all
|
| 293 |
+
</Button>
|
| 294 |
+
</DialogActions>
|
| 295 |
+
</Dialog>
|
| 296 |
+
|
| 297 |
{fragments.length === 0 ? (
|
| 298 |
+
<Box sx={generatedFragmentsWindowStyles.emptyState}>
|
| 299 |
+
<Typography variant="body2">No fragments generated yet</Typography>
|
| 300 |
+
</Box>
|
| 301 |
+
) : showLoading ? (
|
| 302 |
+
<Box sx={{
|
| 303 |
+
...generatedFragmentsWindowStyles.emptyState,
|
| 304 |
+
display: 'flex',
|
| 305 |
+
flexDirection: 'column',
|
| 306 |
+
alignItems: 'center',
|
| 307 |
+
gap: 1.5,
|
| 308 |
+
}}>
|
| 309 |
+
<CircularProgress size={28} />
|
| 310 |
<Typography variant="body2">
|
| 311 |
+
{allReady
|
| 312 |
+
? 'Finishing upβ¦'
|
| 313 |
+
: `Loading fragments⦠${readyCount} / ${fragments.length}`}
|
| 314 |
</Typography>
|
| 315 |
</Box>
|
| 316 |
) : (
|
| 317 |
+
<List sx={generatedFragmentsWindowStyles.listRoot}>
|
| 318 |
+
{fragments.slice().reverse().map((fragment) => {
|
| 319 |
+
const isPlaying = playingFragment === fragment.id;
|
| 320 |
+
const ago = relativeTime(fragment.createdAt);
|
| 321 |
+
// CFG, seed, full timestamp, and model go in the info
|
| 322 |
+
// tooltip β accessible but not pushing the row out.
|
| 323 |
+
const tooltipLines = [
|
| 324 |
+
// Pre-fix fragments stored -1 for a random seed;
|
| 325 |
+
// show that as "random" rather than a bare -1.
|
| 326 |
+
`Seed: ${(fragment.seed != null && fragment.seed >= 0) ? fragment.seed : 'random'}`,
|
| 327 |
+
// Distilled SA3 models have CFG distilled away β it's
|
| 328 |
+
// genuinely not applicable, not missing.
|
| 329 |
+
`CFG: ${fragment.cfgScale ?? 'n/a'}`,
|
| 330 |
+
fragment.steps != null ? `Steps: ${fragment.steps}` : null,
|
| 331 |
+
fragment.modelId ? `Model: ${fragment.modelId}` : null,
|
| 332 |
+
fragment.editMode ? `Mode: ${fragment.editMode}` : null,
|
| 333 |
+
`Duration: ${fragment.duration}s`,
|
| 334 |
+
ago ? `Generated: ${ago}` : null,
|
| 335 |
+
fragment.timestamp ? fragment.timestamp : null,
|
| 336 |
+
].filter(Boolean).join('\n');
|
| 337 |
+
|
| 338 |
+
return (
|
| 339 |
+
<ListItem
|
| 340 |
+
key={fragment.id}
|
| 341 |
+
sx={generatedFragmentsWindowStyles.listItem}
|
| 342 |
+
>
|
| 343 |
+
<IconButton
|
| 344 |
+
size="small"
|
| 345 |
+
onClick={() => handlePlayPause(fragment)}
|
| 346 |
+
aria-label={isPlaying ? 'Stop' : 'Play'}
|
| 347 |
+
sx={generatedFragmentsWindowStyles.playPauseButton(isPlaying)}
|
| 348 |
+
>
|
| 349 |
+
{isPlaying ? <StopIcon size={16} /> : <PlayIcon size={16} />}
|
| 350 |
+
</IconButton>
|
| 351 |
+
|
| 352 |
+
<Box
|
| 353 |
+
sx={{ ...generatedFragmentsWindowStyles.fragmentMeta, cursor: 'grab' }}
|
| 354 |
+
draggable
|
| 355 |
+
onDragStart={(e) => {
|
| 356 |
+
// In-app payload consumed by EditPanel's drop zone
|
| 357 |
+
// ("drag a clip into the Edit tab"). Keeps the
|
| 358 |
+
// waveform's separate OS drag-out untouched.
|
| 359 |
+
e.dataTransfer.setData(
|
| 360 |
+
'application/x-fragmenta-fragment',
|
| 361 |
+
fragment.filename || '',
|
| 362 |
+
);
|
| 363 |
+
e.dataTransfer.effectAllowed = 'copy';
|
| 364 |
+
// Hand off the in-memory blob too so the drop can
|
| 365 |
+
// use it directly β no disk fetch, immune to any
|
| 366 |
+
// name mismatch. Falls back to the filename when
|
| 367 |
+
// the blob isn't preloaded yet.
|
| 368 |
+
const blob = effectiveBlob(fragment);
|
| 369 |
+
if (blob) {
|
| 370 |
+
setFragmentDragPayload({
|
| 371 |
+
filename: fragment.filename || '',
|
| 372 |
+
blob,
|
| 373 |
+
});
|
| 374 |
+
}
|
| 375 |
+
}}
|
| 376 |
+
onDragEnd={() => clearFragmentDragPayload()}
|
| 377 |
+
title="Drag into the Edit tab to use as a source clip"
|
| 378 |
+
>
|
| 379 |
<Typography
|
| 380 |
+
variant="body2"
|
| 381 |
sx={generatedFragmentsWindowStyles.fragmentPrompt}
|
| 382 |
+
title={fragment.prompt}
|
| 383 |
>
|
| 384 |
{fragment.batchTotal > 1 && (
|
| 385 |
+
<Box component="span" sx={generatedFragmentsWindowStyles.batchTag}>
|
| 386 |
+
{fragment.batchIndex}/{fragment.batchTotal}
|
| 387 |
</Box>
|
| 388 |
)}
|
| 389 |
{fragment.prompt}
|
| 390 |
</Typography>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
</Box>
|
| 392 |
+
|
| 393 |
+
<GenerationWaveform
|
| 394 |
+
blob={effectiveBlob(fragment)}
|
| 395 |
+
audioUrl={effectiveUrl(fragment)}
|
| 396 |
+
filename={fragment.filename || 'fragment.wav'}
|
| 397 |
+
currentTime={isPlaying ? playingTime : 0}
|
| 398 |
+
duration={fragment.duration || 0}
|
| 399 |
+
/>
|
| 400 |
+
|
| 401 |
+
<Tooltip
|
| 402 |
+
title={
|
| 403 |
+
<Box component="span" sx={{ whiteSpace: 'pre-line' }}>
|
| 404 |
+
{tooltipLines}
|
| 405 |
+
</Box>
|
| 406 |
+
}
|
| 407 |
+
arrow
|
| 408 |
+
placement="top"
|
| 409 |
+
>
|
| 410 |
+
<Box
|
| 411 |
+
component="span"
|
| 412 |
+
sx={generatedFragmentsWindowStyles.fragmentInfoIcon}
|
| 413 |
>
|
| 414 |
+
<InfoIcon size={14} />
|
| 415 |
+
</Box>
|
| 416 |
+
</Tooltip>
|
| 417 |
+
|
| 418 |
+
{fragment.filename && (
|
| 419 |
+
<Tooltip title={TIPS.fragments.revealInFolder} placement="top" arrow>
|
| 420 |
+
<IconButton
|
| 421 |
+
size="small"
|
| 422 |
+
onClick={() => revealInFolder(fragment)}
|
| 423 |
+
aria-label="Show in folder"
|
| 424 |
+
sx={{ color: 'text.disabled', '&:hover': { color: 'primary.main', bgcolor: 'action.hover' } }}
|
| 425 |
+
>
|
| 426 |
+
<RevealIcon size={16} />
|
| 427 |
+
</IconButton>
|
| 428 |
+
</Tooltip>
|
| 429 |
+
)}
|
| 430 |
+
|
| 431 |
+
{onDelete && (
|
| 432 |
+
<Tooltip title={TIPS.fragments.deleteFromDisk} placement="top" arrow>
|
| 433 |
+
<IconButton
|
| 434 |
+
size="small"
|
| 435 |
+
onClick={() => onDelete(fragment)}
|
| 436 |
+
sx={{ color: 'text.disabled', '&:hover': { color: 'error.main', bgcolor: 'action.hover' } }}
|
| 437 |
+
>
|
| 438 |
+
<DeleteIcon size={16} />
|
| 439 |
+
</IconButton>
|
| 440 |
+
</Tooltip>
|
| 441 |
+
)}
|
| 442 |
+
|
| 443 |
+
<audio
|
| 444 |
+
ref={el => setAudioRef(fragment.id, el)}
|
| 445 |
+
src={effectiveUrl(fragment)}
|
| 446 |
+
preload="auto"
|
| 447 |
+
onTimeUpdate={(e) => {
|
| 448 |
+
if (playingFragment === fragment.id) {
|
| 449 |
+
setPlayingTime(e.target.currentTime);
|
| 450 |
+
}
|
| 451 |
+
}}
|
| 452 |
+
onEnded={() => {
|
| 453 |
+
if (playingFragment === fragment.id) {
|
| 454 |
+
setPlayingFragment(null);
|
| 455 |
+
setPlayingTime(0);
|
| 456 |
+
}
|
| 457 |
+
}}
|
| 458 |
+
onPause={() => {
|
| 459 |
+
if (playingFragment === fragment.id) {
|
| 460 |
+
setPlayingFragment(null);
|
| 461 |
+
}
|
| 462 |
+
}}
|
| 463 |
+
style={generatedFragmentsWindowStyles.hiddenAudio}
|
| 464 |
+
/>
|
| 465 |
+
</ListItem>
|
| 466 |
+
);
|
| 467 |
+
})}
|
| 468 |
</List>
|
| 469 |
)}
|
| 470 |
</Paper>
|
app/frontend/src/components/GenerationWaveform.js
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useEffect, useLayoutEffect, useRef, useState, useCallback } from 'react';
|
| 2 |
+
import { Box } from '@mui/material';
|
| 3 |
+
|
| 4 |
+
const DEFAULT_COLOR = '#279FBB';
|
| 5 |
+
// Fixed, low waveform resolution β matches the dataset-page waveforms
|
| 6 |
+
// (/peaks?n=80). Decoding to a constant bucket count (instead of one pair per
|
| 7 |
+
// pixel) means the decode runs once per clip rather than re-running on every
|
| 8 |
+
// resize, so fragments render faster.
|
| 9 |
+
const PEAK_COUNT = 80;
|
| 10 |
+
|
| 11 |
+
/**
|
| 12 |
+
* Compact waveform indicator for a single generated fragment.
|
| 13 |
+
*
|
| 14 |
+
* Decodes `blob` once per width and renders min/max peaks on a canvas.
|
| 15 |
+
* Played portion is rendered in `color`; unplayed in a dim version of it,
|
| 16 |
+
* with a thin playhead line at the current position. The whole element is
|
| 17 |
+
* draggable: dragstart sets a DownloadURL on the dataTransfer so the user
|
| 18 |
+
* can drag the fragment onto their desktop or into a DAW as a .wav file.
|
| 19 |
+
*
|
| 20 |
+
* Props:
|
| 21 |
+
* blob: Blob | null β audio source (Blob is required for the
|
| 22 |
+
* native drag-to-OS file write).
|
| 23 |
+
* audioUrl: string β blob: URL for the same audio. Used in
|
| 24 |
+
* the dataTransfer; we fall back to
|
| 25 |
+
* createObjectURL(blob) if it's missing.
|
| 26 |
+
* filename: string β file name the OS sees when the drag
|
| 27 |
+
* resolves.
|
| 28 |
+
* currentTime: number β playback head position in seconds.
|
| 29 |
+
* duration: number β total length in seconds.
|
| 30 |
+
* height: number β canvas height in px (default 28).
|
| 31 |
+
* color: string β accent color (default theme amber).
|
| 32 |
+
*/
|
| 33 |
+
export default function GenerationWaveform({
|
| 34 |
+
blob,
|
| 35 |
+
audioUrl,
|
| 36 |
+
filename = 'fragment.wav',
|
| 37 |
+
currentTime = 0,
|
| 38 |
+
duration = 0,
|
| 39 |
+
height = 28,
|
| 40 |
+
color = DEFAULT_COLOR,
|
| 41 |
+
}) {
|
| 42 |
+
const containerRef = useRef(null);
|
| 43 |
+
const canvasRef = useRef(null);
|
| 44 |
+
// Start at a sensible non-zero width so the decode useEffect (gated on
|
| 45 |
+
// width > 0) runs on first mount instead of waiting for the async
|
| 46 |
+
// ResizeObserver callback β which is what was leaving the canvas blank.
|
| 47 |
+
const [width, setWidth] = useState(200);
|
| 48 |
+
const [peaks, setPeaks] = useState(null);
|
| 49 |
+
|
| 50 |
+
// Measure synchronously on mount via useLayoutEffect so we never paint
|
| 51 |
+
// at the placeholder width; ResizeObserver then keeps it in sync with
|
| 52 |
+
// sidebar collapses / window resizes.
|
| 53 |
+
useLayoutEffect(() => {
|
| 54 |
+
const el = containerRef.current;
|
| 55 |
+
if (!el) return;
|
| 56 |
+
const rect = el.getBoundingClientRect();
|
| 57 |
+
if (rect.width > 0) {
|
| 58 |
+
setWidth(Math.max(1, Math.floor(rect.width)));
|
| 59 |
+
}
|
| 60 |
+
const ro = new ResizeObserver((entries) => {
|
| 61 |
+
const w = Math.max(1, Math.floor(entries[0].contentRect.width));
|
| 62 |
+
setWidth(w);
|
| 63 |
+
});
|
| 64 |
+
ro.observe(el);
|
| 65 |
+
return () => ro.disconnect();
|
| 66 |
+
}, []);
|
| 67 |
+
|
| 68 |
+
// Decode into PEAK_COUNT mono min/max pairs β a fixed low resolution,
|
| 69 |
+
// independent of pixel width, so the (expensive) decode runs once per clip
|
| 70 |
+
// and not again on every resize.
|
| 71 |
+
//
|
| 72 |
+
// Audio source can be either a Blob (in-memory, fresh generations) or
|
| 73 |
+
// an HTTP audioUrl (fragments hydrated from disk on app load have
|
| 74 |
+
// audioBlob=null and audioUrl=/api/fragments/...). The blob path is
|
| 75 |
+
// preferred when available; otherwise fetch the URL.
|
| 76 |
+
useEffect(() => {
|
| 77 |
+
if (!blob && !audioUrl) return;
|
| 78 |
+
let cancelled = false;
|
| 79 |
+
(async () => {
|
| 80 |
+
try {
|
| 81 |
+
let buf;
|
| 82 |
+
if (blob) {
|
| 83 |
+
buf = await blob.arrayBuffer();
|
| 84 |
+
} else {
|
| 85 |
+
const r = await fetch(audioUrl);
|
| 86 |
+
if (!r.ok) {
|
| 87 |
+
console.warn(`GenerationWaveform fetch failed (${r.status}): ${audioUrl}`);
|
| 88 |
+
return;
|
| 89 |
+
}
|
| 90 |
+
buf = await r.arrayBuffer();
|
| 91 |
+
}
|
| 92 |
+
if (cancelled) return;
|
| 93 |
+
if (!buf || buf.byteLength === 0) {
|
| 94 |
+
console.warn('GenerationWaveform: empty audio source');
|
| 95 |
+
return;
|
| 96 |
+
}
|
| 97 |
+
const Ctx = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
| 98 |
+
const tmpCtx = Ctx
|
| 99 |
+
? new Ctx(1, 44100, 44100)
|
| 100 |
+
: new (window.AudioContext || window.webkitAudioContext)();
|
| 101 |
+
const audio = await tmpCtx.decodeAudioData(buf.slice(0));
|
| 102 |
+
if (cancelled) return;
|
| 103 |
+
const ch0 = audio.getChannelData(0);
|
| 104 |
+
const ch1 = audio.numberOfChannels > 1 ? audio.getChannelData(1) : null;
|
| 105 |
+
const totalSamples = ch0.length;
|
| 106 |
+
const bucketSize = Math.max(1, Math.floor(totalSamples / PEAK_COUNT));
|
| 107 |
+
const out = new Float32Array(PEAK_COUNT * 2);
|
| 108 |
+
for (let i = 0; i < PEAK_COUNT; i++) {
|
| 109 |
+
const s = i * bucketSize;
|
| 110 |
+
const e = Math.min(totalSamples, s + bucketSize);
|
| 111 |
+
let mn = 0, mx = 0;
|
| 112 |
+
for (let j = s; j < e; j++) {
|
| 113 |
+
const v = ch1 ? (ch0[j] + ch1[j]) * 0.5 : ch0[j];
|
| 114 |
+
if (v < mn) mn = v;
|
| 115 |
+
if (v > mx) mx = v;
|
| 116 |
+
}
|
| 117 |
+
out[i * 2] = mn;
|
| 118 |
+
out[i * 2 + 1] = mx;
|
| 119 |
+
}
|
| 120 |
+
if (!cancelled) setPeaks(out);
|
| 121 |
+
} catch (err) {
|
| 122 |
+
console.warn('GenerationWaveform decode failed:', err);
|
| 123 |
+
}
|
| 124 |
+
})();
|
| 125 |
+
return () => { cancelled = true; };
|
| 126 |
+
}, [blob, audioUrl]);
|
| 127 |
+
|
| 128 |
+
// Draw β re-runs on every currentTime tick so the playhead moves.
|
| 129 |
+
const draw = useCallback(() => {
|
| 130 |
+
const canvas = canvasRef.current;
|
| 131 |
+
if (!canvas || !width || !height) return;
|
| 132 |
+
const dpr = window.devicePixelRatio || 1;
|
| 133 |
+
canvas.width = width * dpr;
|
| 134 |
+
canvas.height = height * dpr;
|
| 135 |
+
const ctx = canvas.getContext('2d');
|
| 136 |
+
ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
|
| 137 |
+
ctx.clearRect(0, 0, width, height);
|
| 138 |
+
|
| 139 |
+
// Always draw a faint center line so the row has a visible "this is
|
| 140 |
+
// a waveform area" cue even while decode is in flight or has failed.
|
| 141 |
+
ctx.fillStyle = `${color}33`;
|
| 142 |
+
ctx.fillRect(0, height / 2 - 0.5, width, 1);
|
| 143 |
+
|
| 144 |
+
if (!peaks) return;
|
| 145 |
+
|
| 146 |
+
const mid = height / 2;
|
| 147 |
+
const scale = (height - 2) / 2;
|
| 148 |
+
// Stretch the fixed PEAK_COUNT buckets across the canvas width as bars
|
| 149 |
+
// (with a 1px gap), matching the dataset-page waveform look.
|
| 150 |
+
const n = peaks.length / 2;
|
| 151 |
+
const step = width / n;
|
| 152 |
+
const barW = Math.max(1, step - 1);
|
| 153 |
+
const progressFrac = duration > 0
|
| 154 |
+
? Math.max(0, Math.min(1, currentTime / duration))
|
| 155 |
+
: 0;
|
| 156 |
+
const progressPx = progressFrac * width;
|
| 157 |
+
const splitIdx = Math.floor(progressFrac * n);
|
| 158 |
+
|
| 159 |
+
for (let i = 0; i < n; i++) {
|
| 160 |
+
const mn = peaks[i * 2];
|
| 161 |
+
const mx = peaks[i * 2 + 1];
|
| 162 |
+
const y0 = mid - mx * scale;
|
| 163 |
+
const y1 = mid - mn * scale;
|
| 164 |
+
// Played bars: full color; unplayed: dimmed (35% alpha of accent).
|
| 165 |
+
ctx.fillStyle = i < splitIdx ? color : `${color}59`;
|
| 166 |
+
ctx.fillRect(i * step, y0, barW, Math.max(1, y1 - y0));
|
| 167 |
+
}
|
| 168 |
+
// Thin playhead at the split.
|
| 169 |
+
if (progressPx > 0 && progressPx < width) {
|
| 170 |
+
ctx.fillStyle = color;
|
| 171 |
+
ctx.fillRect(progressPx - 0.5, 0, 1, height);
|
| 172 |
+
}
|
| 173 |
+
}, [width, height, peaks, color, currentTime, duration]);
|
| 174 |
+
|
| 175 |
+
useEffect(() => { draw(); }, [draw]);
|
| 176 |
+
|
| 177 |
+
// Native drag-to-OS as a file. The DownloadURL mime type is a Chromium
|
| 178 |
+
// extension the OS interprets as "this drag is a file the browser can
|
| 179 |
+
// serve from URL X with mime/name Y". Source is whichever URL we have:
|
| 180 |
+
// a blob: URL for in-memory fragments, or the backend /api/fragments/
|
| 181 |
+
// path for disk-hydrated ones. The OS needs an ABSOLUTE URL, so we
|
| 182 |
+
// resolve relative paths against window.location.origin.
|
| 183 |
+
const canDrag = !!(audioUrl || blob);
|
| 184 |
+
const handleDragStart = (e) => {
|
| 185 |
+
if (!canDrag) return;
|
| 186 |
+
const raw = audioUrl || URL.createObjectURL(blob);
|
| 187 |
+
const absolute = (raw.startsWith('http') || raw.startsWith('blob:'))
|
| 188 |
+
? raw
|
| 189 |
+
: `${window.location.origin}${raw.startsWith('/') ? '' : '/'}${raw}`;
|
| 190 |
+
e.dataTransfer.setData('DownloadURL', `audio/wav:${filename}:${absolute}`);
|
| 191 |
+
e.dataTransfer.effectAllowed = 'copy';
|
| 192 |
+
};
|
| 193 |
+
|
| 194 |
+
return (
|
| 195 |
+
<Box
|
| 196 |
+
ref={containerRef}
|
| 197 |
+
draggable={canDrag}
|
| 198 |
+
onDragStart={handleDragStart}
|
| 199 |
+
title={canDrag ? 'Drag to save or drop into a DAW' : undefined}
|
| 200 |
+
sx={{
|
| 201 |
+
// Floor the width so the container is never zero β without
|
| 202 |
+
// this, a tight flex row could collapse it before
|
| 203 |
+
// ResizeObserver fires, leaving the canvas un-sized.
|
| 204 |
+
flex: 1,
|
| 205 |
+
minWidth: 120,
|
| 206 |
+
height,
|
| 207 |
+
cursor: canDrag ? 'grab' : 'default',
|
| 208 |
+
'&:active': { cursor: canDrag ? 'grabbing' : 'default' },
|
| 209 |
+
}}
|
| 210 |
+
>
|
| 211 |
+
<canvas
|
| 212 |
+
ref={canvasRef}
|
| 213 |
+
style={{ display: 'block', width: '100%', height }}
|
| 214 |
+
/>
|
| 215 |
+
</Box>
|
| 216 |
+
);
|
| 217 |
+
}
|
app/frontend/src/components/InfoView.js
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { createContext, useCallback, useContext, useMemo, useState } from 'react';
|
| 2 |
+
import { Box, Typography } from '@mui/material';
|
| 3 |
+
import { Info as InfoIcon } from 'lucide-react';
|
| 4 |
+
|
| 5 |
+
/**
|
| 6 |
+
* Ableton-style "Info View".
|
| 7 |
+
*
|
| 8 |
+
* A toggleable strip pinned to the bottom of the window that shows the help
|
| 9 |
+
* text for whatever control the pointer (or keyboard focus) is over, instead
|
| 10 |
+
* of popping a tooltip on the control itself. The shared <Tooltip> feeds this
|
| 11 |
+
* panel when the view is enabled (see components/Tooltip.js).
|
| 12 |
+
*
|
| 13 |
+
* State design: `enabled` is owned by App (changes rarely, persisted). The
|
| 14 |
+
* *hint* β which changes on every hover β lives inside the provider and is
|
| 15 |
+
* read only by the bar, so updating it never re-renders the app tree (the app
|
| 16 |
+
* is passed as `children`, whose element identity is stable across the
|
| 17 |
+
* provider's internal state changes).
|
| 18 |
+
*/
|
| 19 |
+
export const InfoViewContext = createContext({ enabled: false, setHint: () => {} });
|
| 20 |
+
|
| 21 |
+
export const useInfoView = () => useContext(InfoViewContext);
|
| 22 |
+
|
| 23 |
+
export function InfoViewProvider({ enabled, children }) {
|
| 24 |
+
const [hint, setHint] = useState(null);
|
| 25 |
+
// Stable setter so the context value only changes when `enabled` flips β
|
| 26 |
+
// hover-driven hint updates don't churn every tooltip consumer.
|
| 27 |
+
const update = useCallback((value) => setHint(value ?? null), []);
|
| 28 |
+
const value = useMemo(() => ({ enabled, setHint: update }), [enabled, update]);
|
| 29 |
+
|
| 30 |
+
return (
|
| 31 |
+
<InfoViewContext.Provider value={value}>
|
| 32 |
+
{children}
|
| 33 |
+
{enabled && <InfoViewBar hint={hint} />}
|
| 34 |
+
</InfoViewContext.Provider>
|
| 35 |
+
);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
function InfoViewBar({ hint }) {
|
| 39 |
+
// Only present when there's something to say β no placeholder.
|
| 40 |
+
if (!hint) return null;
|
| 41 |
+
|
| 42 |
+
return (
|
| 43 |
+
// Full-width fixed row that centers the pill at the bottom of the page.
|
| 44 |
+
<Box
|
| 45 |
+
sx={{
|
| 46 |
+
position: 'fixed',
|
| 47 |
+
left: 0,
|
| 48 |
+
right: 0,
|
| 49 |
+
bottom: { xs: 16, md: 24 },
|
| 50 |
+
zIndex: 1340, // under the bottom dock (1350)
|
| 51 |
+
px: 2,
|
| 52 |
+
display: 'flex',
|
| 53 |
+
justifyContent: 'center',
|
| 54 |
+
pointerEvents: 'none', // pure overlay β never intercepts clicks
|
| 55 |
+
}}
|
| 56 |
+
>
|
| 57 |
+
<Box
|
| 58 |
+
role="status"
|
| 59 |
+
aria-live="polite"
|
| 60 |
+
sx={(theme) => ({
|
| 61 |
+
display: 'inline-flex',
|
| 62 |
+
alignItems: 'center',
|
| 63 |
+
gap: 1,
|
| 64 |
+
maxWidth: 'min(680px, 90vw)',
|
| 65 |
+
px: 1.75,
|
| 66 |
+
py: 0.9,
|
| 67 |
+
borderRadius: 999,
|
| 68 |
+
// Blurred translucent pill β just enough backing for the
|
| 69 |
+
// text to stay readable over any content behind it.
|
| 70 |
+
backgroundColor: theme.palette.mode === 'dark'
|
| 71 |
+
? 'rgba(20, 22, 24, 0.55)'
|
| 72 |
+
: 'rgba(248, 243, 234, 0.62)',
|
| 73 |
+
backdropFilter: 'blur(16px) saturate(160%)',
|
| 74 |
+
WebkitBackdropFilter: 'blur(16px) saturate(160%)',
|
| 75 |
+
border: `1px solid ${theme.palette.divider}`,
|
| 76 |
+
boxShadow: theme.palette.mode === 'dark'
|
| 77 |
+
? '0 8px 28px rgba(0,0,0,0.5)'
|
| 78 |
+
: '0 8px 28px rgba(43,31,18,0.16)',
|
| 79 |
+
animation: 'fragmenta-fade-up 240ms cubic-bezier(0.16, 1, 0.3, 1)',
|
| 80 |
+
})}
|
| 81 |
+
>
|
| 82 |
+
<Box component="span" sx={{ flexShrink: 0, display: 'inline-flex', color: 'primary.main' }}>
|
| 83 |
+
<InfoIcon size={15} />
|
| 84 |
+
</Box>
|
| 85 |
+
<Typography variant="body2" sx={{ color: 'text.primary', lineHeight: 1.3 }}>
|
| 86 |
+
{hint}
|
| 87 |
+
</Typography>
|
| 88 |
+
</Box>
|
| 89 |
+
</Box>
|
| 90 |
+
);
|
| 91 |
+
}
|
app/frontend/src/components/LoraStack.js
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useEffect, useState } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Box,
|
| 4 |
+
Accordion,
|
| 5 |
+
AccordionSummary,
|
| 6 |
+
AccordionDetails,
|
| 7 |
+
Button,
|
| 8 |
+
Typography,
|
| 9 |
+
Stack,
|
| 10 |
+
MenuItem,
|
| 11 |
+
Select,
|
| 12 |
+
Slider,
|
| 13 |
+
IconButton,
|
| 14 |
+
Chip,
|
| 15 |
+
Alert,
|
| 16 |
+
} from '@mui/material';
|
| 17 |
+
import { TIPS } from '../tooltips';
|
| 18 |
+
import Tooltip from './Tooltip';
|
| 19 |
+
import {
|
| 20 |
+
Plus as AddIcon,
|
| 21 |
+
Trash2 as RemoveIcon,
|
| 22 |
+
GripVertical as DragIcon,
|
| 23 |
+
Power as BypassIcon,
|
| 24 |
+
ChevronDown as ChevronDownIcon,
|
| 25 |
+
} from 'lucide-react';
|
| 26 |
+
import api from '../api';
|
| 27 |
+
import { isLoraCompatible } from '../utils/loraMatch';
|
| 28 |
+
|
| 29 |
+
const MAX_SLOTS = 4;
|
| 30 |
+
|
| 31 |
+
/**
|
| 32 |
+
* Multi-LoRA stack for the Generation panel.
|
| 33 |
+
*
|
| 34 |
+
* Props:
|
| 35 |
+
* selectedModel: the currently-selected base model id (e.g. "sa3-medium-base")
|
| 36 |
+
* value: array of { path, strength, bypassed } slots
|
| 37 |
+
* onChange: (newSlots) => void
|
| 38 |
+
*
|
| 39 |
+
* The picker filters available LoRAs by base-model compatibility (a `*-base`
|
| 40 |
+
* LoRA also runs on its distilled sibling β see utils/loraMatch). Slot order
|
| 41 |
+
* is the load order (slot 0 first); drag the handle to reorder. Bypass keeps
|
| 42 |
+
* a slot in the stack but sends strength 0.
|
| 43 |
+
*/
|
| 44 |
+
export default function LoraStack({ selectedModel, value, onChange }) {
|
| 45 |
+
const [available, setAvailable] = useState([]);
|
| 46 |
+
const [loading, setLoading] = useState(false);
|
| 47 |
+
const [error, setError] = useState(null);
|
| 48 |
+
const [dragIndex, setDragIndex] = useState(null);
|
| 49 |
+
|
| 50 |
+
useEffect(() => {
|
| 51 |
+
let cancelled = false;
|
| 52 |
+
setLoading(true);
|
| 53 |
+
api.get('/api/loras')
|
| 54 |
+
.then(r => { if (!cancelled) setAvailable(r.data.loras || []); })
|
| 55 |
+
.catch(e => { if (!cancelled) setError(e.response?.data?.error || e.message); })
|
| 56 |
+
.finally(() => { if (!cancelled) setLoading(false); });
|
| 57 |
+
return () => { cancelled = true; };
|
| 58 |
+
}, []);
|
| 59 |
+
|
| 60 |
+
// LoRAs compatible with the current generation model. A LoRA trained
|
| 61 |
+
// against `*-base` is compatible with both that base and its distilled
|
| 62 |
+
// sibling (same backbone, differ only in CFG state) β loraMatch strips
|
| 63 |
+
// the trailing `-base` before comparing.
|
| 64 |
+
const compatible = available.filter(l =>
|
| 65 |
+
isLoraCompatible(l.base_model, selectedModel)
|
| 66 |
+
);
|
| 67 |
+
|
| 68 |
+
// The single-LoRA case stays one click: when no slots are populated AND
|
| 69 |
+
// there's a compatible LoRA, surface one empty slot so the user sees a
|
| 70 |
+
// "Pick a LoRA" dropdown immediately.
|
| 71 |
+
const slots = (value && value.length > 0)
|
| 72 |
+
? value
|
| 73 |
+
: (compatible.length ? [{ path: '', strength: 1.0, bypassed: false }] : []);
|
| 74 |
+
|
| 75 |
+
const addSlot = () => {
|
| 76 |
+
if (slots.length >= MAX_SLOTS) return;
|
| 77 |
+
onChange([...slots, { path: '', strength: 1.0, bypassed: false }]);
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
const removeSlot = (idx) => onChange(slots.filter((_, i) => i !== idx));
|
| 81 |
+
|
| 82 |
+
const setSlot = (idx, patch) => {
|
| 83 |
+
onChange(slots.map((s, i) => i === idx ? { ...s, ...patch } : s));
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
// --- drag-to-reorder (slot 0 is loaded first) ---------------------------
|
| 87 |
+
const onDrop = (target) => {
|
| 88 |
+
if (dragIndex === null || dragIndex === target) { setDragIndex(null); return; }
|
| 89 |
+
const next = [...slots];
|
| 90 |
+
const [moved] = next.splice(dragIndex, 1);
|
| 91 |
+
next.splice(target, 0, moved);
|
| 92 |
+
setDragIndex(null);
|
| 93 |
+
onChange(next);
|
| 94 |
+
};
|
| 95 |
+
|
| 96 |
+
const hint = (() => {
|
| 97 |
+
if (!selectedModel) return 'Pick a model first.';
|
| 98 |
+
if (!selectedModel.endsWith('-base')) {
|
| 99 |
+
return 'LoRAs need a Base model. Switch to a *-base checkpoint to use LoRAs.';
|
| 100 |
+
}
|
| 101 |
+
if (loading) return 'Loading LoRAsβ¦';
|
| 102 |
+
if (!compatible.length) {
|
| 103 |
+
return `No LoRAs trained against ${selectedModel} yet. Train one in the Training tab.`;
|
| 104 |
+
}
|
| 105 |
+
return null;
|
| 106 |
+
})();
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
<Accordion
|
| 110 |
+
disableGutters
|
| 111 |
+
defaultExpanded={Boolean(value && value.some((s) => s.path))}
|
| 112 |
+
>
|
| 113 |
+
<AccordionSummary expandIcon={<ChevronDownIcon size={18} />}>
|
| 114 |
+
{/* Hover the title to surface the help in the Info View pill
|
| 115 |
+
(when it's on) β no inline "i", matching the rest of the app. */}
|
| 116 |
+
<Tooltip title={TIPS.lora.stackInfo(MAX_SLOTS)}>
|
| 117 |
+
<Typography variant="subtitle1">LoRA Stack</Typography>
|
| 118 |
+
</Tooltip>
|
| 119 |
+
</AccordionSummary>
|
| 120 |
+
<AccordionDetails>
|
| 121 |
+
{error && <Alert severity="error" sx={{ mb: 1 }}>{error}</Alert>}
|
| 122 |
+
{hint && (
|
| 123 |
+
<Typography variant="caption" color="text.secondary" sx={{ display: 'block', mb: 1 }}>
|
| 124 |
+
{hint}
|
| 125 |
+
</Typography>
|
| 126 |
+
)}
|
| 127 |
+
|
| 128 |
+
{slots.length > 0 && (
|
| 129 |
+
<Box sx={{ border: '1px solid', borderColor: 'divider', borderRadius: 1 }}>
|
| 130 |
+
{slots.map((slot, idx) => {
|
| 131 |
+
const choice = available.find(l => l.path === slot.path);
|
| 132 |
+
const bypassed = !!slot.bypassed;
|
| 133 |
+
return (
|
| 134 |
+
<Box
|
| 135 |
+
key={idx}
|
| 136 |
+
onDragOver={(e) => { if (dragIndex !== null) e.preventDefault(); }}
|
| 137 |
+
onDrop={() => onDrop(idx)}
|
| 138 |
+
sx={{
|
| 139 |
+
p: 1.5,
|
| 140 |
+
borderBottom: '1px solid',
|
| 141 |
+
borderColor: 'divider',
|
| 142 |
+
'&:last-child': { borderBottom: 'none' },
|
| 143 |
+
bgcolor: dragIndex === idx ? 'action.hover' : 'transparent',
|
| 144 |
+
opacity: bypassed ? 0.5 : 1,
|
| 145 |
+
}}
|
| 146 |
+
>
|
| 147 |
+
<Stack direction="row" alignItems="center" spacing={1}>
|
| 148 |
+
<Tooltip title={TIPS.lora.dragReorder}>
|
| 149 |
+
<Box
|
| 150 |
+
draggable={slots.length > 1}
|
| 151 |
+
onDragStart={() => setDragIndex(idx)}
|
| 152 |
+
onDragEnd={() => setDragIndex(null)}
|
| 153 |
+
sx={{
|
| 154 |
+
display: 'flex',
|
| 155 |
+
cursor: slots.length > 1 ? 'grab' : 'default',
|
| 156 |
+
color: 'text.disabled',
|
| 157 |
+
}}
|
| 158 |
+
>
|
| 159 |
+
<DragIcon size={16} />
|
| 160 |
+
</Box>
|
| 161 |
+
</Tooltip>
|
| 162 |
+
<Typography variant="caption" color="text.disabled" sx={{ width: 14 }}>
|
| 163 |
+
{idx}
|
| 164 |
+
</Typography>
|
| 165 |
+
<Select
|
| 166 |
+
size="small"
|
| 167 |
+
value={slot.path}
|
| 168 |
+
displayEmpty
|
| 169 |
+
onChange={(e) => setSlot(idx, { path: String(e.target.value) })}
|
| 170 |
+
sx={{ flex: 1, minWidth: 0 }}
|
| 171 |
+
>
|
| 172 |
+
<MenuItem value="" disabled>
|
| 173 |
+
<em>Pick a LoRA</em>
|
| 174 |
+
</MenuItem>
|
| 175 |
+
{compatible.map(l => (
|
| 176 |
+
<MenuItem key={l.id} value={l.path}>
|
| 177 |
+
<Box>
|
| 178 |
+
<Typography variant="body2">
|
| 179 |
+
{l.name} Β· {l.checkpoint}
|
| 180 |
+
</Typography>
|
| 181 |
+
<Stack direction="row" spacing={0.5} sx={{ mt: 0.25 }}>
|
| 182 |
+
<Chip size="small" label={l.adapter_type || 'lora'} sx={{ height: 16, fontSize: 9 }} />
|
| 183 |
+
{l.rank && <Chip size="small" label={`r=${l.rank}`} sx={{ height: 16, fontSize: 9 }} />}
|
| 184 |
+
</Stack>
|
| 185 |
+
</Box>
|
| 186 |
+
</MenuItem>
|
| 187 |
+
))}
|
| 188 |
+
</Select>
|
| 189 |
+
<Tooltip title={TIPS.lora.bypass(bypassed)}>
|
| 190 |
+
<IconButton
|
| 191 |
+
size="small"
|
| 192 |
+
color={bypassed ? 'default' : 'primary'}
|
| 193 |
+
onClick={() => setSlot(idx, { bypassed: !bypassed })}
|
| 194 |
+
>
|
| 195 |
+
<BypassIcon size={14} />
|
| 196 |
+
</IconButton>
|
| 197 |
+
</Tooltip>
|
| 198 |
+
<IconButton size="small" onClick={() => removeSlot(idx)} aria-label="Remove slot">
|
| 199 |
+
<RemoveIcon size={14} />
|
| 200 |
+
</IconButton>
|
| 201 |
+
</Stack>
|
| 202 |
+
|
| 203 |
+
<Stack direction="row" alignItems="center" spacing={1.5} sx={{ mt: 1, mb: 2 }}>
|
| 204 |
+
<Typography variant="caption" color="text.secondary" sx={{ width: 60 }}>
|
| 205 |
+
Strength
|
| 206 |
+
</Typography>
|
| 207 |
+
<Slider
|
| 208 |
+
size="small"
|
| 209 |
+
value={slot.strength}
|
| 210 |
+
disabled={bypassed}
|
| 211 |
+
onChange={(e, v) => setSlot(idx, { strength: v })}
|
| 212 |
+
min={-2}
|
| 213 |
+
max={2}
|
| 214 |
+
step={0.05}
|
| 215 |
+
valueLabelDisplay="auto"
|
| 216 |
+
marks={[
|
| 217 |
+
{ value: 0, label: '0' },
|
| 218 |
+
{ value: 1, label: '1' },
|
| 219 |
+
]}
|
| 220 |
+
sx={{ flex: 1 }}
|
| 221 |
+
/>
|
| 222 |
+
<Typography variant="body2" sx={{ width: 40, textAlign: 'right' }}>
|
| 223 |
+
{bypassed ? 'β' : slot.strength.toFixed(2)}
|
| 224 |
+
</Typography>
|
| 225 |
+
</Stack>
|
| 226 |
+
|
| 227 |
+
{choice && choice.base_model && (
|
| 228 |
+
<Typography variant="caption" color="text.secondary" sx={{ display: 'block', mt: 0.25 }}>
|
| 229 |
+
Trained on {choice.base_model}
|
| 230 |
+
</Typography>
|
| 231 |
+
)}
|
| 232 |
+
</Box>
|
| 233 |
+
);
|
| 234 |
+
})}
|
| 235 |
+
</Box>
|
| 236 |
+
)}
|
| 237 |
+
|
| 238 |
+
<Stack direction="row" sx={{ mt: 1 }}>
|
| 239 |
+
<Button
|
| 240 |
+
size="small"
|
| 241 |
+
variant="outlined"
|
| 242 |
+
startIcon={<AddIcon size={14} />}
|
| 243 |
+
disabled={slots.length >= MAX_SLOTS || !compatible.length}
|
| 244 |
+
onClick={addSlot}
|
| 245 |
+
>
|
| 246 |
+
Add LoRA
|
| 247 |
+
</Button>
|
| 248 |
+
</Stack>
|
| 249 |
+
</AccordionDetails>
|
| 250 |
+
</Accordion>
|
| 251 |
+
);
|
| 252 |
+
}
|
app/frontend/src/components/LossChart.js
CHANGED
|
@@ -1,19 +1,35 @@
|
|
| 1 |
import React, { useState } from 'react';
|
| 2 |
import { lossChartStyles } from '../theme';
|
| 3 |
|
| 4 |
-
//
|
| 5 |
-
//
|
| 6 |
-
//
|
| 7 |
-
//
|
| 8 |
-
//
|
| 9 |
-
//
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
function smoothEMA(values, alpha
|
| 13 |
if (values.length === 0) return [];
|
| 14 |
-
const
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
}
|
| 18 |
return out;
|
| 19 |
}
|
|
|
|
| 1 |
import React, { useState } from 'react';
|
| 2 |
import { lossChartStyles } from '../theme';
|
| 3 |
|
| 4 |
+
// Bias-corrected exponential moving average β same math as the EMA used in
|
| 5 |
+
// TensorBoard's loss curves and Adam's bias-corrected moments. Standard EMA
|
| 6 |
+
// (out[0] = values[0]) makes the smoothed line lag the data for the first
|
| 7 |
+
// 1/alpha steps; diffusion loss spikes high on step 0 (random init), so a
|
| 8 |
+
// naive EMA spends ~17 steps "catching down". The 1/(1-(1-Ξ±)^(i+1)) factor
|
| 9 |
+
// cancels that startup bias: by construction out[0] equals values[0], and
|
| 10 |
+
// out[i] converges to the plain EMA at steady state.
|
| 11 |
+
//
|
| 12 |
+
// alpha is *adaptive* to the run length: more data β more smoothing is OK
|
| 13 |
+
// because the underlying trend has more support; short runs need a tighter
|
| 14 |
+
// window so the smoothed line still resembles the data.
|
| 15 |
+
function pickAlpha(n) {
|
| 16 |
+
if (n < 50) return 0.25;
|
| 17 |
+
if (n < 200) return 0.15;
|
| 18 |
+
if (n < 1000) return 0.08;
|
| 19 |
+
return 0.05;
|
| 20 |
+
}
|
| 21 |
|
| 22 |
+
function smoothEMA(values, alpha) {
|
| 23 |
if (values.length === 0) return [];
|
| 24 |
+
const a = alpha ?? pickAlpha(values.length);
|
| 25 |
+
const w = 1 - a;
|
| 26 |
+
const out = [];
|
| 27 |
+
let ema = 0;
|
| 28 |
+
for (let i = 0; i < values.length; i++) {
|
| 29 |
+
ema = w * ema + a * values[i];
|
| 30 |
+
const correction = 1 - Math.pow(w, i + 1);
|
| 31 |
+
// correction β a at i=0 (so out[0] = values[0]) and β 1 at large i.
|
| 32 |
+
out.push(ema / Math.max(correction, 1e-9));
|
| 33 |
}
|
| 34 |
return out;
|
| 35 |
}
|
app/frontend/src/components/MidiConfigMenu.js
CHANGED
|
@@ -8,7 +8,6 @@ import {
|
|
| 8 |
MenuItem,
|
| 9 |
Button,
|
| 10 |
IconButton,
|
| 11 |
-
Tooltip,
|
| 12 |
Divider,
|
| 13 |
ToggleButton,
|
| 14 |
ToggleButtonGroup,
|
|
@@ -16,7 +15,7 @@ import {
|
|
| 16 |
} from '@mui/material';
|
| 17 |
import { Trash2 as DeleteIcon, X as CloseIcon } from 'lucide-react';
|
| 18 |
import { useMidi, formatMidi } from './MidiContext';
|
| 19 |
-
import { perfTokens } from '../theme';
|
| 20 |
|
| 21 |
const CHANNEL_OPTIONS = [
|
| 22 |
{ value: 0, label: 'Any' },
|
|
@@ -46,39 +45,60 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
|
|
| 46 |
anchorEl={anchorEl}
|
| 47 |
open={open}
|
| 48 |
onClose={onClose}
|
| 49 |
-
anchorOrigin={{ vertical: 'bottom', horizontal: '
|
| 50 |
-
transformOrigin={{ vertical: 'top', horizontal: '
|
| 51 |
slotProps={{
|
| 52 |
paper: {
|
| 53 |
sx: {
|
| 54 |
-
width:
|
| 55 |
maxHeight: '70vh',
|
| 56 |
-
p:
|
| 57 |
borderRadius: 2,
|
| 58 |
border: '1px solid',
|
| 59 |
borderColor: 'divider',
|
|
|
|
| 60 |
},
|
| 61 |
},
|
| 62 |
}}
|
| 63 |
>
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
MIDI Settings
|
| 67 |
</Typography>
|
| 68 |
-
<IconButton
|
| 69 |
-
<CloseIcon size={
|
| 70 |
</IconButton>
|
| 71 |
</Box>
|
| 72 |
|
|
|
|
|
|
|
| 73 |
{!supported && (
|
| 74 |
-
<
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
)}
|
| 78 |
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
<Box>
|
| 81 |
-
<Typography
|
| 82 |
Input device
|
| 83 |
</Typography>
|
| 84 |
<FormControl size="small" fullWidth>
|
|
@@ -92,19 +112,26 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
|
|
| 92 |
const found = inputs.find(i => i.id === value);
|
| 93 |
return found ? found.name : 'Disconnected';
|
| 94 |
}}
|
|
|
|
| 95 |
>
|
| 96 |
-
<MenuItem value="">
|
| 97 |
<em>None</em>
|
| 98 |
</MenuItem>
|
| 99 |
{inputs.map((input) => (
|
| 100 |
-
<MenuItem key={input.id} value={input.id}>
|
| 101 |
{input.name}
|
| 102 |
</MenuItem>
|
| 103 |
))}
|
| 104 |
</Select>
|
| 105 |
</FormControl>
|
| 106 |
{config.deviceName && !inputs.some(i => i.name === config.deviceName) && (
|
| 107 |
-
<Typography
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
Saved device "{config.deviceName}" not connected
|
| 109 |
</Typography>
|
| 110 |
)}
|
|
@@ -112,7 +139,7 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
|
|
| 112 |
|
| 113 |
<Box sx={{ display: 'flex', gap: 1 }}>
|
| 114 |
<Box sx={{ flex: 1 }}>
|
| 115 |
-
<Typography
|
| 116 |
Channel filter
|
| 117 |
</Typography>
|
| 118 |
<FormControl size="small" fullWidth>
|
|
@@ -120,16 +147,17 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
|
|
| 120 |
value={config.channelFilter}
|
| 121 |
onChange={(e) => setChannelFilter(Number(e.target.value))}
|
| 122 |
disabled={!supported}
|
|
|
|
| 123 |
>
|
| 124 |
{CHANNEL_OPTIONS.map(opt => (
|
| 125 |
-
<MenuItem key={opt.value} value={opt.value}>{opt.label}</MenuItem>
|
| 126 |
))}
|
| 127 |
</Select>
|
| 128 |
</FormControl>
|
| 129 |
</Box>
|
| 130 |
|
| 131 |
<Box sx={{ flex: 1 }}>
|
| 132 |
-
<Typography
|
| 133 |
Takeover
|
| 134 |
</Typography>
|
| 135 |
<ToggleButtonGroup
|
|
@@ -138,25 +166,42 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
|
|
| 138 |
exclusive
|
| 139 |
onChange={(_, v) => { if (v) setTakeover(v); }}
|
| 140 |
fullWidth
|
| 141 |
-
sx={{ height:
|
| 142 |
>
|
| 143 |
-
<ToggleButton value="jump" sx={{ fontSize: perfTokens.fontSize.
|
| 144 |
-
<ToggleButton value="pickup" sx={{ fontSize: perfTokens.fontSize.
|
| 145 |
</ToggleButtonGroup>
|
| 146 |
</Box>
|
| 147 |
</Box>
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
Mappings ({config.mappings.length})
|
| 154 |
</Typography>
|
| 155 |
<Button
|
| 156 |
size="small"
|
| 157 |
onClick={clearAll}
|
| 158 |
disabled={config.mappings.length === 0}
|
| 159 |
-
sx={{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
>
|
| 161 |
Clear all
|
| 162 |
</Button>
|
|
@@ -167,15 +212,21 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
|
|
| 167 |
border: '1px solid',
|
| 168 |
borderColor: 'divider',
|
| 169 |
borderRadius: 1,
|
| 170 |
-
maxHeight:
|
| 171 |
overflowY: 'auto',
|
| 172 |
bgcolor: 'background.default',
|
| 173 |
}}
|
| 174 |
>
|
| 175 |
{sortedMappings.length === 0 ? (
|
| 176 |
-
<Box sx={{
|
| 177 |
-
<Typography
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
</Typography>
|
| 180 |
</Box>
|
| 181 |
) : (
|
|
@@ -191,33 +242,54 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
|
|
| 191 |
borderBottom: '1px solid',
|
| 192 |
borderColor: 'divider',
|
| 193 |
'&:last-child': { borderBottom: 'none' },
|
|
|
|
|
|
|
| 194 |
}}
|
| 195 |
>
|
| 196 |
<Box sx={{ flex: 1, minWidth: 0 }}>
|
| 197 |
-
<Typography
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
{m.label}
|
| 199 |
</Typography>
|
| 200 |
-
<Typography
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
{formatMidi(m.midi)}
|
| 202 |
</Typography>
|
| 203 |
</Box>
|
| 204 |
-
<
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
</Tooltip>
|
| 213 |
</Box>
|
| 214 |
))
|
| 215 |
)}
|
| 216 |
</Box>
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
</Typography>
|
| 222 |
</Box>
|
| 223 |
</Popover>
|
|
|
|
| 8 |
MenuItem,
|
| 9 |
Button,
|
| 10 |
IconButton,
|
|
|
|
| 11 |
Divider,
|
| 12 |
ToggleButton,
|
| 13 |
ToggleButtonGroup,
|
|
|
|
| 15 |
} from '@mui/material';
|
| 16 |
import { Trash2 as DeleteIcon, X as CloseIcon } from 'lucide-react';
|
| 17 |
import { useMidi, formatMidi } from './MidiContext';
|
| 18 |
+
import { perfTokens, performancePanelStyles as panelStyles } from '../theme';
|
| 19 |
|
| 20 |
const CHANNEL_OPTIONS = [
|
| 21 |
{ value: 0, label: 'Any' },
|
|
|
|
| 45 |
anchorEl={anchorEl}
|
| 46 |
open={open}
|
| 47 |
onClose={onClose}
|
| 48 |
+
anchorOrigin={{ vertical: 'bottom', horizontal: 'left' }}
|
| 49 |
+
transformOrigin={{ vertical: 'top', horizontal: 'left' }}
|
| 50 |
slotProps={{
|
| 51 |
paper: {
|
| 52 |
sx: {
|
| 53 |
+
width: 360,
|
| 54 |
maxHeight: '70vh',
|
| 55 |
+
p: 0,
|
| 56 |
borderRadius: 2,
|
| 57 |
border: '1px solid',
|
| 58 |
borderColor: 'divider',
|
| 59 |
+
overflow: 'hidden',
|
| 60 |
},
|
| 61 |
},
|
| 62 |
}}
|
| 63 |
>
|
| 64 |
+
{/* Title bar β same pattern as Presets / Audio menus. */}
|
| 65 |
+
<Box sx={{
|
| 66 |
+
display: 'flex',
|
| 67 |
+
alignItems: 'center',
|
| 68 |
+
justifyContent: 'space-between',
|
| 69 |
+
px: 1.5,
|
| 70 |
+
pt: 1.25,
|
| 71 |
+
pb: 1,
|
| 72 |
+
}}>
|
| 73 |
+
<Typography sx={{ ...perfTokens.caps, color: 'text.secondary' }}>
|
| 74 |
MIDI Settings
|
| 75 |
</Typography>
|
| 76 |
+
<IconButton onClick={onClose} sx={panelStyles.compactIconBtn('md')}>
|
| 77 |
+
<CloseIcon size={perfTokens.icon.sm} />
|
| 78 |
</IconButton>
|
| 79 |
</Box>
|
| 80 |
|
| 81 |
+
<Divider />
|
| 82 |
+
|
| 83 |
{!supported && (
|
| 84 |
+
<Box sx={{ px: 1.5, pt: 1.25 }}>
|
| 85 |
+
<Alert severity="warning" sx={{ py: 0.5 }}>
|
| 86 |
+
{permissionError || 'Web MIDI is not available in this browser. Try Chrome / Edge / Electron.'}
|
| 87 |
+
</Alert>
|
| 88 |
+
</Box>
|
| 89 |
)}
|
| 90 |
|
| 91 |
+
{/* SETTINGS β input device + channel filter + takeover. */}
|
| 92 |
+
<Box sx={{
|
| 93 |
+
px: 1.5,
|
| 94 |
+
pt: 1.25,
|
| 95 |
+
pb: 1.25,
|
| 96 |
+
display: 'flex',
|
| 97 |
+
flexDirection: 'column',
|
| 98 |
+
gap: 1.25,
|
| 99 |
+
}}>
|
| 100 |
<Box>
|
| 101 |
+
<Typography sx={{ ...perfTokens.labelMuted, display: 'block', mb: 0.5 }}>
|
| 102 |
Input device
|
| 103 |
</Typography>
|
| 104 |
<FormControl size="small" fullWidth>
|
|
|
|
| 112 |
const found = inputs.find(i => i.id === value);
|
| 113 |
return found ? found.name : 'Disconnected';
|
| 114 |
}}
|
| 115 |
+
sx={{ fontSize: perfTokens.fontSize.sm }}
|
| 116 |
>
|
| 117 |
+
<MenuItem value="" sx={{ fontSize: perfTokens.fontSize.sm }}>
|
| 118 |
<em>None</em>
|
| 119 |
</MenuItem>
|
| 120 |
{inputs.map((input) => (
|
| 121 |
+
<MenuItem key={input.id} value={input.id} sx={{ fontSize: perfTokens.fontSize.sm }}>
|
| 122 |
{input.name}
|
| 123 |
</MenuItem>
|
| 124 |
))}
|
| 125 |
</Select>
|
| 126 |
</FormControl>
|
| 127 |
{config.deviceName && !inputs.some(i => i.name === config.deviceName) && (
|
| 128 |
+
<Typography sx={{
|
| 129 |
+
fontSize: perfTokens.fontSize.xs,
|
| 130 |
+
color: 'warning.main',
|
| 131 |
+
fontStyle: 'italic',
|
| 132 |
+
display: 'block',
|
| 133 |
+
mt: 0.5,
|
| 134 |
+
}}>
|
| 135 |
Saved device "{config.deviceName}" not connected
|
| 136 |
</Typography>
|
| 137 |
)}
|
|
|
|
| 139 |
|
| 140 |
<Box sx={{ display: 'flex', gap: 1 }}>
|
| 141 |
<Box sx={{ flex: 1 }}>
|
| 142 |
+
<Typography sx={{ ...perfTokens.labelMuted, display: 'block', mb: 0.5 }}>
|
| 143 |
Channel filter
|
| 144 |
</Typography>
|
| 145 |
<FormControl size="small" fullWidth>
|
|
|
|
| 147 |
value={config.channelFilter}
|
| 148 |
onChange={(e) => setChannelFilter(Number(e.target.value))}
|
| 149 |
disabled={!supported}
|
| 150 |
+
sx={{ fontSize: perfTokens.fontSize.sm }}
|
| 151 |
>
|
| 152 |
{CHANNEL_OPTIONS.map(opt => (
|
| 153 |
+
<MenuItem key={opt.value} value={opt.value} sx={{ fontSize: perfTokens.fontSize.sm }}>{opt.label}</MenuItem>
|
| 154 |
))}
|
| 155 |
</Select>
|
| 156 |
</FormControl>
|
| 157 |
</Box>
|
| 158 |
|
| 159 |
<Box sx={{ flex: 1 }}>
|
| 160 |
+
<Typography sx={{ ...perfTokens.labelMuted, display: 'block', mb: 0.5 }}>
|
| 161 |
Takeover
|
| 162 |
</Typography>
|
| 163 |
<ToggleButtonGroup
|
|
|
|
| 166 |
exclusive
|
| 167 |
onChange={(_, v) => { if (v) setTakeover(v); }}
|
| 168 |
fullWidth
|
| 169 |
+
sx={{ height: perfTokens.height.compact }}
|
| 170 |
>
|
| 171 |
+
<ToggleButton value="jump" sx={{ fontSize: perfTokens.fontSize.sm, textTransform: 'none' }}>Jump</ToggleButton>
|
| 172 |
+
<ToggleButton value="pickup" sx={{ fontSize: perfTokens.fontSize.sm, textTransform: 'none' }}>Pickup</ToggleButton>
|
| 173 |
</ToggleButtonGroup>
|
| 174 |
</Box>
|
| 175 |
</Box>
|
| 176 |
+
</Box>
|
| 177 |
|
| 178 |
+
<Divider />
|
| 179 |
|
| 180 |
+
{/* MAPPINGS β header row + bordered scrollable list. */}
|
| 181 |
+
<Box sx={{ px: 1.5, pt: 1.25, pb: 1.25 }}>
|
| 182 |
+
<Box sx={{
|
| 183 |
+
display: 'flex',
|
| 184 |
+
alignItems: 'center',
|
| 185 |
+
justifyContent: 'space-between',
|
| 186 |
+
mb: 0.75,
|
| 187 |
+
}}>
|
| 188 |
+
<Typography sx={{ ...perfTokens.labelMuted, display: 'block' }}>
|
| 189 |
Mappings ({config.mappings.length})
|
| 190 |
</Typography>
|
| 191 |
<Button
|
| 192 |
size="small"
|
| 193 |
onClick={clearAll}
|
| 194 |
disabled={config.mappings.length === 0}
|
| 195 |
+
sx={{
|
| 196 |
+
fontSize: perfTokens.fontSize.xs,
|
| 197 |
+
color: 'error.main',
|
| 198 |
+
textTransform: 'none',
|
| 199 |
+
py: 0,
|
| 200 |
+
px: 0.75,
|
| 201 |
+
minWidth: 0,
|
| 202 |
+
'&:hover': { bgcolor: 'action.hover' },
|
| 203 |
+
'&.Mui-disabled': { color: 'text.disabled' },
|
| 204 |
+
}}
|
| 205 |
>
|
| 206 |
Clear all
|
| 207 |
</Button>
|
|
|
|
| 212 |
border: '1px solid',
|
| 213 |
borderColor: 'divider',
|
| 214 |
borderRadius: 1,
|
| 215 |
+
maxHeight: 240,
|
| 216 |
overflowY: 'auto',
|
| 217 |
bgcolor: 'background.default',
|
| 218 |
}}
|
| 219 |
>
|
| 220 |
{sortedMappings.length === 0 ? (
|
| 221 |
+
<Box sx={{ px: 1.5, py: 1.5, textAlign: 'center' }}>
|
| 222 |
+
<Typography sx={{
|
| 223 |
+
fontSize: perfTokens.fontSize.xs,
|
| 224 |
+
color: 'text.disabled',
|
| 225 |
+
fontStyle: 'italic',
|
| 226 |
+
lineHeight: 1.4,
|
| 227 |
+
}}>
|
| 228 |
+
No mappings yet. Enable MIDI mode, click a control,
|
| 229 |
+
then move a hardware knob, fader, or button.
|
| 230 |
</Typography>
|
| 231 |
</Box>
|
| 232 |
) : (
|
|
|
|
| 242 |
borderBottom: '1px solid',
|
| 243 |
borderColor: 'divider',
|
| 244 |
'&:last-child': { borderBottom: 'none' },
|
| 245 |
+
'&:hover': { bgcolor: 'action.hover' },
|
| 246 |
+
transition: 'background-color 120ms',
|
| 247 |
}}
|
| 248 |
>
|
| 249 |
<Box sx={{ flex: 1, minWidth: 0 }}>
|
| 250 |
+
<Typography sx={{
|
| 251 |
+
fontSize: perfTokens.fontSize.sm,
|
| 252 |
+
fontWeight: 500,
|
| 253 |
+
overflow: 'hidden',
|
| 254 |
+
textOverflow: 'ellipsis',
|
| 255 |
+
whiteSpace: 'nowrap',
|
| 256 |
+
}}>
|
| 257 |
{m.label}
|
| 258 |
</Typography>
|
| 259 |
+
<Typography sx={{
|
| 260 |
+
color: 'text.secondary',
|
| 261 |
+
fontSize: perfTokens.fontSize.xs,
|
| 262 |
+
fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Consolas, monospace',
|
| 263 |
+
}}>
|
| 264 |
{formatMidi(m.midi)}
|
| 265 |
</Typography>
|
| 266 |
</Box>
|
| 267 |
+
<IconButton
|
| 268 |
+
size="small"
|
| 269 |
+
onClick={() => clearMapping(m.controlId)}
|
| 270 |
+
sx={panelStyles.compactIconBtn('sm', 'danger')}
|
| 271 |
+
aria-label="Remove mapping"
|
| 272 |
+
>
|
| 273 |
+
<DeleteIcon size={perfTokens.icon.sm} />
|
| 274 |
+
</IconButton>
|
|
|
|
| 275 |
</Box>
|
| 276 |
))
|
| 277 |
)}
|
| 278 |
</Box>
|
| 279 |
+
</Box>
|
| 280 |
+
|
| 281 |
+
<Divider />
|
| 282 |
|
| 283 |
+
{/* Footer help text. */}
|
| 284 |
+
<Box sx={{ px: 1.5, pt: 1, pb: 1.25 }}>
|
| 285 |
+
<Typography sx={{
|
| 286 |
+
color: 'text.disabled',
|
| 287 |
+
fontSize: perfTokens.fontSize.xs,
|
| 288 |
+
fontStyle: 'italic',
|
| 289 |
+
lineHeight: 1.4,
|
| 290 |
+
}}>
|
| 291 |
+
Pickup ignores the hardware until its position matches the on-screen
|
| 292 |
+
value. Right-click a control while in MIDI mode to clear its mapping.
|
| 293 |
</Typography>
|
| 294 |
</Box>
|
| 295 |
</Popover>
|
app/frontend/src/components/MidiContext.js
CHANGED
|
@@ -8,6 +8,7 @@ import React, {
|
|
| 8 |
useState,
|
| 9 |
} from 'react';
|
| 10 |
import { Box } from '@mui/material';
|
|
|
|
| 11 |
|
| 12 |
const STORAGE_KEY = 'fragmenta.midi.config.v1';
|
| 13 |
|
|
@@ -73,7 +74,6 @@ export function MidiProvider({ children }) {
|
|
| 73 |
const [learnMode, setLearnMode] = useState(false);
|
| 74 |
const [learnTarget, setLearnTarget] = useState(null);
|
| 75 |
|
| 76 |
-
const accessRef = useRef(null);
|
| 77 |
const subscribersRef = useRef(new Map());
|
| 78 |
const pickupArmedRef = useRef(new Map());
|
| 79 |
const configRef = useRef(config);
|
|
@@ -86,39 +86,23 @@ export function MidiProvider({ children }) {
|
|
| 86 |
|
| 87 |
useEffect(() => { learnTargetRef.current = learnTarget; }, [learnTarget]);
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
})
|
| 100 |
-
setInputs(list);
|
| 101 |
-
}, []);
|
| 102 |
-
|
| 103 |
-
useEffect(() => {
|
| 104 |
-
if (typeof navigator === 'undefined' || !navigator.requestMIDIAccess) {
|
| 105 |
setSupported(false);
|
| 106 |
-
|
| 107 |
}
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
if (cancelled) return;
|
| 112 |
-
accessRef.current = access;
|
| 113 |
-
refreshInputs();
|
| 114 |
-
access.onstatechange = refreshInputs;
|
| 115 |
-
})
|
| 116 |
-
.catch((err) => {
|
| 117 |
-
setPermissionError(err?.message || 'MIDI permission denied');
|
| 118 |
-
setSupported(false);
|
| 119 |
-
});
|
| 120 |
-
return () => { cancelled = true; };
|
| 121 |
-
}, [refreshInputs]);
|
| 122 |
|
| 123 |
useEffect(() => {
|
| 124 |
if (!inputs.length || !config.deviceName) return;
|
|
@@ -192,24 +176,30 @@ export function MidiProvider({ children }) {
|
|
| 192 |
}
|
| 193 |
}, [captureLearn]);
|
| 194 |
|
|
|
|
|
|
|
|
|
|
| 195 |
useEffect(() => {
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
input.onmidimessage = null;
|
| 205 |
-
}
|
| 206 |
-
});
|
| 207 |
-
|
| 208 |
-
pickupArmedRef.current = new Map();
|
| 209 |
-
return () => {
|
| 210 |
-
bound.forEach((i) => { i.onmidimessage = null; });
|
| 211 |
};
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
function applyContinuous(sub, mapping, midiValue, takeover) {
|
| 215 |
const norm = midiValue / 127;
|
|
|
|
| 8 |
useState,
|
| 9 |
} from 'react';
|
| 10 |
import { Box } from '@mui/material';
|
| 11 |
+
import api from '../api';
|
| 12 |
|
| 13 |
const STORAGE_KEY = 'fragmenta.midi.config.v1';
|
| 14 |
|
|
|
|
| 74 |
const [learnMode, setLearnMode] = useState(false);
|
| 75 |
const [learnTarget, setLearnTarget] = useState(null);
|
| 76 |
|
|
|
|
| 77 |
const subscribersRef = useRef(new Map());
|
| 78 |
const pickupArmedRef = useRef(new Map());
|
| 79 |
const configRef = useRef(config);
|
|
|
|
| 86 |
|
| 87 |
useEffect(() => { learnTargetRef.current = learnTarget; }, [learnTarget]);
|
| 88 |
|
| 89 |
+
// Device list comes from the native backend (python-rtmidi) instead of
|
| 90 |
+
// Web MIDI, so it works in every web engine.
|
| 91 |
+
const refreshInputs = useCallback(async () => {
|
| 92 |
+
try {
|
| 93 |
+
const { data } = await api.get('/api/midi/devices');
|
| 94 |
+
setSupported(!!data.available);
|
| 95 |
+
setInputs(Array.isArray(data.inputs) ? data.inputs : []);
|
| 96 |
+
setPermissionError(data.available
|
| 97 |
+
? null
|
| 98 |
+
: 'Native MIDI is unavailable (python-rtmidi not installed).');
|
| 99 |
+
} catch (err) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
setSupported(false);
|
| 101 |
+
setPermissionError(err?.message || 'Could not reach the MIDI backend.');
|
| 102 |
}
|
| 103 |
+
}, []);
|
| 104 |
+
|
| 105 |
+
useEffect(() => { refreshInputs(); }, [refreshInputs]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
useEffect(() => {
|
| 108 |
if (!inputs.length || !config.deviceName) return;
|
|
|
|
| 176 |
}
|
| 177 |
}, [captureLearn]);
|
| 178 |
|
| 179 |
+
// Stream incoming MIDI from the backend over SSE. Each event is the same
|
| 180 |
+
// {data:[status,d1,d2]} shape Web MIDI gave us, so dispatchMessage is
|
| 181 |
+
// unchanged. EventSource auto-reconnects on drop.
|
| 182 |
useEffect(() => {
|
| 183 |
+
if (typeof EventSource === 'undefined') {
|
| 184 |
+
setSupported(false);
|
| 185 |
+
return undefined;
|
| 186 |
+
}
|
| 187 |
+
const es = new EventSource('/api/midi/stream');
|
| 188 |
+
es.onmessage = (e) => {
|
| 189 |
+
try { dispatchMessage(JSON.parse(e.data)); }
|
| 190 |
+
catch { /* malformed line β ignore */ }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
};
|
| 192 |
+
pickupArmedRef.current = new Map();
|
| 193 |
+
return () => es.close();
|
| 194 |
+
}, [dispatchMessage]);
|
| 195 |
+
|
| 196 |
+
// Tell the backend which port to open. The stream only carries the open
|
| 197 |
+
// port's events, so device selection happens server-side.
|
| 198 |
+
useEffect(() => {
|
| 199 |
+
if (!supported) return;
|
| 200 |
+
api.post('/api/midi/select', { port_id: config.deviceId || null })
|
| 201 |
+
.catch(() => { /* non-fatal */ });
|
| 202 |
+
}, [config.deviceId, supported]);
|
| 203 |
|
| 204 |
function applyContinuous(sub, mapping, midiValue, takeover) {
|
| 205 |
const norm = midiValue / 127;
|
app/frontend/src/components/PerformanceChannel.js
CHANGED
|
@@ -5,27 +5,37 @@ import {
|
|
| 5 |
TextField,
|
| 6 |
IconButton,
|
| 7 |
Slider,
|
| 8 |
-
CircularProgress,
|
| 9 |
-
Tooltip,
|
| 10 |
Select,
|
| 11 |
MenuItem,
|
| 12 |
ButtonBase,
|
| 13 |
} from '@mui/material';
|
|
|
|
|
|
|
| 14 |
import {
|
| 15 |
Play as PlayIcon,
|
| 16 |
Square as StopIcon,
|
| 17 |
-
|
| 18 |
-
Sparkles as GenerateIcon,
|
| 19 |
Volume2 as VolumeIcon,
|
| 20 |
VolumeX as MuteIcon,
|
| 21 |
-
|
| 22 |
-
Check as CommitIcon,
|
| 23 |
} from 'lucide-react';
|
| 24 |
-
import { performanceChannelStyles as styles, perfTokens } from '../theme';
|
| 25 |
import { MidiMappable } from './MidiContext';
|
| 26 |
import { playBlob as playCueBlob, stopCue, isCueSupported } from '../utils/cueAudio';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
const CHANNEL_COLORS = [
|
|
|
|
|
|
|
|
|
|
| 29 |
'#35C2D4', '#9F8AE6', '#53C18A', '#E3A34B',
|
| 30 |
'#E36C61', '#F08AD2', '#5BA0F0', '#A8D86B',
|
| 31 |
];
|
|
@@ -40,19 +50,31 @@ const gainDbToLinear = (db) => (db <= GAIN_DB_MIN ? 0 : Math.pow(10, db / 20));
|
|
| 40 |
|
| 41 |
const KNOB_DEFS = [
|
| 42 |
{ key: 'gain', label: 'GAIN', min: GAIN_DB_MIN, max: GAIN_DB_MAX, step: 0.5, default: GAIN_DB_DEFAULT },
|
| 43 |
-
//
|
| 44 |
-
//
|
| 45 |
-
// the
|
| 46 |
-
|
|
|
|
| 47 |
{ key: 'delay', label: 'DLY', min: 0, max: 1.0, step: 0.01, default: 0.0 },
|
| 48 |
{ key: 'reverb', label: 'REV', min: 0, max: 1.0, step: 0.01, default: 0.0 },
|
| 49 |
];
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
const PAN_CENTER_SNAP = 0.06;
|
| 52 |
|
| 53 |
const BARS_OPTIONS = [1, 2, 4, 8, 16];
|
| 54 |
const BEATS_PER_BAR = 4;
|
| 55 |
const BATCH_OPTIONS = [1, 2, 3, 4];
|
|
|
|
|
|
|
| 56 |
|
| 57 |
export default function PerformanceChannel({
|
| 58 |
index,
|
|
@@ -65,13 +87,16 @@ export default function PerformanceChannel({
|
|
| 65 |
onStateChange,
|
| 66 |
onFormStateChange,
|
| 67 |
initialFormState,
|
| 68 |
-
maxDuration =
|
| 69 |
bpm = 120,
|
| 70 |
}) {
|
| 71 |
const color = CHANNEL_COLORS[index % CHANNEL_COLORS.length];
|
| 72 |
const canvasRef = useRef(null);
|
| 73 |
const meterRef = useRef(null);
|
| 74 |
const meterRafRef = useRef(null);
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
const init = initialFormState || {};
|
| 77 |
const initKnobs = init.knobs || {};
|
|
@@ -83,7 +108,7 @@ export default function PerformanceChannel({
|
|
| 83 |
|
| 84 |
const [prompt, setPrompt] = useState(init.prompt ?? '');
|
| 85 |
const [duration, setDuration] = useState(init.duration ?? 8);
|
| 86 |
-
const [durationMode, setDurationMode] = useState(init.durationMode ?? '
|
| 87 |
const [bars, setBars] = useState(init.bars ?? 4);
|
| 88 |
const [generating, setGenerating] = useState(false);
|
| 89 |
const [loaded, setLoaded] = useState(false);
|
|
@@ -91,31 +116,147 @@ export default function PerformanceChannel({
|
|
| 91 |
const [muted, setMuted] = useState(init.muted ?? false);
|
| 92 |
const [soloed, setSoloed] = useState(init.soloed ?? false);
|
| 93 |
const [batchSize, setBatchSize] = useState(init.batchSize ?? 1);
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
const cueSupported = useMemo(() => isCueSupported(), []);
|
| 103 |
|
| 104 |
// Stop any active cue audition when the channel unmounts.
|
| 105 |
useEffect(() => () => stopCue(), []);
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
// Mirror form state up to the panel so it can persist the session. Skip the
|
| 108 |
// first render so we don't re-write what we just loaded from localStorage.
|
|
|
|
|
|
|
| 109 |
const initialReportSkippedRef = useRef(false);
|
| 110 |
useEffect(() => {
|
| 111 |
if (!initialReportSkippedRef.current) {
|
| 112 |
initialReportSkippedRef.current = true;
|
| 113 |
return;
|
| 114 |
}
|
|
|
|
| 115 |
onFormStateChange?.(index, {
|
| 116 |
prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs,
|
|
|
|
|
|
|
| 117 |
});
|
| 118 |
-
}, [prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
const secondsFromBars = useMemo(
|
| 121 |
() => bars * (60 / Math.max(bpm, 1)) * BEATS_PER_BAR,
|
|
@@ -145,12 +286,31 @@ export default function PerformanceChannel({
|
|
| 145 |
|
| 146 |
const drawWave = useCallback(() => {
|
| 147 |
if (strip && canvasRef.current) {
|
|
|
|
| 148 |
strip.drawWaveform(canvasRef.current, color);
|
| 149 |
}
|
| 150 |
}, [strip, color]);
|
| 151 |
|
| 152 |
useEffect(() => { drawWave(); }, [drawWave, loaded]);
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
// One-shot: push restored knob/loop values into the audio strip when it
|
| 155 |
// first becomes available, so the persisted session matches what's heard.
|
| 156 |
// Mute/solo applies through the parent's mix handler so the panel can
|
|
@@ -161,7 +321,11 @@ export default function PerformanceChannel({
|
|
| 161 |
stripStateAppliedRef.current = true;
|
| 162 |
strip.setUserGain(gainDbToLinear(knobs.gain));
|
| 163 |
strip.setPan(knobs.pan);
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
strip.setDelayMix(knobs.delay);
|
| 166 |
strip.setReverbMix(knobs.reverb);
|
| 167 |
strip.setLoop(looping);
|
|
@@ -183,34 +347,89 @@ export default function PerformanceChannel({
|
|
| 183 |
}
|
| 184 |
}, [availableBars, bars]);
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
const handleGenerate = async () => {
|
| 187 |
if (!prompt.trim() || generating) return;
|
| 188 |
const inBarsMode = durationMode === 'bars';
|
| 189 |
const effectiveDuration = inBarsMode ? secondsFromBars : duration;
|
| 190 |
setGenerating(true);
|
| 191 |
-
// Stop any in-flight cue audition
|
| 192 |
-
//
|
| 193 |
stopCue();
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
| 195 |
try {
|
| 196 |
-
|
| 197 |
prompt,
|
| 198 |
duration: effectiveDuration,
|
| 199 |
batchSize,
|
| 200 |
// Only forward alignment params in bars mode β seconds mode
|
| 201 |
// generates raw audio with no post-processing.
|
| 202 |
...(inBarsMode ? { alignBars: bars, alignBpm: bpm } : {}),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
});
|
| 204 |
-
const blobs = Array.isArray(result) ? result : [result];
|
| 205 |
-
const next = blobs.map((b, i) => ({ index: i, blob: b }));
|
| 206 |
-
setCandidates(next);
|
| 207 |
-
// First candidate auto-loads into the channel strip; the rest sit
|
| 208 |
-
// in the audition row until the user commits a different one.
|
| 209 |
-
await strip.loadBlob(blobs[0]);
|
| 210 |
-
setCommittedIndex(0);
|
| 211 |
-
setLoaded(true);
|
| 212 |
-
onStateChange?.(index, { loaded: true });
|
| 213 |
-
requestAnimationFrame(drawWave);
|
| 214 |
} catch (err) {
|
| 215 |
console.error(`Channel ${index + 1} generate failed:`, err);
|
| 216 |
} finally {
|
|
@@ -218,37 +437,166 @@ export default function PerformanceChannel({
|
|
| 218 |
}
|
| 219 |
};
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
stopCue();
|
| 226 |
-
|
| 227 |
return;
|
| 228 |
}
|
| 229 |
-
|
| 230 |
try {
|
| 231 |
-
await playCueBlob(
|
| 232 |
-
onEnded: () =>
|
| 233 |
});
|
| 234 |
} catch (err) {
|
| 235 |
console.warn(`Channel ${index + 1} audition failed:`, err);
|
| 236 |
-
|
| 237 |
}
|
| 238 |
};
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
requestAnimationFrame(drawWave);
|
| 250 |
};
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
const handlePlay = () => {
|
| 253 |
if (!loaded) return;
|
| 254 |
if (engine) engine.playChannel(index, looping);
|
|
@@ -285,7 +633,11 @@ export default function PerformanceChannel({
|
|
| 285 |
setKnobs(prev => ({ ...prev, [key]: value }));
|
| 286 |
if (key === 'gain') strip.setUserGain(gainDbToLinear(value));
|
| 287 |
else if (key === 'pan') strip.setPan(value);
|
| 288 |
-
else if (key === 'filter')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
else if (key === 'delay') strip.setDelayMix(value);
|
| 290 |
else if (key === 'reverb') strip.setReverbMix(value);
|
| 291 |
};
|
|
@@ -307,15 +659,45 @@ export default function PerformanceChannel({
|
|
| 307 |
return (
|
| 308 |
<Box sx={styles.strip(color, playing)}>
|
| 309 |
<Box sx={styles.stripHeader(color)}>
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
<Box sx={styles.muteSoloRow}>
|
| 312 |
<MidiMappable id={ctrlId('mute')} label={ctrlLabel('Mute')} kind="trigger" onChange={handleMuteToggle}>
|
| 313 |
-
<Tooltip title=
|
| 314 |
<IconButton size="small" onClick={handleMuteToggle} sx={styles.muteBtn(muted)}>M</IconButton>
|
| 315 |
</Tooltip>
|
| 316 |
</MidiMappable>
|
| 317 |
<MidiMappable id={ctrlId('solo')} label={ctrlLabel('Solo')} kind="trigger" onChange={handleSoloToggle}>
|
| 318 |
-
<Tooltip title=
|
| 319 |
<IconButton size="small" onClick={handleSoloToggle} sx={styles.soloBtn(soloed)}>S</IconButton>
|
| 320 |
</Tooltip>
|
| 321 |
</MidiMappable>
|
|
@@ -324,18 +706,50 @@ export default function PerformanceChannel({
|
|
| 324 |
|
| 325 |
<Box sx={styles.promptBox}>
|
| 326 |
<TextField
|
| 327 |
-
placeholder="
|
| 328 |
value={prompt}
|
| 329 |
onChange={(e) => setPrompt(e.target.value)}
|
| 330 |
multiline
|
| 331 |
minRows={2}
|
| 332 |
-
maxRows={
|
| 333 |
size="small"
|
| 334 |
fullWidth
|
| 335 |
sx={styles.promptField}
|
| 336 |
disabled={generating}
|
| 337 |
/>
|
| 338 |
<Box sx={{ ...styles.durationRow, minHeight: 26, height: 26 }}>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
<Box
|
| 340 |
sx={{
|
| 341 |
display: 'inline-flex',
|
|
@@ -344,9 +758,13 @@ export default function PerformanceChannel({
|
|
| 344 |
borderRadius: 0.75,
|
| 345 |
overflow: 'hidden',
|
| 346 |
height: '100%',
|
|
|
|
| 347 |
}}
|
| 348 |
>
|
| 349 |
-
{[
|
|
|
|
|
|
|
|
|
|
| 350 |
const value = mode === 'sec' ? 'seconds' : 'bars';
|
| 351 |
const active = durationMode === value;
|
| 352 |
return (
|
|
@@ -354,215 +772,170 @@ export default function PerformanceChannel({
|
|
| 354 |
key={mode}
|
| 355 |
onClick={() => setDurationMode(value)}
|
| 356 |
sx={{
|
| 357 |
-
fontSize: perfTokens.fontSize.
|
| 358 |
-
letterSpacing: perfTokens.letterSpacing.wide,
|
| 359 |
-
textTransform: 'uppercase',
|
| 360 |
-
fontFamily: 'inherit',
|
| 361 |
px: 0.7,
|
| 362 |
-
minWidth:
|
| 363 |
bgcolor: active ? color : 'transparent',
|
|
|
|
|
|
|
| 364 |
color: active ? 'rgba(0,0,0,0.88)' : 'text.disabled',
|
| 365 |
-
fontWeight: active ?
|
| 366 |
-
transition: 'background-color 120ms, color 120ms',
|
| 367 |
'&:hover': {
|
| 368 |
bgcolor: active ? color : 'action.hover',
|
| 369 |
color: active ? 'rgba(0,0,0,0.88)' : 'text.secondary',
|
| 370 |
},
|
| 371 |
}}
|
| 372 |
>
|
| 373 |
-
{
|
| 374 |
</ButtonBase>
|
| 375 |
);
|
| 376 |
})}
|
| 377 |
</Box>
|
| 378 |
-
|
| 379 |
-
{durationMode === 'seconds' ? (
|
| 380 |
-
<>
|
| 381 |
-
<Typography variant="caption" sx={styles.durationLabel}>{duration.toFixed(0)}s</Typography>
|
| 382 |
-
<Slider
|
| 383 |
-
value={duration}
|
| 384 |
-
onChange={(_, v) => setDuration(v)}
|
| 385 |
-
min={2}
|
| 386 |
-
max={maxDuration}
|
| 387 |
-
step={1}
|
| 388 |
-
size="small"
|
| 389 |
-
sx={styles.durationSlider(color)}
|
| 390 |
-
/>
|
| 391 |
-
</>
|
| 392 |
-
) : (
|
| 393 |
-
<Select
|
| 394 |
-
value={availableBars.includes(bars) ? bars : availableBars[availableBars.length - 1]}
|
| 395 |
-
onChange={(e) => setBars(Number(e.target.value))}
|
| 396 |
-
size="small"
|
| 397 |
-
sx={{
|
| 398 |
-
flex: 1,
|
| 399 |
-
fontSize: perfTokens.fontSize.body,
|
| 400 |
-
height: '100%',
|
| 401 |
-
'& .MuiOutlinedInput-input': {
|
| 402 |
-
py: 0,
|
| 403 |
-
pl: 1,
|
| 404 |
-
minHeight: 'unset',
|
| 405 |
-
},
|
| 406 |
-
'& .MuiSelect-select': {
|
| 407 |
-
py: 0,
|
| 408 |
-
pl: 1,
|
| 409 |
-
minHeight: 'unset',
|
| 410 |
-
},
|
| 411 |
-
}}
|
| 412 |
-
>
|
| 413 |
-
{availableBars.map(b => (
|
| 414 |
-
<MenuItem key={b} value={b} sx={{ fontSize: perfTokens.fontSize.body }}>
|
| 415 |
-
{b} {b === 1 ? 'bar' : 'bars'}
|
| 416 |
-
</MenuItem>
|
| 417 |
-
))}
|
| 418 |
-
</Select>
|
| 419 |
-
)}
|
| 420 |
</Box>
|
| 421 |
<Box sx={{
|
| 422 |
display: 'flex',
|
| 423 |
alignItems: 'center',
|
| 424 |
-
|
| 425 |
-
gap: 1.5,
|
| 426 |
mt: 0.5,
|
| 427 |
width: '100%',
|
| 428 |
}}>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
<Tooltip
|
| 430 |
-
title=
|
| 431 |
placement="top"
|
| 432 |
-
disableFocusListener
|
| 433 |
-
disableTouchListener
|
| 434 |
enterDelay={500}
|
| 435 |
>
|
| 436 |
<Select
|
| 437 |
value={batchSize}
|
| 438 |
onChange={(e) => setBatchSize(Number(e.target.value))}
|
| 439 |
-
size="small"
|
| 440 |
disabled={generating}
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
minWidth: 64,
|
| 445 |
-
'& .MuiOutlinedInput-input': { py: 0, pl: 1.25, pr: '28px !important', minHeight: 'unset' },
|
| 446 |
-
'& .MuiSelect-select': { py: 0, pl: 1.25, pr: '28px !important', minHeight: 'unset' },
|
| 447 |
-
}}
|
| 448 |
>
|
| 449 |
-
{BATCH_OPTIONS.map(n => (
|
| 450 |
-
<MenuItem
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
Γ{n}
|
| 452 |
</MenuItem>
|
| 453 |
))}
|
| 454 |
</Select>
|
| 455 |
</Tooltip>
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
|
|
|
| 462 |
>
|
| 463 |
-
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
</MidiMappable>
|
| 466 |
</Box>
|
| 467 |
</Box>
|
| 468 |
|
| 469 |
-
<Box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
<canvas
|
| 471 |
ref={canvasRef}
|
| 472 |
width={140}
|
| 473 |
height={42}
|
| 474 |
-
style={{ width: '100%', height: 42, display: 'block' }}
|
| 475 |
/>
|
| 476 |
{!loaded && (
|
| 477 |
-
<Typography
|
| 478 |
-
|
| 479 |
</Typography>
|
| 480 |
)}
|
| 481 |
</Box>
|
| 482 |
|
| 483 |
-
{
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
sx={{
|
| 501 |
-
display: 'inline-flex',
|
| 502 |
-
alignItems: 'center',
|
| 503 |
-
border: '1px solid',
|
| 504 |
-
borderColor: isCommitted ? color : 'divider',
|
| 505 |
-
borderRadius: 0.75,
|
| 506 |
-
overflow: 'hidden',
|
| 507 |
-
bgcolor: isCommitted ? `${color}1a` : 'transparent',
|
| 508 |
-
}}
|
| 509 |
-
>
|
| 510 |
-
<Tooltip
|
| 511 |
-
title={
|
| 512 |
-
cueSupported
|
| 513 |
-
? (isAuditioning ? 'Stop cue audition' : 'Audition this take through cue output')
|
| 514 |
-
: 'Cue audition requires Chrome/Edge. Plays through main output.'
|
| 515 |
-
}
|
| 516 |
-
>
|
| 517 |
-
<IconButton
|
| 518 |
-
onClick={() => handleAudition(i)}
|
| 519 |
-
size="small"
|
| 520 |
-
sx={{
|
| 521 |
-
color: isAuditioning ? color : 'text.secondary',
|
| 522 |
-
px: 0.5,
|
| 523 |
-
borderRadius: 0,
|
| 524 |
-
}}
|
| 525 |
-
>
|
| 526 |
-
<CueIcon size={12} />
|
| 527 |
-
<Box
|
| 528 |
-
component="span"
|
| 529 |
-
sx={{
|
| 530 |
-
ml: 0.4,
|
| 531 |
-
fontSize: perfTokens.fontSize.small,
|
| 532 |
-
fontWeight: isAuditioning ? 700 : 500,
|
| 533 |
-
}}
|
| 534 |
-
>
|
| 535 |
-
{i + 1}
|
| 536 |
-
</Box>
|
| 537 |
-
</IconButton>
|
| 538 |
-
</Tooltip>
|
| 539 |
-
<Tooltip title={isCommitted ? 'Currently in channel' : 'Use this take in the channel'}>
|
| 540 |
-
<span>
|
| 541 |
-
<IconButton
|
| 542 |
-
onClick={() => handleCommit(i)}
|
| 543 |
-
size="small"
|
| 544 |
-
disabled={isCommitted}
|
| 545 |
-
sx={{
|
| 546 |
-
color: isCommitted ? color : 'text.disabled',
|
| 547 |
-
px: 0.4,
|
| 548 |
-
borderRadius: 0,
|
| 549 |
-
borderLeft: '1px solid',
|
| 550 |
-
borderColor: 'divider',
|
| 551 |
-
}}
|
| 552 |
-
>
|
| 553 |
-
<CommitIcon size={12} />
|
| 554 |
-
</IconButton>
|
| 555 |
-
</span>
|
| 556 |
-
</Tooltip>
|
| 557 |
-
</Box>
|
| 558 |
-
);
|
| 559 |
-
})}
|
| 560 |
-
</Box>
|
| 561 |
-
)}
|
| 562 |
|
| 563 |
<Box sx={{ px: 1, py: 1 }}>
|
| 564 |
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1, mb: 1 }}>
|
| 565 |
-
<Box component="span" sx={{
|
| 566 |
<MidiMappable
|
| 567 |
id={ctrlId('pan')}
|
| 568 |
label={ctrlLabel('Pan')}
|
|
@@ -584,6 +957,10 @@ export default function PerformanceChannel({
|
|
| 584 |
marks={[{ value: 0 }]}
|
| 585 |
sx={{
|
| 586 |
flex: 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
'& .MuiSlider-mark': {
|
| 588 |
width: 2,
|
| 589 |
height: 10,
|
|
@@ -604,6 +981,7 @@ export default function PerformanceChannel({
|
|
| 604 |
<Box sx={styles.knobsGrid}>
|
| 605 |
{KNOB_DEFS.map((k) => {
|
| 606 |
const isLog = k.scale === 'log';
|
|
|
|
| 607 |
// For log knobs, the slider drives a 0..1 position and we
|
| 608 |
// convert to/from the underlying value (Hz) on the audio
|
| 609 |
// boundary. The knob value stored in state stays in the
|
|
@@ -635,7 +1013,24 @@ export default function PerformanceChannel({
|
|
| 635 |
max={isLog ? 1 : k.max}
|
| 636 |
step={isLog ? 0.001 : k.step}
|
| 637 |
size="small"
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
/>
|
| 640 |
</MidiMappable>
|
| 641 |
<Box component="span" sx={styles.knobLabel}>{k.label}</Box>
|
|
@@ -644,26 +1039,10 @@ export default function PerformanceChannel({
|
|
| 644 |
})}
|
| 645 |
</Box>
|
| 646 |
|
|
|
|
|
|
|
|
|
|
| 647 |
<Box sx={styles.transportRow}>
|
| 648 |
-
<MidiMappable id={ctrlId('transport')} label={ctrlLabel('Play/Stop')} kind="trigger" onChange={handleTransportToggle}>
|
| 649 |
-
<IconButton
|
| 650 |
-
onClick={playing ? handleStop : handlePlay}
|
| 651 |
-
disabled={!loaded}
|
| 652 |
-
sx={styles.transportBtn(color, playing)}
|
| 653 |
-
size="small"
|
| 654 |
-
>
|
| 655 |
-
{playing ? <StopIcon size={16} /> : <PlayIcon size={16} />}
|
| 656 |
-
</IconButton>
|
| 657 |
-
</MidiMappable>
|
| 658 |
-
<MidiMappable id={ctrlId('loop')} label={ctrlLabel('Loop')} kind="trigger" onChange={handleLoopToggle}>
|
| 659 |
-
<IconButton
|
| 660 |
-
onClick={handleLoopToggle}
|
| 661 |
-
sx={styles.loopBtn(color, looping)}
|
| 662 |
-
size="small"
|
| 663 |
-
>
|
| 664 |
-
<LoopIcon size={14} />
|
| 665 |
-
</IconButton>
|
| 666 |
-
</MidiMappable>
|
| 667 |
<Box sx={styles.meterTrack}>
|
| 668 |
<Box ref={meterRef} sx={styles.meterFill(color)} />
|
| 669 |
</Box>
|
|
|
|
| 5 |
TextField,
|
| 6 |
IconButton,
|
| 7 |
Slider,
|
|
|
|
|
|
|
| 8 |
Select,
|
| 9 |
MenuItem,
|
| 10 |
ButtonBase,
|
| 11 |
} from '@mui/material';
|
| 12 |
+
import { TIPS } from '../tooltips';
|
| 13 |
+
import Tooltip from './Tooltip';
|
| 14 |
import {
|
| 15 |
Play as PlayIcon,
|
| 16 |
Square as StopIcon,
|
| 17 |
+
ArrowRight as GenerateArrowIcon,
|
|
|
|
| 18 |
Volume2 as VolumeIcon,
|
| 19 |
VolumeX as MuteIcon,
|
| 20 |
+
Shuffle as VariationIcon,
|
|
|
|
| 21 |
} from 'lucide-react';
|
| 22 |
+
import { performanceChannelStyles as styles, performancePanelStyles as panelStyles, perfTokens, SHEEN_DARK, RAISE_DARK } from '../theme';
|
| 23 |
import { MidiMappable } from './MidiContext';
|
| 24 |
import { playBlob as playCueBlob, stopCue, isCueSupported } from '../utils/cueAudio';
|
| 25 |
+
import {
|
| 26 |
+
channelScope,
|
| 27 |
+
putFragmentBlob,
|
| 28 |
+
getFragmentBlob,
|
| 29 |
+
deleteFragmentBlob,
|
| 30 |
+
clearScope as clearFragmentScope,
|
| 31 |
+
} from '../utils/fragmentStorage';
|
| 32 |
+
import api from '../api';
|
| 33 |
+
import ChannelFragmentHistory from './ChannelFragmentHistory';
|
| 34 |
|
| 35 |
const CHANNEL_COLORS = [
|
| 36 |
+
// Original introduction palette. Ch1 teal, ch2 violet, ch3 green, ch4 amber.
|
| 37 |
+
// Light enough that the black label text on active toggles stays legible.
|
| 38 |
+
// Slots 4β7 are spares (only 4 channels render).
|
| 39 |
'#35C2D4', '#9F8AE6', '#53C18A', '#E3A34B',
|
| 40 |
'#E36C61', '#F08AD2', '#5BA0F0', '#A8D86B',
|
| 41 |
];
|
|
|
|
| 50 |
|
| 51 |
const KNOB_DEFS = [
|
| 52 |
{ key: 'gain', label: 'GAIN', min: GAIN_DB_MIN, max: GAIN_DB_MAX, step: 0.5, default: GAIN_DB_DEFAULT },
|
| 53 |
+
// Bipolar "DJ-filter" knob. -1..+1 with 0 = bypass. Negative side drives
|
| 54 |
+
// the LPF cutoff down from 20 kHz β 20 Hz (kills highs). Positive side
|
| 55 |
+
// drives the HPF cutoff up from 20 Hz β 20 kHz (kills lows). The two
|
| 56 |
+
// biquads sit in series in the engine; only one side ever cuts at a time.
|
| 57 |
+
{ key: 'filter', label: 'FLT', min: -1, max: 1, step: 0.001, default: 0, scale: 'bipolar' },
|
| 58 |
{ key: 'delay', label: 'DLY', min: 0, max: 1.0, step: 0.01, default: 0.0 },
|
| 59 |
{ key: 'reverb', label: 'REV', min: 0, max: 1.0, step: 0.01, default: 0.0 },
|
| 60 |
];
|
| 61 |
|
| 62 |
+
// Map a bipolar filter position (-1..+1) to the (LPF, HPF) frequencies that
|
| 63 |
+
// the engine's two biquads need. 20 Hz / 20 kHz are the bypass anchors on
|
| 64 |
+
// each side; log-scaled so each octave gets equal slider travel.
|
| 65 |
+
function bipolarToFilterFreqs(pos) {
|
| 66 |
+
const lpf = pos <= 0 ? 20 * Math.pow(1000, 1 + pos) : 20000;
|
| 67 |
+
const hpf = pos >= 0 ? 20 * Math.pow(1000, pos) : 20;
|
| 68 |
+
return { lpf, hpf };
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
const PAN_CENTER_SNAP = 0.06;
|
| 72 |
|
| 73 |
const BARS_OPTIONS = [1, 2, 4, 8, 16];
|
| 74 |
const BEATS_PER_BAR = 4;
|
| 75 |
const BATCH_OPTIONS = [1, 2, 3, 4];
|
| 76 |
+
// Per-channel rolling fragment history cap. Starred fragments survive eviction.
|
| 77 |
+
const FRAGMENT_CAP = 200;
|
| 78 |
|
| 79 |
export default function PerformanceChannel({
|
| 80 |
index,
|
|
|
|
| 87 |
onStateChange,
|
| 88 |
onFormStateChange,
|
| 89 |
initialFormState,
|
| 90 |
+
maxDuration = 380,
|
| 91 |
bpm = 120,
|
| 92 |
}) {
|
| 93 |
const color = CHANNEL_COLORS[index % CHANNEL_COLORS.length];
|
| 94 |
const canvasRef = useRef(null);
|
| 95 |
const meterRef = useRef(null);
|
| 96 |
const meterRafRef = useRef(null);
|
| 97 |
+
// IDB scope key for this channel's fragment blobs. Stable across the
|
| 98 |
+
// component's lifetime since the channel index doesn't change.
|
| 99 |
+
const scope = channelScope(index);
|
| 100 |
|
| 101 |
const init = initialFormState || {};
|
| 102 |
const initKnobs = init.knobs || {};
|
|
|
|
| 108 |
|
| 109 |
const [prompt, setPrompt] = useState(init.prompt ?? '');
|
| 110 |
const [duration, setDuration] = useState(init.duration ?? 8);
|
| 111 |
+
const [durationMode, setDurationMode] = useState(init.durationMode ?? 'bars');
|
| 112 |
const [bars, setBars] = useState(init.bars ?? 4);
|
| 113 |
const [generating, setGenerating] = useState(false);
|
| 114 |
const [loaded, setLoaded] = useState(false);
|
|
|
|
| 116 |
const [muted, setMuted] = useState(init.muted ?? false);
|
| 117 |
const [soloed, setSoloed] = useState(init.soloed ?? false);
|
| 118 |
const [batchSize, setBatchSize] = useState(init.batchSize ?? 1);
|
| 119 |
+
// Live progress for the Generate pill while a generation is in flight.
|
| 120 |
+
// 0β100; polled from /api/generation-progress. Resets on each new run.
|
| 121 |
+
const [progress, setProgress] = useState(0);
|
| 122 |
+
const [knobs, setKnobs] = useState(() => {
|
| 123 |
+
const merged = { ...defaultKnobs, ...initKnobs };
|
| 124 |
+
// Migration: pre-bipolar `filter` was a raw Hz value (20..20000).
|
| 125 |
+
// Anything outside the new -1..+1 range is a legacy save β reset
|
| 126 |
+
// to bypass (0) rather than feeding nonsense into the engine.
|
| 127 |
+
if (merged.filter < -1 || merged.filter > 1) merged.filter = 0;
|
| 128 |
+
return merged;
|
| 129 |
+
});
|
| 130 |
+
|
| 131 |
+
// Per-channel rolling fragment history. Each fragment:
|
| 132 |
+
// { id, blob, audioUrl, prompt, duration, createdAt, starred, number }
|
| 133 |
+
// Oldest-first. Capped at FRAGMENT_CAP via FIFO eviction with star
|
| 134 |
+
// priority (starred fragments survive until everything is starred, then
|
| 135 |
+
// oldest go first regardless). `nextFragmentNumberRef` provides a stable
|
| 136 |
+
// F# even after deletes β so F1 stays F1.
|
| 137 |
+
const [fragments, setFragments] = useState([]);
|
| 138 |
+
const [auditioningFragmentId, setAuditioningFragmentId] = useState(null);
|
| 139 |
+
const [committedFragmentId, setCommittedFragmentId] = useState(null);
|
| 140 |
+
const nextFragmentNumberRef = useRef(1);
|
| 141 |
const cueSupported = useMemo(() => isCueSupported(), []);
|
| 142 |
|
| 143 |
// Stop any active cue audition when the channel unmounts.
|
| 144 |
useEffect(() => () => stopCue(), []);
|
| 145 |
|
| 146 |
+
// Poll /api/generation-progress while a generation is in flight so the
|
| 147 |
+
// Generate pill renders a real fill bar instead of a vague spinner. The
|
| 148 |
+
// backend exposes a single in-flight state; performance generations are
|
| 149 |
+
// sequential (the backend serves one at a time), so this naturally
|
| 150 |
+
// reflects whichever channel is currently busy.
|
| 151 |
+
useEffect(() => {
|
| 152 |
+
if (!generating) {
|
| 153 |
+
setProgress(0);
|
| 154 |
+
return;
|
| 155 |
+
}
|
| 156 |
+
let cancelled = false;
|
| 157 |
+
const tick = async () => {
|
| 158 |
+
if (cancelled) return;
|
| 159 |
+
try {
|
| 160 |
+
const r = await api.get('/api/generation-progress');
|
| 161 |
+
const pct = Number(r.data?.progress) || 0;
|
| 162 |
+
if (!cancelled) {
|
| 163 |
+
// Cap at 95 until handleGenerate resolves so the bar
|
| 164 |
+
// doesn't sit at 100 while waiting for the WAV blob.
|
| 165 |
+
setProgress((prev) => Math.max(prev, Math.min(95, pct)));
|
| 166 |
+
}
|
| 167 |
+
} catch { /* non-fatal β bar just freezes briefly */ }
|
| 168 |
+
};
|
| 169 |
+
tick();
|
| 170 |
+
const id = window.setInterval(tick, 250);
|
| 171 |
+
return () => { cancelled = true; window.clearInterval(id); };
|
| 172 |
+
}, [generating]);
|
| 173 |
+
|
| 174 |
// Mirror form state up to the panel so it can persist the session. Skip the
|
| 175 |
// first render so we don't re-write what we just loaded from localStorage.
|
| 176 |
+
// Fragments mirror as metadata only β the Blob bodies live in IndexedDB
|
| 177 |
+
// and get rehydrated on mount by the effect below.
|
| 178 |
const initialReportSkippedRef = useRef(false);
|
| 179 |
useEffect(() => {
|
| 180 |
if (!initialReportSkippedRef.current) {
|
| 181 |
initialReportSkippedRef.current = true;
|
| 182 |
return;
|
| 183 |
}
|
| 184 |
+
const fragmentsMeta = fragments.map(({ blob, audioUrl, ...rest }) => rest);
|
| 185 |
onFormStateChange?.(index, {
|
| 186 |
prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs,
|
| 187 |
+
fragments: fragmentsMeta,
|
| 188 |
+
committedFragmentId,
|
| 189 |
});
|
| 190 |
+
}, [prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs,
|
| 191 |
+
fragments, committedFragmentId, index, onFormStateChange]);
|
| 192 |
+
|
| 193 |
+
// Hydrate fragments on mount from the session metadata + IDB blobs. Runs
|
| 194 |
+
// once, tolerates missing blobs (skips the entry), and rewinds the
|
| 195 |
+
// fragment numbering counter so newly generated fragments don't collide
|
| 196 |
+
// with the restored ones.
|
| 197 |
+
const hydrationRef = useRef(false);
|
| 198 |
+
useEffect(() => {
|
| 199 |
+
if (hydrationRef.current) return;
|
| 200 |
+
hydrationRef.current = true;
|
| 201 |
+
// Backward compat: pre-rename saves used `takes`/`committedTakeId`.
|
| 202 |
+
// The session loader migrates them into `fragments`/`committedFragmentId`,
|
| 203 |
+
// but we also fall back here defensively in case `initialFormState`
|
| 204 |
+
// came from somewhere unmigrated.
|
| 205 |
+
const meta = initialFormState?.fragments
|
| 206 |
+
?? initialFormState?.takes
|
| 207 |
+
?? [];
|
| 208 |
+
const persistedCommittedId = initialFormState?.committedFragmentId
|
| 209 |
+
?? initialFormState?.committedTakeId
|
| 210 |
+
?? null;
|
| 211 |
+
if (meta.length === 0) {
|
| 212 |
+
if (persistedCommittedId) setCommittedFragmentId(null);
|
| 213 |
+
return;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
let cancelled = false;
|
| 217 |
+
(async () => {
|
| 218 |
+
const hydrated = [];
|
| 219 |
+
for (const m of meta) {
|
| 220 |
+
try {
|
| 221 |
+
const blob = await getFragmentBlob(scope, m.id);
|
| 222 |
+
if (cancelled) {
|
| 223 |
+
hydrated.forEach(t => URL.revokeObjectURL(t.audioUrl));
|
| 224 |
+
return;
|
| 225 |
+
}
|
| 226 |
+
if (!blob) continue;
|
| 227 |
+
hydrated.push({
|
| 228 |
+
...m,
|
| 229 |
+
blob,
|
| 230 |
+
audioUrl: URL.createObjectURL(blob),
|
| 231 |
+
});
|
| 232 |
+
} catch {
|
| 233 |
+
/* one bad fetch β keep going */
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
if (cancelled) {
|
| 237 |
+
hydrated.forEach(t => URL.revokeObjectURL(t.audioUrl));
|
| 238 |
+
return;
|
| 239 |
+
}
|
| 240 |
+
const maxNumber = hydrated.reduce((a, t) => Math.max(a, t.number || 0), 0);
|
| 241 |
+
nextFragmentNumberRef.current = maxNumber + 1;
|
| 242 |
+
setFragments(hydrated);
|
| 243 |
+
if (persistedCommittedId && hydrated.some(t => t.id === persistedCommittedId)) {
|
| 244 |
+
setCommittedFragmentId(persistedCommittedId);
|
| 245 |
+
setLoaded(true);
|
| 246 |
+
onStateChange?.(index, { loaded: true });
|
| 247 |
+
}
|
| 248 |
+
})();
|
| 249 |
+
|
| 250 |
+
return () => { cancelled = true; };
|
| 251 |
+
// eslint-disable-next-line react-hooks/exhaustive-deps
|
| 252 |
+
}, []);
|
| 253 |
+
|
| 254 |
+
// When the audio strip becomes available AND we have a hydrated committed
|
| 255 |
+
// fragment, load that fragment's blob into the strip so the channel comes back
|
| 256 |
+
// ready to play after reload. Declared here as a ref so the effect that
|
| 257 |
+
// actually does the work (after drawWave is defined below) can guard
|
| 258 |
+
// against multiple loads.
|
| 259 |
+
const autoLoadDoneRef = useRef(false);
|
| 260 |
|
| 261 |
const secondsFromBars = useMemo(
|
| 262 |
() => bars * (60 / Math.max(bpm, 1)) * BEATS_PER_BAR,
|
|
|
|
| 286 |
|
| 287 |
const drawWave = useCallback(() => {
|
| 288 |
if (strip && canvasRef.current) {
|
| 289 |
+
// Each channel's waveform is drawn in that channel's own color.
|
| 290 |
strip.drawWaveform(canvasRef.current, color);
|
| 291 |
}
|
| 292 |
}, [strip, color]);
|
| 293 |
|
| 294 |
useEffect(() => { drawWave(); }, [drawWave, loaded]);
|
| 295 |
|
| 296 |
+
// Auto-load the persisted committed fragment into the strip once Tone.js
|
| 297 |
+
// is ready. Runs at most once per mount; the ref guards against re-trigger
|
| 298 |
+
// when the user later commits a different fragment (handled by
|
| 299 |
+
// handleCommitFragment).
|
| 300 |
+
useEffect(() => {
|
| 301 |
+
if (autoLoadDoneRef.current) return;
|
| 302 |
+
if (!strip || !committedFragmentId) return;
|
| 303 |
+
const fragment = fragments.find(f => f.id === committedFragmentId);
|
| 304 |
+
if (!fragment) return;
|
| 305 |
+
autoLoadDoneRef.current = true;
|
| 306 |
+
strip.loadBlob(fragment.blob).then(() => {
|
| 307 |
+
requestAnimationFrame(drawWave);
|
| 308 |
+
}).catch(err => {
|
| 309 |
+
console.warn(`Channel ${index + 1} auto-load failed:`, err);
|
| 310 |
+
autoLoadDoneRef.current = false;
|
| 311 |
+
});
|
| 312 |
+
}, [strip, committedFragmentId, fragments, drawWave, index]);
|
| 313 |
+
|
| 314 |
// One-shot: push restored knob/loop values into the audio strip when it
|
| 315 |
// first becomes available, so the persisted session matches what's heard.
|
| 316 |
// Mute/solo applies through the parent's mix handler so the panel can
|
|
|
|
| 321 |
stripStateAppliedRef.current = true;
|
| 322 |
strip.setUserGain(gainDbToLinear(knobs.gain));
|
| 323 |
strip.setPan(knobs.pan);
|
| 324 |
+
{
|
| 325 |
+
const { lpf, hpf } = bipolarToFilterFreqs(knobs.filter);
|
| 326 |
+
strip.setFilter(lpf);
|
| 327 |
+
strip.setHighpass(hpf);
|
| 328 |
+
}
|
| 329 |
strip.setDelayMix(knobs.delay);
|
| 330 |
strip.setReverbMix(knobs.reverb);
|
| 331 |
strip.setLoop(looping);
|
|
|
|
| 347 |
}
|
| 348 |
}, [availableBars, bars]);
|
| 349 |
|
| 350 |
+
// Per-fragment handler factory β fires as each blob returns. Fragment #0
|
| 351 |
+
// auto-loads into the strip so the user can audition while #1..N render.
|
| 352 |
+
// Shared by Generate and Variation so both feed channel history identically.
|
| 353 |
+
const makeOnBlob = (promptSnap, effectiveDuration) => async (blob, i) => {
|
| 354 |
+
const fragmentNumber = nextFragmentNumberRef.current;
|
| 355 |
+
nextFragmentNumberRef.current = fragmentNumber + 1;
|
| 356 |
+
const fragment = {
|
| 357 |
+
id: `${Date.now()}_${i}`,
|
| 358 |
+
blob,
|
| 359 |
+
audioUrl: URL.createObjectURL(blob),
|
| 360 |
+
prompt: promptSnap,
|
| 361 |
+
duration: effectiveDuration,
|
| 362 |
+
createdAt: Date.now(),
|
| 363 |
+
starred: false,
|
| 364 |
+
number: fragmentNumber,
|
| 365 |
+
};
|
| 366 |
+
|
| 367 |
+
// Persist the blob to IndexedDB so it survives reload. Fire-and-forget.
|
| 368 |
+
putFragmentBlob(scope, fragment.id, blob).catch((err) => {
|
| 369 |
+
console.warn(`Channel ${index + 1} fragment persist failed:`, err);
|
| 370 |
+
});
|
| 371 |
+
|
| 372 |
+
// Append to history with FRAGMENT_CAP eviction (oldest unstarred first).
|
| 373 |
+
setFragments((prev) => {
|
| 374 |
+
const combined = [...prev, fragment];
|
| 375 |
+
if (combined.length <= FRAGMENT_CAP) return combined;
|
| 376 |
+
const trimmed = combined.slice();
|
| 377 |
+
while (trimmed.length > FRAGMENT_CAP) {
|
| 378 |
+
let idx = -1;
|
| 379 |
+
for (let j = 0; j < trimmed.length; j++) {
|
| 380 |
+
if (!trimmed[j].starred) { idx = j; break; }
|
| 381 |
+
}
|
| 382 |
+
if (idx < 0) idx = 0; // all starred β drop oldest
|
| 383 |
+
const dying = trimmed[idx];
|
| 384 |
+
if (dying.audioUrl?.startsWith('blob:')) {
|
| 385 |
+
try { URL.revokeObjectURL(dying.audioUrl); } catch { /* ignore */ }
|
| 386 |
+
}
|
| 387 |
+
deleteFragmentBlob(scope, dying.id).catch(() => { /* ignore */ });
|
| 388 |
+
trimmed.splice(idx, 1);
|
| 389 |
+
}
|
| 390 |
+
return trimmed;
|
| 391 |
+
});
|
| 392 |
+
|
| 393 |
+
// Generating must never disturb playback: a playing channel keeps
|
| 394 |
+
// looping its current clip while new fragments just pile into the
|
| 395 |
+
// history list. Only auto-load when the channel has nothing loaded yet
|
| 396 |
+
// (first-ever fragment) β harmless since nothing is playing β so the
|
| 397 |
+
// user still gets a ready-to-play clip on a fresh channel. To start a
|
| 398 |
+
// newly generated fragment, pick it from the list (handleCommitFragment).
|
| 399 |
+
if (i === 0 && !loaded) {
|
| 400 |
+
await strip.loadBlob(blob);
|
| 401 |
+
setCommittedFragmentId(fragment.id);
|
| 402 |
+
setLoaded(true);
|
| 403 |
+
onStateChange?.(index, { loaded: true });
|
| 404 |
+
requestAnimationFrame(drawWave);
|
| 405 |
+
}
|
| 406 |
+
};
|
| 407 |
+
|
| 408 |
const handleGenerate = async () => {
|
| 409 |
if (!prompt.trim() || generating) return;
|
| 410 |
const inBarsMode = durationMode === 'bars';
|
| 411 |
const effectiveDuration = inBarsMode ? secondsFromBars : duration;
|
| 412 |
setGenerating(true);
|
| 413 |
+
// Stop any in-flight cue audition so the old preview doesn't keep
|
| 414 |
+
// playing while we generate the new fragment.
|
| 415 |
stopCue();
|
| 416 |
+
setAuditioningFragmentId(null);
|
| 417 |
+
|
| 418 |
+
const promptSnap = prompt.trim();
|
| 419 |
+
|
| 420 |
try {
|
| 421 |
+
await onGenerate({
|
| 422 |
prompt,
|
| 423 |
duration: effectiveDuration,
|
| 424 |
batchSize,
|
| 425 |
// Only forward alignment params in bars mode β seconds mode
|
| 426 |
// generates raw audio with no post-processing.
|
| 427 |
...(inBarsMode ? { alignBars: bars, alignBpm: bpm } : {}),
|
| 428 |
+
// Phase 7: bars-mode + channel-looping β ask the backend
|
| 429 |
+
// to wrap-inpaint the seam so the clip loops seamlessly.
|
| 430 |
+
...(inBarsMode && looping ? { loopStitch: 'inpaint' } : {}),
|
| 431 |
+
onBlob: makeOnBlob(promptSnap, effectiveDuration),
|
| 432 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
} catch (err) {
|
| 434 |
console.error(`Channel ${index + 1} generate failed:`, err);
|
| 435 |
} finally {
|
|
|
|
| 437 |
}
|
| 438 |
};
|
| 439 |
|
| 440 |
+
// Phase 8 "Variation": re-roll the channel using its current fragment as
|
| 441 |
+
// init_audio at a high noise level β gives a related-but-different take
|
| 442 |
+
// (A/B/A/C/A live sets). Uploads the source blob to get a server path,
|
| 443 |
+
// then routes through the same generate flow.
|
| 444 |
+
const handleVariation = async () => {
|
| 445 |
+
if (generating) return;
|
| 446 |
+
const src = fragments.find((f) => f.id === committedFragmentId)
|
| 447 |
+
|| fragments[fragments.length - 1];
|
| 448 |
+
if (!src?.blob) return;
|
| 449 |
+
const inBarsMode = durationMode === 'bars';
|
| 450 |
+
const effectiveDuration = inBarsMode ? secondsFromBars : duration;
|
| 451 |
+
const promptSnap = (prompt || '').trim() || src.prompt || 'variation';
|
| 452 |
+
setGenerating(true);
|
| 453 |
+
stopCue();
|
| 454 |
+
setAuditioningFragmentId(null);
|
| 455 |
+
try {
|
| 456 |
+
const form = new FormData();
|
| 457 |
+
form.append('file', new File([src.blob], `${scope}_variation_src.wav`, { type: 'audio/wav' }));
|
| 458 |
+
const up = await api.post('/api/audio/upload', form);
|
| 459 |
+
await onGenerate({
|
| 460 |
+
prompt: promptSnap,
|
| 461 |
+
duration: effectiveDuration,
|
| 462 |
+
batchSize: 1,
|
| 463 |
+
initAudioPath: up.data.path,
|
| 464 |
+
initNoiseLevel: 0.9,
|
| 465 |
+
onBlob: makeOnBlob(promptSnap, effectiveDuration),
|
| 466 |
+
});
|
| 467 |
+
} catch (err) {
|
| 468 |
+
console.error(`Channel ${index + 1} variation failed:`, err);
|
| 469 |
+
} finally {
|
| 470 |
+
setGenerating(false);
|
| 471 |
+
}
|
| 472 |
+
};
|
| 473 |
+
|
| 474 |
+
// Fragment history actions β toggle audition through cue, commit a
|
| 475 |
+
// fragment to the channel buffer, star/unstar, delete one, or clear
|
| 476 |
+
// the whole list.
|
| 477 |
+
const handleAuditionFragment = async (fragmentId) => {
|
| 478 |
+
const fragment = fragments.find((f) => f.id === fragmentId);
|
| 479 |
+
if (!fragment) return;
|
| 480 |
+
if (auditioningFragmentId === fragmentId) {
|
| 481 |
stopCue();
|
| 482 |
+
setAuditioningFragmentId(null);
|
| 483 |
return;
|
| 484 |
}
|
| 485 |
+
setAuditioningFragmentId(fragmentId);
|
| 486 |
try {
|
| 487 |
+
await playCueBlob(fragment.blob, {
|
| 488 |
+
onEnded: () => setAuditioningFragmentId((prev) => (prev === fragmentId ? null : prev)),
|
| 489 |
});
|
| 490 |
} catch (err) {
|
| 491 |
console.warn(`Channel ${index + 1} audition failed:`, err);
|
| 492 |
+
setAuditioningFragmentId(null);
|
| 493 |
}
|
| 494 |
};
|
| 495 |
|
| 496 |
+
// Choosing a fragment launches it from the beginning. The currently
|
| 497 |
+
// playing clip (if any) keeps sounding until the launch point: immediately
|
| 498 |
+
// in seconds mode or when launch quantization is None, otherwise at the
|
| 499 |
+
// next launch-quantization bar. The buffer is decoded WITHOUT stopping the
|
| 500 |
+
// live source, so the swap is gapless (the engine schedules the handoff).
|
| 501 |
+
const handleCommitFragment = async (fragmentId) => {
|
| 502 |
+
const fragment = fragments.find((f) => f.id === fragmentId);
|
| 503 |
+
if (!fragment) return;
|
| 504 |
+
const sameFragment = committedFragmentId === fragmentId;
|
| 505 |
+
// Already looping this exact clip β nothing to (re)launch.
|
| 506 |
+
if (sameFragment && playing) return;
|
| 507 |
+
|
| 508 |
+
// Decode the new clip without cutting the live source; skip the decode
|
| 509 |
+
// when this fragment's buffer is already loaded.
|
| 510 |
+
if (!sameFragment) {
|
| 511 |
+
await strip.loadBufferFromBlob(fragment.blob);
|
| 512 |
+
setCommittedFragmentId(fragmentId);
|
| 513 |
+
}
|
| 514 |
+
// Mark loaded so the play button enables (covers preset/hydrated flows
|
| 515 |
+
// where the first commit happens here rather than via generate).
|
| 516 |
+
if (!loaded) {
|
| 517 |
+
setLoaded(true);
|
| 518 |
+
onStateChange?.(index, { loaded: true });
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
// Launch from the top. Seconds mode is always immediate; bars mode
|
| 522 |
+
// defers to the engine's launch-quantization (ASAP when quantum=None).
|
| 523 |
+
const immediate = durationMode === 'seconds';
|
| 524 |
+
if (engine) engine.relaunchChannel(index, looping, immediate);
|
| 525 |
+
else strip.playAt(looping, 0);
|
| 526 |
+
onStateChange?.(index, { playing: true });
|
| 527 |
requestAnimationFrame(drawWave);
|
| 528 |
};
|
| 529 |
|
| 530 |
+
// Drag-and-drop: a fragment row from this channel's history can be
|
| 531 |
+
// dropped onto the waveform monitor to load it (same effect as the row's
|
| 532 |
+
// commit β button). The MIME type is channel-scoped, so a row from
|
| 533 |
+
// channel 1 won't even highlight channel 2's waveform β the browser
|
| 534 |
+
// filters at dragOver level via dataTransfer.types matching.
|
| 535 |
+
const dragMime = `application/x-fragmenta-fragment-ch${index}`;
|
| 536 |
+
const [dropActive, setDropActive] = useState(false);
|
| 537 |
+
// Counter pattern β dragenter/leave also fire when the cursor crosses
|
| 538 |
+
// into child elements (canvas, overlay). Without the counter, dropActive
|
| 539 |
+
// would flicker false whenever the cursor moved over a child.
|
| 540 |
+
const dragCounterRef = useRef(0);
|
| 541 |
+
|
| 542 |
+
const handleWaveDragEnter = (e) => {
|
| 543 |
+
if (!e.dataTransfer.types.includes(dragMime)) return;
|
| 544 |
+
e.preventDefault();
|
| 545 |
+
dragCounterRef.current += 1;
|
| 546 |
+
if (dragCounterRef.current === 1) setDropActive(true);
|
| 547 |
+
};
|
| 548 |
+
const handleWaveDragOver = (e) => {
|
| 549 |
+
if (!e.dataTransfer.types.includes(dragMime)) return;
|
| 550 |
+
e.preventDefault();
|
| 551 |
+
e.dataTransfer.dropEffect = 'copy';
|
| 552 |
+
};
|
| 553 |
+
const handleWaveDragLeave = () => {
|
| 554 |
+
dragCounterRef.current = Math.max(0, dragCounterRef.current - 1);
|
| 555 |
+
if (dragCounterRef.current === 0) setDropActive(false);
|
| 556 |
+
};
|
| 557 |
+
const handleWaveDrop = (e) => {
|
| 558 |
+
e.preventDefault();
|
| 559 |
+
dragCounterRef.current = 0;
|
| 560 |
+
setDropActive(false);
|
| 561 |
+
const fragmentId = e.dataTransfer.getData(dragMime);
|
| 562 |
+
if (fragmentId) handleCommitFragment(fragmentId);
|
| 563 |
+
};
|
| 564 |
+
|
| 565 |
+
const handleToggleStar = (fragmentId) => {
|
| 566 |
+
setFragments((prev) => prev.map((f) =>
|
| 567 |
+
f.id === fragmentId ? { ...f, starred: !f.starred } : f,
|
| 568 |
+
));
|
| 569 |
+
};
|
| 570 |
+
|
| 571 |
+
const handleDeleteFragment = (fragmentId) => {
|
| 572 |
+
const target = fragments.find((f) => f.id === fragmentId);
|
| 573 |
+
if (target?.audioUrl?.startsWith('blob:')) {
|
| 574 |
+
try { URL.revokeObjectURL(target.audioUrl); } catch { /* ignore */ }
|
| 575 |
+
}
|
| 576 |
+
deleteFragmentBlob(scope, fragmentId).catch(() => { /* ignore */ });
|
| 577 |
+
setFragments((prev) => prev.filter((f) => f.id !== fragmentId));
|
| 578 |
+
if (committedFragmentId === fragmentId) setCommittedFragmentId(null);
|
| 579 |
+
if (auditioningFragmentId === fragmentId) {
|
| 580 |
+
stopCue();
|
| 581 |
+
setAuditioningFragmentId(null);
|
| 582 |
+
}
|
| 583 |
+
};
|
| 584 |
+
|
| 585 |
+
const handleClearFragments = () => {
|
| 586 |
+
// Stop any in-flight audition and revoke every blob URL before
|
| 587 |
+
// dropping references β otherwise the URLs leak until reload.
|
| 588 |
+
stopCue();
|
| 589 |
+
setAuditioningFragmentId(null);
|
| 590 |
+
fragments.forEach((f) => {
|
| 591 |
+
if (f.audioUrl?.startsWith('blob:')) {
|
| 592 |
+
try { URL.revokeObjectURL(f.audioUrl); } catch { /* ignore */ }
|
| 593 |
+
}
|
| 594 |
+
});
|
| 595 |
+
clearFragmentScope(scope).catch(() => { /* ignore */ });
|
| 596 |
+
setFragments([]);
|
| 597 |
+
setCommittedFragmentId(null);
|
| 598 |
+
};
|
| 599 |
+
|
| 600 |
const handlePlay = () => {
|
| 601 |
if (!loaded) return;
|
| 602 |
if (engine) engine.playChannel(index, looping);
|
|
|
|
| 633 |
setKnobs(prev => ({ ...prev, [key]: value }));
|
| 634 |
if (key === 'gain') strip.setUserGain(gainDbToLinear(value));
|
| 635 |
else if (key === 'pan') strip.setPan(value);
|
| 636 |
+
else if (key === 'filter') {
|
| 637 |
+
const { lpf, hpf } = bipolarToFilterFreqs(value);
|
| 638 |
+
strip.setFilter(lpf);
|
| 639 |
+
strip.setHighpass(hpf);
|
| 640 |
+
}
|
| 641 |
else if (key === 'delay') strip.setDelayMix(value);
|
| 642 |
else if (key === 'reverb') strip.setReverbMix(value);
|
| 643 |
};
|
|
|
|
| 659 |
return (
|
| 660 |
<Box sx={styles.strip(color, playing)}>
|
| 661 |
<Box sx={styles.stripHeader(color)}>
|
| 662 |
+
{/* Transport (Play / Loop) on the left, Mute / Solo on the
|
| 663 |
+
right β replaces the old "01" channel badge so the channel
|
| 664 |
+
number isn't using up that slot. */}
|
| 665 |
+
<Box sx={styles.muteSoloRow}>
|
| 666 |
+
<MidiMappable id={ctrlId('transport')} label={ctrlLabel('Play/Stop')} kind="trigger" onChange={handleTransportToggle}>
|
| 667 |
+
<IconButton
|
| 668 |
+
onClick={playing ? handleStop : handlePlay}
|
| 669 |
+
disabled={!loaded}
|
| 670 |
+
sx={styles.transportBtn(color, playing)}
|
| 671 |
+
size="small"
|
| 672 |
+
>
|
| 673 |
+
{playing ? <StopIcon size={16} /> : <PlayIcon size={16} />}
|
| 674 |
+
</IconButton>
|
| 675 |
+
</MidiMappable>
|
| 676 |
+
<MidiMappable id={ctrlId('loop')} label={ctrlLabel('Loop')} kind="trigger" onChange={handleLoopToggle}>
|
| 677 |
+
<Tooltip
|
| 678 |
+
title={TIPS.channel.loop(looping, durationMode)}
|
| 679 |
+
placement="top"
|
| 680 |
+
enterDelay={400}
|
| 681 |
+
>
|
| 682 |
+
<IconButton
|
| 683 |
+
onClick={handleLoopToggle}
|
| 684 |
+
sx={styles.loopBtn(color, looping)}
|
| 685 |
+
size="small"
|
| 686 |
+
aria-label={looping ? 'Loop on' : 'Loop off'}
|
| 687 |
+
>
|
| 688 |
+
L
|
| 689 |
+
</IconButton>
|
| 690 |
+
</Tooltip>
|
| 691 |
+
</MidiMappable>
|
| 692 |
+
</Box>
|
| 693 |
<Box sx={styles.muteSoloRow}>
|
| 694 |
<MidiMappable id={ctrlId('mute')} label={ctrlLabel('Mute')} kind="trigger" onChange={handleMuteToggle}>
|
| 695 |
+
<Tooltip title={TIPS.channel.mute}>
|
| 696 |
<IconButton size="small" onClick={handleMuteToggle} sx={styles.muteBtn(muted)}>M</IconButton>
|
| 697 |
</Tooltip>
|
| 698 |
</MidiMappable>
|
| 699 |
<MidiMappable id={ctrlId('solo')} label={ctrlLabel('Solo')} kind="trigger" onChange={handleSoloToggle}>
|
| 700 |
+
<Tooltip title={TIPS.channel.solo}>
|
| 701 |
<IconButton size="small" onClick={handleSoloToggle} sx={styles.soloBtn(soloed)}>S</IconButton>
|
| 702 |
</Tooltip>
|
| 703 |
</MidiMappable>
|
|
|
|
| 706 |
|
| 707 |
<Box sx={styles.promptBox}>
|
| 708 |
<TextField
|
| 709 |
+
placeholder="Promptβ¦"
|
| 710 |
value={prompt}
|
| 711 |
onChange={(e) => setPrompt(e.target.value)}
|
| 712 |
multiline
|
| 713 |
minRows={2}
|
| 714 |
+
maxRows={2}
|
| 715 |
size="small"
|
| 716 |
fullWidth
|
| 717 |
sx={styles.promptField}
|
| 718 |
disabled={generating}
|
| 719 |
/>
|
| 720 |
<Box sx={{ ...styles.durationRow, minHeight: 26, height: 26 }}>
|
| 721 |
+
{durationMode === 'seconds' ? (
|
| 722 |
+
<>
|
| 723 |
+
<Typography sx={styles.durationLabel}>{duration.toFixed(0)}s</Typography>
|
| 724 |
+
<Slider
|
| 725 |
+
value={duration}
|
| 726 |
+
onChange={(_, v) => setDuration(v)}
|
| 727 |
+
min={2}
|
| 728 |
+
max={maxDuration}
|
| 729 |
+
step={1}
|
| 730 |
+
size="small"
|
| 731 |
+
sx={styles.durationSlider(color)}
|
| 732 |
+
/>
|
| 733 |
+
</>
|
| 734 |
+
) : (
|
| 735 |
+
<Select
|
| 736 |
+
value={availableBars.includes(bars) ? bars : availableBars[availableBars.length - 1]}
|
| 737 |
+
onChange={(e) => setBars(Number(e.target.value))}
|
| 738 |
+
size="small"
|
| 739 |
+
sx={{ ...panelStyles.pillControl, flex: 1 }}
|
| 740 |
+
>
|
| 741 |
+
{availableBars.map(b => (
|
| 742 |
+
<MenuItem key={b} value={b} sx={{ fontSize: perfTokens.fontSize.sm }}>
|
| 743 |
+
{b} {b === 1 ? 'bar' : 'bars'}
|
| 744 |
+
</MenuItem>
|
| 745 |
+
))}
|
| 746 |
+
</Select>
|
| 747 |
+
)}
|
| 748 |
+
|
| 749 |
+
{/* Sec/Bars mode toggle β moved to the right of the row so
|
| 750 |
+
it mirrors the Generate row layout (content fills left,
|
| 751 |
+
modifier sits right). Width matches the ΓN selector so
|
| 752 |
+
the right column reads as a uniform stack. */}
|
| 753 |
<Box
|
| 754 |
sx={{
|
| 755 |
display: 'inline-flex',
|
|
|
|
| 758 |
borderRadius: 0.75,
|
| 759 |
overflow: 'hidden',
|
| 760 |
height: '100%',
|
| 761 |
+
flexShrink: 0,
|
| 762 |
}}
|
| 763 |
>
|
| 764 |
+
{[
|
| 765 |
+
{ mode: 'sec', label: 'Sec' },
|
| 766 |
+
{ mode: 'bars', label: 'Bars' },
|
| 767 |
+
].map(({ mode, label }) => {
|
| 768 |
const value = mode === 'sec' ? 'seconds' : 'bars';
|
| 769 |
const active = durationMode === value;
|
| 770 |
return (
|
|
|
|
| 772 |
key={mode}
|
| 773 |
onClick={() => setDurationMode(value)}
|
| 774 |
sx={{
|
| 775 |
+
fontSize: perfTokens.fontSize.sm,
|
|
|
|
|
|
|
|
|
|
| 776 |
px: 0.7,
|
| 777 |
+
minWidth: 36,
|
| 778 |
bgcolor: active ? color : 'transparent',
|
| 779 |
+
backgroundImage: active ? SHEEN_DARK : 'none',
|
| 780 |
+
boxShadow: active ? RAISE_DARK : 'none',
|
| 781 |
color: active ? 'rgba(0,0,0,0.88)' : 'text.disabled',
|
| 782 |
+
fontWeight: active ? perfTokens.weight.bold : perfTokens.weight.regular,
|
| 783 |
+
transition: 'background-color 120ms, color 120ms, box-shadow 120ms',
|
| 784 |
'&:hover': {
|
| 785 |
bgcolor: active ? color : 'action.hover',
|
| 786 |
color: active ? 'rgba(0,0,0,0.88)' : 'text.secondary',
|
| 787 |
},
|
| 788 |
}}
|
| 789 |
>
|
| 790 |
+
{label}
|
| 791 |
</ButtonBase>
|
| 792 |
);
|
| 793 |
})}
|
| 794 |
</Box>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
</Box>
|
| 796 |
<Box sx={{
|
| 797 |
display: 'flex',
|
| 798 |
alignItems: 'center',
|
| 799 |
+
gap: 1,
|
|
|
|
| 800 |
mt: 0.5,
|
| 801 |
width: '100%',
|
| 802 |
}}>
|
| 803 |
+
{/* Generate pill β wide CTA on the left so the eye lands
|
| 804 |
+
on the primary action first. Fills left-to-right with
|
| 805 |
+
live progress while generating; resets when complete. */}
|
| 806 |
+
<MidiMappable id={ctrlId('generate')} label={ctrlLabel('Generate')} kind="trigger" onChange={handleGenerate}>
|
| 807 |
+
<Tooltip
|
| 808 |
+
title={TIPS.channel.generateDisabled(generating, canGenerate, prompt.trim())}
|
| 809 |
+
placement="top"
|
| 810 |
+
>
|
| 811 |
+
<span style={{ display: 'inline-flex', flex: 1, minWidth: 0 }}>
|
| 812 |
+
<ButtonBase
|
| 813 |
+
onClick={handleGenerate}
|
| 814 |
+
disabled={!canGenerate || !prompt.trim() || generating}
|
| 815 |
+
sx={styles.generatePill(color, {
|
| 816 |
+
generating,
|
| 817 |
+
disabled: !canGenerate || !prompt.trim(),
|
| 818 |
+
})}
|
| 819 |
+
>
|
| 820 |
+
{generating && (
|
| 821 |
+
<Box sx={styles.generatePillFill(color, progress)} />
|
| 822 |
+
)}
|
| 823 |
+
<Box component="span" sx={styles.generatePillLabel}>
|
| 824 |
+
{generating
|
| 825 |
+
? `Generating Β· ${Math.round(progress)}%`
|
| 826 |
+
: 'Generate'}
|
| 827 |
+
{!generating && <GenerateArrowIcon size={14} strokeWidth={2.25} />}
|
| 828 |
+
</Box>
|
| 829 |
+
</ButtonBase>
|
| 830 |
+
</span>
|
| 831 |
+
</Tooltip>
|
| 832 |
+
</MidiMappable>
|
| 833 |
+
|
| 834 |
+
{/* Batch selector β sits right of Generate so the row
|
| 835 |
+
reads "Generate Γ 4" (action then modifier). Sized to
|
| 836 |
+
its content (Γ1οΏ½οΏ½Γ8 + dropdown arrow); no need to match
|
| 837 |
+
the wider Sec/Bars toggle above. */}
|
| 838 |
<Tooltip
|
| 839 |
+
title={TIPS.channel.batch}
|
| 840 |
placement="top"
|
|
|
|
|
|
|
| 841 |
enterDelay={500}
|
| 842 |
>
|
| 843 |
<Select
|
| 844 |
value={batchSize}
|
| 845 |
onChange={(e) => setBatchSize(Number(e.target.value))}
|
|
|
|
| 846 |
disabled={generating}
|
| 847 |
+
size="small"
|
| 848 |
+
sx={{ ...styles.channelPillControl, width: 54, flexShrink: 0 }}
|
| 849 |
+
renderValue={(v) => `Γ${v}`}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
>
|
| 851 |
+
{BATCH_OPTIONS.map((n) => (
|
| 852 |
+
<MenuItem
|
| 853 |
+
key={n}
|
| 854 |
+
value={n}
|
| 855 |
+
sx={{ fontSize: perfTokens.fontSize.sm, fontVariantNumeric: 'tabular-nums' }}
|
| 856 |
+
>
|
| 857 |
Γ{n}
|
| 858 |
</MenuItem>
|
| 859 |
))}
|
| 860 |
</Select>
|
| 861 |
</Tooltip>
|
| 862 |
+
|
| 863 |
+
{/* Variation β re-roll from the current fragment as
|
| 864 |
+
init_audio (Phase 8). Disabled until a fragment exists. */}
|
| 865 |
+
<MidiMappable id={ctrlId('variation')} label={ctrlLabel('Variation')} kind="trigger" onChange={handleVariation}>
|
| 866 |
+
<Tooltip
|
| 867 |
+
title={TIPS.channel.variation(loaded)}
|
| 868 |
+
placement="top"
|
| 869 |
>
|
| 870 |
+
<span style={{ display: 'inline-flex', flexShrink: 0 }}>
|
| 871 |
+
<ButtonBase
|
| 872 |
+
onClick={handleVariation}
|
| 873 |
+
disabled={!loaded || generating}
|
| 874 |
+
sx={{
|
| 875 |
+
...styles.channelPillControl,
|
| 876 |
+
width: 40,
|
| 877 |
+
justifyContent: 'center',
|
| 878 |
+
'&.Mui-disabled': { opacity: 0.4, color: 'text.disabled' },
|
| 879 |
+
}}
|
| 880 |
+
aria-label="Variation"
|
| 881 |
+
>
|
| 882 |
+
<VariationIcon size={15} strokeWidth={2.25} />
|
| 883 |
+
</ButtonBase>
|
| 884 |
+
</span>
|
| 885 |
+
</Tooltip>
|
| 886 |
</MidiMappable>
|
| 887 |
</Box>
|
| 888 |
</Box>
|
| 889 |
|
| 890 |
+
<Box
|
| 891 |
+
onDragEnter={handleWaveDragEnter}
|
| 892 |
+
onDragOver={handleWaveDragOver}
|
| 893 |
+
onDragLeave={handleWaveDragLeave}
|
| 894 |
+
onDrop={handleWaveDrop}
|
| 895 |
+
sx={[
|
| 896 |
+
styles.waveformWrap,
|
| 897 |
+
dropActive && {
|
| 898 |
+
borderColor: color,
|
| 899 |
+
boxShadow: `inset 0 0 0 2px ${color}`,
|
| 900 |
+
backgroundColor: `${color}1F`,
|
| 901 |
+
transition: 'border-color 120ms, box-shadow 120ms, background-color 120ms',
|
| 902 |
+
},
|
| 903 |
+
]}
|
| 904 |
+
>
|
| 905 |
<canvas
|
| 906 |
ref={canvasRef}
|
| 907 |
width={140}
|
| 908 |
height={42}
|
| 909 |
+
style={{ width: '100%', height: 42, display: 'block', pointerEvents: 'none' }}
|
| 910 |
/>
|
| 911 |
{!loaded && (
|
| 912 |
+
<Typography sx={styles.waveformPlaceholder}>
|
| 913 |
+
{dropActive ? 'Drop to load' : 'Waveform'}
|
| 914 |
</Typography>
|
| 915 |
)}
|
| 916 |
</Box>
|
| 917 |
|
| 918 |
+
{/* Per-channel rolling fragment history. Always rendered (empty
|
| 919 |
+
state included). Star/keep, delete, audition, load β all
|
| 920 |
+
inline per row. Capped at FRAGMENT_CAP via FIFO with star
|
| 921 |
+
priority. */}
|
| 922 |
+
<ChannelFragmentHistory
|
| 923 |
+
fragments={fragments}
|
| 924 |
+
color={color}
|
| 925 |
+
channelIndex={index}
|
| 926 |
+
auditioningId={auditioningFragmentId}
|
| 927 |
+
committedId={committedFragmentId}
|
| 928 |
+
maxFragments={FRAGMENT_CAP}
|
| 929 |
+
onAudition={handleAuditionFragment}
|
| 930 |
+
onCommit={handleCommitFragment}
|
| 931 |
+
onToggleStar={handleToggleStar}
|
| 932 |
+
onDelete={handleDeleteFragment}
|
| 933 |
+
onClearAll={handleClearFragments}
|
| 934 |
+
/>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
|
| 936 |
<Box sx={{ px: 1, py: 1 }}>
|
| 937 |
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1, mb: 1 }}>
|
| 938 |
+
<Box component="span" sx={{ ...perfTokens.caps, color: 'text.secondary', minWidth: 28 }}>PAN</Box>
|
| 939 |
<MidiMappable
|
| 940 |
id={ctrlId('pan')}
|
| 941 |
label={ctrlLabel('Pan')}
|
|
|
|
| 957 |
marks={[{ value: 0 }]}
|
| 958 |
sx={{
|
| 959 |
flex: 1,
|
| 960 |
+
// Match the channel main color (the global
|
| 961 |
+
// MuiSlider override is amber; the vertical
|
| 962 |
+
// knobs already pass `color` via knobSlider).
|
| 963 |
+
color,
|
| 964 |
'& .MuiSlider-mark': {
|
| 965 |
width: 2,
|
| 966 |
height: 10,
|
|
|
|
| 981 |
<Box sx={styles.knobsGrid}>
|
| 982 |
{KNOB_DEFS.map((k) => {
|
| 983 |
const isLog = k.scale === 'log';
|
| 984 |
+
const isBipolar = k.scale === 'bipolar';
|
| 985 |
// For log knobs, the slider drives a 0..1 position and we
|
| 986 |
// convert to/from the underlying value (Hz) on the audio
|
| 987 |
// boundary. The knob value stored in state stays in the
|
|
|
|
| 1013 |
max={isLog ? 1 : k.max}
|
| 1014 |
step={isLog ? 0.001 : k.step}
|
| 1015 |
size="small"
|
| 1016 |
+
track={isBipolar ? false : undefined}
|
| 1017 |
+
marks={isBipolar ? [{ value: 0 }] : undefined}
|
| 1018 |
+
sx={{
|
| 1019 |
+
...styles.knobSlider(color, k.key === 'gain'),
|
| 1020 |
+
...(isBipolar && {
|
| 1021 |
+
'& .MuiSlider-mark': {
|
| 1022 |
+
width: 10,
|
| 1023 |
+
height: 2,
|
| 1024 |
+
borderRadius: 1,
|
| 1025 |
+
backgroundColor: 'text.secondary',
|
| 1026 |
+
opacity: 0.7,
|
| 1027 |
+
},
|
| 1028 |
+
'& .MuiSlider-markActive': {
|
| 1029 |
+
backgroundColor: 'text.secondary',
|
| 1030 |
+
opacity: 0.7,
|
| 1031 |
+
},
|
| 1032 |
+
}),
|
| 1033 |
+
}}
|
| 1034 |
/>
|
| 1035 |
</MidiMappable>
|
| 1036 |
<Box component="span" sx={styles.knobLabel}>{k.label}</Box>
|
|
|
|
| 1039 |
})}
|
| 1040 |
</Box>
|
| 1041 |
|
| 1042 |
+
{/* Bottom row is now just the channel level meter β Play and Loop
|
| 1043 |
+
moved to the top header so the channel reads "controls on top,
|
| 1044 |
+
signal flow below". */}
|
| 1045 |
<Box sx={styles.transportRow}>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
<Box sx={styles.meterTrack}>
|
| 1047 |
<Box ref={meterRef} sx={styles.meterFill(color)} />
|
| 1048 |
</Box>
|
app/frontend/src/components/PerformancePanel.js
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/frontend/src/components/StorageDrilldown.js
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Dialog,
|
| 4 |
+
DialogTitle,
|
| 5 |
+
DialogContent,
|
| 6 |
+
DialogActions,
|
| 7 |
+
Button,
|
| 8 |
+
Typography,
|
| 9 |
+
Box,
|
| 10 |
+
LinearProgress,
|
| 11 |
+
Table,
|
| 12 |
+
TableBody,
|
| 13 |
+
TableCell,
|
| 14 |
+
TableHead,
|
| 15 |
+
TableRow,
|
| 16 |
+
} from '@mui/material';
|
| 17 |
+
|
| 18 |
+
const fmtBytes = (n) => {
|
| 19 |
+
if (!n) return 'β';
|
| 20 |
+
// Decimal (SI) units β matches what HuggingFace shows next to safetensors files.
|
| 21 |
+
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
|
| 22 |
+
let v = n;
|
| 23 |
+
let u = 0;
|
| 24 |
+
while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; }
|
| 25 |
+
return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`;
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
export default function StorageDrilldown({ open, onClose, storage, catalog }) {
|
| 29 |
+
if (!storage) return null;
|
| 30 |
+
|
| 31 |
+
const usedPct = storage.total_used_bytes && (storage.total_used_bytes + storage.total_free_bytes)
|
| 32 |
+
? (storage.total_used_bytes / (storage.total_used_bytes + storage.total_free_bytes)) * 100
|
| 33 |
+
: 0;
|
| 34 |
+
|
| 35 |
+
const nameFor = (id) => catalog?.find(c => c.id === id)?.name || id;
|
| 36 |
+
|
| 37 |
+
const rows = (storage.per_model || [])
|
| 38 |
+
.filter(m => m.downloaded)
|
| 39 |
+
.sort((a, b) => b.bytes - a.bytes);
|
| 40 |
+
|
| 41 |
+
return (
|
| 42 |
+
<Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
|
| 43 |
+
<DialogTitle>Storage</DialogTitle>
|
| 44 |
+
<DialogContent dividers>
|
| 45 |
+
<Box sx={{ mb: 2 }}>
|
| 46 |
+
<Typography variant="body2" color="text.secondary">
|
| 47 |
+
{fmtBytes(storage.total_used_bytes)} used Β· {fmtBytes(storage.total_free_bytes)} free
|
| 48 |
+
</Typography>
|
| 49 |
+
<LinearProgress
|
| 50 |
+
variant="determinate"
|
| 51 |
+
value={Math.min(100, usedPct)}
|
| 52 |
+
sx={{ mt: 1, height: 6, borderRadius: 3 }}
|
| 53 |
+
/>
|
| 54 |
+
</Box>
|
| 55 |
+
|
| 56 |
+
{rows.length === 0 ? (
|
| 57 |
+
<Typography variant="body2" color="text.secondary" sx={{ py: 2 }}>
|
| 58 |
+
Nothing downloaded yet.
|
| 59 |
+
</Typography>
|
| 60 |
+
) : (
|
| 61 |
+
<Table size="small">
|
| 62 |
+
<TableHead>
|
| 63 |
+
<TableRow>
|
| 64 |
+
<TableCell>Checkpoint</TableCell>
|
| 65 |
+
<TableCell align="right">Size</TableCell>
|
| 66 |
+
</TableRow>
|
| 67 |
+
</TableHead>
|
| 68 |
+
<TableBody>
|
| 69 |
+
{rows.map(m => (
|
| 70 |
+
<TableRow key={m.id}>
|
| 71 |
+
<TableCell>{nameFor(m.id)}</TableCell>
|
| 72 |
+
<TableCell align="right">{fmtBytes(m.bytes)}</TableCell>
|
| 73 |
+
</TableRow>
|
| 74 |
+
))}
|
| 75 |
+
</TableBody>
|
| 76 |
+
</Table>
|
| 77 |
+
)}
|
| 78 |
+
</DialogContent>
|
| 79 |
+
<DialogActions>
|
| 80 |
+
<Button onClick={onClose}>Close</Button>
|
| 81 |
+
</DialogActions>
|
| 82 |
+
</Dialog>
|
| 83 |
+
);
|
| 84 |
+
}
|
app/frontend/src/components/Tooltip.js
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { cloneElement } from 'react';
|
| 2 |
+
import { useInfoView } from './InfoView';
|
| 3 |
+
|
| 4 |
+
/**
|
| 5 |
+
* App-wide Tooltip.
|
| 6 |
+
*
|
| 7 |
+
* Help text is shown exclusively through the Info View (see
|
| 8 |
+
* components/InfoView.js) β there are no popup tooltips on the controls:
|
| 9 |
+
*
|
| 10 |
+
* β’ Info View ON β on hover/focus the `title` is reported to the bottom
|
| 11 |
+
* Info View pill, so help shows in one fixed place rather than over the
|
| 12 |
+
* control itself.
|
| 13 |
+
* β’ Info View OFF β no hover help at all; the child renders untouched.
|
| 14 |
+
*
|
| 15 |
+
* The API matches MUI's Tooltip (drop-in for the existing call sites): pass a
|
| 16 |
+
* `title` plus a single child element. Placement/arrow/delay props are
|
| 17 |
+
* accepted but ignored, since there's no popup.
|
| 18 |
+
*/
|
| 19 |
+
export default function Tooltip({ children, title }) {
|
| 20 |
+
const { enabled, setHint } = useInfoView();
|
| 21 |
+
const child = React.Children.only(children);
|
| 22 |
+
|
| 23 |
+
if (!enabled) {
|
| 24 |
+
// No popup tooltips β the Info View is the only help surface.
|
| 25 |
+
return child;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// Info View mode: route the tip to the bottom pill on hover/focus.
|
| 29 |
+
return cloneElement(child, {
|
| 30 |
+
onMouseEnter: (e) => { setHint(title); child.props?.onMouseEnter?.(e); },
|
| 31 |
+
onMouseLeave: (e) => { setHint(null); child.props?.onMouseLeave?.(e); },
|
| 32 |
+
onFocus: (e) => { setHint(title); child.props?.onFocus?.(e); },
|
| 33 |
+
onBlur: (e) => { setHint(null); child.props?.onBlur?.(e); },
|
| 34 |
+
});
|
| 35 |
+
}
|
app/frontend/src/components/TrainingMonitor.js
CHANGED
|
@@ -1,14 +1,26 @@
|
|
| 1 |
-
import React from 'react';
|
| 2 |
import { Paper, Box, Typography, LinearProgress, Grid, Alert } from '@mui/material';
|
| 3 |
import { Activity as ActivityIcon } from 'lucide-react';
|
| 4 |
import LossChart from './LossChart';
|
| 5 |
import { trainingMonitorStyles } from '../theme';
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
export default function TrainingMonitor({
|
| 8 |
trainingProgress,
|
| 9 |
trainingStatus,
|
|
|
|
| 10 |
trainingError,
|
| 11 |
-
trainingConfig,
|
| 12 |
indicatorState,
|
| 13 |
}) {
|
| 14 |
const getProgressColor = () => {
|
|
@@ -21,6 +33,39 @@ export default function TrainingMonitor({
|
|
| 21 |
const label = indicatorState?.label || 'Idle';
|
| 22 |
const animate = indicatorState?.animate || false;
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
return (
|
| 25 |
<Paper sx={trainingMonitorStyles.rootPaper}>
|
| 26 |
<Box sx={trainingMonitorStyles.headerRow}>
|
|
@@ -53,60 +98,56 @@ export default function TrainingMonitor({
|
|
| 53 |
/>
|
| 54 |
</Box>
|
| 55 |
|
| 56 |
-
{trainingStatus?.device_info && (
|
| 57 |
-
<Box sx={trainingMonitorStyles.deviceSection}>
|
| 58 |
-
<Typography variant="body2" color="textSecondary">
|
| 59 |
-
<strong>{
|
| 60 |
-
trainingStatus.device_info.type === 'cuda' ? 'CUDA' :
|
| 61 |
-
trainingStatus.device_info.type === 'mps' ? 'MPS' : 'CPU'
|
| 62 |
-
}</strong>
|
| 63 |
-
{' Β· '}{trainingStatus.device_info.device}
|
| 64 |
-
{trainingStatus.device_info.memory_gb
|
| 65 |
-
? ` Β· ${trainingStatus.device_info.memory_gb.toFixed(1)} GB`
|
| 66 |
-
: ''}
|
| 67 |
-
</Typography>
|
| 68 |
-
</Box>
|
| 69 |
-
)}
|
| 70 |
-
|
| 71 |
<Grid container spacing={2} sx={trainingMonitorStyles.metricsGrid}>
|
| 72 |
<Grid item xs={12} sm={6}>
|
| 73 |
-
<Typography variant="body2" color="textSecondary">
|
| 74 |
-
<Typography variant="body1">
|
| 75 |
-
{
|
| 76 |
-
`${trainingStatus.current_epoch + 1} / ${trainingConfig.epochs}` :
|
| 77 |
-
'0 / ' + trainingConfig.epochs}
|
| 78 |
</Typography>
|
| 79 |
</Grid>
|
| 80 |
<Grid item xs={12} sm={6}>
|
| 81 |
-
<Typography variant="body2" color="textSecondary">
|
| 82 |
-
<Typography variant="body1"
|
| 83 |
-
{
|
| 84 |
-
`${trainingStatus.global_step} / ${trainingStatus.total_steps}` :
|
| 85 |
-
'N/A'}
|
| 86 |
</Typography>
|
| 87 |
</Grid>
|
| 88 |
<Grid item xs={12} sm={6}>
|
| 89 |
<Typography variant="body2" color="textSecondary">Checkpoints Saved</Typography>
|
| 90 |
-
<Typography variant="body1">
|
| 91 |
-
{trainingStatus?.checkpoints_saved || 0}
|
| 92 |
-
</Typography>
|
| 93 |
</Grid>
|
| 94 |
<Grid item xs={12} sm={6}>
|
| 95 |
-
<Typography variant="body2" color="textSecondary">
|
| 96 |
-
<Typography variant="body1">
|
| 97 |
-
{trainingStatus?.
|
| 98 |
</Typography>
|
| 99 |
</Grid>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
</Grid>
|
| 101 |
|
| 102 |
-
{
|
| 103 |
<Box sx={trainingMonitorStyles.lossSection}>
|
| 104 |
<Typography variant="body2" color="textSecondary" gutterBottom>
|
| 105 |
<strong>Loss History</strong>
|
| 106 |
</Typography>
|
| 107 |
<Box sx={trainingMonitorStyles.lossChartBox}>
|
| 108 |
-
<LossChart data={
|
| 109 |
</Box>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
</Box>
|
| 111 |
)}
|
| 112 |
|
|
|
|
| 1 |
+
import React, { useMemo } from 'react';
|
| 2 |
import { Paper, Box, Typography, LinearProgress, Grid, Alert } from '@mui/material';
|
| 3 |
import { Activity as ActivityIcon } from 'lucide-react';
|
| 4 |
import LossChart from './LossChart';
|
| 5 |
import { trainingMonitorStyles } from '../theme';
|
| 6 |
|
| 7 |
+
/**
|
| 8 |
+
* TrainingMonitor β right-pane status card for the Training tab.
|
| 9 |
+
*
|
| 10 |
+
* Reads SA3-shaped status from the backend:
|
| 11 |
+
* { is_training, status, step, total_steps, current_step, progress,
|
| 12 |
+
* loss, checkpoints, checkpoints_saved, error, ... }
|
| 13 |
+
*
|
| 14 |
+
* SA3 trains by step count, not epochs β the panel surfaces step / total
|
| 15 |
+
* directly. Loss curve is built frontend-side from successive poll snapshots
|
| 16 |
+
* (trainingHistory) so we don't depend on the backend emitting a history
|
| 17 |
+
* array.
|
| 18 |
+
*/
|
| 19 |
export default function TrainingMonitor({
|
| 20 |
trainingProgress,
|
| 21 |
trainingStatus,
|
| 22 |
+
trainingHistory,
|
| 23 |
trainingError,
|
|
|
|
| 24 |
indicatorState,
|
| 25 |
}) {
|
| 26 |
const getProgressColor = () => {
|
|
|
|
| 33 |
const label = indicatorState?.label || 'Idle';
|
| 34 |
const animate = indicatorState?.animate || false;
|
| 35 |
|
| 36 |
+
// Loss points for the chart. We prefer the backend's loss_history
|
| 37 |
+
// (built from Lightning's metrics.csv, which records per-step loss
|
| 38 |
+
// from step 0) so the chart shows the full curve even before PL's
|
| 39 |
+
// tqdm postfix surfaces train/loss (which only appears after the
|
| 40 |
+
// first metrics flush, typically end of epoch 0). Falls back to the
|
| 41 |
+
// frontend-built trainingHistory if the backend hasn't populated
|
| 42 |
+
// loss_history yet (very early in the run, before PL writes CSV).
|
| 43 |
+
const lossPoints = useMemo(() => {
|
| 44 |
+
const fromBackend = trainingStatus?.loss_history;
|
| 45 |
+
if (Array.isArray(fromBackend) && fromBackend.length > 0) {
|
| 46 |
+
return fromBackend
|
| 47 |
+
.filter(p => Number.isFinite(p?.step) && Number.isFinite(p?.loss))
|
| 48 |
+
.sort((a, b) => a.step - b.step);
|
| 49 |
+
}
|
| 50 |
+
if (!trainingHistory || trainingHistory.length === 0) return [];
|
| 51 |
+
const byStep = new Map();
|
| 52 |
+
for (const h of trainingHistory) {
|
| 53 |
+
const step = h.current_step ?? h.step;
|
| 54 |
+
const loss = typeof h.loss === 'number' ? h.loss : parseFloat(h.loss);
|
| 55 |
+
if (Number.isFinite(step) && Number.isFinite(loss)) {
|
| 56 |
+
byStep.set(step, { step, loss });
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
return Array.from(byStep.values()).sort((a, b) => a.step - b.step);
|
| 60 |
+
}, [trainingHistory, trainingStatus?.loss_history]);
|
| 61 |
+
|
| 62 |
+
const step = trainingStatus?.current_step ?? trainingStatus?.step ?? 0;
|
| 63 |
+
const totalSteps = trainingStatus?.total_steps ?? 0;
|
| 64 |
+
const checkpointsSaved = trainingStatus?.checkpoints_saved
|
| 65 |
+
?? trainingStatus?.checkpoints?.length
|
| 66 |
+
?? 0;
|
| 67 |
+
const currentLoss = trainingStatus?.loss;
|
| 68 |
+
|
| 69 |
return (
|
| 70 |
<Paper sx={trainingMonitorStyles.rootPaper}>
|
| 71 |
<Box sx={trainingMonitorStyles.headerRow}>
|
|
|
|
| 98 |
/>
|
| 99 |
</Box>
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
<Grid container spacing={2} sx={trainingMonitorStyles.metricsGrid}>
|
| 102 |
<Grid item xs={12} sm={6}>
|
| 103 |
+
<Typography variant="body2" color="textSecondary">Step</Typography>
|
| 104 |
+
<Typography variant="body1" color="primary">
|
| 105 |
+
{totalSteps > 0 ? `${step} / ${totalSteps}` : `${step}`}
|
|
|
|
|
|
|
| 106 |
</Typography>
|
| 107 |
</Grid>
|
| 108 |
<Grid item xs={12} sm={6}>
|
| 109 |
+
<Typography variant="body2" color="textSecondary">Current Loss</Typography>
|
| 110 |
+
<Typography variant="body1">
|
| 111 |
+
{Number.isFinite(currentLoss) ? parseFloat(currentLoss).toFixed(4) : 'N/A'}
|
|
|
|
|
|
|
| 112 |
</Typography>
|
| 113 |
</Grid>
|
| 114 |
<Grid item xs={12} sm={6}>
|
| 115 |
<Typography variant="body2" color="textSecondary">Checkpoints Saved</Typography>
|
| 116 |
+
<Typography variant="body1">{checkpointsSaved}</Typography>
|
|
|
|
|
|
|
| 117 |
</Grid>
|
| 118 |
<Grid item xs={12} sm={6}>
|
| 119 |
+
<Typography variant="body2" color="textSecondary">Phase</Typography>
|
| 120 |
+
<Typography variant="body1" sx={{ textTransform: 'capitalize' }}>
|
| 121 |
+
{trainingStatus?.status || 'idle'}
|
| 122 |
</Typography>
|
| 123 |
</Grid>
|
| 124 |
+
{Number.isFinite(trainingStatus?.seed) && (
|
| 125 |
+
<Grid item xs={12} sm={6}>
|
| 126 |
+
<Typography variant="body2" color="textSecondary">Seed</Typography>
|
| 127 |
+
<Typography variant="body1" sx={{ fontVariantNumeric: 'tabular-nums' }}>
|
| 128 |
+
{trainingStatus.seed}
|
| 129 |
+
</Typography>
|
| 130 |
+
</Grid>
|
| 131 |
+
)}
|
| 132 |
</Grid>
|
| 133 |
|
| 134 |
+
{lossPoints.length > 1 && (
|
| 135 |
<Box sx={trainingMonitorStyles.lossSection}>
|
| 136 |
<Typography variant="body2" color="textSecondary" gutterBottom>
|
| 137 |
<strong>Loss History</strong>
|
| 138 |
</Typography>
|
| 139 |
<Box sx={trainingMonitorStyles.lossChartBox}>
|
| 140 |
+
<LossChart data={lossPoints} />
|
| 141 |
</Box>
|
| 142 |
+
<Typography
|
| 143 |
+
variant="caption"
|
| 144 |
+
color="textSecondary"
|
| 145 |
+
sx={trainingMonitorStyles.lossDisclaimer}
|
| 146 |
+
>
|
| 147 |
+
LoRA diffusion loss is noisy by design β each step samples
|
| 148 |
+
a random noise level. Judge the result with your ears, not
|
| 149 |
+
only with this chart.
|
| 150 |
+
</Typography>
|
| 151 |
</Box>
|
| 152 |
)}
|
| 153 |
|
app/frontend/src/components/WelcomePage.js
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import React, { useState, useEffect } from 'react';
|
| 2 |
-
import { Backdrop, Box, Fade, Typography, Button, Checkbox, FormControlLabel } from '@mui/material';
|
| 3 |
import { welcomePageStyles } from '../theme';
|
|
|
|
| 4 |
|
| 5 |
export default function WelcomePage({ open, onClose }) {
|
| 6 |
const [titleVisible, setTitleVisible] = useState(false);
|
|
@@ -48,53 +49,41 @@ export default function WelcomePage({ open, onClose }) {
|
|
| 48 |
</Fade>
|
| 49 |
|
| 50 |
<Fade in={textVisible} timeout={1000}>
|
| 51 |
-
<
|
| 52 |
-
<
|
| 53 |
-
variant="
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
<
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
>
|
| 70 |
-
Version 0.1.1
|
| 71 |
-
</Typography>
|
| 72 |
-
<Button
|
| 73 |
-
variant="contained"
|
| 74 |
-
onClick={() => onClose(dontShowAgain)}
|
| 75 |
-
sx={welcomePageStyles.ctaButton}
|
| 76 |
-
>
|
| 77 |
-
Get Started
|
| 78 |
-
</Button>
|
| 79 |
-
<Box sx={{ mt: 1.5 }}>
|
| 80 |
<FormControlLabel
|
| 81 |
control={
|
| 82 |
<Checkbox
|
| 83 |
checked={dontShowAgain}
|
| 84 |
onChange={(e) => setDontShowAgain(e.target.checked)}
|
| 85 |
size="small"
|
| 86 |
-
sx={{ color: 'text.secondary' }}
|
| 87 |
/>
|
| 88 |
}
|
| 89 |
label={
|
| 90 |
-
<Typography variant="caption"
|
| 91 |
Don't show this again
|
| 92 |
</Typography>
|
| 93 |
}
|
| 94 |
/>
|
| 95 |
</Box>
|
| 96 |
-
|
| 97 |
-
</Box>
|
| 98 |
</Fade>
|
| 99 |
</Box>
|
| 100 |
</Backdrop>
|
|
|
|
| 1 |
import React, { useState, useEffect } from 'react';
|
| 2 |
+
import { Backdrop, Box, Fade, Typography, Button, Checkbox, FormControlLabel, Stack } from '@mui/material';
|
| 3 |
import { welcomePageStyles } from '../theme';
|
| 4 |
+
import { APP_VERSION } from '../version';
|
| 5 |
|
| 6 |
export default function WelcomePage({ open, onClose }) {
|
| 7 |
const [titleVisible, setTitleVisible] = useState(false);
|
|
|
|
| 49 |
</Fade>
|
| 50 |
|
| 51 |
<Fade in={textVisible} timeout={1000}>
|
| 52 |
+
<Stack alignItems="center">
|
| 53 |
+
<Stack alignItems="center">
|
| 54 |
+
<Typography variant="body3" color="text.secondary">
|
| 55 |
+
Β©2025-2026 Misagh Azimi
|
| 56 |
+
</Typography>
|
| 57 |
+
<Typography variant="body3" color="text.secondary">
|
| 58 |
+
Version {APP_VERSION}
|
| 59 |
+
</Typography>
|
| 60 |
|
| 61 |
+
</Stack>
|
| 62 |
+
<Box mt={5}>
|
| 63 |
+
<Button
|
| 64 |
+
variant="contained"
|
| 65 |
+
onClick={() => onClose(dontShowAgain)}
|
| 66 |
+
>
|
| 67 |
+
Get Started
|
| 68 |
+
</Button>
|
| 69 |
+
</Box>
|
| 70 |
+
<Box mt={6}>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
<FormControlLabel
|
| 72 |
control={
|
| 73 |
<Checkbox
|
| 74 |
checked={dontShowAgain}
|
| 75 |
onChange={(e) => setDontShowAgain(e.target.checked)}
|
| 76 |
size="small"
|
|
|
|
| 77 |
/>
|
| 78 |
}
|
| 79 |
label={
|
| 80 |
+
<Typography variant="caption" color="text.secondary">
|
| 81 |
Don't show this again
|
| 82 |
</Typography>
|
| 83 |
}
|
| 84 |
/>
|
| 85 |
</Box>
|
| 86 |
+
</Stack>
|
|
|
|
| 87 |
</Fade>
|
| 88 |
</Box>
|
| 89 |
</Backdrop>
|
app/frontend/src/components/usePerformanceSession.js
CHANGED
|
@@ -66,13 +66,21 @@ export function loadPresetIntoSession(name) {
|
|
| 66 |
const CHANNEL_DEFAULT = {
|
| 67 |
prompt: '',
|
| 68 |
duration: 8,
|
| 69 |
-
durationMode: '
|
| 70 |
bars: 4,
|
| 71 |
looping: true,
|
| 72 |
muted: false,
|
| 73 |
soloed: false,
|
| 74 |
batchSize: 1,
|
| 75 |
-
knobs: { gain: -6, pan: 0, filter:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
};
|
| 77 |
|
| 78 |
function defaultSession(channelCount) {
|
|
@@ -88,6 +96,18 @@ function defaultSession(channelCount) {
|
|
| 88 |
randomSeed: true,
|
| 89 |
seedValue: '',
|
| 90 |
cueDeviceId: '',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
channels: Array.from({ length: channelCount }, () => ({
|
| 92 |
...CHANNEL_DEFAULT,
|
| 93 |
knobs: { ...CHANNEL_DEFAULT.knobs },
|
|
@@ -104,11 +124,21 @@ function loadSession(channelCount) {
|
|
| 104 |
// Merge against defaults so older saves don't crash on missing fields.
|
| 105 |
// Length shifts (channel count change between releases) are absorbed
|
| 106 |
// by always producing exactly `channelCount` channels.
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
return { ...fallback, ...parsed, channels };
|
| 113 |
} catch {
|
| 114 |
return fallback;
|
|
|
|
| 66 |
const CHANNEL_DEFAULT = {
|
| 67 |
prompt: '',
|
| 68 |
duration: 8,
|
| 69 |
+
durationMode: 'bars',
|
| 70 |
bars: 4,
|
| 71 |
looping: true,
|
| 72 |
muted: false,
|
| 73 |
soloed: false,
|
| 74 |
batchSize: 1,
|
| 75 |
+
knobs: { gain: -6, pan: 0, filter: 0, delay: 0, reverb: 0 },
|
| 76 |
+
// Fragment history metadata (id, prompt, duration, createdAt, starred,
|
| 77 |
+
// number). The Blob audio bodies live in IndexedDB under the
|
| 78 |
+
// `session-ch{N}` scope β see utils/fragmentStorage.js. Cleared on
|
| 79 |
+
// Fresh Start and overwritten on preset load.
|
| 80 |
+
fragments: [],
|
| 81 |
+
// Which fragment was loaded into the channel strip last; restored on
|
| 82 |
+
// reload so the channel comes back ready to play instead of empty.
|
| 83 |
+
committedFragmentId: null,
|
| 84 |
};
|
| 85 |
|
| 86 |
function defaultSession(channelCount) {
|
|
|
|
| 96 |
randomSeed: true,
|
| 97 |
seedValue: '',
|
| 98 |
cueDeviceId: '',
|
| 99 |
+
// Master FX defaults β the FX are always-on; the wet level on the
|
| 100 |
+
// master bus is determined entirely by per-channel DLY/REV send
|
| 101 |
+
// levels. We only persist the IR choice and the delay division.
|
| 102 |
+
masterReverbIR: 'hall',
|
| 103 |
+
masterDelayDivision: '1/4',
|
| 104 |
+
// Prompt auto-inject fields. Each is appended (comma-separated) to
|
| 105 |
+
// every generated prompt when set. Key and Time accept any text;
|
| 106 |
+
// empty = no injection. BPM is a toggle that, when on, grabs the
|
| 107 |
+
// live master BPM (top-bar value) at generation time.
|
| 108 |
+
promptKey: '',
|
| 109 |
+
promptInjectBpm: false,
|
| 110 |
+
promptTimeSig: '',
|
| 111 |
channels: Array.from({ length: channelCount }, () => ({
|
| 112 |
...CHANNEL_DEFAULT,
|
| 113 |
knobs: { ...CHANNEL_DEFAULT.knobs },
|
|
|
|
| 124 |
// Merge against defaults so older saves don't crash on missing fields.
|
| 125 |
// Length shifts (channel count change between releases) are absorbed
|
| 126 |
// by always producing exactly `channelCount` channels.
|
| 127 |
+
//
|
| 128 |
+
// Migration: pre-rename saves used `takes` / `committedTakeId`. Copy
|
| 129 |
+
// those into the new `fragments` / `committedFragmentId` slots when
|
| 130 |
+
// present, so users' existing generations carry over after the
|
| 131 |
+
// "Takes β Fragments" rename. Old fields are left in place but unused.
|
| 132 |
+
const channels = Array.from({ length: channelCount }, (_, i) => {
|
| 133 |
+
const ch = parsed.channels?.[i] || {};
|
| 134 |
+
return {
|
| 135 |
+
...CHANNEL_DEFAULT,
|
| 136 |
+
...ch,
|
| 137 |
+
fragments: ch.fragments ?? ch.takes ?? [],
|
| 138 |
+
committedFragmentId: ch.committedFragmentId ?? ch.committedTakeId ?? null,
|
| 139 |
+
knobs: { ...CHANNEL_DEFAULT.knobs, ...(ch.knobs || {}) },
|
| 140 |
+
};
|
| 141 |
+
});
|
| 142 |
return { ...fallback, ...parsed, channels };
|
| 143 |
} catch {
|
| 144 |
return fallback;
|
app/frontend/src/theme.js
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/frontend/src/tooltips.js
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export const TIPS = {
|
| 2 |
+
// App.js β LoRA training hyperparameters + model row actions.
|
| 3 |
+
training: {
|
| 4 |
+
downloadModel: 'Download this model',
|
| 5 |
+
deleteFineTuned: 'Delete fine-tuned model',
|
| 6 |
+
steps: "SA3's documented quick-start is 1,000 steps.",
|
| 7 |
+
adapter: "DoRA-rows is SA3's upstream default and works best for most stylistic LoRAs. The -xs variants freeze SVD bases and only train a tiny core matrix β far fewer parameters, useful when VRAM is tight. BoRA scales both rows and columns independently (more expressive, more parameters).",
|
| 8 |
+
checkpointEvery: 'How often a LoRA .safetensors snapshot gets written. Auto picks ~10 checkpoints per run (capped 250β1 000 steps). Lower = more granular but more disk; higher = fewer files to compare.',
|
| 9 |
+
batchSize: 'SA3 examples use 1. Each extra sample adds ~1β2 GB of activations. Raise only on roomy GPUs (β₯24 GB); medium-base activations are heavy. Lower if you hit CUDA OOM.',
|
| 10 |
+
precision: 'Cast applied to the frozen base weights only; LoRA parameters stay in fp32 for the optimizer. bf16 halves the VRAM used by the base with negligible quality cost on Ampere and newer cards.',
|
| 11 |
+
rank: "Capacity of the LoRA update β rank-k matrices A (kΓin) and B (outΓk) are trained. Higher rank = more expressive but larger file and more VRAM. r=16 fits comfortably on 16 GB and is SA3's default.",
|
| 12 |
+
alpha: 'Scaling factor for the LoRA update. Effective scaling is alpha / rank β setting alpha = rank gives a scaling of 1.0. Conventional choice: alpha = rank.',
|
| 13 |
+
dropout: 'Regularization probability applied to LoRA inputs during training. 0 is fine for most cases β raise to ~0.05 if you see overfitting on small datasets.',
|
| 14 |
+
seed: 'Random seed for reproducibility β same dataset + same hyperparameters + same seed produces the same LoRA. Change it to re-roll with different sampling behaviour.',
|
| 15 |
+
learningRate: "AdamW step size for the LoRA weights (base stays frozen). SA3's default is 1e-4, which works for most runs. Too high destabilizes training (loss spikes, artifacts); too low barely moves the adapter. Halve it if loss is erratic.",
|
| 16 |
+
sampleLength: 'Audio fed to the model per training step. Long clips get random-cropped to this length each step; short clips get silence-padded. Capped at the base model\'s native length (~120s small, ~380s medium) β longer windows cost markedly more VRAM and step time, so raise it only for long-form material (pre-encoding helps).',
|
| 17 |
+
includeLayers: 'Space-separated substrings β only layers whose fully-qualified name contains one of these get LoRA. Empty = all matching Linear/Conv1d layers. Example: transformer.layers.',
|
| 18 |
+
excludeLayers: 'Space-separated substrings β matching layers are skipped, even if they also match Include. SA3-docs default (seconds_total to_local_embed) prevents conditioner-hijacking on small datasets.',
|
| 19 |
+
},
|
| 20 |
+
|
| 21 |
+
// PerformancePanel.js β top transport bar + bottom controls.
|
| 22 |
+
perf: {
|
| 23 |
+
notDownloaded: 'Not downloaded β open Checkpoint Manager',
|
| 24 |
+
midiSettings: 'MIDI settings & mappings',
|
| 25 |
+
presets: 'Save / load presets',
|
| 26 |
+
deletePreset: 'Delete preset',
|
| 27 |
+
launchQuant: "Launch quantization β match Ableton's",
|
| 28 |
+
deleteFineTuned: 'Delete fine-tuned model',
|
| 29 |
+
deleteLora: 'Delete LoRA',
|
| 30 |
+
promptKey: 'Auto-inject Key. Leave empty to skip.',
|
| 31 |
+
timeSig: 'Auto-inject Time signature. Leave empty to skip.',
|
| 32 |
+
link: ({ installing, available, enabled, peers }) =>
|
| 33 |
+
installing
|
| 34 |
+
? 'Installing LinkPython-externβ¦'
|
| 35 |
+
: !available
|
| 36 |
+
? 'Click to install Ableton Link script'
|
| 37 |
+
: enabled
|
| 38 |
+
? `Link on β ${peers} peer${peers === 1 ? '' : 's'} (click to disable)`
|
| 39 |
+
: 'Click to sync BPM with Ableton Link',
|
| 40 |
+
midiMode: ({ supported, permissionError, learnMode }) =>
|
| 41 |
+
!supported
|
| 42 |
+
? (permissionError || 'Web MIDI is not available')
|
| 43 |
+
: learnMode
|
| 44 |
+
? 'Exit MIDI mode (Esc)'
|
| 45 |
+
: 'Enter MIDI mode β click a control then move a hardware knob/button to bind',
|
| 46 |
+
audioSetup: (cueSupported) =>
|
| 47 |
+
cueSupported
|
| 48 |
+
? 'Audio setup β choose output device'
|
| 49 |
+
: 'Audio device selection requires Chrome/Edge (AudioContext.setSinkId). Output falls back to system default.',
|
| 50 |
+
restoreDefaults: (armed) =>
|
| 51 |
+
armed
|
| 52 |
+
? 'Click again within 3s to confirm β clears session, fragments, and MIDI mappings'
|
| 53 |
+
: 'Reset all panel settings, clear fragments, and clear MIDI mappings',
|
| 54 |
+
steps: (isDistilled) =>
|
| 55 |
+
isDistilled
|
| 56 |
+
? 'Locked at 8 steps for distilled SA3 models β pick a *-base checkpoint to override'
|
| 57 |
+
: 'Diffusion steps per generation (more = higher quality, slower)',
|
| 58 |
+
bpmInject: (on, bpm) =>
|
| 59 |
+
on
|
| 60 |
+
? `Injecting master BPM (${Math.round(bpm)}) into prompts β click to disable`
|
| 61 |
+
: 'Click to auto-inject the master BPM (top bar) into every prompt',
|
| 62 |
+
},
|
| 63 |
+
|
| 64 |
+
// PerformanceChannel.js β per-channel strip.
|
| 65 |
+
channel: {
|
| 66 |
+
mute: 'Mute',
|
| 67 |
+
solo: 'Solo',
|
| 68 |
+
batch: "Batch generate Fragments and cue below.",
|
| 69 |
+
loop: (looping, durationMode) =>
|
| 70 |
+
looping
|
| 71 |
+
? (durationMode === 'bars'
|
| 72 |
+
? 'Loop'
|
| 73 |
+
: 'Playback loop on')
|
| 74 |
+
: 'Loop off',
|
| 75 |
+
generateDisabled: (generating, canGenerate, hasPrompt) =>
|
| 76 |
+
generating
|
| 77 |
+
? ''
|
| 78 |
+
: !canGenerate
|
| 79 |
+
? 'Pick a model in the Generation tab first'
|
| 80 |
+
: !hasPrompt
|
| 81 |
+
? 'Enter a prompt to generate'
|
| 82 |
+
: '',
|
| 83 |
+
variation: (loaded) =>
|
| 84 |
+
loaded
|
| 85 |
+
? 'Variation from the current fragment'
|
| 86 |
+
: 'Generate a fragment first, then create variations of it',
|
| 87 |
+
},
|
| 88 |
+
|
| 89 |
+
// DatasetPrep.js β dataset workbench.
|
| 90 |
+
dataset: {
|
| 91 |
+
richAnnotate: 'Adds genre / mood / instrument tags using LAION-CLAP. Requires the CLAP weights β downloadable from the Checkpoint Manager.',
|
| 92 |
+
skipAnnotated: 'When on, Auto-annotate skips clips that already have an annotation. Off means every run overwrites existing prompts.',
|
| 93 |
+
deleteProject: 'Delete this project (folder, audio, sidecars, drafts) β irreversible',
|
| 94 |
+
discardChanges: 'Delete unsaved changes β reverts to the last created dataset (removes any audio added since)',
|
| 95 |
+
saveDraft: "Save a draft β persists across app restarts but isn't the SA3 sidecar form",
|
| 96 |
+
createDataset: 'Create Dataset β writes the .txt sidecars (overwrites the previous dataset)',
|
| 97 |
+
selectClips: 'Click to select these clips β then Auto-annotate them.',
|
| 98 |
+
autoAnnotateClip: 'Auto-annotate this clip (overwrites any current prompt)',
|
| 99 |
+
sliceClip: 'Slice this clip into shorter children (immediate)',
|
| 100 |
+
removeClip: 'Remove this clip from the project (immediate)',
|
| 101 |
+
tooShort: (thresholdSec) =>
|
| 102 |
+
`Shorter than ${thresholdSec}s β gets silence-padded into each batch. Consider deleting. Click to select.`,
|
| 103 |
+
duplicates: (count) =>
|
| 104 |
+
`${count} group${count === 1 ? '' : 's'} of clips share the same annotation. Bad for training diversity β click to select all of them.`,
|
| 105 |
+
unsupported: (accepted) =>
|
| 106 |
+
`SA3 only trains on ${(accepted || []).join(', ')}. These clips will be silently skipped at train time β re-export them as .wav (or another accepted format) before committing. Click to select.`,
|
| 107 |
+
},
|
| 108 |
+
|
| 109 |
+
// LoraStack.js β LoRA slot stack.
|
| 110 |
+
lora: {
|
| 111 |
+
stackInfo: (max) => `Blend up to ${max} LoRAs at any strength`,
|
| 112 |
+
dragReorder: 'Drag to reorder (slot 0 loads first)',
|
| 113 |
+
bypass: (bypassed) =>
|
| 114 |
+
bypassed ? 'Bypassed (strength 0) β click to enable' : 'Bypass this slot',
|
| 115 |
+
},
|
| 116 |
+
|
| 117 |
+
// Fragment lists β ChannelFragmentHistory.js + GeneratedFragmentsWindow.js.
|
| 118 |
+
fragments: {
|
| 119 |
+
clearAll: 'Clear all (delete every fragment from disk)',
|
| 120 |
+
deleteFromDisk: 'Delete from disk',
|
| 121 |
+
revealInFolder: 'Show in folder (reveal this file on disk)',
|
| 122 |
+
audition: (isAuditioning) =>
|
| 123 |
+
isAuditioning ? 'Stop cue' : 'Audition through cue output',
|
| 124 |
+
star: (starred) =>
|
| 125 |
+
starred ? 'Unstar' : 'Star (keep through eviction)',
|
| 126 |
+
commit: (committed) =>
|
| 127 |
+
committed ? 'Currently loaded' : 'Load into channel',
|
| 128 |
+
},
|
| 129 |
+
|
| 130 |
+
// CheckpointRow.js β checkpoint catalog rows.
|
| 131 |
+
checkpoints: {
|
| 132 |
+
gatedAccess: "Open on HuggingFace to accept the model's gated-access terms",
|
| 133 |
+
},
|
| 134 |
+
};
|
app/frontend/src/utils/cueAudio.js
CHANGED
|
@@ -11,6 +11,7 @@
|
|
| 11 |
|
| 12 |
let ctx = null;
|
| 13 |
let currentSource = null;
|
|
|
|
| 14 |
let currentEndedHandler = null;
|
| 15 |
let currentSinkId = '';
|
| 16 |
|
|
@@ -132,19 +133,26 @@ export async function playBlob(blob, { onEnded } = {}) {
|
|
| 132 |
|
| 133 |
const src = c.createBufferSource();
|
| 134 |
src.buffer = buf;
|
| 135 |
-
//
|
| 136 |
-
//
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
const handler = () => {
|
| 140 |
if (currentSource === src) {
|
| 141 |
currentSource = null;
|
|
|
|
| 142 |
currentEndedHandler = null;
|
| 143 |
}
|
| 144 |
onEnded?.();
|
| 145 |
};
|
| 146 |
src.addEventListener('ended', handler);
|
| 147 |
currentSource = src;
|
|
|
|
| 148 |
currentEndedHandler = handler;
|
| 149 |
src.start();
|
| 150 |
|
|
@@ -157,12 +165,27 @@ export async function playBlob(blob, { onEnded } = {}) {
|
|
| 157 |
|
| 158 |
export function stopCue() {
|
| 159 |
if (currentSource) {
|
|
|
|
|
|
|
| 160 |
if (currentEndedHandler) {
|
| 161 |
-
|
| 162 |
}
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
currentSource = null;
|
|
|
|
| 166 |
currentEndedHandler = null;
|
| 167 |
}
|
| 168 |
}
|
|
|
|
| 11 |
|
| 12 |
let ctx = null;
|
| 13 |
let currentSource = null;
|
| 14 |
+
let currentSourceFade = null; // per-source gain used by stopCue() to ramp out
|
| 15 |
let currentEndedHandler = null;
|
| 16 |
let currentSinkId = '';
|
| 17 |
|
|
|
|
| 133 |
|
| 134 |
const src = c.createBufferSource();
|
| 135 |
src.buffer = buf;
|
| 136 |
+
// Per-source fade gain so stopCue() can ramp out instead of hard-cut
|
| 137 |
+
// (a hard cut at non-zero samples is what produces the click /
|
| 138 |
+
// crackle when switching fragments rapidly). The fade graph is:
|
| 139 |
+
// source β fadeGain β cueSplitter β cueMerger β destination
|
| 140 |
+
const fadeGain = c.createGain();
|
| 141 |
+
fadeGain.gain.value = 1;
|
| 142 |
+
src.connect(fadeGain);
|
| 143 |
+
fadeGain.connect(cueSplitter);
|
| 144 |
|
| 145 |
const handler = () => {
|
| 146 |
if (currentSource === src) {
|
| 147 |
currentSource = null;
|
| 148 |
+
currentSourceFade = null;
|
| 149 |
currentEndedHandler = null;
|
| 150 |
}
|
| 151 |
onEnded?.();
|
| 152 |
};
|
| 153 |
src.addEventListener('ended', handler);
|
| 154 |
currentSource = src;
|
| 155 |
+
currentSourceFade = fadeGain;
|
| 156 |
currentEndedHandler = handler;
|
| 157 |
src.start();
|
| 158 |
|
|
|
|
| 165 |
|
| 166 |
export function stopCue() {
|
| 167 |
if (currentSource) {
|
| 168 |
+
const src = currentSource;
|
| 169 |
+
const fade = currentSourceFade;
|
| 170 |
if (currentEndedHandler) {
|
| 171 |
+
src.removeEventListener('ended', currentEndedHandler);
|
| 172 |
}
|
| 173 |
+
const now = ctx ? ctx.currentTime : 0;
|
| 174 |
+
const FADE = 0.012;
|
| 175 |
+
try {
|
| 176 |
+
if (fade) {
|
| 177 |
+
fade.gain.cancelScheduledValues(now);
|
| 178 |
+
fade.gain.setValueAtTime(fade.gain.value, now);
|
| 179 |
+
fade.gain.linearRampToValueAtTime(0, now + FADE);
|
| 180 |
+
}
|
| 181 |
+
src.stop(now + FADE + 0.005);
|
| 182 |
+
} catch { /* already stopped */ }
|
| 183 |
+
window.setTimeout(() => {
|
| 184 |
+
try { src.disconnect(); } catch { /* ok */ }
|
| 185 |
+
try { fade && fade.disconnect(); } catch { /* ok */ }
|
| 186 |
+
}, Math.ceil((FADE + 0.02) * 1000));
|
| 187 |
currentSource = null;
|
| 188 |
+
currentSourceFade = null;
|
| 189 |
currentEndedHandler = null;
|
| 190 |
}
|
| 191 |
}
|
app/frontend/src/utils/fragmentDrag.js
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// In-app drag handoff for generated fragments.
|
| 2 |
+
//
|
| 3 |
+
// The HTML drag-and-drop dataTransfer can only carry strings, so when a
|
| 4 |
+
// fragment is dragged from the Generated Fragments window into the Edit tab's
|
| 5 |
+
// source dropzone we stash its in-memory audio Blob here on dragStart and read
|
| 6 |
+
// it back on drop. This lets EditPanel use the blob directly instead of
|
| 7 |
+
// re-fetching by filename β which is immune to any divergence between the
|
| 8 |
+
// fragment's in-memory name and what actually exists on disk.
|
| 9 |
+
//
|
| 10 |
+
// Falls back gracefully: if no blob was stashed (e.g. a not-yet-preloaded
|
| 11 |
+
// disk fragment), the consumer drops back to the filename-based fetch.
|
| 12 |
+
|
| 13 |
+
let _payload = null; // { filename: string, blob: Blob } | null
|
| 14 |
+
|
| 15 |
+
export function setFragmentDragPayload(payload) {
|
| 16 |
+
_payload = payload;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
export function getFragmentDragPayload() {
|
| 20 |
+
return _payload;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
export function clearFragmentDragPayload() {
|
| 24 |
+
_payload = null;
|
| 25 |
+
}
|