MazCodes commited on
Commit
9ea28c1
Β·
verified Β·
1 Parent(s): 14269fb

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. Dockerfile +4 -2
  3. README.md +4 -3
  4. app/backend/app.py +0 -0
  5. app/backend/data/auto_annotator.py +157 -34
  6. app/backend/data/pre_encoder.py +354 -0
  7. app/backend/data/projects.py +1023 -0
  8. app/backend/data/slicing.py +183 -0
  9. app/core/audio/midi_input.py +172 -0
  10. app/core/config.py +16 -86
  11. app/core/generation/audio_generator.py +490 -473
  12. app/core/generation/audio_post_process.py +713 -44
  13. app/core/model_manager.py +628 -437
  14. app/core/training/hyperparam_suggester.py +299 -141
  15. app/core/training/sa3_lora_runner.py +331 -0
  16. app/core/training/sa3_trainer.py +839 -0
  17. app/frontend/index.html +29 -6
  18. app/frontend/logs/fragmenta_20260525.log +8 -0
  19. app/frontend/package.json +2 -2
  20. app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf +3 -0
  21. app/frontend/public/InterTight-VariableFont_wght.ttf +3 -0
  22. app/frontend/public/fragmenta_background.png +2 -2
  23. app/frontend/public/interface.png +2 -2
  24. app/frontend/src/App.js +0 -0
  25. app/frontend/src/api.js +1 -0
  26. app/frontend/src/components/AboutDialog.js +130 -0
  27. app/frontend/src/components/AudioWaveform.js +258 -0
  28. app/frontend/src/components/ChannelFragmentHistory.js +217 -0
  29. app/frontend/src/components/CheckpointManagerWindow.js +243 -0
  30. app/frontend/src/components/CheckpointRow.js +270 -0
  31. app/frontend/src/components/DatasetPrep.js +1823 -0
  32. app/frontend/src/components/EditPanel.js +597 -0
  33. app/frontend/src/components/GeneratedFragmentsWindow.js +420 -70
  34. app/frontend/src/components/GenerationWaveform.js +217 -0
  35. app/frontend/src/components/InfoView.js +91 -0
  36. app/frontend/src/components/LoraStack.js +252 -0
  37. app/frontend/src/components/LossChart.js +27 -11
  38. app/frontend/src/components/MidiConfigMenu.js +118 -46
  39. app/frontend/src/components/MidiContext.js +38 -48
  40. app/frontend/src/components/PerformanceChannel.js +618 -239
  41. app/frontend/src/components/PerformancePanel.js +0 -0
  42. app/frontend/src/components/StorageDrilldown.js +84 -0
  43. app/frontend/src/components/Tooltip.js +35 -0
  44. app/frontend/src/components/TrainingMonitor.js +76 -35
  45. app/frontend/src/components/WelcomePage.js +22 -33
  46. app/frontend/src/components/usePerformanceSession.js +37 -7
  47. app/frontend/src/theme.js +0 -0
  48. app/frontend/src/tooltips.js +134 -0
  49. app/frontend/src/utils/cueAudio.js +29 -6
  50. app/frontend/src/utils/fragmentDrag.js +25 -0
.gitattributes CHANGED
@@ -47,3 +47,5 @@ utils/vendor/wheels/antlr4_python3_runtime-4.9.3-py3-none-any.whl filter=lfs dif
47
  vendor/stable-audio-tools/demo_cfg_3_00000001.wav filter=lfs diff=lfs merge=lfs -text
48
  vendor/stable-audio-tools/demo_cfg_6_00000001.wav filter=lfs diff=lfs merge=lfs -text
49
  vendor/stable-audio-tools/demo_cfg_9_00000001.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
47
  vendor/stable-audio-tools/demo_cfg_3_00000001.wav filter=lfs diff=lfs merge=lfs -text
48
  vendor/stable-audio-tools/demo_cfg_6_00000001.wav filter=lfs diff=lfs merge=lfs -text
49
  vendor/stable-audio-tools/demo_cfg_9_00000001.wav filter=lfs diff=lfs merge=lfs -text
50
+ app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf filter=lfs diff=lfs merge=lfs -text
51
+ app/frontend/public/InterTight-VariableFont_wght.ttf filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -60,8 +60,9 @@ RUN grep -ivE 'flash-attn|extra-index-url|pycairo|pygobject|pywebview' requireme
60
  COPY . .
61
  COPY --from=frontend-builder /build/frontend/build ./app/frontend/build
62
 
63
- # Install stable-audio-tools in-tree
64
- RUN pip install --no-cache-dir --root-user-action=ignore -e ./vendor/stable-audio-tools/
 
65
 
66
  # Create writable directories
67
  RUN mkdir -p /app/models/pretrained \
@@ -105,6 +106,7 @@ ENV FLASK_HOST=0.0.0.0
105
  ENV FLASK_PORT=7860
106
  ENV FRAGMENTA_LOG_LEVEL=INFO
107
  ENV FRAGMENTA_DOCKER=1
 
108
  ENV FRAGMENTA_USE_CUSTOM_MODELS=true
109
  ENV HOME=/home/user
110
  ENV PATH="/home/user/.local/bin:${PATH}"
 
60
  COPY . .
61
  COPY --from=frontend-builder /build/frontend/build ./app/frontend/build
62
 
63
+ # Install vendored Stable Audio 3 in-tree (--no-deps: runtime deps come from
64
+ # requirements.txt). Makes `import stable_audio_3` resolve.
65
+ RUN pip install --no-cache-dir --root-user-action=ignore --no-deps -e ./vendor/stable-audio-3/
66
 
67
  # Create writable directories
68
  RUN mkdir -p /app/models/pretrained \
 
106
  ENV FLASK_PORT=7860
107
  ENV FRAGMENTA_LOG_LEVEL=INFO
108
  ENV FRAGMENTA_DOCKER=1
109
+ ENV PYTHONPATH=/app/vendor/stable-audio-3
110
  ENV FRAGMENTA_USE_CUSTOM_MODELS=true
111
  ENV HOME=/home/user
112
  ENV PATH="/home/user/.local/bin:${PATH}"
README.md CHANGED
@@ -17,9 +17,10 @@ Generate and fine-tune audio from text prompts using Stable Audio Open.
17
 
18
  ## Getting Started
19
 
20
- 1. Upload your model weights (`.safetensors`) to `models/pretrained/` in the Space Files tab.
21
- - `stable-audio-open-small-model.safetensors` (recommended for CPU)
22
- - `stable-audio-open-model.safetensors` (full model, recommended for GPU)
 
23
  2. The Space will auto-rebuild after the upload.
24
  3. Use the **Data Processing** tab to upload audio + prompts.
25
  4. Use the **Training** tab to fine-tune.
 
17
 
18
  ## Getting Started
19
 
20
+ 1. Download an SA3 checkpoint via the in-app Checkpoint Manager, or place one
21
+ under `models/pretrained/sa3/hub/` in the Space Files tab.
22
+ - `sa3-small-music` (recommended for CPU Spaces)
23
+ - `sa3-medium` (recommended for GPU Spaces with Flash Attention 2)
24
  2. The Space will auto-rebuild after the upload.
25
  3. Use the **Data Processing** tab to upload audio + prompts.
26
  4. Use the **Training** tab to fine-tune.
app/backend/app.py CHANGED
The diff for this file is too large to render. See raw diff
 
app/backend/data/auto_annotator.py CHANGED
@@ -22,6 +22,11 @@ AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac", ".m4a", ".ogg", ".aac")
22
  CLAP_CKPT_FILENAME = "music_audioset_epoch_15_esc_90.14.pt"
23
  CLAP_REPO = "lukewys/laion_clap"
24
 
 
 
 
 
 
25
  KEY_NAMES_SHARP = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
26
  KEY_NAMES_FLAT = ["C", "Db", "D", "Eb", "E", "F", "Gb", "G", "Ab", "A", "Bb", "B"]
27
 
@@ -44,9 +49,18 @@ def _iter_audio_files(folder: Path) -> List[Path]:
44
 
45
  def _estimate_tempo(y, sr) -> Optional[int]:
46
  import librosa
 
47
  try:
48
  tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
49
- bpm = float(tempo if hasattr(tempo, "__float__") else tempo[0])
 
 
 
 
 
 
 
 
50
  if bpm <= 0:
51
  return None
52
  return int(round(bpm))
@@ -178,12 +192,50 @@ class _ClapTagger:
178
  f"CLAP checkpoint not found at {self.ckpt_path}. "
179
  "Download it first via /api/bulk-annotate/download-clap."
180
  )
181
- import laion_clap
182
- import torch
183
  logging.getLogger("transformers").setLevel(logging.ERROR)
184
 
185
- device = "cuda" if torch.cuda.is_available() else "cpu"
186
- model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  # torch >= 2.6 flipped torch.load(weights_only=True) and newer
189
  # transformers dropped the roberta position_ids buffer, so
@@ -268,44 +320,97 @@ def clap_checkpoint_path(models_pretrained_dir: Path) -> Path:
268
  return models_pretrained_dir / "clap" / CLAP_CKPT_FILENAME
269
 
270
 
 
 
 
 
 
271
  def clap_checkpoint_available(models_pretrained_dir: Path) -> bool:
272
  return clap_checkpoint_path(models_pretrained_dir).exists()
273
 
274
 
 
 
 
 
 
 
 
 
275
  def download_clap_checkpoint(
276
  models_pretrained_dir: Path,
277
  progress_cb: Optional[Callable[[str], None]] = None,
 
278
  ) -> Path:
 
 
 
 
 
 
 
 
 
279
  target = clap_checkpoint_path(models_pretrained_dir)
280
  target.parent.mkdir(parents=True, exist_ok=True)
281
- if target.exists():
282
- return target
283
 
284
- from huggingface_hub import hf_hub_download
285
  import os
286
 
287
- if progress_cb:
288
- progress_cb("Downloading CLAP checkpoint (~630 MB)…")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- # Use custom CLAP from fragmenta-models on HF Spaces
291
- use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true'
292
- if use_custom_repo:
293
- repo_id = "MazCodes/fragmenta-models"
294
- else:
295
- repo_id = CLAP_REPO
296
-
297
- downloaded = hf_hub_download(
298
- repo_id=repo_id,
299
- filename=CLAP_CKPT_FILENAME,
300
- local_dir=str(target.parent),
301
- )
302
- downloaded_path = Path(downloaded)
303
- if downloaded_path != target:
304
- try:
305
- downloaded_path.replace(target)
306
- except OSError:
307
- import shutil
308
- shutil.copy2(downloaded_path, target)
309
  return target
310
 
311
 
@@ -328,8 +433,10 @@ def annotate_file(
328
  label_sets: Dict[str, List[str]],
329
  sr: int = 22050,
330
  max_seconds: float = 60.0,
 
331
  ) -> Dict[str, Any]:
332
  import librosa
 
333
 
334
  parts: Dict[str, Any] = {}
335
  try:
@@ -343,10 +450,19 @@ def annotate_file(
343
  "error": f"load failed: {exc}",
344
  }
345
 
346
- parts["bpm"] = _estimate_tempo(y, loaded_sr)
347
- parts["key"] = _estimate_key(y, loaded_sr)
348
- parts["brightness"] = _estimate_brightness(y, loaded_sr)
349
- parts["character"] = _estimate_character(y, loaded_sr)
 
 
 
 
 
 
 
 
 
350
 
351
  if tier == "rich" and clap_tagger is not None:
352
  try:
@@ -355,7 +471,14 @@ def annotate_file(
355
  except Exception as exc:
356
  logger.warning("CLAP tagging failed for %s: %s", audio_path.name, exc)
357
 
358
- prompt = _compose_prompt(parts)
 
 
 
 
 
 
 
359
  return {
360
  "file_name": audio_path.name,
361
  "prompt": prompt,
 
22
  CLAP_CKPT_FILENAME = "music_audioset_epoch_15_esc_90.14.pt"
23
  CLAP_REPO = "lukewys/laion_clap"
24
 
25
+ # Text-side dependencies laion_clap pulls from HF on construction.
26
+ # We stage these into models/pretrained/clap/hub/ so the rich tier is
27
+ # fully offline after a single download and nothing leaks to ~/.cache.
28
+ CLAP_TEXT_DEPS = ("roberta-base", "bert-base-uncased", "facebook/bart-base")
29
+
30
  KEY_NAMES_SHARP = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
31
  KEY_NAMES_FLAT = ["C", "Db", "D", "Eb", "E", "F", "Gb", "G", "Ab", "A", "Bb", "B"]
32
 
 
49
 
50
  def _estimate_tempo(y, sr) -> Optional[int]:
51
  import librosa
52
+ import numpy as np
53
  try:
54
  tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
55
+ # librosa 0.10+ returns tempo as np.ndarray (shape (1,) typically).
56
+ # numpy 2.x removed implicit float() conversion of N-d arrays β€”
57
+ # `float(np.array([120.]))` now raises TypeError instead of returning
58
+ # 120.0 like numpy 1.x did. Normalize via .flat[0] which handles
59
+ # scalar, 0-d, 1-d, and N-d uniformly.
60
+ arr = np.atleast_1d(np.asarray(tempo))
61
+ if arr.size == 0:
62
+ return None
63
+ bpm = float(arr.flat[0])
64
  if bpm <= 0:
65
  return None
66
  return int(round(bpm))
 
192
  f"CLAP checkpoint not found at {self.ckpt_path}. "
193
  "Download it first via /api/bulk-annotate/download-clap."
194
  )
 
 
195
  logging.getLogger("transformers").setLevel(logging.ERROR)
196
 
197
+ # Point HF resolution at our project-local cache and disable the
198
+ # HEAD-revalidation traffic. After download_clap_checkpoint() has
199
+ # staged the text deps under <pretrained>/clap/hub/, CLAP_Module
200
+ # loads them offline with zero HF hub requests.
201
+ #
202
+ # Two reasons env vars alone aren't enough:
203
+ # 1. huggingface_hub.constants.HF_HUB_OFFLINE is captured at
204
+ # module-import time (constants.py:185). model_manager.py
205
+ # imports huggingface_hub at app startup, so the constant is
206
+ # already False by the time we set the env var here.
207
+ # transformers.utils.hub.is_offline_mode reads that same
208
+ # constant β€” patching the attribute makes both libraries see
209
+ # offline mode.
210
+ # 2. laion_clap/training/data.py:44-46 runs three from_pretrained
211
+ # calls at MODULE LEVEL β€” those fire the first time we do
212
+ # `import laion_clap` and predate any patch we do after the
213
+ # import. So we patch BEFORE the import, not after.
214
+ hub_dir = self.ckpt_path.parent / "hub"
215
+ env_keys = ("HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE", "TRANSFORMERS_CACHE",
216
+ "HF_HUB_OFFLINE", "TRANSFORMERS_OFFLINE")
217
+ prev_env = {k: os.environ.get(k) for k in env_keys}
218
+ os.environ["HF_HUB_CACHE"] = str(hub_dir)
219
+ os.environ["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
220
+ os.environ["TRANSFORMERS_CACHE"] = str(hub_dir)
221
+ os.environ["HF_HUB_OFFLINE"] = "1"
222
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
223
+
224
+ import huggingface_hub.constants as _hhc
225
+ prev_offline_attr = _hhc.HF_HUB_OFFLINE
226
+ _hhc.HF_HUB_OFFLINE = True
227
+ try:
228
+ import laion_clap # noqa: E402 β€” must follow the offline patch
229
+ import torch
230
+ device = "cuda" if torch.cuda.is_available() else "cpu"
231
+ model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device=device)
232
+ finally:
233
+ _hhc.HF_HUB_OFFLINE = prev_offline_attr
234
+ for k, v in prev_env.items():
235
+ if v is None:
236
+ os.environ.pop(k, None)
237
+ else:
238
+ os.environ[k] = v
239
 
240
  # torch >= 2.6 flipped torch.load(weights_only=True) and newer
241
  # transformers dropped the roberta position_ids buffer, so
 
320
  return models_pretrained_dir / "clap" / CLAP_CKPT_FILENAME
321
 
322
 
323
+ def clap_hub_dir(models_pretrained_dir: Path) -> Path:
324
+ """HF cache for laion_clap's text-side deps. Sibling of the .pt."""
325
+ return models_pretrained_dir / "clap" / "hub"
326
+
327
+
328
  def clap_checkpoint_available(models_pretrained_dir: Path) -> bool:
329
  return clap_checkpoint_path(models_pretrained_dir).exists()
330
 
331
 
332
+ def _text_dep_snapshot_present(hub_dir: Path, repo_id: str) -> bool:
333
+ safe = "models--" + repo_id.replace("/", "--")
334
+ snap_root = hub_dir / safe / "snapshots"
335
+ if not snap_root.exists():
336
+ return False
337
+ return any(snap_root.iterdir())
338
+
339
+
340
  def download_clap_checkpoint(
341
  models_pretrained_dir: Path,
342
  progress_cb: Optional[Callable[[str], None]] = None,
343
+ phase_cb: Optional[Callable[[int, int, str], None]] = None,
344
  ) -> Path:
345
+ """Download the CLAP audio .pt plus laion_clap's text-side HF snapshots.
346
+
347
+ Four sequential phases β€” emit a phase update (current, total, label) at the
348
+ start of each so a multi-phase progress UI can show real context. Skips
349
+ phases whose artifacts are already on disk.
350
+
351
+ `progress_cb` (str-only) is kept for the bulk-annotate API.
352
+ `phase_cb` (current, total, label) is the structured channel.
353
+ """
354
  target = clap_checkpoint_path(models_pretrained_dir)
355
  target.parent.mkdir(parents=True, exist_ok=True)
356
+ hub_dir = clap_hub_dir(models_pretrained_dir)
357
+ hub_dir.mkdir(parents=True, exist_ok=True)
358
 
359
+ from huggingface_hub import hf_hub_download, snapshot_download
360
  import os
361
 
362
+ total_phases = 1 + len(CLAP_TEXT_DEPS)
363
+
364
+ def _emit(phase_index: int, label: str) -> None:
365
+ if phase_cb:
366
+ phase_cb(phase_index, total_phases, label)
367
+ if progress_cb:
368
+ progress_cb(f"[{phase_index}/{total_phases}] {label}")
369
+
370
+ if not target.exists():
371
+ _emit(1, "CLAP audio model (~2.35 GB)")
372
+
373
+ # Use custom CLAP from fragmenta-models on HF Spaces
374
+ use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true'
375
+ if use_custom_repo:
376
+ repo_id = "MazCodes/fragmenta-models"
377
+ else:
378
+ repo_id = CLAP_REPO
379
+
380
+ downloaded = hf_hub_download(
381
+ repo_id=repo_id,
382
+ filename=CLAP_CKPT_FILENAME,
383
+ local_dir=str(target.parent),
384
+ )
385
+ downloaded_path = Path(downloaded)
386
+ if downloaded_path != target:
387
+ try:
388
+ downloaded_path.replace(target)
389
+ except OSError:
390
+ import shutil
391
+ shutil.copy2(downloaded_path, target)
392
+
393
+ # laion_clap's CLAP_Module(...) constructor instantiates a Roberta text
394
+ # branch plus bert/bart tokenizers at import time. Pre-stage them into
395
+ # our own cache so the rich tier is fully offline after this step.
396
+ # safetensors only β€” pytorch_model.bin is a redundant copy.
397
+ for i, repo_id in enumerate(CLAP_TEXT_DEPS, start=2):
398
+ if _text_dep_snapshot_present(hub_dir, repo_id):
399
+ continue
400
+ _emit(i, f"Text encoder: {repo_id}")
401
+ snapshot_download(
402
+ repo_id=repo_id,
403
+ cache_dir=str(hub_dir),
404
+ allow_patterns=[
405
+ "config.json",
406
+ "tokenizer*",
407
+ "vocab*",
408
+ "merges.txt",
409
+ "special_tokens_map.json",
410
+ "model.safetensors",
411
+ ],
412
+ )
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  return target
415
 
416
 
 
433
  label_sets: Dict[str, List[str]],
434
  sr: int = 22050,
435
  max_seconds: float = 60.0,
436
+ prompt_template: Optional[str] = None,
437
  ) -> Dict[str, Any]:
438
  import librosa
439
+ import warnings
440
 
441
  parts: Dict[str, Any] = {}
442
  try:
 
450
  "error": f"load failed: {exc}",
451
  }
452
 
453
+ # Silent / harmonically flat clips trip librosa's "Trying to estimate
454
+ # tuning from empty frequency set" warning during chroma extraction.
455
+ # The warning is benign β€” the analysis returns sensible defaults β€” but
456
+ # it spams stderr on every silent file, so we mute it here.
457
+ with warnings.catch_warnings():
458
+ warnings.filterwarnings(
459
+ "ignore",
460
+ message="Trying to estimate tuning from empty frequency set",
461
+ )
462
+ parts["bpm"] = _estimate_tempo(y, loaded_sr)
463
+ parts["key"] = _estimate_key(y, loaded_sr)
464
+ parts["brightness"] = _estimate_brightness(y, loaded_sr)
465
+ parts["character"] = _estimate_character(y, loaded_sr)
466
 
467
  if tier == "rich" and clap_tagger is not None:
468
  try:
 
471
  except Exception as exc:
472
  logger.warning("CLAP tagging failed for %s: %s", audio_path.name, exc)
473
 
474
+ # Template-driven prompt assembly. Falls back to the legacy descriptive
475
+ # prose if no template is supplied (call sites that haven't been
476
+ # threaded with project metadata yet).
477
+ if prompt_template is not None and prompt_template.strip():
478
+ from app.backend.data.projects import apply_template
479
+ prompt = apply_template(prompt_template, parts)
480
+ else:
481
+ prompt = _compose_prompt(parts)
482
  return {
483
  "file_name": audio_path.name,
484
  "prompt": prompt,
app/backend/data/pre_encoder.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SA3 pre-encoding job runner β€” Phase 6.
2
+
3
+ Encodes every audio clip in a Dataset Workbench project into SA3 latents
4
+ once, ahead of training, so the training subprocess can skip the SAME
5
+ autoencoder pass per step. Mirrors the shape of `_project_annotate_jobs`
6
+ in app.py (background thread, per-project state, cooperative cancel).
7
+
8
+ Latents land in `<project>/.latents/` β€” a hidden subdirectory inside the
9
+ project folder. Disk layout matches SA3's `pre_encode_dataset.py`:
10
+
11
+ <project>/.latents/
12
+ 000000000000.npy # latent tensor (shape (256, T_lat))
13
+ 000000000000.json # {"prompt": "...", "padding_mask": [...], ...}
14
+ 000001000000.npy
15
+ 000001000000.json
16
+ ...
17
+ silence.npy # padding latent (auto-generated)
18
+ _meta.json # Fragmenta-specific: AE used, source clip count
19
+
20
+ SA3's `train_lora.py --encoded_dir <project>/.latents` consumes this layout
21
+ directly. `SA3Trainer._stage_dataset` auto-detects the directory and feeds
22
+ `--encoded_dir` to the subprocess when latents are present.
23
+
24
+ Cache invalidation lives in projects.py β€” any project mutation that could
25
+ desync the latents (commit, delete_clip, slice_clip) wipes the directory.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import json
31
+ import os
32
+ import re
33
+ import signal
34
+ import subprocess
35
+ import sys
36
+ import threading
37
+ import time
38
+ from pathlib import Path
39
+ from typing import Any, Dict, Optional
40
+
41
+ from app.backend.data.projects import project_path
42
+ from app.core.config import get_config
43
+ from utils.logger import get_logger
44
+
45
+ logger = get_logger("PreEncoder")
46
+
47
+
48
+ # --- Per-project job registry ----------------------------------------------
49
+
50
+ _pre_encode_jobs: Dict[str, Dict[str, Any]] = {}
51
+ _pre_encode_jobs_lock = threading.Lock()
52
+ _pre_encode_processes: Dict[str, subprocess.Popen] = {}
53
+
54
+
55
+ def get_pre_encode_job(project_name: str) -> Dict[str, Any]:
56
+ """Snapshot of the current job state for a project. Always returns a
57
+ well-formed dict so the frontend can render against it without guards."""
58
+ with _pre_encode_jobs_lock:
59
+ job = _pre_encode_jobs.get(project_name)
60
+ if job is None:
61
+ return _idle_job()
62
+ return dict(job)
63
+
64
+
65
+ def _idle_job() -> Dict[str, Any]:
66
+ return {
67
+ "state": "idle", # idle | queued | running | complete | failed | cancelled
68
+ "current": 0, # batch index (0-based)
69
+ "total": 0, # total batches (derived from clip count)
70
+ "current_file": "",
71
+ "error": None,
72
+ "started_at": None,
73
+ "finished_at": None,
74
+ "autoencoder": None,
75
+ }
76
+
77
+
78
+ # --- Autoencoder selection -------------------------------------------------
79
+
80
+ # Bind latents to a specific SA3 autoencoder. Latents from same-s only work
81
+ # with small-music / small-sfx DiTs; same-l latents only work with medium.
82
+ # For v1 we default to same-s (covers the most common base) and leave a
83
+ # manifest in .latents/_meta.json that training reads to verify
84
+ # compatibility. If a user trains against medium with same-s latents,
85
+ # SA3Trainer falls back to non-encoded training and logs a warning.
86
+ DEFAULT_AUTOENCODER = "same-s"
87
+
88
+ # Audio length (samples per channel) the dataset pads/crops to before
89
+ # encoding. SA3's pre_encode_dataset.py defaults to ~285s at 44.1 kHz, which
90
+ # covers any training-time --duration up to that limit (and SA3 small caps
91
+ # at 120s anyway). Longer clips in the project will be cropped to this
92
+ # length during encoding β€” a documented limitation for v1.
93
+ DEFAULT_SAMPLE_SIZE = 12_582_912
94
+
95
+
96
+ # --- Job lifecycle ---------------------------------------------------------
97
+
98
+ def latents_dir(project_name: str) -> Path:
99
+ return project_path(project_name) / ".latents"
100
+
101
+
102
+ def latents_count(project_name: str) -> int:
103
+ d = latents_dir(project_name)
104
+ if not d.exists():
105
+ return 0
106
+ return sum(
107
+ 1 for p in d.glob("*.npy")
108
+ if p.name != "silence.npy"
109
+ )
110
+
111
+
112
+ def latents_meta(project_name: str) -> Optional[Dict[str, Any]]:
113
+ """Read the manifest we drop alongside the .npy files."""
114
+ p = latents_dir(project_name) / "_meta.json"
115
+ if not p.exists():
116
+ return None
117
+ try:
118
+ return json.loads(p.read_text(encoding="utf-8"))
119
+ except Exception:
120
+ return None
121
+
122
+
123
+ def latents_match_base(project_name: str, base_model: str) -> bool:
124
+ """Whether the cached latents are compatible with the chosen base.
125
+
126
+ same-s ↔ small-music / small-sfx (and their *-base variants).
127
+ same-l ↔ medium (and medium-base).
128
+ """
129
+ meta = latents_meta(project_name)
130
+ if not meta:
131
+ return False
132
+ ae = meta.get("autoencoder")
133
+ if ae == "same-s":
134
+ return base_model in ("sa3-small-music", "sa3-small-music-base",
135
+ "sa3-small-sfx", "sa3-small-sfx-base")
136
+ if ae == "same-l":
137
+ return base_model in ("sa3-medium", "sa3-medium-base")
138
+ return False
139
+
140
+
141
+ def cancel_pre_encode(project_name: str) -> bool:
142
+ """Send a cancel signal to an in-flight job. Returns True if cancelled."""
143
+ with _pre_encode_jobs_lock:
144
+ job = _pre_encode_jobs.get(project_name)
145
+ if not job or job.get("state") not in ("queued", "running"):
146
+ return False
147
+ job["state"] = "cancelled"
148
+ job["cancelled"] = True
149
+
150
+ proc = _pre_encode_processes.get(project_name)
151
+ if proc is not None and proc.poll() is None:
152
+ try:
153
+ proc.send_signal(signal.SIGINT)
154
+ try:
155
+ proc.wait(timeout=5)
156
+ except subprocess.TimeoutExpired:
157
+ proc.terminate()
158
+ try:
159
+ proc.wait(timeout=3)
160
+ except subprocess.TimeoutExpired:
161
+ proc.kill()
162
+ except Exception as exc:
163
+ logger.warning("Failed to signal pre-encode subprocess: %s", exc)
164
+ return True
165
+
166
+
167
+ def start_pre_encode(
168
+ project_name: str,
169
+ autoencoder: Optional[str] = None,
170
+ sample_size: Optional[int] = None,
171
+ ) -> Dict[str, Any]:
172
+ """Spawn the pre-encode subprocess in a background thread. Returns the
173
+ job state β€” frontend polls /pre-encode/status thereafter.
174
+ """
175
+ proj_dir = project_path(project_name)
176
+ if not proj_dir.exists():
177
+ raise FileNotFoundError(f"project not found: {project_name}")
178
+
179
+ ae = autoencoder or DEFAULT_AUTOENCODER
180
+ if ae not in ("same-s", "same-l"):
181
+ raise ValueError(f"autoencoder must be 'same-s' or 'same-l'; got {ae!r}")
182
+
183
+ with _pre_encode_jobs_lock:
184
+ existing = _pre_encode_jobs.get(project_name)
185
+ if existing and existing.get("state") in ("queued", "running"):
186
+ return dict(existing)
187
+
188
+ # Count source clips (sidecars committed) so we know the denominator.
189
+ sidecars = list(proj_dir.glob("*.txt"))
190
+ clip_count = sum(
191
+ 1 for p in sidecars
192
+ if p.read_text(encoding="utf-8").strip()
193
+ and p.with_suffix(".wav").exists() # cheap & accurate enough
194
+ )
195
+
196
+ job: Dict[str, Any] = {
197
+ "state": "queued",
198
+ "current": 0,
199
+ "total": clip_count,
200
+ "current_file": "",
201
+ "error": None,
202
+ "started_at": time.time(),
203
+ "finished_at": None,
204
+ "autoencoder": ae,
205
+ "cancelled": False,
206
+ }
207
+ _pre_encode_jobs[project_name] = job
208
+
209
+ thread = threading.Thread(
210
+ target=_run_pre_encode,
211
+ args=(project_name, ae, sample_size or DEFAULT_SAMPLE_SIZE),
212
+ daemon=True,
213
+ name=f"sa3-pre-encode:{project_name}",
214
+ )
215
+ thread.start()
216
+ return get_pre_encode_job(project_name)
217
+
218
+
219
+ # --- Worker ----------------------------------------------------------------
220
+
221
+ def _update_job(project_name: str, **fields: Any) -> None:
222
+ with _pre_encode_jobs_lock:
223
+ job = _pre_encode_jobs.get(project_name)
224
+ if job is None:
225
+ return
226
+ job.update(fields)
227
+
228
+
229
+ def _run_pre_encode(project_name: str, ae: str, sample_size: int) -> None:
230
+ """Background-thread target. Spawns the SA3 pre_encode_dataset.py script,
231
+ streams stdout for progress, writes a _meta.json manifest on success."""
232
+ cfg = get_config()
233
+ proj_dir = project_path(project_name)
234
+ out_dir = latents_dir(project_name)
235
+ out_dir.mkdir(parents=True, exist_ok=True)
236
+
237
+ sa3_vendor = cfg.get_path("stable_audio_3")
238
+ venv_python = sys.executable
239
+
240
+ cmd = [
241
+ venv_python,
242
+ str(sa3_vendor / "scripts" / "pre_encode_dataset.py"),
243
+ "--model", ae,
244
+ "--data_dir", str(proj_dir),
245
+ "--output_path", str(out_dir),
246
+ "--batch_size", "1",
247
+ "--sample_size", str(int(sample_size)),
248
+ ]
249
+
250
+ env = os.environ.copy()
251
+ pp = env.get("PYTHONPATH", "")
252
+ env["PYTHONPATH"] = (
253
+ f"{sa3_vendor}{os.pathsep}{pp}" if pp else str(sa3_vendor)
254
+ )
255
+ hub_dir = cfg.get_path("models_pretrained") / "sa3" / "hub"
256
+ env["HF_HUB_CACHE"] = str(hub_dir)
257
+ env["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
258
+ env["TRANSFORMERS_CACHE"] = str(hub_dir)
259
+ env["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
260
+ env["HF_HUB_OFFLINE"] = "1"
261
+ env["TRANSFORMERS_OFFLINE"] = "1"
262
+
263
+ _update_job(project_name, state="running")
264
+ logger.info(
265
+ "Pre-encoding started Β· project=%s Β· autoencoder=%s Β· clips=%d Β· sample_size=%d",
266
+ project_name, ae, get_pre_encode_job(project_name)["total"], sample_size,
267
+ )
268
+
269
+ batch_pat = re.compile(r"Processing batch (\d+)")
270
+ process: Optional[subprocess.Popen] = None
271
+ try:
272
+ process = subprocess.Popen(
273
+ cmd,
274
+ cwd=str(cfg.project_root),
275
+ env=env,
276
+ stdout=subprocess.PIPE,
277
+ stderr=subprocess.STDOUT,
278
+ text=True,
279
+ bufsize=1,
280
+ )
281
+ _pre_encode_processes[project_name] = process
282
+
283
+ if process.stdout is not None:
284
+ for line in process.stdout:
285
+ line = line.rstrip()
286
+ m = batch_pat.search(line)
287
+ if m:
288
+ # Subprocess prints "Processing batch N" once per batch
289
+ # (and batch_size=1 β†’ one batch per clip). N starts at 0.
290
+ _update_job(project_name, current=int(m.group(1)) + 1)
291
+
292
+ rc = process.wait() if process else 1
293
+
294
+ # Check whether we got cancelled mid-flight.
295
+ snapshot = get_pre_encode_job(project_name)
296
+ if snapshot.get("cancelled"):
297
+ _update_job(
298
+ project_name,
299
+ state="cancelled",
300
+ finished_at=time.time(),
301
+ )
302
+ logger.info("Pre-encoding cancelled Β· project=%s", project_name)
303
+ return
304
+
305
+ if rc != 0:
306
+ _update_job(
307
+ project_name,
308
+ state="failed",
309
+ error=f"pre_encode_dataset.py exited with code {rc}",
310
+ finished_at=time.time(),
311
+ )
312
+ logger.error(
313
+ "Pre-encoding failed (exit %s) Β· project=%s",
314
+ rc, project_name,
315
+ )
316
+ return
317
+
318
+ # Success β€” write manifest so SA3Trainer can verify AE compatibility.
319
+ manifest = {
320
+ "autoencoder": ae,
321
+ "sample_size": sample_size,
322
+ "created_at": time.time(),
323
+ "source_clip_count": snapshot.get("total", 0),
324
+ "encoded_count": latents_count(project_name),
325
+ }
326
+ try:
327
+ (out_dir / "_meta.json").write_text(
328
+ json.dumps(manifest, indent=2), encoding="utf-8",
329
+ )
330
+ except Exception as exc:
331
+ logger.warning("Failed to write latents manifest: %s", exc)
332
+
333
+ _update_job(
334
+ project_name,
335
+ state="complete",
336
+ current=manifest["encoded_count"],
337
+ total=manifest["encoded_count"] or snapshot.get("total", 0),
338
+ finished_at=time.time(),
339
+ )
340
+ logger.info(
341
+ "Pre-encoding complete Β· project=%s Β· %d latent(s) Β· ae=%s",
342
+ project_name, manifest["encoded_count"], ae,
343
+ )
344
+
345
+ except Exception as exc:
346
+ _update_job(
347
+ project_name,
348
+ state="failed",
349
+ error=str(exc),
350
+ finished_at=time.time(),
351
+ )
352
+ logger.exception("Pre-encoding crashed for project=%s", project_name)
353
+ finally:
354
+ _pre_encode_processes.pop(project_name, None)
app/backend/data/projects.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """On-disk projects + buffered in-memory editing for SA3 sidecar datasets.
2
+
3
+ A *project* is a folder under `<user_data_dir>/projects/<name>/` (or wherever
4
+ `FRAGMENTA_PROJECTS_DIR` points) holding audio + `.txt` sidecar pairs plus a
5
+ hidden `.project.json` with Fragmenta metadata. The on-disk folder is the
6
+ **committed** dataset β€” what training reads, what survives across app
7
+ restarts.
8
+
9
+ The UI works against an **in-memory session** per loaded project. Prompt
10
+ edits, auto-annotate output, and just-ingested audio all live in memory
11
+ until the user explicitly persists them via:
12
+
13
+ Save β†’ write `.draft.json` (transient, hidden). Survives app restart
14
+ but is not the SA3 deliverable.
15
+ Commit β†’ flush prompts to `.txt` sidecars, mark current audio as
16
+ committed in `.project.json`, delete `.draft.json`.
17
+ Discard β†’ drop the in-memory session, delete `.draft.json`, remove any
18
+ audio files added since the last commit.
19
+
20
+ See DATASET_PREP_REDESIGN.md for the full design and rationale.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import json
26
+ import logging
27
+ import os
28
+ import re
29
+ import shutil
30
+ import threading
31
+ import time
32
+ from dataclasses import dataclass, field
33
+ from pathlib import Path
34
+ from typing import Any, Dict, List, Optional, Tuple
35
+
36
+ from app.backend.data.auto_annotator import AUDIO_EXTENSIONS, _iter_audio_files
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ PROJECT_METADATA_FILENAME = ".project.json"
41
+ PROJECT_DRAFT_FILENAME = ".draft.json"
42
+ DEFAULT_INGEST_MODE = "copy" # copy | symlink
43
+ INGEST_MODES = ("copy", "symlink")
44
+
45
+ # SA3's prompting guide (vendor/stable-audio-3/docs/guides/prompting.md)
46
+ # distinguishes three generation modes β€” music, stems / solo instruments,
47
+ # and audio samples / SFX β€” each with its own AudioSparx-tag convention.
48
+ # We ship one preset per mode and let the user pick a single id; the rest
49
+ # is opinionated defaults. Each segment is rendered by apply_template's
50
+ # segment-drop semantics, so missing CLAP attributes never leave dangling
51
+ # punctuation.
52
+ PROMPT_TEMPLATE_PRESETS: Dict[str, Dict[str, str]] = {
53
+ "music": {
54
+ "label": "Music",
55
+ "description": "Full instrumental tracks (SA3's `TrackType: Music` convention).",
56
+ "template": (
57
+ "TrackType: Music, VocalType: Instrumental, "
58
+ "Genre: {genre}, Mood: {mood}, Instruments: {instruments}, "
59
+ "BPM: {bpm}, Key: {key}"
60
+ ),
61
+ },
62
+ "instrument": {
63
+ "label": "Instrument / Stem",
64
+ "description": "Isolated parts or single-instrument pieces (`TrackType: Instrument`).",
65
+ "template": (
66
+ "TrackType: Instrument, "
67
+ "Instruments: {instruments}, Genre: {genre}, "
68
+ "BPM: {bpm}, Key: {key}, Mood: {mood}"
69
+ ),
70
+ },
71
+ "sfx": {
72
+ "label": "Sample / SFX",
73
+ "description": "Sound effects, one-shots, samples (`TrackType: SFX`).",
74
+ "template": "TrackType: SFX, {brightness}, {character}",
75
+ },
76
+ }
77
+ DEFAULT_PROMPT_TEMPLATE_PRESET = "music"
78
+
79
+ # Names must look like reasonable filesystem folders.
80
+ _VALID_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9 _\-.]{0,99}$")
81
+
82
+
83
+ # ---------- Locations -------------------------------------------------------
84
+
85
+
86
+ def get_projects_dir() -> Path:
87
+ """Resolve the projects root.
88
+
89
+ Honors `FRAGMENTA_PROJECTS_DIR` for power users; otherwise sits next to
90
+ `data/` and `models/` under the configured user_data_dir.
91
+ """
92
+ override = os.environ.get("FRAGMENTA_PROJECTS_DIR")
93
+ if override:
94
+ root = Path(override).expanduser()
95
+ else:
96
+ from app.core.config import get_config
97
+ root = get_config().user_data_dir / "projects"
98
+ root.mkdir(parents=True, exist_ok=True)
99
+ return root
100
+
101
+
102
+ def project_path(name: str) -> Path:
103
+ return get_projects_dir() / name
104
+
105
+
106
+ def project_metadata_path(name: str) -> Path:
107
+ return project_path(name) / PROJECT_METADATA_FILENAME
108
+
109
+
110
+ def project_draft_path(name: str) -> Path:
111
+ return project_path(name) / PROJECT_DRAFT_FILENAME
112
+
113
+
114
+ # ---------- Validation ------------------------------------------------------
115
+
116
+
117
+ def sanitize_project_name(raw: Any) -> str:
118
+ if not isinstance(raw, str):
119
+ raise ValueError("Project name must be a string.")
120
+ name = raw.strip()
121
+ if not name:
122
+ raise ValueError("Project name cannot be empty.")
123
+ if name in (".", ".."):
124
+ raise ValueError("Invalid project name.")
125
+ if not _VALID_NAME_RE.match(name):
126
+ raise ValueError(
127
+ "Project name must start with a letter or digit and may only "
128
+ "contain letters, digits, spaces, dashes, underscores, and dots."
129
+ )
130
+ return name
131
+
132
+
133
+ # ---------- Disk persistence: committed state -------------------------------
134
+
135
+
136
+ def _default_metadata(name: str) -> Dict[str, Any]:
137
+ now = time.time()
138
+ return {
139
+ "name": name,
140
+ "created_at": now,
141
+ "modified_at": now,
142
+ "committed_at": None,
143
+ "ingest_mode": DEFAULT_INGEST_MODE,
144
+ "prompt_template_preset": DEFAULT_PROMPT_TEMPLATE_PRESET,
145
+ "source_folders": [],
146
+ "committed_files": [], # files written to disk + already committed
147
+ }
148
+
149
+
150
+ def _read_metadata(name: str) -> Dict[str, Any]:
151
+ path = project_metadata_path(name)
152
+ if not path.exists():
153
+ return _default_metadata(name)
154
+ try:
155
+ with open(path, "r", encoding="utf-8") as f:
156
+ data = json.load(f)
157
+ except (OSError, json.JSONDecodeError) as exc:
158
+ logger.warning("Could not read project metadata %s: %s; using defaults.", path, exc)
159
+ return _default_metadata(name)
160
+ defaults = _default_metadata(name)
161
+ for k, v in defaults.items():
162
+ data.setdefault(k, v)
163
+ return data
164
+
165
+
166
+ def _write_metadata(name: str, metadata: Dict[str, Any]) -> None:
167
+ metadata["modified_at"] = time.time()
168
+ path = project_metadata_path(name)
169
+ path.parent.mkdir(parents=True, exist_ok=True)
170
+ tmp = path.with_suffix(path.suffix + ".tmp")
171
+ with open(tmp, "w", encoding="utf-8") as f:
172
+ json.dump(metadata, f, indent=2)
173
+ os.replace(tmp, path)
174
+
175
+
176
+ def _sidecar_for(audio_path: Path) -> Path:
177
+ return audio_path.with_suffix(".txt")
178
+
179
+
180
+ def _read_sidecar(audio_path: Path) -> str:
181
+ txt = _sidecar_for(audio_path)
182
+ if not txt.exists():
183
+ return ""
184
+ try:
185
+ return txt.read_text(encoding="utf-8").strip()
186
+ except OSError:
187
+ return ""
188
+
189
+
190
+ def _write_sidecar(audio_path: Path, prompt: str) -> None:
191
+ _sidecar_for(audio_path).write_text(prompt or "", encoding="utf-8")
192
+
193
+
194
+ # ---------- Disk persistence: draft state -----------------------------------
195
+
196
+
197
+ def _read_draft(name: str) -> Optional[Dict[str, Any]]:
198
+ path = project_draft_path(name)
199
+ if not path.exists():
200
+ return None
201
+ try:
202
+ with open(path, "r", encoding="utf-8") as f:
203
+ return json.load(f)
204
+ except (OSError, json.JSONDecodeError) as exc:
205
+ logger.warning("Could not read draft %s: %s; treating as no draft.", path, exc)
206
+ return None
207
+
208
+
209
+ def _write_draft(name: str, draft: Dict[str, Any]) -> None:
210
+ draft["saved_at"] = time.time()
211
+ path = project_draft_path(name)
212
+ path.parent.mkdir(parents=True, exist_ok=True)
213
+ tmp = path.with_suffix(path.suffix + ".tmp")
214
+ with open(tmp, "w", encoding="utf-8") as f:
215
+ json.dump(draft, f, indent=2)
216
+ os.replace(tmp, path)
217
+
218
+
219
+ def _delete_draft(name: str) -> None:
220
+ path = project_draft_path(name)
221
+ if path.exists():
222
+ path.unlink()
223
+
224
+
225
+ # ---------- In-memory session ----------------------------------------------
226
+
227
+
228
+ @dataclass
229
+ class ClipState:
230
+ """One clip in an active project session.
231
+
232
+ `prompt` is the live in-memory value (what the UI shows). `committed_prompt`
233
+ is what's on disk in the sidecar β€” used to compute dirtiness.
234
+
235
+ `parent` is the original clip's file_name if this clip was produced by a
236
+ slice operation in the current session. In-memory only; not persisted
237
+ across restart (yet). Future merge-back will need disk-level lineage.
238
+ """
239
+ file_name: str
240
+ path: str
241
+ prompt: str = ""
242
+ committed_prompt: str = ""
243
+ committed: bool = True # False if audio was added since last commit
244
+ parent: Optional[str] = None
245
+
246
+ def to_dict(self) -> Dict[str, Any]:
247
+ return {
248
+ "file_name": self.file_name,
249
+ "path": self.path,
250
+ "prompt": self.prompt,
251
+ "committed_prompt": self.committed_prompt,
252
+ "committed": self.committed,
253
+ "dirty": self.prompt != self.committed_prompt,
254
+ "parent": self.parent,
255
+ }
256
+
257
+
258
+ @dataclass
259
+ class ProjectSession:
260
+ """In-memory view of a project. One per loaded project name.
261
+
262
+ Loading happens lazily on first GET. The session stays alive until
263
+ the user discards, commits, or the process exits.
264
+ """
265
+ name: str
266
+ clips: Dict[str, ClipState] = field(default_factory=dict) # by file_name
267
+ saved_at: Optional[float] = None # last time .draft.json was written
268
+ last_save_snapshot: Dict[str, str] = field(default_factory=dict)
269
+ metadata: Dict[str, Any] = field(default_factory=dict)
270
+ cancel_event: threading.Event = field(default_factory=threading.Event)
271
+ lock: threading.Lock = field(default_factory=threading.Lock)
272
+ # file_name -> (peaks, duration). Lazily filled by get_or_compute_peaks.
273
+ # Cleared on Discard. Survives an annotate; safe to recompute on miss.
274
+ peaks_cache: Dict[str, Tuple[List[float], float]] = field(default_factory=dict)
275
+ # file_name -> duration_sec. Same lifecycle, but populated cheaply via
276
+ # soundfile.info() instead of waiting for a peaks fetch.
277
+ duration_cache: Dict[str, float] = field(default_factory=dict)
278
+
279
+ def _draft_snapshot(self) -> Dict[str, str]:
280
+ """Map file_name -> prompt, only for clips whose prompt differs from
281
+ the committed sidecar. Used both to decide if a Save is needed and
282
+ to compute the on-disk draft contents."""
283
+ return {c.file_name: c.prompt for c in self.clips.values() if c.prompt != c.committed_prompt}
284
+
285
+ def has_dirty_prompts(self) -> bool:
286
+ return any(c.prompt != c.committed_prompt for c in self.clips.values())
287
+
288
+ def has_uncommitted_files(self) -> bool:
289
+ return any(not c.committed for c in self.clips.values())
290
+
291
+ def has_unsaved_changes(self) -> bool:
292
+ """True if the in-memory state differs from the saved draft."""
293
+ return self._draft_snapshot() != self.last_save_snapshot
294
+
295
+ def to_dict(self) -> Dict[str, Any]:
296
+ ordered = sorted(self.clips.values(), key=lambda c: c.file_name)
297
+ # Phase 6 β€” pre-encoded latents state. The latents live inside the
298
+ # project at .latents/. Surface presence + count for the UI, plus
299
+ # the per-project "don't ask again" flag for the post-commit dialog.
300
+ proj_path = project_path(self.name)
301
+ latents_dir = proj_path / ".latents"
302
+ latents_npy = (
303
+ [p for p in latents_dir.glob("*.npy") if p.name != "silence.npy"]
304
+ if latents_dir.exists() else []
305
+ )
306
+ return {
307
+ "name": self.name,
308
+ "created_at": self.metadata.get("created_at"),
309
+ "modified_at": self.metadata.get("modified_at"),
310
+ "committed_at": self.metadata.get("committed_at"),
311
+ "ingest_mode": self.metadata.get("ingest_mode", DEFAULT_INGEST_MODE),
312
+ "prompt_template_preset": (
313
+ self.metadata.get("prompt_template_preset") or DEFAULT_PROMPT_TEMPLATE_PRESET
314
+ ),
315
+ "prompt_template_presets": [
316
+ {"id": k, "label": v["label"], "description": v["description"], "template": v["template"]}
317
+ for k, v in PROMPT_TEMPLATE_PRESETS.items()
318
+ ],
319
+ "source_folders": list(self.metadata.get("source_folders", [])),
320
+ "saved_at": self.saved_at,
321
+ "dirty": self.has_dirty_prompts() or self.has_uncommitted_files(),
322
+ "has_unsaved_changes": self.has_unsaved_changes(),
323
+ "uncommitted_files": [c.file_name for c in ordered if not c.committed],
324
+ "clips": [c.to_dict() for c in ordered],
325
+ "clip_count": len(self.clips),
326
+ "latents_present": bool(latents_npy),
327
+ "latents_count": len(latents_npy),
328
+ "suppress_pre_encode_prompt": bool(self.metadata.get("suppress_pre_encode_prompt")),
329
+ }
330
+
331
+
332
+ # Registry of active sessions keyed by project name.
333
+ _sessions: Dict[str, ProjectSession] = {}
334
+ _sessions_lock = threading.Lock()
335
+
336
+
337
+ def _get_or_load_session(name: str) -> ProjectSession:
338
+ """Return the active session for `name`, loading from disk if needed."""
339
+ with _sessions_lock:
340
+ existing = _sessions.get(name)
341
+ if existing is not None:
342
+ return existing
343
+
344
+ # Validate folder exists.
345
+ path = project_path(name)
346
+ if not path.exists() or not path.is_dir():
347
+ raise FileNotFoundError(f"Project not found: {name}")
348
+
349
+ metadata = _read_metadata(name)
350
+ committed_files = set(metadata.get("committed_files") or [])
351
+
352
+ # Build clip states from the disk layout. `committed_prompt` is whatever's
353
+ # in the .txt sidecar today.
354
+ clips: Dict[str, ClipState] = {}
355
+ for audio_path in sorted(path.iterdir()):
356
+ if not audio_path.is_file():
357
+ continue
358
+ if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
359
+ continue
360
+ committed_prompt = _read_sidecar(audio_path)
361
+ is_committed = audio_path.name in committed_files
362
+ clips[audio_path.name] = ClipState(
363
+ file_name=audio_path.name,
364
+ path=str(audio_path),
365
+ prompt=committed_prompt,
366
+ committed_prompt=committed_prompt,
367
+ committed=is_committed,
368
+ )
369
+
370
+ session = ProjectSession(name=name, clips=clips, metadata=metadata)
371
+
372
+ # Overlay any draft prompts on top of committed values.
373
+ draft = _read_draft(name)
374
+ if draft:
375
+ for file_name, prompt in (draft.get("prompts") or {}).items():
376
+ clip = session.clips.get(file_name)
377
+ if clip is not None:
378
+ clip.prompt = prompt
379
+ session.saved_at = draft.get("saved_at")
380
+ session.last_save_snapshot = dict(draft.get("prompts") or {})
381
+
382
+ with _sessions_lock:
383
+ # Race: another thread may have loaded concurrently. Use whichever
384
+ # got in first.
385
+ existing = _sessions.get(name)
386
+ if existing is not None:
387
+ return existing
388
+ _sessions[name] = session
389
+ return session
390
+
391
+
392
+ def _drop_session(name: str) -> None:
393
+ with _sessions_lock:
394
+ _sessions.pop(name, None)
395
+
396
+
397
+ # ---------- CRUD ------------------------------------------------------------
398
+
399
+
400
+ def list_projects() -> List[Dict[str, Any]]:
401
+ root = get_projects_dir()
402
+ out: List[Dict[str, Any]] = []
403
+ for entry in sorted(root.iterdir()):
404
+ if not entry.is_dir() or entry.name.startswith("."):
405
+ continue
406
+ try:
407
+ meta = _read_metadata(entry.name)
408
+ except Exception as exc:
409
+ logger.warning("Skipping project %s: %s", entry.name, exc)
410
+ continue
411
+ clip_count = sum(
412
+ 1 for f in entry.iterdir()
413
+ if f.is_file() and f.suffix.lower() in AUDIO_EXTENSIONS
414
+ )
415
+ has_draft = project_draft_path(entry.name).exists()
416
+ out.append({
417
+ "name": entry.name,
418
+ "created_at": meta.get("created_at"),
419
+ "modified_at": meta.get("modified_at"),
420
+ "committed_at": meta.get("committed_at"),
421
+ "clip_count": clip_count,
422
+ "has_draft": has_draft,
423
+ })
424
+ return out
425
+
426
+
427
+ def create_project(name: str) -> Dict[str, Any]:
428
+ name = sanitize_project_name(name)
429
+ path = project_path(name)
430
+ if path.exists():
431
+ raise FileExistsError(f"Project '{name}' already exists.")
432
+ path.mkdir(parents=True)
433
+ metadata = _default_metadata(name)
434
+ _write_metadata(name, metadata)
435
+ return get_project(name)
436
+
437
+
438
+ def get_project(name: str) -> Dict[str, Any]:
439
+ session = _get_or_load_session(name)
440
+ with session.lock:
441
+ return session.to_dict()
442
+
443
+
444
+ def _stage_file(src: Path, dst: Path, mode: str) -> str:
445
+ """Place `src` at `dst` using the requested ingest mode."""
446
+ if dst.exists() or dst.is_symlink():
447
+ return "skipped"
448
+ if mode == "symlink":
449
+ try:
450
+ dst.symlink_to(src.resolve())
451
+ return "symlinked"
452
+ except OSError as exc:
453
+ logger.warning("Symlink failed for %s -> %s: %s; falling back to copy.", src, dst, exc)
454
+ shutil.copy2(src, dst)
455
+ return "copied"
456
+ else:
457
+ shutil.copy2(src, dst)
458
+ return "copied"
459
+
460
+
461
+ def ingest_folder(name: str, source_folder: Path, mode: str) -> Dict[str, Any]:
462
+ """Add every audio file under `source_folder` to project `name`.
463
+
464
+ Audio is written to disk immediately (we don't buffer gigabytes). The
465
+ new files are flagged as uncommitted in the session so a later Discard
466
+ can remove them.
467
+ """
468
+ if mode not in INGEST_MODES:
469
+ raise ValueError(f"Invalid ingest mode: {mode}")
470
+ if not source_folder.exists() or not source_folder.is_dir():
471
+ raise FileNotFoundError(f"Source folder not found: {source_folder}")
472
+
473
+ session = _get_or_load_session(name)
474
+ proj_path = project_path(name)
475
+
476
+ files = _iter_audio_files(source_folder)
477
+ if not files:
478
+ raise ValueError(f"No audio files found in {source_folder}")
479
+
480
+ copied = 0
481
+ symlinked = 0
482
+ skipped = 0
483
+ with session.lock:
484
+ for src in files:
485
+ dst = proj_path / src.name
486
+ tag = _stage_file(src, dst, mode)
487
+ if tag == "copied":
488
+ copied += 1
489
+ elif tag == "symlinked":
490
+ symlinked += 1
491
+ else:
492
+ skipped += 1
493
+ if tag != "skipped" and src.name not in session.clips:
494
+ # Newly added file β€” uncommitted.
495
+ session.clips[src.name] = ClipState(
496
+ file_name=src.name,
497
+ path=str(dst),
498
+ prompt="",
499
+ committed_prompt="",
500
+ committed=False,
501
+ )
502
+
503
+ session.metadata["ingest_mode"] = mode
504
+ src_abs = str(source_folder.resolve())
505
+ if src_abs not in session.metadata.setdefault("source_folders", []):
506
+ session.metadata["source_folders"].append(src_abs)
507
+
508
+ return {
509
+ "copied": copied,
510
+ "symlinked": symlinked,
511
+ "skipped": skipped,
512
+ "added": copied + symlinked,
513
+ }
514
+
515
+
516
+ def update_clip_prompt(name: str, file_name: str, prompt: str) -> Dict[str, Any]:
517
+ """In-memory only. Disk is not touched until Save or Commit."""
518
+ session = _get_or_load_session(name)
519
+ with session.lock:
520
+ clip = session.clips.get(file_name)
521
+ if clip is None:
522
+ raise FileNotFoundError(f"Clip not found in project '{name}': {file_name}")
523
+ clip.prompt = prompt or ""
524
+ return clip.to_dict()
525
+
526
+
527
+ def delete_clip(name: str, file_name: str) -> None:
528
+ """Remove a clip immediately (audio + sidecar + session entry).
529
+
530
+ Treated like ingest: the disk change happens now, since carrying a
531
+ pending-deletion in memory complicates everything for no real win.
532
+ Discard cannot recover deleted files.
533
+ """
534
+ session = _get_or_load_session(name)
535
+ proj_path = project_path(name)
536
+ with session.lock:
537
+ audio_path = proj_path / file_name
538
+ txt_path = _sidecar_for(audio_path)
539
+ if audio_path.exists():
540
+ audio_path.unlink()
541
+ if txt_path.exists():
542
+ txt_path.unlink()
543
+ session.clips.pop(file_name, None)
544
+ # Evict any cached peaks for this file (regardless of N).
545
+ for key in list(session.peaks_cache):
546
+ if key.startswith(f"{file_name}:"):
547
+ del session.peaks_cache[key]
548
+ session.duration_cache.pop(file_name, None)
549
+ committed = session.metadata.get("committed_files") or []
550
+ if file_name in committed:
551
+ session.metadata["committed_files"] = [f for f in committed if f != file_name]
552
+ # Invalidate latents β€” outside the lock so we don't block under FS I/O.
553
+ _invalidate_latents(name)
554
+
555
+
556
+ # ---------- Save / Commit / Discard -----------------------------------------
557
+
558
+
559
+ def save_project(name: str) -> Dict[str, Any]:
560
+ """Persist the current in-memory prompt diffs as a hidden draft."""
561
+ session = _get_or_load_session(name)
562
+ with session.lock:
563
+ snapshot = session._draft_snapshot()
564
+ draft = {
565
+ "prompts": snapshot,
566
+ "uncommitted_files": [c.file_name for c in session.clips.values() if not c.committed],
567
+ }
568
+ _write_draft(name, draft)
569
+ session.saved_at = time.time()
570
+ session.last_save_snapshot = dict(snapshot)
571
+ return session.to_dict()
572
+
573
+
574
+ def _invalidate_latents(name: str) -> None:
575
+ """Phase 6 β€” wipe any pre-encoded latents for this project.
576
+
577
+ Latents are bound to specific source-clip content; any mutation that
578
+ changes the source set (commit, delete_clip, slice_clip) renders them
579
+ misaligned. v1 strategy is wipe-and-recompute; per-clip invalidation
580
+ is a follow-up (not worth the complexity for the speed-up we get).
581
+ """
582
+ latents_dir = project_path(name) / ".latents"
583
+ if latents_dir.exists():
584
+ shutil.rmtree(latents_dir, ignore_errors=True)
585
+
586
+
587
+ def update_pre_encode_suppression(name: str, suppress: bool) -> Dict[str, Any]:
588
+ """Persist the 'Don't ask again' choice from the post-commit dialog.
589
+
590
+ Stored on .project.json so it survives restart. The Training-tab
591
+ fallback button is always available regardless of this flag.
592
+ """
593
+ session = _get_or_load_session(name)
594
+ with session.lock:
595
+ session.metadata["suppress_pre_encode_prompt"] = bool(suppress)
596
+ _write_metadata(name, session.metadata)
597
+ return session.to_dict()
598
+
599
+
600
+ def commit_project(name: str) -> Dict[str, Any]:
601
+ """Flush in-memory state to disk as the canonical SA3 dataset.
602
+
603
+ Overwrites existing sidecars. Marks all current audio as committed.
604
+ Deletes any draft. Wipes any pre-encoded latents β€” re-encode is
605
+ explicit via the post-commit dialog or the Training-tab button.
606
+ """
607
+ _invalidate_latents(name)
608
+ session = _get_or_load_session(name)
609
+ proj_path = project_path(name)
610
+ with session.lock:
611
+ # Write a sidecar for every clip, even if the prompt didn't change.
612
+ # This guarantees the on-disk state is exactly the in-memory state
613
+ # after Commit, no surprises.
614
+ for clip in session.clips.values():
615
+ audio_path = proj_path / clip.file_name
616
+ _write_sidecar(audio_path, clip.prompt)
617
+ clip.committed_prompt = clip.prompt
618
+ clip.committed = True
619
+
620
+ session.metadata["committed_files"] = sorted(session.clips.keys())
621
+ session.metadata["committed_at"] = time.time()
622
+ _write_metadata(name, session.metadata)
623
+ _delete_draft(name)
624
+ session.saved_at = None
625
+ session.last_save_snapshot = {}
626
+ return session.to_dict()
627
+
628
+
629
+ def delete_project(name: str) -> None:
630
+ """Permanently remove a project β€” folder, sidecars, drafts, session.
631
+
632
+ Destructive: there is no recovery path. Caller should confirm with
633
+ the user before invoking.
634
+ """
635
+ proj_path = project_path(name)
636
+ if not proj_path.exists():
637
+ raise FileNotFoundError(f"Project not found: {name}")
638
+
639
+ # Cancel any in-flight annotate first, drop the session, then nuke
640
+ # the folder. Order matters: if we rm the folder while another
641
+ # thread is writing to it (e.g. annotate writing prompts to memory
642
+ # is fine, but the audio-stream endpoint could be holding a file
643
+ # handle), at least the session is gone so no fresh writes happen.
644
+ with _sessions_lock:
645
+ existing = _sessions.pop(name, None)
646
+ if existing is not None:
647
+ existing.cancel_event.set()
648
+ shutil.rmtree(proj_path, ignore_errors=True)
649
+
650
+
651
+ def discard_project(name: str) -> Dict[str, Any]:
652
+ """Throw away all uncommitted work.
653
+
654
+ - Delete the draft.
655
+ - Delete audio files added since the last commit (and their sidecars).
656
+ - Drop the in-memory session so the next GET rebuilds from disk.
657
+ """
658
+ session = _get_or_load_session(name)
659
+ proj_path = project_path(name)
660
+ with session.lock:
661
+ # Cancel any in-flight annotate before we tear state apart.
662
+ session.cancel_event.set()
663
+
664
+ uncommitted = [c.file_name for c in session.clips.values() if not c.committed]
665
+ for file_name in uncommitted:
666
+ audio_path = proj_path / file_name
667
+ txt_path = _sidecar_for(audio_path)
668
+ if audio_path.exists():
669
+ audio_path.unlink()
670
+ if txt_path.exists():
671
+ txt_path.unlink()
672
+
673
+ _delete_draft(name)
674
+
675
+ _drop_session(name)
676
+ return get_project(name)
677
+
678
+
679
+ # ---------- Annotate cancellation handle ------------------------------------
680
+
681
+
682
+ def get_session_handle(name: str) -> ProjectSession:
683
+ """Used by the annotate endpoint to share a cancel handle + clip dict."""
684
+ return _get_or_load_session(name)
685
+
686
+
687
+ def reset_cancel(session: ProjectSession) -> None:
688
+ session.cancel_event.clear()
689
+
690
+
691
+ # ---------- Prompt template -------------------------------------------------
692
+
693
+
694
+ _TEMPLATE_VAR_RE = re.compile(r"\{([a-zA-Z_][a-zA-Z0-9_]*)\}")
695
+
696
+
697
+ def _render_value(name: str, raw: Any) -> str:
698
+ """Stringify one variable value. Lists get joined; falsy is empty."""
699
+ if raw is None:
700
+ return ""
701
+ if isinstance(raw, (list, tuple)):
702
+ parts = [str(x).strip() for x in raw if str(x).strip()]
703
+ return ", ".join(parts)
704
+ text = str(raw).strip()
705
+ return text
706
+
707
+
708
+ def apply_template(template: str, attributes: Dict[str, Any]) -> str:
709
+ """Segment-based templating with graceful missing-value handling.
710
+
711
+ The template is split on ',' (segments). For each segment, every
712
+ {var} placeholder is resolved against `attributes`. If any placeholder
713
+ in the segment resolves to empty/missing, the whole segment is dropped
714
+ β€” so a missing key/BPM/whatever doesn't leave dangling punctuation.
715
+
716
+ Segments without any placeholders (e.g. "TrackType: Music") always
717
+ appear.
718
+ """
719
+ if not template:
720
+ return ""
721
+ out_segments: List[str] = []
722
+ for raw_segment in template.split(","):
723
+ segment = raw_segment.strip()
724
+ if not segment:
725
+ continue
726
+ var_names = _TEMPLATE_VAR_RE.findall(segment)
727
+ if var_names:
728
+ resolved = {n: _render_value(n, attributes.get(n)) for n in var_names}
729
+ if any(not v for v in resolved.values()):
730
+ continue # drop the segment β€” one of its vars is missing
731
+ segment = _TEMPLATE_VAR_RE.sub(
732
+ lambda m: resolved[m.group(1)],
733
+ segment,
734
+ )
735
+ out_segments.append(segment)
736
+ return ", ".join(out_segments)
737
+
738
+
739
+ def resolve_prompt_template(session: "ProjectSession") -> str:
740
+ """Return the active template string for the project's selected preset.
741
+
742
+ Falls back to the music default if the stored preset id is unknown
743
+ (e.g. someone hand-edited .project.json to a bad value).
744
+ """
745
+ preset_id = (session.metadata.get("prompt_template_preset")
746
+ or DEFAULT_PROMPT_TEMPLATE_PRESET)
747
+ preset = PROMPT_TEMPLATE_PRESETS.get(preset_id)
748
+ if preset is None:
749
+ preset = PROMPT_TEMPLATE_PRESETS[DEFAULT_PROMPT_TEMPLATE_PRESET]
750
+ return preset["template"]
751
+
752
+
753
+ def update_project_template_preset(name: str, preset_id: str) -> Dict[str, Any]:
754
+ """Persist the user-selected preset id and return updated project state."""
755
+ if not isinstance(preset_id, str) or preset_id not in PROMPT_TEMPLATE_PRESETS:
756
+ valid = ", ".join(PROMPT_TEMPLATE_PRESETS.keys())
757
+ raise ValueError(f"Unknown preset id: {preset_id!r}. Valid: {valid}")
758
+ session = _get_or_load_session(name)
759
+ with session.lock:
760
+ session.metadata["prompt_template_preset"] = preset_id
761
+ # Drop the legacy free-form field so we stop carrying two parallel
762
+ # ways to configure annotation shape.
763
+ session.metadata.pop("prompt_template", None)
764
+ _write_metadata(name, session.metadata)
765
+ return get_project(name)
766
+
767
+
768
+ # ---------- Waveform peaks --------------------------------------------------
769
+
770
+
771
+ def _compute_peaks(audio_path: Path, n: int) -> Tuple[List[float], float]:
772
+ """Return N normalized peak amplitudes + duration in seconds.
773
+
774
+ Reads N short blocks at evenly spaced offsets via soundfile.seek instead
775
+ of decoding the whole file. ~40x faster than librosa.load on a typical
776
+ 30s clip; bounded I/O regardless of file length (a 5-minute clip costs
777
+ the same as a 30s one).
778
+
779
+ Falls back to a librosa-based decode for formats soundfile can't open
780
+ on this build (typically m4a/aac without ffmpeg-libsndfile).
781
+ """
782
+ import numpy as np
783
+ try:
784
+ import soundfile as sf
785
+ with sf.SoundFile(str(audio_path)) as src:
786
+ total = src.frames
787
+ sr = src.samplerate
788
+ if total == 0:
789
+ return ([0.0] * n, 0.0)
790
+ duration = float(total / sr)
791
+ # ~6 buckets-worth of samples per probe gives stable peaks without
792
+ # devolving into "read the whole file."
793
+ block = max(256, total // (n * 6))
794
+ peaks = np.zeros(n, dtype="float32")
795
+ for i in range(n):
796
+ center = int((i + 0.5) * total / n)
797
+ start = max(0, center - block // 2)
798
+ src.seek(start)
799
+ data = src.read(block, dtype="float32", always_2d=False)
800
+ if data.ndim > 1:
801
+ data = data.max(axis=1)
802
+ if len(data):
803
+ peaks[i] = float(np.abs(data).max())
804
+ max_peak = float(peaks.max())
805
+ if max_peak > 0:
806
+ peaks = peaks / max_peak
807
+ return (peaks.tolist(), duration)
808
+ except Exception as exc:
809
+ logger.debug("soundfile peak path failed for %s (%s); falling back to librosa", audio_path.name, exc)
810
+
811
+ # Fallback: librosa.load handles every codec we register, at the cost of
812
+ # a full-file decode + resample. Slower but bulletproof.
813
+ import librosa
814
+ y, sr = librosa.load(str(audio_path), sr=8000, mono=True)
815
+ if len(y) == 0:
816
+ return ([0.0] * n, 0.0)
817
+ duration = float(len(y) / sr)
818
+ chunks = np.array_split(y, n)
819
+ peaks = np.array([float(np.abs(c).max()) if len(c) else 0.0 for c in chunks])
820
+ max_peak = peaks.max()
821
+ if max_peak > 0:
822
+ peaks = peaks / max_peak
823
+ return (peaks.tolist(), duration)
824
+
825
+
826
+ def get_or_compute_peaks(
827
+ session: ProjectSession,
828
+ file_name: str,
829
+ audio_path: Path,
830
+ n: int = 200,
831
+ ) -> Tuple[List[float], float]:
832
+ """Memoized per-session peak computation. Cache key is `file_name:N`."""
833
+ cache_key = f"{file_name}:{n}"
834
+ cached = session.peaks_cache.get(cache_key)
835
+ if cached is not None:
836
+ return cached
837
+ result = _compute_peaks(audio_path, n)
838
+ session.peaks_cache[cache_key] = result
839
+ return result
840
+
841
+
842
+ # ---------- Health checks ---------------------------------------------------
843
+
844
+
845
+ def _clip_duration_sec(audio_path: Path) -> Optional[float]:
846
+ """Cheap duration probe via soundfile.info() β€” header read, no decode."""
847
+ try:
848
+ import soundfile as sf
849
+ info = sf.info(str(audio_path))
850
+ if info.samplerate <= 0:
851
+ return None
852
+ return float(info.frames / info.samplerate)
853
+ except Exception:
854
+ return None
855
+
856
+
857
+ def compute_health(
858
+ name: str,
859
+ short_threshold_sec: float = 1.0,
860
+ ) -> Dict[str, Any]:
861
+ """Per-clip checks that surface dataset problems before training.
862
+
863
+ Note: we don't flag "too long" clips. The SA3 dataloader handles them
864
+ via random-crop per __getitem__ β€” long files just get sampled at
865
+ different windows across epochs. Slicing remains useful for annotation
866
+ granularity and CLAP's 10s window, but it's not a correctness issue.
867
+
868
+ We also don't flag mixed sample rates or loudness: SA3 resamples every
869
+ file to its model rate (T.Resample in its dataset loader) and Fragmenta
870
+ enables SA3's built-in -16 LUFS VolumeNorm at train/pre-encode time, so
871
+ both are handled automatically downstream.
872
+
873
+ short_threshold_sec defaults to 1s β€” clips below this end up mostly
874
+ silence-padded into the training window.
875
+ """
876
+ from collections import defaultdict
877
+
878
+ # Single source of truth for what SA3's loader actually accepts. Fragmenta
879
+ # ingest accepts a wider set (.m4a, .aac) β€” those files would be silently
880
+ # skipped at train time, so we surface them here.
881
+ from app.core.training.sa3_lora_runner import SA3_AUDIO_EXTENSIONS
882
+
883
+ session = _get_or_load_session(name)
884
+ with session.lock:
885
+ clips = list(session.clips.values())
886
+
887
+ empty_prompts: List[str] = []
888
+ too_short: List[str] = []
889
+ unsupported_format: List[str] = []
890
+ prompt_groups: Dict[str, List[str]] = defaultdict(list)
891
+
892
+ for c in clips:
893
+ if not (c.prompt or "").strip():
894
+ empty_prompts.append(c.file_name)
895
+ else:
896
+ prompt_groups[c.prompt.strip().lower()].append(c.file_name)
897
+
898
+ ext = Path(c.file_name).suffix.lower()
899
+ if ext not in SA3_AUDIO_EXTENSIONS:
900
+ unsupported_format.append(c.file_name)
901
+
902
+ # Duration (header-only, ~free) β€” only used for the too-short check now.
903
+ dur = session.duration_cache.get(c.file_name)
904
+ if dur is None:
905
+ dur = _clip_duration_sec(Path(c.path))
906
+ if dur is not None:
907
+ session.duration_cache[c.file_name] = dur
908
+ if dur is not None and dur < short_threshold_sec:
909
+ too_short.append(c.file_name)
910
+
911
+ # --- Duplicate annotations: any non-empty prompt shared by 2+ clips.
912
+ dup_groups = [files for files in prompt_groups.values() if len(files) > 1]
913
+ dup_files = sorted({f for group in dup_groups for f in group})
914
+
915
+ empty_prompts.sort()
916
+ too_short.sort()
917
+ unsupported_format.sort()
918
+
919
+ return {
920
+ "total_clips": len(clips),
921
+ "empty_prompts": {"count": len(empty_prompts), "files": empty_prompts},
922
+ "too_short": {
923
+ "count": len(too_short),
924
+ "threshold_sec": short_threshold_sec,
925
+ "files": too_short,
926
+ },
927
+ "unsupported_format": {
928
+ "count": len(unsupported_format),
929
+ "accepted": sorted(SA3_AUDIO_EXTENSIONS),
930
+ "files": unsupported_format,
931
+ },
932
+ "duplicate_annotations": {
933
+ "count": len(dup_files),
934
+ "group_count": len(dup_groups),
935
+ "files": dup_files,
936
+ },
937
+ }
938
+
939
+
940
+ # ---------- Slicing ---------------------------------------------------------
941
+
942
+
943
+ def slice_clip(
944
+ name: str,
945
+ file_name: str,
946
+ target_sec: float,
947
+ overlap_sec: float,
948
+ strategy: str,
949
+ ) -> Dict[str, Any]:
950
+ """Split one clip into N children. Disk-level β€” happens immediately.
951
+
952
+ The parent file (and its sidecar) is deleted. Each child:
953
+ - lives in the project folder as `<stem>__NNN.wav`
954
+ - inherits the parent's in-memory prompt verbatim
955
+ - is uncommitted (so Discard rolls it back)
956
+ - keeps `parent=<parent_file_name>` in its session state
957
+
958
+ Discard cannot recover the parent file from children β€” same rule as
959
+ delete_clip. Commit makes the slice permanent.
960
+ """
961
+ from app.backend.data.slicing import plan_slices, write_slices
962
+
963
+ session = _get_or_load_session(name)
964
+ proj_path = project_path(name)
965
+ audio_path = proj_path / file_name
966
+
967
+ if not audio_path.exists():
968
+ raise FileNotFoundError(f"Clip not on disk: {file_name}")
969
+
970
+ plans = plan_slices(audio_path, target_sec, overlap_sec, strategy)
971
+ if len(plans) <= 1:
972
+ raise ValueError(
973
+ f"{file_name} is shorter than the target duration "
974
+ f"({target_sec:.1f}s); nothing to slice."
975
+ )
976
+
977
+ stem = audio_path.stem
978
+ children = write_slices(audio_path, plans, proj_path, stem)
979
+ if not children:
980
+ raise RuntimeError("Slice produced no children β€” check the audio file.")
981
+
982
+ with session.lock:
983
+ parent_clip = session.clips.get(file_name)
984
+ inherited_prompt = parent_clip.prompt if parent_clip else ""
985
+
986
+ # Remove the parent from session + disk.
987
+ session.clips.pop(file_name, None)
988
+ for key in list(session.peaks_cache):
989
+ if key.startswith(f"{file_name}:"):
990
+ del session.peaks_cache[key]
991
+ session.duration_cache.pop(file_name, None)
992
+ sidecar = _sidecar_for(audio_path)
993
+ if audio_path.exists():
994
+ audio_path.unlink()
995
+ if sidecar.exists():
996
+ sidecar.unlink()
997
+ committed = session.metadata.get("committed_files") or []
998
+ if file_name in committed:
999
+ session.metadata["committed_files"] = [f for f in committed if f != file_name]
1000
+
1001
+ # Register children as uncommitted clips with parent linkage.
1002
+ for child_path in children:
1003
+ session.clips[child_path.name] = ClipState(
1004
+ file_name=child_path.name,
1005
+ path=str(child_path),
1006
+ prompt=inherited_prompt,
1007
+ committed_prompt="",
1008
+ committed=False,
1009
+ parent=file_name,
1010
+ )
1011
+
1012
+ # Slicing replaces the parent's audio with N children β†’ any cached
1013
+ # latents reference the deleted parent and are now misaligned.
1014
+ _invalidate_latents(name)
1015
+
1016
+ return {
1017
+ "parent": file_name,
1018
+ "children": [
1019
+ {"file_name": p.name, "start_sec": pl.start_sec, "end_sec": pl.end_sec}
1020
+ for p, pl in zip(children, plans)
1021
+ ],
1022
+ "project": get_project(name),
1023
+ }
app/backend/data/slicing.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio slicing for the Dataset Workbench.
2
+
3
+ Splits one audio file into N children. Three strategies:
4
+
5
+ hard β€” uniform cuts every `target_duration` seconds.
6
+ transient β€” uniform anchor points, each snapped to the nearest onset
7
+ (librosa.onset.onset_detect).
8
+ silence β€” uniform anchor points, each snapped to the nearest low-RMS
9
+ window (cleanest splice between phrases).
10
+
11
+ All three honor `overlap_sec`, applied as a head-overlap on every child
12
+ after the first: child i starts at (end of child i-1) - overlap_sec.
13
+
14
+ Writes WAV regardless of source format (lossless, no codec deps). Parent
15
+ prompt is inherited verbatim; the user edits children individually after.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import logging
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import List, Literal, Tuple
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ SliceStrategy = Literal["hard", "transient", "silence"]
28
+ VALID_STRATEGIES = ("hard", "transient", "silence")
29
+
30
+ # How far a snap is allowed to move from the uniform anchor. Beyond this we
31
+ # just take the anchor β€” better a tidy cut than a wildly off-target chunk.
32
+ SNAP_WINDOW_FRAC = 0.35
33
+
34
+
35
+ @dataclass
36
+ class SlicePlan:
37
+ """One child's location inside the parent. Times are in seconds."""
38
+ index: int # 1-based
39
+ start_sec: float
40
+ end_sec: float
41
+
42
+
43
+ def _uniform_anchors(duration_sec: float, target_sec: float, overlap_sec: float) -> List[Tuple[float, float]]:
44
+ """Return [(start, end), ...] for uniform cuts, before any snapping."""
45
+ if target_sec <= 0:
46
+ raise ValueError("target_duration must be positive")
47
+ if overlap_sec < 0 or overlap_sec >= target_sec:
48
+ raise ValueError("overlap_sec must be >= 0 and < target_duration")
49
+ step = target_sec - overlap_sec
50
+ anchors: List[Tuple[float, float]] = []
51
+ start = 0.0
52
+ while start < duration_sec - 0.05: # don't emit a sub-50ms tail
53
+ end = min(start + target_sec, duration_sec)
54
+ anchors.append((start, end))
55
+ if end >= duration_sec:
56
+ break
57
+ start += step
58
+ return anchors
59
+
60
+
61
+ def _snap_to_onsets(anchors: List[Tuple[float, float]], y, sr: int, target_sec: float) -> List[Tuple[float, float]]:
62
+ """Snap each cut boundary to the nearest detected onset within a window."""
63
+ import librosa
64
+ import numpy as np
65
+ onsets = librosa.onset.onset_detect(y=y, sr=sr, units="time", backtrack=True)
66
+ if len(onsets) == 0:
67
+ return anchors
68
+ snap_window = target_sec * SNAP_WINDOW_FRAC
69
+ out: List[Tuple[float, float]] = []
70
+ for i, (s, e) in enumerate(anchors):
71
+ if i > 0:
72
+ # Snap the start (= previous end) to nearest onset within window.
73
+ candidates = onsets[(onsets >= s - snap_window) & (onsets <= s + snap_window)]
74
+ if len(candidates):
75
+ s = float(min(candidates, key=lambda t: abs(t - s)))
76
+ out.append((s, e))
77
+ # Stitch ends to match next start so no gap/overlap drift creeps in.
78
+ for i in range(len(out) - 1):
79
+ s, _ = out[i]
80
+ next_s, _ = out[i + 1]
81
+ out[i] = (s, next_s + (target_sec * 0.0)) # next_s alone β€” overlap is in next_s already from caller
82
+ return out
83
+
84
+
85
+ def _snap_to_silence(anchors: List[Tuple[float, float]], y, sr: int, target_sec: float) -> List[Tuple[float, float]]:
86
+ """Snap each cut boundary to the lowest-RMS frame within a window."""
87
+ import librosa
88
+ import numpy as np
89
+ # Frame-level RMS at ~20ms hop.
90
+ hop = max(1, sr // 50)
91
+ rms = librosa.feature.rms(y=y, frame_length=hop * 2, hop_length=hop)[0]
92
+ if len(rms) == 0:
93
+ return anchors
94
+ frame_times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=hop)
95
+ snap_window = target_sec * SNAP_WINDOW_FRAC
96
+ out: List[Tuple[float, float]] = []
97
+ for i, (s, e) in enumerate(anchors):
98
+ if i > 0:
99
+ mask = (frame_times >= s - snap_window) & (frame_times <= s + snap_window)
100
+ if mask.any():
101
+ local_idx = int(np.argmin(rms[mask]))
102
+ # Map masked-index back to absolute time.
103
+ masked_times = frame_times[mask]
104
+ s = float(masked_times[local_idx])
105
+ out.append((s, e))
106
+ return out
107
+
108
+
109
+ def plan_slices(
110
+ audio_path: Path,
111
+ target_sec: float,
112
+ overlap_sec: float,
113
+ strategy: SliceStrategy,
114
+ ) -> List[SlicePlan]:
115
+ """Compute the (start, end) for each child without writing anything yet."""
116
+ if strategy not in VALID_STRATEGIES:
117
+ raise ValueError(f"Unknown strategy: {strategy}")
118
+ import librosa
119
+ # Use mono for boundary detection only; final write uses the original.
120
+ y, sr = librosa.load(str(audio_path), sr=22050, mono=True)
121
+ duration = float(len(y) / sr) if len(y) else 0.0
122
+ if duration <= 0:
123
+ raise ValueError(f"{audio_path.name} has zero duration")
124
+ if duration < target_sec:
125
+ # Single child = the whole file. Skip the slice loop entirely.
126
+ return [SlicePlan(index=1, start_sec=0.0, end_sec=duration)]
127
+
128
+ anchors = _uniform_anchors(duration, target_sec, overlap_sec)
129
+ if strategy == "transient":
130
+ anchors = _snap_to_onsets(anchors, y, sr, target_sec)
131
+ elif strategy == "silence":
132
+ anchors = _snap_to_silence(anchors, y, sr, target_sec)
133
+
134
+ return [
135
+ SlicePlan(index=i + 1, start_sec=s, end_sec=e)
136
+ for i, (s, e) in enumerate(anchors)
137
+ ]
138
+
139
+
140
+ def write_slices(
141
+ audio_path: Path,
142
+ plans: List[SlicePlan],
143
+ out_dir: Path,
144
+ stem: str,
145
+ ) -> List[Path]:
146
+ """Write children as `<stem>__001.wav`, `<stem>__002.wav`, ... in `out_dir`.
147
+
148
+ Uses soundfile for lossless WAV write at the source's native sample rate.
149
+ Skips names that already exist on disk to avoid clobbering.
150
+ """
151
+ import soundfile as sf
152
+ import numpy as np
153
+
154
+ info = sf.info(str(audio_path))
155
+ sr = info.samplerate
156
+ total_frames = info.frames
157
+ written: List[Path] = []
158
+ width = max(3, len(str(len(plans))))
159
+
160
+ with sf.SoundFile(str(audio_path)) as src:
161
+ for plan in plans:
162
+ start_frame = max(0, int(plan.start_sec * sr))
163
+ end_frame = min(total_frames, int(plan.end_sec * sr))
164
+ if end_frame <= start_frame:
165
+ logger.warning("Skipping empty slice %s [%.2f-%.2f]", plan.index, plan.start_sec, plan.end_sec)
166
+ continue
167
+ src.seek(start_frame)
168
+ data = src.read(end_frame - start_frame, dtype="float32", always_2d=True)
169
+
170
+ child_name = f"{stem}__{plan.index:0{width}d}.wav"
171
+ child_path = out_dir / child_name
172
+ if child_path.exists():
173
+ # Don't silently overwrite; bump the suffix until free.
174
+ k = 2
175
+ while True:
176
+ candidate = out_dir / f"{stem}__{plan.index:0{width}d}_{k}.wav"
177
+ if not candidate.exists():
178
+ child_path = candidate
179
+ break
180
+ k += 1
181
+ sf.write(str(child_path), data, sr, subtype="PCM_16")
182
+ written.append(child_path)
183
+ return written
app/core/audio/midi_input.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Native MIDI input.
2
+
3
+ Reads hardware MIDI via python-rtmidi (CoreMIDI on macOS, WinMM on Windows,
4
+ ALSA on Linux) so MIDI works regardless of the web engine the OS gives us β€”
5
+ WKWebView has no Web MIDI, WebView2's is flaky. Same pattern as the native
6
+ Ableton Link binding in link_sync.py: wrap an optional native lib and no-op
7
+ gracefully if it isn't importable.
8
+
9
+ The backend owns the *transport* only: it enumerates input ports, opens one,
10
+ and broadcasts incoming messages to subscribers (drained by the SSE endpoint
11
+ in app.py). All mapping / learn / takeover logic stays in the frontend
12
+ MidiContext β€” it just consumes these events instead of Web MIDI.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import ctypes
17
+ import glob
18
+ import os
19
+ import queue
20
+ import sys
21
+ import threading
22
+ from typing import Any, Dict, List, Optional
23
+
24
+
25
+ def _preload_bundled_jack() -> None:
26
+ """Work around a broken RPATH in python-rtmidi's manylinux wheel.
27
+
28
+ The wheel bundles libjack as `python_rtmidi/libjack-<hash>.so.*`, but the
29
+ `_rtmidi` extension's RPATH points at a directory that doesn't exist
30
+ (`$ORIGIN/../python_rtmidi.` β€” note the stray trailing dot), so the loader
31
+ can't find it and `import rtmidi` dies with
32
+ `ImportError: libjack-<hash>.so...: cannot open shared object file`.
33
+
34
+ The bundled lib's soname matches the extension's DT_NEEDED exactly, so
35
+ dlopen'ing it with RTLD_GLOBAL first lets the loader satisfy the dependency
36
+ from the already-loaded object. Doing it here (rather than patching the
37
+ venv) survives a pip reinstall and needs no patchelf/root. Linux-only; a
38
+ no-op everywhere the glob finds nothing.
39
+ """
40
+ if not sys.platform.startswith("linux"):
41
+ return
42
+ for base in sys.path:
43
+ if not base or not os.path.isdir(base):
44
+ continue
45
+ for lib in glob.glob(os.path.join(base, "python_rtmidi*", "libjack-*.so*")):
46
+ try:
47
+ ctypes.CDLL(lib, mode=ctypes.RTLD_GLOBAL)
48
+ except OSError:
49
+ pass
50
+
51
+
52
+ try:
53
+ import rtmidi # python-rtmidi
54
+ _RTMIDI_OK = True
55
+ except Exception: # pragma: no cover - import guard
56
+ # Most likely the bundled-libjack RPATH bug β€” preload it and retry once.
57
+ try:
58
+ _preload_bundled_jack()
59
+ import rtmidi
60
+ _RTMIDI_OK = True
61
+ except Exception:
62
+ rtmidi = None
63
+ _RTMIDI_OK = False
64
+
65
+ _lock = threading.Lock()
66
+ _midi_in: Any = None # the open rtmidi.MidiIn, or None
67
+ _current_port: Optional[str] = None # name of the open port, or None
68
+ _subscribers: List["queue.Queue"] = []
69
+
70
+
71
+ def is_available() -> bool:
72
+ """True if the native MIDI backend is importable."""
73
+ return _RTMIDI_OK
74
+
75
+
76
+ def list_inputs() -> List[Dict[str, Any]]:
77
+ """Enumerate input ports. `id` is the port name (stable across index
78
+ shuffles); `index` is its current rtmidi index."""
79
+ if not _RTMIDI_OK:
80
+ return []
81
+ mi = rtmidi.MidiIn()
82
+ try:
83
+ names = mi.get_ports()
84
+ finally:
85
+ mi.delete()
86
+ return [{"id": name, "name": name, "index": i} for i, name in enumerate(names)]
87
+
88
+
89
+ def current_port() -> Optional[str]:
90
+ with _lock:
91
+ return _current_port
92
+
93
+
94
+ def _on_message(event, _data=None) -> None:
95
+ """rtmidi callback (runs on its own thread). `event` is (message, delta).
96
+ Broadcast the raw status/data bytes so the frontend can reuse its existing
97
+ Web-MIDI-shaped dispatcher unchanged."""
98
+ message, _delta = event
99
+ payload = {"data": list(message)}
100
+ with _lock:
101
+ subs = list(_subscribers)
102
+ for q in subs:
103
+ try:
104
+ q.put_nowait(payload)
105
+ except queue.Full:
106
+ pass # slow consumer β€” drop rather than block the MIDI thread
107
+
108
+
109
+ def close_input() -> None:
110
+ global _midi_in, _current_port
111
+ with _lock:
112
+ mi = _midi_in
113
+ _midi_in = None
114
+ _current_port = None
115
+ if mi is not None:
116
+ try:
117
+ mi.cancel_callback()
118
+ except Exception:
119
+ pass
120
+ try:
121
+ mi.close_port()
122
+ except Exception:
123
+ pass
124
+ try:
125
+ mi.delete()
126
+ except Exception:
127
+ pass
128
+
129
+
130
+ def open_input(port_id: Optional[str]) -> bool:
131
+ """Open the input port whose name == port_id. A falsy port_id just closes
132
+ the current port. Returns True on success (or on a pure close)."""
133
+ if not _RTMIDI_OK:
134
+ return False
135
+ close_input()
136
+ if not port_id:
137
+ return True
138
+
139
+ mi = rtmidi.MidiIn()
140
+ idx = None
141
+ for i, name in enumerate(mi.get_ports()):
142
+ if name == port_id:
143
+ idx = i
144
+ break
145
+ if idx is None:
146
+ mi.delete()
147
+ return False
148
+
149
+ mi.open_port(idx)
150
+ # Drop sysex / timing-clock / active-sensing so the stream stays to the
151
+ # control messages the mapper cares about (CC + notes).
152
+ mi.ignore_types(sysex=True, timing=True, active_sense=True)
153
+ mi.set_callback(_on_message)
154
+
155
+ global _midi_in, _current_port
156
+ with _lock:
157
+ _midi_in = mi
158
+ _current_port = port_id
159
+ return True
160
+
161
+
162
+ def subscribe() -> "queue.Queue":
163
+ q: "queue.Queue" = queue.Queue(maxsize=512)
164
+ with _lock:
165
+ _subscribers.append(q)
166
+ return q
167
+
168
+
169
+ def unsubscribe(q: "queue.Queue") -> None:
170
+ with _lock:
171
+ if q in _subscribers:
172
+ _subscribers.remove(q)
app/core/config.py CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
4
  from typing import Dict, Any, Optional
5
  import json
6
 
 
7
  class ProjectConfig:
8
 
9
  def __init__(self, project_root: Optional[Path] = None) -> None:
@@ -18,11 +19,11 @@ class ProjectConfig:
18
  self.user_data_dir = Path.home() / "Library" / "Application Support" / "FragmentaDesktop"
19
  else:
20
  self.user_data_dir = Path.home() / ".local" / "share" / "FragmentaDesktop"
21
-
22
  self.user_data_dir.mkdir(parents=True, exist_ok=True)
23
  print(f"Running in frozen mode. Project root: {self.project_root}")
24
  print(f"User data directory: {self.user_data_dir}")
25
-
26
  else:
27
  self.frozen = False
28
  if project_root is None:
@@ -37,123 +38,52 @@ class ProjectConfig:
37
  break
38
  else:
39
  project_root = config_file_dir
40
-
41
  self.project_root: Path = Path(project_root).resolve()
42
  self.user_data_dir = self.project_root
43
 
44
  fine_tuned_override = os.environ.get("FRAGMENTA_FINE_TUNED_DIR")
45
  fine_tuned_dir = Path(fine_tuned_override) if fine_tuned_override else self.user_data_dir / "models" / "fine_tuned"
46
 
47
- data_override = os.environ.get("FRAGMENTA_DATA_DIR")
48
- data_dir = Path(data_override) if data_override else self.user_data_dir / "data"
 
 
 
49
 
50
  self.paths: Dict[str, Path] = {
51
  "models": self.user_data_dir / "models",
52
  "models_config": self.user_data_dir / "models" / "config",
53
  "models_pretrained": self.user_data_dir / "models" / "pretrained",
54
  "models_fine_tuned": fine_tuned_dir,
55
- "data": data_dir,
56
  "logs": self.user_data_dir / "logs",
57
  "output": self.user_data_dir / "output",
58
 
59
  "application": self.project_root,
60
  "backend": self.project_root / "app" / "backend",
61
  "frontend": self.project_root / "app" / "frontend",
62
- "stable_audio_tools": self.project_root / "vendor" / "stable-audio-tools",
63
- "loraw_vendor": self.project_root / "vendor" / "loraw_vendor",
64
  "venv": self.project_root / "venv",
65
  }
66
 
67
  self._ensure_directories()
68
- self.model_configs: Dict[str, Dict[str, str]
69
- ] = self._load_model_configs()
 
 
70
 
71
  def _ensure_directories(self) -> None:
72
 
73
  for path_name, path in self.paths.items():
74
- if path_name.endswith(('_fine_tuned', 'data')):
75
  path.mkdir(parents=True, exist_ok=True)
76
 
77
- def _load_model_configs(self) -> Dict[str, Dict[str, str]]:
78
-
79
- return {
80
- "stable-audio-open-1.0": {
81
- "config": str(self.paths["models_config"] / "model_config.json"),
82
- "ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-model.safetensors")
83
- },
84
- "stable-audio-open-small": {
85
- "config": str(self.paths["models_config"] / "model_config_small.json"),
86
- "ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-small-model.safetensors")
87
- },
88
- "custom": {
89
- "config": str(self.paths["models_config"] / "model_config_small.json"),
90
- "ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-small-model.safetensors")
91
- }
92
- }
93
-
94
  def get_path(self, path_name: str) -> Path:
95
  if path_name not in self.paths:
96
  raise ValueError(f"Unknown path name: {path_name}")
97
  return self.paths[path_name]
98
 
99
- def get_model_config(self, model_name: str) -> Dict[str, str]:
100
- if model_name not in self.model_configs:
101
- raise ValueError(f"Unknown model: {model_name}")
102
- return self.model_configs[model_name]
103
-
104
- def get_dataset_config_path(self) -> str:
105
- return str(self.paths["models_config"] / "dataset-config.json")
106
-
107
- def get_custom_metadata_path(self) -> str:
108
- return str(self.project_root / "vendor" / "stable-audio-tools" / "custom_metadata.py")
109
-
110
- def get_metadata_json_path(self) -> str:
111
- return str(self.paths["data"] / "metadata.json")
112
-
113
- def update_dataset_config(self) -> None:
114
- from app.backend.data.simple_audio_processor import SimpleAudioProcessor
115
-
116
- try:
117
- processor = SimpleAudioProcessor(
118
- model_config_path=self.paths["models_config"] / "model_config.json"
119
- )
120
-
121
- result = processor.create_dataset_config(
122
- input_dir=self.paths["data"],
123
- output_dir=self.paths["data"]
124
- )
125
-
126
- target_config = self.paths["models_config"] / "dataset-config.json"
127
- with open(target_config, 'w') as f:
128
- json.dump(result["dataset_config"], f, indent=4)
129
-
130
- print(f"Updated dataset config: {target_config}")
131
- print(f"Points to {result['file_count']} original audio files")
132
- print(f"Sample size: {result['sample_size']} samples ({result['sample_size']/result['sample_rate']:.1f}s)")
133
- print(f"Random cropping during training (correct!)")
134
-
135
- except Exception as e:
136
- print(f"Failed to update dataset config: {e}")
137
- print("Falling back to basic dataset config...")
138
-
139
- dataset_config: Dict[str, Any] = {
140
- "dataset_type": "audio_dir",
141
- "datasets": [
142
- {
143
- "id": "fine_tune_data",
144
- "path": str(self.paths["data"]),
145
- "custom_metadata_module": "custom_metadata"
146
- }
147
- ],
148
- "random_crop": True
149
- }
150
-
151
- config_path = self.paths["models_config"] / "dataset-config.json"
152
- with open(config_path, 'w') as f:
153
- json.dump(dataset_config, f, indent=4)
154
-
155
- print(f"Updated fallback dataset config: {config_path}")
156
-
157
  def to_dict(self) -> Dict[str, Any]:
158
  return {
159
  "project_root": str(self.project_root),
 
4
  from typing import Dict, Any, Optional
5
  import json
6
 
7
+
8
  class ProjectConfig:
9
 
10
  def __init__(self, project_root: Optional[Path] = None) -> None:
 
19
  self.user_data_dir = Path.home() / "Library" / "Application Support" / "FragmentaDesktop"
20
  else:
21
  self.user_data_dir = Path.home() / ".local" / "share" / "FragmentaDesktop"
22
+
23
  self.user_data_dir.mkdir(parents=True, exist_ok=True)
24
  print(f"Running in frozen mode. Project root: {self.project_root}")
25
  print(f"User data directory: {self.user_data_dir}")
26
+
27
  else:
28
  self.frozen = False
29
  if project_root is None:
 
38
  break
39
  else:
40
  project_root = config_file_dir
41
+
42
  self.project_root: Path = Path(project_root).resolve()
43
  self.user_data_dir = self.project_root
44
 
45
  fine_tuned_override = os.environ.get("FRAGMENTA_FINE_TUNED_DIR")
46
  fine_tuned_dir = Path(fine_tuned_override) if fine_tuned_override else self.user_data_dir / "models" / "fine_tuned"
47
 
48
+ # Scratch area for browser folder uploads (/api/upload-folder). The
49
+ # SA2-era "data" dataset directory is gone in 0.2.0 β€” datasets are now
50
+ # Dataset Workbench projects under projects/.
51
+ uploads_override = os.environ.get("FRAGMENTA_UPLOADS_DIR")
52
+ uploads_dir = Path(uploads_override) if uploads_override else self.user_data_dir / "uploads"
53
 
54
  self.paths: Dict[str, Path] = {
55
  "models": self.user_data_dir / "models",
56
  "models_config": self.user_data_dir / "models" / "config",
57
  "models_pretrained": self.user_data_dir / "models" / "pretrained",
58
  "models_fine_tuned": fine_tuned_dir,
59
+ "uploads": uploads_dir,
60
  "logs": self.user_data_dir / "logs",
61
  "output": self.user_data_dir / "output",
62
 
63
  "application": self.project_root,
64
  "backend": self.project_root / "app" / "backend",
65
  "frontend": self.project_root / "app" / "frontend",
66
+ "stable_audio_3": self.project_root / "vendor" / "stable-audio-3",
 
67
  "venv": self.project_root / "venv",
68
  }
69
 
70
  self._ensure_directories()
71
+ # The SA3 catalog lives in app/core/model_manager.py. This dict stays
72
+ # empty; it's retained only because to_dict()/print_paths() and the
73
+ # config validator still reference it.
74
+ self.model_configs: Dict[str, Dict[str, str]] = {}
75
 
76
  def _ensure_directories(self) -> None:
77
 
78
  for path_name, path in self.paths.items():
79
+ if path_name.endswith(('_fine_tuned', 'uploads')):
80
  path.mkdir(parents=True, exist_ok=True)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def get_path(self, path_name: str) -> Path:
83
  if path_name not in self.paths:
84
  raise ValueError(f"Unknown path name: {path_name}")
85
  return self.paths[path_name]
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def to_dict(self) -> Dict[str, Any]:
88
  return {
89
  "project_root": str(self.project_root),
app/core/generation/audio_generator.py CHANGED
@@ -1,519 +1,536 @@
1
- import torch
2
- import soundfile as sf
3
- import numpy as np
4
- from pathlib import Path
5
- from typing import Dict, Any, Optional, List, Tuple
6
- import logging
 
 
 
 
 
 
 
7
  import re
8
  import sys
9
  import threading
10
  import time
11
  import warnings
12
- from datetime import datetime
13
-
14
 
15
- class GenerationStopped(Exception):
16
- """Raised by the per-step callback when a stop has been requested."""
17
- pass
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def _slugify_prompt(text: str, max_len: int = 40) -> str:
21
- s = re.sub(r'[^a-zA-Z0-9]+', '_', text.strip().lower())
22
- s = re.sub(r'_+', '_', s).strip('_')
23
- return s[:max_len] or 'untitled'
24
 
25
- sys.path.append(
26
- str(Path(__file__).parent.parent.parent.parent / "vendor" / "stable-audio-tools"))
27
- # LoRAW lives at <project>/vendor/loraw_vendor; expose its `loraw` package for inference.
28
- sys.path.append(
29
- str(Path(__file__).parent.parent.parent.parent / "vendor" / "loraw_vendor"))
30
 
31
 
32
- warnings.filterwarnings(
33
- "ignore",
34
- message=r"pkg_resources is deprecated as an API.*",
35
- category=UserWarning,
36
- )
37
 
38
- from stable_audio_tools.models.utils import load_ckpt_state_dict
39
- from stable_audio_tools.inference.generation import generate_diffusion_cond
40
- from stable_audio_tools.models import create_model_from_config
41
- from loraw.network import create_lora_from_config
42
 
43
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  class AudioGenerator:
47
- def __init__(self, config):
48
- self.config = config
49
- self.model = None
50
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
51
- self.current_model_name = None
52
- self.current_model_path = None
53
- self.current_model_key = None
54
- self.is_distilled_small = False
55
- self.is_fine_tuned = False
56
- # LoRA state. `lora` holds the LoRAWrapper instance when one is active;
57
- # `_active_lora_path` / `_active_lora_multiplier` are used (along with
58
- # the base-model identifier) in `current_model_key` so the cache
59
- # invalidates whenever the LoRA selection changes β€” forcing a fresh
60
- # base reload because LoRAW's `activate()` is not reversible in-place.
61
- self.lora = None
62
- self._active_lora_path = None
63
- self._active_lora_multiplier = 1.0
64
- self._stop_event = threading.Event()
65
- logger.info(f"Using device: {self.device}")
66
-
67
- def _apply_lora(self, lora_path: str, lora_config: Dict[str, Any], multiplier: float = 1.0):
68
- """Wrap the currently-loaded base model with a LoRA from LoRAW.
69
-
70
- Caller is responsible for ensuring the base model is fresh (no prior
71
- LoRA injected) β€” typically by routing through `generate_audio`'s cache
72
- invalidation, which reloads the base when the LoRA selection changes.
73
- """
74
- if self.model is None:
75
- raise RuntimeError("Base model must be loaded before applying a LoRA")
76
- # torch.compile wraps in OptimizedModule, which prefixes named_modules()
77
- # with `_orig_mod/`. LoRAW's saved state has no such prefix (training
78
- # didn't compile). Operate on the underlying module so scan_model keys
79
- # match the checkpoint exactly. The compiled wrapper still dispatches
80
- # forward through this same module, so the LoRA stays active.
81
- target = getattr(self.model, "_orig_mod", self.model)
82
- full_config = {
83
- "model_type": getattr(target, "model_type", "diffusion_cond"),
84
- "lora": lora_config,
85
- }
86
- self.lora = create_lora_from_config(full_config, target)
87
- state = torch.load(lora_path, map_location=self.device)
88
- self.lora.load_weights(state, multiplier=multiplier)
89
- self.lora.activate()
90
- self._active_lora_path = lora_path
91
- self._active_lora_multiplier = multiplier
92
- logger.info(f"LoRA applied: {Path(lora_path).name} (multiplier={multiplier})")
93
 
 
 
 
 
 
 
 
 
 
 
 
94
  def request_stop(self) -> bool:
95
- """Signal the in-flight diffusion loop (if any) to abort at the next step."""
96
- already_set = self._stop_event.is_set()
97
- self._stop_event.set()
98
- return not already_set
99
-
100
- def load_local_base_model(self, model_name: str = "stable-audio-open-small") -> bool:
101
- try:
102
- logger.info(f"Loading local base model: {model_name}")
103
-
104
- self.current_model_name = model_name
105
-
106
- from stable_audio_tools.models.factory import create_model_from_config
107
- from stable_audio_tools.models.utils import load_ckpt_state_dict
108
- if "small" in model_name:
109
- config_file = "model_config_small.json"
110
- else:
111
- config_file = "model_config.json"
112
- self.is_distilled_small = "small" in model_name.lower()
113
- self.is_fine_tuned = False
114
-
115
- config_path = Path(__file__).parent.parent.parent.parent / "models" / "config" / config_file
116
- logger.info(f"Using config file: {config_path}")
117
-
118
- with open(config_path, 'r') as f:
119
- import json
120
- model_config = json.load(f)
121
-
122
- self.model = create_model_from_config(model_config)
123
- if model_name == 'stable-audio-open-small':
124
- model_file_name = 'stable-audio-open-small-model.safetensors'
125
- elif model_name == 'stable-audio-open-1.0':
126
- model_file_name = 'stable-audio-open-model.safetensors'
127
- else:
128
- model_file_name = f"{model_name}-model.safetensors"
129
-
130
- model_file = Path(__file__).parent.parent.parent.parent / "models" / "pretrained" / model_file_name
131
- self.current_model_path = str(model_file)
132
- logger.info(f"Loading weights from: {model_file}")
133
-
134
- if not model_file.exists():
135
- raise FileNotFoundError(f"Local model file not found: {model_file}")
136
-
137
- state_dict = load_ckpt_state_dict(str(model_file))
138
- self.model.load_state_dict(state_dict, strict=False)
139
-
140
- self.model = self.model.to(self.device)
141
- self.model.eval()
142
- self.model.requires_grad_(False)
143
- if self.device.startswith("cuda"):
144
- self.model = torch.compile(self.model, mode="reduce-overhead")
145
-
146
- logger.info("Local base model loaded successfully")
147
- return True
148
-
149
- except Exception as e:
150
- logger.error(f"Failed to load local base model: {e}")
151
  return False
 
 
152
 
153
- def load_model(self, model_path: Optional[Path] = None) -> bool:
154
- try:
155
- print(f"Loading model from {model_path}")
156
-
157
- if model_path is None:
158
- return self.load_local_base_model("stable-audio-open-small")
159
- else:
160
- safetensors_files = list(model_path.glob("*.safetensors"))
161
- if safetensors_files:
162
- unwrapped_path = str(safetensors_files[0])
163
- print(f"Found safetensors file: {unwrapped_path}")
164
- return self.load_unwrapped_model(unwrapped_path)
165
- else:
166
- print(f"No safetensors files found in {model_path}, using local base model")
167
- return self.load_local_base_model("stable-audio-open-small")
168
-
169
- except Exception as e:
170
- print(f"Failed to load model: {e}")
171
- return False
172
-
173
- def load_unwrapped_model(self, unwrapped_model_path: str, config_file: str = None) -> bool:
174
- try:
175
- print(f"Loading unwrapped model from {unwrapped_model_path}")
176
-
177
- self.current_model_path = unwrapped_model_path
178
-
179
- from stable_audio_tools.models.factory import create_model_from_config
180
- from stable_audio_tools.models.utils import load_ckpt_state_dict
181
- if config_file is None:
182
- config_file = "model_config_small.json"
183
- self.is_distilled_small = "small" in config_file.lower()
184
-
185
-
186
- metadata_path = Path(unwrapped_model_path).parent.parent / "training_metadata.json"
187
- self.is_fine_tuned = metadata_path.exists()
188
- if self.is_fine_tuned:
189
- logger.info(
190
- f"Detected fine-tuned model via {metadata_path}; "
191
- f"using full diffusion sampler recipe instead of distilled 8-step pingpong"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  )
193
 
194
- config_path = Path(__file__).parent.parent.parent.parent / \
195
- "models" / "config" / config_file
196
- print(f"Using config file: {config_path}")
197
-
198
- with open(config_path, 'r') as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  import json
200
- model_config = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- self.model = create_model_from_config(model_config)
 
 
 
 
 
203
 
204
- state_dict = load_ckpt_state_dict(unwrapped_model_path)
205
- self.model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- self.model = self.model.to(self.device)
208
- self.model.eval()
209
- self.model.requires_grad_(False)
210
 
211
- if self.device.startswith("cuda"):
212
- self.model = torch.compile(self.model, mode="reduce-overhead")
 
 
 
 
213
 
214
- print(f"AUDIO GENERATOR: Unwrapped model loaded successfully")
215
- return True
216
 
217
- except Exception as e:
218
- print(f"Failed to load unwrapped model: {e}")
 
219
  return False
 
 
 
220
 
 
221
  def generate_audio(
222
  self,
223
  prompt: str,
224
- model_path: Optional[Path] = None,
225
- unwrapped_model_path: Optional[str] = None,
226
- config_file: Optional[str] = None,
227
  duration: float = 10.0,
228
- cfg_scale: float = 7.0,
229
- steps: int = 250,
230
  seed: int = -1,
231
- output_path: Optional[Path] = None,
232
- batch_index: int = 1,
233
- batch_total: int = 1,
234
- loop_mode: bool = False,
235
- lora_path: Optional[str] = None,
236
- lora_config: Optional[Dict[str, Any]] = None,
237
- lora_multiplier: float = 1.0,
 
 
 
 
 
 
 
 
 
 
 
238
  ) -> Path:
239
- print(f"\nAUDIO GENERATOR: generate_audio called")
240
- print(f" - Prompt: '{prompt}'")
241
- print(f" - Duration: {duration}s")
242
- if lora_path:
243
- print(f" - LoRA: {lora_path} (Γ—{lora_multiplier})")
244
-
245
- # The cache key includes LoRA selection so the base reloads whenever
246
- # the LoRA changes (LoRAW's activate() is not reversible in-place;
247
- # the only safe way to drop or swap a LoRA is to reload the base).
248
- lora_signature = (lora_path, lora_multiplier) if lora_path else (None, 1.0)
249
- if unwrapped_model_path:
250
- target_key = ('unwrapped', str(unwrapped_model_path), lora_signature)
251
- elif model_path:
252
- target_key = ('path', str(model_path), lora_signature)
253
- else:
254
- target_key = ('default', 'stable-audio-open-small', lora_signature)
255
-
256
- if self.model is not None and self.current_model_key == target_key:
257
- print(f"AUDIO GENERATOR: Reusing already-loaded model")
258
- else:
259
- print(f"AUDIO GENERATOR: Loading new model")
260
- # Reset any prior LoRA state β€” load_*_model rebuilds self.model fresh.
261
- self.lora = None
262
- self._active_lora_path = None
263
- self._active_lora_multiplier = 1.0
264
-
265
- if unwrapped_model_path:
266
- print(f"AUDIO GENERATOR: Loading unwrapped model from {unwrapped_model_path}")
267
- if not self.load_unwrapped_model(unwrapped_model_path, config_file):
268
- raise ValueError(f"Failed to load unwrapped model from {unwrapped_model_path}")
269
- elif model_path:
270
- model_path_str = str(model_path)
271
- print(f"AUDIO GENERATOR: Checking model path: {model_path_str}")
272
-
273
- if "stable-audio-open-small" in model_path_str:
274
- print(f"AUDIO GENERATOR: Loading local small base model")
275
- if not self.load_local_base_model("stable-audio-open-small"):
276
- raise ValueError("Failed to load local small base model")
277
- elif "stable-audio-open-model" in model_path_str:
278
- print(f"AUDIO GENERATOR: Loading local large base model")
279
- if not self.load_local_base_model("stable-audio-open-1.0"):
280
- raise ValueError("Failed to load local large base model")
281
- else:
282
- print(f"AUDIO GENERATOR: Loading fine-tuned model from {model_path}")
283
- if not self.load_model(model_path):
284
- raise ValueError(f"Failed to load model from {model_path}")
285
- else:
286
- print(f"AUDIO GENERATOR: Loading default local small base model")
287
- if not self.load_local_base_model("stable-audio-open-small"):
288
- raise ValueError("Failed to load default local base model")
289
-
290
- # Attach the LoRA (if requested) onto the freshly loaded base.
291
- if lora_path:
292
- if not lora_config:
293
- raise ValueError("lora_config required when lora_path is set")
294
- self._apply_lora(lora_path, lora_config, lora_multiplier)
295
-
296
- self.current_model_key = target_key
297
-
298
- print(f"AUDIO GENERATOR: Model loaded successfully")
299
-
300
- self._stop_event.clear()
301
-
302
- def _stop_callback(state):
303
- if self._stop_event.is_set():
304
- raise GenerationStopped("Stop requested mid-diffusion")
305
-
306
- try:
307
- # Three recipes, picked by what the loaded weights actually are:
308
- # 1. Original distilled small β€” rectified-flow + CFG distillation
309
- # baked in. Requires pingpong / 8 steps / CFG 1.0.
310
- # 2. Fine-tuned small β€” distillation destroyed by SFT but the
311
- # objective is still rectified-flow, so the sampler name must
312
- # come from the rectified-flow family (euler|rk4|dpmpp|pingpong),
313
- # NOT from the v-diffusion family. Use external CFG.
314
- # 3. Large model β€” standard v-diffusion, accepts dpmpp-3m-sde.
315
- use_distilled_recipe = self.is_distilled_small and not self.is_fine_tuned
316
- if use_distilled_recipe:
317
- effective_sampler = "pingpong"
318
- effective_steps = 8
319
- effective_cfg = 1.0
320
- sigma_kwargs = {}
321
- elif self.is_distilled_small:
322
- effective_sampler = "dpmpp"
323
- effective_steps = steps
324
- effective_cfg = cfg_scale
325
- sigma_kwargs = {"sigma_max": 1.0}
326
- else:
327
- effective_sampler = "dpmpp-3m-sde"
328
- effective_steps = steps
329
- effective_cfg = cfg_scale
330
- sigma_kwargs = {"sigma_min": 0.03, "sigma_max": 1000}
331
-
332
- print(f"Generating audio for prompt: '{prompt}'")
333
- recipe_note = ""
334
- if use_distilled_recipe:
335
- recipe_note = " (distilled small overrides applied)"
336
- elif self.is_fine_tuned and self.is_distilled_small:
337
- recipe_note = " (fine-tuned small: rectified-flow dpmpp + external CFG)"
338
- print(
339
- f"Duration: {duration}s, CFG scale: {effective_cfg}, "
340
- f"Steps: {effective_steps}, Sampler: {effective_sampler}"
341
- + recipe_note
342
- )
343
- requested_sample_size = int(duration * self.model.sample_rate)
344
- max_sample_size = None
345
- try:
346
- max_sample_size = self.model.sample_size
347
- except AttributeError:
348
- if hasattr(self.model, 'model') and hasattr(self.model.model, 'sample_size'):
349
- max_sample_size = self.model.model.sample_size
350
- else:
351
- config_path = Path(__file__).parent.parent.parent.parent / "models" / "config"
352
- if hasattr(self, 'current_model_name') and self.current_model_name:
353
- if 'small' in self.current_model_name:
354
- config_file = config_path / "model_config_small.json"
355
- else:
356
- config_file = config_path / "model_config.json"
357
- else:
358
- if hasattr(self, 'current_model_path') and self.current_model_path:
359
- model_file = Path(self.current_model_path)
360
- if model_file.exists():
361
- file_size_gb = model_file.stat().st_size / (1024**3)
362
- if file_size_gb < 2.0:
363
- config_file = config_path / "model_config_small.json"
364
- else:
365
- config_file = config_path / "model_config.json"
366
- else:
367
- config_file = config_path / "model_config_small.json"
368
- else:
369
- config_file = config_path / "model_config_small.json"
370
-
371
- if config_file.exists():
372
- with open(config_file, 'r') as f:
373
- import json
374
- config_data = json.load(f)
375
- max_sample_size = config_data.get('sample_size', 44100 * 10)
376
- else:
377
- max_sample_size = 44100 * 10
378
- if max_sample_size and requested_sample_size > max_sample_size:
379
- print(f"Requested duration {duration}s exceeds model maximum. Truncating.")
380
- requested_sample_size = max_sample_size
381
- duration = requested_sample_size / self.model.sample_rate
382
-
383
- if seed == -1:
384
- import numpy as np
385
- seed = np.random.randint(0, 2**32 - 1, dtype=np.int64)
386
-
387
- print(f"Using seed: {seed}")
388
-
389
- if loop_mode and max_sample_size:
390
- song_seconds = max(int(duration),
391
- int(max_sample_size / self.model.sample_rate))
392
- else:
393
- song_seconds = int(duration)
394
-
395
- conditioning = [{
396
- "prompt": prompt,
397
- "seconds_start": 0,
398
- "seconds_total": song_seconds,
399
- }]
400
-
401
- device = next(self.model.parameters()).device
402
- print(f"Using device: {device}")
403
-
404
- with warnings.catch_warnings():
405
- # Known torchsde float-boundary chatter from dpmpp-3m-sde.
406
- warnings.filterwarnings(
407
- "ignore",
408
- message=r"Should have tb<=t1 but got tb=.*",
409
- category=UserWarning,
410
- module=r"torchsde\._brownian\.brownian_interval",
411
  )
412
- warnings.filterwarnings(
413
- "ignore",
414
- message=r"Should have ta>=t0 but got ta=.*",
415
- category=UserWarning,
416
- module=r"torchsde\._brownian\.brownian_interval",
417
- )
418
-
419
- audio = generate_diffusion_cond(
420
- model=self.model,
421
- steps=effective_steps,
422
- cfg_scale=effective_cfg,
423
- conditioning=conditioning,
424
- batch_size=1,
425
- sample_size=requested_sample_size,
426
- seed=seed,
427
- device=str(device),
428
- sampler_type=effective_sampler,
429
- callback=_stop_callback,
430
- **sigma_kwargs,
431
  )
432
 
433
- print(f"Generation complete, audio shape: {audio.shape}")
434
-
435
- from einops import rearrange
436
- audio = rearrange(audio, "b d n -> d (b n)").to(torch.float32)
437
- audio = audio / audio.abs().max()
438
- audio_int16 = (audio.clamp(-1, 1) * 32767).to(torch.int16).cpu()
439
-
440
- if output_path is None:
441
- output_dir = Path(__file__).parent.parent.parent.parent / "output"
442
- output_dir.mkdir(exist_ok=True)
443
- ts = datetime.now().strftime('%Y%m%d_%H%M%S')
444
- slug = _slugify_prompt(prompt)
445
- suffix = f"_{batch_index}" if batch_total > 1 else ""
446
- output_path = output_dir / f"fragmenta_{ts}_{slug}{suffix}.wav"
447
-
448
- self.save_audio(audio_int16, output_path, self.model.sample_rate)
449
-
450
- print(f"AUDIO GENERATOR: Generation complete")
451
- print(f" - Output file: {output_path}")
452
- print(f" - Output file size: {output_path.stat().st_size} bytes")
453
-
454
- return output_path
455
-
456
  except GenerationStopped:
457
- print("AUDIO GENERATOR: Generation stopped by user request")
458
  raise
459
- except Exception as e:
460
- print(f"AUDIO GENERATOR: Error during generation: {str(e)}")
461
- import traceback
462
- traceback.print_exc()
463
  raise
464
- finally:
465
- self._stop_event.clear()
466
-
467
- def generate_batch(
468
- self,
469
- prompts: List[str],
470
- duration: float = 10.0,
471
- cfg_scale: float = 6.0,
472
- steps: int = 250,
473
- seed: int = -1,
474
- output_dir: Optional[Path] = None
475
- ) -> List[Path]:
476
- results = []
477
-
478
- for i, prompt in enumerate(prompts):
479
- print(f"Generating audio {i+1}/{len(prompts)}")
480
 
481
- current_seed = seed if seed != -1 else seed + i
482
- output_path = None
483
- if output_dir:
484
- output_dir.mkdir(exist_ok=True, parents=True)
485
- output_path = output_dir / f"generated_{i+1:03d}.wav"
486
-
487
- try:
488
- output_path = self.generate_audio(
489
- prompt=prompt,
490
- duration=duration,
491
- cfg_scale=cfg_scale,
492
- steps=steps,
493
- seed=current_seed,
494
- output_path=output_path
495
- )
496
- results.append(output_path)
497
-
498
- except Exception as e:
499
- print(f"Failed to generate audio for prompt {i+1}: {e}")
500
- results.append(None)
501
-
502
- return results
503
 
504
- def save_audio(self, audio: torch.Tensor, output_path: Path, sample_rate: int):
505
- output_path.parent.mkdir(exist_ok=True, parents=True)
506
- audio_np = audio.detach().cpu().transpose(0, 1).numpy()
507
- sf.write(str(output_path), audio_np, sample_rate, subtype="PCM_16")
508
 
509
- def get_model_info(self) -> Dict[str, Any]:
510
- if self.model is None:
511
- return {"status": "no_model_loaded"}
512
-
513
- return {
514
- "status": "loaded",
515
- "sample_rate": self.model.sample_rate,
516
- "device": str(self.device),
517
- "model_type": getattr(self.model, 'model_type', 'unknown'),
518
- "io_channels": getattr(self.model, 'io_channels', 'unknown')
519
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SA3 inference engine.
2
+
3
+ Thin wrapper around stable_audio_3.StableAudioModel.from_pretrained() that
4
+ caches the loaded model between requests (eviction on model_id change),
5
+ auto-detects the device, and writes 44.1 kHz stereo int16 WAV.
6
+
7
+ Cancellation is wired via `request_stop()` for API parity, but SA3's
8
+ generate() doesn't expose a per-step callback yet β€” the flag is checked
9
+ between calls, not inside them. A finer-grained cancel hook is a Phase
10
+ 3.1 follow-up.
11
+ """
12
+ import os
13
+ import platform
14
  import re
15
  import sys
16
  import threading
17
  import time
18
  import warnings
19
+ from pathlib import Path
20
+ from typing import Any, Callable, Dict, Optional, Tuple
21
 
22
+ import numpy as np
23
+ import soundfile as sf
24
+ import torch
25
 
26
+ from utils.logger import get_logger
27
+
28
+ logger = get_logger("AudioGenerator")
29
+
30
+
31
+ # Live progress from the SA3 sampler. SA3's `model.generate(**sampler_kwargs)`
32
+ # forwards `callback=fn` into the sampler, which fires it per ODE step with
33
+ # `{'i': step_index, ...}`. We mirror that into this dict so the frontend can
34
+ # poll real progress instead of a fake ticker. Reset on each new generation.
35
+ _generation_state: Dict[str, Any] = {
36
+ "is_generating": False,
37
+ # idle | loading | sampling | decoding | complete | failed
38
+ "phase": "idle",
39
+ "step": 0,
40
+ "total_steps": 0,
41
+ "progress": 0, # 0-100, derived
42
+ "batch_index": 0,
43
+ "batch_total": 0,
44
+ "started_at": None,
45
+ "ended_at": None,
46
+ "error": None,
47
+ }
48
+ _generation_state_lock = threading.Lock()
49
+
50
+
51
+ def get_generation_progress() -> Dict[str, Any]:
52
+ """Snapshot of the current generation's live progress. Cheap to call."""
53
+ with _generation_state_lock:
54
+ return dict(_generation_state)
55
+
56
+
57
+ def _set_progress(**kwargs: Any) -> None:
58
+ """Merge fields into _generation_state under the lock. Recomputes
59
+ `progress` automatically when step/total_steps land in the same update."""
60
+ with _generation_state_lock:
61
+ _generation_state.update(kwargs)
62
+ total = int(_generation_state.get("total_steps") or 0)
63
+ step = int(_generation_state.get("step") or 0)
64
+ _generation_state["progress"] = (
65
+ int(round(100 * step / total)) if total > 0 else 0
66
+ )
67
+
68
+
69
+ def _reset_progress() -> None:
70
+ with _generation_state_lock:
71
+ _generation_state.update({
72
+ "is_generating": False, "phase": "idle",
73
+ "step": 0, "total_steps": 0, "progress": 0,
74
+ "batch_index": 0, "batch_total": 0,
75
+ "started_at": None, "ended_at": None, "error": None,
76
+ })
77
+
78
+ # Vendored SA3 lives at <repo>/vendor/stable-audio-3 β€” put it on sys.path so
79
+ # `import stable_audio_3` resolves without a global pip install.
80
+ _SA3_VENDOR = Path(__file__).resolve().parents[3] / "vendor" / "stable-audio-3"
81
+ if str(_SA3_VENDOR) not in sys.path:
82
+ sys.path.insert(0, str(_SA3_VENDOR))
83
+
84
+
85
+ # model_id -> (sa3_name passed to StableAudioModel.from_pretrained,
86
+ # "user-visible or base" tag, max duration seconds).
87
+ # Kept in sync manually with _SA3_CATALOG in app/core/model_manager.py.
88
+ _MODEL_INFO: Dict[str, Tuple[str, str, int]] = {
89
+ "sa3-small-music": ("small-music", "post", 120),
90
+ "sa3-small-sfx": ("small-sfx", "post", 120),
91
+ "sa3-medium": ("medium", "post", 380),
92
+ "sa3-small-music-base": ("small-music-base", "base", 120),
93
+ "sa3-small-sfx-base": ("small-sfx-base", "base", 120),
94
+ "sa3-medium-base": ("medium-base", "base", 380),
95
+ }
96
 
 
 
 
 
97
 
98
+ class GenerationStopped(Exception):
99
+ """Raised when an in-flight generation is interrupted by a stop request."""
 
 
 
100
 
101
 
102
+ def _slugify(text: str, max_len: int = 40) -> str:
103
+ s = re.sub(r"[^a-zA-Z0-9_-]+", "_", text or "")
104
+ return s[:max_len].strip("_").lower() or "audio"
 
 
105
 
 
 
 
 
106
 
107
+ def _autodetect_device() -> str:
108
+ """cuda β†’ mps β†’ cpu, with FRAGMENTA_FORCE_DEVICE override."""
109
+ override = os.environ.get("FRAGMENTA_FORCE_DEVICE")
110
+ if override:
111
+ return override
112
+ if torch.cuda.is_available():
113
+ return "cuda"
114
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
115
+ return "mps"
116
+ return "cpu"
117
 
118
 
119
  class AudioGenerator:
120
+ """One-model warm cache. Reload only when model_id changes."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ def __init__(self, config: Any) -> None:
123
+ self.config = config
124
+ self.model: Any = None
125
+ self._model_id: Optional[str] = None
126
+ self._device: Optional[str] = None
127
+ self._stop_requested: bool = False
128
+ # Tracks LoRAs currently injected into self.model. List of
129
+ # {"path": str, "strength": float}. Empty when no LoRAs are active.
130
+ self._loaded_loras: list = []
131
+
132
+ # --- cooperative cancel ---------------------------------------------------
133
  def request_stop(self) -> bool:
134
+ if self._stop_requested:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return False
136
+ self._stop_requested = True
137
+ return True
138
 
139
+ # --- model load -----------------------------------------------------------
140
+ def _ensure_model(
141
+ self,
142
+ model_id: str,
143
+ device: Optional[str] = None,
144
+ half: bool = True,
145
+ ) -> None:
146
+ if model_id not in _MODEL_INFO:
147
+ raise ValueError(f"Unknown SA3 model_id: {model_id}")
148
+ sa3_name, _kind, _max_dur = _MODEL_INFO[model_id]
149
+
150
+ if model_id in ("sa3-medium", "sa3-medium-base"):
151
+ # Medium normally requires Flash Attention 2 for its long-form (up
152
+ # to 380s) sliding-window attention. FRAGMENTA_MEDIUM_NO_FLASH=1 is
153
+ # the Path-B validation switch: it lets medium load WITHOUT
154
+ # flash_attn and fall back to PyTorch-native attention
155
+ # (flex_attention -> chunked-halo SDPA -> masked SDPA; see
156
+ # transformer.apply_attn). Output is math-equivalent, but VRAM is
157
+ # higher and sampling slower at long durations. Off by default, so
158
+ # the shipped behaviour is unchanged until the fallback is validated.
159
+ allow_no_flash = os.environ.get("FRAGMENTA_MEDIUM_NO_FLASH") == "1"
160
+ try:
161
+ import flash_attn # noqa: F401
162
+ have_flash = True
163
+ except ImportError as err:
164
+ have_flash = False
165
+ _flash_err = err
166
+
167
+ if not have_flash and not allow_no_flash:
168
+ if platform.system() == "Windows":
169
+ raise RuntimeError(
170
+ "sa3-medium requires Flash Attention 2, which doesn't "
171
+ "have Windows wheels. Use sa3-small-music / sa3-small-sfx, "
172
+ "run Fragmenta via Docker on WSL2, or set "
173
+ "FRAGMENTA_MEDIUM_NO_FLASH=1 to run on the (slower, "
174
+ "higher-memory) PyTorch attention fallback."
175
+ ) from _flash_err
176
+ raise RuntimeError(
177
+ "sa3-medium needs Flash Attention 2 (flash_attn) but the "
178
+ f"current install is unusable: {_flash_err}.\n"
179
+ "Pick the wheel matching your torch+ABI+Python+CUDA from\n"
180
+ " https://github.com/Dao-AILab/flash-attention/releases\n"
181
+ "and install with `pip install --no-deps <wheel-url>`. "
182
+ "See the note next to flash-attn in requirements.txt for an example.\n"
183
+ "Or set FRAGMENTA_MEDIUM_NO_FLASH=1 to use the PyTorch "
184
+ "attention fallback."
185
+ ) from _flash_err
186
+
187
+ if not have_flash:
188
+ logger.warning(
189
+ "sa3-medium loading WITHOUT Flash Attention 2 "
190
+ "(FRAGMENTA_MEDIUM_NO_FLASH=1). Using the PyTorch-native "
191
+ "attention fallback β€” expect higher VRAM and slower sampling "
192
+ "at long durations. Validate memory headroom before "
193
+ "generating long-form (up to 380s) clips."
194
  )
195
 
196
+ device = device or _autodetect_device()
197
+ if (
198
+ self.model is not None
199
+ and self._model_id == model_id
200
+ and self._device == device
201
+ ):
202
+ return # warm cache hit
203
+
204
+ if self.model is not None:
205
+ del self.model
206
+ self.model = None
207
+ if torch.cuda.is_available():
208
+ torch.cuda.empty_cache()
209
+
210
+ # Two layouts to support during the unification transition:
211
+ # 1. Canonical (post-Phase 5c): HF cache layout rooted at
212
+ # <app>/models/pretrained/sa3/hub/. model_manager sets
213
+ # HF_HUB_CACHE to that path, so StableAudioModel.from_pretrained
214
+ # finds files there without going to ~/.cache/huggingface.
215
+ # 2. Legacy: <app>/models/pretrained/sa3/<model_id>/ flat layout
216
+ # from earlier downloads. We fall back to direct load so
217
+ # pre-existing users don't have to re-download.
218
+ #
219
+ # Defense-in-depth: re-force the HF cache vars here too. model_manager
220
+ # sets them at construction, but if generation is reached via an
221
+ # alternate code path or the env was clobbered later, we still
222
+ # guarantee resolution into <pretrained>/sa3/hub/.
223
+ hub_dir = self.config.get_path("models_pretrained") / "sa3" / "hub"
224
+ hf_env_keys = ("HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE",
225
+ "TRANSFORMERS_CACHE", "HF_HUB_OFFLINE")
226
+ prev_env = {k: os.environ.get(k) for k in hf_env_keys}
227
+ os.environ["HF_HUB_CACHE"] = str(hub_dir)
228
+ os.environ["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
229
+ os.environ["TRANSFORMERS_CACHE"] = str(hub_dir)
230
+ os.environ["HF_HUB_OFFLINE"] = "1"
231
+ # huggingface_hub captures HF_HUB_CACHE and HF_HUB_OFFLINE as
232
+ # module-level constants AT IMPORT TIME. The Flask backend imports
233
+ # huggingface_hub (transitively, via model_manager.py) before we ever
234
+ # set these env vars, so the constants point at ~/.cache/huggingface/
235
+ # and offline=False. Setting os.environ now has no effect on already-
236
+ # captured constants. We have to monkey-patch them directly.
237
+ # Same trick we used for the CLAP loader.
238
+ prev_hub_constants = {}
239
+ try:
240
+ import huggingface_hub.constants as _hf_const
241
+ prev_hub_constants = {
242
+ "HF_HUB_CACHE": _hf_const.HF_HUB_CACHE,
243
+ "HF_HUB_OFFLINE": _hf_const.HF_HUB_OFFLINE,
244
+ }
245
+ _hf_const.HF_HUB_CACHE = str(hub_dir)
246
+ _hf_const.HF_HUB_OFFLINE = True
247
+ except Exception:
248
+ _hf_const = None
249
+ try:
250
+ try:
251
+ from stable_audio_3 import StableAudioModel
252
+ with warnings.catch_warnings():
253
+ warnings.simplefilter("ignore")
254
+ self.model = StableAudioModel.from_pretrained(
255
+ sa3_name, device=device, model_half=half,
256
+ )
257
+ except (FileNotFoundError, OSError) as primary_err:
258
+ # HF cache miss β€” fall back to flat layout.
259
+ legacy_dir = self.config.get_path("models_pretrained") / "sa3" / model_id
260
+ config_path = legacy_dir / "model_config.json"
261
+ ckpt_path = legacy_dir / "model.safetensors"
262
+ if not (config_path.exists() and ckpt_path.exists()):
263
+ raise FileNotFoundError(
264
+ f"Checkpoint '{model_id}' not found in HF cache "
265
+ f"({os.environ.get('HF_HUB_CACHE')}) or legacy flat "
266
+ f"layout ({legacy_dir}). Download it from the "
267
+ f"Checkpoint Manager."
268
+ ) from primary_err
269
  import json
270
+ with open(config_path) as fh:
271
+ model_config = json.load(fh)
272
+ from stable_audio_3.loading_utils import load_diffusion_cond
273
+ with warnings.catch_warnings():
274
+ warnings.simplefilter("ignore")
275
+ inner = load_diffusion_cond(
276
+ model_config, str(ckpt_path),
277
+ device=device, model_half=half,
278
+ )
279
+ inner.use_lora = False
280
+ inner.lora_names = []
281
+ self.model = StableAudioModel(inner, model_config, device, half)
282
+ finally:
283
+ for k, v in prev_env.items():
284
+ if v is None:
285
+ os.environ.pop(k, None)
286
+ else:
287
+ os.environ[k] = v
288
+ # Restore the patched constants so we don't permanently alter
289
+ # global huggingface_hub state for anything else in-process.
290
+ if _hf_const is not None and prev_hub_constants:
291
+ _hf_const.HF_HUB_CACHE = prev_hub_constants["HF_HUB_CACHE"]
292
+ _hf_const.HF_HUB_OFFLINE = prev_hub_constants["HF_HUB_OFFLINE"]
293
+ self._model_id = model_id
294
+ self._device = device
295
+
296
+ # --- LoRA stack -----------------------------------------------------------
297
+ def _apply_loras(self, loras: list) -> None:
298
+ """Inject the given LoRA stack into self.model (idempotent).
299
+
300
+ loras: [{"path": str, "strength": float}, ...]
301
+
302
+ Strategy:
303
+ * Same paths in same order β†’ just update strengths in place.
304
+ * Different paths β†’ remove all, load fresh.
305
+ """
306
+ if self.model is None:
307
+ return
308
+
309
+ new_paths = [l["path"] for l in loras]
310
+ cur_paths = [l["path"] for l in self._loaded_loras]
311
 
312
+ if new_paths == cur_paths:
313
+ # Path-set unchanged; only strengths may have moved.
314
+ for i, l in enumerate(loras):
315
+ self.model.set_lora_strength(l["strength"], lora_index=i)
316
+ self._loaded_loras = list(loras)
317
+ return
318
 
319
+ # Path-set changed. Remove any currently loaded, then load the new set.
320
+ if cur_paths:
321
+ try:
322
+ from stable_audio_3.models.lora import remove_lora
323
+ # SA3 applies LoRA to the DiffusionCond's DiT (.model) and
324
+ # conditioner (.conditioner) β€” mirror StableAudioModel's own
325
+ # set_lora_strength which iterates both submodules.
326
+ # `self.model` is StableAudioModel; `self.model.model` is the
327
+ # inner DiffusionCond.
328
+ #
329
+ # remove_lora() strips *every* LoRA parametrization in one
330
+ # pass. We use it instead of remove_lora_by_index(..., 0) in a
331
+ # loop: removal does NOT renumber the remaining adapters, so
332
+ # repeatedly popping index 0 only ever clears the first LoRA
333
+ # and leaves indices 1..n-1 stranded β€” stale adapters then
334
+ # contaminate every later generation with a different stack.
335
+ inner = self.model.model
336
+ remove_lora(inner.model)
337
+ remove_lora(inner.conditioner)
338
+ except Exception as exc:
339
+ # If removal fails (e.g. an upstream API change), force a
340
+ # base-model reload so we don't carry stale adapters. KEEP
341
+ # _model_id intact β€” _ensure_model needs it to know what to
342
+ # reload. (Previous code zeroed it; the reload then raised
343
+ # "Unknown SA3 model_id: None".)
344
+ logger.warning(
345
+ "LoRA removal failed (%s); reloading base model %s",
346
+ exc, self._model_id,
347
+ )
348
+ self.model = None
349
 
350
+ if self.model is None and self._model_id is not None:
351
+ # Forced full reload (only if remove failed above).
352
+ self._ensure_model(self._model_id, device=self._device, half=True)
353
 
354
+ if loras:
355
+ with warnings.catch_warnings():
356
+ warnings.simplefilter("ignore")
357
+ self.model.load_lora(new_paths)
358
+ for i, l in enumerate(loras):
359
+ self.model.set_lora_strength(l["strength"], lora_index=i)
360
 
361
+ self._loaded_loras = list(loras)
 
362
 
363
+ def set_lora_strength(self, index: int, strength: float) -> bool:
364
+ """Live-update one slot's strength. Returns False if index invalid."""
365
+ if not self.model or index < 0 or index >= len(self._loaded_loras):
366
  return False
367
+ self.model.set_lora_strength(float(strength), lora_index=index)
368
+ self._loaded_loras[index]["strength"] = float(strength)
369
+ return True
370
 
371
+ # --- public entry ---------------------------------------------------------
372
  def generate_audio(
373
  self,
374
  prompt: str,
375
+ *,
376
+ model_id: str,
 
377
  duration: float = 10.0,
378
+ steps: Optional[int] = None,
379
+ cfg_scale: Optional[float] = None,
380
  seed: int = -1,
381
+ negative_prompt: Optional[str] = None,
382
+ batch_size: int = 1,
383
+ device: Optional[str] = None,
384
+ half: bool = True,
385
+ chunked_decode: Optional[bool] = None,
386
+ loop_mode: bool = False, # bars-mode passthrough
387
+ loras: Optional[list] = None, # [{path, strength}, ...]
388
+ # Phase 7: audio-to-audio + inpainting -----------------------------
389
+ init_audio_path: Optional[str] = None,
390
+ init_noise_level: float = 1.0,
391
+ inpaint_audio_path: Optional[str] = None,
392
+ inpaint_starts: Optional[list] = None, # list[float], seconds
393
+ inpaint_ends: Optional[list] = None,
394
+ # Phase 7: seamless looping ----------------------------------------
395
+ loop_stitch: Optional[str] = None, # "inpaint" | "crossfade" | None
396
+ loop_bars: Optional[int] = None,
397
+ loop_bpm: Optional[float] = None,
398
+ **_ignored_legacy_kwargs: Any,
399
  ) -> Path:
400
+ self._stop_requested = False
401
+ if self._stop_requested: # honour pre-call stop
402
+ raise GenerationStopped()
403
+
404
+ # `loop_stitch` / `loop_bars` / `loop_bpm` are accepted for API
405
+ # compatibility but ignored β€” the seamless-loop pipeline was
406
+ # removed because user A/B testing showed it degraded audio
407
+ # quality on every prompt class. We deliver raw model output.
408
+
409
+ _set_progress(
410
+ is_generating=True, phase="loading",
411
+ step=0, total_steps=0, error=None,
412
+ started_at=time.time(), ended_at=None,
413
+ )
414
+
415
+ self._ensure_model(model_id, device=device, half=half)
416
+ self._apply_loras(loras or [])
417
+
418
+ init_audio = self._load_audio(init_audio_path) if init_audio_path else None
419
+ inpaint_audio = self._load_audio(inpaint_audio_path) if inpaint_audio_path else None
420
+
421
+ _, kind, max_dur = _MODEL_INFO[model_id]
422
+ is_base = (kind == "base")
423
+
424
+ # Defaults differ by model kind. Post-trained models distilled CFG
425
+ # away; we force cfg=1.0 there even if the caller overrides.
426
+ effective_steps = int(steps) if steps else (50 if is_base else 8)
427
+ effective_cfg = float(cfg_scale) if (cfg_scale is not None and is_base) else (
428
+ 7.0 if is_base else 1.0
429
+ )
430
+
431
+ duration = float(min(max(1.0, float(duration)), float(max_dur)))
432
+
433
+ target_samples = int(round(duration * 44100))
434
+ gen_duration = duration
435
+ total_steps_logical = effective_steps
436
+
437
+ if self._stop_requested: # one more check before the heavy call
438
+ raise GenerationStopped()
439
+
440
+ # Sampler callback β€” fires per ODE step. Also gives us a cheap
441
+ # cancellation hook: raising mid-callback aborts the sampler.
442
+ def _cb(info: Dict[str, Any]) -> None:
443
+ if self._stop_requested:
444
+ raise GenerationStopped()
445
+ i = info.get("i")
446
+ if isinstance(i, int):
447
+ _set_progress(step=min(i + 1, total_steps_logical))
448
+
449
+ _set_progress(phase="sampling", total_steps=int(total_steps_logical), step=0)
450
+
451
+ gen_kwargs = dict(
452
+ prompt=prompt,
453
+ negative_prompt=negative_prompt or None,
454
+ duration=gen_duration,
455
+ steps=effective_steps,
456
+ cfg_scale=effective_cfg,
457
+ seed=int(seed),
458
+ batch_size=int(batch_size),
459
+ chunked_decode=chunked_decode,
460
+ callback=_cb,
461
+ )
462
+ if init_audio is not None:
463
+ gen_kwargs["init_audio"] = init_audio
464
+ gen_kwargs["init_noise_level"] = float(init_noise_level)
465
+ if inpaint_audio is not None:
466
+ gen_kwargs["inpaint_audio"] = inpaint_audio
467
+ if inpaint_starts is not None and len(inpaint_starts) > 0:
468
+ # SA3 accepts a single float or a list for multi-region.
469
+ gen_kwargs["inpaint_mask_start_seconds"] = (
470
+ list(inpaint_starts) if len(inpaint_starts) > 1 else float(inpaint_starts[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  )
472
+ if inpaint_ends is not None and len(inpaint_ends) > 0:
473
+ gen_kwargs["inpaint_mask_end_seconds"] = (
474
+ list(inpaint_ends) if len(inpaint_ends) > 1 else float(inpaint_ends[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  )
476
 
477
+ try:
478
+ audio = self.model.generate(**gen_kwargs)
479
+ # audio: torch.Tensor[B, channels=2, samples] in [-1, 1] @ 44.1 kHz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  except GenerationStopped:
481
+ _set_progress(phase="idle", is_generating=False, ended_at=time.time())
482
  raise
483
+ except Exception as exc:
484
+ _set_progress(phase="failed", is_generating=False,
485
+ error=str(exc), ended_at=time.time())
 
486
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
+ # Seamless-loop processing (quantize, inpaint, crossfade) was
489
+ # removed: the user A/B-compared raw SA3 output against the full
490
+ # pipeline and confirmed the post-processing made every prompt
491
+ # worse β€” silence-at-start on percussion, smeared transients,
492
+ # off-grid anchoring. We now deliver the raw model output. The
493
+ # `loop_stitch` / `loop_bars` / `loop_bpm` parameters are still
494
+ # accepted from the frontend for API compatibility but are
495
+ # ignored. Performance-Bars looping will have an audible click
496
+ # at the wrap point and multi-channel stacks will not be
497
+ # sample-aligned β€” both acceptable trade-offs vs. the artifacts
498
+ # the quantizer was introducing.
499
+ _set_progress(phase="decoding", step=total_steps_logical)
500
+ try:
501
+ return self._finalize(audio, prompt=prompt, model_id=model_id)
502
+ finally:
503
+ _set_progress(phase="complete", is_generating=False,
504
+ step=total_steps_logical, ended_at=time.time())
 
 
 
 
 
505
 
506
+ # --- audio loader (a2a + inpaint inputs) ----------------------------------
507
+ @staticmethod
508
+ def _load_audio(path: str):
509
+ """Load a wav/mp3/flac into the (sample_rate, tensor) tuple SA3 expects.
510
 
511
+ Returns a stereo float32 tensor of shape (channels, samples). Mono
512
+ inputs are duplicated to stereo (SA3 expects 2 channels); β‰₯3-channel
513
+ inputs are truncated to the first 2.
514
+ """
515
+ import torchaudio
516
+ wav, sr = torchaudio.load(str(path)) # (channels, samples), float32
517
+ if wav.shape[0] == 1:
518
+ wav = wav.repeat(2, 1)
519
+ elif wav.shape[0] > 2:
520
+ wav = wav[:2]
521
+ return int(sr), wav
522
+
523
+ # --- output --------------------------------------------------------------
524
+ def _finalize(self, audio: torch.Tensor, *, prompt: str, model_id: str) -> Path:
525
+ audio = audio.detach().clamp_(-1.0, 1.0).cpu()
526
+ if audio.ndim != 3:
527
+ raise RuntimeError(f"Unexpected SA3 output shape {tuple(audio.shape)}")
528
+ first = audio[0] # [C, samples]
529
+ pcm = (first.numpy() * 32767.0).astype(np.int16).T # β†’ [samples, C]
530
+
531
+ out_dir = self.config.get_path("output")
532
+ out_dir.mkdir(parents=True, exist_ok=True)
533
+ ts = time.strftime("%Y%m%d_%H%M%S")
534
+ out_path = out_dir / f"{ts}_{model_id}_{_slugify(prompt)}.wav"
535
+ sf.write(str(out_path), pcm, 44100, subtype="PCM_16")
536
+ return out_path
app/core/generation/audio_post_process.py CHANGED
@@ -1,9 +1,40 @@
1
  """Beat-align and tempo-conform a generated WAV to a target BPM and bar count.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  """
3
 
4
  from __future__ import annotations
5
 
6
  import logging
 
 
7
  from pathlib import Path
8
  from typing import Optional, Tuple
9
 
@@ -14,35 +45,429 @@ import soundfile as sf
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
- # Safe range for phase-vocoder time-stretching. Wider than the previous
18
- # [0.7, 1.4] so we actually warp in more cases β€” librosa's vocoder produces
19
- # acceptable audio across this range for music, and the alternative
20
- # (no warp at all) drifts off the grid completely on loop.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  _STRETCH_SAFE_MIN = 0.6
22
  _STRETCH_SAFE_MAX = 1.7
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def align_to_grid(
26
  input_path: Path,
27
  target_bpm: float,
28
  target_bars: int,
29
  beats_per_bar: int = 4,
30
  ) -> Path:
 
 
 
 
 
 
31
  audio, sr = sf.read(str(input_path), always_2d=True)
32
  audio = audio.astype(np.float32, copy=False)
33
- target_samples = int(round(target_bars * beats_per_bar * 60.0 / target_bpm * sr))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
36
 
37
- # Pass target_bpm as a prior to librosa β€” biases the beat tracker away
38
- # from half-time / double-time interpretations of the same grid.
39
- detected_bpm, first_beat = _detect_grid_anchor(mono, sr, start_bpm=target_bpm)
40
 
 
41
  head_offset = 0
42
- if first_beat is not None and 0 < first_beat < sr * 1.5:
43
- head_offset = first_beat
44
- logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms to first beat")
45
- elif first_beat is None:
 
 
46
  head_offset = _detect_first_onset_sample(mono, sr)
47
  if head_offset > 0:
48
  logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms (onset fallback)")
@@ -50,12 +475,26 @@ def align_to_grid(
50
  if head_offset > 0:
51
  audio = audio[head_offset:]
52
  mono = mono[head_offset:]
 
 
 
 
 
53
 
 
54
  if detected_bpm is not None:
55
- rate, effective_bpm = _best_stretch_rate(detected_bpm, target_bpm)
56
- if rate is not None:
57
- if abs(rate - 1.0) > 1e-3:
58
- audio = _time_stretch_multichannel(audio, rate)
 
 
 
 
 
 
 
 
59
  interp_note = (
60
  f" (interpreted as {effective_bpm:.2f} BPM, "
61
  f"octave={effective_bpm / detected_bpm:.2f}Γ—)"
@@ -66,33 +505,267 @@ def align_to_grid(
66
  f"align_to_grid: detected {detected_bpm:.2f} BPM{interp_note}, "
67
  f"stretched by {rate:.4f} to match target {target_bpm:.2f} BPM"
68
  )
 
 
 
 
 
 
69
  else:
70
  logger.info(
71
  f"align_to_grid: detected {detected_bpm:.2f} BPM has no safe "
72
- f"interpretation vs target {target_bpm:.2f}; skipping warp"
 
 
73
  )
74
  else:
75
  logger.info("align_to_grid: no usable tempo detected; skipping warp")
76
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  if audio.shape[0] > target_samples:
78
- audio = audio[:target_samples]
79
- # 8ms tail fade prevents the click at the loop boundary when the
80
- # truncation point lands mid-waveform.
81
- fade_samples = min(int(0.008 * sr), audio.shape[0])
82
- if fade_samples > 1:
83
- fade = np.linspace(1.0, 0.0, fade_samples, dtype=audio.dtype)
84
- audio[-fade_samples:] *= fade[:, np.newaxis] if audio.ndim > 1 else fade
85
- elif audio.shape[0] < target_samples:
86
- pad = np.zeros((target_samples - audio.shape[0], audio.shape[1]), dtype=audio.dtype)
87
- audio = np.concatenate([audio, pad], axis=0)
 
 
88
 
89
  sf.write(str(input_path), audio, sr, subtype="PCM_16")
90
  return input_path
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def _best_stretch_rate(
94
  detected_bpm: float,
95
  target_bpm: float,
 
 
 
96
  ) -> Tuple[Optional[float], float]:
97
  """Pick the time-stretch rate that maps detected β†’ target, considering
98
  half-time and double-time interpretations of the detected tempo. Returns
@@ -101,27 +774,20 @@ def _best_stretch_rate(
101
  nothing safe is available.
102
 
103
  Order of preference:
104
- 1. Detected as-is, if it lands inside the safe stretch range.
105
  2. Octave-corrected (detected Γ— 0.5 or Γ— 2.0), only when the as-is
106
  interpretation is out of range. This is the librosa half-/double-
107
  time error recovery path.
108
-
109
- This biases the algorithm toward honesty: only re-interpret the
110
- detector's reading when it can't otherwise produce a usable stretch.
111
  """
112
- # First, try the detector's reading at face value.
113
  rate_asis = target_bpm / detected_bpm
114
- if _STRETCH_SAFE_MIN <= rate_asis <= _STRETCH_SAFE_MAX:
115
  return rate_asis, detected_bpm
116
 
117
- # As-is is out of safe range β€” almost certainly a librosa octave error.
118
- # Try the half-time and double-time reinterpretations and pick whichever
119
- # is closest to a no-op stretch.
120
  candidates = []
121
  for octave_factor in (0.5, 2.0):
122
  interpreted = detected_bpm * octave_factor
123
  rate = target_bpm / interpreted
124
- if _STRETCH_SAFE_MIN <= rate <= _STRETCH_SAFE_MAX:
125
  candidates.append((abs(rate - 1.0), rate, interpreted))
126
  if not candidates:
127
  return None, detected_bpm
@@ -130,6 +796,7 @@ def _best_stretch_rate(
130
  return best_rate, best_interp
131
 
132
 
 
133
  def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int:
134
  """Return the sample index of the first detected onset, or 0 if none found."""
135
  try:
@@ -147,15 +814,16 @@ def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int:
147
  return first
148
 
149
 
150
- def _detect_grid_anchor(
 
151
  mono: np.ndarray,
152
  sr: int,
153
  start_bpm: Optional[float] = None,
154
- ) -> Tuple[Optional[float], Optional[int]]:
155
- """Run librosa beat tracking with the target tempo as a prior. Passing
156
- start_bpm reduces (but doesn't eliminate) half-time / double-time errors.
157
- The octave-correction in _best_stretch_rate handles whatever librosa
158
- still gets wrong."""
159
  try:
160
  kwargs = {"y": mono, "sr": sr, "units": "samples"}
161
  if start_bpm is not None and start_bpm > 0:
@@ -169,9 +837,10 @@ def _detect_grid_anchor(
169
  bpm = float(np.atleast_1d(tempo).flatten()[0])
170
  if not (40.0 <= bpm <= 240.0):
171
  return None, None
172
- return bpm, int(beats[0])
173
 
174
 
 
175
  def _time_stretch_multichannel(audio: np.ndarray, rate: float) -> np.ndarray:
176
  """Phase-vocoder time stretch, applied per channel and re-stacked."""
177
  stretched = librosa.effects.time_stretch(audio.T, rate=rate)
 
1
  """Beat-align and tempo-conform a generated WAV to a target BPM and bar count.
2
+
3
+ DEPRECATED β€” this entire module is superseded by ``app/core/loop_quantizer/``
4
+ (see ``task_1.md`` and ``AUDIT.md`` Β§9 "Scheduled for removal"). The legacy
5
+ ``align_to_grid`` / ``align_for_loop`` path and the gated ``_stage_a_v2`` path
6
+ both live here until the new module passes acceptance; once it does, every
7
+ public symbol below is removed and the file deletes itself. Do NOT add new
8
+ callers, do NOT extend the v1 helpers, and prefer adding work directly under
9
+ ``app/core/loop_quantizer/`` for any new behaviour.
10
+
11
+ SA3 generates at the exact requested duration via variable-length flow
12
+ matching, so the post-processor's role is **drift correction**, not length
13
+ control: it only nudges the audio when librosa detects that the realised
14
+ tempo has drifted from the target. The tempo-conform gate is intentionally
15
+ tight β€” `|rate - 1| > 5%` AND `rate in [0.85, 1.15]` β€” so we never warp
16
+ audibly when SA3 was already close.
17
+
18
+ Pipeline (in order):
19
+ 1. Detect tempo + beat grid via librosa (with target BPM as prior).
20
+ 2. Head-trim to the first detected beat (or first onset as fallback),
21
+ followed by a 3 ms equal-power fade-in to mask the trim seam.
22
+ 3. Tempo-conform via phase-vocoder time-stretch, ONLY when the detected
23
+ tempo drifts >5% from target AND the resulting stretch lies inside
24
+ the safe range [0.85, 1.15]. Outside this window we leave the audio
25
+ alone and let the user re-roll.
26
+ 4. End-anchored truncation: snap the cut to the nearest detected beat
27
+ within Β±Β½ beat of the mathematical target sample count, so loops
28
+ don't end mid-note. Followed by an 8 ms equal-power fade-out so the
29
+ loop seam doesn't click.
30
+ 5. Zero-pad if the audio came out shorter than the target.
31
  """
32
 
33
  from __future__ import annotations
34
 
35
  import logging
36
+ import os
37
+ import warnings
38
  from pathlib import Path
39
  from typing import Optional, Tuple
40
 
 
45
  logger = logging.getLogger(__name__)
46
 
47
 
48
+ # DEPRECATED: flag goes away with the v1/v2 split (AUDIT.md Β§9d).
49
+ def beatsync_v2_enabled() -> bool:
50
+ """Feature gate for the hardened Stage A pipeline (sample-exact length,
51
+ first-transient-to-zero alignment, transient-preserving stretch).
52
+
53
+ Off by default: with the flag unset, every Stage A function takes its
54
+ legacy code path, so Bars-mode output is byte-identical to pre-flag
55
+ builds and Seconds mode (which never enters Stage A at all) is unaffected.
56
+ Enable with ``FRAGMENTA_BEATSYNC_V2=1``.
57
+ """
58
+ return os.environ.get("FRAGMENTA_BEATSYNC_V2", "0").strip().lower() in (
59
+ "1", "true", "yes", "on",
60
+ )
61
+
62
+
63
+ # DEPRECATED: flag goes away with the v1/v2 split (AUDIT.md Β§9d).
64
+ def _warp_enabled() -> bool:
65
+ """Per-beat (Ableton 'Beats'-style) warp gate β€” OFF by default.
66
+
67
+ The warp is only as reliable as librosa's per-beat detection; on real audio
68
+ a single mis-detected beat scrambles the groove. Anchor + exact-crop already
69
+ lands real loops at ~3 ms, so the warp is opt-in for experimentation only.
70
+ Enable with ``FRAGMENTA_BEATSYNC_WARP=1``."""
71
+ return os.environ.get("FRAGMENTA_BEATSYNC_WARP", "0").strip().lower() in (
72
+ "1", "true", "yes", "on",
73
+ )
74
+
75
+
76
+ # Liberal module-default range for `_best_stretch_rate`. Kept wide so any
77
+ # future force-warp caller has room; the bars-mode drift-correction path
78
+ # (`align_to_grid`) overrides with tighter bounds below.
79
  _STRETCH_SAFE_MIN = 0.6
80
  _STRETCH_SAFE_MAX = 1.7
81
 
82
+ # Bars-mode drift correction. SA3 hits the requested duration exactly via
83
+ # variable-length generation, so the post-processor only kicks in when the
84
+ # detected tempo of the generated audio drifts from the requested target.
85
+ # Tight gates avoid audible vocoder artifacts when SA3 was already close.
86
+ _BARS_MODE_STRETCH_MIN = 0.85
87
+ _BARS_MODE_STRETCH_MAX = 1.15
88
+ _BARS_MODE_DEADBAND = 0.05
89
+
90
+ # Loop-mode (Phase 7) is stricter β€” a 5% tempo slack compounds visibly when
91
+ # multiple loop channels run side-by-side, even though loop iteration
92
+ # lengths are sample-exact. 0.5% is below librosa's noise floor for beat
93
+ # detection on rhythmic content, so we won't be acting on noise, but we
94
+ # WILL correct anything detectable that the looser bars-mode would skip.
95
+ _LOOP_MODE_DEADBAND = 0.005
96
+
97
+ # Fade durations applied at trim points. Kept very short β€” the fade is
98
+ # click-prevention, not a perceptible ramp. Performance Mode loops these
99
+ # clips, and longer fades audibly "duck" the loop seam.
100
+ _HEAD_FADE_SEC = 0.003 # mask click at the trimmed head
101
+ _TAIL_FADE_SEC = 0.003 # mask click at a mid-note truncation; skipped on beats
102
+
103
+ # Trailing-silence detection. SA3 occasionally pads a generation with low-
104
+ # level tail; the post-processor used to keep that and fade over it, which
105
+ # produced perceptible "silence + duck" at the loop point.
106
+ _SILENCE_THRESHOLD_DB = -50.0 # anything below is silence
107
+ _SILENCE_WINDOW_SEC = 0.05 # RMS window granularity
108
+ _SILENCE_TAIL_KEEP_SEC = 0.010 # leave a tiny natural decay
109
+
110
+ # v2 first-transient search: a downbeat lands within the first bar or two of
111
+ # generated content, so we never hunt past this window for the musical "1".
112
+ _V2_TRANSIENT_SEARCH_SEC = 1.5
113
+ _V2_STRONG_RATIO = 0.30 # candidate must reach 30% of peak
114
+ _V2_RISE_RATIO = 0.15 # rising-edge threshold for refinement
115
+ _V2_REFINE_WIN_SEC = 0.03 # +/- window for sample-accurate refine
116
+
117
+ # Grid confidence. librosa's beat tracker emits a tempo for ANY input β€” on
118
+ # ambient/textural content it is essentially noise (measured: 49-161 BPM on a
119
+ # 120-BPM target, 130+ ms intra-beat drift). Warping toward a wrong detected
120
+ # tempo is worse than not warping, so we only tempo-conform when the detected
121
+ # grid is trustworthy: beats evenly spaced (low interval CV) AND a clear pulse
122
+ # in the onset envelope. Below the threshold we trust the *requested* grid and
123
+ # skip the stretch (still doing the safe, tempo-independent transient@0 + crop).
124
+ # Calibrated on real fixtures: clean drum/bass loops score 0.76-0.88, pure
125
+ # pads 0.00 (no trackable beat), and ambiguous textures 0.44-0.57 β€” often with
126
+ # a wrong detected tempo. 0.65 sits in that gap. (The safe-range gate in
127
+ # _best_stretch_rate independently rejects octave-wrong tempos like 49/161 BPM.)
128
+ _GRID_CONFIDENCE_MIN = 0.65
129
+ _CV_MAX = 0.20 # interval CV at which regularity -> 0
130
+
131
+ # Beat-synchronous warp (Ableton "Beats"-style). Measured: real drum loops are
132
+ # already coherent to ~3-6 ms, where anchor+exact-crop alone lands single-digit
133
+ # ms β€” so a global/elastic warp there only adds phase-vocoder jitter for no gain.
134
+ # We therefore warp ONLY when a confident grid still drifts past this threshold,
135
+ # and need enough beats to define segments.
136
+ _WARP_DRIFT_MIN_MS = 15.0
137
+ _WARP_MIN_BEATS = 6
138
+
139
+
140
+ # === Stage A v2 (FRAGMENTA_BEATSYNC_V2) ====================================
141
+ # DEPRECATED: every symbol in this section is scheduled for relocation into
142
+ # `app/core/loop_quantizer` (see AUDIT.md Β§9c). Port the logic, then delete
143
+ # the originals here. Do NOT add new callers to anything below.
144
+ # A single hardened core shared by both align entry points. It enforces the
145
+ # locked invariants directly instead of relying on librosa's beat[0] for
146
+ # phase and on end-snap/silence-trim for length:
147
+ # * tempo conform with a bounded phase-vocoder stretch (_conform_stretch) β€”
148
+ # gen-time warp only, no live tracking (decision: v1);
149
+ # * align the first STRONG transient to sample 0 (rotate-free head trim) so
150
+ # two independently-correct clips share a downbeat with zero per-clip code;
151
+ # * crop to the exact target sample count β€” overgenerate-then-trim, never
152
+ # zero-pad in the common path (pad only as a logged last resort).
153
+
154
+ def _stage_a_v2(
155
+ audio: np.ndarray,
156
+ sr: int,
157
+ *,
158
+ target_samples: int,
159
+ target_bpm: float,
160
+ deadband: float,
161
+ ) -> np.ndarray:
162
+ """Hardened Stage A core. Input/return: float32 ``[T, C]``.
163
+
164
+ Decides per clip how to land it on the grid:
165
+ * low grid confidence -> place as-is (trust the requested grid; no warp,
166
+ no trim β€” Ableton likewise won't warp a pulse-less texture);
167
+ * confident + non-uniform drift -> beat-synchronous warp (each inter-beat
168
+ segment stretched onto the exact grid, Ableton "Beats" warp);
169
+ * confident + already coherent -> anchor + (optional) whole-loop tempo
170
+ nudge; the measured workhorse path (single-digit ms on real loops).
171
+ Always finishes with: first-strong-transient -> sample 0, then exact crop.
172
+ """
173
+ mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
174
+ detected_bpm, beats = _detect_grid(mono, sr, start_bpm=target_bpm)
175
+ confidence = _grid_confidence(mono, sr, beats)
176
+ spb = sr * 60.0 / target_bpm
177
+
178
+ trusted = (
179
+ detected_bpm is not None
180
+ and confidence >= _GRID_CONFIDENCE_MIN
181
+ and beats is not None
182
+ and len(beats) >= _WARP_MIN_BEATS
183
+ )
184
+
185
+ if not trusted:
186
+ logger.info(
187
+ "stage_a_v2: %s; trusting requested %.2f BPM grid, exact-length only",
188
+ "low grid confidence (%.2f < %.2f)" % (confidence, _GRID_CONFIDENCE_MIN)
189
+ if detected_bpm is not None else "no usable grid",
190
+ target_bpm,
191
+ )
192
+ return _exact_len(audio, target_samples, sr)
193
+
194
+ # --- anchor the musical "1" to sample 0 (INV#4, enables INV#9) --------
195
+ # Anchor to the first TRACKED beat, not the "first loud onset": the tracked
196
+ # beat is the same metrical position across clips, so two loops coincide;
197
+ # "first loud onset" lands on whatever transient happens to be loudest and
198
+ # differs per clip (measured: 200+ ms apart). Refine beats[0] to the exact
199
+ # rising edge for sample accuracy.
200
+ anchor = _refine_to_transient(mono, int(beats[0]), sr)
201
+ if anchor > 0:
202
+ audio = audio[anchor:]
203
+ mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
204
+ beats = np.asarray(beats, dtype=np.int64) - anchor
205
+ beats = beats[beats >= 0]
206
+
207
+ drift = _grid_drift_samples(beats)
208
+ if (_warp_enabled() and drift > _WARP_DRIFT_MIN_MS * sr / 1000.0
209
+ and len(beats) >= 2):
210
+ # OFF BY DEFAULT (FRAGMENTA_BEATSYNC_WARP). Per-beat warp is only as good
211
+ # as librosa's beat detection β€” when detection is even slightly off it
212
+ # warps the wrong points onto the grid and SCRAMBLES the groove on real
213
+ # audio. Measured gain on clean drift was marginal (it merely halved it
214
+ # and added jitter), while anchor + exact-crop already lands real loops
215
+ # at ~3 ms. So it's opt-in for experiments, not the default path.
216
+ audio = _beat_sync_warp(audio, beats, spb)
217
+ logger.info("stage_a_v2: anchored + beat-sync warp (intra-loop drift "
218
+ "%.1f ms)", drift / sr * 1000)
219
+ else:
220
+ # Already coherent: a single global stretch is sufficient (and cleaner
221
+ # than per-segment warping) when the overall tempo is off; otherwise
222
+ # the anchor + exact crop is all that's needed.
223
+ rate, eff = _best_stretch_rate(
224
+ detected_bpm, target_bpm,
225
+ safe_min=_BARS_MODE_STRETCH_MIN, safe_max=_BARS_MODE_STRETCH_MAX,
226
+ )
227
+ if rate is not None and abs(rate - 1.0) > deadband:
228
+ audio = _conform_stretch(audio, rate, sr)
229
+ logger.info("stage_a_v2: anchored + global tempo conform x%.4f "
230
+ "(detected %.2f -> %.2f)", rate, detected_bpm, target_bpm)
231
+ else:
232
+ logger.info("stage_a_v2: anchored only (low drift, on-tempo)")
233
+
234
+ return _exact_len(audio, target_samples, sr)
235
+
236
+
237
+ def _exact_len(audio: np.ndarray, target_samples: int, sr: int) -> np.ndarray:
238
+ """Crop to exactly target_samples (INV#2/#3). Pads only as a logged last
239
+ resort β€” the generation overshoots duration so trimming is the norm."""
240
+ if audio.shape[0] >= target_samples:
241
+ return np.ascontiguousarray(audio[:target_samples], dtype=np.float32)
242
+ pad = target_samples - audio.shape[0]
243
+ logger.warning(
244
+ "stage_a_v2: content short by %d samp (%.0f ms) β€” padding as a last "
245
+ "resort; raise generation headroom or re-roll", pad, pad / sr * 1000,
246
+ )
247
+ return np.ascontiguousarray(
248
+ np.concatenate([audio, np.zeros((pad, audio.shape[1]), np.float32)], 0),
249
+ dtype=np.float32,
250
+ )
251
+
252
+
253
+ def _grid_drift_samples(beats: Optional[np.ndarray]) -> float:
254
+ """Std of detected-beat residuals vs a uniform least-squares grid (samples).
255
+ A coherent loop sits near 0; tempo wobble shows up as a large residual."""
256
+ if beats is None or len(beats) < 4:
257
+ return 0.0
258
+ idx = np.arange(len(beats))
259
+ A = np.vstack([idx, np.ones_like(idx)]).T
260
+ slope, icpt = np.linalg.lstsq(A, beats.astype(float), rcond=None)[0]
261
+ resid = beats.astype(float) - (slope * idx + icpt)
262
+ return float(np.std(resid))
263
+
264
+
265
+ def _refine_to_transient(mono: np.ndarray, approx: int, sr: int,
266
+ win_sec: float = 0.015) -> int:
267
+ """Snap a frame-resolution beat sample to the exact rising edge of the
268
+ transient AT that beat. librosa picks WHICH transient is the beat (good);
269
+ this gives it sample accuracy (INV#4). The window is deliberately tight
270
+ (~15 ms): wide enough to cover beat-tracker frame jitter, narrow enough not
271
+ to jump to a neighbouring transient (which would desync clips, INV#9)."""
272
+ n = len(mono)
273
+ if n == 0:
274
+ return 0
275
+ approx = int(max(0, min(approx, n - 1)))
276
+ lo = max(0, approx - int(sr * win_sec))
277
+ hi = min(n, approx + int(sr * win_sec))
278
+ if hi - lo < 2:
279
+ return approx
280
+ seg = np.abs(mono[lo:hi])
281
+ pk = float(seg.max())
282
+ if pk <= 1e-6:
283
+ return approx
284
+ above = np.flatnonzero(seg >= _V2_RISE_RATIO * pk)
285
+ return int(lo + above[0]) if len(above) else approx
286
+
287
+
288
+ def _beat_sync_warp(audio: np.ndarray, beats: np.ndarray, spb: float) -> np.ndarray:
289
+ """Ableton 'Beats'-style warp: stretch each inter-beat segment to exactly
290
+ round(spb) samples. Output starts at the first detected beat and has a
291
+ perfectly uniform grid, so two clips at the same tempo become sample-for-
292
+ sample periodic (INV#9). Phase-vocoder per segment; only invoked when drift
293
+ is high enough to be worth the boundary jitter."""
294
+ beats = np.asarray(beats, dtype=np.int64)
295
+ beats = beats[(beats >= 0) & (beats < audio.shape[0])]
296
+ if len(beats) < 2:
297
+ return audio
298
+ target_spb = int(round(spb))
299
+ segs = []
300
+ for i in range(len(beats) - 1):
301
+ s, e = int(beats[i]), int(beats[i + 1])
302
+ seg = audio[s:e]
303
+ if seg.shape[0] < 16:
304
+ continue
305
+ rate = float(np.clip(seg.shape[0] / spb, 0.5, 2.0))
306
+ w = librosa.effects.time_stretch(seg.T, rate=rate).T
307
+ if w.shape[0] >= target_spb:
308
+ w = w[:target_spb]
309
+ else:
310
+ w = np.concatenate(
311
+ [w, np.zeros((target_spb - w.shape[0], w.shape[1]), np.float32)], 0)
312
+ segs.append(np.ascontiguousarray(w, dtype=np.float32))
313
+ return np.concatenate(segs, 0) if segs else audio
314
+
315
+
316
+ def _grid_confidence(
317
+ mono: np.ndarray, sr: int, beats: Optional[np.ndarray]
318
+ ) -> float:
319
+ """Trustworthiness of the detected beat grid, in [0, 1].
320
+
321
+ Two evidence sources, averaged:
322
+ * regularity β€” how evenly spaced the detected beats are (1 - interval
323
+ coefficient of variation, clamped); a locked tracker gives near-even
324
+ intervals, ambient content gives erratic ones;
325
+ * pulse clarity β€” the strongest off-zero peak of the onset-envelope
326
+ autocorrelation relative to lag 0; high when there is a real periodic
327
+ pulse, low for drones/pads.
328
+ """
329
+ if beats is None or len(beats) < 4:
330
+ return 0.0
331
+ intervals = np.diff(beats.astype(np.float64))
332
+ mean_i = float(np.mean(intervals)) if len(intervals) else 0.0
333
+ if mean_i <= 0:
334
+ return 0.0
335
+ cv = float(np.std(intervals) / mean_i)
336
+ regularity = max(0.0, min(1.0, 1.0 - cv / _CV_MAX))
337
+
338
+ clarity = 0.0
339
+ try:
340
+ oenv = librosa.onset.onset_strength(y=mono, sr=sr)
341
+ oenv = oenv - float(np.mean(oenv))
342
+ ac = librosa.autocorrelate(oenv)
343
+ if len(ac) > 4 and ac[0] > 0:
344
+ clarity = float(np.max(ac[4:]) / ac[0])
345
+ clarity = max(0.0, min(1.0, clarity))
346
+ except Exception as exc:
347
+ logger.warning("grid-confidence clarity failed: %s", exc)
348
+
349
+ return 0.5 * regularity + 0.5 * clarity
350
+
351
+
352
+ def _first_strong_transient(mono: np.ndarray, sr: int) -> int:
353
+ """Sample index of the first STRONG transient, refined to the rising edge.
354
+
355
+ Two-stage so we neither latch onto low-level noise nor lose sample
356
+ accuracy to librosa's 512-sample hop:
357
+ 1. librosa onset candidates; take the first whose local peak reaches
358
+ ``_V2_STRONG_RATIO`` of the search-window peak;
359
+ 2. refine within a small window to the first sample crossing
360
+ ``_V2_RISE_RATIO`` of that local peak β€” the attack's true start.
361
+ Returns 0 when the clip is silent or no strong transient is found.
362
+ """
363
+ n = len(mono)
364
+ search = min(n, int(sr * _V2_TRANSIENT_SEARCH_SEC))
365
+ if search <= 0:
366
+ return 0
367
+ peak = float(np.max(np.abs(mono[:search])))
368
+ if peak <= 1e-6:
369
+ return 0
370
+
371
+ try:
372
+ onsets = librosa.onset.onset_detect(
373
+ y=mono, sr=sr, units="samples", backtrack=True
374
+ )
375
+ except Exception as exc:
376
+ logger.warning("v2 onset detection failed: %s", exc)
377
+ onsets = None
378
+
379
+ cand: Optional[int] = None
380
+ if onsets is not None and len(onsets) > 0:
381
+ look = int(sr * 0.05)
382
+ for o in np.asarray(onsets, dtype=np.int64):
383
+ if o >= search:
384
+ break
385
+ lo, hi = int(o), min(n, int(o) + look)
386
+ if float(np.max(np.abs(mono[lo:hi]))) >= _V2_STRONG_RATIO * peak:
387
+ cand = int(o)
388
+ break
389
+
390
+ if cand is None:
391
+ # No qualifying onset β€” fall back to the first sample that crosses a
392
+ # fraction of the window peak (handles smooth/pad content).
393
+ idx = np.flatnonzero(np.abs(mono[:search]) >= _V2_STRONG_RATIO * peak)
394
+ return int(idx[0]) if len(idx) else 0
395
 
396
+ win = int(sr * _V2_REFINE_WIN_SEC)
397
+ lo = max(0, cand - win)
398
+ hi = min(n, cand + win)
399
+ local_peak = float(np.max(np.abs(mono[lo:hi]))) or peak
400
+ seg = np.abs(mono[lo:hi])
401
+ above = np.flatnonzero(seg >= _V2_RISE_RATIO * local_peak)
402
+ return int(lo + above[0]) if len(above) else cand
403
+
404
+
405
+ def _conform_stretch(audio: np.ndarray, rate: float, sr: int) -> np.ndarray:
406
+ """Tempo-conform time-stretch β€” the INV#5 "justified equivalent".
407
+
408
+ We use the librosa phase vocoder (no external binary to ship) rather than
409
+ RubberBand's transient mode, justified by three properties that keep
410
+ transient smearing perceptually negligible here:
411
+
412
+ 1. Bounded rate. This only runs inside the safe range [0.85, 1.15] β€” at
413
+ most a 15% stretch β€” where phase-vocoder transient blur is minor.
414
+ 2. Rare path. It fires only on high grid-confidence, off-by->0.5%-tempo
415
+ loops; SA3 usually hits the target at gen-time and skips it entirely.
416
+ 3. The perceptually critical transient β€” the downbeat β€” is positioned by
417
+ the sample-accurate trim in `_stage_a_v2`, NOT by this stretch, so the
418
+ musical "1" is never vocoded.
419
+
420
+ `sr` is accepted for call-site symmetry (the phase vocoder is rate-only)."""
421
+ if abs(rate - 1.0) < 1e-9:
422
+ return audio
423
+ return _time_stretch_multichannel(audio, rate)
424
+
425
+
426
+ # DEPRECATED: superseded by app/core/loop_quantizer (see task_1.md / AUDIT.md Β§9a).
427
+ # Public entry; emits DeprecationWarning at runtime. Scheduled for removal once
428
+ # the new module passes acceptance.
429
  def align_to_grid(
430
  input_path: Path,
431
  target_bpm: float,
432
  target_bars: int,
433
  beats_per_bar: int = 4,
434
  ) -> Path:
435
+ warnings.warn(
436
+ "align_to_grid is deprecated and will be removed once "
437
+ "app/core/loop_quantizer ships (see task_1.md / AUDIT.md Β§9a).",
438
+ DeprecationWarning,
439
+ stacklevel=2,
440
+ )
441
  audio, sr = sf.read(str(input_path), always_2d=True)
442
  audio = audio.astype(np.float32, copy=False)
443
+ samples_per_beat = sr * 60.0 / float(target_bpm)
444
+ target_samples = int(round(target_bars * beats_per_bar * samples_per_beat))
445
+
446
+ if beatsync_v2_enabled():
447
+ out = _stage_a_v2(
448
+ np.ascontiguousarray(audio), sr,
449
+ target_samples=target_samples, target_bpm=float(target_bpm),
450
+ deadband=_BARS_MODE_DEADBAND,
451
+ )
452
+ # 3 ms head fade-in masks any click at the new sample-0 transient.
453
+ _apply_fade(out, _HEAD_FADE_SEC, sr, fade_in=True)
454
+ sf.write(str(input_path), out, sr, subtype="PCM_16")
455
+ logger.info("align_to_grid[v2]: %d samples (exact target %d)",
456
+ out.shape[0], target_samples)
457
+ return input_path
458
 
459
  mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
460
 
461
+ detected_bpm, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
 
 
462
 
463
+ # --- Head trim ---------------------------------------------------------
464
  head_offset = 0
465
+ if beat_samples is not None and len(beat_samples) > 0:
466
+ first_beat = int(beat_samples[0])
467
+ if 0 < first_beat < sr * 1.5:
468
+ head_offset = first_beat
469
+ logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms to first beat")
470
+ elif beat_samples is None:
471
  head_offset = _detect_first_onset_sample(mono, sr)
472
  if head_offset > 0:
473
  logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms (onset fallback)")
 
475
  if head_offset > 0:
476
  audio = audio[head_offset:]
477
  mono = mono[head_offset:]
478
+ if beat_samples is not None:
479
+ shifted = np.asarray(beat_samples, dtype=np.int64) - head_offset
480
+ beat_samples = shifted[shifted > 0]
481
+ # Head fade-in: 3 ms equal-power so the trim seam doesn't click.
482
+ _apply_fade(audio, _HEAD_FADE_SEC, sr, fade_in=True)
483
 
484
+ # --- Tempo conform -----------------------------------------------------
485
  if detected_bpm is not None:
486
+ rate, effective_bpm = _best_stretch_rate(
487
+ detected_bpm,
488
+ target_bpm,
489
+ safe_min=_BARS_MODE_STRETCH_MIN,
490
+ safe_max=_BARS_MODE_STRETCH_MAX,
491
+ )
492
+ if rate is not None and abs(rate - 1.0) > _BARS_MODE_DEADBAND:
493
+ audio = _time_stretch_multichannel(audio, rate)
494
+ # Beats have moved β€” re-detect from the warped audio so the
495
+ # end-snap step below sees current beat positions.
496
+ mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
497
+ _, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
498
  interp_note = (
499
  f" (interpreted as {effective_bpm:.2f} BPM, "
500
  f"octave={effective_bpm / detected_bpm:.2f}Γ—)"
 
505
  f"align_to_grid: detected {detected_bpm:.2f} BPM{interp_note}, "
506
  f"stretched by {rate:.4f} to match target {target_bpm:.2f} BPM"
507
  )
508
+ elif rate is not None:
509
+ logger.info(
510
+ f"align_to_grid: detected {detected_bpm:.2f} BPM is within "
511
+ f"{_BARS_MODE_DEADBAND * 100:.0f}% of target {target_bpm:.2f}; "
512
+ f"skipping stretch to preserve transients"
513
+ )
514
  else:
515
  logger.info(
516
  f"align_to_grid: detected {detected_bpm:.2f} BPM has no safe "
517
+ f"interpretation vs target {target_bpm:.2f} within "
518
+ f"[{_BARS_MODE_STRETCH_MIN:.2f}, {_BARS_MODE_STRETCH_MAX:.2f}]; "
519
+ f"skipping warp (user re-roll recommended)"
520
  )
521
  else:
522
  logger.info("align_to_grid: no usable tempo detected; skipping warp")
523
 
524
+ # --- Trim trailing silence --------------------------------------------
525
+ # Done before end-snap so the snap operates on real audio, not on
526
+ # beats that happen to fall inside a quiet tail.
527
+ new_len = _trailing_audio_end(audio, sr)
528
+ if new_len < audio.shape[0]:
529
+ trimmed_ms = (audio.shape[0] - new_len) / sr * 1000
530
+ logger.info(f"align_to_grid: trimmed {trimmed_ms:.0f} ms trailing silence")
531
+ audio = audio[:new_len]
532
+ if beat_samples is not None:
533
+ beat_samples = beat_samples[beat_samples < new_len]
534
+
535
+ # --- End-anchored truncation ------------------------------------------
536
  if audio.shape[0] > target_samples:
537
+ end = _snap_to_beat(target_samples, beat_samples, samples_per_beat, audio.shape[0])
538
+ cut_on_beat = beat_samples is not None and end in beat_samples.tolist()
539
+ audio = audio[:end]
540
+ if not cut_on_beat:
541
+ # Mid-note cut β€” short fade hides the click. On a clean beat
542
+ # boundary the cut is on a natural transient edge, so the fade
543
+ # would only "duck" the start of the next beat at the loop
544
+ # seam without preventing any audible click.
545
+ _apply_fade(audio, _TAIL_FADE_SEC, sr, fade_in=False)
546
+ # If we came in shorter than target, return the actual audio without
547
+ # zero-padding. A 7.5-bar clip that loops cleanly beats an 8-bar clip
548
+ # with 0.5 bars of silence at the loop seam.
549
 
550
  sf.write(str(input_path), audio, sr, subtype="PCM_16")
551
  return input_path
552
 
553
 
554
+ # --- Phase 7 loop alignment -----------------------------------------------
555
+
556
+ # DEPRECATED: superseded by app/core/loop_quantizer (see task_1.md / AUDIT.md Β§9a).
557
+ # Public entry; emits DeprecationWarning at runtime. Scheduled for removal once
558
+ # the new module passes acceptance.
559
+ def align_for_loop(
560
+ audio: np.ndarray,
561
+ sr: int,
562
+ *,
563
+ target_samples: int,
564
+ target_bpm: float,
565
+ ) -> np.ndarray:
566
+ """Align a baseline clip for seamless looping at an exact length.
567
+
568
+ DEPRECATED β€” superseded by ``app/core/loop_quantizer`` (see ``task_1.md`` /
569
+ ``AUDIT.md`` Β§9a). Scheduled for removal once the new module ships.
570
+
571
+ Pipeline (in-memory, no disk I/O):
572
+ 1. Detect tempo + beat grid via librosa.
573
+ 2. Time-stretch (uniformly) if detected BPM drifts past the bars-mode
574
+ deadband AND the required rate is in the safe range. Drift
575
+ beyond the safe range is left alone (caller can re-roll).
576
+ 3. Head-trim to the first detected beat (or first onset as fallback),
577
+ within the first ~1.5 s. This is the phase-alignment step β€” it
578
+ puts the loop's "downbeat" at sample 0 so multiple channels'
579
+ beats coincide when launched on a bar boundary.
580
+ 4. Crop or zero-pad to exactly `target_samples`. No end-snap: the
581
+ loop iteration length is sample-exact so it stays phase-locked
582
+ to the master clock across iterations.
583
+
584
+ Returns a `np.ndarray` of shape `(target_samples, channels)` (or 1-D
585
+ if input was 1-D). The caller is expected to wrap-and-inpaint the
586
+ output to smooth the seam β€” `align_for_loop` does no fade.
587
+ """
588
+ warnings.warn(
589
+ "align_for_loop is deprecated and will be removed once "
590
+ "app/core/loop_quantizer ships (see task_1.md / AUDIT.md Β§9a).",
591
+ DeprecationWarning,
592
+ stacklevel=2,
593
+ )
594
+ if audio.ndim == 1:
595
+ audio = audio[:, np.newaxis]
596
+ squeeze_out = True
597
+ else:
598
+ squeeze_out = False
599
+ audio = np.ascontiguousarray(audio, dtype=np.float32)
600
+
601
+ if beatsync_v2_enabled():
602
+ out = _stage_a_v2(
603
+ audio, sr,
604
+ target_samples=target_samples, target_bpm=float(target_bpm),
605
+ deadband=_LOOP_MODE_DEADBAND,
606
+ )
607
+ return out.squeeze(1) if squeeze_out else out
608
+
609
+ mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
610
+ detected_bpm, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
611
+
612
+ # --- 1+2: tempo conform ---------------------------------------------
613
+ if detected_bpm is not None:
614
+ rate, effective_bpm = _best_stretch_rate(
615
+ detected_bpm,
616
+ target_bpm,
617
+ safe_min=_BARS_MODE_STRETCH_MIN,
618
+ safe_max=_BARS_MODE_STRETCH_MAX,
619
+ )
620
+ if rate is not None and abs(rate - 1.0) > _LOOP_MODE_DEADBAND:
621
+ audio = _time_stretch_multichannel(audio, rate)
622
+ mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
623
+ _, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm)
624
+ interp = (
625
+ f" (interpreted as {effective_bpm:.2f} BPM)"
626
+ if abs(effective_bpm - detected_bpm) > 1e-2 else ""
627
+ )
628
+ logger.info(
629
+ "align_for_loop: detected %.2f BPM%s, stretched by %.4f to "
630
+ "match %.2f target",
631
+ detected_bpm, interp, rate, target_bpm,
632
+ )
633
+ elif rate is not None:
634
+ logger.info(
635
+ "align_for_loop: detected %.2f BPM within %.2f%% of %.2f target; "
636
+ "no stretch",
637
+ detected_bpm, _LOOP_MODE_DEADBAND * 100, target_bpm,
638
+ )
639
+ else:
640
+ logger.info(
641
+ "align_for_loop: detected %.2f BPM has no safe stretch to "
642
+ "%.2f target within [%.2f, %.2f]; leaving tempo as-is",
643
+ detected_bpm, target_bpm,
644
+ _BARS_MODE_STRETCH_MIN, _BARS_MODE_STRETCH_MAX,
645
+ )
646
+ else:
647
+ logger.info("align_for_loop: no usable tempo detected; skipping stretch")
648
+
649
+ # --- 3: head-trim to first beat / onset (phase alignment) -----------
650
+ head_offset = 0
651
+ if beat_samples is not None and len(beat_samples) > 0:
652
+ first_beat = int(beat_samples[0])
653
+ if 0 < first_beat < sr * 1.5:
654
+ head_offset = first_beat
655
+ if head_offset == 0:
656
+ # Onset fallback when beat tracking didn't lock β€” gives at least
657
+ # a transient-aligned start instead of mid-attack on sample 0.
658
+ head_offset = _detect_first_onset_sample(mono, sr)
659
+ if head_offset >= sr * 1.5:
660
+ head_offset = 0
661
+ if head_offset > 0:
662
+ audio = audio[head_offset:]
663
+ logger.info(
664
+ "align_for_loop: head-trimmed %.1f ms to first beat/onset",
665
+ head_offset / sr * 1000,
666
+ )
667
+
668
+ # --- 4: crop or pad to exact target_samples -------------------------
669
+ if audio.shape[0] > target_samples:
670
+ audio = audio[:target_samples]
671
+ elif audio.shape[0] < target_samples:
672
+ pad = target_samples - audio.shape[0]
673
+ audio = np.concatenate(
674
+ [audio, np.zeros((pad, audio.shape[1]), dtype=audio.dtype)],
675
+ axis=0,
676
+ )
677
+
678
+ return audio.squeeze(1) if squeeze_out else audio
679
+
680
+
681
+ # --- helpers ---------------------------------------------------------------
682
+
683
+ # DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
684
+ def _trailing_audio_end(audio: np.ndarray, sr: int) -> int:
685
+ """Return the sample index just past the last audible content.
686
+
687
+ Walks backwards in non-overlapping windows of `_SILENCE_WINDOW_SEC` and
688
+ finds the last window whose RMS exceeds `_SILENCE_THRESHOLD_DB`. Returns
689
+ the end of that window plus a small natural-decay tail.
690
+
691
+ Falls back to the original audio length when the entire clip is below
692
+ threshold (silent input) or shorter than one window.
693
+ """
694
+ n = audio.shape[0]
695
+ window = int(sr * _SILENCE_WINDOW_SEC)
696
+ if n <= window:
697
+ return n
698
+ mono = audio.mean(axis=1) if audio.ndim > 1 else audio
699
+ # Squared amplitudes β€” comparing to thresholdΒ² is equivalent to RMS vs
700
+ # threshold but avoids a sqrt per window.
701
+ sq = (mono ** 2)
702
+ thresh_sq = (10.0 ** (_SILENCE_THRESHOLD_DB / 20.0)) ** 2
703
+ tail_keep = int(sr * _SILENCE_TAIL_KEEP_SEC)
704
+ end = n
705
+ while end > 0:
706
+ start = max(0, end - window)
707
+ if float(sq[start:end].mean()) > thresh_sq:
708
+ return min(n, end + tail_keep)
709
+ end = start
710
+ # Whole clip is below threshold β€” leave as-is rather than truncate to 0.
711
+ return n
712
+
713
+
714
+ # DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
715
+ def _snap_to_beat(
716
+ target_samples: int,
717
+ beat_samples: Optional[np.ndarray],
718
+ samples_per_beat: float,
719
+ audio_len: int,
720
+ ) -> int:
721
+ """Return the cut point: the nearest detected beat within Β±Β½ beat of
722
+ target_samples, falling back to target_samples itself if no beat is in
723
+ range. Never overshoots audio length."""
724
+ fallback = min(target_samples, audio_len)
725
+ if beat_samples is None or len(beat_samples) == 0:
726
+ return fallback
727
+ tol = samples_per_beat * 0.5
728
+ valid = beat_samples[(beat_samples > 0) & (beat_samples <= audio_len)]
729
+ if len(valid) == 0:
730
+ return fallback
731
+ diffs = np.abs(valid - target_samples)
732
+ idx = int(np.argmin(diffs))
733
+ if diffs[idx] <= tol:
734
+ return int(valid[idx])
735
+ return fallback
736
+
737
+
738
+ # DEPRECATED: superseded by loop_quantizer (AUDIT.md Β§9b); may be ported if reused.
739
+ def _apply_fade(audio: np.ndarray, duration_sec: float, sr: int, *, fade_in: bool) -> None:
740
+ """In-place equal-power fade on the head (fade_in=True) or tail."""
741
+ n = min(int(duration_sec * sr), audio.shape[0])
742
+ if n <= 1:
743
+ return
744
+ ramp = _equal_power_ramp(n, fade_in=fade_in, dtype=audio.dtype)
745
+ if audio.ndim > 1:
746
+ ramp = ramp[:, np.newaxis]
747
+ if fade_in:
748
+ audio[:n] *= ramp
749
+ else:
750
+ audio[-n:] *= ramp
751
+
752
+
753
+ # DEPRECATED: superseded by loop_quantizer (AUDIT.md Β§9b); may be ported if reused.
754
+ def _equal_power_ramp(n: int, *, fade_in: bool, dtype) -> np.ndarray:
755
+ """Cosine-shaped equal-power fade. Energy at the midpoint is preserved
756
+ when summing fade-out + fade-in of complementary segments, avoiding the
757
+ perceptible 'duck' that linear ramps produce at loop seams."""
758
+ t = np.linspace(0.0, np.pi / 2.0, n).astype(dtype, copy=False)
759
+ return np.sin(t) if fade_in else np.cos(t)
760
+
761
+
762
+ # DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
763
  def _best_stretch_rate(
764
  detected_bpm: float,
765
  target_bpm: float,
766
+ *,
767
+ safe_min: float = _STRETCH_SAFE_MIN,
768
+ safe_max: float = _STRETCH_SAFE_MAX,
769
  ) -> Tuple[Optional[float], float]:
770
  """Pick the time-stretch rate that maps detected β†’ target, considering
771
  half-time and double-time interpretations of the detected tempo. Returns
 
774
  nothing safe is available.
775
 
776
  Order of preference:
777
+ 1. Detected as-is, if it lands inside [safe_min, safe_max].
778
  2. Octave-corrected (detected Γ— 0.5 or Γ— 2.0), only when the as-is
779
  interpretation is out of range. This is the librosa half-/double-
780
  time error recovery path.
 
 
 
781
  """
 
782
  rate_asis = target_bpm / detected_bpm
783
+ if safe_min <= rate_asis <= safe_max:
784
  return rate_asis, detected_bpm
785
 
 
 
 
786
  candidates = []
787
  for octave_factor in (0.5, 2.0):
788
  interpreted = detected_bpm * octave_factor
789
  rate = target_bpm / interpreted
790
+ if safe_min <= rate <= safe_max:
791
  candidates.append((abs(rate - 1.0), rate, interpreted))
792
  if not candidates:
793
  return None, detected_bpm
 
796
  return best_rate, best_interp
797
 
798
 
799
+ # DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
800
  def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int:
801
  """Return the sample index of the first detected onset, or 0 if none found."""
802
  try:
 
814
  return first
815
 
816
 
817
+ # DEPRECATED: superseded by loop_quantizer detector (AUDIT.md Β§9c); port or replace.
818
+ def _detect_grid(
819
  mono: np.ndarray,
820
  sr: int,
821
  start_bpm: Optional[float] = None,
822
+ ) -> Tuple[Optional[float], Optional[np.ndarray]]:
823
+ """Run librosa beat tracking with the target tempo as a prior. Returns
824
+ (bpm, beat_samples_array). Passing start_bpm reduces (but doesn't
825
+ eliminate) half-time / double-time errors; the octave-correction in
826
+ _best_stretch_rate handles whatever librosa still gets wrong."""
827
  try:
828
  kwargs = {"y": mono, "sr": sr, "units": "samples"}
829
  if start_bpm is not None and start_bpm > 0:
 
837
  bpm = float(np.atleast_1d(tempo).flatten()[0])
838
  if not (40.0 <= bpm <= 240.0):
839
  return None, None
840
+ return bpm, np.asarray(beats, dtype=np.int64)
841
 
842
 
843
+ # DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md Β§9b).
844
  def _time_stretch_multichannel(audio: np.ndarray, rate: float) -> np.ndarray:
845
  """Phase-vocoder time stretch, applied per channel and re-stacked."""
846
  stretched = librosa.effects.time_stretch(audio.T, rate=rate)
app/core/model_manager.py CHANGED
@@ -1,478 +1,669 @@
1
- import os
 
 
 
 
 
 
 
 
 
2
  import json
 
3
  import shutil
4
- from pathlib import Path
5
- from typing import Dict, List, Optional, Callable
 
6
  from datetime import datetime
7
- import requests
8
- from huggingface_hub import snapshot_download, hf_hub_download
9
- import hashlib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
 
 
12
  class ModelManager:
 
13
 
14
- def __init__(self, config):
15
  self.config = config
16
- self.models_dir = config.get_path("models_pretrained")
17
  self.models_dir.mkdir(exist_ok=True, parents=True)
18
 
19
- # Use fragmenta-models repo on HF Spaces, Stability AI models elsewhere
20
- use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true'
21
-
22
- if use_custom_repo:
23
- models_repo = 'MazCodes/fragmenta-models'
24
- small_file = 'stable-audio-open-small-model.safetensors'
25
- large_file = 'stable-audio-open-model.safetensors'
26
- else:
27
- models_repo_small = 'stabilityai/stable-audio-open-small'
28
- models_repo_large = 'stabilityai/stable-audio-open-1.0'
29
- small_file = 'model.safetensors'
30
- large_file = 'model.safetensors'
31
-
32
- self.available_models = {
33
- 'stable-audio-open-small': {
34
- 'name': 'Stable Audio Open Small',
35
- 'repo': models_repo if use_custom_repo else models_repo_small,
36
- 'files': [small_file],
37
- 'size': '2.1 GB',
38
- 'description': 'Fast generation, good quality, lower memory usage',
39
- 'best_for': 'Beginners, quick experiments, limited GPU',
40
- 'license': 'Stability AI License',
41
- 'checksum': 'sha256:abc123...'
42
- },
43
- 'stable-audio-open-1.0': {
44
- 'name': 'Stable Audio Open 1.0',
45
- 'repo': models_repo if use_custom_repo else models_repo_large,
46
- 'files': [large_file],
47
- 'size': '8.2 GB',
48
- 'description': 'Highest quality, more detailed audio',
49
- 'best_for': 'Professional use, high-end GPUs',
50
- 'license': 'Stability AI License',
51
- 'checksum': 'sha256:def456...'
52
- }
53
  }
54
 
55
- self.terms_file = Path("config/terms_accepted.json")
56
- self.terms_file.parent.mkdir(exist_ok=True)
57
-
58
- def get_available_models(self) -> List[Dict]:
59
-
60
- models = []
61
-
62
- for model_id, info in self.available_models.items():
63
- is_downloaded = self.is_model_downloaded(model_id)
64
-
65
- downloaded_size = None
66
- if is_downloaded:
67
- if model_id == 'stable-audio-open-small':
68
- model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
69
- downloaded_size = self._get_file_size(
70
- model_file) if model_file.exists() else None
71
- elif model_id == 'stable-audio-open-1.0':
72
- model_file = self.models_dir / 'stable-audio-open-model.safetensors'
73
- downloaded_size = self._get_file_size(
74
- model_file) if model_file.exists() else None
75
- else:
76
- model_path = self.models_dir / model_id
77
- downloaded_size = self._get_downloaded_size(
78
- model_path) if model_path.exists() else None
79
-
80
- models.append({
81
- 'id': model_id,
82
- 'name': info['name'],
83
- 'size': info['size'],
84
- 'description': info['description'],
85
- 'best_for': info['best_for'],
86
- 'license': info['license'],
87
- 'downloaded': is_downloaded,
88
- 'downloaded_size': downloaded_size,
89
- 'terms_accepted': self.is_terms_accepted(model_id)
90
- })
91
-
92
- return models
93
-
94
- def _get_file_size(self, file_path: Path) -> str:
95
-
96
- if not file_path.exists() or not file_path.is_file():
97
- return "0 B"
98
 
99
- size = file_path.stat().st_size
100
- return self._bytes_to_human(size)
101
-
102
- def _get_downloaded_size(self, model_path: Path) -> str:
103
-
104
- if not model_path.exists():
105
- return "0 B"
106
-
107
- total_size = 0
108
- for file_path in model_path.rglob("*"):
109
- if file_path.is_file():
110
- total_size += file_path.stat().st_size
 
 
 
111
 
112
- for unit in ['B', 'KB', 'MB', 'GB']:
113
- if total_size < 1024.0:
114
- return f"{total_size:.1f} {unit}"
115
- total_size /= 1024.0
116
- return f"{total_size:.1f} TB"
117
 
118
- def get_model_info(self, model_id: str) -> Optional[Dict]:
119
 
120
- if model_id not in self.available_models:
121
- return None
 
 
 
 
 
122
 
123
- info = self.available_models[model_id].copy()
124
- info['id'] = model_id
125
- info['downloaded'] = self.is_model_downloaded(model_id)
126
- info['terms_accepted'] = self.is_terms_accepted(model_id)
127
 
128
- return info
 
 
129
 
130
  def is_model_downloaded(self, model_id: str) -> bool:
131
-
132
- if model_id == 'stable-audio-open-small':
133
- model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
134
- return model_file.exists() and model_file.is_file()
135
- elif model_id == 'stable-audio-open-1.0':
136
- model_file = self.models_dir / 'stable-audio-open-model.safetensors'
137
- return model_file.exists() and model_file.is_file()
138
- else:
139
- model_path = self.models_dir / model_id
140
- if model_path.exists() and model_path.is_dir():
141
- return any(model_path.iterdir())
142
- pattern = f"*{model_id}*.safetensors"
143
- matching_files = list(self.models_dir.glob(pattern))
144
- return len(matching_files) > 0
145
-
146
- def is_terms_accepted(self, model_id: str) -> bool:
147
-
148
- if not self.terms_file.exists():
149
  return False
150
-
151
- try:
152
- with open(self.terms_file, 'r') as f:
153
- terms_data = json.load(f)
154
- return terms_data.get(model_id, {}).get('accepted', False)
155
- except:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  return False
157
 
158
- def accept_terms(self, model_id: str) -> bool:
159
-
160
- if model_id not in self.available_models:
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return False
162
-
163
- terms_data = {}
164
- if self.terms_file.exists():
165
- try:
166
- with open(self.terms_file, 'r') as f:
167
- terms_data = json.load(f)
168
- except:
169
- terms_data = {}
170
-
171
- terms_data[model_id] = {
172
- 'accepted': True,
173
- 'accepted_at': datetime.now().isoformat(),
174
- 'model_name': self.available_models[model_id]['name'],
175
- 'license': self.available_models[model_id]['license']
176
- }
177
-
178
  try:
179
- with open(self.terms_file, 'w') as f:
180
- json.dump(terms_data, f, indent=2)
181
- return True
182
- except Exception as e:
183
- print(f"Error saving terms acceptance: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  return False
185
-
186
- def download_model(self, model_id: str, progress_callback: Optional[Callable] = None) -> bool:
187
-
188
- if model_id not in self.available_models:
189
  return False
190
-
191
- if not self.is_terms_accepted(model_id):
192
- print(f"Terms not accepted for {model_id}")
193
- self.accept_terms(model_id)
194
- print(f"Automatically accepted terms for {model_id}")
195
-
196
- model_info = self.available_models[model_id]
197
- target_dir = self.models_dir
198
- target_dir.mkdir(exist_ok=True, parents=True)
199
-
200
- try:
201
- print(f"Downloading {model_info['name']} to {target_dir}")
202
-
203
- if progress_callback:
204
- progress_callback(
205
- 0, f"Starting download of {model_info['name']}...")
206
-
207
- from huggingface_hub import HfApi
208
- api = HfApi()
209
-
210
  try:
211
- user = api.whoami()
212
- print(f"Authenticated as: {user}")
213
  if progress_callback:
214
- progress_callback(10, "Authentication verified...")
215
- except Exception as auth_error:
216
- print(f"Not authenticated with Hugging Face: {auth_error}")
217
- if progress_callback:
218
- progress_callback(0, "Authentication required...")
219
- print("To download models, you need to:")
220
- print(
221
- "1. Visit https://huggingface.co/stabilityai/stable-audio-open-small")
222
- print("2. Accept the terms and conditions")
223
- print("3. Log in to your Hugging Face account")
224
- print(
225
- "4. Get your access token from https://huggingface.co/settings/tokens")
226
- print("5. Use the in-app Hugging Face login dialog")
 
 
 
 
 
227
  if progress_callback:
228
- progress_callback(0, "Please authenticate in the app first")
229
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- if progress_callback:
232
- progress_callback(20, "Starting file download...")
233
-
234
- try:
235
- from huggingface_hub import hf_hub_download
236
- import shutil
237
- from tqdm import tqdm
238
- import sys
239
-
240
- class TqdmToCallback:
241
- def __init__(self, callback, file_index, total_files):
242
- self.callback = callback
243
- self.file_index = file_index
244
- self.total_files = total_files
245
- self.last_percent = 0
246
-
247
- def __call__(self, t):
248
- def inner(bytes_amount=1):
249
- if t.total:
250
- file_progress = (t.n / t.total)
251
- overall_progress = (self.file_index + file_progress) / self.total_files
252
- percent = 20 + int(overall_progress * 70)
253
-
254
- if percent != self.last_percent:
255
- self.last_percent = percent
256
- downloaded_mb = t.n / (1024 * 1024)
257
- total_mb = t.total / (1024 * 1024)
258
- if self.callback:
259
- self.callback(
260
- percent,
261
- f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB"
262
- )
263
- return inner
264
-
265
- downloaded_files = []
266
- total_files = len(model_info['files'])
267
-
268
- for i, file_pattern in enumerate(model_info['files']):
269
  if progress_callback:
270
  progress_callback(
271
- 20 + int((i / total_files) * 70),
272
- f"Starting download of {file_pattern}..."
273
  )
274
-
275
- try:
276
- if file_pattern == 'model.safetensors':
277
- if model_id == 'stable-audio-open-small':
278
- final_filename = 'stable-audio-open-small-model.safetensors'
279
- elif model_id == 'stable-audio-open-1.0':
280
- final_filename = 'stable-audio-open-model.safetensors'
281
- else:
282
- final_filename = f"{model_id}-model.safetensors"
283
- else:
284
- final_filename = f"{model_id}-{file_pattern}"
285
-
286
- tqdm_callback = TqdmToCallback(progress_callback, i, total_files)
287
-
288
- # hf_hub_download drives its own tqdm β€” monkey-patch its init/update so we
289
- # forward byte progress to progress_callback without a second progress bar.
290
- original_tqdm_init = tqdm.__init__
291
-
292
- def patched_tqdm_init(self, *args, **kwargs):
293
- original_tqdm_init(self, *args, **kwargs)
294
- original_update = self.update
295
- def new_update(n=1):
296
- result = original_update(n)
297
- if progress_callback and self.total:
298
- file_progress = (self.n / self.total)
299
- overall_progress = (i + file_progress) / total_files
300
- percent = 20 + int(overall_progress * 70)
301
- downloaded_mb = self.n / (1024 * 1024)
302
- total_mb = self.total / (1024 * 1024)
303
- progress_callback(
304
- percent,
305
- f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB"
306
- )
307
- return result
308
- self.update = new_update
309
-
310
- tqdm.__init__ = patched_tqdm_init
311
-
312
- try:
313
- downloaded_file = hf_hub_download(
314
- repo_id=model_info['repo'],
315
- filename=file_pattern,
316
- resume_download=True
317
- )
318
- finally:
319
- tqdm.__init__ = original_tqdm_init
320
-
321
- downloaded_path = Path(downloaded_file)
322
- final_path = target_dir / final_filename
323
-
324
- final_path.parent.mkdir(parents=True, exist_ok=True)
325
-
326
- shutil.copy2(str(downloaded_path), str(final_path))
327
- print(f"Saved as {final_filename}")
328
-
329
- downloaded_files.append(str(final_path))
330
-
331
- if progress_callback:
332
- progress_callback(
333
- 20 + int(((i + 1) / total_files) * 70),
334
- f"Completed {file_pattern}"
335
- )
336
-
337
- except Exception as file_error:
338
- print(
339
- f"Failed to download {file_pattern}: {file_error}")
340
- if progress_callback:
341
- progress_callback(
342
- 0, f"Failed to download {file_pattern}")
343
- continue
344
-
345
- print(f"Downloaded {len(downloaded_files)} files")
346
-
347
- if progress_callback:
348
- progress_callback(
349
- 95, "Download completed, verifying files...")
350
-
351
- except Exception as download_error:
352
- print(f"Error during download: {download_error}")
353
- if progress_callback:
354
- progress_callback(
355
- 0, f"Download failed: {str(download_error)}")
356
- return False
357
-
358
  if progress_callback:
359
- progress_callback(95, "Verifying download...")
360
-
361
- expected_files = []
362
- if model_id == 'stable-audio-open-small':
363
- expected_files.append(
364
- 'stable-audio-open-small-model.safetensors')
365
- elif model_id == 'stable-audio-open-1.0':
366
- expected_files.append('stable-audio-open-model.safetensors')
367
- else:
368
- expected_files.append(f"{model_id}-model.safetensors")
369
-
370
- files_exist = any((target_dir / expected_file).exists()
371
- for expected_file in expected_files)
372
-
373
- if files_exist:
374
- if progress_callback:
375
- progress_callback(100, "Download complete!")
376
- print(f"Successfully downloaded {model_info['name']}")
377
- return True
378
- else:
379
- if progress_callback:
380
- progress_callback(0, "Download verification failed")
381
- print(f"Expected files not found: {expected_files}")
382
- return False
383
-
384
- except Exception as e:
385
- print(f"Error downloading {model_info['name']}: {e}")
386
- if progress_callback:
387
- progress_callback(0, f"Error: {str(e)}")
388
-
389
- if "403" in str(e) and "gated repositories" in str(e).lower():
390
- print("Token permission issue detected!")
391
- print(
392
- "Your Hugging Face token needs 'Read access to public gated repositories'")
393
- print("Please:")
394
- print("1. Go to https://huggingface.co/settings/tokens")
395
- print("2. Edit your token or create a new one")
396
- print("3. Enable 'Read access to public gated repositories'")
397
- print("4. Try the download again")
398
- elif "401" in str(e) or "restricted" in str(e).lower():
399
- print("This model requires Hugging Face authentication.")
400
- print("Please visit the model page and accept terms first:")
401
- print(f"https://huggingface.co/{model_info['repo']}")
402
- return False
403
 
404
  def delete_model(self, model_id: str) -> bool:
405
-
406
- deleted_something = False
407
-
408
- if model_id == 'stable-audio-open-small':
409
- model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
410
- config_file = self.models_dir / 'stable-audio-open-small-config.json'
411
- elif model_id == 'stable-audio-open-1.0':
412
- model_file = self.models_dir / 'stable-audio-open-model.safetensors'
413
- config_file = self.models_dir / 'stable-audio-open-1.0-config.json'
414
- else:
415
- model_file = self.models_dir / f"{model_id}-model.safetensors"
416
- config_file = self.models_dir / f"{model_id}-config.json"
417
-
418
- for file_path in [model_file, config_file]:
419
- if file_path.exists():
420
- try:
421
- file_path.unlink()
422
- print(f"Deleted {file_path.name}")
423
- deleted_something = True
424
- except Exception as e:
425
- print(f"Error deleting {file_path.name}: {e}")
426
-
427
- model_path = self.models_dir / model_id
428
- if model_path.exists() and model_path.is_dir():
429
- try:
430
- shutil.rmtree(model_path)
431
- print(f"Deleted {model_id} directory")
432
- deleted_something = True
433
- except Exception as e:
434
- print(f"Error deleting {model_id} directory: {e}")
435
-
436
- if deleted_something:
437
- print(f"Deleted {model_id}")
438
- return True
439
- else:
440
- print(f"No files found for {model_id}")
441
  return False
442
-
443
- def get_download_progress(self, model_id: str) -> Dict:
444
-
445
- return {
446
- 'model_id': model_id,
447
- 'downloaded': self.is_model_downloaded(model_id),
448
- 'size': self.available_models.get(model_id, {}).get('size', 'Unknown')
449
- }
450
-
451
- def get_storage_info(self) -> Dict:
452
-
453
- total_size = 0
454
- model_count = 0
455
-
456
- if self.models_dir.exists():
457
- for model_id in self.available_models.keys():
458
- if self.is_model_downloaded(model_id):
459
- model_count += 1
460
-
461
- for file_path in self.models_dir.rglob("*"):
462
- if file_path.is_file():
463
- total_size += file_path.stat().st_size
464
-
 
 
 
 
 
 
465
  return {
466
- 'total_size_bytes': total_size,
467
- 'total_size_human': self._bytes_to_human(total_size),
468
- 'model_count': model_count,
469
- 'models_dir': str(self.models_dir)
470
  }
471
 
472
- def _bytes_to_human(self, bytes_value: int) -> str:
473
-
474
- for unit in ['B', 'KB', 'MB', 'GB']:
475
- if bytes_value < 1024.0:
476
- return f"{bytes_value:.1f} {unit}"
477
- bytes_value /= 1024.0
478
- return f"{bytes_value:.1f} TB"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Checkpoint Manager β€” SA3 catalog, HF downloads, license + auth.
2
+
3
+ Phase 2a in SA3_INTEGRATION_PLAN.md. Replaces the SA2-era SAO catalog.
4
+ Eight downloadable artifacts (3 post-trained + 3 base + 2 autoencoders);
5
+ each is fetched via `huggingface_hub.snapshot_download` with cooperative
6
+ cancel + progress reporting.
7
+
8
+ The Phase 2b frontend (CheckpointManagerWindow.js) consumes the JSON shapes
9
+ returned by the `/api/checkpoints/*` endpoints in `app/backend/app.py`.
10
+ """
11
  import json
12
+ import os
13
  import shutil
14
+ import threading
15
+ import uuid
16
+ from dataclasses import dataclass, field
17
  from datetime import datetime
18
+ from pathlib import Path
19
+ from typing import Any, Callable, Dict, List, Optional
20
+
21
+ from huggingface_hub import get_token, snapshot_download, whoami
22
+ from huggingface_hub.errors import GatedRepoError, RepositoryNotFoundError
23
+
24
+
25
+ # --- Catalog ------------------------------------------------------------------
26
+
27
+ # Approximate sizes; the frontend can refine these by hitting
28
+ # `huggingface_hub.HfApi().model_info(repo_id)` lazily. Numbers come from the
29
+ # HF model cards (paragraph parameter counts Γ— bytes/param, rounded).
30
+ _SA3_CATALOG: Dict[str, Dict[str, Any]] = {
31
+ # --- Generation models (post-trained) ----------------------------------
32
+ "sa3-small-music": {
33
+ "user_visible": True,
34
+ "kind": "post-trained",
35
+ "name": "Small - Music",
36
+ "sa3_name": "small-music",
37
+ "repo": "stabilityai/stable-audio-3-small-music",
38
+ "size_bytes": 2_270_000_000,
39
+ "hardware": "cpu", # CPU / MPS / CUDA all work
40
+ "max_duration_sec": 120,
41
+ "description": "Fast distilled music generation. Locked to 8 steps, cfg 1.0.",
42
+ },
43
+ "sa3-small-sfx": {
44
+ "user_visible": True,
45
+ "kind": "post-trained",
46
+ "name": "Small - SFX",
47
+ "sa3_name": "small-sfx",
48
+ "repo": "stabilityai/stable-audio-3-small-sfx",
49
+ "size_bytes": 2_270_000_000,
50
+ "hardware": "cpu",
51
+ "max_duration_sec": 120,
52
+ "description": "Fast distilled SFX/foley generation. Locked to 8 steps, cfg 1.0.",
53
+ },
54
+ "sa3-medium": {
55
+ "user_visible": True,
56
+ "kind": "post-trained",
57
+ "name": "Medium",
58
+ "sa3_name": "medium",
59
+ "repo": "stabilityai/stable-audio-3-medium",
60
+ "size_bytes": 9_220_000_000,
61
+ "hardware": "cuda+flash-attn",
62
+ "max_duration_sec": 380,
63
+ "description": "Fast distilled hi-fi generation, up to 380s. Locked to 8 steps, cfg 1.0.",
64
+ },
65
+ # --- Base checkpoints (full artist control) ----------------------------
66
+ # These are the CFG-aware pre-distillation models. Slower (~50 steps,
67
+ # cfg ~7), but the user controls cfg_scale, steps, and the inference
68
+ # trajectory. Also the canonical targets for LoRA training.
69
+ "sa3-small-music-base": {
70
+ "user_visible": True,
71
+ "kind": "base",
72
+ "name": "Small - Music (Base)",
73
+ "sa3_name": "small-music-base",
74
+ "repo": "stabilityai/stable-audio-3-small-music-base",
75
+ "size_bytes": 2_270_000_000,
76
+ "hardware": "cpu",
77
+ "max_duration_sec": 120,
78
+ "description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.",
79
+ },
80
+ "sa3-small-sfx-base": {
81
+ "user_visible": True,
82
+ "kind": "base",
83
+ "name": "Small - SFX (Base)",
84
+ "sa3_name": "small-sfx-base",
85
+ "repo": "stabilityai/stable-audio-3-small-sfx-base",
86
+ "size_bytes": 2_270_000_000,
87
+ "hardware": "cpu",
88
+ "max_duration_sec": 120,
89
+ "description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.",
90
+ },
91
+ "sa3-medium-base": {
92
+ "user_visible": True,
93
+ "kind": "base",
94
+ "name": "Medium (Base)",
95
+ "sa3_name": "medium-base",
96
+ "repo": "stabilityai/stable-audio-3-medium-base",
97
+ "size_bytes": 9_220_000_000,
98
+ "hardware": "cuda+flash-attn",
99
+ "max_duration_sec": 380,
100
+ "description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.",
101
+ },
102
+ # Standalone autoencoders: the AE is bundled INSIDE each DiT repo
103
+ # already (StableAudioModel.from_pretrained loads it from there), so
104
+ # we don't surface SAME-S / SAME-L in the manager. They remain
105
+ # downloadable via /api/checkpoints?include=all for advanced uses
106
+ # (autoencoder-only workflows, pre-encoding datasets for training).
107
+ "sa3-same-s": {
108
+ "user_visible": False,
109
+ "kind": "autoencoder",
110
+ "name": "SAME-S",
111
+ "sa3_name": "same-s",
112
+ "repo": "stabilityai/SAME-S",
113
+ "size_bytes": 530_000_000,
114
+ "hardware": "cpu",
115
+ "description": "Standalone autoencoder (266M). Already bundled with the small-* DiTs.",
116
+ },
117
+ "sa3-same-l": {
118
+ "user_visible": False,
119
+ "kind": "autoencoder",
120
+ "name": "SAME-L",
121
+ "sa3_name": "same-l",
122
+ "repo": "stabilityai/SAME-L",
123
+ "size_bytes": 3_400_000_000,
124
+ "hardware": "cuda",
125
+ "description": "Standalone autoencoder (1.7B). Already bundled with medium.",
126
+ },
127
+ # --- Auto-annotation tools ---------------------------------------------
128
+ # Single-file HF download, lives under <models_pretrained>/clap/.
129
+ # `is_model_downloaded` and `_run_download` special-case kind=="tagger".
130
+ "clap-music": {
131
+ "user_visible": True,
132
+ "kind": "tagger",
133
+ "name": "LAION-CLAP (music)",
134
+ "sa3_name": "clap-music",
135
+ "repo": "lukewys/laion_clap",
136
+ "filename": "music_audioset_epoch_15_esc_90.14.pt",
137
+ # ~2.35 GB .pt + ~1.4 GB of text-encoder snapshots (roberta-base,
138
+ # bert-base-uncased, facebook/bart-base) that laion_clap loads at
139
+ # construction. download_clap_checkpoint pulls all of them.
140
+ "size_bytes": 3_800_000_000,
141
+ "hardware": "cpu",
142
+ "description": (
143
+ "Zero-shot tagger used by the dataset prep's rich-tier annotation. "
144
+ "Scores each clip against your genre / mood / instrument vocabulary."
145
+ ),
146
+ },
147
+ }
148
+
149
+ # --- Job state for in-flight downloads ----------------------------------------
150
+
151
+ @dataclass
152
+ class _DownloadJob:
153
+ """In-memory record of one download attempt."""
154
+ job_id: str
155
+ model_id: str
156
+ status: str = "queued" # queued | running | complete | failed | cancelled
157
+ downloaded_bytes: int = 0
158
+ total_bytes: int = 0
159
+ error: Optional[str] = None
160
+ started_at: Optional[str] = None
161
+ finished_at: Optional[str] = None
162
+ _cancel_flag: threading.Event = field(default_factory=threading.Event)
163
+ _thread: Optional[threading.Thread] = None
164
+
165
+ def to_dict(self) -> Dict[str, Any]:
166
+ return {
167
+ "job_id": self.job_id,
168
+ "model_id": self.model_id,
169
+ "status": self.status,
170
+ "downloaded_bytes": self.downloaded_bytes,
171
+ "total_bytes": self.total_bytes,
172
+ "error": self.error,
173
+ "started_at": self.started_at,
174
+ "finished_at": self.finished_at,
175
+ }
176
+
177
+
178
+ class _DownloadCancelled(Exception):
179
+ """Raised inside the tqdm hook when a job's cancel flag fires."""
180
 
181
 
182
+ # --- ModelManager -------------------------------------------------------------
183
+
184
  class ModelManager:
185
+ """Owns the SA3 catalog and the on-disk pretrained directory."""
186
 
187
+ def __init__(self, config: Any) -> None:
188
  self.config = config
189
+ self.models_dir: Path = config.get_path("models_pretrained")
190
  self.models_dir.mkdir(exist_ok=True, parents=True)
191
 
192
+ # Project-wide policy: every HF download lands inside
193
+ # <app>/models/pretrained/. SA3 generation + training uses
194
+ # <pretrained>/sa3/hub/; CLAP text deps use <pretrained>/clap/hub/.
195
+ # Both are HF cache layout so snapshot_download / hf_hub_download /
196
+ # from_pretrained resolve there transparently.
197
+ self.hub_dir: Path = self.models_dir / "sa3" / "hub"
198
+ self.hub_dir.mkdir(exist_ok=True, parents=True)
199
+ # Hard-force the resolution vars β€” never let an external env leak
200
+ # downloads into ~/.cache/huggingface or anywhere else outside the
201
+ # app folder. Covers huggingface_hub (current + legacy name) and
202
+ # transformers (which still consults TRANSFORMERS_CACHE).
203
+ os.environ["HF_HUB_CACHE"] = str(self.hub_dir)
204
+ os.environ["HUGGINGFACE_HUB_CACHE"] = str(self.hub_dir)
205
+ os.environ["TRANSFORMERS_CACHE"] = str(self.hub_dir)
206
+
207
+ # available_models is exposed for backwards compat with the existing
208
+ # /api/models/available endpoint. New code should use get_catalog().
209
+ self.available_models: Dict[str, Dict] = {
210
+ mid: dict(meta) for mid, meta in _SA3_CATALOG.items()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  }
212
 
213
+ self._jobs: Dict[str, _DownloadJob] = {}
214
+ self._jobs_lock = threading.Lock()
215
+
216
+ # --- Catalog --------------------------------------------------------------
217
+
218
+ def get_catalog(self, include_hidden: bool = False) -> List[Dict[str, Any]]:
219
+ """Checkpoint Manager catalog with per-item state.
220
+
221
+ Default returns only user-visible entries (the three generation
222
+ models). `include_hidden=True` also returns base + standalone-AE
223
+ entries β€” used by the Phase 5 training subprocess to ensure the
224
+ right base variant is on disk before kicking train_lora.py.
225
+ """
226
+ return [
227
+ self._catalog_entry(mid)
228
+ for mid, info in _SA3_CATALOG.items()
229
+ if include_hidden or info.get("user_visible")
230
+ ]
231
+
232
+ def _catalog_entry(self, model_id: str) -> Dict[str, Any]:
233
+ info = _SA3_CATALOG[model_id]
234
+ downloaded = self.is_model_downloaded(model_id)
235
+ bytes_total = 0
236
+ if downloaded:
237
+ for d in (self._hub_cache_dir_for(model_id), self._legacy_flat_dir_for(model_id)):
238
+ if d.exists():
239
+ bytes_total += self._dir_size(d)
240
+
241
+ # Surface the most recent in-flight job for this model so the
242
+ # frontend can resume the progress bar after the Checkpoint Manager
243
+ # dialog is closed and reopened. The job lives on the backend; only
244
+ # the polling died with the dismissed UI.
245
+ active_job = None
246
+ with self._jobs_lock:
247
+ in_flight = [
248
+ j for j in self._jobs.values()
249
+ if j.model_id == model_id and j.status in ("queued", "running")
250
+ ]
251
+ if in_flight:
252
+ in_flight.sort(key=lambda j: j.started_at or "", reverse=True)
253
+ active_job = in_flight[0].to_dict()
 
 
254
 
255
+ return {
256
+ "id": model_id,
257
+ "kind": info.get("kind"),
258
+ "name": info["name"],
259
+ "sa3_name": info["sa3_name"],
260
+ "repo": info["repo"],
261
+ "size_bytes": info["size_bytes"],
262
+ "hardware": info["hardware"],
263
+ "max_duration_sec": info.get("max_duration_sec"),
264
+ "description": info["description"],
265
+ "user_visible": info.get("user_visible", False),
266
+ "downloaded": downloaded,
267
+ "downloaded_bytes": bytes_total,
268
+ "active_job": active_job,
269
+ }
270
 
271
+ def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
272
+ if model_id not in _SA3_CATALOG:
273
+ return None
274
+ return self._catalog_entry(model_id)
 
275
 
276
+ # --- Filesystem layout ----------------------------------------------------
277
 
278
+ def _hub_cache_dir_for(self, model_id: str) -> Path:
279
+ """HF-cache-shaped directory inside the app folder."""
280
+ info = _SA3_CATALOG.get(model_id)
281
+ if info is None:
282
+ return self.hub_dir / "_unknown"
283
+ safe = "models--" + info["repo"].replace("/", "--")
284
+ return self.hub_dir / safe
285
 
286
+ def _legacy_flat_dir_for(self, model_id: str) -> Path:
287
+ """Pre-unification per-model dir. Read-only fallback for migration."""
288
+ return self.models_dir / "sa3" / model_id
 
289
 
290
+ def _local_dir_for(self, model_id: str) -> Path:
291
+ """Public: returns the canonical (HF cache) directory for a model."""
292
+ return self._hub_cache_dir_for(model_id)
293
 
294
  def is_model_downloaded(self, model_id: str) -> bool:
295
+ if model_id not in _SA3_CATALOG:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  return False
297
+ info = _SA3_CATALOG[model_id]
298
+ if info.get("kind") == "tagger":
299
+ # Single-file artifacts live in <models_pretrained>/<group>/<filename>.
300
+ # auto_annotator owns the exact path for CLAP, so we delegate.
301
+ from app.backend.data.auto_annotator import clap_checkpoint_available
302
+ return clap_checkpoint_available(self.models_dir)
303
+ # Canonical: HF cache layout under <app>/models/pretrained/sa3/hub/.
304
+ # Look for the *top-level* model.safetensors only β€” NOT recursive β€”
305
+ # because a sibling repo may have only its conditioner subfolder
306
+ # downloaded (e.g. via the eager T5Gemma companion fetch when the
307
+ # user installed the matching *-base), and that doesn't make the
308
+ # post-trained model "downloaded".
309
+ main_present = False
310
+ hub = self._hub_cache_dir_for(model_id)
311
+ if hub.is_dir():
312
+ snaps = hub / "snapshots"
313
+ if snaps.is_dir():
314
+ for sub in snaps.iterdir():
315
+ if any(sub.glob("*.safetensors")):
316
+ main_present = True
317
+ break
318
+ if not main_present:
319
+ # Fallback: legacy flat layout (predates the unification). Counts
320
+ # as downloaded for inference purposes; trainer will re-stage into
321
+ # hub.
322
+ legacy = self._legacy_flat_dir_for(model_id)
323
+ if legacy.is_dir() and any(legacy.glob("*.safetensors")):
324
+ main_present = True
325
+ if not main_present:
326
  return False
327
 
328
+ # Base models need a T5Gemma conditioner that lives in a subfolder
329
+ # of the *post-trained sibling* repo. "Installed" must mean "ready
330
+ # to train / generate" β€” without the companion the first run blocks
331
+ # for 30s+ on an HF fetch.
332
+ return self._is_companion_present(model_id)
333
+
334
+ def _is_companion_present(self, model_id: str) -> bool:
335
+ from app.core.training.sa3_lora_runner import SA3_T5GEMMA_SIBLINGS
336
+ sibling = SA3_T5GEMMA_SIBLINGS.get(model_id)
337
+ if not sibling:
338
+ return True # nothing to check (post-trained / autoencoder / tagger)
339
+ sib_repo, sib_subfolder = sibling
340
+ safe = "models--" + sib_repo.replace("/", "--")
341
+ sib_hub = self.hub_dir / safe
342
+ snaps = sib_hub / "snapshots"
343
+ if not snaps.is_dir():
344
  return False
345
+ for sub in snaps.iterdir():
346
+ if (sub / sib_subfolder).is_dir():
347
+ # Any non-empty file presence is good enough β€” the eager
348
+ # fetch always pulls the tokenizer + config + safetensors.
349
+ if any((sub / sib_subfolder).iterdir()):
350
+ return True
351
+ return False
352
+
353
+ # --- HF auth --------------------------------------------------------------
354
+
355
+ @staticmethod
356
+ def hf_auth_status() -> Dict[str, Any]:
357
+ token = get_token()
358
+ if not token:
359
+ return {"signed_in": False, "username": None}
 
360
  try:
361
+ user = whoami(token=token)
362
+ return {"signed_in": True, "username": user.get("name") or user.get("fullname")}
363
+ except Exception as err:
364
+ return {"signed_in": False, "username": None, "error": str(err)}
365
+
366
+ # --- Downloads ------------------------------------------------------------
367
+
368
+ def start_download(
369
+ self,
370
+ model_id: str,
371
+ progress_callback: Optional[Callable[[int, str], None]] = None,
372
+ ) -> Dict[str, Any]:
373
+ """Spawn a background download job. Returns the job descriptor."""
374
+ if model_id not in _SA3_CATALOG:
375
+ return {"error": f"Unknown checkpoint: {model_id}"}
376
+
377
+ job = _DownloadJob(
378
+ job_id=str(uuid.uuid4()),
379
+ model_id=model_id,
380
+ total_bytes=_SA3_CATALOG[model_id]["size_bytes"],
381
+ )
382
+ with self._jobs_lock:
383
+ self._jobs[job.job_id] = job
384
+
385
+ thread = threading.Thread(
386
+ target=self._run_download,
387
+ args=(job, progress_callback),
388
+ daemon=True,
389
+ name=f"sa3-download:{model_id}",
390
+ )
391
+ job._thread = thread
392
+ thread.start()
393
+ return job.to_dict()
394
+
395
+ def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
396
+ with self._jobs_lock:
397
+ job = self._jobs.get(job_id)
398
+ return job.to_dict() if job else None
399
+
400
+ def list_jobs(self) -> List[Dict[str, Any]]:
401
+ with self._jobs_lock:
402
+ return [j.to_dict() for j in self._jobs.values()]
403
+
404
+ def cancel_job(self, job_id: str) -> bool:
405
+ with self._jobs_lock:
406
+ job = self._jobs.get(job_id)
407
+ if not job:
408
  return False
409
+ if job.status not in ("queued", "running"):
 
 
 
410
  return False
411
+ job._cancel_flag.set()
412
+ return True
413
+
414
+ def _run_download(
415
+ self,
416
+ job: _DownloadJob,
417
+ progress_callback: Optional[Callable[[int, str], None]],
418
+ ) -> None:
419
+ info = _SA3_CATALOG[job.model_id]
420
+ job.status = "running"
421
+ job.started_at = datetime.now().isoformat()
422
+
423
+ # Tagger kind (e.g. CLAP) is a .pt file plus auxiliary HF snapshots
424
+ # living outside the sa3/hub layout. Multi-phase: 1 hf_hub_download
425
+ # for the audio .pt, then N sequential snapshot_downloads for the
426
+ # text encoders. Each spawns its own tqdm bars, so we use the
427
+ # cumulative hook to accumulate bytes across phases, and a phase_cb
428
+ # to prefix the message with which step the user is on.
429
+ if info.get("kind") == "tagger":
 
430
  try:
431
+ from app.backend.data.auto_annotator import download_clap_checkpoint
 
432
  if progress_callback:
433
+ progress_callback(0, f"Downloading {info['name']}…")
434
+ # Pin total to the catalog estimate so the % stays anchored
435
+ # even before tqdm reports any file's size.
436
+ job.total_bytes = info["size_bytes"]
437
+ current_phase = {"label": ""}
438
+
439
+ def phase_cb(idx: int, total: int, label: str) -> None:
440
+ current_phase["label"] = f"[{idx}/{total}] {label}"
441
+ if progress_callback:
442
+ pct = (int(job.downloaded_bytes / job.total_bytes * 100)
443
+ if job.total_bytes else 0)
444
+ progress_callback(pct, current_phase["label"])
445
+
446
+ with _cumulative_tqdm_hook(job, progress_callback, current_phase):
447
+ download_clap_checkpoint(self.models_dir, phase_cb=phase_cb)
448
+ job.downloaded_bytes = job.total_bytes
449
+ job.status = "complete"
450
+ job.finished_at = datetime.now().isoformat()
451
  if progress_callback:
452
+ progress_callback(100, f"Downloaded {info['name']}")
453
+ except _DownloadCancelled:
454
+ job.status = "cancelled"
455
+ job.error = "Cancelled by user"
456
+ job.finished_at = datetime.now().isoformat()
457
+ except Exception as err:
458
+ job.status = "failed"
459
+ job.error = f"{type(err).__name__}: {err}"
460
+ job.finished_at = datetime.now().isoformat()
461
+ return
462
+
463
+ cache_dir = self._hub_cache_dir_for(job.model_id).parent # = self.hub_dir
464
+ cache_dir.mkdir(exist_ok=True, parents=True)
465
+ target = self._hub_cache_dir_for(job.model_id)
466
+
467
+ token = get_token()
468
 
469
+ try:
470
+ with _tqdm_progress_hook(job, progress_callback):
471
+ # Write into hub/ in HF cache layout. snapshot_download in
472
+ # hf-hub 1.x populates `<cache_dir>/models--<org>--<name>/`
473
+ # with the blobs/refs/snapshots structure that
474
+ # hf_hub_download() and StableAudioModel.from_pretrained()
475
+ # both consume.
476
+ snapshot_download(
477
+ repo_id=info["repo"],
478
+ cache_dir=str(cache_dir),
479
+ token=token,
480
+ allow_patterns=[
481
+ "*.safetensors", "*.json", "*.txt", "*.model",
482
+ "tokenizer*", "*.tiktoken",
483
+ ],
484
+ )
485
+
486
+ # Companion fetch: base models reference their T5Gemma
487
+ # conditioner in a subfolder of the *post-trained sibling*
488
+ # repo. Without it the training subprocess crashes at
489
+ # AutoTokenizer.from_pretrained, and inference can't build
490
+ # the conditioner either. Pull it eagerly so "Installed"
491
+ # actually means "ready to use".
492
+ from app.core.training.sa3_lora_runner import SA3_T5GEMMA_SIBLINGS
493
+ sibling = SA3_T5GEMMA_SIBLINGS.get(job.model_id)
494
+ if sibling:
495
+ sib_repo, sib_subfolder = sibling
 
 
 
 
 
 
 
 
 
 
 
496
  if progress_callback:
497
  progress_callback(
498
+ min(99, int(job.downloaded_bytes / max(1, job.total_bytes) * 100)),
499
+ f"Fetching T5Gemma conditioner from {sib_repo}…",
500
  )
501
+ snapshot_download(
502
+ repo_id=sib_repo,
503
+ cache_dir=str(cache_dir),
504
+ token=token,
505
+ allow_patterns=[f"{sib_subfolder}/*"],
506
+ )
507
+ job.status = "complete"
508
+ job.downloaded_bytes = self._dir_size(target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  if progress_callback:
510
+ progress_callback(100, f"Downloaded {info['name']}")
511
+ except _DownloadCancelled:
512
+ job.status = "cancelled"
513
+ job.error = "Cancelled by user"
514
+ shutil.rmtree(target, ignore_errors=True)
515
+ except GatedRepoError as err:
516
+ job.status = "failed"
517
+ job.error = f"hf_auth_required: {err}"
518
+ except RepositoryNotFoundError as err:
519
+ job.status = "failed"
520
+ job.error = f"Repository not found: {err}"
521
+ except Exception as err:
522
+ job.status = "failed"
523
+ job.error = str(err)
524
+ finally:
525
+ job.finished_at = datetime.now().isoformat()
526
+
527
+ # --- Delete ---------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  def delete_model(self, model_id: str) -> bool:
530
+ if model_id not in _SA3_CATALOG:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  return False
532
+ # Remove both the canonical hub copy and the legacy flat copy if
533
+ # they exist. Either being present is enough to consider the
534
+ # model "downloaded", so both must be cleaned for the row to
535
+ # flip back to "Get".
536
+ hub = self._hub_cache_dir_for(model_id)
537
+ legacy = self._legacy_flat_dir_for(model_id)
538
+ any_existed = hub.exists() or legacy.exists()
539
+ if hub.exists():
540
+ shutil.rmtree(hub, ignore_errors=True)
541
+ if legacy.exists():
542
+ shutil.rmtree(legacy, ignore_errors=True)
543
+ return any_existed and not (hub.exists() or legacy.exists())
544
+
545
+ # --- Storage --------------------------------------------------------------
546
+
547
+ def get_storage_info(self) -> Dict[str, Any]:
548
+ per_model: List[Dict[str, Any]] = []
549
+ total_used = 0
550
+ for mid in _SA3_CATALOG:
551
+ bytes_ = 0
552
+ for d in (self._hub_cache_dir_for(mid), self._legacy_flat_dir_for(mid)):
553
+ if d.exists():
554
+ bytes_ += self._dir_size(d)
555
+ per_model.append({
556
+ "id": mid,
557
+ "downloaded": self.is_model_downloaded(mid),
558
+ "bytes": bytes_,
559
+ })
560
+ total_used += bytes_
561
  return {
562
+ "total_used_bytes": total_used,
563
+ "total_free_bytes": shutil.disk_usage(self.models_dir).free,
564
+ "per_model": per_model,
 
565
  }
566
 
567
+ # --- Helpers --------------------------------------------------------------
568
+
569
+ @staticmethod
570
+ def _dir_size(path: Path) -> int:
571
+ if not path.exists():
572
+ return 0
573
+ return sum(p.stat().st_size for p in path.rglob("*") if p.is_file())
574
+
575
+ # --- tqdm hook ----------------------------------------------------------------
576
+
577
+ import contextlib
578
+
579
+ @contextlib.contextmanager
580
+ def _tqdm_progress_hook(
581
+ job: _DownloadJob,
582
+ progress_callback: Optional[Callable[[int, str], None]],
583
+ ):
584
+ """Monkey-patch tqdm so snapshot_download updates flow into the job state.
585
+
586
+ `snapshot_download` doesn't expose a progress callback. tqdm is its
587
+ internal progress bar β€” we wrap `update` to update job state and raise
588
+ `_DownloadCancelled` when the job's cancel flag fires.
589
+ """
590
+ from tqdm.auto import tqdm
591
+ original_init = tqdm.__init__
592
+
593
+ def patched_init(self, *args: Any, **kwargs: Any) -> None:
594
+ original_init(self, *args, **kwargs)
595
+ original_update = self.update
596
+
597
+ def new_update(n: int = 1) -> Any:
598
+ if job._cancel_flag.is_set():
599
+ raise _DownloadCancelled()
600
+ result = original_update(n)
601
+ if self.total:
602
+ job.downloaded_bytes = max(job.downloaded_bytes, self.n)
603
+ if job.total_bytes < self.total:
604
+ job.total_bytes = self.total
605
+ if progress_callback:
606
+ pct = int(self.n / self.total * 100) if self.total else 0
607
+ mb_done = self.n / (1024 * 1024)
608
+ mb_total = self.total / (1024 * 1024)
609
+ progress_callback(pct, f"Downloading: {mb_done:.1f}MB / {mb_total:.1f}MB")
610
+ return result
611
+
612
+ self.update = new_update # type: ignore[method-assign]
613
+
614
+ tqdm.__init__ = patched_init # type: ignore[method-assign]
615
+ try:
616
+ yield
617
+ finally:
618
+ tqdm.__init__ = original_init # type: ignore[method-assign]
619
+
620
+
621
+ @contextlib.contextmanager
622
+ def _cumulative_tqdm_hook(
623
+ job: _DownloadJob,
624
+ progress_callback: Optional[Callable[[int, str], None]],
625
+ current_phase: Dict[str, str],
626
+ ):
627
+ """Like _tqdm_progress_hook, but sums bytes across sequential bars.
628
+
629
+ Each tqdm bar reports `self.n` cumulative within ITS file. The single-bar
630
+ hook uses max() which freezes the UI when a fresh bar starts smaller than
631
+ the previous bar's total. Here we track the previous `self.n` per bar id
632
+ and add only the delta to job.downloaded_bytes β€” so progress climbs
633
+ monotonically across all phases.
634
+ """
635
+ from tqdm.auto import tqdm
636
+ original_init = tqdm.__init__
637
+ prev_n: Dict[int, int] = {}
638
+
639
+ def patched_init(self, *args: Any, **kwargs: Any) -> None:
640
+ original_init(self, *args, **kwargs)
641
+ original_update = self.update
642
+ prev_n[id(self)] = 0
643
+
644
+ def new_update(n: int = 1) -> Any:
645
+ if job._cancel_flag.is_set():
646
+ raise _DownloadCancelled()
647
+ result = original_update(n)
648
+ prev = prev_n.get(id(self), 0)
649
+ delta = self.n - prev
650
+ prev_n[id(self)] = self.n
651
+ if delta > 0:
652
+ job.downloaded_bytes += delta
653
+ if progress_callback and job.total_bytes:
654
+ pct = min(int(job.downloaded_bytes / job.total_bytes * 100), 99)
655
+ mb_done = job.downloaded_bytes / (1024 * 1024)
656
+ mb_total = job.total_bytes / (1024 * 1024)
657
+ label = current_phase.get("label", "")
658
+ msg = (f"{label} Β· {mb_done:.0f} MB / {mb_total:.0f} MB"
659
+ if label else f"{mb_done:.0f} MB / {mb_total:.0f} MB")
660
+ progress_callback(pct, msg)
661
+ return result
662
+
663
+ self.update = new_update # type: ignore[method-assign]
664
+
665
+ tqdm.__init__ = patched_init # type: ignore[method-assign]
666
+ try:
667
+ yield
668
+ finally:
669
+ tqdm.__init__ = original_init # type: ignore[method-assign]
app/core/training/hyperparam_suggester.py CHANGED
@@ -1,76 +1,67 @@
1
- """Heuristic hyperparameter suggester for the Training tab's "Suggest" button.
2
-
3
- Given the dataset on disk and the current hardware, returns a config that
4
- trades off "small dataset, needs more updates per epoch" vs "big dataset,
5
- batch up for throughput", plus the practical VRAM ceilings of the LoRA path
6
- on Stable Audio Open 1.0. Returns the same shape the frontend `trainingConfig`
7
- uses, so Apply can spread the result into state directly.
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  from __future__ import annotations
11
 
12
- import json
13
- import os
14
- import subprocess
15
  from pathlib import Path
16
- from typing import Any, Dict, List, Optional
 
 
 
17
 
18
- AUDIO_EXTS = {".wav", ".mp3", ".flac", ".m4a"}
19
 
20
- # Cache file for total-duration measurement. ffprobe across 500 files takes
21
- # 10-30s; we don't want to pay that on every button click. Cache key is the
22
- # (file_count, max_mtime_int) pair β€” invalidates automatically when files
23
- # are added/removed/touched.
24
- _DURATION_CACHE_NAME = ".duration_cache.json"
25
 
26
 
27
  def _list_audio_files(data_dir: Path) -> List[Path]:
 
28
  if not data_dir.exists():
29
  return []
30
  return [
31
  p for p in data_dir.iterdir()
32
- if p.is_file() and p.suffix.lower() in AUDIO_EXTS
33
  ]
34
 
35
 
36
- def _measure_total_duration(audio_files: List[Path], cache_path: Path) -> float:
37
- if not audio_files:
38
- return 0.0
39
-
40
- file_count = len(audio_files)
41
- max_mtime = int(max(p.stat().st_mtime for p in audio_files))
42
- cache_key = f"{file_count}:{max_mtime}"
43
-
44
- if cache_path.exists():
45
- try:
46
- cached = json.loads(cache_path.read_text())
47
- if cached.get("key") == cache_key:
48
- return float(cached["duration_sec"])
49
- except Exception:
50
- pass
51
-
52
- total = 0.0
53
  for f in audio_files:
54
- try:
55
- out = subprocess.check_output(
56
- ["ffprobe", "-v", "error", "-show_entries", "format=duration",
57
- "-of", "default=noprint_wrappers=1:nokey=1", str(f)],
58
- text=True, timeout=10,
59
- ).strip()
60
- total += float(out)
61
- except Exception:
62
- # Skip files ffprobe can't read; better to under-report than crash.
63
- continue
64
-
65
- try:
66
- cache_path.write_text(json.dumps({
67
- "key": cache_key,
68
- "duration_sec": total,
69
- }))
70
- except Exception:
71
- pass
72
-
73
- return total
74
 
75
 
76
  def _detect_vram_gb() -> Optional[float]:
@@ -83,6 +74,9 @@ def _detect_vram_gb() -> Optional[float]:
83
  return None
84
 
85
 
 
 
 
86
  def _bucket(file_count: int) -> str:
87
  if file_count < 20:
88
  return "tiny"
@@ -93,66 +87,136 @@ def _bucket(file_count: int) -> str:
93
  return "large"
94
 
95
 
96
- def _heuristic(file_count: int, vram_gb: Optional[float], mode: str) -> Dict[str, Any]:
97
- """The rules-of-thumb. Same shape regardless of mode; the frontend ignores
98
- LoRA-specific keys when mode='full'."""
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  bucket = _bucket(file_count)
101
- has_vram = vram_gb is not None
102
- constrained = (has_vram and vram_gb < 12)
103
-
104
- # Target total weight updates. Sublinear with dataset size so tiny sets
105
- # still get enough gradient steps, while large sets don't run forever.
106
- target_steps_by_bucket = {
107
- "tiny": 2500,
108
- "small": 2000,
109
- "medium": 1500,
110
- "large": 3000,
111
- }
112
- target_steps = target_steps_by_bucket[bucket]
113
-
114
- # Rank/LR/alpha scale with how much "capacity per data point" the run needs.
115
- # Small dataset trick: keep rank moderate (16) and conservative LR (1e-4 β€”
116
- # 2e-4 caused overshoot/flat loss in testing), but boost alpha so the
117
- # LoRA delta trains at higher effective voltage (scaling = alpha/rank).
118
- # This produces a stronger imprint without the parameter bloat of rank=32
119
- # or the instability of higher LR.
120
- if bucket in ("tiny", "small"):
121
- rank, alpha, lr = 16, 32, 1e-4
122
- else:
123
- rank, alpha, lr = 16, 16, 1e-4
124
-
125
- # Batch size: smaller on small datasets (more updates per epoch + better
126
- # gradient noise); larger on medium/large for throughput. VRAM caps the top.
127
- if bucket == "tiny":
128
- batch = 1 if constrained else 2
129
- elif bucket == "small":
130
- # Hold batch=2 even on roomy VRAM β€” the noise benefit on a small
131
- # dataset outweighs the throughput win, and it keeps the epoch
132
- # count to a reasonable display number.
133
- batch = 2
134
- elif bucket == "medium":
135
- batch = 2 if constrained else 4
136
- else:
137
- batch = 4 if constrained else 8
138
 
139
- steps_per_epoch = max(1, file_count // batch)
140
- epochs = max(20, round(target_steps / steps_per_epoch))
 
 
141
 
142
  return {
 
143
  "batchSize": batch,
144
- "learningRate": lr,
145
- "epochs": epochs,
146
- "loraRank": rank,
147
- "loraAlpha": alpha,
148
- "loraDropout": 0,
149
- "loraMultiplier": 1.0,
 
 
 
 
150
  "_meta": {
151
  "bucket": bucket,
152
- "target_steps": target_steps,
153
- "steps_per_epoch": steps_per_epoch,
154
- "total_steps": steps_per_epoch * epochs,
155
  "vram_constrained": constrained,
 
156
  },
157
  }
158
 
@@ -166,68 +230,162 @@ def _format_duration(seconds: float) -> str:
166
  return f"{m}m {s}s"
167
 
168
 
169
- def _compose_rationale(file_count: int, duration_sec: float, vram_gb: Optional[float],
170
- mode: str, meta: Dict[str, Any]) -> List[str]:
171
- """Human-readable explanation, returned as a list of bullet strings."""
172
- bullets = []
 
 
 
 
 
 
 
 
 
173
  bullets.append(
174
- f"Dataset: {file_count} audio file{'s' if file_count != 1 else ''}, "
175
- f"total {_format_duration(duration_sec)} β†’ "
176
- f"\"{meta['bucket']}\" bucket."
177
  )
 
 
 
 
 
 
 
 
 
178
  if vram_gb is not None:
179
- constraint = "VRAM-constrained" if meta["vram_constrained"] else "comfortable VRAM headroom"
180
- bullets.append(f"Detected GPU with {vram_gb:.1f} GB ({constraint}).")
 
 
 
181
  else:
182
- bullets.append("No GPU detected β€” assuming consumer-class constraints.")
183
- bullets.append(
184
- f"Targeting ~{meta['target_steps']} weight updates total; with batch_size "
185
- f"the dataset gives {meta['steps_per_epoch']} steps/epoch, so "
186
- f"{meta['total_steps']} steps over the recommended epoch count."
187
- )
188
- if meta["bucket"] in ("tiny", "small"):
189
  bullets.append(
190
- "Small dataset β†’ conservative 1e-4 LR + rank=16 for stability, "
191
- "but alpha=32 (alpha/rank = 2.0) so the LoRA delta trains at "
192
- "double voltage. Stronger imprint without overshoot risk."
193
  )
194
  else:
195
  bullets.append(
196
- "Larger dataset β†’ moderate batch + standard 1e-4 LR. Rank=16 has "
197
- "plenty of capacity for the prompt distribution this size implies."
 
198
  )
199
- return bullets
200
 
 
 
 
 
 
201
 
202
- def suggest(data_dir: Path, mode: str = "lora") -> Dict[str, Any]:
203
- """Public entry point. Returns the suggestion + a rationale + raw stats."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  audio_files = _list_audio_files(data_dir)
205
  file_count = len(audio_files)
206
  if file_count == 0:
207
  return {
208
  "ok": False,
209
- "error": f"No audio files found in {data_dir}",
 
 
 
210
  }
211
 
212
- cache_path = data_dir / _DURATION_CACHE_NAME
213
- duration_sec = _measure_total_duration(audio_files, cache_path)
214
  vram_gb = _detect_vram_gb()
215
 
216
- suggestion = _heuristic(file_count, vram_gb, mode)
217
  meta = suggestion.pop("_meta")
218
- rationale = _compose_rationale(file_count, duration_sec, vram_gb, mode, meta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  return {
221
  "ok": True,
222
  "stats": {
223
  "file_count": file_count,
224
- "duration_sec": duration_sec,
225
- "duration_human": _format_duration(duration_sec),
 
 
 
 
226
  "vram_gb": round(vram_gb, 2) if vram_gb is not None else None,
227
  "bucket": meta["bucket"],
228
- "steps_per_epoch": meta["steps_per_epoch"],
229
- "total_steps": meta["total_steps"],
230
  },
231
  "config": suggestion,
232
- "rationale": rationale,
 
233
  }
 
1
+ """SA3 LoRA hyperparameter suggester for the Training tab's "Suggest" button.
2
+
3
+ Reads a Dataset Workbench project directly β€” counts SA3-compatible audio
4
+ files, measures their durations via the same `soundfile.info()` header-only
5
+ probe used elsewhere in the app, factors in the user's picked base model
6
+ and detected GPU VRAM, and returns a config that:
7
+
8
+ * matches the upstream SA3 LoRA docs as the starting point
9
+ (see vendor/stable-audio-3/docs/workflows/lora.md)
10
+ * sets `--include transformer.layers` and `--exclude seconds_total
11
+ to_local_embed` by default (documented best practices, prevents the
12
+ "conditioner hijacking" failure mode on small datasets)
13
+ * picks a `-XS` adapter family when VRAM is tight for the chosen base
14
+ * proposes a `duration` derived from the actual clip lengths in the
15
+ project β€” not a hardcoded 30s
16
+ * warns when the dataset is below SA3's documented minimum (~20 clips)
17
+ or when clips are too short to learn from
18
+
19
+ Returns the same shape the frontend `trainingConfig` uses, so Apply can
20
+ spread the result into state directly.
21
  """
22
 
23
  from __future__ import annotations
24
 
25
+ import math
 
 
26
  from pathlib import Path
27
+ from typing import Any, Dict, List, Optional, Tuple
28
+
29
+ from app.backend.data.projects import _clip_duration_sec
30
+ from app.core.training.sa3_lora_runner import SA3_AUDIO_EXTENSIONS, SA3_BASE_MODELS
31
 
 
32
 
33
+ # --- Discovery -------------------------------------------------------------
 
 
 
 
34
 
35
 
36
  def _list_audio_files(data_dir: Path) -> List[Path]:
37
+ """Files SA3's loader would actually train on. Mirrors the loader's filter."""
38
  if not data_dir.exists():
39
  return []
40
  return [
41
  p for p in data_dir.iterdir()
42
+ if p.is_file() and p.suffix.lower() in SA3_AUDIO_EXTENSIONS
43
  ]
44
 
45
 
46
+ def _duration_stats(audio_files: List[Path]) -> Dict[str, Optional[float]]:
47
+ """Header-only duration probe + summary stats. None-safe for unreadable files."""
48
+ durations: List[float] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  for f in audio_files:
50
+ d = _clip_duration_sec(f)
51
+ if d is not None and d > 0:
52
+ durations.append(d)
53
+ if not durations:
54
+ return {"count": 0, "total": 0.0, "median": None, "p95": None, "max": None, "min": None}
55
+ durations.sort()
56
+ n = len(durations)
57
+ return {
58
+ "count": n,
59
+ "total": float(sum(durations)),
60
+ "median": float(durations[n // 2]),
61
+ "p95": float(durations[min(n - 1, int(math.ceil(0.95 * n)) - 1)]),
62
+ "max": float(durations[-1]),
63
+ "min": float(durations[0]),
64
+ }
 
 
 
 
 
65
 
66
 
67
  def _detect_vram_gb() -> Optional[float]:
 
74
  return None
75
 
76
 
77
+ # --- Bucketing & sizing ----------------------------------------------------
78
+
79
+
80
  def _bucket(file_count: int) -> str:
81
  if file_count < 20:
82
  return "tiny"
 
87
  return "large"
88
 
89
 
90
+ # SA3's documented quick-start: --steps 1000, with no dataset-size caveat.
91
+ # (vendor/stable-audio-3/docs/workflows/lora.md, "Standard (recommended starting point)".)
92
+ # SA3 trains by *windows seen*, not epochs, so a 5h dataset doesn't need more
93
+ # steps than a 30min one β€” it just produces more diverse sampling per step.
94
+ # We keep the SA3 default for tiny/small, and bump modestly only when a
95
+ # dataset is large enough that 1000 steps won't see all unique windows.
96
+ _STEPS_BY_BUCKET: Dict[str, int] = {
97
+ "tiny": 1000,
98
+ "small": 1000,
99
+ "medium": 2000,
100
+ "large": 4000,
101
+ }
102
+
103
+
104
+ # Per-base-model VRAM table from SA3 docs. (standard_gb, xs_bf16_gb)
105
+ # Source: docs/workflows/lora.md memory table.
106
+ _VRAM_REQ: Dict[str, Tuple[float, float]] = {
107
+ "sa3-small-music-base": (2.5, 2.0),
108
+ "sa3-small-sfx-base": (2.5, 2.0),
109
+ "sa3-medium-base": (6.5, 5.5),
110
+ }
111
+
112
+
113
+ def _pick_adapter(base_model: Optional[str], vram_gb: Optional[float]) -> Tuple[str, bool]:
114
+ """Choose adapter family. Returns (adapter_type, vram_constrained_flag).
115
+
116
+ SA3 docs recommend the `-xs` family + bf16 base precision for VRAM-limited
117
+ hosts. Headroom rule: standard_gb + 4 GB activations is the comfort target;
118
+ below that we pick the xs family.
119
+ """
120
+ default = "dora-rows"
121
+ if base_model is None or vram_gb is None:
122
+ return default, False
123
+ std_gb, _xs_gb = _VRAM_REQ.get(base_model, (2.5, 2.0))
124
+ comfort = std_gb + 4.0
125
+ constrained = vram_gb < comfort
126
+ return ("dora-rows-xs" if constrained else default), constrained
127
+
128
+
129
+ def _model_max_window_sec(base_model: Optional[str]) -> float:
130
+ """SA3's native training length for the base, from its model config
131
+ sample_size / sample_rate: medium-base β‰ˆ380s, small bases β‰ˆ120s. The
132
+ `seconds_total` conditioner caps at 384s, so 380 is the safe medium ceiling.
133
+ Longer windows aren't a model limit below these β€” they're VRAM/time bound.
134
+ """
135
+ if base_model and "medium" in base_model:
136
+ return 380.0
137
+ return 120.0
138
+
139
+
140
+ def _pick_duration(p95_clip_sec: Optional[float], base_model: Optional[str]) -> float:
141
+ """Set training window from the project's actual p95 clip length.
142
+
143
+ Floors at 5s; caps at β€” and defaults to β€” the model's native length
144
+ (β‰ˆ120s small / β‰ˆ380s medium) rather than an arbitrary 30s. SA3 random-crops
145
+ longer files, so the only real limits are the model's sequence length and
146
+ VRAM. Rounds up p95 with 2s headroom so the window isn't cropping the tails
147
+ of typical clips. With no duration data, defaults to the model max.
148
+ """
149
+ model_max = _model_max_window_sec(base_model)
150
+ if p95_clip_sec is None or p95_clip_sec <= 0:
151
+ return model_max
152
+ suggested = math.ceil(p95_clip_sec + 2.0)
153
+ return float(max(5, min(model_max, suggested)))
154
+
155
+
156
+ def _pick_batch_size(bucket: str, vram_gb: Optional[float]) -> int:
157
+ """SA3 examples all use batch 1. Only go higher on roomy hardware + big data.
158
+
159
+ 24 GB threshold for batch 2 leaves enough headroom for medium-base + bf16
160
+ activations across two samples. Going beyond batch 2 hits diminishing
161
+ returns and risks OOM mid-run.
162
+ """
163
+ if vram_gb is None or vram_gb < 24:
164
+ return 1
165
+ if bucket in ("medium", "large"):
166
+ return 2
167
+ return 1
168
+
169
+
170
+ # Filter pattern straight from SA3 docs:
171
+ # --include transformer.layers --exclude seconds_total to_local_embed
172
+ # "Everything except local embedding and seconds_total conditioner" β€” prevents
173
+ # the conditioner-hijacking failure mode that bites small datasets hardest.
174
+ _INCLUDE_DEFAULT: List[str] = ["transformer.layers"]
175
+ _EXCLUDE_DEFAULT: List[str] = ["seconds_total", "to_local_embed"]
176
+
177
+
178
+ # --- Suggestion + rationale ------------------------------------------------
179
+
180
+
181
+ def _heuristic(
182
+ file_count: int,
183
+ dur_stats: Dict[str, Optional[float]],
184
+ base_model: Optional[str],
185
+ vram_gb: Optional[float],
186
+ ) -> Dict[str, Any]:
187
  bucket = _bucket(file_count)
188
+ steps = _STEPS_BY_BUCKET[bucket]
189
+ adapter, constrained = _pick_adapter(base_model, vram_gb)
190
+ duration = _pick_duration(dur_stats.get("p95"), base_model)
191
+ batch = _pick_batch_size(bucket, vram_gb)
192
+
193
+ # Mild dropout for tiny datasets only β€” extra regularization where overfit
194
+ # is most likely. SA3 default is 0.0; we deviate intentionally.
195
+ dropout = 0.05 if bucket == "tiny" else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ # Checkpoint cadence: ~10 checkpoints per run, but keep within sane bounds
198
+ # so we don't write a checkpoint every 50 steps on tiny runs or sit on a
199
+ # 2K-step gap on long ones.
200
+ checkpoint_every = max(250, min(1000, steps // 10))
201
 
202
  return {
203
+ "steps": steps,
204
  "batchSize": batch,
205
+ "learningRate": 1e-4,
206
+ "loraRank": 16,
207
+ "loraAlpha": 16,
208
+ "loraDropout": dropout,
209
+ "adapterType": adapter,
210
+ "precision": "bf16",
211
+ "duration": duration,
212
+ "checkpointSteps": checkpoint_every,
213
+ "include": list(_INCLUDE_DEFAULT),
214
+ "exclude": list(_EXCLUDE_DEFAULT),
215
  "_meta": {
216
  "bucket": bucket,
217
+ "target_steps": steps,
 
 
218
  "vram_constrained": constrained,
219
+ "picked_adapter_for_vram": constrained,
220
  },
221
  }
222
 
 
230
  return f"{m}m {s}s"
231
 
232
 
233
+ def _compose_rationale(
234
+ file_count: int,
235
+ dur_stats: Dict[str, Optional[float]],
236
+ base_model: Optional[str],
237
+ vram_gb: Optional[float],
238
+ config: Dict[str, Any],
239
+ meta: Dict[str, Any],
240
+ ) -> Tuple[List[str], List[str]]:
241
+ """Return (bullets, warnings). Warnings are surfaced separately in the UI."""
242
+ bullets: List[str] = []
243
+ warnings: List[str] = []
244
+
245
+ total = dur_stats.get("total") or 0.0
246
  bullets.append(
247
+ f"Dataset: {file_count} clip{'s' if file_count != 1 else ''}, "
248
+ f"total {_format_duration(total)} β†’ \"{meta['bucket']}\" bucket."
 
249
  )
250
+
251
+ p95 = dur_stats.get("p95")
252
+ median = dur_stats.get("median")
253
+ if p95 is not None and median is not None:
254
+ bullets.append(
255
+ f"Clip durations: median {median:.1f}s, p95 {p95:.1f}s. "
256
+ f"Training window set to {config['duration']:.0f}s."
257
+ )
258
+
259
  if vram_gb is not None:
260
+ bullets.append(
261
+ f"Detected GPU: {vram_gb:.1f} GB"
262
+ + (" (tight for the chosen base β€” switched adapter to a -XS variant)."
263
+ if meta["vram_constrained"] else " (comfortable headroom).")
264
+ )
265
  else:
266
+ bullets.append("No CUDA GPU detected β€” adapter defaults to dora-rows; "
267
+ "training will run on CPU/MPS where supported.")
268
+
269
+ if meta["target_steps"] == 1000:
 
 
 
270
  bullets.append(
271
+ "Target 1 000 optimizer steps β€” SA3's documented quick-start. "
272
+ "LoRAs typically overfit well before this; watch the loss curve."
 
273
  )
274
  else:
275
  bullets.append(
276
+ f"Target {meta['target_steps']:,} optimizer steps β€” modest bump "
277
+ f"above SA3's 1 000-step default for larger datasets to see more "
278
+ "unique sampling windows."
279
  )
 
280
 
281
+ bullets.append(
282
+ f"Layer filter: include `{config['include'][0]}`, exclude "
283
+ f"`{' '.join(config['exclude'])}`. "
284
+ "Documented SA3 default β€” prevents conditioner-hijacking on small sets."
285
+ )
286
 
287
+ bullets.append(
288
+ f"Adapter `{config['adapterType']}` Β· rank 16 Β· Ξ± 16 Β· "
289
+ f"dropout {config['loraDropout']} Β· {config['precision']} base."
290
+ )
291
+
292
+ # --- Warnings (separate channel) ---------------------------------------
293
+
294
+ if file_count < 20:
295
+ warnings.append(
296
+ f"{file_count} clips is below SA3's documented minimum of ~20. "
297
+ "Expect heavy overfit and poor generalization β€” add more data if you can."
298
+ )
299
+ if median is not None and median < 2.0:
300
+ warnings.append(
301
+ f"Median clip is only {median:.1f}s β€” most of the training window "
302
+ f"({config['duration']:.0f}s) will be silence-padded. "
303
+ "Re-slice the source material to longer chunks for better signal."
304
+ )
305
+ if config["duration"] > 45:
306
+ warnings.append(
307
+ f"Training window is {config['duration']:.0f}s. Longer windows use "
308
+ "markedly more VRAM and step time (DiT attention scales with length). "
309
+ "If you hit OOM, lower the window or pre-encode the dataset first."
310
+ )
311
+
312
+ # VRAM Γ— base model crosscheck
313
+ if base_model in _VRAM_REQ:
314
+ std_gb, xs_gb = _VRAM_REQ[base_model]
315
+ if vram_gb is None:
316
+ if base_model == "sa3-medium-base":
317
+ warnings.append(
318
+ "No CUDA GPU detected, but you picked Medium-Base. "
319
+ "Medium-base needs CUDA + Flash-Attn 2 (Linux) and β‰₯5.5 GB VRAM. "
320
+ "Consider Small-Music-Base or Small-SFX-Base for CPU/MPS hosts."
321
+ )
322
+ elif vram_gb < xs_gb:
323
+ warnings.append(
324
+ f"GPU has {vram_gb:.1f} GB; even {base_model} with bf16+lora-xs needs "
325
+ f"~{xs_gb:.1f} GB. Training will likely OOM. Pick a smaller base."
326
+ )
327
+ elif vram_gb < std_gb:
328
+ warnings.append(
329
+ f"GPU has {vram_gb:.1f} GB; {base_model} standard config needs "
330
+ f"~{std_gb:.1f} GB. The -XS adapter (selected) brings it to ~{xs_gb:.1f} GB."
331
+ )
332
+
333
+ return bullets, warnings
334
+
335
+
336
+ def suggest(data_dir: Path, base_model: Optional[str] = None) -> Dict[str, Any]:
337
+ """Public entry point. SA3 is LoRA-only; no `mode` switch."""
338
  audio_files = _list_audio_files(data_dir)
339
  file_count = len(audio_files)
340
  if file_count == 0:
341
  return {
342
  "ok": False,
343
+ "error": (
344
+ f"No SA3-compatible audio in {data_dir}. SA3's loader accepts "
345
+ + ", ".join(SA3_AUDIO_EXTENSIONS) + "."
346
+ ),
347
  }
348
 
349
+ dur_stats = _duration_stats(audio_files)
 
350
  vram_gb = _detect_vram_gb()
351
 
352
+ suggestion = _heuristic(file_count, dur_stats, base_model, vram_gb)
353
  meta = suggestion.pop("_meta")
354
+ bullets, warnings = _compose_rationale(
355
+ file_count, dur_stats, base_model, vram_gb, suggestion, meta
356
+ )
357
+
358
+ # Caption coverage: SA3 trains on audio + matching .txt sidecars, and
359
+ # silently drops clips whose prompt is blank. Surface missing captions so
360
+ # the user isn't unknowingly training on a fraction of the dataset.
361
+ uncaptioned = sum(
362
+ 1 for p in audio_files
363
+ if not (p.with_suffix(".txt").exists()
364
+ and p.with_suffix(".txt").read_text(encoding="utf-8", errors="ignore").strip())
365
+ )
366
+ if uncaptioned:
367
+ warnings.insert(0, (
368
+ f"{uncaptioned} of {file_count} clip{'s' if file_count != 1 else ''} "
369
+ "have no annotation. SA3 silently skips un-captioned clips at train "
370
+ "time β€” annotate them first or they won't contribute to the LoRA."
371
+ ))
372
 
373
  return {
374
  "ok": True,
375
  "stats": {
376
  "file_count": file_count,
377
+ "duration_sec": dur_stats.get("total") or 0.0,
378
+ "duration_human": _format_duration(dur_stats.get("total") or 0.0),
379
+ "median_clip_sec": dur_stats.get("median"),
380
+ "p95_clip_sec": dur_stats.get("p95"),
381
+ "max_clip_sec": dur_stats.get("max"),
382
+ "min_clip_sec": dur_stats.get("min"),
383
  "vram_gb": round(vram_gb, 2) if vram_gb is not None else None,
384
  "bucket": meta["bucket"],
385
+ "total_steps": meta["target_steps"],
386
+ "base_model": base_model,
387
  },
388
  "config": suggestion,
389
+ "rationale": bullets,
390
+ "warnings": warnings,
391
  }
app/core/training/sa3_lora_runner.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for the SA3 LoRA training pipeline.
2
+
3
+ Responsibilities:
4
+ * Pre-stage the base model in an app-folder HF cache so the training
5
+ subprocess finds it without falling back to ~/.cache/huggingface.
6
+ * Build the train_lora.py subprocess command + env.
7
+ * Convert PyTorch Lightning .ckpt LoRA outputs to SA3-native .safetensors
8
+ with the base_model and run name embedded in the metadata header.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+
18
+
19
+ # SA3 model_id β†’ (sa3_name passed to train_lora.py --model, HF repo id)
20
+ # Only `*-base` variants are valid LoRA targets β€” SA3 won't train against
21
+ # the post-trained / distilled checkpoints.
22
+ SA3_BASE_MODELS: Dict[str, Tuple[str, str]] = {
23
+ "sa3-small-music-base": ("small-music-base", "stabilityai/stable-audio-3-small-music-base"),
24
+ "sa3-small-sfx-base": ("small-sfx-base", "stabilityai/stable-audio-3-small-sfx-base"),
25
+ "sa3-medium-base": ("medium-base", "stabilityai/stable-audio-3-medium-base"),
26
+ }
27
+
28
+
29
+ # Each *-base config references its T5Gemma conditioner at a subfolder of the
30
+ # *post-trained sibling* repo (e.g., medium-base's t5gemma lives at
31
+ # stabilityai/stable-audio-3-medium / t5gemma-b-b-ul2/). Without that subtree
32
+ # in the cache, training crashes inside the conditioner constructor when SA3
33
+ # does `AutoTokenizer.from_pretrained(repo_id, subfolder=...)`.
34
+ # Keep in sync with model_config.json's `conditioning.configs[0].config.repo_id`.
35
+ SA3_T5GEMMA_SIBLINGS: Dict[str, Tuple[str, str]] = {
36
+ "sa3-small-music-base": ("stabilityai/stable-audio-3-small-music", "t5gemma-b-b-ul2"),
37
+ "sa3-small-sfx-base": ("stabilityai/stable-audio-3-small-sfx", "t5gemma-b-b-ul2"),
38
+ "sa3-medium-base": ("stabilityai/stable-audio-3-medium", "t5gemma-b-b-ul2"),
39
+ }
40
+
41
+
42
+ # Extensions SA3's training data loader actually accepts.
43
+ # Source: vendor/stable-audio-3/stable_audio_3/data/dataset.py:91.
44
+ # Single source of truth β€” both the health check and the hyperparam suggester
45
+ # use this so what we count matches what the loader will train on.
46
+ SA3_AUDIO_EXTENSIONS: Tuple[str, ...] = (".wav", ".mp3", ".flac", ".ogg", ".aif", ".opus")
47
+
48
+
49
+ # --- Base model pre-staging -------------------------------------------------
50
+
51
+ def prestage_base_model(
52
+ sa3_model_id: str,
53
+ hub_dir: Path,
54
+ token: Optional[str] = None,
55
+ progress_callback: Optional[Any] = None,
56
+ ) -> Path:
57
+ """Ensure the base model is in `hub_dir` (HF-cache layout, inside app folder).
58
+
59
+ train_lora.py calls `model_cfg.resolve()` which is hf_hub_download under
60
+ the hood β€” it reads from the HF cache root. We point it at hub_dir via
61
+ the HF_HUB_CACHE env var on the subprocess; for that to actually find
62
+ files we need to download into hub_dir using snapshot_download with
63
+ `cache_dir=hub_dir`.
64
+
65
+ Idempotent: if the model is already cached there, returns the cached
66
+ snapshot dir without re-downloading.
67
+ """
68
+ if sa3_model_id not in SA3_BASE_MODELS:
69
+ raise ValueError(
70
+ f"'{sa3_model_id}' is not a valid LoRA base. Pick one of "
71
+ f"{list(SA3_BASE_MODELS)} (only *-base variants are CFG-aware)."
72
+ )
73
+ sa3_name, repo_id = SA3_BASE_MODELS[sa3_model_id]
74
+ hub_dir.mkdir(parents=True, exist_ok=True)
75
+
76
+ from huggingface_hub import snapshot_download
77
+
78
+ allow_patterns = [
79
+ "*.safetensors", "*.json", "*.txt", "*.model",
80
+ "tokenizer*", "*.tiktoken",
81
+ ]
82
+
83
+ if progress_callback:
84
+ progress_callback(5, f"Staging {sa3_name} base model in {hub_dir.name}/...")
85
+
86
+ # Prefer cache. snapshot_download otherwise phones home on every run to
87
+ # check the model's revision β€” wasteful and noisy when the user just
88
+ # downloaded the weights through the Checkpoint Manager. If anything's
89
+ # missing, fall back to an online fetch.
90
+ try:
91
+ local_snap = snapshot_download(
92
+ repo_id=repo_id,
93
+ cache_dir=str(hub_dir),
94
+ token=token,
95
+ allow_patterns=allow_patterns,
96
+ local_files_only=True,
97
+ )
98
+ if progress_callback:
99
+ progress_callback(15, "Base model ready (from cache).")
100
+ except Exception:
101
+ if progress_callback:
102
+ progress_callback(8, "Cache miss β€” fetching from HuggingFace…")
103
+ local_snap = snapshot_download(
104
+ repo_id=repo_id,
105
+ cache_dir=str(hub_dir),
106
+ token=token,
107
+ allow_patterns=allow_patterns,
108
+ )
109
+ if progress_callback:
110
+ progress_callback(15, "Base model ready.")
111
+
112
+ # Pre-stage the T5Gemma conditioner from the post-trained sibling repo.
113
+ # SA3's *-base model_config.json points the prompt conditioner at
114
+ # e.g. stabilityai/stable-audio-3-medium / t5gemma-b-b-ul2/, NOT at the
115
+ # base repo. Without this subtree in the cache, the training subprocess
116
+ # (HF_HUB_OFFLINE=1) crashes when AutoTokenizer.from_pretrained tries
117
+ # to phone home.
118
+ sibling = SA3_T5GEMMA_SIBLINGS.get(sa3_model_id)
119
+ if sibling:
120
+ sib_repo, sib_subfolder = sibling
121
+ sib_patterns = [f"{sib_subfolder}/*"]
122
+ if progress_callback:
123
+ progress_callback(16, f"Staging T5Gemma conditioner from {sib_repo}…")
124
+ try:
125
+ snapshot_download(
126
+ repo_id=sib_repo,
127
+ cache_dir=str(hub_dir),
128
+ token=token,
129
+ allow_patterns=sib_patterns,
130
+ local_files_only=True,
131
+ )
132
+ if progress_callback:
133
+ progress_callback(18, "T5Gemma conditioner ready (from cache).")
134
+ except Exception:
135
+ if progress_callback:
136
+ progress_callback(17, f"T5Gemma cache miss β€” fetching from {sib_repo}…")
137
+ snapshot_download(
138
+ repo_id=sib_repo,
139
+ cache_dir=str(hub_dir),
140
+ token=token,
141
+ allow_patterns=sib_patterns,
142
+ )
143
+ if progress_callback:
144
+ progress_callback(18, "T5Gemma conditioner ready.")
145
+
146
+ return Path(local_snap)
147
+
148
+
149
+ # --- Subprocess command builder ---------------------------------------------
150
+
151
+ def build_train_command(
152
+ *,
153
+ venv_python: str,
154
+ sa3_vendor_dir: Path,
155
+ sa3_model_name: str,
156
+ data_dir: Path,
157
+ encoded_dir: Optional[Path] = None,
158
+ svd_bases_path: Optional[Path] = None,
159
+ save_dir: Path,
160
+ rank: int = 16,
161
+ lora_alpha: Optional[int] = None,
162
+ adapter_type: str = "dora-rows",
163
+ dropout: float = 0.0,
164
+ lr: float = 1e-4,
165
+ steps: int = 5000,
166
+ batch_size: int = 1,
167
+ duration: float = 30.0,
168
+ base_precision: str = "bf16",
169
+ include: Optional[List[str]] = None,
170
+ exclude: Optional[List[str]] = None,
171
+ seed: int = 42,
172
+ checkpoint_every: int = 500,
173
+ # `--log_every` controls how often DiffusionCondTrainingWrapper calls
174
+ # self.log(). 50 is SA3's example value and gives a much cleaner chart
175
+ # than per-step logging β€” diffusion loss is intrinsically noisy (each
176
+ # step samples a random timestep), so per-step values bounce wildly and
177
+ # the trend is hard to read. Sampling every 50 steps gives ~20 points
178
+ # for a 1000-step run, which the EMA smoother turns into a legible
179
+ # descent. First point arrives after step 49 (β‰ˆ15s on small, β‰ˆ50s on
180
+ # medium, dominated by first-step JIT warmup anyway).
181
+ log_every: int = 50,
182
+ num_workers: int = 2,
183
+ name: str = "fragmenta-lora",
184
+ ) -> List[str]:
185
+ """Construct the train_lora.py subprocess argv."""
186
+ cmd = [
187
+ venv_python,
188
+ str(sa3_vendor_dir / "scripts" / "train_lora.py"),
189
+ "--model", sa3_model_name,
190
+ "--data_dir", str(data_dir),
191
+ "--save_dir", str(save_dir),
192
+ "--rank", str(int(rank)),
193
+ "--adapter_type", adapter_type,
194
+ "--dropout", str(float(dropout)),
195
+ "--lr", str(float(lr)),
196
+ "--steps", str(int(steps)),
197
+ "--batch_size", str(int(batch_size)),
198
+ "--duration", str(float(duration)),
199
+ "--base_precision", base_precision,
200
+ "--seed", str(int(seed)),
201
+ "--checkpoint_every", str(int(checkpoint_every)),
202
+ "--log_every", str(int(log_every)),
203
+ "--num_workers", str(int(num_workers)),
204
+ "--name", name,
205
+ "--logger", "csv",
206
+ # demo_every set to a very large number β€” Fragmenta's training
207
+ # monitor doesn't surface demo audio, no need to spend cycles.
208
+ "--demo_every", "1000000",
209
+ ]
210
+ if encoded_dir is not None:
211
+ # Phase 6 β€” feed pre-encoded latents directory. SA3's train_lora.py
212
+ # then uses PreEncodedDataset instead of SampleDataset and skips
213
+ # the SAME autoencoder pass per step.
214
+ cmd += ["--encoded_dir", str(encoded_dir)]
215
+ if svd_bases_path is not None and adapter_type.endswith("-xs"):
216
+ # -XS adapters factor weights against precomputed SVD bases. SA3 only
217
+ # *loads* bases from this path (it doesn't write them), so we pass it
218
+ # only when a cached .pt already exists β€” otherwise SA3 recomputes the
219
+ # SVD per layer on device (slower, but correct). See SA3Trainer for the
220
+ # cache path convention.
221
+ cmd += ["--svd_bases_path", str(svd_bases_path)]
222
+ if lora_alpha is not None:
223
+ cmd += ["--lora_alpha", str(int(lora_alpha))]
224
+ if include:
225
+ cmd += ["--include", *include]
226
+ if exclude:
227
+ cmd += ["--exclude", *exclude]
228
+ return cmd
229
+
230
+
231
+ # --- Checkpoint conversion (.ckpt β†’ .safetensors with base_model metadata) ---
232
+
233
+ def convert_run_checkpoints_to_safetensors(
234
+ run_dir: Path,
235
+ base_model: str,
236
+ model_name: Optional[str] = None,
237
+ delete_originals: bool = True,
238
+ ) -> List[Path]:
239
+ """Convert PyTorch Lightning .ckpt files in a run's checkpoints/ directory
240
+ to SA3's native .safetensors LoRA format, with `base_model` injected into
241
+ the safetensors metadata header so /api/loras can filter by it.
242
+
243
+ Why: SA3's `train_lora.py` writes Lightning .ckpt files. The inference
244
+ LoRA picker (/api/loras) globs for *.safetensors only. Without this
245
+ conversion, every trained LoRA is functionally orphaned β€” saved
246
+ correctly to disk but invisible to the inference loader.
247
+
248
+ Idempotent: skips any .ckpt whose .safetensors sibling already exists
249
+ with a non-zero size.
250
+
251
+ Returns the list of paths to the produced .safetensors files (sorted).
252
+ """
253
+ ckpt_dir = run_dir / "checkpoints"
254
+ if not ckpt_dir.exists():
255
+ return []
256
+
257
+ # Imports deferred so this module can be imported without the SA3 vendor
258
+ # being on sys.path (e.g., during pure orchestrator construction).
259
+ from app.core.config import get_config
260
+ sa3_vendor = get_config().get_path("stable_audio_3")
261
+ pp = sys.path[:]
262
+ if str(sa3_vendor) not in pp:
263
+ sys.path.insert(0, str(sa3_vendor))
264
+ try:
265
+ from stable_audio_3.models.lora.utils import load_lora_checkpoint
266
+ from safetensors.torch import save_file as st_save_file
267
+ finally:
268
+ # Don't permanently mutate sys.path from a helper call.
269
+ if sys.path != pp:
270
+ sys.path[:] = pp
271
+
272
+ written: List[Path] = []
273
+ for ckpt_path in sorted(ckpt_dir.glob("*.ckpt")):
274
+ out_path = ckpt_path.with_suffix(".safetensors")
275
+ if out_path.exists() and out_path.stat().st_size > 0:
276
+ # Already converted (older artifact or a previous pass). Just
277
+ # bookkeep so the caller sees it in the return list.
278
+ written.append(out_path)
279
+ continue
280
+ try:
281
+ state_dict, lora_config = load_lora_checkpoint(ckpt_path)
282
+ except Exception:
283
+ # Corrupt or truncated ckpt β€” skip rather than crash the
284
+ # post-training pass.
285
+ continue
286
+
287
+ # Top-level metadata is what /api/loras' safetensors reader inspects
288
+ # directly. We also keep the canonical `lora_config` JSON blob so
289
+ # SA3's own load_lora_checkpoint() can parse the file as-is.
290
+ metadata = {
291
+ "lora_config": json.dumps(lora_config or {}),
292
+ "base_model": base_model,
293
+ }
294
+ if model_name:
295
+ metadata["model_name"] = model_name
296
+ # Cast fp16 to keep file sizes consistent with SA3's standard format.
297
+ fp16_dict = {k: (v.half() if v.is_floating_point() else v)
298
+ for k, v in state_dict.items()}
299
+ st_save_file(fp16_dict, str(out_path), metadata=metadata)
300
+ if delete_originals:
301
+ try:
302
+ ckpt_path.unlink()
303
+ except OSError:
304
+ pass
305
+ written.append(out_path)
306
+ return sorted(written)
307
+
308
+
309
+ def build_train_env(sa3_vendor_dir: Path, hub_dir: Path) -> Dict[str, str]:
310
+ """Subprocess env: redirect HF cache into the app folder + silence WANDB."""
311
+ env = os.environ.copy()
312
+ # Make `import stable_audio_3` work without pip-installing the package.
313
+ pp = env.get("PYTHONPATH", "")
314
+ env["PYTHONPATH"] = (
315
+ f"{sa3_vendor_dir}{os.pathsep}{pp}" if pp else str(sa3_vendor_dir)
316
+ )
317
+ # Pin the HF cache to our app-folder hub dir; otherwise train_lora.py's
318
+ # model_cfg.resolve() would write into ~/.cache/huggingface/hub. Cover
319
+ # the legacy + transformers env names too for defense-in-depth.
320
+ env["HF_HUB_CACHE"] = str(hub_dir)
321
+ env["HUGGINGFACE_HUB_CACHE"] = str(hub_dir)
322
+ env["TRANSFORMERS_CACHE"] = str(hub_dir)
323
+ env["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
324
+ env["WANDB_DISABLED"] = "1"
325
+ # Force the training subprocess into offline mode for HF β€” we already
326
+ # pre-staged the base model in prestage_base_model(), so any remaining
327
+ # network call from the SA3 internals would be a noisy revision check
328
+ # against a cache we know is current.
329
+ env["HF_HUB_OFFLINE"] = "1"
330
+ env["TRANSFORMERS_OFFLINE"] = "1"
331
+ return env
app/core/training/sa3_trainer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SA3 LoRA training orchestrator β€” Phase 5.
2
+
3
+ Public surface (matches what app/backend/app.py imports):
4
+ start_training(config) -> dict
5
+ get_training_status() -> dict
6
+ stop_training() -> dict
7
+ preview_training_plan(config) -> dict
8
+ class SA3Trainer
9
+
10
+ Training is dispatched as a subprocess running
11
+ `vendor/stable-audio-3/scripts/train_lora.py`. Progress comes back through
12
+ two channels:
13
+ * stdout/stderr from the subprocess (parsed for tqdm "step X/Y" lines)
14
+ * metrics.csv that train_lora.py writes under --save_dir
15
+
16
+ Config shape (from the frontend training form):
17
+ {
18
+ "modelName": "my-lora", # used for run dir name
19
+ "baseModel": "sa3-medium-base", # must end in -base
20
+ "projectName": "my_first_track", # Dataset Workbench project name
21
+ "steps": 5000,
22
+ "checkpointSteps": 500, # checkpoint cadence
23
+ "batchSize": 1,
24
+ "learningRate": 1.0e-4,
25
+ "duration": 30.0, # max clip seconds per sample
26
+ "loraRank": 16,
27
+ "loraAlpha": 16, # null β†’ defaults to rank
28
+ "loraDropout": 0.0,
29
+ "adapterType": "dora-rows",
30
+ "precision": "bf16", # bf16|fp16
31
+ "seed": 42,
32
+ "include": null, # list[str] or null
33
+ "exclude": null
34
+ }
35
+ """
36
+ from __future__ import annotations
37
+
38
+ import csv
39
+ import json
40
+ import os
41
+ import re
42
+ import shlex
43
+ import signal
44
+ import subprocess
45
+ import sys
46
+ import threading
47
+ import time
48
+ from pathlib import Path
49
+ from typing import Any, Dict, List, Optional
50
+
51
+ from app.backend.data.projects import project_path
52
+ from app.core.config import get_config
53
+ from app.core.training.sa3_lora_runner import (
54
+ SA3_BASE_MODELS,
55
+ build_train_command,
56
+ build_train_env,
57
+ convert_run_checkpoints_to_safetensors,
58
+ prestage_base_model,
59
+ )
60
+ from utils.logger import get_logger
61
+
62
+ logger = get_logger("SA3Trainer")
63
+
64
+
65
+ # --- Defaults --------------------------------------------------------------
66
+
67
+ DEFAULT_STEPS = 5000
68
+ DEFAULT_CHECKPOINT_STEPS = 500
69
+ DEFAULT_BATCH_SIZE = 1
70
+ DEFAULT_LR = 1e-4
71
+ DEFAULT_DURATION = 30.0
72
+ DEFAULT_RANK = 16
73
+ DEFAULT_ADAPTER = "dora-rows"
74
+ DEFAULT_PRECISION = "bf16"
75
+
76
+
77
+ # --- SA3Trainer singleton --------------------------------------------------
78
+
79
+ class SA3Trainer:
80
+ def __init__(self, config: Dict[str, Any]) -> None:
81
+ self.config: Dict[str, Any] = config or {}
82
+ self.process: Optional[subprocess.Popen] = None
83
+ self.run_dir: Optional[Path] = None
84
+ self.metrics_csv: Optional[Path] = None
85
+ self._monitor_thread: Optional[threading.Thread] = None
86
+ self.status: Dict[str, Any] = {
87
+ "is_training": False,
88
+ "status": "idle",
89
+ "step": 0,
90
+ "total_steps": 0,
91
+ "loss": None,
92
+ "message": "",
93
+ "started_at": None,
94
+ "ended_at": None,
95
+ "log_tail": [], # last ~50 stdout lines
96
+ "checkpoints": [], # safetensors written so far
97
+ "error": None,
98
+ }
99
+
100
+ # --- Public API --------------------------------------------------------
101
+
102
+ def start(self) -> Dict[str, Any]:
103
+ # Fresh run on this trainer β€” clear any stop flag from a prior run.
104
+ self._stop_requested = False
105
+ # Mark training as in-flight BEFORE any blocking work. /api/start-training
106
+ # can block for tens of seconds (T5Gemma sibling fetch, base-model
107
+ # prestaging) β€” during that window the frontend polls
108
+ # /api/training-status and would otherwise see is_training=False from
109
+ # the __init__ default and interpret it as "training complete".
110
+ self.status.update({
111
+ "is_training": True,
112
+ "status": "staging",
113
+ "started_at": time.time(),
114
+ "ended_at": None,
115
+ "step": 0,
116
+ "total_steps": int(self.config.get("steps") or DEFAULT_STEPS),
117
+ "loss": None,
118
+ "error": None,
119
+ "checkpoints": [],
120
+ # Surface the concrete seed (the backend rolls a random one when the
121
+ # UI requests it) so the user can reproduce a run they liked.
122
+ "seed": (int(self.config["seed"]) if self.config.get("seed") is not None else None),
123
+ "message": "Preparing dataset and base model…",
124
+ })
125
+ try:
126
+ self._maybe_wipe_run_dir()
127
+ self._resolve_paths()
128
+ self._stage_dataset()
129
+ self._stage_base_model()
130
+ cmd, env = self._build_invocation()
131
+ self._spawn(cmd, env)
132
+ logger.info(
133
+ "Training started Β· project=%s Β· base=%s Β· adapter=%s Β· "
134
+ "rank=%s Β· steps=%s Β· batch=%s Β· lr=%s Β· duration=%ss",
135
+ self.config.get("projectName"),
136
+ self.config.get("baseModel"),
137
+ self.config.get("adapterType") or DEFAULT_ADAPTER,
138
+ self.config.get("loraRank") or DEFAULT_RANK,
139
+ self.config.get("steps") or DEFAULT_STEPS,
140
+ self.config.get("batchSize") or DEFAULT_BATCH_SIZE,
141
+ self.config.get("learningRate") or DEFAULT_LR,
142
+ self.config.get("duration") or DEFAULT_DURATION,
143
+ )
144
+ return {"success": True, "run_dir": str(self.run_dir)}
145
+ except Exception as e:
146
+ self.status["error"] = str(e)
147
+ self.status["status"] = "failed"
148
+ self.status["is_training"] = False
149
+ self.status["ended_at"] = time.time()
150
+ logger.error("Training failed to start: %s", e)
151
+ return {"error": str(e)}
152
+
153
+ def get_status(self) -> Dict[str, Any]:
154
+ # Snapshot + add a few derived fields the frontend already reads, so
155
+ # the polling loop in App.js doesn't have to know about both names.
156
+ # SA3 is step-based; we no longer expose `current_epoch`.
157
+ # If the on-disk checkpoint count looks stale (run finished, glob
158
+ # ran with the old filter, no live files surfaced), rescan once
159
+ # lazily so the UI catches up without needing a backend restart.
160
+ if not self.status.get("checkpoints") and self.run_dir is not None:
161
+ ckpt_dir = self.run_dir / "checkpoints"
162
+ if ckpt_dir.exists() and any(ckpt_dir.glob("*.ckpt")):
163
+ self._scan_checkpoints()
164
+ s = dict(self.status)
165
+ total = int(s.get("total_steps") or 0)
166
+ step = int(s.get("step") or 0)
167
+ s["current_step"] = step
168
+ s["progress"] = int(round(100 * step / total)) if total > 0 else 0
169
+ s["checkpoints_saved"] = len(s.get("checkpoints") or [])
170
+ return s
171
+
172
+ def stop(self) -> Dict[str, Any]:
173
+ if not self.process or self.process.poll() is not None:
174
+ return {"error": "Nothing to stop β€” no active training run."}
175
+ try:
176
+ # Flag the stop so the monitor thread labels the exit "stopped"
177
+ # rather than "failed" β€” SIGINT doesn't yield a stable rc==-2.
178
+ self._stop_requested = True
179
+ self.process.send_signal(signal.SIGINT)
180
+ try:
181
+ self.process.wait(timeout=10)
182
+ except subprocess.TimeoutExpired:
183
+ self.process.terminate()
184
+ try:
185
+ self.process.wait(timeout=5)
186
+ except subprocess.TimeoutExpired:
187
+ self.process.kill()
188
+ self.status["status"] = "stopped"
189
+ self.status["is_training"] = False
190
+ self.status["ended_at"] = time.time()
191
+ return {"success": True}
192
+ except Exception as e:
193
+ return {"error": str(e)}
194
+
195
+ def preview_plan(self) -> Dict[str, Any]:
196
+ try:
197
+ self._resolve_paths(create_dirs=False)
198
+ except FileNotFoundError as e:
199
+ return {"error": str(e)}
200
+ steps = int(self.config.get("steps") or DEFAULT_STEPS)
201
+ ckpt_every = int(self.config.get("checkpointSteps") or DEFAULT_CHECKPOINT_STEPS)
202
+ ckpts = max(1, steps // max(1, ckpt_every))
203
+ proj_name = self.config.get("projectName") or self.config.get("project_name")
204
+ data_dir = str(project_path(proj_name)) if proj_name else None
205
+ return {
206
+ "model_name": self.config.get("modelName", "fragmenta-lora"),
207
+ "base_model": self.config.get("baseModel"),
208
+ "project_name": proj_name,
209
+ "data_dir": data_dir,
210
+ "save_dir": str(self.run_dir / "checkpoints") if self.run_dir else None,
211
+ "steps": steps,
212
+ "checkpoint_every": ckpt_every,
213
+ "expected_checkpoints": ckpts,
214
+ "rank": int(self.config.get("loraRank") or DEFAULT_RANK),
215
+ "alpha": int(self.config.get("loraAlpha") or self.config.get("loraRank") or DEFAULT_RANK),
216
+ "adapter_type": self.config.get("adapterType") or DEFAULT_ADAPTER,
217
+ "batch_size": int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE),
218
+ "lr": float(self.config.get("learningRate") or DEFAULT_LR),
219
+ "duration": float(self.config.get("duration") or DEFAULT_DURATION),
220
+ "precision": self.config.get("precision") or DEFAULT_PRECISION,
221
+ }
222
+
223
+ # --- Internals ---------------------------------------------------------
224
+
225
+ def _resolve_paths(self, create_dirs: bool = True) -> None:
226
+ cfg = get_config()
227
+ run_name = self._safe_name(self.config.get("modelName") or "lora-run")
228
+ self.run_dir = cfg.get_path("models_fine_tuned") / run_name
229
+ # Lightning's CSVLogger writes metrics.csv under
230
+ # `<save_dir>/lightning_logs/version_X/metrics.csv`. We don't know X
231
+ # upfront, so leave this unset and let _scrape_loss_history /
232
+ # _scrape_csv_loss rglob for it the first time they're called.
233
+ self.metrics_csv = None
234
+ if create_dirs:
235
+ self.run_dir.mkdir(parents=True, exist_ok=True)
236
+ (self.run_dir / "checkpoints").mkdir(exist_ok=True)
237
+
238
+ @classmethod
239
+ def existing_run_info(cls, model_name: str) -> Optional[Dict[str, Any]]:
240
+ """Look up an existing run dir for a given LoRA name. Returns a dict
241
+ of countable artifacts if the dir exists with content, else None.
242
+
243
+ Used by /api/start-training to refuse a same-name run unless the
244
+ caller explicitly opts in to overwrite. Counts only *.ckpt and
245
+ *.safetensors so a half-set-up dir with only a metadata file
246
+ doesn't trip the prompt.
247
+ """
248
+ import shutil # noqa: F401 # ensures shutil resolves if user calls _maybe_wipe later
249
+ cfg = get_config()
250
+ run_name = cls._safe_name(model_name or "lora-run")
251
+ run_dir = cfg.get_path("models_fine_tuned") / run_name
252
+ if not run_dir.exists():
253
+ return None
254
+ ckpt_dir = run_dir / "checkpoints"
255
+ files = []
256
+ if ckpt_dir.exists():
257
+ for ext in ("*.safetensors", "*.ckpt"):
258
+ files.extend(ckpt_dir.glob(ext))
259
+ if not files and not (run_dir / "training.log").exists():
260
+ return None
261
+ return {
262
+ "run_dir": str(run_dir),
263
+ "run_name": run_name,
264
+ "checkpoint_count": len(files),
265
+ "has_log": (run_dir / "training.log").exists(),
266
+ }
267
+
268
+ def _maybe_wipe_run_dir(self) -> None:
269
+ """Honor the `overwrite` flag β€” wipe the run dir before staging."""
270
+ if not self.config.get("overwrite"):
271
+ return
272
+ cfg = get_config()
273
+ run_name = self._safe_name(self.config.get("modelName") or "lora-run")
274
+ run_dir = cfg.get_path("models_fine_tuned") / run_name
275
+ if run_dir.exists():
276
+ import shutil
277
+ shutil.rmtree(run_dir)
278
+ logger.info("Cleared existing run dir before restart: %s", run_dir)
279
+
280
+ def _stage_dataset(self) -> None:
281
+ """Resolve --data_dir from a Dataset Workbench project.
282
+
283
+ Training reads the committed `.txt` sidecars sitting next to each
284
+ audio file inside `<projects_dir>/<projectName>/`. The Workbench's
285
+ "Create Dataset" action materialised those sidecars; we don't
286
+ rewrite anything here.
287
+ """
288
+ project_name = self.config.get("projectName") or self.config.get("project_name")
289
+ if not project_name:
290
+ raise FileNotFoundError(
291
+ "projectName is required. Pick a project in the Training "
292
+ "tab's Dataset picker before starting a run."
293
+ )
294
+ proj_dir = project_path(project_name)
295
+ if not proj_dir.exists():
296
+ raise FileNotFoundError(f"project not found: {project_name}")
297
+
298
+ sidecars = list(proj_dir.glob("*.txt"))
299
+ if not sidecars:
300
+ raise RuntimeError(
301
+ f"project β€œ{project_name}” has no committed prompts yet β€” "
302
+ "annotate the clips and click Create Dataset, then retry."
303
+ )
304
+ # SA3's caption_metadata_fn rejects clips whose sidecar is empty,
305
+ # so they silently drop out of the training set. Count them upfront
306
+ # so the user knows what they're actually training on (and refuse
307
+ # to start if NONE have prompts β€” that would just waste GPU hours).
308
+ non_empty = [p for p in sidecars if p.read_text(encoding="utf-8").strip()]
309
+ if not non_empty:
310
+ raise RuntimeError(
311
+ f"project β€œ{project_name}” has {len(sidecars)} clip(s) but every "
312
+ "sidecar is empty β€” SA3 will reject all of them. Annotate at "
313
+ "least one clip and re-commit before training."
314
+ )
315
+ blank = len(sidecars) - len(non_empty)
316
+ if blank > 0:
317
+ logger.warning(
318
+ "%d of %d clip(s) in project '%s' have empty prompts β€” "
319
+ "SA3 will silently drop them. Training on %d clip(s).",
320
+ blank, len(sidecars), project_name, len(non_empty),
321
+ )
322
+ self.status["log_tail"].append(
323
+ f"Warning: {blank}/{len(sidecars)} clips have empty prompts and "
324
+ "will be dropped by SA3's data loader."
325
+ )
326
+ self.status["log_tail"].append(
327
+ f"Dataset: project '{project_name}' Β· {len(non_empty)} usable clip(s) Β· {proj_dir}"
328
+ )
329
+ self._data_dir = proj_dir
330
+
331
+ # Phase 6 β€” opt into pre-encoded latents if a compatible .latents/
332
+ # cache exists. SA3's `train_lora.py --encoded_dir` then skips the
333
+ # autoencoder pass per step. The cache is AE-bound (same-s vs
334
+ # same-l) so we verify the manifest matches the picked base before
335
+ # using it β€” otherwise we'd feed the DiT mis-shaped latents.
336
+ self._encoded_dir: Optional[Path] = None
337
+ try:
338
+ from app.backend.data.pre_encoder import (
339
+ latents_dir, latents_count, latents_match_base,
340
+ )
341
+ ldir = latents_dir(project_name)
342
+ base_model = self.config.get("baseModel")
343
+ if ldir.exists() and latents_count(project_name) > 0:
344
+ if latents_match_base(project_name, base_model):
345
+ self._encoded_dir = ldir
346
+ self.status["log_tail"].append(
347
+ f"Using pre-encoded latents: {latents_count(project_name)} "
348
+ f"file(s) Β· {ldir}"
349
+ )
350
+ logger.info(
351
+ "Pre-encoded latents detected for project '%s' (%d files) β€” "
352
+ "skipping SAME autoencoder per step.",
353
+ project_name, latents_count(project_name),
354
+ )
355
+ else:
356
+ logger.warning(
357
+ "Pre-encoded latents exist for project '%s' but were "
358
+ "produced by a different autoencoder than the chosen "
359
+ "base (%s) β€” falling back to live encoding.",
360
+ project_name, base_model,
361
+ )
362
+ self.status["log_tail"].append(
363
+ f"Note: project has cached latents but they're for a "
364
+ f"different autoencoder than {base_model}. Training "
365
+ "will re-encode audio per step."
366
+ )
367
+ except Exception as exc:
368
+ logger.warning("Pre-encoded latents probe failed: %s", exc)
369
+
370
+ def _stage_base_model(self) -> None:
371
+ cfg = get_config()
372
+ base_model = self.config.get("baseModel")
373
+ if base_model not in SA3_BASE_MODELS:
374
+ raise ValueError(
375
+ f"baseModel must be one of {list(SA3_BASE_MODELS)}. "
376
+ "Post-trained checkpoints (no -base suffix) can't be used "
377
+ "as a LoRA training base β€” CFG distillation has collapsed "
378
+ "the gradient signal LoRAs target."
379
+ )
380
+ hub_dir = cfg.get_path("models_pretrained") / "sa3" / "hub"
381
+ try:
382
+ from huggingface_hub import get_token
383
+ token = get_token()
384
+ except Exception:
385
+ token = None
386
+
387
+ def _cb(pct: int, msg: str) -> None:
388
+ self.status["message"] = msg
389
+ self.status["log_tail"].append(f"[stage] {msg}")
390
+ # Mirror to the project logger so the terminal shows what's
391
+ # happening during long blocking operations (e.g. first-time
392
+ # T5Gemma sibling fetch can take ~30s on medium-base).
393
+ logger.info("[stage] %s", msg)
394
+
395
+ prestage_base_model(base_model, hub_dir, token=token, progress_callback=_cb)
396
+ self._hub_dir = hub_dir
397
+
398
+ def _build_invocation(self):
399
+ cfg = get_config()
400
+ sa3_vendor = cfg.get_path("stable_audio_3")
401
+ sa3_name, _repo = SA3_BASE_MODELS[self.config["baseModel"]]
402
+
403
+ # Use the Fragmenta venv's python so we share installed packages.
404
+ venv_python = sys.executable
405
+
406
+ precision_raw = (self.config.get("precision") or DEFAULT_PRECISION).lower()
407
+ precision = "bf16" if precision_raw in ("bf16", "bfloat16", "auto", "") else "fp16"
408
+
409
+ include = self.config.get("include")
410
+ if include and isinstance(include, str):
411
+ include = shlex.split(include)
412
+ exclude = self.config.get("exclude")
413
+ if exclude and isinstance(exclude, str):
414
+ exclude = shlex.split(exclude)
415
+
416
+ adapter_type = self.config.get("adapterType") or DEFAULT_ADAPTER
417
+
418
+ # -XS adapters can reuse a precomputed SVD-bases cache keyed by base
419
+ # model, skipping the per-layer SVD at startup. SA3 only loads (never
420
+ # writes) this file, so we pass it only when present; population is a
421
+ # manual/precompute step. Ensure the dir exists so it's discoverable.
422
+ svd_bases_path = None
423
+ if adapter_type.endswith("-xs"):
424
+ svd_cache_dir = get_config().get_path("models_fine_tuned") / ".svd_cache"
425
+ svd_cache_dir.mkdir(parents=True, exist_ok=True)
426
+ candidate = svd_cache_dir / f"{self.config['baseModel']}.pt"
427
+ if candidate.exists():
428
+ svd_bases_path = candidate
429
+
430
+ cmd = build_train_command(
431
+ venv_python=venv_python,
432
+ sa3_vendor_dir=sa3_vendor,
433
+ sa3_model_name=sa3_name,
434
+ data_dir=self._data_dir,
435
+ encoded_dir=getattr(self, "_encoded_dir", None),
436
+ svd_bases_path=svd_bases_path,
437
+ save_dir=self.run_dir / "checkpoints",
438
+ rank=int(self.config.get("loraRank") or DEFAULT_RANK),
439
+ lora_alpha=self.config.get("loraAlpha"),
440
+ adapter_type=adapter_type,
441
+ dropout=float(self.config.get("loraDropout") or 0.0),
442
+ lr=float(self.config.get("learningRate") or DEFAULT_LR),
443
+ steps=int(self.config.get("steps") or DEFAULT_STEPS),
444
+ batch_size=int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE),
445
+ # Default to AND clamp at the base model's native training length
446
+ # (medium β‰ˆ380s, small β‰ˆ120s) β€” SA3's DiT tops out at 4096 latent
447
+ # tokens, so a longer window would exceed the model, not just cost
448
+ # VRAM. A missing duration defaults to the model max.
449
+ duration=min(
450
+ float(self.config.get("duration") or (380.0 if "medium" in sa3_name else 120.0)),
451
+ 380.0 if "medium" in sa3_name else 120.0,
452
+ ),
453
+ base_precision=precision,
454
+ include=include,
455
+ exclude=exclude,
456
+ seed=(int(self.config["seed"]) if self.config.get("seed") is not None else 42),
457
+ checkpoint_every=int(self.config.get("checkpointSteps") or DEFAULT_CHECKPOINT_STEPS),
458
+ name=self.config.get("modelName") or "fragmenta-lora",
459
+ )
460
+ env = build_train_env(sa3_vendor, self._hub_dir)
461
+ return cmd, env
462
+
463
+ def _spawn(self, cmd: List[str], env: Dict[str, str]) -> None:
464
+ log_path = self.run_dir / "training.log"
465
+ rank = int(self.config.get("loraRank") or DEFAULT_RANK)
466
+ alpha_cfg = self.config.get("loraAlpha")
467
+ alpha = int(alpha_cfg) if alpha_cfg not in (None, "") else rank
468
+ # Stamp training_metadata.json so /api/loras can find the base_model
469
+ # if the embedded safetensors metadata is missing it (legacy paths).
470
+ (self.run_dir / "training_metadata.json").write_text(json.dumps({
471
+ "mode": "lora",
472
+ "engine": "sa3",
473
+ "base_model": self.config.get("baseModel"),
474
+ "model_name": self.config.get("modelName"),
475
+ "started_at": time.time(),
476
+ "lora_config": {
477
+ "rank": rank,
478
+ "alpha": alpha,
479
+ "adapter_type": self.config.get("adapterType") or DEFAULT_ADAPTER,
480
+ "dropout": float(self.config.get("loraDropout") or 0.0),
481
+ },
482
+ "steps": int(self.config.get("steps") or DEFAULT_STEPS),
483
+ "lr": float(self.config.get("learningRate") or DEFAULT_LR),
484
+ "batch_size": int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE),
485
+ }, indent=2))
486
+
487
+ self.status.update({
488
+ "is_training": True,
489
+ "status": "running",
490
+ "step": 0,
491
+ "total_steps": int(self.config.get("steps") or DEFAULT_STEPS),
492
+ "loss": None,
493
+ "error": None,
494
+ "started_at": time.time(),
495
+ "ended_at": None,
496
+ "checkpoints": [],
497
+ "message": "Starting training subprocess...",
498
+ })
499
+
500
+ self.process = subprocess.Popen(
501
+ cmd,
502
+ cwd=str(get_config().project_root),
503
+ env=env,
504
+ stdout=subprocess.PIPE,
505
+ stderr=subprocess.STDOUT,
506
+ text=True,
507
+ bufsize=1,
508
+ )
509
+ self._monitor_thread = threading.Thread(
510
+ target=self._monitor,
511
+ args=(log_path,),
512
+ daemon=True,
513
+ name=f"sa3-train-monitor:{self.run_dir.name}",
514
+ )
515
+ self._monitor_thread.start()
516
+
517
+ def _monitor(self, log_path: Path) -> None:
518
+ """Pull stdout, parse PyTorch Lightning progress, scrape loss, watch checkpoints.
519
+
520
+ SA3 trains via PL whose default progress bar emits *per-epoch* step
521
+ counts ("Epoch 6: 50%|...| 25/50 [00:07<00:07, 3.36it/s, train/loss=0.559]").
522
+ We derive the global step as `epoch * batches_per_epoch + step_in_epoch`,
523
+ capture `batches_per_epoch` from the first such line (it's stable across
524
+ epochs since SampleDataset returns a fixed length), and clamp the
525
+ result to the configured max_steps so the percentage doesn't go past
526
+ 100 if the final epoch overruns.
527
+ """
528
+ epoch_pat = re.compile(r"Epoch\s+(\d+):")
529
+ in_epoch_pat = re.compile(r"\|\s*(\d+)/(\d+)\b") # tqdm's "current/total"
530
+ loss_pat = re.compile(r"train/loss=([\d.eE+\-]+)")
531
+ speed_pat = re.compile(r"([\d.]+)it/s")
532
+ last_log_flush = time.time()
533
+ last_ckpt_scan = 0.0
534
+ last_terminal_log = 0.0
535
+ last_logged_step = -1
536
+ prev_ckpt_count = 0
537
+ current_epoch = 0
538
+ batches_per_epoch = 0
539
+ try:
540
+ with open(log_path, "w") as logf:
541
+ if self.process and self.process.stdout:
542
+ for line in self.process.stdout:
543
+ line = line.rstrip()
544
+ logf.write(line + "\n")
545
+ if time.time() - last_log_flush > 1:
546
+ logf.flush()
547
+ last_log_flush = time.time()
548
+ self.status["log_tail"].append(line)
549
+ if len(self.status["log_tail"]) > 80:
550
+ self.status["log_tail"] = self.status["log_tail"][-50:]
551
+
552
+ # Only parse the step counter on lines that ARE the
553
+ # training progress bar (prefixed with "Epoch N:"),
554
+ # so unrelated tqdm bars during startup (e.g.
555
+ # "Loading checkpoint shards: 9/9") don't pollute
556
+ # batches_per_epoch.
557
+ m_epoch = epoch_pat.search(line)
558
+ if m_epoch:
559
+ current_epoch = int(m_epoch.group(1))
560
+ m_step = in_epoch_pat.search(line)
561
+ if m_step:
562
+ cur_in_epoch = int(m_step.group(1))
563
+ per_epoch = int(m_step.group(2))
564
+ if per_epoch > 0 and batches_per_epoch == 0:
565
+ batches_per_epoch = per_epoch
566
+ if batches_per_epoch > 0:
567
+ global_step = current_epoch * batches_per_epoch + cur_in_epoch
568
+ max_steps = self.status.get("total_steps") or 0
569
+ if max_steps > 0:
570
+ global_step = min(global_step, max_steps)
571
+ if global_step > self.status.get("step", 0):
572
+ self.status["step"] = global_step
573
+
574
+ m_loss = loss_pat.search(line)
575
+ if m_loss:
576
+ try:
577
+ self.status["loss"] = float(m_loss.group(1))
578
+ except ValueError:
579
+ pass
580
+
581
+ # Live checkpoint enumeration + loss history scrape.
582
+ # Lightning writes *.ckpt every N steps; we want the
583
+ # count to climb in the UI as files appear, not only
584
+ # at end-of-run. Bucketed to ~2s so we don't pound
585
+ # the FS. The loss history scrape backfills step
586
+ # 0..49 from metrics.csv since PL's stdout postfix
587
+ # doesn't show train/loss until end-of-epoch-0.
588
+ now = time.time()
589
+ if now - last_ckpt_scan > 2.0:
590
+ last_ckpt_scan = now
591
+ self._scan_checkpoints()
592
+ self._scrape_loss_history()
593
+ cur_ckpt_count = len(self.status.get("checkpoints") or [])
594
+ if cur_ckpt_count > prev_ckpt_count:
595
+ logger.info(
596
+ "Checkpoint saved Β· %d total Β· run=%s",
597
+ cur_ckpt_count, self.run_dir.name,
598
+ )
599
+ prev_ckpt_count = cur_ckpt_count
600
+
601
+ # Throttled progress to the backend terminal log.
602
+ # Lightning emits step lines ~3Γ— per second; we
603
+ # condense to one tidy summary every 5s. Omit the
604
+ # loss segment when we don't have a value yet (the
605
+ # CSV scrape runs every 2s but PL may not have
606
+ # logged anything during the very first second).
607
+ cur_step = self.status.get("step") or 0
608
+ if (cur_step > last_logged_step
609
+ and now - last_terminal_log >= 5.0):
610
+ total = self.status.get("total_steps") or 0
611
+ loss = self.status.get("loss")
612
+ pct = round(100 * cur_step / total) if total > 0 else 0
613
+ speed_m = speed_pat.search(line)
614
+ parts = [f"step {cur_step}/{total} ({pct}%)"]
615
+ if isinstance(loss, (int, float)):
616
+ parts.append(f"loss {loss:.4f}")
617
+ if speed_m:
618
+ parts.append(f"{speed_m.group(1)} it/s")
619
+ logger.info(" Β· ".join(parts))
620
+ last_terminal_log = now
621
+ last_logged_step = cur_step
622
+ rc = self.process.wait() if self.process else 1
623
+ except Exception as e:
624
+ self.status["error"] = str(e)
625
+ rc = -1
626
+
627
+ self.status["ended_at"] = time.time()
628
+ self.status["is_training"] = False
629
+ # A user-requested stop wins regardless of the exit code (SIGINT can
630
+ # surface as various negative/non-zero codes across platforms).
631
+ if getattr(self, "_stop_requested", False):
632
+ self.status["status"] = "stopped"
633
+ else:
634
+ self.status["status"] = "complete" if rc == 0 else "failed"
635
+ if self.status["status"] == "failed" and not self.status.get("error"):
636
+ self.status["error"] = f"train_lora.py exited with code {rc}"
637
+
638
+ # Convert PyTorch Lightning .ckpt files to SA3's native .safetensors
639
+ # LoRA format β€” the inference loader (/api/loras) only sees
640
+ # .safetensors, so unconverted .ckpt files would be functionally
641
+ # orphaned. We also inject `base_model` into the safetensors header
642
+ # so /api/loras' metadata filter passes without a JSON fallback.
643
+ # Best-effort: failure here doesn't fail the run.
644
+ if self.status["status"] in ("complete", "stopped") and self.run_dir:
645
+ try:
646
+ produced = convert_run_checkpoints_to_safetensors(
647
+ self.run_dir,
648
+ base_model=self.config.get("baseModel"),
649
+ model_name=self.config.get("modelName"),
650
+ )
651
+ if produced:
652
+ logger.info(
653
+ "Converted %d checkpoint(s) to .safetensors Β· run=%s",
654
+ len(produced), self.run_dir.name,
655
+ )
656
+ except Exception as exc:
657
+ logger.warning("Checkpoint conversion failed: %s", exc)
658
+
659
+ # Final pass: enumerate written checkpoints + full loss history +
660
+ # latest single-value loss.
661
+ self._scan_checkpoints()
662
+ self._scrape_loss_history()
663
+ self._scrape_csv_loss()
664
+
665
+ final_step = self.status.get("step") or 0
666
+ final_total = self.status.get("total_steps") or 0
667
+ final_loss = self.status.get("loss")
668
+ final_ckpts = len(self.status.get("checkpoints") or [])
669
+ loss_str = f"{final_loss:.4f}" if isinstance(final_loss, (int, float)) else "β€”"
670
+ if self.status["status"] == "complete":
671
+ logger.info(
672
+ "Training complete Β· %d/%d steps Β· final loss %s Β· %d checkpoint(s) Β· run=%s",
673
+ final_step, final_total, loss_str, final_ckpts, self.run_dir.name,
674
+ )
675
+ elif self.status["status"] == "stopped":
676
+ logger.info(
677
+ "Training stopped at step %d/%d Β· %d checkpoint(s) Β· run=%s",
678
+ final_step, final_total, final_ckpts, self.run_dir.name,
679
+ )
680
+ else:
681
+ logger.error(
682
+ "Training failed (exit %s) Β· %d/%d steps Β· error: %s Β· run=%s",
683
+ rc, final_step, final_total, self.status.get("error"), self.run_dir.name,
684
+ )
685
+
686
+ def _scrape_loss_history(self) -> None:
687
+ """Refresh self.status['loss_history'] from Lightning's metrics.csv.
688
+
689
+ PL's tqdm postfix only surfaces `train/loss=` *after* the first
690
+ metrics flush (typically end-of-epoch-0), so step 0..49 of a fresh
691
+ run never appear in stdout. metrics.csv, on the other hand, has
692
+ per-step rows from step 0 β€” we just need to read it.
693
+
694
+ Cheap: even at 10K steps a CSV scan is sub-10ms. Skipped silently
695
+ if the file hasn't been created yet (early in the run, before PL's
696
+ CSVLogger flushes anything).
697
+ """
698
+ if not self.metrics_csv or not self.metrics_csv.exists():
699
+ # CSVLogger writes under <save_dir>/lightning_logs/version_*/
700
+ if self.run_dir:
701
+ for p in (self.run_dir / "checkpoints").rglob("metrics.csv"):
702
+ self.metrics_csv = p
703
+ break
704
+ if not self.metrics_csv or not self.metrics_csv.exists():
705
+ return
706
+ try:
707
+ with open(self.metrics_csv) as f:
708
+ rows = list(csv.DictReader(f))
709
+ except Exception:
710
+ return
711
+ points: List[Dict[str, Any]] = []
712
+ loss_keys = ("train/loss", "loss", "train_loss")
713
+ for row in rows:
714
+ step_raw = row.get("step")
715
+ if step_raw in (None, ""):
716
+ continue
717
+ try:
718
+ step = int(step_raw)
719
+ except ValueError:
720
+ continue
721
+ for k in loss_keys:
722
+ v = row.get(k)
723
+ if v not in (None, ""):
724
+ try:
725
+ points.append({"step": step, "loss": float(v)})
726
+ except ValueError:
727
+ pass
728
+ break
729
+ # Dedupe: csv can have multiple rows per step (different metric flush
730
+ # boundaries) β€” keep the last loss seen for each step.
731
+ by_step: Dict[int, float] = {}
732
+ for p in points:
733
+ by_step[p["step"]] = p["loss"]
734
+ ordered = sorted(by_step.items())
735
+ self.status["loss_history"] = [{"step": s, "loss": l} for s, l in ordered]
736
+ # Also surface the most recent loss as the scalar so the terminal
737
+ # log and "Current Loss" field don't show "β€”" until end-of-epoch-0.
738
+ # PL's tqdm postfix is async; the CSV row lands a beat ahead.
739
+ if ordered:
740
+ self.status["loss"] = ordered[-1][1]
741
+
742
+ def _scan_checkpoints(self) -> None:
743
+ """Update self.status['checkpoints'] from on-disk artifacts.
744
+
745
+ SA3's train_lora.py uses PyTorch Lightning's ModelCheckpoint, which
746
+ writes `.ckpt` files (Lightning pickle format). The diffusion wrapper's
747
+ `on_save_checkpoint` hook strips the state_dict to LoRA-only weights
748
+ plus the embedded `lora_config`, so each .ckpt IS a LoRA checkpoint.
749
+ We also accept .safetensors for forward-compat with a future export
750
+ path or manual conversion.
751
+ """
752
+ if not self.run_dir:
753
+ return
754
+ ckpt_dir = self.run_dir / "checkpoints"
755
+ if not ckpt_dir.exists():
756
+ return
757
+ found = []
758
+ for ext in ("*.safetensors", "*.ckpt"):
759
+ found.extend(ckpt_dir.glob(ext))
760
+ # Lightning writes nested lightning_logs/version_X/* β€” those aren't
761
+ # the user-facing artifacts; skip recursion.
762
+ project_root = get_config().project_root
763
+ self.status["checkpoints"] = sorted(
764
+ str(p.relative_to(project_root)) for p in found
765
+ )
766
+
767
+ def _scrape_csv_loss(self) -> None:
768
+ if not self.metrics_csv or not self.metrics_csv.exists():
769
+ # train_lora.py writes its CSV under the lightning logger dir,
770
+ # which is `<save_dir>/<name>/version_*/metrics.csv`. Walk to
771
+ # find it.
772
+ ckpt_dir = self.run_dir / "checkpoints"
773
+ for p in ckpt_dir.rglob("metrics.csv"):
774
+ self.metrics_csv = p
775
+ break
776
+ if not self.metrics_csv or not self.metrics_csv.exists():
777
+ return
778
+ try:
779
+ with open(self.metrics_csv) as f:
780
+ rows = list(csv.DictReader(f))
781
+ for row in reversed(rows):
782
+ for k in ("train/loss", "loss", "train_loss"):
783
+ v = row.get(k)
784
+ if v not in (None, ""):
785
+ try:
786
+ self.status["loss"] = float(v)
787
+ return
788
+ except ValueError:
789
+ pass
790
+ except Exception:
791
+ pass
792
+
793
+ @staticmethod
794
+ def _safe_name(s: str) -> str:
795
+ return re.sub(r"[^a-zA-Z0-9_-]+", "_", s).strip("_") or "lora-run"
796
+
797
+
798
+ # --- Module-level singleton + public functions -----------------------------
799
+
800
+ _active: Optional[SA3Trainer] = None
801
+ _lock = threading.Lock()
802
+
803
+
804
+ def get_trainer() -> Optional[SA3Trainer]:
805
+ return _active
806
+
807
+
808
+ def start_training(config: Dict[str, Any]) -> Dict[str, Any]:
809
+ global _active
810
+ with _lock:
811
+ if _active and _active.status.get("is_training"):
812
+ return {"error": "A training run is already in progress."}
813
+ _active = SA3Trainer(config)
814
+ return _active.start()
815
+
816
+
817
+ def get_training_status() -> Dict[str, Any]:
818
+ if _active is None:
819
+ return {
820
+ "is_training": False,
821
+ "status": "idle",
822
+ "message": "No training run has been started yet.",
823
+ "progress": 0,
824
+ "current_step": 0,
825
+ "total_steps": 0,
826
+ "checkpoints_saved": 0,
827
+ "loss": None,
828
+ }
829
+ return _active.get_status()
830
+
831
+
832
+ def stop_training() -> Dict[str, Any]:
833
+ if _active is None:
834
+ return {"error": "No training run to stop."}
835
+ return _active.stop()
836
+
837
+
838
+ def preview_training_plan(config: Dict[str, Any]) -> Dict[str, Any]:
839
+ return SA3Trainer(config).preview_plan()
app/frontend/index.html CHANGED
@@ -7,15 +7,38 @@
7
  <meta name="theme-color" content="#000000" />
8
  <meta
9
  name="description"
10
- content="Fragmenta Desktop - Stable Audio Fine-Tuning Application"
11
  />
12
  <link rel="manifest" href="/manifest.json" />
13
 
14
- <link rel="preconnect" href="https://fonts.googleapis.com">
15
- <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
16
- <link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500;600&family=Space+Mono:wght@400;700&family=IBM+Plex+Mono:wght@300;400;500;600&display=swap" rel="stylesheet">
17
-
18
  <style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @font-face {
20
  font-family: 'Bitcount Single';
21
  src: url('/BitcountSingle-VariableFont_CRSV,ELSH,ELXP,slnt,wght.ttf') format('truetype');
@@ -25,7 +48,7 @@
25
  }
26
  </style>
27
 
28
- <title>Fragmenta Desktop</title>
29
  </head>
30
  <body>
31
  <noscript>You need to enable JavaScript to run this app.</noscript>
 
7
  <meta name="theme-color" content="#000000" />
8
  <meta
9
  name="description"
10
+ content="Fragmenta β€” Stable Audio Fine-Tuning Application"
11
  />
12
  <link rel="manifest" href="/manifest.json" />
13
 
 
 
 
 
14
  <style>
15
+ /* Layout floor β€” channel grid + master strip need 1300px
16
+ horizontally to sit side-by-side, and the vertical layout (top
17
+ bar + channels + bottom bar) gets cramped below 830px. Below
18
+ either floor, scrollbars appear so the layout stays intact
19
+ instead of collapsing. The launcher (start.py) opens Chromium
20
+ at 1300Γ—830 so the fresh-launch experience lands exactly at
21
+ the floor. */
22
+ html, body {
23
+ min-width: 1300px;
24
+ min-height: 830px;
25
+ }
26
+
27
+ /* Local variable fonts β€” ship with the app, no network dependency. */
28
+ @font-face {
29
+ font-family: 'Bricolage Grotesque';
30
+ src: url('/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf') format('truetype');
31
+ font-weight: 200 800;
32
+ font-style: normal;
33
+ font-display: swap;
34
+ }
35
+ @font-face {
36
+ font-family: 'Inter Tight';
37
+ src: url('/InterTight-VariableFont_wght.ttf') format('truetype');
38
+ font-weight: 100 900;
39
+ font-style: normal;
40
+ font-display: swap;
41
+ }
42
  @font-face {
43
  font-family: 'Bitcount Single';
44
  src: url('/BitcountSingle-VariableFont_CRSV,ELSH,ELXP,slnt,wght.ttf') format('truetype');
 
48
  }
49
  </style>
50
 
51
+ <title>Fragmenta</title>
52
  </head>
53
  <body>
54
  <noscript>You need to enable JavaScript to run this app.</noscript>
app/frontend/logs/fragmenta_20260525.log ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ 2026-05-25 11:21:33 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
2
+ 2026-05-25 11:21:33 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
3
+ 2026-05-25 11:44:54 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
4
+ 2026-05-25 11:44:54 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
5
+ 2026-05-25 13:55:04 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
6
+ 2026-05-25 13:55:04 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
7
+ 2026-05-25 13:55:05 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO)
8
+ 2026-05-25 13:55:05 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log
app/frontend/package.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "name": "fragmenta-desktop",
3
- "version": "0.1.2",
4
- "description": "Fragmenta Desktop",
5
  "type": "module",
6
  "scripts": {
7
  "dev": "vite",
 
1
  {
2
  "name": "fragmenta-desktop",
3
+ "version": "0.2.0",
4
+ "description": "Fragmenta",
5
  "type": "module",
6
  "scripts": {
7
  "dev": "vite",
app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31b91d15aae398699fae58363dbc8ca1167faffe7d2cd62e68c716dcaa7d5fdd
3
+ size 407844
app/frontend/public/InterTight-VariableFont_wght.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b8ef9ed255ebe7341aa566554c0f3e87ee10ce06d2085f07ccf66f41ef96c28
3
+ size 580572
app/frontend/public/fragmenta_background.png CHANGED

Git LFS Details

  • SHA256: 048aea503935f9763e76db3f5d1fcd6d561d3db9aeac415605c46527a3d6631b
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB

Git LFS Details

  • SHA256: f7c5c50356c595570f790621b89da04b93680b2be43803810b33a165111e8600
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
app/frontend/public/interface.png CHANGED

Git LFS Details

  • SHA256: 00d2730e2f53440597b018538ed30200928e26d1034c51ec8ef7a95fc0477e98
  • Pointer size: 132 Bytes
  • Size of remote file: 1.81 MB

Git LFS Details

  • SHA256: a05704da0c9b7ea812b44d94186d81fe969a3963bf11cca6c79fbadf5d33f645
  • Pointer size: 132 Bytes
  • Size of remote file: 1.59 MB
app/frontend/src/App.js CHANGED
The diff for this file is too large to render. See raw diff
 
app/frontend/src/api.js CHANGED
@@ -37,6 +37,7 @@ const api = {
37
  get: (url, config) => request('GET', url, null, config),
38
  post: (url, body, config) => request('POST', url, body, config),
39
  put: (url, body, config) => request('PUT', url, body, config),
 
40
  delete: (url, config) => request('DELETE', url, null, config),
41
  };
42
 
 
37
  get: (url, config) => request('GET', url, null, config),
38
  post: (url, body, config) => request('POST', url, body, config),
39
  put: (url, body, config) => request('PUT', url, body, config),
40
+ patch: (url, body, config) => request('PATCH', url, body, config),
41
  delete: (url, config) => request('DELETE', url, null, config),
42
  };
43
 
app/frontend/src/components/AboutDialog.js ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react';
2
+ import {
3
+ Box,
4
+ Button,
5
+ Dialog,
6
+ DialogActions,
7
+ DialogContent,
8
+ DialogTitle,
9
+ Typography,
10
+ } from '@mui/material';
11
+ import {
12
+ Info as InfoIcon,
13
+ BookOpen as BookOpenIcon,
14
+ } from 'lucide-react';
15
+ import { appStyles } from '../theme';
16
+ import { APP_VERSION } from '../version';
17
+
18
+ /**
19
+ * "About Fragmenta" dialog β€” logo + title, short intro, three doc buttons
20
+ * (About / Documentation / Tutorials), and the Stability AI Community
21
+ * License attribution footer.
22
+ *
23
+ * Props:
24
+ * open: bool
25
+ * onClose: () => void
26
+ * onOpenDocumentation: ('about' | 'documentation') => void
27
+ * isOpeningDocumentation: bool β€” disables the doc buttons while a
28
+ * native open-file call is in flight
29
+ */
30
+ export default function AboutDialog({
31
+ open,
32
+ onClose,
33
+ onOpenDocumentation,
34
+ isOpeningDocumentation,
35
+ }) {
36
+ return (
37
+ <Dialog
38
+ open={open}
39
+ onClose={onClose}
40
+ aria-labelledby="about-documentation-dialog-title"
41
+ maxWidth="sm"
42
+ fullWidth
43
+ >
44
+ <DialogTitle id="about-documentation-dialog-title">
45
+ <Box sx={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: 1 }}>
46
+ <Box sx={{
47
+ ...appStyles.logo,
48
+ width: 52, height: 52,
49
+ border: 'none',
50
+ boxShadow: 'none',
51
+ filter: 'none',
52
+ }} />
53
+ <Typography variant="h5" component="span" sx={appStyles.title}>
54
+ Fragmenta
55
+ </Typography>
56
+ <Typography variant="caption" color="text.secondary" sx={{ fontSize: '0.7rem', letterSpacing: '0.04em' }}>
57
+ v{APP_VERSION}
58
+ </Typography>
59
+ </Box>
60
+ </DialogTitle>
61
+ <DialogContent>
62
+ <Typography sx={appStyles.infoDialogIntro}>
63
+ Fragmenta is an open source, local-first suit to prepare datasets, train, generate and perform with text-to-audio diffusion models.
64
+ Made by the composer and researcher Misagh Azimi.
65
+ </Typography>
66
+
67
+ <Box sx={appStyles.infoDialogActionStack}>
68
+ <Button
69
+ variant="contained"
70
+ size="small"
71
+ startIcon={<InfoIcon size={16} />}
72
+ onClick={() => onOpenDocumentation('about')}
73
+ disabled={isOpeningDocumentation}
74
+ sx={appStyles.infoDocButton}
75
+ >
76
+ About
77
+ </Button>
78
+ <Button
79
+ variant="outlined"
80
+ size="small"
81
+ startIcon={<BookOpenIcon size={16} />}
82
+ onClick={() => onOpenDocumentation('documentation')}
83
+ disabled={isOpeningDocumentation}
84
+ sx={appStyles.infoDocButton}
85
+ >
86
+ Documentation
87
+ </Button>
88
+ <Button
89
+ variant="outlined"
90
+ size="small"
91
+ disabled
92
+ sx={appStyles.infoDocButton}
93
+ >
94
+ Tutorials (Coming soon...)
95
+ </Button>
96
+ </Box>
97
+
98
+ <Box sx={{ mt: 3, pt: 1.5, borderTop: '1px solid', borderColor: 'divider', textAlign: 'center' }}>
99
+ <Typography variant="caption" color="textSecondary" sx={{ display: 'block', fontStyle: 'italic', fontSize: '0.6rem', lineHeight: 1.5 }}>
100
+ Powered by{' '}
101
+ <Typography
102
+ component="a"
103
+ variant="caption"
104
+ href="https://github.com/Stability-AI/stable-audio-3"
105
+ target="_blank"
106
+ rel="noopener noreferrer"
107
+ sx={{ color: 'primary.main', textDecoration: 'underline', fontStyle: 'italic', fontSize: '0.6rem' }}
108
+ >
109
+ Stable Audio 3
110
+ </Typography>{' '}by Stability AI. "This Stability AI Model is licensed under the{' '}
111
+ <Typography
112
+ component="a"
113
+ variant="caption"
114
+ href="https://stability.ai/license"
115
+ target="_blank"
116
+ rel="noopener noreferrer"
117
+ sx={{ color: 'primary.main', textDecoration: 'underline', fontStyle: 'italic', fontSize: '0.6rem' }}
118
+ >
119
+ Stability AI Community License
120
+ </Typography>,{' '}
121
+ Copyright Β© Stability AI Ltd. All Rights Reserved"
122
+ </Typography>
123
+ </Box>
124
+ </DialogContent>
125
+ <DialogActions>
126
+ <Button onClick={onClose}>Close</Button>
127
+ </DialogActions>
128
+ </Dialog>
129
+ );
130
+ }
app/frontend/src/components/AudioWaveform.js ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useEffect, useRef, useState, useCallback } from 'react';
2
+ import { Box, Typography } from '@mui/material';
3
+
4
+ /**
5
+ * Canvas waveform with a single draggable region (for SA3 inpaint UX).
6
+ *
7
+ * Decodes the supplied File via the Web Audio API (no network round-trip),
8
+ * computes per-pixel min/max peaks once per (file, width) pair, and renders
9
+ * a region overlay + two draggable handles. Region drag in three modes:
10
+ * - drag the left handle β†’ adjust start
11
+ * - drag the right handle β†’ adjust end
12
+ * - drag the body β†’ shift the whole region in place
13
+ *
14
+ * Region is controlled: parent owns `start` / `end` in seconds.
15
+ *
16
+ * Props:
17
+ * file: File | null β€” source audio
18
+ * duration: number β€” clip length in seconds (must be passed; we
19
+ * don't infer it from decoded length so the
20
+ * caller can drive a probe before decode
21
+ * finishes)
22
+ * start, end: number β€” region in seconds
23
+ * onRegionChange: (start, end) => void
24
+ * minRegionSec: number β€” default 0.1
25
+ * height: number β€” canvas height in px (default 96)
26
+ * color: CSS color β€” waveform peak color (default theme accent)
27
+ * regionColor: CSS color β€” fill for the region rect
28
+ */
29
+ export default function AudioWaveform({
30
+ file,
31
+ duration,
32
+ start,
33
+ end,
34
+ onRegionChange,
35
+ minRegionSec = 0.1,
36
+ height = 96,
37
+ color = '#279FBB',
38
+ regionColor = 'rgba(253, 162, 43, 0.28)',
39
+ }) {
40
+ const canvasRef = useRef(null);
41
+ const containerRef = useRef(null);
42
+ const [width, setWidth] = useState(0);
43
+ const [peaks, setPeaks] = useState(null);
44
+ const [decoding, setDecoding] = useState(false);
45
+ const [decodeError, setDecodeError] = useState(null);
46
+ // Drag state lives in a ref to avoid re-renders during pointer move.
47
+ const dragRef = useRef(null);
48
+
49
+ // --- responsive width via ResizeObserver -----------------------------
50
+ useEffect(() => {
51
+ const el = containerRef.current;
52
+ if (!el) return;
53
+ const ro = new ResizeObserver((entries) => {
54
+ const w = Math.max(1, Math.floor(entries[0].contentRect.width));
55
+ setWidth(w);
56
+ });
57
+ ro.observe(el);
58
+ return () => ro.disconnect();
59
+ }, []);
60
+
61
+ // --- decode + peak computation ---------------------------------------
62
+ useEffect(() => {
63
+ if (!file || !width) return;
64
+ let cancelled = false;
65
+ setDecoding(true);
66
+ setDecodeError(null);
67
+
68
+ (async () => {
69
+ try {
70
+ const buf = await file.arrayBuffer();
71
+ if (cancelled) return;
72
+ // Reuse one AudioContext where possible. Safari and Chrome both
73
+ // permit creating an offline one for pure decode without user
74
+ // gesture, which is what we want.
75
+ const Ctx = window.OfflineAudioContext || window.webkitOfflineAudioContext;
76
+ const tmpCtx = Ctx
77
+ ? new Ctx(1, 44100, 44100)
78
+ : new (window.AudioContext || window.webkitAudioContext)();
79
+ const audio = await tmpCtx.decodeAudioData(buf.slice(0));
80
+ if (cancelled) return;
81
+
82
+ // Average across channels into mono peaks, then bucket into
83
+ // `width` columns. Each column gets (min, max) in [-1, 1].
84
+ const ch0 = audio.getChannelData(0);
85
+ const ch1 = audio.numberOfChannels > 1 ? audio.getChannelData(1) : null;
86
+ const totalSamples = ch0.length;
87
+ const bucketSize = Math.max(1, Math.floor(totalSamples / width));
88
+ const out = new Float32Array(width * 2);
89
+ for (let i = 0; i < width; i++) {
90
+ const s = i * bucketSize;
91
+ const e = Math.min(totalSamples, s + bucketSize);
92
+ let mn = 0, mx = 0;
93
+ for (let j = s; j < e; j++) {
94
+ const v = ch1 ? (ch0[j] + ch1[j]) * 0.5 : ch0[j];
95
+ if (v < mn) mn = v;
96
+ if (v > mx) mx = v;
97
+ }
98
+ out[i * 2] = mn;
99
+ out[i * 2 + 1] = mx;
100
+ }
101
+ setPeaks(out);
102
+ } catch (err) {
103
+ setDecodeError(err.message || 'Failed to decode audio');
104
+ } finally {
105
+ if (!cancelled) setDecoding(false);
106
+ }
107
+ })();
108
+
109
+ return () => { cancelled = true; };
110
+ }, [file, width]);
111
+
112
+ // --- canvas drawing --------------------------------------------------
113
+ const draw = useCallback(() => {
114
+ const canvas = canvasRef.current;
115
+ if (!canvas || !width || !height) return;
116
+ const dpr = window.devicePixelRatio || 1;
117
+ canvas.width = width * dpr;
118
+ canvas.height = height * dpr;
119
+ const ctx = canvas.getContext('2d');
120
+ ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
121
+ ctx.clearRect(0, 0, width, height);
122
+
123
+ // Background: faint center line so empty audio still shows scale.
124
+ ctx.fillStyle = 'rgba(255, 255, 255, 0.05)';
125
+ ctx.fillRect(0, height / 2 - 0.5, width, 1);
126
+
127
+ // Peaks
128
+ if (peaks) {
129
+ ctx.fillStyle = color;
130
+ const mid = height / 2;
131
+ const scale = (height - 4) / 2;
132
+ for (let i = 0; i < width; i++) {
133
+ const mn = peaks[i * 2];
134
+ const mx = peaks[i * 2 + 1];
135
+ const y0 = mid - mx * scale;
136
+ const y1 = mid - mn * scale;
137
+ ctx.fillRect(i, y0, 1, Math.max(1, y1 - y0));
138
+ }
139
+ }
140
+
141
+ // Region overlay
142
+ if (duration > 0 && Number.isFinite(start) && Number.isFinite(end)) {
143
+ const sPx = Math.max(0, Math.min(width, (start / duration) * width));
144
+ const ePx = Math.max(0, Math.min(width, (end / duration) * width));
145
+ const rectW = Math.max(1, ePx - sPx);
146
+ ctx.fillStyle = regionColor;
147
+ ctx.fillRect(sPx, 0, rectW, height);
148
+ // Handles
149
+ ctx.fillStyle = '#FDA22B';
150
+ ctx.fillRect(sPx - 1, 0, 2, height);
151
+ ctx.fillRect(ePx - 1, 0, 2, height);
152
+ }
153
+ }, [width, height, peaks, color, regionColor, start, end, duration]);
154
+
155
+ useEffect(() => { draw(); }, [draw]);
156
+
157
+ // --- pointer interaction --------------------------------------------
158
+ const HIT_PX = 8;
159
+ const pxToSec = useCallback((px) => {
160
+ return Math.max(0, Math.min(duration, (px / width) * duration));
161
+ }, [width, duration]);
162
+
163
+ const onPointerDown = (e) => {
164
+ if (!duration || !width) return;
165
+ const rect = canvasRef.current.getBoundingClientRect();
166
+ const px = e.clientX - rect.left;
167
+ const sPx = (start / duration) * width;
168
+ const ePx = (end / duration) * width;
169
+ let mode;
170
+ if (Math.abs(px - sPx) <= HIT_PX) mode = 'start';
171
+ else if (Math.abs(px - ePx) <= HIT_PX) mode = 'end';
172
+ else if (px > sPx && px < ePx) mode = 'body';
173
+ else mode = 'new'; // start a new region by drag
174
+ dragRef.current = {
175
+ mode,
176
+ startPx: px,
177
+ origStart: start,
178
+ origEnd: end,
179
+ };
180
+ canvasRef.current.setPointerCapture(e.pointerId);
181
+ if (mode === 'new') {
182
+ const t = pxToSec(px);
183
+ onRegionChange?.(t, Math.min(duration, t + minRegionSec));
184
+ dragRef.current.mode = 'end';
185
+ dragRef.current.origStart = t;
186
+ dragRef.current.origEnd = t + minRegionSec;
187
+ }
188
+ };
189
+
190
+ const onPointerMove = (e) => {
191
+ const d = dragRef.current;
192
+ if (!d) return;
193
+ const rect = canvasRef.current.getBoundingClientRect();
194
+ const px = e.clientX - rect.left;
195
+ const delta = pxToSec(px) - pxToSec(d.startPx);
196
+ let s = d.origStart;
197
+ let en = d.origEnd;
198
+ if (d.mode === 'start') {
199
+ s = Math.max(0, Math.min(d.origEnd - minRegionSec, d.origStart + delta));
200
+ } else if (d.mode === 'end') {
201
+ en = Math.max(d.origStart + minRegionSec, Math.min(duration, d.origEnd + delta));
202
+ } else if (d.mode === 'body') {
203
+ const span = d.origEnd - d.origStart;
204
+ s = Math.max(0, Math.min(duration - span, d.origStart + delta));
205
+ en = s + span;
206
+ }
207
+ onRegionChange?.(s, en);
208
+ };
209
+
210
+ const onPointerUp = (e) => {
211
+ if (dragRef.current) {
212
+ canvasRef.current.releasePointerCapture(e.pointerId);
213
+ dragRef.current = null;
214
+ }
215
+ };
216
+
217
+ // --- render ----------------------------------------------------------
218
+ return (
219
+ <Box ref={containerRef} sx={{ width: '100%', position: 'relative' }}>
220
+ <canvas
221
+ ref={canvasRef}
222
+ style={{
223
+ width: '100%',
224
+ height,
225
+ display: 'block',
226
+ cursor: dragRef.current ? 'grabbing' : 'crosshair',
227
+ touchAction: 'none',
228
+ borderRadius: 4,
229
+ background: 'rgba(255,255,255,0.02)',
230
+ }}
231
+ onPointerDown={onPointerDown}
232
+ onPointerMove={onPointerMove}
233
+ onPointerUp={onPointerUp}
234
+ onPointerCancel={onPointerUp}
235
+ />
236
+ {(decoding || decodeError || !file) && (
237
+ <Box
238
+ sx={{
239
+ position: 'absolute',
240
+ inset: 0,
241
+ display: 'flex',
242
+ alignItems: 'center',
243
+ justifyContent: 'center',
244
+ pointerEvents: 'none',
245
+ }}
246
+ >
247
+ <Typography variant="caption" color="text.secondary">
248
+ {decodeError
249
+ ? `decode failed: ${decodeError}`
250
+ : !file
251
+ ? 'no source loaded'
252
+ : 'decoding…'}
253
+ </Typography>
254
+ </Box>
255
+ )}
256
+ </Box>
257
+ );
258
+ }
app/frontend/src/components/ChannelFragmentHistory.js ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState } from 'react';
2
+ import {
3
+ Box,
4
+ IconButton,
5
+ Dialog,
6
+ DialogTitle,
7
+ DialogContent,
8
+ DialogContentText,
9
+ DialogActions,
10
+ Button,
11
+ } from '@mui/material';
12
+ import { TIPS } from '../tooltips';
13
+ import Tooltip from './Tooltip';
14
+ import {
15
+ Play as PlayIcon,
16
+ Square as StopIcon,
17
+ Star as StarIcon,
18
+ Trash2 as DeleteIcon,
19
+ Check as CommitIcon,
20
+ Eraser as ClearAllIcon,
21
+ } from 'lucide-react';
22
+ import { performanceChannelStyles as styles } from '../theme';
23
+ import { MidiMappable } from './MidiContext';
24
+
25
+ /**
26
+ * Per-channel rolling fragment history. Always visible (empty-state included)
27
+ * so the user knows the strip exists. Chronological order β€” oldest at
28
+ * the top, newest at the bottom; scrolls vertically when the list grows
29
+ * past ~4 visible rows.
30
+ *
31
+ * Each row exposes four actions, all visible by default (no hover-reveal β€”
32
+ * Performance use is fast, can't afford the discoverability tax):
33
+ * β€’ Cue β–Ά/β–  β€” audition through the cue output (separate from main mix)
34
+ * β€’ Star β˜…/β˜† β€” mark as a keeper. Starred fragments survive the cap
35
+ * eviction; unstarred get dropped FIFO when over cap.
36
+ * β€’ Delete ⌫ β€” remove this fragment from history (cancellable confirm not
37
+ * shown for single deletes β€” the entry can be regenerated
38
+ * or audition can be retriggered after a quick re-tap).
39
+ * β€’ Load βœ“ β€” commit this fragment to the channel strip (becomes the
40
+ * audio the channel plays). Disabled while already loaded.
41
+ *
42
+ * Props:
43
+ * fragments: [{ id, audioUrl, blob, prompt, duration, createdAt,
44
+ * starred, number }]
45
+ * color: channel accent color
46
+ * auditioningId: the id currently playing through cue, or null
47
+ * committedId: the id currently loaded into the channel strip, or null
48
+ * maxFragments: cap, default 50 (informational; eviction lives in parent)
49
+ * on{Audition,Commit,ToggleStar,Delete}: (fragmentId) => void
50
+ * onClearAll: () => void (parent confirms separately β€” we still show
51
+ * a confirm dialog here for the trash-everything action)
52
+ */
53
+ export default function ChannelFragmentHistory({
54
+ fragments,
55
+ color,
56
+ channelIndex,
57
+ auditioningId,
58
+ committedId,
59
+ maxFragments = 50,
60
+ onAudition,
61
+ onCommit,
62
+ onToggleStar,
63
+ onDelete,
64
+ onClearAll,
65
+ }) {
66
+ const [clearConfirmOpen, setClearConfirmOpen] = useState(false);
67
+ // Channel-scoped MIME type for drag-and-drop. The waveform drop target on
68
+ // this same channel listens for this exact type β€” cross-channel drags
69
+ // won't highlight or accept because the mime won't match.
70
+ const dragMime = `application/x-fragmenta-fragment-ch${channelIndex}`;
71
+
72
+ return (
73
+ <Box sx={styles.fragmentHistoryPanel}>
74
+ <Box sx={styles.fragmentHistoryHeader}>
75
+ <Box component="span" sx={styles.fragmentHistoryHeaderText}>
76
+ Fragments
77
+ </Box>
78
+ {fragments.length > 0 && (
79
+ <IconButton
80
+ size="small"
81
+ onClick={() => setClearConfirmOpen(true)}
82
+ sx={styles.fragmentHistoryHeaderBtn}
83
+ aria-label="Clear all fragments"
84
+ >
85
+ <ClearAllIcon size={12} />
86
+ </IconButton>
87
+ )}
88
+ </Box>
89
+
90
+ {fragments.length === 0 ? (
91
+ <Box sx={styles.fragmentHistoryEmpty}>Empty</Box>
92
+ ) : (
93
+ <Box sx={styles.fragmentHistoryList}>
94
+ {fragments.map((fragment) => {
95
+ const isAuditioning = auditioningId === fragment.id;
96
+ const isCommitted = committedId === fragment.id;
97
+ return (
98
+ <Box
99
+ key={fragment.id}
100
+ draggable
101
+ onDragStart={(e) => {
102
+ e.dataTransfer.setData(dragMime, fragment.id);
103
+ e.dataTransfer.effectAllowed = 'copy';
104
+ }}
105
+ sx={{
106
+ ...styles.fragmentRow(color, isCommitted, isAuditioning),
107
+ cursor: 'grab',
108
+ '&:active': { cursor: 'grabbing' },
109
+ }}
110
+ >
111
+ <MidiMappable
112
+ id={`channel.${channelIndex}.fragment.${fragment.id}.audition`}
113
+ label={`Ch ${channelIndex + 1} Β· Fragment ${fragment.number} audition`}
114
+ kind="trigger"
115
+ onChange={() => onAudition(fragment.id)}
116
+ >
117
+ <Tooltip
118
+ title={TIPS.fragments.audition(isAuditioning)}
119
+ placement="top"
120
+ arrow
121
+ enterDelay={300}
122
+ >
123
+ <IconButton
124
+ size="small"
125
+ onClick={() => onAudition(fragment.id)}
126
+ sx={styles.fragmentIconBtn(color, isAuditioning, true)}
127
+ aria-label={isAuditioning ? 'Stop cue' : 'Audition'}
128
+ >
129
+ {isAuditioning
130
+ ? <StopIcon size={12} />
131
+ : <PlayIcon size={12} />}
132
+ </IconButton>
133
+ </Tooltip>
134
+ </MidiMappable>
135
+
136
+ <Box sx={styles.fragmentMeta}>
137
+ <Box component="span" sx={styles.fragmentOrdinal}>
138
+ F{fragment.number}
139
+ </Box>
140
+ </Box>
141
+
142
+ <Tooltip
143
+ title={TIPS.fragments.star(fragment.starred)}
144
+ placement="top"
145
+ arrow
146
+ enterDelay={300}
147
+ >
148
+ <IconButton
149
+ size="small"
150
+ onClick={() => onToggleStar(fragment.id)}
151
+ sx={styles.fragmentIconBtn(color, fragment.starred)}
152
+ aria-label={fragment.starred ? 'Unstar fragment' : 'Star fragment'}
153
+ >
154
+ <StarIcon
155
+ size={12}
156
+ fill={fragment.starred ? color : 'none'}
157
+ strokeWidth={2}
158
+ />
159
+ </IconButton>
160
+ </Tooltip>
161
+
162
+ <IconButton
163
+ size="small"
164
+ onClick={() => onDelete(fragment.id)}
165
+ sx={styles.fragmentDeleteBtn}
166
+ aria-label="Delete fragment"
167
+ >
168
+ <DeleteIcon size={12} />
169
+ </IconButton>
170
+
171
+ <Tooltip
172
+ title={TIPS.fragments.commit(isCommitted)}
173
+ placement="top"
174
+ arrow
175
+ enterDelay={300}
176
+ >
177
+ <span>
178
+ <IconButton
179
+ size="small"
180
+ onClick={() => onCommit(fragment.id)}
181
+ disabled={isCommitted}
182
+ sx={styles.fragmentIconBtn(color, isCommitted, true)}
183
+ aria-label="Load fragment into channel"
184
+ >
185
+ <CommitIcon size={12} strokeWidth={isCommitted ? 3 : 2} />
186
+ </IconButton>
187
+ </span>
188
+ </Tooltip>
189
+ </Box>
190
+ );
191
+ })}
192
+ </Box>
193
+ )}
194
+
195
+ <Dialog open={clearConfirmOpen} onClose={() => setClearConfirmOpen(false)}>
196
+ <DialogTitle>Clear fragment history?</DialogTitle>
197
+ <DialogContent>
198
+ <DialogContentText>
199
+ Removes all {fragments.length} fragments from this channel's history,
200
+ including starred ones. The currently loaded clip stays loaded
201
+ β€” only the history entries are dropped.
202
+ </DialogContentText>
203
+ </DialogContent>
204
+ <DialogActions>
205
+ <Button onClick={() => setClearConfirmOpen(false)}>Cancel</Button>
206
+ <Button
207
+ onClick={() => { setClearConfirmOpen(false); onClearAll?.(); }}
208
+ color="error"
209
+ variant="contained"
210
+ >
211
+ Clear all
212
+ </Button>
213
+ </DialogActions>
214
+ </Dialog>
215
+ </Box>
216
+ );
217
+ }
app/frontend/src/components/CheckpointManagerWindow.js ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useCallback, useEffect, useState } from 'react';
2
+ import {
3
+ Dialog,
4
+ DialogTitle,
5
+ DialogContent,
6
+ DialogActions,
7
+ Box,
8
+ Typography,
9
+ Button,
10
+ IconButton,
11
+ Stack,
12
+ Alert,
13
+ TextField,
14
+ LinearProgress,
15
+ } from '@mui/material';
16
+ import {
17
+ X as CloseIcon,
18
+ HardDrive as StorageIcon,
19
+ LogIn as LoginIcon,
20
+ LogOut as LogoutIcon,
21
+ } from 'lucide-react';
22
+ import api from '../api';
23
+ import CheckpointRow from './CheckpointRow';
24
+ import StorageDrilldown from './StorageDrilldown';
25
+
26
+ const fmtBytes = (n) => {
27
+ if (!n && n !== 0) return 'β€”';
28
+ const units = ['B', 'KB', 'MB', 'GB', 'TB'];
29
+ let v = n;
30
+ let u = 0;
31
+ while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; }
32
+ return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`;
33
+ };
34
+
35
+ export default function CheckpointManagerWindow({ open, onClose }) {
36
+ const [catalog, setCatalog] = useState([]);
37
+ const [storage, setStorage] = useState(null);
38
+ const [env, setEnv] = useState(null);
39
+ const [hfAuth, setHfAuth] = useState({ signed_in: false, username: null });
40
+ const [tokenDraft, setTokenDraft] = useState('');
41
+ const [showTokenInput, setShowTokenInput] = useState(false);
42
+ const [authError, setAuthError] = useState(null);
43
+ const [showStorage, setShowStorage] = useState(false);
44
+ const [loading, setLoading] = useState(false);
45
+ const [error, setError] = useState(null);
46
+
47
+ const refresh = useCallback(async () => {
48
+ setLoading(true);
49
+ setError(null);
50
+ try {
51
+ const [cat, store, auth, environment] = await Promise.all([
52
+ api.get('/api/checkpoints'),
53
+ api.get('/api/checkpoints/storage'),
54
+ api.get('/api/hf-auth/status'),
55
+ api.get('/api/environment'),
56
+ ]);
57
+ setCatalog(cat.data.checkpoints);
58
+ setStorage(store.data);
59
+ setHfAuth(auth.data);
60
+ setEnv(environment.data);
61
+ } catch (e) {
62
+ setError(e.response?.data?.error || e.message);
63
+ } finally {
64
+ setLoading(false);
65
+ }
66
+ }, []);
67
+
68
+ useEffect(() => {
69
+ if (open) refresh();
70
+ }, [open, refresh]);
71
+
72
+ const submitToken = async () => {
73
+ setAuthError(null);
74
+ try {
75
+ await api.post('/api/hf-auth', { token: tokenDraft.trim() });
76
+ setTokenDraft('');
77
+ setShowTokenInput(false);
78
+ refresh();
79
+ } catch (e) {
80
+ setAuthError(e.response?.data?.error || e.message);
81
+ }
82
+ };
83
+
84
+ const logout = async () => {
85
+ try {
86
+ await api.delete('/api/hf-auth');
87
+ refresh();
88
+ } catch (e) {
89
+ setAuthError(e.response?.data?.error || e.message);
90
+ }
91
+ };
92
+
93
+ const anyInstalled = catalog.some(c => c.downloaded);
94
+
95
+ return (
96
+ <>
97
+ <Dialog open={open} onClose={onClose} maxWidth="md" fullWidth scroll="paper">
98
+ <DialogTitle sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
99
+ <Box sx={{ flex: 1 }}>Checkpoint Manager</Box>
100
+ <IconButton size="small" onClick={onClose}><CloseIcon size={18} /></IconButton>
101
+ </DialogTitle>
102
+
103
+ <DialogContent dividers>
104
+ <Box sx={{ mb: 2 }}>
105
+ <Stack direction="row" alignItems="center" spacing={2} flexWrap="wrap">
106
+ <Button
107
+ size="small"
108
+ variant="text"
109
+ startIcon={<StorageIcon size={14} />}
110
+ onClick={() => setShowStorage(true)}
111
+ disabled={!storage}
112
+ >
113
+ {storage
114
+ ? `${fmtBytes(storage.total_used_bytes)} used Β· ${fmtBytes(storage.total_free_bytes)} free`
115
+ : 'β€”'}
116
+ </Button>
117
+
118
+ <Box sx={{ flex: 1 }} />
119
+
120
+ {hfAuth.signed_in ? (
121
+ <Stack direction="row" alignItems="center" spacing={1}>
122
+ <Typography variant="caption" color="text.secondary">
123
+ HuggingFace: signed in as <strong>{hfAuth.username}</strong>
124
+ </Typography>
125
+ <Button
126
+ size="small"
127
+ variant="text"
128
+ startIcon={<LogoutIcon size={14} />}
129
+ onClick={logout}
130
+ >
131
+ Sign out
132
+ </Button>
133
+ </Stack>
134
+ ) : showTokenInput ? (
135
+ <Stack direction="row" alignItems="center" spacing={1}>
136
+ <TextField
137
+ size="small"
138
+ placeholder="hf_..."
139
+ value={tokenDraft}
140
+ onChange={(e) => setTokenDraft(e.target.value)}
141
+ type="password"
142
+ sx={{ width: 240 }}
143
+ />
144
+ <Button size="small" variant="contained" onClick={submitToken}>
145
+ Sign in
146
+ </Button>
147
+ <Button size="small" onClick={() => { setShowTokenInput(false); setTokenDraft(''); }}>
148
+ Cancel
149
+ </Button>
150
+ </Stack>
151
+ ) : (
152
+ <Button
153
+ size="small"
154
+ variant="outlined"
155
+ startIcon={<LoginIcon size={14} />}
156
+ onClick={() => setShowTokenInput(true)}
157
+ >
158
+ Sign in to HuggingFace
159
+ </Button>
160
+ )}
161
+ </Stack>
162
+ {authError && <Alert severity="error" sx={{ mt: 1 }}>{authError}</Alert>}
163
+ </Box>
164
+
165
+ {!hfAuth.signed_in ? (
166
+ <Alert severity="info" sx={{ mb: 2 }}>
167
+ SA3 checkpoints are gated on HuggingFace. You need a{' '}
168
+ <a href="https://huggingface.co/join" target="_blank" rel="noreferrer">
169
+ HuggingFace account
170
+ </a>
171
+ {' '}to continue. Then{' '}
172
+ <a href="https://huggingface.co/settings/tokens" target="_blank" rel="noreferrer">
173
+ create a Read access token
174
+ </a>
175
+ {' '}and sign in above.
176
+ </Alert>
177
+ ) : (
178
+ <Alert severity="info" sx={{ mb: 2 }}>
179
+ You're signed in. Each model is gated β€” click its name below to open the
180
+ HuggingFace page and accept the model's terms before downloading.
181
+ </Alert>
182
+ )}
183
+
184
+ {error && <Alert severity="error" sx={{ mb: 2 }}>{error}</Alert>}
185
+ {loading && <LinearProgress sx={{ mb: 2 }} />}
186
+
187
+ {!loading && !anyInstalled && catalog.length > 0 && (
188
+ <Box sx={{
189
+ p: 2, mb: 2, borderRadius: 1, bgcolor: 'action.hover',
190
+ }}>
191
+ <Typography variant="body2" fontWeight={500}>
192
+ Pick a model to get started.
193
+ </Typography>
194
+ <Typography variant="caption" color="text.secondary">
195
+ Small - Music (1.2 GB) is a good first choice on a laptop or any GPU.
196
+ </Typography>
197
+ </Box>
198
+ )}
199
+
200
+ {[
201
+ { kind: 'post-trained', label: 'Distilled (fast)', hint: '8 steps, cfg locked at 1.0. Prompt, duration and seed only.' },
202
+ { kind: 'base', label: 'Base (full control)', hint: 'CFG-aware. ~50 steps, cfg ~7. Cfg-scale and steps are live controls.' },
203
+ { kind: 'tagger', label: 'Auto-annotation tools', hint: 'Optional helpers for dataset prep. CLAP scores audio against your vocabulary.' },
204
+ ].map(group => {
205
+ const rows = catalog.filter(c => c.kind === group.kind);
206
+ if (!rows.length) return null;
207
+ return (
208
+ <Box key={group.kind} sx={{ mb: 2 }}>
209
+ <Typography variant="subtitle2" sx={{ mb: 0.25 }}>{group.label}</Typography>
210
+ <Typography variant="caption" color="text.secondary" sx={{ display: 'block', mb: 0.75 }}>
211
+ {group.hint}
212
+ </Typography>
213
+ <Box sx={{ border: '1px solid', borderColor: 'divider', borderRadius: 1 }}>
214
+ {rows.map(c => (
215
+ <CheckpointRow
216
+ key={c.id}
217
+ checkpoint={c}
218
+ env={env}
219
+ onAuthRequired={() => setShowTokenInput(true)}
220
+ onChanged={refresh}
221
+ />
222
+ ))}
223
+ </Box>
224
+ </Box>
225
+ );
226
+ })}
227
+ </DialogContent>
228
+
229
+ <DialogActions>
230
+ <Button onClick={refresh} disabled={loading}>Refresh</Button>
231
+ <Button onClick={onClose} variant="contained">Close</Button>
232
+ </DialogActions>
233
+ </Dialog>
234
+
235
+ <StorageDrilldown
236
+ open={showStorage}
237
+ onClose={() => setShowStorage(false)}
238
+ storage={storage}
239
+ catalog={catalog}
240
+ />
241
+ </>
242
+ );
243
+ }
app/frontend/src/components/CheckpointRow.js ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useEffect, useRef, useState } from 'react';
2
+ import {
3
+ Box,
4
+ Typography,
5
+ Button,
6
+ Chip,
7
+ LinearProgress,
8
+ Stack,
9
+ IconButton,
10
+ } from '@mui/material';
11
+ import { TIPS } from '../tooltips';
12
+ import Tooltip from './Tooltip';
13
+ import {
14
+ CloudDownload as DownloadIcon,
15
+ Trash2 as DeleteIcon,
16
+ X as CancelIcon,
17
+ } from 'lucide-react';
18
+ import api from '../api';
19
+
20
+ const fmtBytes = (n) => {
21
+ if (!n && n !== 0) return 'β€”';
22
+ const units = ['B', 'KB', 'MB', 'GB', 'TB'];
23
+ let v = n;
24
+ let u = 0;
25
+ while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; }
26
+ return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`;
27
+ };
28
+
29
+ const hardwareLabel = (hw) => ({
30
+ 'cpu': 'CPU / GPU',
31
+ 'cuda': 'CUDA',
32
+ 'cuda+flash-attn': 'CUDA + Flash-Attn',
33
+ }[hw] || hw);
34
+
35
+ // Why this host can't run a given model, or null if it can. Mirrors the gate
36
+ // in audio_generator._ensure_model. `env` comes from GET /api/environment.
37
+ const hostIncompatReason = (hw, env) => {
38
+ if (!env) return null; // capabilities unknown β€” don't block
39
+ if (hw === 'cuda+flash-attn') {
40
+ if (!env.cuda_available) {
41
+ return 'Requires an NVIDIA CUDA GPU. Use a Small model β€” those run on CPU, Apple Silicon, or any GPU.';
42
+ }
43
+ // Gate on the real capability, not the platform: Windows works once a
44
+ // matching flash-attn wheel is installed (Blackwell/Ampere + cu12x).
45
+ // No wheel β†’ guide the user to install one (or use Docker on WSL2).
46
+ if (!env.flash_attn_available) {
47
+ return env.platform === 'Windows'
48
+ ? 'Requires Flash Attention 2 (flash-attn). No official Windows wheel β€” install a matching prebuilt/built wheel for your torch+CUDA, or run via Docker on WSL2.'
49
+ : 'Requires Flash Attention 2 (flash-attn) β€” not installed. Install it, or use a Small model.';
50
+ }
51
+ }
52
+ if (hw === 'cuda' && !env.cuda_available) {
53
+ return 'Recommended on an NVIDIA CUDA GPU; this host has none.';
54
+ }
55
+ return null;
56
+ };
57
+
58
+ export default function CheckpointRow({ checkpoint, env, onAuthRequired, onChanged }) {
59
+ const [jobId, setJobId] = useState(checkpoint.active_job?.job_id || null);
60
+ const [job, setJob] = useState(checkpoint.active_job || null);
61
+ const [error, setError] = useState(null);
62
+ const [busy, setBusy] = useState(false);
63
+ const pollTimer = useRef(null);
64
+
65
+ // If the parent's refresh tells us about an in-flight job and we don't
66
+ // already have one locally (typical case: dialog was closed mid-download
67
+ // and just got reopened), adopt it. Don't stomp a freshly-started local
68
+ // job_id with stale catalog data β€” only sync when the local state is empty
69
+ // or a *different* job is now active for this checkpoint.
70
+ useEffect(() => {
71
+ const incoming = checkpoint.active_job?.job_id || null;
72
+ if (incoming && incoming !== jobId) {
73
+ setJobId(incoming);
74
+ setJob(checkpoint.active_job);
75
+ }
76
+ }, [checkpoint.active_job, jobId]);
77
+
78
+ useEffect(() => {
79
+ if (!jobId) return undefined;
80
+ const tick = async () => {
81
+ try {
82
+ const r = await api.get(`/api/checkpoints/jobs/${jobId}`);
83
+ setJob(r.data);
84
+ if (['complete', 'failed', 'cancelled'].includes(r.data.status)) {
85
+ if (r.data.status === 'failed' && (r.data.error || '').startsWith('hf_auth_required')) {
86
+ onAuthRequired?.();
87
+ } else if (r.data.status === 'failed') {
88
+ setError(r.data.error);
89
+ }
90
+ setJobId(null);
91
+ onChanged?.();
92
+ }
93
+ } catch (e) {
94
+ setError(e.response?.data?.error || e.message);
95
+ setJobId(null);
96
+ }
97
+ };
98
+ tick();
99
+ pollTimer.current = setInterval(tick, 1500);
100
+ return () => clearInterval(pollTimer.current);
101
+ }, [jobId, onAuthRequired, onChanged]);
102
+
103
+ const startDownload = async () => {
104
+ setBusy(true);
105
+ setError(null);
106
+ try {
107
+ const r = await api.post(`/api/checkpoints/${checkpoint.id}/download`);
108
+ setJobId(r.data.job_id);
109
+ } catch (e) {
110
+ setError(e.response?.data?.error || e.message);
111
+ } finally {
112
+ setBusy(false);
113
+ }
114
+ };
115
+
116
+ const cancelDownload = async () => {
117
+ try {
118
+ await api.post(`/api/checkpoints/${checkpoint.id}/cancel-download`);
119
+ } catch (e) {
120
+ setError(e.response?.data?.error || e.message);
121
+ }
122
+ };
123
+
124
+ const deleteCheckpoint = async () => {
125
+ if (!window.confirm(`Delete ${checkpoint.name} (${fmtBytes(checkpoint.downloaded_bytes)})?`)) return;
126
+ setBusy(true);
127
+ try {
128
+ await api.delete(`/api/checkpoints/${checkpoint.id}`);
129
+ onChanged?.();
130
+ } catch (e) {
131
+ setError(e.response?.data?.error || e.message);
132
+ } finally {
133
+ setBusy(false);
134
+ }
135
+ };
136
+
137
+ const downloading = !!jobId && job?.status === 'running';
138
+ const queued = !!jobId && job?.status === 'queued';
139
+ const pct = job?.total_bytes ? (job.downloaded_bytes / job.total_bytes) * 100 : 0;
140
+ const incompatReason = hostIncompatReason(checkpoint.hardware, env);
141
+
142
+ const renderAction = () => {
143
+ if (downloading || queued) {
144
+ return (
145
+ <IconButton size="small" onClick={cancelDownload} aria-label="Cancel download"><CancelIcon size={16} /></IconButton>
146
+ );
147
+ }
148
+ if (checkpoint.downloaded) {
149
+ return (
150
+ <IconButton size="small" onClick={deleteCheckpoint} disabled={busy} aria-label="Delete from disk">
151
+ <DeleteIcon size={16} />
152
+ </IconButton>
153
+ );
154
+ }
155
+ if (incompatReason) {
156
+ return (
157
+ <Tooltip title={incompatReason}>
158
+ {/* span wrapper so the tooltip works on a disabled button */}
159
+ <span>
160
+ <Button
161
+ size="small"
162
+ variant="outlined"
163
+ startIcon={<DownloadIcon size={14} />}
164
+ disabled
165
+ >
166
+ Get
167
+ </Button>
168
+ </span>
169
+ </Tooltip>
170
+ );
171
+ }
172
+ return (
173
+ <Button
174
+ size="small"
175
+ variant="contained"
176
+ startIcon={<DownloadIcon size={14} />}
177
+ onClick={startDownload}
178
+ disabled={busy}
179
+ >
180
+ Get
181
+ </Button>
182
+ );
183
+ };
184
+
185
+ return (
186
+ <Box
187
+ sx={{
188
+ py: 1.25,
189
+ px: 1.5,
190
+ borderBottom: '1px solid',
191
+ borderColor: 'divider',
192
+ '&:last-child': { borderBottom: 'none' },
193
+ }}
194
+ >
195
+ <Stack direction="row" alignItems="center" spacing={2}>
196
+ <Box sx={{ flex: 1, minWidth: 0, opacity: (incompatReason && !checkpoint.downloaded) ? 0.55 : 1 }}>
197
+ <Stack direction="row" alignItems="center" spacing={1}>
198
+ <Tooltip title={TIPS.checkpoints.gatedAccess}>
199
+ <Typography
200
+ component="a"
201
+ href={`https://huggingface.co/${checkpoint.repo}`}
202
+ target="_blank"
203
+ rel="noreferrer"
204
+ variant="body2"
205
+ sx={{
206
+ fontWeight: 500,
207
+ color: 'inherit',
208
+ textDecoration: 'none',
209
+ borderBottom: '1px dashed',
210
+ borderColor: 'text.disabled',
211
+ '&:hover': { color: 'primary.main', borderColor: 'primary.main' },
212
+ }}
213
+ >
214
+ {checkpoint.name}
215
+ </Typography>
216
+ </Tooltip>
217
+ <Chip
218
+ size="small"
219
+ label={hardwareLabel(checkpoint.hardware)}
220
+ variant="outlined"
221
+ sx={{ height: 18, fontSize: 10 }}
222
+ />
223
+ {checkpoint.downloaded && (
224
+ <Chip
225
+ size="small"
226
+ label="installed"
227
+ sx={{
228
+ height: 18,
229
+ fontSize: 10,
230
+ fontWeight: 600,
231
+ bgcolor: 'success.main',
232
+ color: 'common.white',
233
+ }}
234
+ />
235
+ )}
236
+ </Stack>
237
+ <Typography variant="caption" color="text.secondary">
238
+ {fmtBytes(checkpoint.size_bytes)}
239
+ {checkpoint.max_duration_sec && ` Β· up to ${checkpoint.max_duration_sec}s`}
240
+ </Typography>
241
+ {incompatReason && !checkpoint.downloaded && (
242
+ <Typography variant="caption" color="warning.main" sx={{ display: 'block' }}>
243
+ Not supported on this machine
244
+ </Typography>
245
+ )}
246
+ </Box>
247
+ <Box>{renderAction()}</Box>
248
+ </Stack>
249
+
250
+ {(downloading || queued) && (
251
+ <Box sx={{ mt: 1 }}>
252
+ <LinearProgress
253
+ variant={queued ? 'indeterminate' : 'determinate'}
254
+ value={Math.min(100, pct)}
255
+ sx={{ height: 4, borderRadius: 2 }}
256
+ />
257
+ <Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
258
+ {queued ? 'Queued…' : `${fmtBytes(job?.downloaded_bytes)} / ${fmtBytes(job?.total_bytes)}`}
259
+ </Typography>
260
+ </Box>
261
+ )}
262
+
263
+ {error && (
264
+ <Typography variant="caption" color="error" sx={{ mt: 0.5, display: 'block' }}>
265
+ {error}
266
+ </Typography>
267
+ )}
268
+ </Box>
269
+ );
270
+ }
app/frontend/src/components/DatasetPrep.js ADDED
@@ -0,0 +1,1823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useCallback, useEffect, useRef, useState } from 'react';
2
+ import {
3
+ Accordion,
4
+ AccordionDetails,
5
+ AccordionSummary,
6
+ Alert,
7
+ Autocomplete,
8
+ Box,
9
+ Button,
10
+ Checkbox,
11
+ Chip,
12
+ Dialog,
13
+ DialogActions,
14
+ DialogContent,
15
+ DialogTitle,
16
+ FormControl,
17
+ FormControlLabel,
18
+ IconButton,
19
+ InputLabel,
20
+ LinearProgress,
21
+ MenuItem,
22
+ Paper,
23
+ Portal,
24
+ Radio,
25
+ RadioGroup,
26
+ Select,
27
+ Snackbar,
28
+ Stack,
29
+ Switch,
30
+ Table,
31
+ TableBody,
32
+ TableCell,
33
+ TableContainer,
34
+ TableHead,
35
+ TableRow,
36
+ TextField,
37
+ Typography,
38
+ useTheme,
39
+ } from '@mui/material';
40
+ import { TIPS } from '../tooltips';
41
+ import Tooltip from './Tooltip';
42
+ import {
43
+ ChevronDown as ChevronDownIcon,
44
+ FolderOpenIcon,
45
+ PlusIcon,
46
+ WandSparkles,
47
+ SaveIcon,
48
+ Database as Database,
49
+ DatabaseZap as DatasetIcon,
50
+ Square as StopIcon,
51
+ Trash2 as TrashIcon,
52
+ Play as PlayIcon,
53
+ Pause as PauseIcon,
54
+ Scissors as ScissorsIcon,
55
+ Music as MusicIcon,
56
+ Activity as HealthIcon,
57
+ } from 'lucide-react';
58
+ import api from '../api';
59
+ import { appStyles } from '../theme';
60
+
61
+ /**
62
+ * DatasetPrep β€” sidecar-native dataset surface with a buffered editing model.
63
+ *
64
+ * One page, no modes. Pick or create a project. The dataset folder on disk
65
+ * is the *committed* state. Edits, auto-annotate output, and just-ingested
66
+ * audio all live in an in-memory session until the user explicitly hits
67
+ * Save (writes a draft) or Commit (writes .txt sidecars).
68
+ */
69
+ export default function DatasetPrep({ onOpenCheckpointManager }) {
70
+ const [projects, setProjects] = useState([]);
71
+ const [selectedName, setSelectedName] = useState(() => {
72
+ try { return window.localStorage.getItem('fragmenta.datasetPrep.lastProject') || ''; }
73
+ catch { return ''; }
74
+ });
75
+ const [project, setProject] = useState(null);
76
+ const [createOpen, setCreateOpen] = useState(false);
77
+ const [loadOpen, setLoadOpen] = useState(false);
78
+ const [ingestOpen, setIngestOpen] = useState(false);
79
+ const [sliceTarget, setSliceTarget] = useState(null); // file_name or null
80
+ // Single confirm-dialog state powering destructive actions. Mirrors the
81
+ // Free GPU / Start Fresh confirm style from App.js β€” replaces the
82
+ // browser-native window.confirm() prompts so the UX is consistent.
83
+ const [confirm, setConfirm] = useState(null);
84
+ const [confirmBusy, setConfirmBusy] = useState(false);
85
+ const [error, setError] = useState('');
86
+
87
+ const [errorCode, setErrorCode] = useState('');
88
+ const [errorExtra, setErrorExtra] = useState(null);
89
+ const [annotateJob, setAnnotateJob] = useState(null);
90
+ const [notice, setNotice] = useState(null); // { severity, message } | null
91
+ // Phase 6 β€” pre-encoded latents
92
+ const [preEncodeJob, setPreEncodeJob] = useState(null);
93
+ const [preEncodeOffer, setPreEncodeOffer] = useState(false); // post-commit dialog
94
+ const [tier, setTier] = useState(() => {
95
+ try { return window.localStorage.getItem('fragmenta.datasetPrep.tier') || 'basic'; }
96
+ catch { return 'basic'; }
97
+ });
98
+ const [skipExisting, setSkipExisting] = useState(true);
99
+
100
+ const pollHandleRef = useRef(null);
101
+ const preEncodePollRef = useRef(null);
102
+ const isAnnotating = annotateJob?.state === 'running';
103
+ const isPreEncoding = preEncodeJob?.state === 'running' || preEncodeJob?.state === 'queued';
104
+
105
+ // --- Multi-row selection (for bulk Slice) -----------------------------
106
+ // Set<string> of clip file_names. Reset whenever the active project
107
+ // changes, since selections from a different project are meaningless.
108
+ const [selectedFiles, setSelectedFiles] = useState(() => new Set());
109
+ useEffect(() => { setSelectedFiles(new Set()); }, [selectedName]);
110
+
111
+ const toggleSelected = useCallback((fileName) => {
112
+ setSelectedFiles((prev) => {
113
+ const next = new Set(prev);
114
+ if (next.has(fileName)) next.delete(fileName);
115
+ else next.add(fileName);
116
+ return next;
117
+ });
118
+ }, []);
119
+ const toggleSelectAll = useCallback((clips) => {
120
+ setSelectedFiles((prev) => {
121
+ const allNames = clips.map((c) => c.file_name);
122
+ const allSelected = allNames.length > 0 && allNames.every((n) => prev.has(n));
123
+ return allSelected ? new Set() : new Set(allNames);
124
+ });
125
+ }, []);
126
+ const clearSelection = useCallback(() => setSelectedFiles(new Set()), []);
127
+
128
+ // --- Per-row audio preview --------------------------------------------
129
+ // One <audio> for the whole table. Rows just say "play me" / "pause";
130
+ // the parent reconciles which file is loaded and where the playhead is.
131
+ const audioRef = useRef(null);
132
+ const [playingFile, setPlayingFile] = useState(null);
133
+ const [playProgress, setPlayProgress] = useState(0); // 0..1
134
+
135
+ const stopPlayback = useCallback(() => {
136
+ const audio = audioRef.current;
137
+ if (audio) { audio.pause(); }
138
+ setPlayingFile(null);
139
+ setPlayProgress(0);
140
+ }, []);
141
+
142
+ const handlePlayToggle = useCallback((fileName) => {
143
+ if (!selectedName) return;
144
+ const audio = audioRef.current;
145
+ if (!audio) return;
146
+ if (playingFile === fileName) {
147
+ audio.pause();
148
+ setPlayingFile(null);
149
+ return;
150
+ }
151
+ const url = `/api/projects/${encodeURIComponent(selectedName)}/clip/${encodeURIComponent(fileName)}/audio`;
152
+ audio.src = url;
153
+ setPlayProgress(0);
154
+ setPlayingFile(fileName);
155
+ audio.play().catch(() => {
156
+ setPlayingFile(null);
157
+ });
158
+ }, [selectedName, playingFile]);
159
+
160
+ // Stop playback when the project changes β€” the audio element's src would
161
+ // suddenly refer to a different project's file.
162
+ useEffect(() => { stopPlayback(); }, [selectedName, stopPlayback]);
163
+
164
+ const refreshProjects = useCallback(async () => {
165
+ try {
166
+ const { data } = await api.get('/api/projects');
167
+ setProjects(data.projects || []);
168
+ } catch (e) { setError(extractError(e, 'Failed to list projects')); }
169
+ }, []);
170
+
171
+ const [health, setHealth] = useState(null);
172
+ const refreshHealth = useCallback(async (name) => {
173
+ if (!name) { setHealth(null); return; }
174
+ try {
175
+ const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}/health`);
176
+ setHealth(data);
177
+ } catch {
178
+ // Non-fatal β€” strip just hides until next refresh.
179
+ setHealth(null);
180
+ }
181
+ }, []);
182
+
183
+ const refreshProject = useCallback(async (name) => {
184
+ if (!name) { setProject(null); setHealth(null); return; }
185
+ try {
186
+ const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}`);
187
+ setProject(data);
188
+ refreshHealth(name);
189
+ } catch (e) {
190
+ if (e?.response?.status === 404) {
191
+ setSelectedName('');
192
+ setProject(null);
193
+ setHealth(null);
194
+ await refreshProjects();
195
+ return;
196
+ }
197
+ setError(extractError(e, 'Failed to load project'));
198
+ }
199
+ }, [refreshProjects, refreshHealth]);
200
+
201
+ useEffect(() => { refreshProjects(); }, [refreshProjects]);
202
+
203
+ const pollAnnotateStatus = useCallback(async function poll(name) {
204
+ try {
205
+ const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}/annotate/status`);
206
+ setAnnotateJob(data.job);
207
+ if (data.job.state === 'done') {
208
+ await refreshProject(name);
209
+ return;
210
+ }
211
+ if (data.job.state === 'error') {
212
+ setError(data.job.error || 'Annotation failed');
213
+ return;
214
+ }
215
+ // Only keep polling while the backend is actively annotating. Other
216
+ // states ('idle', 'cancelled', missing) terminate the loop so a
217
+ // freshly-mounted tab doesn't poll forever for a non-existent job.
218
+ if (data.job.state === 'running') {
219
+ pollHandleRef.current = window.setTimeout(() => poll(name), 500);
220
+ }
221
+ } catch (e) { setError(extractError(e, 'Status poll failed')); }
222
+ }, [refreshProject]);
223
+
224
+ // Phase 6 β€” pre-encode polling. Same survives-tab-switch shape as the
225
+ // annotate poller above.
226
+ const pollPreEncodeStatus = useCallback(async function poll(name) {
227
+ try {
228
+ const { data } = await api.get(`/api/projects/${encodeURIComponent(name)}/pre-encode/status`);
229
+ setPreEncodeJob(data.job);
230
+ if (data.job.state === 'complete') {
231
+ refreshProject(name);
232
+ return;
233
+ }
234
+ if (data.job.state === 'failed') {
235
+ setError(data.job.error || 'Pre-encoding failed');
236
+ return;
237
+ }
238
+ if (data.job.state === 'running' || data.job.state === 'queued') {
239
+ preEncodePollRef.current = window.setTimeout(() => poll(name), 750);
240
+ }
241
+ } catch (e) { /* non-fatal β€” bar just freezes */ }
242
+ }, [refreshProject]);
243
+
244
+ useEffect(() => {
245
+ if (selectedName) {
246
+ try { window.localStorage.setItem('fragmenta.datasetPrep.lastProject', selectedName); } catch {}
247
+ refreshProject(selectedName);
248
+ // Re-bootstrap progress polling on (re)mount or project switch, so
249
+ // the progress strip survives tab changes while a job runs.
250
+ pollAnnotateStatus(selectedName);
251
+ pollPreEncodeStatus(selectedName);
252
+ } else {
253
+ setProject(null);
254
+ setAnnotateJob(null);
255
+ setPreEncodeJob(null);
256
+ }
257
+ return () => {
258
+ if (pollHandleRef.current) {
259
+ window.clearTimeout(pollHandleRef.current);
260
+ pollHandleRef.current = null;
261
+ }
262
+ if (preEncodePollRef.current) {
263
+ window.clearTimeout(preEncodePollRef.current);
264
+ preEncodePollRef.current = null;
265
+ }
266
+ };
267
+ }, [selectedName, refreshProject, pollAnnotateStatus, pollPreEncodeStatus]);
268
+
269
+ function changeTier(value) {
270
+ setTier(value);
271
+ try { window.localStorage.setItem('fragmenta.datasetPrep.tier', value); } catch {}
272
+ }
273
+
274
+ function trySelectProject(nextName) {
275
+ // Confirm before switching if there are unsaved or uncommitted edits.
276
+ if (project && (project.dirty || project.has_unsaved_changes) && nextName !== project.name) {
277
+ const ok = window.confirm(
278
+ `β€œ${project.name}” has unsaved or uncommitted changes. Switch anyway? They'll stay in memory until you reload the project β€” but a backend restart will lose them.`,
279
+ );
280
+ if (!ok) return;
281
+ }
282
+ setSelectedName(nextName);
283
+ }
284
+
285
+ async function handleAnnotate(scope /* "all" | [file_names] */, opts = {}) {
286
+ if (!project) return;
287
+ setError(''); setErrorCode(''); setErrorExtra(null);
288
+ try {
289
+ await api.post(`/api/projects/${encodeURIComponent(project.name)}/annotate`, {
290
+ tier,
291
+ scope: scope ?? 'all',
292
+ skip_existing: opts.skip_existing ?? skipExisting,
293
+ });
294
+ pollAnnotateStatus(project.name);
295
+ } catch (e) {
296
+ const body = e?.response?.data || {};
297
+ setError(extractError(e, 'Failed to start annotation'));
298
+ setErrorCode(body.code || '');
299
+ setErrorExtra(body.install_command ? { install_command: body.install_command } : null);
300
+ }
301
+ }
302
+
303
+ async function handleCancelAnnotate() {
304
+ if (!project) return;
305
+ try {
306
+ await api.post(`/api/projects/${encodeURIComponent(project.name)}/annotate/cancel`);
307
+ } catch (e) { setError(extractError(e, 'Cancel failed')); }
308
+ }
309
+
310
+ async function handleSave() {
311
+ if (!project) return;
312
+ setError('');
313
+ try {
314
+ const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/save`);
315
+ setProject(data);
316
+ setNotice({ severity: 'success', message: `Draft saved Β· ${data.clip_count} clips` });
317
+ } catch (e) { setError(extractError(e, 'Save failed')); }
318
+ }
319
+
320
+ async function handleStartPreEncode() {
321
+ if (!project) return;
322
+ setError('');
323
+ try {
324
+ const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/pre-encode`);
325
+ setPreEncodeJob(data.job);
326
+ pollPreEncodeStatus(project.name);
327
+ } catch (e) { setError(extractError(e, 'Pre-encode failed to start')); }
328
+ }
329
+
330
+ async function handleCancelPreEncode() {
331
+ if (!project) return;
332
+ try {
333
+ await api.post(`/api/projects/${encodeURIComponent(project.name)}/pre-encode/cancel`);
334
+ } catch (e) { setError(extractError(e, 'Cancel failed')); }
335
+ }
336
+
337
+ async function persistPreEncodeSuppression(suppress) {
338
+ if (!project) return;
339
+ try {
340
+ const { data } = await api.patch(
341
+ `/api/projects/${encodeURIComponent(project.name)}/pre-encode/prompt`,
342
+ { suppress: !!suppress },
343
+ );
344
+ setProject(data);
345
+ } catch (e) { /* non-fatal β€” dialog still closes */ }
346
+ }
347
+
348
+ async function handleCommit() {
349
+ if (!project) return;
350
+ setError('');
351
+ try {
352
+ const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/commit`);
353
+ setProject(data);
354
+ await refreshProjects();
355
+ // Phase 6 β€” post-commit pre-encode prompt.
356
+ // Open the dialog unless: (a) latents already present (re-commit
357
+ // wiped them but we still avoid re-asking immediately), or
358
+ // (b) the user previously chose "Don't ask again".
359
+ if (!data.suppress_pre_encode_prompt && !data.latents_present && data.clip_count > 0) {
360
+ setPreEncodeOffer(true);
361
+ }
362
+ setNotice({
363
+ severity: 'success',
364
+ message: `Dataset created Β· ${data.clip_count} clips written to disk`,
365
+ });
366
+ } catch (e) { setError(extractError(e, 'Create Dataset failed')); }
367
+ }
368
+
369
+ function handleDiscard() {
370
+ if (!project) return;
371
+ setConfirm({
372
+ title: 'Delete unsaved changes',
373
+ body: `Delete all changes in β€œ${project.name}” since the last created dataset? Audio files added since then will be removed.`,
374
+ warning: 'This cannot be undone.',
375
+ confirmLabel: 'Delete',
376
+ busyLabel: 'Deleting…',
377
+ danger: true,
378
+ onConfirm: async () => {
379
+ setError('');
380
+ try {
381
+ const { data } = await api.post(`/api/projects/${encodeURIComponent(project.name)}/discard`);
382
+ setProject(data);
383
+ await refreshProjects();
384
+ setNotice({ severity: 'info', message: 'Unsaved changes discarded' });
385
+ } catch (e) { setError(extractError(e, 'Delete failed')); }
386
+ },
387
+ });
388
+ }
389
+
390
+ function handleDeleteProject(name) {
391
+ if (!name) return;
392
+ setConfirm({
393
+ title: 'Delete project',
394
+ body: `Permanently delete project β€œ${name}”? Audio files, sidecars, and any drafts will be removed from disk.`,
395
+ warning: 'This cannot be undone.',
396
+ confirmLabel: 'Delete',
397
+ busyLabel: 'Deleting…',
398
+ danger: true,
399
+ onConfirm: async () => {
400
+ setError('');
401
+ try {
402
+ await api.delete(`/api/projects/${encodeURIComponent(name)}`);
403
+ if (selectedName === name) {
404
+ stopPlayback();
405
+ setSelectedName('');
406
+ setProject(null);
407
+ try { window.localStorage.removeItem('fragmenta.datasetPrep.lastProject'); } catch {}
408
+ }
409
+ await refreshProjects();
410
+ } catch (e) { setError(extractError(e, 'Delete project failed')); }
411
+ },
412
+ });
413
+ }
414
+
415
+ async function handleChangeTemplatePreset(presetId) {
416
+ if (!project) return;
417
+ try {
418
+ const { data } = await api.patch(
419
+ `/api/projects/${encodeURIComponent(project.name)}/template`,
420
+ { preset: presetId },
421
+ );
422
+ setProject(data);
423
+ } catch (e) {
424
+ setError(extractError(e, 'Could not update annotation style'));
425
+ }
426
+ }
427
+
428
+ function handleClearSelectedAnnotations() {
429
+ if (!project || selectedFiles.size === 0) return;
430
+ const count = selectedFiles.size;
431
+ const files = Array.from(selectedFiles);
432
+ setConfirm({
433
+ title: 'Clear',
434
+ body: `Clear annotations on ${count} clip${count === 1 ? '' : 's'}? Buffered in memory until you Save or Create Dataset.`,
435
+ warning: 'Use the Delete button to revert; this action itself can’t be undone in place.',
436
+ confirmLabel: `Clear (${count})`,
437
+ busyLabel: 'Clearing…',
438
+ danger: true,
439
+ onConfirm: async () => {
440
+ setError('');
441
+ try {
442
+ for (const f of files) {
443
+ await api.patch(
444
+ `/api/projects/${encodeURIComponent(project.name)}/clip/${encodeURIComponent(f)}`,
445
+ { prompt: '' },
446
+ );
447
+ }
448
+ clearSelection();
449
+ await refreshProject(project.name);
450
+ } catch (e) { setError(extractError(e, 'Clear annotations failed')); }
451
+ },
452
+ });
453
+ }
454
+
455
+ async function handleClipPromptChange(fileName, newPrompt) {
456
+ if (!project) return;
457
+ try {
458
+ await api.patch(
459
+ `/api/projects/${encodeURIComponent(project.name)}/clip/${encodeURIComponent(fileName)}`,
460
+ { prompt: newPrompt },
461
+ );
462
+ // Reload to pick up dirty-state flip in the header.
463
+ await refreshProject(project.name);
464
+ } catch (e) { setError(extractError(e, 'Failed to save prompt')); }
465
+ }
466
+
467
+ async function handleClipDelete(fileName) {
468
+ if (!project) return;
469
+ if (!window.confirm(`Remove ${fileName} from this project? (Deletes the audio file from disk immediately β€” cannot be discarded back.)`)) return;
470
+ try {
471
+ await api.delete(
472
+ `/api/projects/${encodeURIComponent(project.name)}/clip/${encodeURIComponent(fileName)}`,
473
+ );
474
+ await refreshProject(project.name);
475
+ } catch (e) { setError(extractError(e, 'Failed to delete clip')); }
476
+ }
477
+
478
+ return (
479
+ <Paper variant="outlined" sx={{ p: { xs: 2.25, sm: 3 }, borderRadius: 2.5 }}>
480
+ <Stack spacing={2.5}>
481
+ <Box>
482
+ <Box sx={{ ...appStyles.sectionCardHeader, mb: 0.5 }}>
483
+ <Box component="span" sx={appStyles.sectionCardIcon}>
484
+ <Database size={20} />
485
+ </Box>
486
+ <Typography variant="h6" sx={appStyles.sectionCardTitle}>
487
+ Dataset Workbench
488
+ </Typography>
489
+ <Box sx={{ flex: 1 }} />
490
+ <Button
491
+ variant="outlined"
492
+ size="small"
493
+ startIcon={<FolderOpenIcon size={16} />}
494
+ onClick={() => setLoadOpen(true)}
495
+ disabled={projects.length === 0}
496
+ >
497
+ Load project
498
+ </Button>
499
+ <Button
500
+ variant="outlined"
501
+ size="small"
502
+ startIcon={<PlusIcon size={16} />}
503
+ onClick={() => setCreateOpen(true)}
504
+ >
505
+ New project
506
+ </Button>
507
+ </Box>
508
+ <Typography variant="body2" color="text.secondary">
509
+ Create a new dataset or load and edit one.
510
+ </Typography>
511
+ <Typography variant="body2" color="text.secondary" paddingBottom={2}>
512
+ You can auto-annotate using Librosa and CLAP or annotate everything manually.
513
+ </Typography>
514
+ </Box>
515
+
516
+ {error && (
517
+ <Alert
518
+ severity={(errorCode === 'clap_not_available' || errorCode === 'clap_package_missing') ? 'warning' : 'error'}
519
+ onClose={() => { setError(''); setErrorCode(''); setErrorExtra(null); }}
520
+ action={
521
+ errorCode === 'clap_not_available' && onOpenCheckpointManager ? (
522
+ <Button
523
+ color="inherit"
524
+ size="small"
525
+ onClick={() => { setError(''); setErrorCode(''); setErrorExtra(null); onOpenCheckpointManager(); }}
526
+ >
527
+ Open Model Management
528
+ </Button>
529
+ ) : null
530
+ }
531
+ >
532
+ <Box>
533
+ <Typography variant="body2">{error}</Typography>
534
+ {errorCode === 'clap_package_missing' && errorExtra?.install_command && (
535
+ <Box
536
+ component="pre"
537
+ sx={{
538
+ mt: 1,
539
+ mb: 0,
540
+ p: 1,
541
+ borderRadius: 1,
542
+ bgcolor: 'action.hover',
543
+ fontSize: '0.8rem',
544
+ fontFamily: 'monospace',
545
+ overflowX: 'auto',
546
+ }}
547
+ >
548
+ {errorExtra.install_command}
549
+ </Box>
550
+ )}
551
+ </Box>
552
+ </Alert>
553
+ )}
554
+
555
+ {project && (
556
+ <Stack spacing={2}>
557
+ <ProjectHeader
558
+ project={project}
559
+ onSave={handleSave}
560
+ onCommit={handleCommit}
561
+ onDiscard={handleDiscard}
562
+ onAddAudio={() => setIngestOpen(true)}
563
+ disabled={isAnnotating}
564
+ />
565
+
566
+ <HealthStrip
567
+ health={health}
568
+ onSelectFiles={(files) => setSelectedFiles(new Set(files))}
569
+ />
570
+
571
+ {isAnnotating && annotateJob && (
572
+ <Box>
573
+ <LinearProgress
574
+ variant={annotateJob.total > 0 ? 'determinate' : 'indeterminate'}
575
+ value={annotateJob.total > 0 ? (annotateJob.current / annotateJob.total) * 100 : undefined}
576
+ />
577
+ <Box sx={{ mt: 0.75, display: 'flex', alignItems: 'center', gap: 1.5 }}>
578
+ <Typography variant="caption" color="text.secondary" sx={{ flex: 1 }}>
579
+ Annotating {annotateJob.current} / {annotateJob.total}
580
+ {annotateJob.current_file ? ` Β· ${annotateJob.current_file}` : ''}
581
+ </Typography>
582
+ <Button
583
+ size="small"
584
+ variant="outlined"
585
+ color="error"
586
+ startIcon={<StopIcon size={14} />}
587
+ onClick={handleCancelAnnotate}
588
+ >
589
+ Stop
590
+ </Button>
591
+ </Box>
592
+ </Box>
593
+ )}
594
+
595
+ {isPreEncoding && preEncodeJob && (
596
+ <Box>
597
+ <LinearProgress
598
+ variant={preEncodeJob.total > 0 ? 'determinate' : 'indeterminate'}
599
+ value={preEncodeJob.total > 0 ? (preEncodeJob.current / preEncodeJob.total) * 100 : undefined}
600
+ />
601
+ <Box sx={{ mt: 0.75, display: 'flex', alignItems: 'center', gap: 1.5 }}>
602
+ <Typography variant="caption" color="text.secondary" sx={{ flex: 1 }}>
603
+ Pre-encoding latents Β· {preEncodeJob.current} / {preEncodeJob.total}
604
+ {preEncodeJob.autoencoder ? ` Β· ${preEncodeJob.autoencoder}` : ''}
605
+ </Typography>
606
+ <Button
607
+ size="small"
608
+ variant="outlined"
609
+ color="error"
610
+ startIcon={<StopIcon size={14} />}
611
+ onClick={handleCancelPreEncode}
612
+ >
613
+ Stop
614
+ </Button>
615
+ </Box>
616
+ </Box>
617
+ )}
618
+
619
+ <ClipTable
620
+ projectName={selectedName}
621
+ clips={project.clips}
622
+ playingFile={playingFile}
623
+ playProgress={playProgress}
624
+ onPlayToggle={handlePlayToggle}
625
+ onPromptChange={handleClipPromptChange}
626
+ onAnnotate={(fname) => handleAnnotate([fname], { skip_existing: false })}
627
+ onDelete={(fname) => {
628
+ if (playingFile === fname) stopPlayback();
629
+ return handleClipDelete(fname);
630
+ }}
631
+ onSlice={(fname) => {
632
+ if (playingFile === fname) stopPlayback();
633
+ setSliceTarget(fname);
634
+ }}
635
+ selectedFiles={selectedFiles}
636
+ onToggleSelected={toggleSelected}
637
+ onToggleSelectAll={() => toggleSelectAll(project.clips)}
638
+ disabled={isAnnotating}
639
+ toolbar={
640
+ <Stack spacing={1}>
641
+ <Box sx={{ display: 'flex', alignItems: 'center', flexWrap: 'wrap', gap: 1.5 }}>
642
+ <Button
643
+ variant="contained"
644
+ color="warm"
645
+ size="small"
646
+ startIcon={<WandSparkles size={16} />}
647
+ onClick={() => handleAnnotate('all')}
648
+ disabled={isAnnotating || project.clip_count === 0}
649
+ >
650
+ Auto-annotate all
651
+ </Button>
652
+ <FormControl size="small" sx={{ minWidth: 180 }}>
653
+ <Select
654
+ value={project.prompt_template_preset || 'music'}
655
+ onChange={(e) => handleChangeTemplatePreset(e.target.value)}
656
+ disabled={isAnnotating}
657
+ renderValue={(v) => {
658
+ const p = (project.prompt_template_presets || []).find((x) => x.id === v);
659
+ return p ? p.label : v;
660
+ }}
661
+ >
662
+ {(project.prompt_template_presets || []).map((p) => (
663
+ <MenuItem key={p.id} value={p.id}>
664
+ <Box>
665
+ <Typography variant="body2">{p.label}</Typography>
666
+ <Typography variant="caption" color="text.secondary">
667
+ {p.description}
668
+ </Typography>
669
+ </Box>
670
+ </MenuItem>
671
+ ))}
672
+ </Select>
673
+ </FormControl>
674
+ <Tooltip title={TIPS.dataset.richAnnotate}>
675
+ <FormControlLabel
676
+ control={
677
+ <Switch
678
+ size="small"
679
+ checked={tier === 'rich'}
680
+ onChange={(e) => changeTier(e.target.checked ? 'rich' : 'basic')}
681
+ disabled={isAnnotating}
682
+ />
683
+ }
684
+ label={<Typography variant="caption" color="text.secondary">Rich annotation</Typography>}
685
+ sx={{ mr: 0 }}
686
+ />
687
+ </Tooltip>
688
+ <Tooltip title={TIPS.dataset.skipAnnotated}>
689
+ <FormControlLabel
690
+ control={
691
+ <Switch
692
+ size="small"
693
+ checked={skipExisting}
694
+ onChange={(e) => setSkipExisting(e.target.checked)}
695
+ disabled={isAnnotating}
696
+ />
697
+ }
698
+ label={<Typography variant="caption" color="text.secondary">Skip already annotated</Typography>}
699
+ sx={{ mr: 0 }}
700
+ />
701
+ </Tooltip>
702
+ <Box sx={{ flex: 1 }} />
703
+ {selectedFiles.size > 0 && (
704
+ <Button
705
+ variant="outlined"
706
+ color="error"
707
+ size="small"
708
+ startIcon={<TrashIcon size={16} />}
709
+ onClick={handleClearSelectedAnnotations}
710
+ disabled={isAnnotating}
711
+ >
712
+ Clear annotations ({selectedFiles.size})
713
+ </Button>
714
+ )}
715
+ </Box>
716
+ {tier === 'rich' && (
717
+ <ClapVocabAccordion disabled={isAnnotating} />
718
+ )}
719
+ </Stack>
720
+ }
721
+ />
722
+ <audio
723
+ ref={audioRef}
724
+ style={{ display: 'none' }}
725
+ onTimeUpdate={(e) => {
726
+ const a = e.currentTarget;
727
+ if (a.duration && isFinite(a.duration)) {
728
+ setPlayProgress(a.currentTime / a.duration);
729
+ }
730
+ }}
731
+ onEnded={() => { setPlayingFile(null); setPlayProgress(0); }}
732
+ onError={() => { setPlayingFile(null); setPlayProgress(0); }}
733
+ />
734
+ </Stack>
735
+ )}
736
+
737
+ <CreateProjectDialog
738
+ open={createOpen}
739
+ existingNames={projects.map((p) => p.name)}
740
+ onClose={() => setCreateOpen(false)}
741
+ onCreated={async (name) => {
742
+ setCreateOpen(false);
743
+ await refreshProjects();
744
+ setSelectedName(name);
745
+ }}
746
+ />
747
+
748
+ <LoadProjectDialog
749
+ open={loadOpen}
750
+ projects={projects}
751
+ currentName={selectedName}
752
+ onClose={() => setLoadOpen(false)}
753
+ onLoad={(name) => {
754
+ setLoadOpen(false);
755
+ trySelectProject(name);
756
+ }}
757
+ onDeleteProject={handleDeleteProject}
758
+ />
759
+
760
+ <IngestDialog
761
+ open={ingestOpen}
762
+ projectName={project?.name}
763
+ onClose={() => setIngestOpen(false)}
764
+ onIngested={async () => {
765
+ setIngestOpen(false);
766
+ if (project) await refreshProject(project.name);
767
+ await refreshProjects();
768
+ }}
769
+ />
770
+
771
+ <SliceDialog
772
+ open={Boolean(sliceTarget)}
773
+ projectName={project?.name}
774
+ fileName={sliceTarget}
775
+ onClose={() => setSliceTarget(null)}
776
+ onSliced={async () => {
777
+ clearSelection();
778
+ if (project) await refreshProject(project.name);
779
+ await refreshProjects();
780
+ }}
781
+ />
782
+
783
+ <Dialog
784
+ open={Boolean(confirm)}
785
+ onClose={confirmBusy ? undefined : () => setConfirm(null)}
786
+ aria-labelledby="dataset-confirm-title"
787
+ >
788
+ <DialogTitle id="dataset-confirm-title">
789
+ {confirm?.title}
790
+ </DialogTitle>
791
+ <DialogContent>
792
+ <Typography sx={appStyles.dialogBodyText}>
793
+ {confirm?.body}
794
+ </Typography>
795
+ {confirm?.warning && (
796
+ <Typography variant="body2" color="warning.main" sx={appStyles.dialogErrorText}>
797
+ {confirm.warning}
798
+ </Typography>
799
+ )}
800
+ </DialogContent>
801
+ <DialogActions>
802
+ <Button onClick={() => setConfirm(null)} disabled={confirmBusy}>
803
+ Cancel
804
+ </Button>
805
+ <Button
806
+ onClick={async () => {
807
+ if (!confirm?.onConfirm) { setConfirm(null); return; }
808
+ setConfirmBusy(true);
809
+ try {
810
+ await confirm.onConfirm();
811
+ } finally {
812
+ setConfirmBusy(false);
813
+ setConfirm(null);
814
+ }
815
+ }}
816
+ color={confirm?.danger ? 'error' : 'primary'}
817
+ variant="contained"
818
+ disabled={confirmBusy}
819
+ >
820
+ {confirmBusy ? (confirm?.busyLabel || 'Working…') : (confirm?.confirmLabel || 'Confirm')}
821
+ </Button>
822
+ </DialogActions>
823
+ </Dialog>
824
+
825
+ {/* Phase 6 β€” post-commit pre-encode dialog. Surfaces after a
826
+ successful Create Dataset commit unless the user previously
827
+ chose "Don't ask again". */}
828
+ <Dialog
829
+ open={preEncodeOffer}
830
+ onClose={() => setPreEncodeOffer(false)}
831
+ maxWidth="xs"
832
+ fullWidth
833
+ >
834
+ <DialogTitle>Pre-encode latents?</DialogTitle>
835
+ <DialogContent>
836
+ <Typography variant="body2" sx={{ mb: 1 }}>
837
+ Encode your audio into SA3 latents now to speed up training. The
838
+ autoencoder runs once up-front instead of every training step.
839
+ </Typography>
840
+ <Typography variant="caption" color="text.secondary">
841
+ Takes a few minutes for ~50 clips. Latents live in
842
+ <code> {project?.name}/.latents/</code> and get wiped automatically
843
+ when you next commit or edit a clip.
844
+ </Typography>
845
+ </DialogContent>
846
+ <DialogActions sx={{ flexWrap: 'wrap' }}>
847
+ <Button
848
+ onClick={() => {
849
+ persistPreEncodeSuppression(true);
850
+ setPreEncodeOffer(false);
851
+ }}
852
+ sx={{ mr: 'auto' }}
853
+ >
854
+ Don't ask again
855
+ </Button>
856
+ <Button onClick={() => setPreEncodeOffer(false)}>Not now</Button>
857
+ <Button
858
+ variant="contained"
859
+ onClick={() => {
860
+ setPreEncodeOffer(false);
861
+ handleStartPreEncode();
862
+ }}
863
+ >
864
+ Pre-encode now
865
+ </Button>
866
+ </DialogActions>
867
+ </Dialog>
868
+
869
+ <Portal>
870
+ <Snackbar
871
+ open={Boolean(notice)}
872
+ autoHideDuration={4000}
873
+ onClose={() => setNotice(null)}
874
+ anchorOrigin={{ vertical: 'bottom', horizontal: 'right' }}
875
+ >
876
+ {notice ? (
877
+ <Alert
878
+ onClose={() => setNotice(null)}
879
+ severity={notice.severity}
880
+ variant="filled"
881
+ sx={{ width: '100%' }}
882
+ >
883
+ {notice.message}
884
+ </Alert>
885
+ ) : undefined}
886
+ </Snackbar>
887
+ </Portal>
888
+ </Stack>
889
+ </Paper>
890
+ );
891
+ }
892
+
893
+ // ---------- subcomponents --------------------------------------------------
894
+
895
+ function ClapVocabAccordion({ disabled }) {
896
+ const [labels, setLabels] = useState({ genre: [], mood: [], instruments: [] });
897
+ const [overridden, setOverridden] = useState(false);
898
+ const [dirty, setDirty] = useState(false);
899
+ const [busy, setBusy] = useState(false);
900
+ const [vocabError, setVocabError] = useState('');
901
+
902
+ useEffect(() => {
903
+ let cancelled = false;
904
+ (async () => {
905
+ try {
906
+ const { data } = await api.get('/api/annotator-labels');
907
+ if (cancelled) return;
908
+ setLabels(data.labels || { genre: [], mood: [], instruments: [] });
909
+ setOverridden(!!data.overridden);
910
+ setDirty(false);
911
+ } catch (e) {
912
+ if (!cancelled) setVocabError(extractError(e, 'Failed to load vocabulary'));
913
+ }
914
+ })();
915
+ return () => { cancelled = true; };
916
+ }, []);
917
+
918
+ function setCategory(cat, values) {
919
+ setLabels((prev) => ({ ...prev, [cat]: values }));
920
+ setDirty(true);
921
+ }
922
+
923
+ async function save() {
924
+ setBusy(true);
925
+ setVocabError('');
926
+ try {
927
+ await api.put('/api/annotator-labels', labels);
928
+ setDirty(false);
929
+ setOverridden(true);
930
+ } catch (e) {
931
+ setVocabError(extractError(e, 'Failed to save vocabulary'));
932
+ } finally {
933
+ setBusy(false);
934
+ }
935
+ }
936
+
937
+ async function reset() {
938
+ if (!window.confirm('Reset vocabulary to the built-in defaults? Your custom tags will be lost.')) return;
939
+ setBusy(true);
940
+ setVocabError('');
941
+ try {
942
+ await api.delete('/api/annotator-labels');
943
+ const { data } = await api.get('/api/annotator-labels');
944
+ setLabels(data.labels || { genre: [], mood: [], instruments: [] });
945
+ setOverridden(false);
946
+ setDirty(false);
947
+ } catch (e) {
948
+ setVocabError(extractError(e, 'Failed to reset vocabulary'));
949
+ } finally {
950
+ setBusy(false);
951
+ }
952
+ }
953
+
954
+ const tagCount = (labels.genre?.length || 0) + (labels.mood?.length || 0) + (labels.instruments?.length || 0);
955
+
956
+ return (
957
+ <Accordion
958
+ disableGutters
959
+ sx={{ '&, &.Mui-expanded': { mt: 0, mb: 0 } }}
960
+ >
961
+ <AccordionSummary
962
+ expandIcon={<ChevronDownIcon size={18} />}
963
+ sx={{
964
+ minHeight: 48,
965
+ '&.Mui-expanded': { minHeight: 48 },
966
+ '& .MuiAccordionSummary-content': {
967
+ margin: '12px 0',
968
+ '&.Mui-expanded': { margin: '12px 0' },
969
+ },
970
+ }}
971
+ >
972
+ <Typography variant="subtitle1">CLAP Vocabulary</Typography>
973
+ <Typography variant="caption" color="text.secondary" sx={{ ml: 1.5, alignSelf: 'center' }}>
974
+ {overridden ? 'custom' : 'defaults'} Β· {tagCount} tags
975
+ </Typography>
976
+ </AccordionSummary>
977
+ <AccordionDetails>
978
+ <Stack spacing={2}>
979
+ <Typography variant="body2" color="text.secondary">
980
+ Words CLAP scores each clip against. Empty categories are ignored. Tweak to match your dataset's territory.
981
+ </Typography>
982
+ <VocabCategory
983
+ label="Genre"
984
+ values={labels.genre || []}
985
+ onChange={(v) => setCategory('genre', v)}
986
+ disabled={disabled || busy}
987
+ />
988
+ <VocabCategory
989
+ label="Mood"
990
+ values={labels.mood || []}
991
+ onChange={(v) => setCategory('mood', v)}
992
+ disabled={disabled || busy}
993
+ />
994
+ <VocabCategory
995
+ label="Instruments"
996
+ values={labels.instruments || []}
997
+ onChange={(v) => setCategory('instruments', v)}
998
+ disabled={disabled || busy}
999
+ />
1000
+ {vocabError && <Alert severity="error" onClose={() => setVocabError('')}>{vocabError}</Alert>}
1001
+ <Box sx={{ display: 'flex', alignItems: 'center', gap: 1.5 }}>
1002
+ <Button
1003
+ variant="text"
1004
+ size="small"
1005
+ onClick={reset}
1006
+ disabled={disabled || busy || !overridden}
1007
+ >
1008
+ Reset to defaults
1009
+ </Button>
1010
+ <Box sx={{ flex: 1 }} />
1011
+ <Button
1012
+ variant="contained"
1013
+ size="small"
1014
+ onClick={save}
1015
+ disabled={disabled || busy || !dirty}
1016
+ >
1017
+ Save vocabulary
1018
+ </Button>
1019
+ </Box>
1020
+ </Stack>
1021
+ </AccordionDetails>
1022
+ </Accordion>
1023
+ );
1024
+ }
1025
+
1026
+ function VocabCategory({ label, values, onChange, disabled }) {
1027
+ return (
1028
+ <Autocomplete
1029
+ multiple
1030
+ freeSolo
1031
+ options={[]}
1032
+ value={values}
1033
+ onChange={(_e, newValues) => onChange(newValues)}
1034
+ disabled={disabled}
1035
+ renderTags={(value, getTagProps) =>
1036
+ value.map((option, index) => {
1037
+ const tagProps = getTagProps({ index });
1038
+ return (
1039
+ <Chip
1040
+ variant="outlined"
1041
+ size="small"
1042
+ label={option}
1043
+ {...tagProps}
1044
+ key={`${option}-${index}`}
1045
+ />
1046
+ );
1047
+ })
1048
+ }
1049
+ renderInput={(params) => (
1050
+ <TextField
1051
+ {...params}
1052
+ label={label}
1053
+ placeholder="Add tag, press Enter"
1054
+ size="small"
1055
+ />
1056
+ )}
1057
+ />
1058
+ );
1059
+ }
1060
+
1061
+ function LoadProjectDialog({ open, projects, currentName, onClose, onLoad, onDeleteProject }) {
1062
+ const [picked, setPicked] = useState(currentName || '');
1063
+
1064
+ useEffect(() => {
1065
+ if (open) setPicked(currentName || (projects[0]?.name ?? ''));
1066
+ }, [open, currentName, projects]);
1067
+
1068
+ return (
1069
+ <Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
1070
+ <DialogTitle>Load project</DialogTitle>
1071
+ <DialogContent>
1072
+ {projects.length === 0 ? (
1073
+ <Typography variant="body2" color="text.secondary" sx={{ py: 2 }}>
1074
+ No projects yet. Create one first.
1075
+ </Typography>
1076
+ ) : (
1077
+ <RadioGroup value={picked} onChange={(e) => setPicked(e.target.value)}>
1078
+ {projects.map((p) => (
1079
+ <Box
1080
+ key={p.name}
1081
+ sx={{ display: 'flex', alignItems: 'center', gap: 1, py: 0.25 }}
1082
+ >
1083
+ <FormControlLabel
1084
+ value={p.name}
1085
+ control={<Radio size="small" />}
1086
+ label={
1087
+ <Box>
1088
+ <Typography variant="body2">{p.name}</Typography>
1089
+ <Typography variant="caption" color="text.secondary">
1090
+ {p.clip_count} clip{p.clip_count === 1 ? '' : 's'}
1091
+ {p.has_draft ? ' Β· has unsaved draft' : ''}
1092
+ </Typography>
1093
+ </Box>
1094
+ }
1095
+ sx={{ alignItems: 'flex-start', flex: 1, mr: 0 }}
1096
+ />
1097
+ <Tooltip title={TIPS.dataset.deleteProject}>
1098
+ <span>
1099
+ <IconButton
1100
+ size="small"
1101
+ sx={{ color: 'text.disabled', '&:hover': { color: 'error.main', bgcolor: 'action.hover' } }}
1102
+ onClick={() => onDeleteProject(p.name)}
1103
+ >
1104
+ <TrashIcon size={16} />
1105
+ </IconButton>
1106
+ </span>
1107
+ </Tooltip>
1108
+ </Box>
1109
+ ))}
1110
+ </RadioGroup>
1111
+ )}
1112
+ </DialogContent>
1113
+ <DialogActions>
1114
+ <Button onClick={onClose}>Cancel</Button>
1115
+ <Button
1116
+ variant="contained"
1117
+ onClick={() => onLoad(picked)}
1118
+ disabled={!picked || projects.length === 0}
1119
+ >
1120
+ Load
1121
+ </Button>
1122
+ </DialogActions>
1123
+ </Dialog>
1124
+ );
1125
+ }
1126
+
1127
+ function ProjectHeader({ project, onSave, onCommit, onDiscard, onAddAudio, disabled }) {
1128
+ const stateLabel = (() => {
1129
+ if (project.dirty && project.has_unsaved_changes) return 'Unsaved changes';
1130
+ if (project.dirty && !project.has_unsaved_changes) return 'Draft saved Β· dataset not created';
1131
+ if (!project.dirty) return 'Dataset created';
1132
+ return '';
1133
+ })();
1134
+ return (
1135
+ <Box sx={{ display: 'flex', alignItems: 'center', gap: 2, flexWrap: 'wrap' }}>
1136
+ <Stack direction="row" spacing={2} alignItems="center" sx={{ flex: 1, minWidth: 240 }}>
1137
+ <Box>
1138
+ <Typography variant="h6">Project: &ldquo;{project.name}&rdquo;</Typography>
1139
+ <Typography variant="body2" color="text.secondary">
1140
+ {project.clip_count} clip{project.clip_count === 1 ? '' : 's'}
1141
+ {' Β· '}{stateLabel}
1142
+ </Typography>
1143
+ </Box>
1144
+ <Button
1145
+ variant="outlined"
1146
+ size="small"
1147
+ startIcon={<MusicIcon size={16} />}
1148
+ onClick={onAddAudio}
1149
+ disabled={disabled}
1150
+ >
1151
+ Add audio
1152
+ </Button>
1153
+ </Stack>
1154
+ <Stack direction="row" spacing={1}>
1155
+ <Tooltip title={TIPS.dataset.discardChanges}>
1156
+ <span>
1157
+ <Button
1158
+ variant="outlined"
1159
+ color="error"
1160
+ size="small"
1161
+ startIcon={<TrashIcon size={16} />}
1162
+ onClick={onDiscard}
1163
+ disabled={disabled || !project.dirty}
1164
+ >
1165
+ Delete
1166
+ </Button>
1167
+ </span>
1168
+ </Tooltip>
1169
+ <Tooltip title={TIPS.dataset.saveDraft}>
1170
+ <span>
1171
+ <Button
1172
+ variant="outlined"
1173
+ size="small"
1174
+ startIcon={<SaveIcon size={16} />}
1175
+ onClick={onSave}
1176
+ disabled={disabled || !project.has_unsaved_changes}
1177
+ >
1178
+ Save
1179
+ </Button>
1180
+ </span>
1181
+ </Tooltip>
1182
+ <Tooltip title={TIPS.dataset.createDataset}>
1183
+ <span>
1184
+ <Button
1185
+ variant="contained"
1186
+ size="small"
1187
+ startIcon={<DatasetIcon size={16} />}
1188
+ onClick={onCommit}
1189
+ disabled={disabled || !project.dirty}
1190
+ >
1191
+ Create Dataset
1192
+ </Button>
1193
+ </span>
1194
+ </Tooltip>
1195
+ </Stack>
1196
+ </Box>
1197
+ );
1198
+ }
1199
+
1200
+ function HealthStrip({ health, onSelectFiles }) {
1201
+ if (!health || health.total_clips === 0) return null;
1202
+ const empty = health.empty_prompts || { count: 0, files: [] };
1203
+ const tooShort = health.too_short || { count: 0, files: [] };
1204
+ const dups = health.duplicate_annotations || { count: 0, group_count: 0, files: [] };
1205
+ const unsupported = health.unsupported_format || { count: 0, accepted: [], files: [] };
1206
+ const issues = empty.count + tooShort.count
1207
+ + dups.count + unsupported.count;
1208
+
1209
+ // Three-tier status driven by the share of unique clips touched by any
1210
+ // health check. A single file showing up in multiple categories only
1211
+ // counts once.
1212
+ const affected = new Set([
1213
+ ...empty.files,
1214
+ ...tooShort.files,
1215
+ ...dups.files,
1216
+ ...unsupported.files,
1217
+ ]);
1218
+ const affectedRatio = health.total_clips > 0 ? affected.size / health.total_clips : 0;
1219
+ let status;
1220
+ if (affected.size === 0) status = 'ok';
1221
+ else if (affectedRatio > 0.5) status = 'bad';
1222
+ else status = 'warn';
1223
+
1224
+ const statusColor = (
1225
+ status === 'ok' ? 'success.main'
1226
+ : status === 'warn' ? 'warm.main'
1227
+ : 'error.main'
1228
+ );
1229
+ const statusText = (
1230
+ status === 'ok'
1231
+ ? `All clean Β· ${health.total_clips} clip${health.total_clips === 1 ? '' : 's'} ready`
1232
+ : `${affected.size} of ${health.total_clips} clip${health.total_clips === 1 ? '' : 's'} flagged`
1233
+ );
1234
+
1235
+ return (
1236
+ <Paper variant="outlined" sx={{ borderRadius: 2.5 }}>
1237
+ <Box sx={{ px: 2, py: 1.25, display: 'flex', alignItems: 'center', gap: 1 }}>
1238
+ <Box component="span" sx={appStyles.sectionCardIcon}>
1239
+ <HealthIcon size={18} />
1240
+ </Box>
1241
+ <Typography variant="subtitle1" sx={{ fontWeight: 500, flex: 1 }}>
1242
+ Dataset health
1243
+ </Typography>
1244
+ <Box
1245
+ sx={{
1246
+ width: 8,
1247
+ height: 8,
1248
+ borderRadius: '50%',
1249
+ bgcolor: statusColor,
1250
+ // Soft halo so the dot reads as a status indicator,
1251
+ // not stray decoration.
1252
+ boxShadow: (theme) =>
1253
+ `0 0 0 3px ${theme.palette.mode === 'dark' ? 'rgba(255,255,255,0.04)' : 'rgba(0,0,0,0.04)'}`,
1254
+ }}
1255
+ />
1256
+ <Typography variant="caption" sx={{ color: statusColor }}>
1257
+ {statusText}
1258
+ </Typography>
1259
+ </Box>
1260
+
1261
+ {issues > 0 && (
1262
+ <Box
1263
+ sx={{
1264
+ px: 2,
1265
+ py: 1.25,
1266
+ borderTop: 1,
1267
+ borderColor: 'divider',
1268
+ display: 'flex',
1269
+ alignItems: 'center',
1270
+ gap: 0.75,
1271
+ flexWrap: 'wrap',
1272
+ }}
1273
+ >
1274
+ {empty.count > 0 && (
1275
+ <Tooltip title={TIPS.dataset.selectClips}>
1276
+ <Chip
1277
+ size="small"
1278
+ variant="outlined"
1279
+ color="warning"
1280
+ label={`${empty.count} empty annotation${empty.count === 1 ? '' : 's'}`}
1281
+ onClick={() => onSelectFiles(empty.files)}
1282
+ />
1283
+ </Tooltip>
1284
+ )}
1285
+ {tooShort.count > 0 && (
1286
+ <Tooltip title={TIPS.dataset.tooShort(tooShort.threshold_sec)}>
1287
+ <Chip
1288
+ size="small"
1289
+ variant="outlined"
1290
+ color="error"
1291
+ label={`${tooShort.count} too short (< ${tooShort.threshold_sec}s)`}
1292
+ onClick={() => onSelectFiles(tooShort.files)}
1293
+ />
1294
+ </Tooltip>
1295
+ )}
1296
+ {dups.count > 0 && (
1297
+ <Tooltip title={TIPS.dataset.duplicates(dups.group_count)}>
1298
+ <Chip
1299
+ size="small"
1300
+ variant="outlined"
1301
+ color="warning"
1302
+ label={`${dups.count} duplicate annotation${dups.count === 1 ? '' : 's'}`}
1303
+ onClick={() => onSelectFiles(dups.files)}
1304
+ />
1305
+ </Tooltip>
1306
+ )}
1307
+ {unsupported.count > 0 && (
1308
+ <Tooltip title={TIPS.dataset.unsupported(unsupported.accepted)}>
1309
+ <Chip
1310
+ size="small"
1311
+ variant="outlined"
1312
+ color="error"
1313
+ label={`${unsupported.count} unsupported format${unsupported.count === 1 ? '' : 's'}`}
1314
+ onClick={() => onSelectFiles(unsupported.files)}
1315
+ />
1316
+ </Tooltip>
1317
+ )}
1318
+ </Box>
1319
+ )}
1320
+ </Paper>
1321
+ );
1322
+ }
1323
+
1324
+ function Waveform({ projectName, fileName, isActive, progress }) {
1325
+ const canvasRef = useRef(null);
1326
+ const theme = useTheme();
1327
+ const [peaks, setPeaks] = useState(null);
1328
+ const [failed, setFailed] = useState(false);
1329
+
1330
+ useEffect(() => {
1331
+ let cancelled = false;
1332
+ setPeaks(null);
1333
+ setFailed(false);
1334
+ if (!projectName || !fileName) return;
1335
+ const url = `/api/projects/${encodeURIComponent(projectName)}/clip/${encodeURIComponent(fileName)}/peaks?n=80`;
1336
+ api.get(url)
1337
+ .then(({ data }) => { if (!cancelled) setPeaks(data?.peaks || []); })
1338
+ .catch(() => { if (!cancelled) setFailed(true); });
1339
+ return () => { cancelled = true; };
1340
+ }, [projectName, fileName]);
1341
+
1342
+ useEffect(() => {
1343
+ const canvas = canvasRef.current;
1344
+ if (!canvas) return;
1345
+ const dpr = window.devicePixelRatio || 1;
1346
+ const w = canvas.clientWidth;
1347
+ const h = canvas.clientHeight;
1348
+ if (canvas.width !== w * dpr) canvas.width = w * dpr;
1349
+ if (canvas.height !== h * dpr) canvas.height = h * dpr;
1350
+ const ctx = canvas.getContext('2d');
1351
+ ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
1352
+ ctx.clearRect(0, 0, w, h);
1353
+
1354
+ if (!peaks || !peaks.length) {
1355
+ ctx.fillStyle = 'rgba(0,0,0,0.08)';
1356
+ const midY = h / 2;
1357
+ ctx.fillRect(0, midY - 0.5, w, 1);
1358
+ return;
1359
+ }
1360
+
1361
+ const barCount = peaks.length;
1362
+ const barWidth = Math.max(1, w / barCount - 1);
1363
+ const playedIdx = isActive ? Math.floor(progress * barCount) : -1;
1364
+ // Match the Generated-Fragments waveforms: teal accent for the played
1365
+ // portion, dimmed (35% alpha) for the rest.
1366
+ const playedColor = '#279FBB';
1367
+ const restColor = '#279FBB59';
1368
+
1369
+ for (let i = 0; i < barCount; i++) {
1370
+ const v = peaks[i];
1371
+ const barH = Math.max(1, v * (h - 2));
1372
+ const x = i * (w / barCount);
1373
+ const y = (h - barH) / 2;
1374
+ ctx.fillStyle = i <= playedIdx ? playedColor : restColor;
1375
+ ctx.fillRect(x, y, barWidth, barH);
1376
+ }
1377
+ }, [peaks, isActive, progress, theme]);
1378
+
1379
+ return (
1380
+ <Box sx={{ width: 120, height: 28, flexShrink: 0, opacity: failed ? 0.3 : 1 }}>
1381
+ <canvas
1382
+ ref={canvasRef}
1383
+ style={{ width: '100%', height: '100%', display: 'block' }}
1384
+ />
1385
+ </Box>
1386
+ );
1387
+ }
1388
+
1389
+ function ClipTable({ projectName, clips, playingFile, playProgress, onPlayToggle, onPromptChange, onAnnotate, onDelete, onSlice, selectedFiles, onToggleSelected, onToggleSelectAll, disabled, toolbar }) {
1390
+ const totalSelected = selectedFiles ? selectedFiles.size : 0;
1391
+ const allSelected = clips && clips.length > 0 && totalSelected === clips.length;
1392
+ const partiallySelected = totalSelected > 0 && !allSelected;
1393
+ if (!clips || clips.length === 0) {
1394
+ return (
1395
+ <Paper variant="outlined" sx={{ borderRadius: 2.5, overflow: 'hidden' }}>
1396
+ {toolbar && (
1397
+ <Box sx={{ px: 1.5, py: 1, borderBottom: 1, borderColor: 'divider' }}>
1398
+ {toolbar}
1399
+ </Box>
1400
+ )}
1401
+ <Box sx={{ py: 4, textAlign: 'center', color: 'text.secondary' }}>
1402
+ <Typography variant="body2">
1403
+ No clips yet. Use β€œAdd audio” to bring in a folder.
1404
+ </Typography>
1405
+ </Box>
1406
+ </Paper>
1407
+ );
1408
+ }
1409
+ return (
1410
+ <Paper variant="outlined" sx={{ borderRadius: 2.5, overflow: 'hidden' }}>
1411
+ {toolbar && (
1412
+ <Box sx={{ px: 1.5, py: 1, borderBottom: 1, borderColor: 'divider' }}>
1413
+ {toolbar}
1414
+ </Box>
1415
+ )}
1416
+ <TableContainer>
1417
+ <Table size="small">
1418
+ <TableHead>
1419
+ <TableRow>
1420
+ <TableCell padding="checkbox">
1421
+ <Checkbox
1422
+ size="small"
1423
+ checked={allSelected}
1424
+ indeterminate={partiallySelected}
1425
+ onChange={onToggleSelectAll}
1426
+ disabled={disabled || clips.length === 0}
1427
+ />
1428
+ </TableCell>
1429
+ <TableCell sx={{ width: '36%' }}>File</TableCell>
1430
+ <TableCell>Annotation</TableCell>
1431
+ <TableCell sx={{ width: 132, textAlign: 'right' }}>Actions</TableCell>
1432
+ </TableRow>
1433
+ </TableHead>
1434
+ <TableBody>
1435
+ {clips.map((c) => (
1436
+ <ClipRow
1437
+ key={c.file_name}
1438
+ projectName={projectName}
1439
+ clip={c}
1440
+ isPlaying={playingFile === c.file_name}
1441
+ playProgress={playingFile === c.file_name ? playProgress : 0}
1442
+ onPlayToggle={onPlayToggle}
1443
+ onPromptChange={onPromptChange}
1444
+ onAnnotate={onAnnotate}
1445
+ onDelete={onDelete}
1446
+ onSlice={onSlice}
1447
+ selected={selectedFiles ? selectedFiles.has(c.file_name) : false}
1448
+ onToggleSelected={onToggleSelected}
1449
+ disabled={disabled}
1450
+ />
1451
+ ))}
1452
+ </TableBody>
1453
+ </Table>
1454
+ </TableContainer>
1455
+ </Paper>
1456
+ );
1457
+ }
1458
+
1459
+ // React.memo so the 60Hz audio-playhead ticks don't reconcile every row in
1460
+ // the table. Custom comparator: skip if visual props didn't change. Callback
1461
+ // identity intentionally ignored β€” they're stable in behavior, just inline
1462
+ // arrows from the parent, and re-creating a row only to re-bind a click
1463
+ // handler isn't worth the work. playProgress only matters on the active row.
1464
+ const ClipRow = React.memo(function ClipRow({ projectName, clip, isPlaying, playProgress, onPlayToggle, onPromptChange, onAnnotate, onDelete, onSlice, selected, onToggleSelected, disabled }) {
1465
+ const [draft, setDraft] = useState(clip.prompt);
1466
+ useEffect(() => { setDraft(clip.prompt); }, [clip.prompt]);
1467
+
1468
+ const dirty = draft !== clip.prompt;
1469
+
1470
+ return (
1471
+ <TableRow hover selected={selected}>
1472
+ <TableCell padding="checkbox">
1473
+ <Checkbox
1474
+ size="small"
1475
+ checked={!!selected}
1476
+ onChange={() => onToggleSelected && onToggleSelected(clip.file_name)}
1477
+ />
1478
+ </TableCell>
1479
+ <TableCell sx={{ wordBreak: 'break-all' }}>
1480
+ <Stack direction="row" alignItems="center" spacing={1}>
1481
+ <IconButton
1482
+ size="small"
1483
+ onClick={() => onPlayToggle(clip.file_name)}
1484
+ sx={{ width: 28, height: 28 }}
1485
+ >
1486
+ {isPlaying ? <PauseIcon size={14} /> : <PlayIcon size={14} />}
1487
+ </IconButton>
1488
+ <Waveform
1489
+ projectName={projectName}
1490
+ fileName={clip.file_name}
1491
+ isActive={isPlaying}
1492
+ progress={playProgress}
1493
+ />
1494
+ <Typography variant="body2" sx={{ flex: 1, minWidth: 0, wordBreak: 'break-all' }}>
1495
+ {clip.file_name}
1496
+ </Typography>
1497
+ </Stack>
1498
+ </TableCell>
1499
+ <TableCell>
1500
+ <TextField
1501
+ fullWidth
1502
+ size="small"
1503
+ variant="standard"
1504
+ value={draft}
1505
+ onChange={(e) => setDraft(e.target.value)}
1506
+ onBlur={() => { if (dirty) onPromptChange(clip.file_name, draft); }}
1507
+ placeholder="(empty β€” write a prompt or auto-annotate)"
1508
+ disabled={disabled}
1509
+ />
1510
+ </TableCell>
1511
+ <TableCell sx={{ textAlign: 'right', whiteSpace: 'nowrap' }}>
1512
+ <Tooltip title={TIPS.dataset.autoAnnotateClip}>
1513
+ <span>
1514
+ <IconButton
1515
+ size="small"
1516
+ onClick={() => onAnnotate(clip.file_name)}
1517
+ disabled={disabled}
1518
+ sx={{ color: 'warm.main', '&:hover': { color: 'warm.light', bgcolor: 'action.hover' } }}
1519
+ >
1520
+ <WandSparkles size={16} />
1521
+ </IconButton>
1522
+ </span>
1523
+ </Tooltip>
1524
+ <Tooltip title={TIPS.dataset.sliceClip}>
1525
+ <span>
1526
+ <IconButton
1527
+ size="small"
1528
+ onClick={() => onSlice(clip.file_name)}
1529
+ disabled={disabled}
1530
+ >
1531
+ <ScissorsIcon size={16} />
1532
+ </IconButton>
1533
+ </span>
1534
+ </Tooltip>
1535
+ <Tooltip title={TIPS.dataset.removeClip}>
1536
+ <span>
1537
+ <IconButton
1538
+ size="small"
1539
+ onClick={() => onDelete(clip.file_name)}
1540
+ disabled={disabled}
1541
+ >
1542
+ <TrashIcon size={16} />
1543
+ </IconButton>
1544
+ </span>
1545
+ </Tooltip>
1546
+ </TableCell>
1547
+ </TableRow>
1548
+ );
1549
+ }, (prev, next) => {
1550
+ if (prev.clip !== next.clip) return false;
1551
+ if (prev.disabled !== next.disabled) return false;
1552
+ if (prev.projectName !== next.projectName) return false;
1553
+ if (prev.isPlaying !== next.isPlaying) return false;
1554
+ if (prev.selected !== next.selected) return false;
1555
+ // playProgress only matters when this row is the active one β€” inactive
1556
+ // rows always receive playProgress=0 from the parent, so they're skipped.
1557
+ if (next.isPlaying && prev.playProgress !== next.playProgress) return false;
1558
+ return true;
1559
+ });
1560
+
1561
+ function CreateProjectDialog({ open, existingNames, onClose, onCreated }) {
1562
+ const [name, setName] = useState('');
1563
+ const [busy, setBusy] = useState(false);
1564
+ const [dialogError, setDialogError] = useState('');
1565
+
1566
+ useEffect(() => {
1567
+ if (open) { setName(''); setDialogError(''); }
1568
+ }, [open]);
1569
+
1570
+ const duplicate = existingNames.includes(name.trim());
1571
+
1572
+ async function submit() {
1573
+ setDialogError('');
1574
+ setBusy(true);
1575
+ try {
1576
+ const { data } = await api.post('/api/projects', { name: name.trim() });
1577
+ await onCreated(data.name);
1578
+ } catch (e) {
1579
+ setDialogError(extractError(e, 'Failed to create project'));
1580
+ } finally {
1581
+ setBusy(false);
1582
+ }
1583
+ }
1584
+
1585
+ return (
1586
+ <Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
1587
+ <DialogTitle>New project</DialogTitle>
1588
+ <DialogContent>
1589
+ <Stack spacing={2} sx={{ pt: 1 }}>
1590
+ <TextField
1591
+ autoFocus
1592
+ label="Project name"
1593
+ value={name}
1594
+ onChange={(e) => setName(e.target.value)}
1595
+ helperText="Letters, digits, spaces, dashes, underscores, dots. Becomes a folder name on disk."
1596
+ error={duplicate}
1597
+ />
1598
+ {duplicate && (
1599
+ <Typography variant="caption" color="error">
1600
+ A project with this name already exists.
1601
+ </Typography>
1602
+ )}
1603
+ {dialogError && <Alert severity="error">{dialogError}</Alert>}
1604
+ </Stack>
1605
+ </DialogContent>
1606
+ <DialogActions>
1607
+ <Button onClick={onClose} disabled={busy}>Cancel</Button>
1608
+ <Button
1609
+ variant="contained"
1610
+ onClick={submit}
1611
+ disabled={busy || !name.trim() || duplicate}
1612
+ >
1613
+ Create
1614
+ </Button>
1615
+ </DialogActions>
1616
+ </Dialog>
1617
+ );
1618
+ }
1619
+
1620
+ function IngestDialog({ open, projectName, onClose, onIngested }) {
1621
+ const [folder, setFolder] = useState('');
1622
+ const [mode, setMode] = useState('copy');
1623
+ const [busy, setBusy] = useState(false);
1624
+ const [dialogError, setDialogError] = useState('');
1625
+
1626
+ useEffect(() => {
1627
+ if (open) { setFolder(''); setMode('copy'); setDialogError(''); }
1628
+ }, [open]);
1629
+
1630
+ async function pick() {
1631
+ try {
1632
+ const { data } = await api.post('/api/pick-folder', {});
1633
+ if (data?.path) setFolder(data.path);
1634
+ } catch (e) {
1635
+ setDialogError(extractError(e, 'Folder picker failed'));
1636
+ }
1637
+ }
1638
+
1639
+ async function submit() {
1640
+ if (!projectName) return;
1641
+ setBusy(true);
1642
+ setDialogError('');
1643
+ try {
1644
+ await api.post(
1645
+ `/api/projects/${encodeURIComponent(projectName)}/ingest`,
1646
+ { folder_path: folder, mode },
1647
+ );
1648
+ await onIngested();
1649
+ } catch (e) {
1650
+ setDialogError(extractError(e, 'Ingest failed'));
1651
+ } finally {
1652
+ setBusy(false);
1653
+ }
1654
+ }
1655
+
1656
+ return (
1657
+ <Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
1658
+ <DialogTitle>Add audio to {projectName}</DialogTitle>
1659
+ <DialogContent>
1660
+ <Stack spacing={2} sx={{ pt: 1 }}>
1661
+ <Stack direction="row" spacing={1.5} alignItems="center">
1662
+ <Button variant="outlined" startIcon={<FolderOpenIcon size={18} />} onClick={pick}>
1663
+ Pick folder
1664
+ </Button>
1665
+ <Typography variant="body2" color="text.secondary" sx={{ wordBreak: 'break-all' }}>
1666
+ {folder || 'No folder selected'}
1667
+ </Typography>
1668
+ </Stack>
1669
+
1670
+ <FormControl>
1671
+ <Typography variant="body2" gutterBottom>How to bring the audio in:</Typography>
1672
+ <RadioGroup value={mode} onChange={(e) => setMode(e.target.value)}>
1673
+ <FormControlLabel
1674
+ value="copy"
1675
+ control={<Radio size="small" />}
1676
+ label={<Typography variant="body2">Copy β€” duplicates audio into the project (safe, originals untouched)</Typography>}
1677
+ />
1678
+ <FormControlLabel
1679
+ value="symlink"
1680
+ control={<Radio size="small" />}
1681
+ label={<Typography variant="body2">Symlink β€” points at the originals (saves disk, breaks if you move them)</Typography>}
1682
+ />
1683
+ </RadioGroup>
1684
+ </FormControl>
1685
+
1686
+ {dialogError && <Alert severity="error">{dialogError}</Alert>}
1687
+ </Stack>
1688
+ </DialogContent>
1689
+ <DialogActions>
1690
+ <Button onClick={onClose} disabled={busy}>Cancel</Button>
1691
+ <Button variant="contained" onClick={submit} disabled={busy || !folder}>
1692
+ {busy ? 'Adding…' : 'Add'}
1693
+ </Button>
1694
+ </DialogActions>
1695
+ </Dialog>
1696
+ );
1697
+ }
1698
+
1699
+ function SliceDialog({ open, projectName, fileName, onClose, onSliced }) {
1700
+ const [target, setTarget] = useState(30);
1701
+ const [overlap, setOverlap] = useState(0);
1702
+ const [strategy, setStrategy] = useState('hard');
1703
+ const [duration, setDuration] = useState(null);
1704
+ const [busy, setBusy] = useState(false);
1705
+ const [dialogError, setDialogError] = useState('');
1706
+
1707
+ useEffect(() => {
1708
+ if (!open) return;
1709
+ setTarget(30);
1710
+ setOverlap(0);
1711
+ setStrategy('hard');
1712
+ setDialogError('');
1713
+ setDuration(null);
1714
+ if (!projectName || !fileName) return;
1715
+ // Reuse the peaks endpoint to pull duration cheaply (cached server-side).
1716
+ api.get(`/api/projects/${encodeURIComponent(projectName)}/clip/${encodeURIComponent(fileName)}/peaks?n=20`)
1717
+ .then(({ data }) => setDuration(data?.duration || null))
1718
+ .catch(() => setDuration(null));
1719
+ }, [open, projectName, fileName]);
1720
+
1721
+ const stepSec = Math.max(0.5, target - overlap);
1722
+ const estChildren = duration && target > 0 ? Math.max(1, Math.ceil(duration / stepSec)) : null;
1723
+ const tooShort = duration !== null && duration <= target;
1724
+
1725
+ async function submit() {
1726
+ setBusy(true);
1727
+ setDialogError('');
1728
+ try {
1729
+ await api.post(
1730
+ `/api/projects/${encodeURIComponent(projectName)}/clip/${encodeURIComponent(fileName)}/slice`,
1731
+ { target_duration: target, overlap_sec: overlap, strategy },
1732
+ );
1733
+ await onSliced();
1734
+ onClose();
1735
+ } catch (e) {
1736
+ setDialogError(extractError(e, 'Slice failed'));
1737
+ } finally {
1738
+ setBusy(false);
1739
+ }
1740
+ }
1741
+
1742
+ return (
1743
+ <Dialog open={open} onClose={busy ? undefined : onClose} maxWidth="sm" fullWidth>
1744
+ <DialogTitle>Slice {fileName || ''}</DialogTitle>
1745
+ <DialogContent>
1746
+ <Stack spacing={2.5} sx={{ pt: 1 }}>
1747
+ <Typography variant="body2" color="text.secondary">
1748
+ The original file will be replaced by the children on disk. Children inherit this clip's annotation. They stay in the project until you Create Dataset (Delete reverts them).
1749
+ </Typography>
1750
+
1751
+ <Stack direction="row" spacing={2}>
1752
+ <TextField
1753
+ label="Target duration (sec)"
1754
+ type="number"
1755
+ size="small"
1756
+ value={target}
1757
+ onChange={(e) => setTarget(Math.max(0.5, parseFloat(e.target.value) || 0))}
1758
+ inputProps={{ step: 0.5, min: 0.5, max: 60 }}
1759
+ fullWidth
1760
+ />
1761
+ <TextField
1762
+ label="Overlap (sec)"
1763
+ type="number"
1764
+ size="small"
1765
+ value={overlap}
1766
+ onChange={(e) => setOverlap(Math.max(0, parseFloat(e.target.value) || 0))}
1767
+ inputProps={{ step: 0.1, min: 0, max: Math.max(0, target - 0.5) }}
1768
+ fullWidth
1769
+ helperText="Head-overlap on every child after the first"
1770
+ />
1771
+ </Stack>
1772
+
1773
+ <FormControl>
1774
+ <Typography variant="body2" gutterBottom>Where each cut should land:</Typography>
1775
+ <RadioGroup value={strategy} onChange={(e) => setStrategy(e.target.value)}>
1776
+ <FormControlLabel
1777
+ value="hard"
1778
+ control={<Radio size="small" />}
1779
+ label={<Typography variant="body2">Hard cut β€” exact intervals; fastest, can split mid-note</Typography>}
1780
+ />
1781
+ <FormControlLabel
1782
+ value="transient"
1783
+ control={<Radio size="small" />}
1784
+ label={<Typography variant="body2">Transient-aware β€” snaps each cut to the nearest onset (good for drums / rhythmic)</Typography>}
1785
+ />
1786
+ <FormControlLabel
1787
+ value="silence"
1788
+ control={<Radio size="small" />}
1789
+ label={<Typography variant="body2">Silence-aware β€” snaps to the quietest moment in each window (good for melodic / phrased)</Typography>}
1790
+ />
1791
+ </RadioGroup>
1792
+ </FormControl>
1793
+
1794
+ {duration !== null && (
1795
+ <Typography variant="caption" color="text.secondary">
1796
+ Source: {duration.toFixed(1)}s
1797
+ {estChildren !== null && !tooShort && ` Β· ~${estChildren} children at this setting`}
1798
+ {tooShort && ' Β· already shorter than the target β€” nothing to slice'}
1799
+ </Typography>
1800
+ )}
1801
+
1802
+ {dialogError && <Alert severity="error">{dialogError}</Alert>}
1803
+ </Stack>
1804
+ </DialogContent>
1805
+ <DialogActions>
1806
+ <Button onClick={onClose} disabled={busy}>Cancel</Button>
1807
+ <Button
1808
+ variant="contained"
1809
+ onClick={submit}
1810
+ disabled={busy || tooShort || target <= 0 || overlap >= target}
1811
+ >
1812
+ {busy ? 'Slicing…' : 'Slice'}
1813
+ </Button>
1814
+ </DialogActions>
1815
+ </Dialog>
1816
+ );
1817
+ }
1818
+
1819
+ // ---------- utils ----------------------------------------------------------
1820
+
1821
+ function extractError(e, fallback) {
1822
+ return e?.response?.data?.error || e?.message || fallback;
1823
+ }
app/frontend/src/components/EditPanel.js ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useRef, useEffect } from 'react';
2
+ import {
3
+ Box,
4
+ Typography,
5
+ Button,
6
+ Stack,
7
+ TextField,
8
+ ToggleButton,
9
+ ToggleButtonGroup,
10
+ Slider,
11
+ Alert,
12
+ LinearProgress,
13
+ IconButton,
14
+ Switch,
15
+ FormControlLabel,
16
+ } from '@mui/material';
17
+ import { Upload as UploadIcon, X as ClearIcon, Play as PlayIcon, Square as StopIcon } from 'lucide-react';
18
+ import api from '../api';
19
+ import AudioWaveform from './AudioWaveform';
20
+ import { getFragmentDragPayload } from '../utils/fragmentDrag';
21
+
22
+ /**
23
+ * SA3 audio-to-audio + inpainting UI.
24
+ *
25
+ * Three modes:
26
+ * - Style transfer: feed a source clip + new prompt, init_noise_level
27
+ * controls how much character is preserved (0 = source-faithful,
28
+ * 1 = prompt-only).
29
+ * - Inpaint: regenerate a region of the source clip, keeping the rest.
30
+ * - Extend: append N seconds of new audio to the end of the source.
31
+ *
32
+ * All three send to /api/generate using SA3's init_audio / inpaint_audio
33
+ * params. The backend handles file resolution; this panel just uploads
34
+ * the source clip to /api/audio/upload and posts the returned path.
35
+ *
36
+ * Props:
37
+ * model_id: active SA3 model id
38
+ * negativePrompt: optional, passed through
39
+ * loraStack: [{path, strength, bypassed}] from the Generation panel β€”
40
+ * applied to the edit so style/inpaint/extend inherit the
41
+ * same LoRA character as plain generation.
42
+ * steps: sampler step count from the Generation panel.
43
+ * cfgScale: CFG from the Generation panel (only sent for *-base models;
44
+ * distilled models bake CFG at 1.0).
45
+ * onGenerated(blob, filename, params): called with the resulting WAV
46
+ */
47
+ export default function EditPanel({ model_id, negativePrompt, loraStack, steps, cfgScale, onGenerated }) {
48
+ const [mode, setMode] = useState('style'); // 'style' | 'inpaint' | 'extend'
49
+ const [sourcePath, setSourcePath] = useState('');
50
+ const [sourceName, setSourceName] = useState('');
51
+ const [sourceFile, setSourceFile] = useState(null); // kept for in-browser decode (waveform)
52
+ const [sourceUploading, setSourceUploading] = useState(false);
53
+ const [dropActive, setDropActive] = useState(false);
54
+ const [prompt, setPrompt] = useState('');
55
+ const [duration, setDuration] = useState(8);
56
+ // Seed: random by default, mirroring the rest of the app. When off, the
57
+ // numeric field is honoured (0 included β€” a legitimate seed).
58
+ const [randomSeed, setRandomSeed] = useState(true);
59
+ const [seedValue, setSeedValue] = useState('');
60
+
61
+ // sa3-medium generates up to 380s; small models cap at 120s. Matches the
62
+ // generator's _MODEL_INFO so the slider can't request past the model max.
63
+ const maxDuration = (model_id || '').includes('medium') ? 380 : 120;
64
+ // Distilled (post-trained) models bake CFG at 1.0 and ignore cfg_scale; only
65
+ // *-base variants honour it. Same rule the Generation panel uses.
66
+ const isDistilledBase =
67
+ !!model_id && model_id.startsWith('sa3-') && !model_id.endsWith('-base');
68
+
69
+ // style transfer
70
+ const [initNoiseLevel, setInitNoiseLevel] = useState(0.7);
71
+
72
+ // inpaint
73
+ const [maskStart, setMaskStart] = useState(2.0);
74
+ const [maskEnd, setMaskEnd] = useState(4.0);
75
+
76
+ // extend
77
+ const [extendSeconds, setExtendSeconds] = useState(4.0);
78
+ const [sourceDurationSec, setSourceDurationSec] = useState(null);
79
+
80
+ const [generating, setGenerating] = useState(false);
81
+ const [error, setError] = useState(null);
82
+ const fileInputRef = useRef(null);
83
+
84
+ // Inpaint region audition β€” a hidden <audio> set to the source clip, played
85
+ // from maskStart and auto-stopped at maskEnd, so users can hear the segment
86
+ // they're about to regenerate before committing.
87
+ const regionAudioRef = useRef(null);
88
+ const regionStopRef = useRef(null); // removes the active timeupdate guard
89
+ const [regionUrl, setRegionUrl] = useState(null);
90
+ const [regionPlaying, setRegionPlaying] = useState(false);
91
+
92
+ useEffect(() => {
93
+ if (!sourceFile) { setRegionUrl(null); return undefined; }
94
+ const url = URL.createObjectURL(sourceFile);
95
+ setRegionUrl(url);
96
+ return () => URL.revokeObjectURL(url);
97
+ }, [sourceFile]);
98
+
99
+ // Stop any in-flight preview when the source changes or the mode switches
100
+ // away from inpaint (don't auto-stop on every region drag β€” the end is
101
+ // captured per play, so dragging mid-play just runs to the old boundary).
102
+ useEffect(() => {
103
+ const a = regionAudioRef.current;
104
+ if (a) { try { a.pause(); } catch { /* ignore */ } }
105
+ regionStopRef.current?.();
106
+ regionStopRef.current = null;
107
+ setRegionPlaying(false);
108
+ }, [regionUrl, mode]);
109
+
110
+ const toggleRegionPreview = () => {
111
+ const a = regionAudioRef.current;
112
+ if (!a || !regionUrl) return;
113
+ if (regionPlaying) {
114
+ a.pause();
115
+ regionStopRef.current?.();
116
+ regionStopRef.current = null;
117
+ setRegionPlaying(false);
118
+ return;
119
+ }
120
+ const start = Math.max(0, Number(maskStart) || 0);
121
+ const end = Math.max(start + 0.05, Number(maskEnd) || 0);
122
+ const onTime = () => {
123
+ if (a.currentTime >= end) {
124
+ a.pause();
125
+ a.removeEventListener('timeupdate', onTime);
126
+ regionStopRef.current = null;
127
+ setRegionPlaying(false);
128
+ }
129
+ };
130
+ try { a.currentTime = start; } catch { /* ignore */ }
131
+ a.addEventListener('timeupdate', onTime);
132
+ regionStopRef.current = () => a.removeEventListener('timeupdate', onTime);
133
+ a.play()
134
+ .then(() => setRegionPlaying(true))
135
+ .catch(() => {
136
+ a.removeEventListener('timeupdate', onTime);
137
+ regionStopRef.current = null;
138
+ setRegionPlaying(false);
139
+ });
140
+ };
141
+
142
+ // --- source upload ---------------------------------------------------
143
+ const onPickFile = () => fileInputRef.current?.click();
144
+ const uploadFile = async (f) => {
145
+ if (!f) return;
146
+ setSourceUploading(true);
147
+ setError(null);
148
+ try {
149
+ const form = new FormData();
150
+ form.append('file', f);
151
+ const r = await api.post('/api/audio/upload', form);
152
+ setSourcePath(r.data.path);
153
+ setSourceName(r.data.name);
154
+ setSourceFile(f); // keep for in-browser waveform decode
155
+ // Probe duration via a temp object URL β†’ <audio>.
156
+ const url = URL.createObjectURL(f);
157
+ const a = new Audio(url);
158
+ a.addEventListener('loadedmetadata', () => {
159
+ if (Number.isFinite(a.duration)) {
160
+ setSourceDurationSec(a.duration);
161
+ // Default the output length to the source length (clamped to
162
+ // the model max). For inpaint this is mandatory β€” the mask is
163
+ // measured in source seconds, so the output must be the same
164
+ // length or the masked region drifts off the audio you see.
165
+ setDuration(Math.max(1, Math.min(maxDuration, Math.round(a.duration))));
166
+ // Seed inpaint region to the middle quarter so the
167
+ // waveform shows something sensible without a 4 s default
168
+ // landing past the end of short clips.
169
+ const q = a.duration / 4;
170
+ setMaskStart(Math.max(0, q));
171
+ setMaskEnd(Math.min(a.duration, q * 3));
172
+ }
173
+ URL.revokeObjectURL(url);
174
+ }, { once: true });
175
+ } catch (err) {
176
+ setError(err.response?.data?.error?.message || err.message || 'Upload failed');
177
+ } finally {
178
+ setSourceUploading(false);
179
+ }
180
+ };
181
+ const onFileChange = async (e) => {
182
+ const f = e.target.files?.[0];
183
+ e.target.value = '';
184
+ await uploadFile(f);
185
+ };
186
+ // Pull a fragment already on disk (dragged in from the Generated
187
+ // Fragments window) and run it through the same upload path so it gets a
188
+ // server path + waveform + duration probe, exactly like a picked file.
189
+ const loadFragmentByName = async (filename) => {
190
+ if (!filename) return;
191
+ setSourceUploading(true);
192
+ setError(null);
193
+ try {
194
+ const r = await api.get(`/api/fragments/${encodeURIComponent(filename)}`, { responseType: 'blob' });
195
+ const file = new File([r.data], filename, { type: r.data.type || 'audio/wav' });
196
+ await uploadFile(file);
197
+ } catch (err) {
198
+ setError(err.response?.data?.error?.message || err.message || 'Could not load fragment');
199
+ setSourceUploading(false);
200
+ }
201
+ };
202
+
203
+ const onDrop = async (e) => {
204
+ e.preventDefault();
205
+ setDropActive(false);
206
+ // In-app drag from the Generated Fragments window carries the
207
+ // fragment filename; OS file drags carry dataTransfer.files. Read the
208
+ // custom payload synchronously before any await.
209
+ const fragName = e.dataTransfer.getData('application/x-fragmenta-fragment');
210
+ if (fragName) {
211
+ // Prefer the in-memory blob handed off on dragStart β€” no disk
212
+ // round-trip, and immune to any in-memory vs on-disk name mismatch.
213
+ const payload = getFragmentDragPayload();
214
+ if (payload?.blob && payload.filename === fragName) {
215
+ const file = new File([payload.blob], fragName || 'fragment.wav', {
216
+ type: payload.blob.type || 'audio/wav',
217
+ });
218
+ await uploadFile(file);
219
+ } else {
220
+ // Fallback: blob wasn't preloaded β€” fetch it from disk by name.
221
+ await loadFragmentByName(fragName);
222
+ }
223
+ return;
224
+ }
225
+ const f = e.dataTransfer.files?.[0];
226
+ await uploadFile(f);
227
+ };
228
+ const onDragOver = (e) => { e.preventDefault(); setDropActive(true); };
229
+ const onDragLeave = (e) => { e.preventDefault(); setDropActive(false); };
230
+ const clearSource = () => {
231
+ setSourcePath('');
232
+ setSourceName('');
233
+ setSourceFile(null);
234
+ setSourceDurationSec(null);
235
+ };
236
+
237
+ // --- generate --------------------------------------------------------
238
+ const generate = async () => {
239
+ if (!model_id) {
240
+ setError('Pick a model in the Generation tab first.');
241
+ return;
242
+ }
243
+ if (!sourcePath) {
244
+ setError('Upload a source clip first.');
245
+ return;
246
+ }
247
+ if (!prompt.trim() && mode !== 'extend') {
248
+ setError('Enter a prompt describing the change.');
249
+ return;
250
+ }
251
+
252
+ setGenerating(true);
253
+ setError(null);
254
+ try {
255
+ // Seed: -1 lets the backend pick (and record) a random one; an
256
+ // explicit value is parsed with parseInt so 0 stays 0 rather than
257
+ // collapsing to random via `|| -1`.
258
+ let seedToSend = -1;
259
+ if (!randomSeed) {
260
+ const parsed = parseInt(seedValue, 10);
261
+ if (Number.isNaN(parsed) || parsed < 0) {
262
+ setError('Enter a non-negative integer seed, or switch Seed to Random.');
263
+ setGenerating(false);
264
+ return;
265
+ }
266
+ seedToSend = parsed;
267
+ }
268
+
269
+ const body = {
270
+ model_id,
271
+ prompt: prompt.trim() || 'continue',
272
+ duration,
273
+ seed: seedToSend,
274
+ steps,
275
+ };
276
+ if (negativePrompt) body.negative_prompt = negativePrompt;
277
+ // Only base models honour CFG; sending it on a distilled model is
278
+ // harmless (backend forces 1.0) but we keep the UI honest.
279
+ if (!isDistilledBase) body.cfg_scale = cfgScale;
280
+ // Inherit the Generation panel's LoRA stack. Bypassed slots stay in
281
+ // load order but contribute strength 0 (same as plain generation).
282
+ const activeLoras = (loraStack || [])
283
+ .filter((s) => s.path)
284
+ .map((s) => ({ path: s.path, strength: s.bypassed ? 0 : s.strength }));
285
+ if (activeLoras.length) body.loras = activeLoras;
286
+
287
+ if (mode === 'style') {
288
+ body.init_audio_path = sourcePath;
289
+ body.init_noise_level = initNoiseLevel;
290
+ } else if (mode === 'inpaint') {
291
+ // Pin output length to the source so the mask (measured in
292
+ // source seconds) maps onto the same timeline the user sees.
293
+ if (!Number.isFinite(sourceDurationSec)) {
294
+ setError("Couldn't read source duration β€” re-upload the file.");
295
+ setGenerating(false);
296
+ return;
297
+ }
298
+ body.duration = sourceDurationSec;
299
+ body.inpaint_audio_path = sourcePath;
300
+ body.inpaint_starts = [Number(maskStart)];
301
+ body.inpaint_ends = [Number(maskEnd)];
302
+ } else if (mode === 'extend') {
303
+ // Extend = inpaint where the mask is the new tail. Total clip
304
+ // duration = source length + extendSeconds; mask covers
305
+ // [source_length, source_length + extendSeconds].
306
+ if (!Number.isFinite(sourceDurationSec)) {
307
+ setError("Couldn't read source duration β€” re-upload the file.");
308
+ setGenerating(false);
309
+ return;
310
+ }
311
+ body.duration = sourceDurationSec + extendSeconds;
312
+ body.inpaint_audio_path = sourcePath;
313
+ body.inpaint_starts = [sourceDurationSec];
314
+ body.inpaint_ends = [sourceDurationSec + extendSeconds];
315
+ }
316
+
317
+ const resp = await api.post('/api/generate', body, { responseType: 'blob' });
318
+ // Use the backend's real on-disk name (header) so the fragment in
319
+ // the list resolves to an actual file for reveal/delete; only fall
320
+ // back to a synthetic name if the header is absent.
321
+ const fname = resp.headers?.['x-fragment-filename'] || `${mode}_${Date.now()}.wav`;
322
+ // Record the resolved seed (the backend picks a concrete one when we
323
+ // sent -1) so the fragment shows the real value, not "random".
324
+ const resolvedSeed = parseInt(resp.headers?.['x-fragment-seed'], 10);
325
+ const params = Number.isFinite(resolvedSeed) ? { ...body, seed: resolvedSeed } : body;
326
+ onGenerated?.(resp.data, fname, params);
327
+ } catch (err) {
328
+ setError(err.response?.data?.error?.message || err.message || 'Generation failed');
329
+ } finally {
330
+ setGenerating(false);
331
+ }
332
+ };
333
+
334
+ // --- render ----------------------------------------------------------
335
+ return (
336
+ <Box sx={{ p: 2 }}>
337
+ {/* Source picker (drag-and-drop or click) */}
338
+ <Box
339
+ sx={{ mb: 2 }}
340
+ onDragOver={onDragOver}
341
+ onDragLeave={onDragLeave}
342
+ onDrop={onDrop}
343
+ >
344
+ <Typography variant="caption" color="text.secondary" display="block" sx={{ mb: 0.5 }}>
345
+ Source clip
346
+ </Typography>
347
+ {sourcePath ? (
348
+ <Stack
349
+ direction="row"
350
+ alignItems="center"
351
+ spacing={1}
352
+ sx={{
353
+ p: 1,
354
+ border: '1px dashed',
355
+ borderColor: dropActive ? 'primary.main' : 'divider',
356
+ borderRadius: 1,
357
+ transition: 'border-color 120ms',
358
+ }}
359
+ >
360
+ <Typography variant="body2" sx={{ flex: 1, fontFamily: 'monospace', fontSize: 12, overflow: 'hidden', textOverflow: 'ellipsis' }}>
361
+ {sourceName}
362
+ {sourceDurationSec && ` Β· ${sourceDurationSec.toFixed(2)}s`}
363
+ </Typography>
364
+ <IconButton size="small" onClick={clearSource} aria-label="Remove source"><ClearIcon size={14} /></IconButton>
365
+ </Stack>
366
+ ) : (
367
+ <Button
368
+ variant="outlined"
369
+ startIcon={<UploadIcon size={14} />}
370
+ onClick={onPickFile}
371
+ disabled={sourceUploading}
372
+ fullWidth
373
+ sx={{
374
+ borderStyle: 'dashed',
375
+ borderColor: dropActive ? 'primary.main' : undefined,
376
+ bgcolor: dropActive ? 'action.hover' : undefined,
377
+ transition: 'border-color 120ms, background-color 120ms',
378
+ }}
379
+ >
380
+ {sourceUploading ? 'Uploading…' : 'Drop a clip here, or click to pick a file'}
381
+ </Button>
382
+ )}
383
+ <input
384
+ ref={fileInputRef}
385
+ type="file"
386
+ accept=".wav,.mp3,.flac,.m4a,.ogg,.opus,audio/*"
387
+ style={{ display: 'none' }}
388
+ onChange={onFileChange}
389
+ />
390
+ </Box>
391
+
392
+ {/* Mode selector */}
393
+ <ToggleButtonGroup
394
+ value={mode}
395
+ exclusive
396
+ size="small"
397
+ onChange={(_, v) => v && setMode(v)}
398
+ sx={{ mb: 2 }}
399
+ >
400
+ <ToggleButton value="style">Style transfer</ToggleButton>
401
+ <ToggleButton value="inpaint">Inpaint region</ToggleButton>
402
+ <ToggleButton value="extend">Extend</ToggleButton>
403
+ </ToggleButtonGroup>
404
+
405
+ {/* Mode-specific controls */}
406
+ {mode === 'style' && (
407
+ <Box sx={{ mb: 2 }}>
408
+ <Typography variant="caption" color="text.secondary">
409
+ Preserve source character ←→ follow prompt
410
+ </Typography>
411
+ <Stack direction="row" alignItems="center" spacing={2}>
412
+ <Slider
413
+ value={initNoiseLevel}
414
+ onChange={(_, v) => setInitNoiseLevel(v)}
415
+ min={0}
416
+ max={1}
417
+ step={0.05}
418
+ valueLabelDisplay="auto"
419
+ marks={[
420
+ { value: 0, label: '0' },
421
+ { value: 0.5, label: '0.5' },
422
+ { value: 1, label: '1' },
423
+ ]}
424
+ sx={{ flex: 1 }}
425
+ />
426
+ <Typography variant="body2" sx={{ width: 40, textAlign: 'right' }}>
427
+ {initNoiseLevel.toFixed(2)}
428
+ </Typography>
429
+ </Stack>
430
+ </Box>
431
+ )}
432
+
433
+ {mode === 'inpaint' && (
434
+ <Box sx={{ mb: 2 }}>
435
+ <Typography variant="caption" color="text.secondary" display="block" sx={{ mb: 0.5 }}>
436
+ Drag the highlighted region to inpaint
437
+ </Typography>
438
+ <AudioWaveform
439
+ file={sourceFile}
440
+ duration={sourceDurationSec || 0}
441
+ start={maskStart}
442
+ end={maskEnd}
443
+ onRegionChange={(s, e) => { setMaskStart(s); setMaskEnd(e); }}
444
+ />
445
+ <Stack direction="row" alignItems="center" spacing={2} sx={{ mt: 1 }}>
446
+ <TextField
447
+ label="Start (s)"
448
+ type="number"
449
+ size="small"
450
+ value={maskStart.toFixed(2)}
451
+ onChange={(e) => setMaskStart(parseFloat(e.target.value) || 0)}
452
+ inputProps={{ min: 0, max: sourceDurationSec || 999, step: 0.05 }}
453
+ sx={{ width: 96 }}
454
+ />
455
+ <TextField
456
+ label="End (s)"
457
+ type="number"
458
+ size="small"
459
+ value={maskEnd.toFixed(2)}
460
+ onChange={(e) => setMaskEnd(parseFloat(e.target.value) || 0)}
461
+ inputProps={{ min: 0, max: sourceDurationSec || 999, step: 0.05 }}
462
+ sx={{ width: 96 }}
463
+ />
464
+ <Box sx={{ flex: 1, display: 'flex', alignItems: 'center', gap: 1 }}>
465
+ <Button
466
+ size="small"
467
+ variant="outlined"
468
+ startIcon={regionPlaying ? <StopIcon size={14} /> : <PlayIcon size={14} />}
469
+ onClick={toggleRegionPreview}
470
+ disabled={!regionUrl || (maskEnd - maskStart) < 0.05}
471
+ // Fixed width so swapping "Preview" ↔ "Stop" doesn't
472
+ // resize the button. Sized to fit "Preview" + icon.
473
+ sx={{ width: 108, flexShrink: 0 }}
474
+ >
475
+ {regionPlaying ? 'Stop' : 'Preview'}
476
+ </Button>
477
+ <Typography variant="caption" color="text.secondary">
478
+ {(maskEnd - maskStart).toFixed(2)} s
479
+ </Typography>
480
+ </Box>
481
+ </Stack>
482
+ <Typography variant="caption" color="text.secondary" display="block" sx={{ mt: 1 }}>
483
+ Output is the same length as the source β€” only your selected region is replaced
484
+ </Typography>
485
+ </Box>
486
+ )}
487
+
488
+ {mode === 'extend' && (
489
+ <Box sx={{ mb: 2 }}>
490
+ <TextField
491
+ label="Seconds to add at the end"
492
+ type="number"
493
+ size="small"
494
+ value={extendSeconds}
495
+ onChange={(e) => setExtendSeconds(parseFloat(e.target.value) || 0)}
496
+ inputProps={{ min: 0.5, max: 60, step: 0.5 }}
497
+ fullWidth
498
+ />
499
+ <Typography variant="caption" color="text.secondary">
500
+ Source is {sourceDurationSec ? sourceDurationSec.toFixed(2) : 'β€”'} s; final clip will be{' '}
501
+ {sourceDurationSec ? (sourceDurationSec + Number(extendSeconds || 0)).toFixed(2) : 'β€”'} s.
502
+ </Typography>
503
+ </Box>
504
+ )}
505
+
506
+ {/* Shared inputs */}
507
+ <TextField
508
+ label={mode === 'inpaint' ? 'Prompt for the inpainting region' : 'Prompt for the edit'}
509
+ placeholder={
510
+ mode === 'style' ? 'How the source should sound now…' :
511
+ mode === 'inpaint' ? 'What goes in the gap…' :
512
+ 'What the continuation should sound like (optional)'
513
+ }
514
+ multiline
515
+ minRows={1}
516
+ maxRows={3}
517
+ value={prompt}
518
+ onChange={(e) => setPrompt(e.target.value)}
519
+ fullWidth
520
+ sx={{ mb: 2 }}
521
+ />
522
+
523
+ {mode === 'style' && (
524
+ <Stack direction="row" alignItems="center" spacing={2} sx={{ mb: 2 }}>
525
+ <Typography variant="body2" color="text.secondary" sx={{ minWidth: 80 }}>
526
+ Duration
527
+ </Typography>
528
+ <Slider
529
+ value={Math.min(duration, maxDuration)}
530
+ onChange={(_, v) => setDuration(v)}
531
+ min={1}
532
+ max={maxDuration}
533
+ step={1}
534
+ valueLabelDisplay="auto"
535
+ sx={{ flex: 1 }}
536
+ />
537
+ <Typography variant="body2" sx={{ width: 40, textAlign: 'right' }}>
538
+ {duration}s
539
+ </Typography>
540
+ </Stack>
541
+ )}
542
+
543
+ {/* Seed β€” random by default, mirrors the Generation panel */}
544
+ <Stack direction="row" alignItems="center" spacing={2} sx={{ mb: 2 }}>
545
+ <Typography variant="body2" color="text.secondary" sx={{ minWidth: 80 }}>
546
+ Seed
547
+ </Typography>
548
+ <FormControlLabel
549
+ control={
550
+ <Switch
551
+ size="small"
552
+ checked={randomSeed}
553
+ onChange={(e) => setRandomSeed(e.target.checked)}
554
+ />
555
+ }
556
+ label="Random"
557
+ sx={{ mr: 0 }}
558
+ />
559
+ <TextField
560
+ size="small"
561
+ type="number"
562
+ value={seedValue}
563
+ disabled={randomSeed}
564
+ onChange={(e) => setSeedValue(e.target.value)}
565
+ placeholder={randomSeed ? 'Randomized each run (recorded)' : 'e.g. 42'}
566
+ inputProps={{ min: 0, step: 1 }}
567
+ sx={{ flex: 1 }}
568
+ />
569
+ </Stack>
570
+
571
+ {/* Hidden element backing the inpaint region preview */}
572
+ <audio
573
+ ref={regionAudioRef}
574
+ src={regionUrl || undefined}
575
+ preload="auto"
576
+ style={{ display: 'none' }}
577
+ onEnded={() => setRegionPlaying(false)}
578
+ />
579
+
580
+ {error && <Alert severity="error" sx={{ mb: 2 }}>{error}</Alert>}
581
+ {generating && <LinearProgress sx={{ mb: 2 }} />}
582
+
583
+ <Button
584
+ variant="contained"
585
+ fullWidth
586
+ onClick={generate}
587
+ disabled={generating || !sourcePath}
588
+ >
589
+ {generating
590
+ ? 'Generating…'
591
+ : mode === 'style' ? 'Apply style'
592
+ : mode === 'inpaint' ? 'Inpaint region'
593
+ : 'Extend clip'}
594
+ </Button>
595
+ </Box>
596
+ );
597
+ }
app/frontend/src/components/GeneratedFragmentsWindow.js CHANGED
@@ -1,27 +1,242 @@
1
- import React, { useState, useRef, useCallback } from 'react';
2
- import { Paper, Box, Typography, Button, List, ListItem, IconButton } from '@mui/material';
3
- import { Square as StopIcon, Play as PlayIcon, Download as DownloadIcon } from 'lucide-react';
4
- import api from '../api';
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import { generatedFragmentsWindowStyles } from '../theme';
 
 
 
6
 
7
- export default function GeneratedFragmentsWindow({ fragments, onDownload }) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  const [playingFragment, setPlayingFragment] = useState(null);
 
 
9
  const audioRefs = useRef({});
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  const handlePlayPause = (fragment) => {
12
  const audio = audioRefs.current[fragment.id];
13
  if (!audio) return;
14
 
15
- if (playingFragment === fragment.id) {
 
 
 
16
  audio.pause();
 
17
  setPlayingFragment(null);
18
- } else {
19
- if (playingFragment && audioRefs.current[playingFragment]) {
20
- audioRefs.current[playingFragment].pause();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  }
22
- audio.play();
23
- setPlayingFragment(fragment.id);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  };
26
 
27
  const setAudioRef = useCallback((fragmentId, audioElement) => {
@@ -31,90 +246,225 @@ export default function GeneratedFragmentsWindow({ fragments, onDownload }) {
31
  }, []);
32
 
33
  return (
34
- <Paper
35
- variant="outlined"
36
- sx={generatedFragmentsWindowStyles.rootPaper}
37
- >
38
  <Box sx={generatedFragmentsWindowStyles.headerRow}>
39
  <Box sx={generatedFragmentsWindowStyles.titleRow}>
40
  <Box component="span" sx={generatedFragmentsWindowStyles.titleIcon}>
41
- <DownloadIcon size={20} />
42
  </Box>
43
  <Typography variant="h6" sx={generatedFragmentsWindowStyles.titleText}>
44
  Generated Fragments
45
  </Typography>
46
  </Box>
47
- <Typography variant="caption" color="textSecondary" sx={generatedFragmentsWindowStyles.countText}>
48
- {fragments.length}
49
- </Typography>
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  </Box>
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  {fragments.length === 0 ? (
53
- <Box
54
- sx={generatedFragmentsWindowStyles.emptyState}
55
- >
 
 
 
 
 
 
 
 
 
56
  <Typography variant="body2">
57
- No fragments generated yet
 
 
58
  </Typography>
59
  </Box>
60
  ) : (
61
- <List
62
- sx={generatedFragmentsWindowStyles.listRoot}
63
- >
64
- {fragments.slice().reverse().map((fragment) => (
65
- <ListItem
66
- key={fragment.id}
67
- sx={generatedFragmentsWindowStyles.listItem}
68
- >
69
- <Box sx={generatedFragmentsWindowStyles.fragmentRow}>
70
- <Box sx={generatedFragmentsWindowStyles.fragmentMeta}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  <Typography
72
- variant="subtitle2"
73
  sx={generatedFragmentsWindowStyles.fragmentPrompt}
 
74
  >
75
  {fragment.batchTotal > 1 && (
76
- <Box component="span" sx={{ fontWeight: 700, mr: 0.75 }}>
77
- [{fragment.batchIndex}/{fragment.batchTotal}]
78
  </Box>
79
  )}
80
  {fragment.prompt}
81
  </Typography>
82
- <Typography variant="caption" color="textSecondary">
83
- {fragment.duration}s
84
- {fragment.cfgScale !== undefined && ` β€’ CFG ${fragment.cfgScale}`}
85
- {fragment.seed !== undefined && ` β€’ Seed ${fragment.seed}`}
86
- {' β€’ '}{fragment.timestamp}
87
- </Typography>
88
  </Box>
89
- <Box sx={generatedFragmentsWindowStyles.fragmentActions}>
90
- <IconButton
91
- size="small"
92
- onClick={() => handlePlayPause(fragment)}
93
- color={playingFragment === fragment.id ? "primary" : "default"}
94
- sx={generatedFragmentsWindowStyles.playPauseButton(playingFragment === fragment.id)}
95
- >
96
- {playingFragment === fragment.id ? <StopIcon /> : <PlayIcon />}
97
- </IconButton>
98
- <Button
99
- size="small"
100
- variant="outlined"
101
- startIcon={<DownloadIcon />}
102
- onClick={() => onDownload(fragment)}
 
 
 
 
 
 
 
103
  >
104
- Download
105
- </Button>
106
- </Box>
107
- </Box>
108
-
109
- <audio
110
- ref={el => setAudioRef(fragment.id, el)}
111
- src={fragment.audioUrl}
112
- onEnded={() => setPlayingFragment(null)}
113
- onPause={() => setPlayingFragment(null)}
114
- style={generatedFragmentsWindowStyles.hiddenAudio}
115
- />
116
- </ListItem>
117
- ))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  </List>
119
  )}
120
  </Paper>
 
1
+ import React, { useState, useRef, useCallback, useEffect } from 'react';
2
+ import {
3
+ Paper, Box, Typography, List, ListItem, IconButton,
4
+ Dialog, DialogTitle, DialogContent, DialogContentText, DialogActions, Button,
5
+ CircularProgress,
6
+ } from '@mui/material';
7
+ import { TIPS } from '../tooltips';
8
+ import Tooltip from './Tooltip';
9
+ import {
10
+ Square as StopIcon,
11
+ Play as PlayIcon,
12
+ AudioLines as TitleIcon,
13
+ Info as InfoIcon,
14
+ Trash2 as DeleteIcon,
15
+ Eraser as ClearAllIcon,
16
+ FolderOpen as RevealIcon,
17
+ } from 'lucide-react';
18
  import { generatedFragmentsWindowStyles } from '../theme';
19
+ import GenerationWaveform from './GenerationWaveform';
20
+ import api from '../api';
21
+ import { setFragmentDragPayload, clearFragmentDragPayload } from '../utils/fragmentDrag';
22
 
23
+ // Compact human-readable "X ago" with absolute fallback for stale items.
24
+ function relativeTime(createdAt) {
25
+ if (!createdAt) return '';
26
+ const sec = Math.max(0, (Date.now() - createdAt) / 1000);
27
+ if (sec < 10) return 'just now';
28
+ if (sec < 60) return `${Math.floor(sec)}s ago`;
29
+ const min = sec / 60;
30
+ if (min < 60) return `${Math.floor(min)}m ago`;
31
+ const hr = min / 60;
32
+ if (hr < 24) return `${Math.floor(hr)}h ago`;
33
+ const day = hr / 24;
34
+ if (day < 7) return `${Math.floor(day)}d ago`;
35
+ // Older than a week β€” show absolute date, no time
36
+ return new Date(createdAt).toLocaleDateString();
37
+ }
38
+
39
+ export default function GeneratedFragmentsWindow({ fragments, onDelete, onClearAll }) {
40
  const [playingFragment, setPlayingFragment] = useState(null);
41
+ const [playingTime, setPlayingTime] = useState(0);
42
+ const [clearConfirmOpen, setClearConfirmOpen] = useState(false);
43
  const audioRefs = useRef({});
44
+ // Tracks a play request that's between "user clicked Play" and "audio
45
+ // actually started". If the user clicks again during this window we
46
+ // need to either no-op (same fragment) or cleanly cancel (different
47
+ // fragment) β€” re-entering load() would abort the first play() and
48
+ // both attempts would fail with AbortError.
49
+ const playInFlightRef = useRef(null);
50
+
51
+ // Background-preload of disk-hydrated fragments. On app reload the parent
52
+ // gives us fragment metadata + the backend URL (/api/fragments/...) but
53
+ // no in-memory Blob. The first Play click on those would HTTP-fetch the
54
+ // file synchronously through the <audio> element and freeze briefly. We
55
+ // pre-fetch them in parallel on mount and gate the UI behind a single
56
+ // loading screen β€” once everything is ready, plays + waveform decodes
57
+ // are instant because they work off blob: URLs.
58
+ const fetchingIdsRef = useRef(new Set());
59
+ const loadedRef = useRef({}); // { [id]: { blob, blobUrl } }
60
+ const [loadedTick, setLoadedTick] = useState(0);
61
+
62
+ useEffect(() => {
63
+ let cancelled = false;
64
+ fragments.forEach((frag) => {
65
+ if (frag.audioBlob) return; // already in memory
66
+ if (loadedRef.current[frag.id]) return; // already preloaded
67
+ if (fetchingIdsRef.current.has(frag.id)) return;
68
+ if (!frag.audioUrl) return;
69
+ fetchingIdsRef.current.add(frag.id);
70
+ fetch(frag.audioUrl)
71
+ .then((r) => {
72
+ if (!r.ok) throw new Error(`HTTP ${r.status}`);
73
+ return r.blob();
74
+ })
75
+ .then((blob) => {
76
+ if (cancelled) return;
77
+ const blobUrl = URL.createObjectURL(blob);
78
+ loadedRef.current[frag.id] = { blob, blobUrl };
79
+ setLoadedTick((t) => t + 1);
80
+ })
81
+ .catch((err) => {
82
+ console.warn(`Fragment preload failed (${frag.filename || frag.id}):`, err);
83
+ })
84
+ .finally(() => {
85
+ fetchingIdsRef.current.delete(frag.id);
86
+ });
87
+ });
88
+ return () => { cancelled = true; };
89
+ }, [fragments]);
90
+
91
+ // Revoke all preload blob URLs on unmount so we don't leak.
92
+ useEffect(() => () => {
93
+ Object.values(loadedRef.current).forEach(({ blobUrl }) => {
94
+ try { URL.revokeObjectURL(blobUrl); } catch { /* ignore */ }
95
+ });
96
+ }, []);
97
+
98
+ // Per-fragment helpers that prefer the in-memory blob (immediate) over
99
+ // the HTTP URL. Defined after loadedTick is read so React knows to
100
+ // re-render when a new fragment finishes preloading.
101
+ void loadedTick;
102
+ const effectiveBlob = (frag) => frag.audioBlob || loadedRef.current[frag.id]?.blob || null;
103
+ const effectiveUrl = (frag) => loadedRef.current[frag.id]?.blobUrl || frag.audioUrl;
104
+ const isFragmentReady = (frag) => !!frag.audioBlob || !!loadedRef.current[frag.id];
105
 
106
+ const readyCount = fragments.filter(isFragmentReady).length;
107
+ const allReady = fragments.length === 0 || readyCount === fragments.length;
108
+
109
+ // Safety buffer: once everything reports ready, keep the loading overlay
110
+ // up for an extra 5s before revealing the list. Audio decodes that are
111
+ // still settling in the background can't be poked (and can't crash the
112
+ // list) while the user is gated behind the spinner.
113
+ const GRACE_MS = 5000;
114
+ const [graceDone, setGraceDone] = useState(false);
115
+ useEffect(() => {
116
+ if (fragments.length === 0) { setGraceDone(true); return undefined; }
117
+ if (!allReady) { setGraceDone(false); return undefined; }
118
+ const t = setTimeout(() => setGraceDone(true), GRACE_MS);
119
+ return () => clearTimeout(t);
120
+ }, [allReady, fragments.length]);
121
+ const showLoading = fragments.length > 0 && (!allReady || !graceDone);
122
+
123
+ // Strict single-play with first-click readiness gate.
124
+ //
125
+ // Race-fixes the old version had:
126
+ // 1. Iterate audioRefs.current and pause everything that isn't the
127
+ // new target β€” avoids losing the race when two play clicks land
128
+ // before React state settles.
129
+ // 2. For blob URLs, Chromium often doesn't actually pull bytes until
130
+ // the first play() call, and play() rejects/hangs if readyState
131
+ // is too low. If we're not ready, call load() and wait for
132
+ // `canplay` (with a 1500 ms safety timeout) before play().
133
+ // 3. Guard against the user clicking Play twice during loading. A
134
+ // second load() while the first play() is still pending aborts
135
+ // the first with AbortError. playInFlightRef tracks the active
136
+ // request: same-fragment second click is a no-op; different
137
+ // fragment cleanly cancels the prior load timer/listener.
138
  const handlePlayPause = (fragment) => {
139
  const audio = audioRefs.current[fragment.id];
140
  if (!audio) return;
141
 
142
+ // Stop case: this fragment is currently playing β†’ pause it.
143
+ if (!audio.paused) {
144
+ playInFlightRef.current?.cleanup?.();
145
+ playInFlightRef.current = null;
146
  audio.pause();
147
+ audio.currentTime = 0;
148
  setPlayingFragment(null);
149
+ setPlayingTime(0);
150
+ return;
151
+ }
152
+
153
+ // Click during loading of the SAME fragment β†’ ignore.
154
+ if (playInFlightRef.current?.fragmentId === fragment.id) {
155
+ return;
156
+ }
157
+ // Click during loading of a DIFFERENT fragment β†’ cancel that.
158
+ if (playInFlightRef.current) {
159
+ playInFlightRef.current.cleanup?.();
160
+ playInFlightRef.current = null;
161
+ }
162
+
163
+ Object.values(audioRefs.current).forEach((el) => {
164
+ if (el && el !== audio) {
165
+ el.pause();
166
+ el.currentTime = 0;
167
  }
168
+ });
169
+
170
+ const startedFor = fragment.id;
171
+ setPlayingFragment(startedFor);
172
+ setPlayingTime(0);
173
+
174
+ const startPlayback = () => {
175
+ audio.currentTime = 0;
176
+ Promise.resolve(audio.play())
177
+ .then(() => {
178
+ // Successfully playing β€” clear the in-flight marker so
179
+ // the next Play click can fire a fresh request.
180
+ if (playInFlightRef.current?.fragmentId === startedFor) {
181
+ playInFlightRef.current = null;
182
+ }
183
+ })
184
+ .catch((err) => {
185
+ // AbortError is expected when the user cancels (clicks
186
+ // Stop or switches fragments) β€” don't noise the log.
187
+ if (err && err.name !== 'AbortError') {
188
+ console.warn(`Fragment play failed (${fragment.filename || fragment.id}):`, err);
189
+ }
190
+ setPlayingFragment((prev) => (prev === startedFor ? null : prev));
191
+ setPlayingTime(0);
192
+ if (playInFlightRef.current?.fragmentId === startedFor) {
193
+ playInFlightRef.current = null;
194
+ }
195
+ });
196
+ };
197
+
198
+ if (audio.readyState >= 2) {
199
+ playInFlightRef.current = { fragmentId: startedFor, cleanup: null };
200
+ startPlayback();
201
+ return;
202
  }
203
+
204
+ // Not ready yet β€” load and wait for canplay (or 1.5 s timeout).
205
+ try { audio.load(); } catch { /* ignore */ }
206
+ let cancelled = false;
207
+ const onReady = () => {
208
+ audio.removeEventListener('canplay', onReady);
209
+ clearTimeout(timer);
210
+ if (cancelled) return;
211
+ startPlayback();
212
+ };
213
+ audio.addEventListener('canplay', onReady, { once: true });
214
+ // 5 s β€” disk-hydrated fragments fetch from /api/fragments/...
215
+ // over HTTP, which can take a couple of seconds on first request.
216
+ // Blob-URL fragments (in-memory) hit canplay almost instantly.
217
+ const timer = setTimeout(() => {
218
+ audio.removeEventListener('canplay', onReady);
219
+ if (!cancelled) startPlayback();
220
+ }, 5000);
221
+ playInFlightRef.current = {
222
+ fragmentId: startedFor,
223
+ cleanup: () => {
224
+ cancelled = true;
225
+ audio.removeEventListener('canplay', onReady);
226
+ clearTimeout(timer);
227
+ },
228
+ };
229
+ };
230
+
231
+ // Reveal a fragment in the OS file manager (folder opens with the file
232
+ // highlighted where the platform supports it). Disk-hydrated fragments
233
+ // always have a filename; in-memory-only ones (not yet flushed) won't.
234
+ const revealInFolder = (fragment) => {
235
+ if (!fragment.filename) return;
236
+ api.post('/api/reveal-fragment', { filename: fragment.filename })
237
+ .catch((err) => {
238
+ console.warn(`Reveal failed (${fragment.filename}):`, err);
239
+ });
240
  };
241
 
242
  const setAudioRef = useCallback((fragmentId, audioElement) => {
 
246
  }, []);
247
 
248
  return (
249
+ <Paper variant="outlined" sx={generatedFragmentsWindowStyles.rootPaper}>
 
 
 
250
  <Box sx={generatedFragmentsWindowStyles.headerRow}>
251
  <Box sx={generatedFragmentsWindowStyles.titleRow}>
252
  <Box component="span" sx={generatedFragmentsWindowStyles.titleIcon}>
253
+ <TitleIcon size={20} />
254
  </Box>
255
  <Typography variant="h6" sx={generatedFragmentsWindowStyles.titleText}>
256
  Generated Fragments
257
  </Typography>
258
  </Box>
259
+ <Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5 }}>
260
+ <Typography variant="caption" color="textSecondary" sx={generatedFragmentsWindowStyles.countText}>
261
+ {fragments.length}
262
+ </Typography>
263
+ {fragments.length > 0 && onClearAll && (
264
+ <Tooltip title={TIPS.fragments.clearAll} placement="top" arrow>
265
+ <IconButton
266
+ size="small"
267
+ onClick={() => setClearConfirmOpen(true)}
268
+ sx={{ color: 'text.disabled', '&:hover': { color: 'error.main' } }}
269
+ >
270
+ <ClearAllIcon size={14} />
271
+ </IconButton>
272
+ </Tooltip>
273
+ )}
274
+ </Box>
275
  </Box>
276
 
277
+ <Dialog open={clearConfirmOpen} onClose={() => setClearConfirmOpen(false)}>
278
+ <DialogTitle>Clear all generated fragments?</DialogTitle>
279
+ <DialogContent>
280
+ <DialogContentText>
281
+ Permanently delete all {fragments.length} fragment{fragments.length === 1 ? '' : 's'} from disk.
282
+ Uploaded source clips (used by Edit mode) are not affected.
283
+ </DialogContentText>
284
+ </DialogContent>
285
+ <DialogActions>
286
+ <Button onClick={() => setClearConfirmOpen(false)}>Cancel</Button>
287
+ <Button
288
+ onClick={() => { setClearConfirmOpen(false); onClearAll?.(); }}
289
+ color="error"
290
+ variant="contained"
291
+ >
292
+ Delete all
293
+ </Button>
294
+ </DialogActions>
295
+ </Dialog>
296
+
297
  {fragments.length === 0 ? (
298
+ <Box sx={generatedFragmentsWindowStyles.emptyState}>
299
+ <Typography variant="body2">No fragments generated yet</Typography>
300
+ </Box>
301
+ ) : showLoading ? (
302
+ <Box sx={{
303
+ ...generatedFragmentsWindowStyles.emptyState,
304
+ display: 'flex',
305
+ flexDirection: 'column',
306
+ alignItems: 'center',
307
+ gap: 1.5,
308
+ }}>
309
+ <CircularProgress size={28} />
310
  <Typography variant="body2">
311
+ {allReady
312
+ ? 'Finishing up…'
313
+ : `Loading fragments… ${readyCount} / ${fragments.length}`}
314
  </Typography>
315
  </Box>
316
  ) : (
317
+ <List sx={generatedFragmentsWindowStyles.listRoot}>
318
+ {fragments.slice().reverse().map((fragment) => {
319
+ const isPlaying = playingFragment === fragment.id;
320
+ const ago = relativeTime(fragment.createdAt);
321
+ // CFG, seed, full timestamp, and model go in the info
322
+ // tooltip β€” accessible but not pushing the row out.
323
+ const tooltipLines = [
324
+ // Pre-fix fragments stored -1 for a random seed;
325
+ // show that as "random" rather than a bare -1.
326
+ `Seed: ${(fragment.seed != null && fragment.seed >= 0) ? fragment.seed : 'random'}`,
327
+ // Distilled SA3 models have CFG distilled away β€” it's
328
+ // genuinely not applicable, not missing.
329
+ `CFG: ${fragment.cfgScale ?? 'n/a'}`,
330
+ fragment.steps != null ? `Steps: ${fragment.steps}` : null,
331
+ fragment.modelId ? `Model: ${fragment.modelId}` : null,
332
+ fragment.editMode ? `Mode: ${fragment.editMode}` : null,
333
+ `Duration: ${fragment.duration}s`,
334
+ ago ? `Generated: ${ago}` : null,
335
+ fragment.timestamp ? fragment.timestamp : null,
336
+ ].filter(Boolean).join('\n');
337
+
338
+ return (
339
+ <ListItem
340
+ key={fragment.id}
341
+ sx={generatedFragmentsWindowStyles.listItem}
342
+ >
343
+ <IconButton
344
+ size="small"
345
+ onClick={() => handlePlayPause(fragment)}
346
+ aria-label={isPlaying ? 'Stop' : 'Play'}
347
+ sx={generatedFragmentsWindowStyles.playPauseButton(isPlaying)}
348
+ >
349
+ {isPlaying ? <StopIcon size={16} /> : <PlayIcon size={16} />}
350
+ </IconButton>
351
+
352
+ <Box
353
+ sx={{ ...generatedFragmentsWindowStyles.fragmentMeta, cursor: 'grab' }}
354
+ draggable
355
+ onDragStart={(e) => {
356
+ // In-app payload consumed by EditPanel's drop zone
357
+ // ("drag a clip into the Edit tab"). Keeps the
358
+ // waveform's separate OS drag-out untouched.
359
+ e.dataTransfer.setData(
360
+ 'application/x-fragmenta-fragment',
361
+ fragment.filename || '',
362
+ );
363
+ e.dataTransfer.effectAllowed = 'copy';
364
+ // Hand off the in-memory blob too so the drop can
365
+ // use it directly β€” no disk fetch, immune to any
366
+ // name mismatch. Falls back to the filename when
367
+ // the blob isn't preloaded yet.
368
+ const blob = effectiveBlob(fragment);
369
+ if (blob) {
370
+ setFragmentDragPayload({
371
+ filename: fragment.filename || '',
372
+ blob,
373
+ });
374
+ }
375
+ }}
376
+ onDragEnd={() => clearFragmentDragPayload()}
377
+ title="Drag into the Edit tab to use as a source clip"
378
+ >
379
  <Typography
380
+ variant="body2"
381
  sx={generatedFragmentsWindowStyles.fragmentPrompt}
382
+ title={fragment.prompt}
383
  >
384
  {fragment.batchTotal > 1 && (
385
+ <Box component="span" sx={generatedFragmentsWindowStyles.batchTag}>
386
+ {fragment.batchIndex}/{fragment.batchTotal}
387
  </Box>
388
  )}
389
  {fragment.prompt}
390
  </Typography>
 
 
 
 
 
 
391
  </Box>
392
+
393
+ <GenerationWaveform
394
+ blob={effectiveBlob(fragment)}
395
+ audioUrl={effectiveUrl(fragment)}
396
+ filename={fragment.filename || 'fragment.wav'}
397
+ currentTime={isPlaying ? playingTime : 0}
398
+ duration={fragment.duration || 0}
399
+ />
400
+
401
+ <Tooltip
402
+ title={
403
+ <Box component="span" sx={{ whiteSpace: 'pre-line' }}>
404
+ {tooltipLines}
405
+ </Box>
406
+ }
407
+ arrow
408
+ placement="top"
409
+ >
410
+ <Box
411
+ component="span"
412
+ sx={generatedFragmentsWindowStyles.fragmentInfoIcon}
413
  >
414
+ <InfoIcon size={14} />
415
+ </Box>
416
+ </Tooltip>
417
+
418
+ {fragment.filename && (
419
+ <Tooltip title={TIPS.fragments.revealInFolder} placement="top" arrow>
420
+ <IconButton
421
+ size="small"
422
+ onClick={() => revealInFolder(fragment)}
423
+ aria-label="Show in folder"
424
+ sx={{ color: 'text.disabled', '&:hover': { color: 'primary.main', bgcolor: 'action.hover' } }}
425
+ >
426
+ <RevealIcon size={16} />
427
+ </IconButton>
428
+ </Tooltip>
429
+ )}
430
+
431
+ {onDelete && (
432
+ <Tooltip title={TIPS.fragments.deleteFromDisk} placement="top" arrow>
433
+ <IconButton
434
+ size="small"
435
+ onClick={() => onDelete(fragment)}
436
+ sx={{ color: 'text.disabled', '&:hover': { color: 'error.main', bgcolor: 'action.hover' } }}
437
+ >
438
+ <DeleteIcon size={16} />
439
+ </IconButton>
440
+ </Tooltip>
441
+ )}
442
+
443
+ <audio
444
+ ref={el => setAudioRef(fragment.id, el)}
445
+ src={effectiveUrl(fragment)}
446
+ preload="auto"
447
+ onTimeUpdate={(e) => {
448
+ if (playingFragment === fragment.id) {
449
+ setPlayingTime(e.target.currentTime);
450
+ }
451
+ }}
452
+ onEnded={() => {
453
+ if (playingFragment === fragment.id) {
454
+ setPlayingFragment(null);
455
+ setPlayingTime(0);
456
+ }
457
+ }}
458
+ onPause={() => {
459
+ if (playingFragment === fragment.id) {
460
+ setPlayingFragment(null);
461
+ }
462
+ }}
463
+ style={generatedFragmentsWindowStyles.hiddenAudio}
464
+ />
465
+ </ListItem>
466
+ );
467
+ })}
468
  </List>
469
  )}
470
  </Paper>
app/frontend/src/components/GenerationWaveform.js ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useEffect, useLayoutEffect, useRef, useState, useCallback } from 'react';
2
+ import { Box } from '@mui/material';
3
+
4
+ const DEFAULT_COLOR = '#279FBB';
5
+ // Fixed, low waveform resolution β€” matches the dataset-page waveforms
6
+ // (/peaks?n=80). Decoding to a constant bucket count (instead of one pair per
7
+ // pixel) means the decode runs once per clip rather than re-running on every
8
+ // resize, so fragments render faster.
9
+ const PEAK_COUNT = 80;
10
+
11
+ /**
12
+ * Compact waveform indicator for a single generated fragment.
13
+ *
14
+ * Decodes `blob` once per width and renders min/max peaks on a canvas.
15
+ * Played portion is rendered in `color`; unplayed in a dim version of it,
16
+ * with a thin playhead line at the current position. The whole element is
17
+ * draggable: dragstart sets a DownloadURL on the dataTransfer so the user
18
+ * can drag the fragment onto their desktop or into a DAW as a .wav file.
19
+ *
20
+ * Props:
21
+ * blob: Blob | null β€” audio source (Blob is required for the
22
+ * native drag-to-OS file write).
23
+ * audioUrl: string β€” blob: URL for the same audio. Used in
24
+ * the dataTransfer; we fall back to
25
+ * createObjectURL(blob) if it's missing.
26
+ * filename: string β€” file name the OS sees when the drag
27
+ * resolves.
28
+ * currentTime: number β€” playback head position in seconds.
29
+ * duration: number β€” total length in seconds.
30
+ * height: number β€” canvas height in px (default 28).
31
+ * color: string β€” accent color (default theme amber).
32
+ */
33
+ export default function GenerationWaveform({
34
+ blob,
35
+ audioUrl,
36
+ filename = 'fragment.wav',
37
+ currentTime = 0,
38
+ duration = 0,
39
+ height = 28,
40
+ color = DEFAULT_COLOR,
41
+ }) {
42
+ const containerRef = useRef(null);
43
+ const canvasRef = useRef(null);
44
+ // Start at a sensible non-zero width so the decode useEffect (gated on
45
+ // width > 0) runs on first mount instead of waiting for the async
46
+ // ResizeObserver callback β€” which is what was leaving the canvas blank.
47
+ const [width, setWidth] = useState(200);
48
+ const [peaks, setPeaks] = useState(null);
49
+
50
+ // Measure synchronously on mount via useLayoutEffect so we never paint
51
+ // at the placeholder width; ResizeObserver then keeps it in sync with
52
+ // sidebar collapses / window resizes.
53
+ useLayoutEffect(() => {
54
+ const el = containerRef.current;
55
+ if (!el) return;
56
+ const rect = el.getBoundingClientRect();
57
+ if (rect.width > 0) {
58
+ setWidth(Math.max(1, Math.floor(rect.width)));
59
+ }
60
+ const ro = new ResizeObserver((entries) => {
61
+ const w = Math.max(1, Math.floor(entries[0].contentRect.width));
62
+ setWidth(w);
63
+ });
64
+ ro.observe(el);
65
+ return () => ro.disconnect();
66
+ }, []);
67
+
68
+ // Decode into PEAK_COUNT mono min/max pairs β€” a fixed low resolution,
69
+ // independent of pixel width, so the (expensive) decode runs once per clip
70
+ // and not again on every resize.
71
+ //
72
+ // Audio source can be either a Blob (in-memory, fresh generations) or
73
+ // an HTTP audioUrl (fragments hydrated from disk on app load have
74
+ // audioBlob=null and audioUrl=/api/fragments/...). The blob path is
75
+ // preferred when available; otherwise fetch the URL.
76
+ useEffect(() => {
77
+ if (!blob && !audioUrl) return;
78
+ let cancelled = false;
79
+ (async () => {
80
+ try {
81
+ let buf;
82
+ if (blob) {
83
+ buf = await blob.arrayBuffer();
84
+ } else {
85
+ const r = await fetch(audioUrl);
86
+ if (!r.ok) {
87
+ console.warn(`GenerationWaveform fetch failed (${r.status}): ${audioUrl}`);
88
+ return;
89
+ }
90
+ buf = await r.arrayBuffer();
91
+ }
92
+ if (cancelled) return;
93
+ if (!buf || buf.byteLength === 0) {
94
+ console.warn('GenerationWaveform: empty audio source');
95
+ return;
96
+ }
97
+ const Ctx = window.OfflineAudioContext || window.webkitOfflineAudioContext;
98
+ const tmpCtx = Ctx
99
+ ? new Ctx(1, 44100, 44100)
100
+ : new (window.AudioContext || window.webkitAudioContext)();
101
+ const audio = await tmpCtx.decodeAudioData(buf.slice(0));
102
+ if (cancelled) return;
103
+ const ch0 = audio.getChannelData(0);
104
+ const ch1 = audio.numberOfChannels > 1 ? audio.getChannelData(1) : null;
105
+ const totalSamples = ch0.length;
106
+ const bucketSize = Math.max(1, Math.floor(totalSamples / PEAK_COUNT));
107
+ const out = new Float32Array(PEAK_COUNT * 2);
108
+ for (let i = 0; i < PEAK_COUNT; i++) {
109
+ const s = i * bucketSize;
110
+ const e = Math.min(totalSamples, s + bucketSize);
111
+ let mn = 0, mx = 0;
112
+ for (let j = s; j < e; j++) {
113
+ const v = ch1 ? (ch0[j] + ch1[j]) * 0.5 : ch0[j];
114
+ if (v < mn) mn = v;
115
+ if (v > mx) mx = v;
116
+ }
117
+ out[i * 2] = mn;
118
+ out[i * 2 + 1] = mx;
119
+ }
120
+ if (!cancelled) setPeaks(out);
121
+ } catch (err) {
122
+ console.warn('GenerationWaveform decode failed:', err);
123
+ }
124
+ })();
125
+ return () => { cancelled = true; };
126
+ }, [blob, audioUrl]);
127
+
128
+ // Draw β€” re-runs on every currentTime tick so the playhead moves.
129
+ const draw = useCallback(() => {
130
+ const canvas = canvasRef.current;
131
+ if (!canvas || !width || !height) return;
132
+ const dpr = window.devicePixelRatio || 1;
133
+ canvas.width = width * dpr;
134
+ canvas.height = height * dpr;
135
+ const ctx = canvas.getContext('2d');
136
+ ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
137
+ ctx.clearRect(0, 0, width, height);
138
+
139
+ // Always draw a faint center line so the row has a visible "this is
140
+ // a waveform area" cue even while decode is in flight or has failed.
141
+ ctx.fillStyle = `${color}33`;
142
+ ctx.fillRect(0, height / 2 - 0.5, width, 1);
143
+
144
+ if (!peaks) return;
145
+
146
+ const mid = height / 2;
147
+ const scale = (height - 2) / 2;
148
+ // Stretch the fixed PEAK_COUNT buckets across the canvas width as bars
149
+ // (with a 1px gap), matching the dataset-page waveform look.
150
+ const n = peaks.length / 2;
151
+ const step = width / n;
152
+ const barW = Math.max(1, step - 1);
153
+ const progressFrac = duration > 0
154
+ ? Math.max(0, Math.min(1, currentTime / duration))
155
+ : 0;
156
+ const progressPx = progressFrac * width;
157
+ const splitIdx = Math.floor(progressFrac * n);
158
+
159
+ for (let i = 0; i < n; i++) {
160
+ const mn = peaks[i * 2];
161
+ const mx = peaks[i * 2 + 1];
162
+ const y0 = mid - mx * scale;
163
+ const y1 = mid - mn * scale;
164
+ // Played bars: full color; unplayed: dimmed (35% alpha of accent).
165
+ ctx.fillStyle = i < splitIdx ? color : `${color}59`;
166
+ ctx.fillRect(i * step, y0, barW, Math.max(1, y1 - y0));
167
+ }
168
+ // Thin playhead at the split.
169
+ if (progressPx > 0 && progressPx < width) {
170
+ ctx.fillStyle = color;
171
+ ctx.fillRect(progressPx - 0.5, 0, 1, height);
172
+ }
173
+ }, [width, height, peaks, color, currentTime, duration]);
174
+
175
+ useEffect(() => { draw(); }, [draw]);
176
+
177
+ // Native drag-to-OS as a file. The DownloadURL mime type is a Chromium
178
+ // extension the OS interprets as "this drag is a file the browser can
179
+ // serve from URL X with mime/name Y". Source is whichever URL we have:
180
+ // a blob: URL for in-memory fragments, or the backend /api/fragments/
181
+ // path for disk-hydrated ones. The OS needs an ABSOLUTE URL, so we
182
+ // resolve relative paths against window.location.origin.
183
+ const canDrag = !!(audioUrl || blob);
184
+ const handleDragStart = (e) => {
185
+ if (!canDrag) return;
186
+ const raw = audioUrl || URL.createObjectURL(blob);
187
+ const absolute = (raw.startsWith('http') || raw.startsWith('blob:'))
188
+ ? raw
189
+ : `${window.location.origin}${raw.startsWith('/') ? '' : '/'}${raw}`;
190
+ e.dataTransfer.setData('DownloadURL', `audio/wav:${filename}:${absolute}`);
191
+ e.dataTransfer.effectAllowed = 'copy';
192
+ };
193
+
194
+ return (
195
+ <Box
196
+ ref={containerRef}
197
+ draggable={canDrag}
198
+ onDragStart={handleDragStart}
199
+ title={canDrag ? 'Drag to save or drop into a DAW' : undefined}
200
+ sx={{
201
+ // Floor the width so the container is never zero β€” without
202
+ // this, a tight flex row could collapse it before
203
+ // ResizeObserver fires, leaving the canvas un-sized.
204
+ flex: 1,
205
+ minWidth: 120,
206
+ height,
207
+ cursor: canDrag ? 'grab' : 'default',
208
+ '&:active': { cursor: canDrag ? 'grabbing' : 'default' },
209
+ }}
210
+ >
211
+ <canvas
212
+ ref={canvasRef}
213
+ style={{ display: 'block', width: '100%', height }}
214
+ />
215
+ </Box>
216
+ );
217
+ }
app/frontend/src/components/InfoView.js ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { createContext, useCallback, useContext, useMemo, useState } from 'react';
2
+ import { Box, Typography } from '@mui/material';
3
+ import { Info as InfoIcon } from 'lucide-react';
4
+
5
+ /**
6
+ * Ableton-style "Info View".
7
+ *
8
+ * A toggleable strip pinned to the bottom of the window that shows the help
9
+ * text for whatever control the pointer (or keyboard focus) is over, instead
10
+ * of popping a tooltip on the control itself. The shared <Tooltip> feeds this
11
+ * panel when the view is enabled (see components/Tooltip.js).
12
+ *
13
+ * State design: `enabled` is owned by App (changes rarely, persisted). The
14
+ * *hint* β€” which changes on every hover β€” lives inside the provider and is
15
+ * read only by the bar, so updating it never re-renders the app tree (the app
16
+ * is passed as `children`, whose element identity is stable across the
17
+ * provider's internal state changes).
18
+ */
19
+ export const InfoViewContext = createContext({ enabled: false, setHint: () => {} });
20
+
21
+ export const useInfoView = () => useContext(InfoViewContext);
22
+
23
+ export function InfoViewProvider({ enabled, children }) {
24
+ const [hint, setHint] = useState(null);
25
+ // Stable setter so the context value only changes when `enabled` flips β€”
26
+ // hover-driven hint updates don't churn every tooltip consumer.
27
+ const update = useCallback((value) => setHint(value ?? null), []);
28
+ const value = useMemo(() => ({ enabled, setHint: update }), [enabled, update]);
29
+
30
+ return (
31
+ <InfoViewContext.Provider value={value}>
32
+ {children}
33
+ {enabled && <InfoViewBar hint={hint} />}
34
+ </InfoViewContext.Provider>
35
+ );
36
+ }
37
+
38
+ function InfoViewBar({ hint }) {
39
+ // Only present when there's something to say β€” no placeholder.
40
+ if (!hint) return null;
41
+
42
+ return (
43
+ // Full-width fixed row that centers the pill at the bottom of the page.
44
+ <Box
45
+ sx={{
46
+ position: 'fixed',
47
+ left: 0,
48
+ right: 0,
49
+ bottom: { xs: 16, md: 24 },
50
+ zIndex: 1340, // under the bottom dock (1350)
51
+ px: 2,
52
+ display: 'flex',
53
+ justifyContent: 'center',
54
+ pointerEvents: 'none', // pure overlay β€” never intercepts clicks
55
+ }}
56
+ >
57
+ <Box
58
+ role="status"
59
+ aria-live="polite"
60
+ sx={(theme) => ({
61
+ display: 'inline-flex',
62
+ alignItems: 'center',
63
+ gap: 1,
64
+ maxWidth: 'min(680px, 90vw)',
65
+ px: 1.75,
66
+ py: 0.9,
67
+ borderRadius: 999,
68
+ // Blurred translucent pill β€” just enough backing for the
69
+ // text to stay readable over any content behind it.
70
+ backgroundColor: theme.palette.mode === 'dark'
71
+ ? 'rgba(20, 22, 24, 0.55)'
72
+ : 'rgba(248, 243, 234, 0.62)',
73
+ backdropFilter: 'blur(16px) saturate(160%)',
74
+ WebkitBackdropFilter: 'blur(16px) saturate(160%)',
75
+ border: `1px solid ${theme.palette.divider}`,
76
+ boxShadow: theme.palette.mode === 'dark'
77
+ ? '0 8px 28px rgba(0,0,0,0.5)'
78
+ : '0 8px 28px rgba(43,31,18,0.16)',
79
+ animation: 'fragmenta-fade-up 240ms cubic-bezier(0.16, 1, 0.3, 1)',
80
+ })}
81
+ >
82
+ <Box component="span" sx={{ flexShrink: 0, display: 'inline-flex', color: 'primary.main' }}>
83
+ <InfoIcon size={15} />
84
+ </Box>
85
+ <Typography variant="body2" sx={{ color: 'text.primary', lineHeight: 1.3 }}>
86
+ {hint}
87
+ </Typography>
88
+ </Box>
89
+ </Box>
90
+ );
91
+ }
app/frontend/src/components/LoraStack.js ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useEffect, useState } from 'react';
2
+ import {
3
+ Box,
4
+ Accordion,
5
+ AccordionSummary,
6
+ AccordionDetails,
7
+ Button,
8
+ Typography,
9
+ Stack,
10
+ MenuItem,
11
+ Select,
12
+ Slider,
13
+ IconButton,
14
+ Chip,
15
+ Alert,
16
+ } from '@mui/material';
17
+ import { TIPS } from '../tooltips';
18
+ import Tooltip from './Tooltip';
19
+ import {
20
+ Plus as AddIcon,
21
+ Trash2 as RemoveIcon,
22
+ GripVertical as DragIcon,
23
+ Power as BypassIcon,
24
+ ChevronDown as ChevronDownIcon,
25
+ } from 'lucide-react';
26
+ import api from '../api';
27
+ import { isLoraCompatible } from '../utils/loraMatch';
28
+
29
+ const MAX_SLOTS = 4;
30
+
31
+ /**
32
+ * Multi-LoRA stack for the Generation panel.
33
+ *
34
+ * Props:
35
+ * selectedModel: the currently-selected base model id (e.g. "sa3-medium-base")
36
+ * value: array of { path, strength, bypassed } slots
37
+ * onChange: (newSlots) => void
38
+ *
39
+ * The picker filters available LoRAs by base-model compatibility (a `*-base`
40
+ * LoRA also runs on its distilled sibling β€” see utils/loraMatch). Slot order
41
+ * is the load order (slot 0 first); drag the handle to reorder. Bypass keeps
42
+ * a slot in the stack but sends strength 0.
43
+ */
44
+ export default function LoraStack({ selectedModel, value, onChange }) {
45
+ const [available, setAvailable] = useState([]);
46
+ const [loading, setLoading] = useState(false);
47
+ const [error, setError] = useState(null);
48
+ const [dragIndex, setDragIndex] = useState(null);
49
+
50
+ useEffect(() => {
51
+ let cancelled = false;
52
+ setLoading(true);
53
+ api.get('/api/loras')
54
+ .then(r => { if (!cancelled) setAvailable(r.data.loras || []); })
55
+ .catch(e => { if (!cancelled) setError(e.response?.data?.error || e.message); })
56
+ .finally(() => { if (!cancelled) setLoading(false); });
57
+ return () => { cancelled = true; };
58
+ }, []);
59
+
60
+ // LoRAs compatible with the current generation model. A LoRA trained
61
+ // against `*-base` is compatible with both that base and its distilled
62
+ // sibling (same backbone, differ only in CFG state) β€” loraMatch strips
63
+ // the trailing `-base` before comparing.
64
+ const compatible = available.filter(l =>
65
+ isLoraCompatible(l.base_model, selectedModel)
66
+ );
67
+
68
+ // The single-LoRA case stays one click: when no slots are populated AND
69
+ // there's a compatible LoRA, surface one empty slot so the user sees a
70
+ // "Pick a LoRA" dropdown immediately.
71
+ const slots = (value && value.length > 0)
72
+ ? value
73
+ : (compatible.length ? [{ path: '', strength: 1.0, bypassed: false }] : []);
74
+
75
+ const addSlot = () => {
76
+ if (slots.length >= MAX_SLOTS) return;
77
+ onChange([...slots, { path: '', strength: 1.0, bypassed: false }]);
78
+ };
79
+
80
+ const removeSlot = (idx) => onChange(slots.filter((_, i) => i !== idx));
81
+
82
+ const setSlot = (idx, patch) => {
83
+ onChange(slots.map((s, i) => i === idx ? { ...s, ...patch } : s));
84
+ };
85
+
86
+ // --- drag-to-reorder (slot 0 is loaded first) ---------------------------
87
+ const onDrop = (target) => {
88
+ if (dragIndex === null || dragIndex === target) { setDragIndex(null); return; }
89
+ const next = [...slots];
90
+ const [moved] = next.splice(dragIndex, 1);
91
+ next.splice(target, 0, moved);
92
+ setDragIndex(null);
93
+ onChange(next);
94
+ };
95
+
96
+ const hint = (() => {
97
+ if (!selectedModel) return 'Pick a model first.';
98
+ if (!selectedModel.endsWith('-base')) {
99
+ return 'LoRAs need a Base model. Switch to a *-base checkpoint to use LoRAs.';
100
+ }
101
+ if (loading) return 'Loading LoRAs…';
102
+ if (!compatible.length) {
103
+ return `No LoRAs trained against ${selectedModel} yet. Train one in the Training tab.`;
104
+ }
105
+ return null;
106
+ })();
107
+
108
+ return (
109
+ <Accordion
110
+ disableGutters
111
+ defaultExpanded={Boolean(value && value.some((s) => s.path))}
112
+ >
113
+ <AccordionSummary expandIcon={<ChevronDownIcon size={18} />}>
114
+ {/* Hover the title to surface the help in the Info View pill
115
+ (when it's on) β€” no inline "i", matching the rest of the app. */}
116
+ <Tooltip title={TIPS.lora.stackInfo(MAX_SLOTS)}>
117
+ <Typography variant="subtitle1">LoRA Stack</Typography>
118
+ </Tooltip>
119
+ </AccordionSummary>
120
+ <AccordionDetails>
121
+ {error && <Alert severity="error" sx={{ mb: 1 }}>{error}</Alert>}
122
+ {hint && (
123
+ <Typography variant="caption" color="text.secondary" sx={{ display: 'block', mb: 1 }}>
124
+ {hint}
125
+ </Typography>
126
+ )}
127
+
128
+ {slots.length > 0 && (
129
+ <Box sx={{ border: '1px solid', borderColor: 'divider', borderRadius: 1 }}>
130
+ {slots.map((slot, idx) => {
131
+ const choice = available.find(l => l.path === slot.path);
132
+ const bypassed = !!slot.bypassed;
133
+ return (
134
+ <Box
135
+ key={idx}
136
+ onDragOver={(e) => { if (dragIndex !== null) e.preventDefault(); }}
137
+ onDrop={() => onDrop(idx)}
138
+ sx={{
139
+ p: 1.5,
140
+ borderBottom: '1px solid',
141
+ borderColor: 'divider',
142
+ '&:last-child': { borderBottom: 'none' },
143
+ bgcolor: dragIndex === idx ? 'action.hover' : 'transparent',
144
+ opacity: bypassed ? 0.5 : 1,
145
+ }}
146
+ >
147
+ <Stack direction="row" alignItems="center" spacing={1}>
148
+ <Tooltip title={TIPS.lora.dragReorder}>
149
+ <Box
150
+ draggable={slots.length > 1}
151
+ onDragStart={() => setDragIndex(idx)}
152
+ onDragEnd={() => setDragIndex(null)}
153
+ sx={{
154
+ display: 'flex',
155
+ cursor: slots.length > 1 ? 'grab' : 'default',
156
+ color: 'text.disabled',
157
+ }}
158
+ >
159
+ <DragIcon size={16} />
160
+ </Box>
161
+ </Tooltip>
162
+ <Typography variant="caption" color="text.disabled" sx={{ width: 14 }}>
163
+ {idx}
164
+ </Typography>
165
+ <Select
166
+ size="small"
167
+ value={slot.path}
168
+ displayEmpty
169
+ onChange={(e) => setSlot(idx, { path: String(e.target.value) })}
170
+ sx={{ flex: 1, minWidth: 0 }}
171
+ >
172
+ <MenuItem value="" disabled>
173
+ <em>Pick a LoRA</em>
174
+ </MenuItem>
175
+ {compatible.map(l => (
176
+ <MenuItem key={l.id} value={l.path}>
177
+ <Box>
178
+ <Typography variant="body2">
179
+ {l.name} Β· {l.checkpoint}
180
+ </Typography>
181
+ <Stack direction="row" spacing={0.5} sx={{ mt: 0.25 }}>
182
+ <Chip size="small" label={l.adapter_type || 'lora'} sx={{ height: 16, fontSize: 9 }} />
183
+ {l.rank && <Chip size="small" label={`r=${l.rank}`} sx={{ height: 16, fontSize: 9 }} />}
184
+ </Stack>
185
+ </Box>
186
+ </MenuItem>
187
+ ))}
188
+ </Select>
189
+ <Tooltip title={TIPS.lora.bypass(bypassed)}>
190
+ <IconButton
191
+ size="small"
192
+ color={bypassed ? 'default' : 'primary'}
193
+ onClick={() => setSlot(idx, { bypassed: !bypassed })}
194
+ >
195
+ <BypassIcon size={14} />
196
+ </IconButton>
197
+ </Tooltip>
198
+ <IconButton size="small" onClick={() => removeSlot(idx)} aria-label="Remove slot">
199
+ <RemoveIcon size={14} />
200
+ </IconButton>
201
+ </Stack>
202
+
203
+ <Stack direction="row" alignItems="center" spacing={1.5} sx={{ mt: 1, mb: 2 }}>
204
+ <Typography variant="caption" color="text.secondary" sx={{ width: 60 }}>
205
+ Strength
206
+ </Typography>
207
+ <Slider
208
+ size="small"
209
+ value={slot.strength}
210
+ disabled={bypassed}
211
+ onChange={(e, v) => setSlot(idx, { strength: v })}
212
+ min={-2}
213
+ max={2}
214
+ step={0.05}
215
+ valueLabelDisplay="auto"
216
+ marks={[
217
+ { value: 0, label: '0' },
218
+ { value: 1, label: '1' },
219
+ ]}
220
+ sx={{ flex: 1 }}
221
+ />
222
+ <Typography variant="body2" sx={{ width: 40, textAlign: 'right' }}>
223
+ {bypassed ? 'β€”' : slot.strength.toFixed(2)}
224
+ </Typography>
225
+ </Stack>
226
+
227
+ {choice && choice.base_model && (
228
+ <Typography variant="caption" color="text.secondary" sx={{ display: 'block', mt: 0.25 }}>
229
+ Trained on {choice.base_model}
230
+ </Typography>
231
+ )}
232
+ </Box>
233
+ );
234
+ })}
235
+ </Box>
236
+ )}
237
+
238
+ <Stack direction="row" sx={{ mt: 1 }}>
239
+ <Button
240
+ size="small"
241
+ variant="outlined"
242
+ startIcon={<AddIcon size={14} />}
243
+ disabled={slots.length >= MAX_SLOTS || !compatible.length}
244
+ onClick={addSlot}
245
+ >
246
+ Add LoRA
247
+ </Button>
248
+ </Stack>
249
+ </AccordionDetails>
250
+ </Accordion>
251
+ );
252
+ }
app/frontend/src/components/LossChart.js CHANGED
@@ -1,19 +1,35 @@
1
  import React, { useState } from 'react';
2
  import { lossChartStyles } from '../theme';
3
 
4
- // Exponential moving average. alpha controls smoothness:
5
- // alpha β†’ 1 = no smoothing (output equals input)
6
- // alpha β†’ 0 = heavy smoothing (output flat-ish line)
7
- // Diffusion loss is intrinsically noisy because each step samples a random
8
- // timestep with different difficulty, so a small alpha (heavy smoothing) is
9
- // what makes the underlying trend visible.
10
- const EMA_ALPHA = 0.06;
 
 
 
 
 
 
 
 
 
 
11
 
12
- function smoothEMA(values, alpha = EMA_ALPHA) {
13
  if (values.length === 0) return [];
14
- const out = [values[0]];
15
- for (let i = 1; i < values.length; i++) {
16
- out.push(alpha * values[i] + (1 - alpha) * out[i - 1]);
 
 
 
 
 
 
17
  }
18
  return out;
19
  }
 
1
  import React, { useState } from 'react';
2
  import { lossChartStyles } from '../theme';
3
 
4
+ // Bias-corrected exponential moving average β€” same math as the EMA used in
5
+ // TensorBoard's loss curves and Adam's bias-corrected moments. Standard EMA
6
+ // (out[0] = values[0]) makes the smoothed line lag the data for the first
7
+ // 1/alpha steps; diffusion loss spikes high on step 0 (random init), so a
8
+ // naive EMA spends ~17 steps "catching down". The 1/(1-(1-Ξ±)^(i+1)) factor
9
+ // cancels that startup bias: by construction out[0] equals values[0], and
10
+ // out[i] converges to the plain EMA at steady state.
11
+ //
12
+ // alpha is *adaptive* to the run length: more data β†’ more smoothing is OK
13
+ // because the underlying trend has more support; short runs need a tighter
14
+ // window so the smoothed line still resembles the data.
15
+ function pickAlpha(n) {
16
+ if (n < 50) return 0.25;
17
+ if (n < 200) return 0.15;
18
+ if (n < 1000) return 0.08;
19
+ return 0.05;
20
+ }
21
 
22
+ function smoothEMA(values, alpha) {
23
  if (values.length === 0) return [];
24
+ const a = alpha ?? pickAlpha(values.length);
25
+ const w = 1 - a;
26
+ const out = [];
27
+ let ema = 0;
28
+ for (let i = 0; i < values.length; i++) {
29
+ ema = w * ema + a * values[i];
30
+ const correction = 1 - Math.pow(w, i + 1);
31
+ // correction β†’ a at i=0 (so out[0] = values[0]) and β†’ 1 at large i.
32
+ out.push(ema / Math.max(correction, 1e-9));
33
  }
34
  return out;
35
  }
app/frontend/src/components/MidiConfigMenu.js CHANGED
@@ -8,7 +8,6 @@ import {
8
  MenuItem,
9
  Button,
10
  IconButton,
11
- Tooltip,
12
  Divider,
13
  ToggleButton,
14
  ToggleButtonGroup,
@@ -16,7 +15,7 @@ import {
16
  } from '@mui/material';
17
  import { Trash2 as DeleteIcon, X as CloseIcon } from 'lucide-react';
18
  import { useMidi, formatMidi } from './MidiContext';
19
- import { perfTokens } from '../theme';
20
 
21
  const CHANNEL_OPTIONS = [
22
  { value: 0, label: 'Any' },
@@ -46,39 +45,60 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
46
  anchorEl={anchorEl}
47
  open={open}
48
  onClose={onClose}
49
- anchorOrigin={{ vertical: 'bottom', horizontal: 'right' }}
50
- transformOrigin={{ vertical: 'top', horizontal: 'right' }}
51
  slotProps={{
52
  paper: {
53
  sx: {
54
- width: 380,
55
  maxHeight: '70vh',
56
- p: 2,
57
  borderRadius: 2,
58
  border: '1px solid',
59
  borderColor: 'divider',
 
60
  },
61
  },
62
  }}
63
  >
64
- <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', mb: 1.5 }}>
65
- <Typography variant="subtitle2" sx={{ letterSpacing: '0.08em', textTransform: 'uppercase', color: 'text.secondary' }}>
 
 
 
 
 
 
 
 
66
  MIDI Settings
67
  </Typography>
68
- <IconButton size="small" onClick={onClose}>
69
- <CloseIcon size={14} />
70
  </IconButton>
71
  </Box>
72
 
 
 
73
  {!supported && (
74
- <Alert severity="warning" sx={{ mb: 1.5 }}>
75
- {permissionError || 'Web MIDI is not available in this browser. Try Chrome / Edge / Electron.'}
76
- </Alert>
 
 
77
  )}
78
 
79
- <Box sx={{ display: 'flex', flexDirection: 'column', gap: 1.5 }}>
 
 
 
 
 
 
 
 
80
  <Box>
81
- <Typography variant="caption" sx={{ color: 'text.secondary', display: 'block', mb: 0.5 }}>
82
  Input device
83
  </Typography>
84
  <FormControl size="small" fullWidth>
@@ -92,19 +112,26 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
92
  const found = inputs.find(i => i.id === value);
93
  return found ? found.name : 'Disconnected';
94
  }}
 
95
  >
96
- <MenuItem value="">
97
  <em>None</em>
98
  </MenuItem>
99
  {inputs.map((input) => (
100
- <MenuItem key={input.id} value={input.id}>
101
  {input.name}
102
  </MenuItem>
103
  ))}
104
  </Select>
105
  </FormControl>
106
  {config.deviceName && !inputs.some(i => i.name === config.deviceName) && (
107
- <Typography variant="caption" sx={{ color: 'warning.main', display: 'block', mt: 0.5 }}>
 
 
 
 
 
 
108
  Saved device "{config.deviceName}" not connected
109
  </Typography>
110
  )}
@@ -112,7 +139,7 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
112
 
113
  <Box sx={{ display: 'flex', gap: 1 }}>
114
  <Box sx={{ flex: 1 }}>
115
- <Typography variant="caption" sx={{ color: 'text.secondary', display: 'block', mb: 0.5 }}>
116
  Channel filter
117
  </Typography>
118
  <FormControl size="small" fullWidth>
@@ -120,16 +147,17 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
120
  value={config.channelFilter}
121
  onChange={(e) => setChannelFilter(Number(e.target.value))}
122
  disabled={!supported}
 
123
  >
124
  {CHANNEL_OPTIONS.map(opt => (
125
- <MenuItem key={opt.value} value={opt.value}>{opt.label}</MenuItem>
126
  ))}
127
  </Select>
128
  </FormControl>
129
  </Box>
130
 
131
  <Box sx={{ flex: 1 }}>
132
- <Typography variant="caption" sx={{ color: 'text.secondary', display: 'block', mb: 0.5 }}>
133
  Takeover
134
  </Typography>
135
  <ToggleButtonGroup
@@ -138,25 +166,42 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
138
  exclusive
139
  onChange={(_, v) => { if (v) setTakeover(v); }}
140
  fullWidth
141
- sx={{ height: 40 }}
142
  >
143
- <ToggleButton value="jump" sx={{ fontSize: perfTokens.fontSize.body }}>Jump</ToggleButton>
144
- <ToggleButton value="pickup" sx={{ fontSize: perfTokens.fontSize.body }}>Pickup</ToggleButton>
145
  </ToggleButtonGroup>
146
  </Box>
147
  </Box>
 
148
 
149
- <Divider sx={{ my: 0.5 }} />
150
 
151
- <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
152
- <Typography variant="caption" sx={{ color: 'text.secondary', letterSpacing: '0.08em', textTransform: 'uppercase' }}>
 
 
 
 
 
 
 
153
  Mappings ({config.mappings.length})
154
  </Typography>
155
  <Button
156
  size="small"
157
  onClick={clearAll}
158
  disabled={config.mappings.length === 0}
159
- sx={{ fontSize: perfTokens.fontSize.small, textTransform: 'none' }}
 
 
 
 
 
 
 
 
 
160
  >
161
  Clear all
162
  </Button>
@@ -167,15 +212,21 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
167
  border: '1px solid',
168
  borderColor: 'divider',
169
  borderRadius: 1,
170
- maxHeight: 280,
171
  overflowY: 'auto',
172
  bgcolor: 'background.default',
173
  }}
174
  >
175
  {sortedMappings.length === 0 ? (
176
- <Box sx={{ p: 2, textAlign: 'center' }}>
177
- <Typography variant="caption" sx={{ color: 'text.disabled', fontStyle: 'italic' }}>
178
- No mappings yet. Enable MIDI mode (the MIDI button), click a control, then move a hardware knob, fader, or button.
 
 
 
 
 
 
179
  </Typography>
180
  </Box>
181
  ) : (
@@ -191,33 +242,54 @@ export default function MidiConfigMenu({ anchorEl, open, onClose }) {
191
  borderBottom: '1px solid',
192
  borderColor: 'divider',
193
  '&:last-child': { borderBottom: 'none' },
 
 
194
  }}
195
  >
196
  <Box sx={{ flex: 1, minWidth: 0 }}>
197
- <Typography variant="body2" sx={{ fontSize: perfTokens.fontSize.body, overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }}>
 
 
 
 
 
 
198
  {m.label}
199
  </Typography>
200
- <Typography variant="caption" sx={{ color: 'text.secondary', fontSize: perfTokens.fontSize.small, fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Consolas, monospace' }}>
 
 
 
 
201
  {formatMidi(m.midi)}
202
  </Typography>
203
  </Box>
204
- <Tooltip title="Remove mapping">
205
- <IconButton
206
- size="small"
207
- onClick={() => clearMapping(m.controlId)}
208
- sx={{ color: 'text.disabled', '&:hover': { color: 'error.main' } }}
209
- >
210
- <DeleteIcon size={13} />
211
- </IconButton>
212
- </Tooltip>
213
  </Box>
214
  ))
215
  )}
216
  </Box>
 
 
 
217
 
218
- <Typography variant="caption" sx={{ color: 'text.disabled', fontSize: perfTokens.fontSize.small, lineHeight: 1.4 }}>
219
- Pickup = ignore the hardware until its position matches the on-screen value (no jumps).
220
- Right-click a control while in MIDI mode to clear its mapping.
 
 
 
 
 
 
 
221
  </Typography>
222
  </Box>
223
  </Popover>
 
8
  MenuItem,
9
  Button,
10
  IconButton,
 
11
  Divider,
12
  ToggleButton,
13
  ToggleButtonGroup,
 
15
  } from '@mui/material';
16
  import { Trash2 as DeleteIcon, X as CloseIcon } from 'lucide-react';
17
  import { useMidi, formatMidi } from './MidiContext';
18
+ import { perfTokens, performancePanelStyles as panelStyles } from '../theme';
19
 
20
  const CHANNEL_OPTIONS = [
21
  { value: 0, label: 'Any' },
 
45
  anchorEl={anchorEl}
46
  open={open}
47
  onClose={onClose}
48
+ anchorOrigin={{ vertical: 'bottom', horizontal: 'left' }}
49
+ transformOrigin={{ vertical: 'top', horizontal: 'left' }}
50
  slotProps={{
51
  paper: {
52
  sx: {
53
+ width: 360,
54
  maxHeight: '70vh',
55
+ p: 0,
56
  borderRadius: 2,
57
  border: '1px solid',
58
  borderColor: 'divider',
59
+ overflow: 'hidden',
60
  },
61
  },
62
  }}
63
  >
64
+ {/* Title bar β€” same pattern as Presets / Audio menus. */}
65
+ <Box sx={{
66
+ display: 'flex',
67
+ alignItems: 'center',
68
+ justifyContent: 'space-between',
69
+ px: 1.5,
70
+ pt: 1.25,
71
+ pb: 1,
72
+ }}>
73
+ <Typography sx={{ ...perfTokens.caps, color: 'text.secondary' }}>
74
  MIDI Settings
75
  </Typography>
76
+ <IconButton onClick={onClose} sx={panelStyles.compactIconBtn('md')}>
77
+ <CloseIcon size={perfTokens.icon.sm} />
78
  </IconButton>
79
  </Box>
80
 
81
+ <Divider />
82
+
83
  {!supported && (
84
+ <Box sx={{ px: 1.5, pt: 1.25 }}>
85
+ <Alert severity="warning" sx={{ py: 0.5 }}>
86
+ {permissionError || 'Web MIDI is not available in this browser. Try Chrome / Edge / Electron.'}
87
+ </Alert>
88
+ </Box>
89
  )}
90
 
91
+ {/* SETTINGS β€” input device + channel filter + takeover. */}
92
+ <Box sx={{
93
+ px: 1.5,
94
+ pt: 1.25,
95
+ pb: 1.25,
96
+ display: 'flex',
97
+ flexDirection: 'column',
98
+ gap: 1.25,
99
+ }}>
100
  <Box>
101
+ <Typography sx={{ ...perfTokens.labelMuted, display: 'block', mb: 0.5 }}>
102
  Input device
103
  </Typography>
104
  <FormControl size="small" fullWidth>
 
112
  const found = inputs.find(i => i.id === value);
113
  return found ? found.name : 'Disconnected';
114
  }}
115
+ sx={{ fontSize: perfTokens.fontSize.sm }}
116
  >
117
+ <MenuItem value="" sx={{ fontSize: perfTokens.fontSize.sm }}>
118
  <em>None</em>
119
  </MenuItem>
120
  {inputs.map((input) => (
121
+ <MenuItem key={input.id} value={input.id} sx={{ fontSize: perfTokens.fontSize.sm }}>
122
  {input.name}
123
  </MenuItem>
124
  ))}
125
  </Select>
126
  </FormControl>
127
  {config.deviceName && !inputs.some(i => i.name === config.deviceName) && (
128
+ <Typography sx={{
129
+ fontSize: perfTokens.fontSize.xs,
130
+ color: 'warning.main',
131
+ fontStyle: 'italic',
132
+ display: 'block',
133
+ mt: 0.5,
134
+ }}>
135
  Saved device "{config.deviceName}" not connected
136
  </Typography>
137
  )}
 
139
 
140
  <Box sx={{ display: 'flex', gap: 1 }}>
141
  <Box sx={{ flex: 1 }}>
142
+ <Typography sx={{ ...perfTokens.labelMuted, display: 'block', mb: 0.5 }}>
143
  Channel filter
144
  </Typography>
145
  <FormControl size="small" fullWidth>
 
147
  value={config.channelFilter}
148
  onChange={(e) => setChannelFilter(Number(e.target.value))}
149
  disabled={!supported}
150
+ sx={{ fontSize: perfTokens.fontSize.sm }}
151
  >
152
  {CHANNEL_OPTIONS.map(opt => (
153
+ <MenuItem key={opt.value} value={opt.value} sx={{ fontSize: perfTokens.fontSize.sm }}>{opt.label}</MenuItem>
154
  ))}
155
  </Select>
156
  </FormControl>
157
  </Box>
158
 
159
  <Box sx={{ flex: 1 }}>
160
+ <Typography sx={{ ...perfTokens.labelMuted, display: 'block', mb: 0.5 }}>
161
  Takeover
162
  </Typography>
163
  <ToggleButtonGroup
 
166
  exclusive
167
  onChange={(_, v) => { if (v) setTakeover(v); }}
168
  fullWidth
169
+ sx={{ height: perfTokens.height.compact }}
170
  >
171
+ <ToggleButton value="jump" sx={{ fontSize: perfTokens.fontSize.sm, textTransform: 'none' }}>Jump</ToggleButton>
172
+ <ToggleButton value="pickup" sx={{ fontSize: perfTokens.fontSize.sm, textTransform: 'none' }}>Pickup</ToggleButton>
173
  </ToggleButtonGroup>
174
  </Box>
175
  </Box>
176
+ </Box>
177
 
178
+ <Divider />
179
 
180
+ {/* MAPPINGS β€” header row + bordered scrollable list. */}
181
+ <Box sx={{ px: 1.5, pt: 1.25, pb: 1.25 }}>
182
+ <Box sx={{
183
+ display: 'flex',
184
+ alignItems: 'center',
185
+ justifyContent: 'space-between',
186
+ mb: 0.75,
187
+ }}>
188
+ <Typography sx={{ ...perfTokens.labelMuted, display: 'block' }}>
189
  Mappings ({config.mappings.length})
190
  </Typography>
191
  <Button
192
  size="small"
193
  onClick={clearAll}
194
  disabled={config.mappings.length === 0}
195
+ sx={{
196
+ fontSize: perfTokens.fontSize.xs,
197
+ color: 'error.main',
198
+ textTransform: 'none',
199
+ py: 0,
200
+ px: 0.75,
201
+ minWidth: 0,
202
+ '&:hover': { bgcolor: 'action.hover' },
203
+ '&.Mui-disabled': { color: 'text.disabled' },
204
+ }}
205
  >
206
  Clear all
207
  </Button>
 
212
  border: '1px solid',
213
  borderColor: 'divider',
214
  borderRadius: 1,
215
+ maxHeight: 240,
216
  overflowY: 'auto',
217
  bgcolor: 'background.default',
218
  }}
219
  >
220
  {sortedMappings.length === 0 ? (
221
+ <Box sx={{ px: 1.5, py: 1.5, textAlign: 'center' }}>
222
+ <Typography sx={{
223
+ fontSize: perfTokens.fontSize.xs,
224
+ color: 'text.disabled',
225
+ fontStyle: 'italic',
226
+ lineHeight: 1.4,
227
+ }}>
228
+ No mappings yet. Enable MIDI mode, click a control,
229
+ then move a hardware knob, fader, or button.
230
  </Typography>
231
  </Box>
232
  ) : (
 
242
  borderBottom: '1px solid',
243
  borderColor: 'divider',
244
  '&:last-child': { borderBottom: 'none' },
245
+ '&:hover': { bgcolor: 'action.hover' },
246
+ transition: 'background-color 120ms',
247
  }}
248
  >
249
  <Box sx={{ flex: 1, minWidth: 0 }}>
250
+ <Typography sx={{
251
+ fontSize: perfTokens.fontSize.sm,
252
+ fontWeight: 500,
253
+ overflow: 'hidden',
254
+ textOverflow: 'ellipsis',
255
+ whiteSpace: 'nowrap',
256
+ }}>
257
  {m.label}
258
  </Typography>
259
+ <Typography sx={{
260
+ color: 'text.secondary',
261
+ fontSize: perfTokens.fontSize.xs,
262
+ fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Consolas, monospace',
263
+ }}>
264
  {formatMidi(m.midi)}
265
  </Typography>
266
  </Box>
267
+ <IconButton
268
+ size="small"
269
+ onClick={() => clearMapping(m.controlId)}
270
+ sx={panelStyles.compactIconBtn('sm', 'danger')}
271
+ aria-label="Remove mapping"
272
+ >
273
+ <DeleteIcon size={perfTokens.icon.sm} />
274
+ </IconButton>
 
275
  </Box>
276
  ))
277
  )}
278
  </Box>
279
+ </Box>
280
+
281
+ <Divider />
282
 
283
+ {/* Footer help text. */}
284
+ <Box sx={{ px: 1.5, pt: 1, pb: 1.25 }}>
285
+ <Typography sx={{
286
+ color: 'text.disabled',
287
+ fontSize: perfTokens.fontSize.xs,
288
+ fontStyle: 'italic',
289
+ lineHeight: 1.4,
290
+ }}>
291
+ Pickup ignores the hardware until its position matches the on-screen
292
+ value. Right-click a control while in MIDI mode to clear its mapping.
293
  </Typography>
294
  </Box>
295
  </Popover>
app/frontend/src/components/MidiContext.js CHANGED
@@ -8,6 +8,7 @@ import React, {
8
  useState,
9
  } from 'react';
10
  import { Box } from '@mui/material';
 
11
 
12
  const STORAGE_KEY = 'fragmenta.midi.config.v1';
13
 
@@ -73,7 +74,6 @@ export function MidiProvider({ children }) {
73
  const [learnMode, setLearnMode] = useState(false);
74
  const [learnTarget, setLearnTarget] = useState(null);
75
 
76
- const accessRef = useRef(null);
77
  const subscribersRef = useRef(new Map());
78
  const pickupArmedRef = useRef(new Map());
79
  const configRef = useRef(config);
@@ -86,39 +86,23 @@ export function MidiProvider({ children }) {
86
 
87
  useEffect(() => { learnTargetRef.current = learnTarget; }, [learnTarget]);
88
 
89
- const refreshInputs = useCallback(() => {
90
- const access = accessRef.current;
91
- if (!access) return;
92
- const list = [];
93
- access.inputs.forEach((input) => {
94
- list.push({
95
- id: input.id,
96
- name: input.name || 'Unknown device',
97
- manufacturer: input.manufacturer || '',
98
- });
99
- });
100
- setInputs(list);
101
- }, []);
102
-
103
- useEffect(() => {
104
- if (typeof navigator === 'undefined' || !navigator.requestMIDIAccess) {
105
  setSupported(false);
106
- return undefined;
107
  }
108
- let cancelled = false;
109
- navigator.requestMIDIAccess({ sysex: false })
110
- .then((access) => {
111
- if (cancelled) return;
112
- accessRef.current = access;
113
- refreshInputs();
114
- access.onstatechange = refreshInputs;
115
- })
116
- .catch((err) => {
117
- setPermissionError(err?.message || 'MIDI permission denied');
118
- setSupported(false);
119
- });
120
- return () => { cancelled = true; };
121
- }, [refreshInputs]);
122
 
123
  useEffect(() => {
124
  if (!inputs.length || !config.deviceName) return;
@@ -192,24 +176,30 @@ export function MidiProvider({ children }) {
192
  }
193
  }, [captureLearn]);
194
 
 
 
 
195
  useEffect(() => {
196
- const access = accessRef.current;
197
- if (!access) return undefined;
198
- const bound = [];
199
- access.inputs.forEach((input) => {
200
- if (config.deviceId && input.id === config.deviceId) {
201
- input.onmidimessage = dispatchMessage;
202
- bound.push(input);
203
- } else {
204
- input.onmidimessage = null;
205
- }
206
- });
207
-
208
- pickupArmedRef.current = new Map();
209
- return () => {
210
- bound.forEach((i) => { i.onmidimessage = null; });
211
  };
212
- }, [config.deviceId, inputs, dispatchMessage]);
 
 
 
 
 
 
 
 
 
 
213
 
214
  function applyContinuous(sub, mapping, midiValue, takeover) {
215
  const norm = midiValue / 127;
 
8
  useState,
9
  } from 'react';
10
  import { Box } from '@mui/material';
11
+ import api from '../api';
12
 
13
  const STORAGE_KEY = 'fragmenta.midi.config.v1';
14
 
 
74
  const [learnMode, setLearnMode] = useState(false);
75
  const [learnTarget, setLearnTarget] = useState(null);
76
 
 
77
  const subscribersRef = useRef(new Map());
78
  const pickupArmedRef = useRef(new Map());
79
  const configRef = useRef(config);
 
86
 
87
  useEffect(() => { learnTargetRef.current = learnTarget; }, [learnTarget]);
88
 
89
+ // Device list comes from the native backend (python-rtmidi) instead of
90
+ // Web MIDI, so it works in every web engine.
91
+ const refreshInputs = useCallback(async () => {
92
+ try {
93
+ const { data } = await api.get('/api/midi/devices');
94
+ setSupported(!!data.available);
95
+ setInputs(Array.isArray(data.inputs) ? data.inputs : []);
96
+ setPermissionError(data.available
97
+ ? null
98
+ : 'Native MIDI is unavailable (python-rtmidi not installed).');
99
+ } catch (err) {
 
 
 
 
 
100
  setSupported(false);
101
+ setPermissionError(err?.message || 'Could not reach the MIDI backend.');
102
  }
103
+ }, []);
104
+
105
+ useEffect(() => { refreshInputs(); }, [refreshInputs]);
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  useEffect(() => {
108
  if (!inputs.length || !config.deviceName) return;
 
176
  }
177
  }, [captureLearn]);
178
 
179
+ // Stream incoming MIDI from the backend over SSE. Each event is the same
180
+ // {data:[status,d1,d2]} shape Web MIDI gave us, so dispatchMessage is
181
+ // unchanged. EventSource auto-reconnects on drop.
182
  useEffect(() => {
183
+ if (typeof EventSource === 'undefined') {
184
+ setSupported(false);
185
+ return undefined;
186
+ }
187
+ const es = new EventSource('/api/midi/stream');
188
+ es.onmessage = (e) => {
189
+ try { dispatchMessage(JSON.parse(e.data)); }
190
+ catch { /* malformed line β€” ignore */ }
 
 
 
 
 
 
 
191
  };
192
+ pickupArmedRef.current = new Map();
193
+ return () => es.close();
194
+ }, [dispatchMessage]);
195
+
196
+ // Tell the backend which port to open. The stream only carries the open
197
+ // port's events, so device selection happens server-side.
198
+ useEffect(() => {
199
+ if (!supported) return;
200
+ api.post('/api/midi/select', { port_id: config.deviceId || null })
201
+ .catch(() => { /* non-fatal */ });
202
+ }, [config.deviceId, supported]);
203
 
204
  function applyContinuous(sub, mapping, midiValue, takeover) {
205
  const norm = midiValue / 127;
app/frontend/src/components/PerformanceChannel.js CHANGED
@@ -5,27 +5,37 @@ import {
5
  TextField,
6
  IconButton,
7
  Slider,
8
- CircularProgress,
9
- Tooltip,
10
  Select,
11
  MenuItem,
12
  ButtonBase,
13
  } from '@mui/material';
 
 
14
  import {
15
  Play as PlayIcon,
16
  Square as StopIcon,
17
- Repeat as LoopIcon,
18
- Sparkles as GenerateIcon,
19
  Volume2 as VolumeIcon,
20
  VolumeX as MuteIcon,
21
- Headphones as CueIcon,
22
- Check as CommitIcon,
23
  } from 'lucide-react';
24
- import { performanceChannelStyles as styles, perfTokens } from '../theme';
25
  import { MidiMappable } from './MidiContext';
26
  import { playBlob as playCueBlob, stopCue, isCueSupported } from '../utils/cueAudio';
 
 
 
 
 
 
 
 
 
27
 
28
  const CHANNEL_COLORS = [
 
 
 
29
  '#35C2D4', '#9F8AE6', '#53C18A', '#E3A34B',
30
  '#E36C61', '#F08AD2', '#5BA0F0', '#A8D86B',
31
  ];
@@ -40,19 +50,31 @@ const gainDbToLinear = (db) => (db <= GAIN_DB_MIN ? 0 : Math.pow(10, db / 20));
40
 
41
  const KNOB_DEFS = [
42
  { key: 'gain', label: 'GAIN', min: GAIN_DB_MIN, max: GAIN_DB_MAX, step: 0.5, default: GAIN_DB_DEFAULT },
43
- // LPF range goes from 20 Hz (full kill) to 20 kHz (bypass). We render the
44
- // slider on a log axis so each octave gets equal travel β€” without this
45
- // the bottom 5% of the knob does all the audible work.
46
- { key: 'filter', label: 'LPF', min: 20, max: 20000, step: 1, default: 20000, scale: 'log' },
 
47
  { key: 'delay', label: 'DLY', min: 0, max: 1.0, step: 0.01, default: 0.0 },
48
  { key: 'reverb', label: 'REV', min: 0, max: 1.0, step: 0.01, default: 0.0 },
49
  ];
50
 
 
 
 
 
 
 
 
 
 
51
  const PAN_CENTER_SNAP = 0.06;
52
 
53
  const BARS_OPTIONS = [1, 2, 4, 8, 16];
54
  const BEATS_PER_BAR = 4;
55
  const BATCH_OPTIONS = [1, 2, 3, 4];
 
 
56
 
57
  export default function PerformanceChannel({
58
  index,
@@ -65,13 +87,16 @@ export default function PerformanceChannel({
65
  onStateChange,
66
  onFormStateChange,
67
  initialFormState,
68
- maxDuration = 47,
69
  bpm = 120,
70
  }) {
71
  const color = CHANNEL_COLORS[index % CHANNEL_COLORS.length];
72
  const canvasRef = useRef(null);
73
  const meterRef = useRef(null);
74
  const meterRafRef = useRef(null);
 
 
 
75
 
76
  const init = initialFormState || {};
77
  const initKnobs = init.knobs || {};
@@ -83,7 +108,7 @@ export default function PerformanceChannel({
83
 
84
  const [prompt, setPrompt] = useState(init.prompt ?? '');
85
  const [duration, setDuration] = useState(init.duration ?? 8);
86
- const [durationMode, setDurationMode] = useState(init.durationMode ?? 'seconds');
87
  const [bars, setBars] = useState(init.bars ?? 4);
88
  const [generating, setGenerating] = useState(false);
89
  const [loaded, setLoaded] = useState(false);
@@ -91,31 +116,147 @@ export default function PerformanceChannel({
91
  const [muted, setMuted] = useState(init.muted ?? false);
92
  const [soloed, setSoloed] = useState(init.soloed ?? false);
93
  const [batchSize, setBatchSize] = useState(init.batchSize ?? 1);
94
- const [knobs, setKnobs] = useState(() => ({ ...defaultKnobs, ...initKnobs }));
95
-
96
- // Candidates from the latest batch generation. Held in component state
97
- // because they don't survive a page reload β€” the blob URLs would be dead.
98
- // `committedIndex` tracks which one is currently loaded into the strip.
99
- const [candidates, setCandidates] = useState([]);
100
- const [auditioningIndex, setAuditioningIndex] = useState(null);
101
- const [committedIndex, setCommittedIndex] = useState(null);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  const cueSupported = useMemo(() => isCueSupported(), []);
103
 
104
  // Stop any active cue audition when the channel unmounts.
105
  useEffect(() => () => stopCue(), []);
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  // Mirror form state up to the panel so it can persist the session. Skip the
108
  // first render so we don't re-write what we just loaded from localStorage.
 
 
109
  const initialReportSkippedRef = useRef(false);
110
  useEffect(() => {
111
  if (!initialReportSkippedRef.current) {
112
  initialReportSkippedRef.current = true;
113
  return;
114
  }
 
115
  onFormStateChange?.(index, {
116
  prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs,
 
 
117
  });
118
- }, [prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs, index, onFormStateChange]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  const secondsFromBars = useMemo(
121
  () => bars * (60 / Math.max(bpm, 1)) * BEATS_PER_BAR,
@@ -145,12 +286,31 @@ export default function PerformanceChannel({
145
 
146
  const drawWave = useCallback(() => {
147
  if (strip && canvasRef.current) {
 
148
  strip.drawWaveform(canvasRef.current, color);
149
  }
150
  }, [strip, color]);
151
 
152
  useEffect(() => { drawWave(); }, [drawWave, loaded]);
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  // One-shot: push restored knob/loop values into the audio strip when it
155
  // first becomes available, so the persisted session matches what's heard.
156
  // Mute/solo applies through the parent's mix handler so the panel can
@@ -161,7 +321,11 @@ export default function PerformanceChannel({
161
  stripStateAppliedRef.current = true;
162
  strip.setUserGain(gainDbToLinear(knobs.gain));
163
  strip.setPan(knobs.pan);
164
- strip.setFilter(knobs.filter);
 
 
 
 
165
  strip.setDelayMix(knobs.delay);
166
  strip.setReverbMix(knobs.reverb);
167
  strip.setLoop(looping);
@@ -183,34 +347,89 @@ export default function PerformanceChannel({
183
  }
184
  }, [availableBars, bars]);
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  const handleGenerate = async () => {
187
  if (!prompt.trim() || generating) return;
188
  const inBarsMode = durationMode === 'bars';
189
  const effectiveDuration = inBarsMode ? secondsFromBars : duration;
190
  setGenerating(true);
191
- // Stop any in-flight cue audition and clear stale candidate state so
192
- // the audition strip doesn't keep playing the old generation.
193
  stopCue();
194
- setAuditioningIndex(null);
 
 
 
195
  try {
196
- const result = await onGenerate({
197
  prompt,
198
  duration: effectiveDuration,
199
  batchSize,
200
  // Only forward alignment params in bars mode β€” seconds mode
201
  // generates raw audio with no post-processing.
202
  ...(inBarsMode ? { alignBars: bars, alignBpm: bpm } : {}),
 
 
 
 
203
  });
204
- const blobs = Array.isArray(result) ? result : [result];
205
- const next = blobs.map((b, i) => ({ index: i, blob: b }));
206
- setCandidates(next);
207
- // First candidate auto-loads into the channel strip; the rest sit
208
- // in the audition row until the user commits a different one.
209
- await strip.loadBlob(blobs[0]);
210
- setCommittedIndex(0);
211
- setLoaded(true);
212
- onStateChange?.(index, { loaded: true });
213
- requestAnimationFrame(drawWave);
214
  } catch (err) {
215
  console.error(`Channel ${index + 1} generate failed:`, err);
216
  } finally {
@@ -218,37 +437,166 @@ export default function PerformanceChannel({
218
  }
219
  };
220
 
221
- const handleAudition = async (i) => {
222
- const candidate = candidates[i];
223
- if (!candidate) return;
224
- if (auditioningIndex === i) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  stopCue();
226
- setAuditioningIndex(null);
227
  return;
228
  }
229
- setAuditioningIndex(i);
230
  try {
231
- await playCueBlob(candidate.blob, {
232
- onEnded: () => setAuditioningIndex(prev => (prev === i ? null : prev)),
233
  });
234
  } catch (err) {
235
  console.warn(`Channel ${index + 1} audition failed:`, err);
236
- setAuditioningIndex(null);
237
  }
238
  };
239
 
240
- const handleCommit = async (i) => {
241
- const candidate = candidates[i];
242
- if (!candidate || committedIndex === i) return;
243
- // Stop the live channel before swapping the buffer so we don't get a
244
- // glitch in the middle of a loop iteration.
245
- try { strip.stop(); } catch { /* not playing */ }
246
- onStateChange?.(index, { playing: false });
247
- await strip.loadBlob(candidate.blob);
248
- setCommittedIndex(i);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  requestAnimationFrame(drawWave);
250
  };
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  const handlePlay = () => {
253
  if (!loaded) return;
254
  if (engine) engine.playChannel(index, looping);
@@ -285,7 +633,11 @@ export default function PerformanceChannel({
285
  setKnobs(prev => ({ ...prev, [key]: value }));
286
  if (key === 'gain') strip.setUserGain(gainDbToLinear(value));
287
  else if (key === 'pan') strip.setPan(value);
288
- else if (key === 'filter') strip.setFilter(value);
 
 
 
 
289
  else if (key === 'delay') strip.setDelayMix(value);
290
  else if (key === 'reverb') strip.setReverbMix(value);
291
  };
@@ -307,15 +659,45 @@ export default function PerformanceChannel({
307
  return (
308
  <Box sx={styles.strip(color, playing)}>
309
  <Box sx={styles.stripHeader(color)}>
310
- <Box sx={styles.channelBadge(color)}>{String(index + 1).padStart(2, '0')}</Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  <Box sx={styles.muteSoloRow}>
312
  <MidiMappable id={ctrlId('mute')} label={ctrlLabel('Mute')} kind="trigger" onChange={handleMuteToggle}>
313
- <Tooltip title="Mute">
314
  <IconButton size="small" onClick={handleMuteToggle} sx={styles.muteBtn(muted)}>M</IconButton>
315
  </Tooltip>
316
  </MidiMappable>
317
  <MidiMappable id={ctrlId('solo')} label={ctrlLabel('Solo')} kind="trigger" onChange={handleSoloToggle}>
318
- <Tooltip title="Solo">
319
  <IconButton size="small" onClick={handleSoloToggle} sx={styles.soloBtn(soloed)}>S</IconButton>
320
  </Tooltip>
321
  </MidiMappable>
@@ -324,18 +706,50 @@ export default function PerformanceChannel({
324
 
325
  <Box sx={styles.promptBox}>
326
  <TextField
327
- placeholder="prompt…"
328
  value={prompt}
329
  onChange={(e) => setPrompt(e.target.value)}
330
  multiline
331
  minRows={2}
332
- maxRows={3}
333
  size="small"
334
  fullWidth
335
  sx={styles.promptField}
336
  disabled={generating}
337
  />
338
  <Box sx={{ ...styles.durationRow, minHeight: 26, height: 26 }}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  <Box
340
  sx={{
341
  display: 'inline-flex',
@@ -344,9 +758,13 @@ export default function PerformanceChannel({
344
  borderRadius: 0.75,
345
  overflow: 'hidden',
346
  height: '100%',
 
347
  }}
348
  >
349
- {['sec', 'bars'].map((mode) => {
 
 
 
350
  const value = mode === 'sec' ? 'seconds' : 'bars';
351
  const active = durationMode === value;
352
  return (
@@ -354,215 +772,170 @@ export default function PerformanceChannel({
354
  key={mode}
355
  onClick={() => setDurationMode(value)}
356
  sx={{
357
- fontSize: perfTokens.fontSize.small,
358
- letterSpacing: perfTokens.letterSpacing.wide,
359
- textTransform: 'uppercase',
360
- fontFamily: 'inherit',
361
  px: 0.7,
362
- minWidth: 30,
363
  bgcolor: active ? color : 'transparent',
 
 
364
  color: active ? 'rgba(0,0,0,0.88)' : 'text.disabled',
365
- fontWeight: active ? 600 : 400,
366
- transition: 'background-color 120ms, color 120ms',
367
  '&:hover': {
368
  bgcolor: active ? color : 'action.hover',
369
  color: active ? 'rgba(0,0,0,0.88)' : 'text.secondary',
370
  },
371
  }}
372
  >
373
- {mode}
374
  </ButtonBase>
375
  );
376
  })}
377
  </Box>
378
-
379
- {durationMode === 'seconds' ? (
380
- <>
381
- <Typography variant="caption" sx={styles.durationLabel}>{duration.toFixed(0)}s</Typography>
382
- <Slider
383
- value={duration}
384
- onChange={(_, v) => setDuration(v)}
385
- min={2}
386
- max={maxDuration}
387
- step={1}
388
- size="small"
389
- sx={styles.durationSlider(color)}
390
- />
391
- </>
392
- ) : (
393
- <Select
394
- value={availableBars.includes(bars) ? bars : availableBars[availableBars.length - 1]}
395
- onChange={(e) => setBars(Number(e.target.value))}
396
- size="small"
397
- sx={{
398
- flex: 1,
399
- fontSize: perfTokens.fontSize.body,
400
- height: '100%',
401
- '& .MuiOutlinedInput-input': {
402
- py: 0,
403
- pl: 1,
404
- minHeight: 'unset',
405
- },
406
- '& .MuiSelect-select': {
407
- py: 0,
408
- pl: 1,
409
- minHeight: 'unset',
410
- },
411
- }}
412
- >
413
- {availableBars.map(b => (
414
- <MenuItem key={b} value={b} sx={{ fontSize: perfTokens.fontSize.body }}>
415
- {b} {b === 1 ? 'bar' : 'bars'}
416
- </MenuItem>
417
- ))}
418
- </Select>
419
- )}
420
  </Box>
421
  <Box sx={{
422
  display: 'flex',
423
  alignItems: 'center',
424
- justifyContent: 'center',
425
- gap: 1.5,
426
  mt: 0.5,
427
  width: '100%',
428
  }}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  <Tooltip
430
- title="Batch generation: produce N candidates and audition them through the cue output before committing one to this channel."
431
  placement="top"
432
- disableFocusListener
433
- disableTouchListener
434
  enterDelay={500}
435
  >
436
  <Select
437
  value={batchSize}
438
  onChange={(e) => setBatchSize(Number(e.target.value))}
439
- size="small"
440
  disabled={generating}
441
- sx={{
442
- fontSize: perfTokens.fontSize.body,
443
- height: 32,
444
- minWidth: 64,
445
- '& .MuiOutlinedInput-input': { py: 0, pl: 1.25, pr: '28px !important', minHeight: 'unset' },
446
- '& .MuiSelect-select': { py: 0, pl: 1.25, pr: '28px !important', minHeight: 'unset' },
447
- }}
448
  >
449
- {BATCH_OPTIONS.map(n => (
450
- <MenuItem key={n} value={n} sx={{ fontSize: perfTokens.fontSize.body }}>
 
 
 
 
451
  Γ—{n}
452
  </MenuItem>
453
  ))}
454
  </Select>
455
  </Tooltip>
456
- <MidiMappable id={ctrlId('generate')} label={ctrlLabel('Generate')} kind="trigger" onChange={handleGenerate}>
457
- <IconButton
458
- onClick={handleGenerate}
459
- disabled={!canGenerate || !prompt.trim() || generating}
460
- sx={styles.generateBtn(color)}
461
- size="small"
 
462
  >
463
- {generating ? <CircularProgress size={16} sx={{ color }} /> : <GenerateIcon size={16} />}
464
- </IconButton>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  </MidiMappable>
466
  </Box>
467
  </Box>
468
 
469
- <Box sx={styles.waveformWrap}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  <canvas
471
  ref={canvasRef}
472
  width={140}
473
  height={42}
474
- style={{ width: '100%', height: 42, display: 'block' }}
475
  />
476
  {!loaded && (
477
- <Typography variant="caption" sx={styles.waveformPlaceholder}>
478
- empty
479
  </Typography>
480
  )}
481
  </Box>
482
 
483
- {candidates.length > 1 && (
484
- <Box
485
- sx={{
486
- display: 'flex',
487
- alignItems: 'center',
488
- gap: 0.5,
489
- px: 1,
490
- py: 0.5,
491
- flexWrap: 'wrap',
492
- }}
493
- >
494
- {candidates.map((c, i) => {
495
- const isAuditioning = auditioningIndex === i;
496
- const isCommitted = committedIndex === i;
497
- return (
498
- <Box
499
- key={c.index}
500
- sx={{
501
- display: 'inline-flex',
502
- alignItems: 'center',
503
- border: '1px solid',
504
- borderColor: isCommitted ? color : 'divider',
505
- borderRadius: 0.75,
506
- overflow: 'hidden',
507
- bgcolor: isCommitted ? `${color}1a` : 'transparent',
508
- }}
509
- >
510
- <Tooltip
511
- title={
512
- cueSupported
513
- ? (isAuditioning ? 'Stop cue audition' : 'Audition this take through cue output')
514
- : 'Cue audition requires Chrome/Edge. Plays through main output.'
515
- }
516
- >
517
- <IconButton
518
- onClick={() => handleAudition(i)}
519
- size="small"
520
- sx={{
521
- color: isAuditioning ? color : 'text.secondary',
522
- px: 0.5,
523
- borderRadius: 0,
524
- }}
525
- >
526
- <CueIcon size={12} />
527
- <Box
528
- component="span"
529
- sx={{
530
- ml: 0.4,
531
- fontSize: perfTokens.fontSize.small,
532
- fontWeight: isAuditioning ? 700 : 500,
533
- }}
534
- >
535
- {i + 1}
536
- </Box>
537
- </IconButton>
538
- </Tooltip>
539
- <Tooltip title={isCommitted ? 'Currently in channel' : 'Use this take in the channel'}>
540
- <span>
541
- <IconButton
542
- onClick={() => handleCommit(i)}
543
- size="small"
544
- disabled={isCommitted}
545
- sx={{
546
- color: isCommitted ? color : 'text.disabled',
547
- px: 0.4,
548
- borderRadius: 0,
549
- borderLeft: '1px solid',
550
- borderColor: 'divider',
551
- }}
552
- >
553
- <CommitIcon size={12} />
554
- </IconButton>
555
- </span>
556
- </Tooltip>
557
- </Box>
558
- );
559
- })}
560
- </Box>
561
- )}
562
 
563
  <Box sx={{ px: 1, py: 1 }}>
564
  <Box sx={{ display: 'flex', alignItems: 'center', gap: 1, mb: 1 }}>
565
- <Box component="span" sx={{ fontSize: perfTokens.fontSize.knob, color: 'text.secondary', letterSpacing: perfTokens.letterSpacing.wide, minWidth: 28 }}>PAN</Box>
566
  <MidiMappable
567
  id={ctrlId('pan')}
568
  label={ctrlLabel('Pan')}
@@ -584,6 +957,10 @@ export default function PerformanceChannel({
584
  marks={[{ value: 0 }]}
585
  sx={{
586
  flex: 1,
 
 
 
 
587
  '& .MuiSlider-mark': {
588
  width: 2,
589
  height: 10,
@@ -604,6 +981,7 @@ export default function PerformanceChannel({
604
  <Box sx={styles.knobsGrid}>
605
  {KNOB_DEFS.map((k) => {
606
  const isLog = k.scale === 'log';
 
607
  // For log knobs, the slider drives a 0..1 position and we
608
  // convert to/from the underlying value (Hz) on the audio
609
  // boundary. The knob value stored in state stays in the
@@ -635,7 +1013,24 @@ export default function PerformanceChannel({
635
  max={isLog ? 1 : k.max}
636
  step={isLog ? 0.001 : k.step}
637
  size="small"
638
- sx={styles.knobSlider(color, k.key === 'gain')}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  />
640
  </MidiMappable>
641
  <Box component="span" sx={styles.knobLabel}>{k.label}</Box>
@@ -644,26 +1039,10 @@ export default function PerformanceChannel({
644
  })}
645
  </Box>
646
 
 
 
 
647
  <Box sx={styles.transportRow}>
648
- <MidiMappable id={ctrlId('transport')} label={ctrlLabel('Play/Stop')} kind="trigger" onChange={handleTransportToggle}>
649
- <IconButton
650
- onClick={playing ? handleStop : handlePlay}
651
- disabled={!loaded}
652
- sx={styles.transportBtn(color, playing)}
653
- size="small"
654
- >
655
- {playing ? <StopIcon size={16} /> : <PlayIcon size={16} />}
656
- </IconButton>
657
- </MidiMappable>
658
- <MidiMappable id={ctrlId('loop')} label={ctrlLabel('Loop')} kind="trigger" onChange={handleLoopToggle}>
659
- <IconButton
660
- onClick={handleLoopToggle}
661
- sx={styles.loopBtn(color, looping)}
662
- size="small"
663
- >
664
- <LoopIcon size={14} />
665
- </IconButton>
666
- </MidiMappable>
667
  <Box sx={styles.meterTrack}>
668
  <Box ref={meterRef} sx={styles.meterFill(color)} />
669
  </Box>
 
5
  TextField,
6
  IconButton,
7
  Slider,
 
 
8
  Select,
9
  MenuItem,
10
  ButtonBase,
11
  } from '@mui/material';
12
+ import { TIPS } from '../tooltips';
13
+ import Tooltip from './Tooltip';
14
  import {
15
  Play as PlayIcon,
16
  Square as StopIcon,
17
+ ArrowRight as GenerateArrowIcon,
 
18
  Volume2 as VolumeIcon,
19
  VolumeX as MuteIcon,
20
+ Shuffle as VariationIcon,
 
21
  } from 'lucide-react';
22
+ import { performanceChannelStyles as styles, performancePanelStyles as panelStyles, perfTokens, SHEEN_DARK, RAISE_DARK } from '../theme';
23
  import { MidiMappable } from './MidiContext';
24
  import { playBlob as playCueBlob, stopCue, isCueSupported } from '../utils/cueAudio';
25
+ import {
26
+ channelScope,
27
+ putFragmentBlob,
28
+ getFragmentBlob,
29
+ deleteFragmentBlob,
30
+ clearScope as clearFragmentScope,
31
+ } from '../utils/fragmentStorage';
32
+ import api from '../api';
33
+ import ChannelFragmentHistory from './ChannelFragmentHistory';
34
 
35
  const CHANNEL_COLORS = [
36
+ // Original introduction palette. Ch1 teal, ch2 violet, ch3 green, ch4 amber.
37
+ // Light enough that the black label text on active toggles stays legible.
38
+ // Slots 4–7 are spares (only 4 channels render).
39
  '#35C2D4', '#9F8AE6', '#53C18A', '#E3A34B',
40
  '#E36C61', '#F08AD2', '#5BA0F0', '#A8D86B',
41
  ];
 
50
 
51
  const KNOB_DEFS = [
52
  { key: 'gain', label: 'GAIN', min: GAIN_DB_MIN, max: GAIN_DB_MAX, step: 0.5, default: GAIN_DB_DEFAULT },
53
+ // Bipolar "DJ-filter" knob. -1..+1 with 0 = bypass. Negative side drives
54
+ // the LPF cutoff down from 20 kHz β†’ 20 Hz (kills highs). Positive side
55
+ // drives the HPF cutoff up from 20 Hz β†’ 20 kHz (kills lows). The two
56
+ // biquads sit in series in the engine; only one side ever cuts at a time.
57
+ { key: 'filter', label: 'FLT', min: -1, max: 1, step: 0.001, default: 0, scale: 'bipolar' },
58
  { key: 'delay', label: 'DLY', min: 0, max: 1.0, step: 0.01, default: 0.0 },
59
  { key: 'reverb', label: 'REV', min: 0, max: 1.0, step: 0.01, default: 0.0 },
60
  ];
61
 
62
+ // Map a bipolar filter position (-1..+1) to the (LPF, HPF) frequencies that
63
+ // the engine's two biquads need. 20 Hz / 20 kHz are the bypass anchors on
64
+ // each side; log-scaled so each octave gets equal slider travel.
65
+ function bipolarToFilterFreqs(pos) {
66
+ const lpf = pos <= 0 ? 20 * Math.pow(1000, 1 + pos) : 20000;
67
+ const hpf = pos >= 0 ? 20 * Math.pow(1000, pos) : 20;
68
+ return { lpf, hpf };
69
+ }
70
+
71
  const PAN_CENTER_SNAP = 0.06;
72
 
73
  const BARS_OPTIONS = [1, 2, 4, 8, 16];
74
  const BEATS_PER_BAR = 4;
75
  const BATCH_OPTIONS = [1, 2, 3, 4];
76
+ // Per-channel rolling fragment history cap. Starred fragments survive eviction.
77
+ const FRAGMENT_CAP = 200;
78
 
79
  export default function PerformanceChannel({
80
  index,
 
87
  onStateChange,
88
  onFormStateChange,
89
  initialFormState,
90
+ maxDuration = 380,
91
  bpm = 120,
92
  }) {
93
  const color = CHANNEL_COLORS[index % CHANNEL_COLORS.length];
94
  const canvasRef = useRef(null);
95
  const meterRef = useRef(null);
96
  const meterRafRef = useRef(null);
97
+ // IDB scope key for this channel's fragment blobs. Stable across the
98
+ // component's lifetime since the channel index doesn't change.
99
+ const scope = channelScope(index);
100
 
101
  const init = initialFormState || {};
102
  const initKnobs = init.knobs || {};
 
108
 
109
  const [prompt, setPrompt] = useState(init.prompt ?? '');
110
  const [duration, setDuration] = useState(init.duration ?? 8);
111
+ const [durationMode, setDurationMode] = useState(init.durationMode ?? 'bars');
112
  const [bars, setBars] = useState(init.bars ?? 4);
113
  const [generating, setGenerating] = useState(false);
114
  const [loaded, setLoaded] = useState(false);
 
116
  const [muted, setMuted] = useState(init.muted ?? false);
117
  const [soloed, setSoloed] = useState(init.soloed ?? false);
118
  const [batchSize, setBatchSize] = useState(init.batchSize ?? 1);
119
+ // Live progress for the Generate pill while a generation is in flight.
120
+ // 0–100; polled from /api/generation-progress. Resets on each new run.
121
+ const [progress, setProgress] = useState(0);
122
+ const [knobs, setKnobs] = useState(() => {
123
+ const merged = { ...defaultKnobs, ...initKnobs };
124
+ // Migration: pre-bipolar `filter` was a raw Hz value (20..20000).
125
+ // Anything outside the new -1..+1 range is a legacy save β€” reset
126
+ // to bypass (0) rather than feeding nonsense into the engine.
127
+ if (merged.filter < -1 || merged.filter > 1) merged.filter = 0;
128
+ return merged;
129
+ });
130
+
131
+ // Per-channel rolling fragment history. Each fragment:
132
+ // { id, blob, audioUrl, prompt, duration, createdAt, starred, number }
133
+ // Oldest-first. Capped at FRAGMENT_CAP via FIFO eviction with star
134
+ // priority (starred fragments survive until everything is starred, then
135
+ // oldest go first regardless). `nextFragmentNumberRef` provides a stable
136
+ // F# even after deletes β€” so F1 stays F1.
137
+ const [fragments, setFragments] = useState([]);
138
+ const [auditioningFragmentId, setAuditioningFragmentId] = useState(null);
139
+ const [committedFragmentId, setCommittedFragmentId] = useState(null);
140
+ const nextFragmentNumberRef = useRef(1);
141
  const cueSupported = useMemo(() => isCueSupported(), []);
142
 
143
  // Stop any active cue audition when the channel unmounts.
144
  useEffect(() => () => stopCue(), []);
145
 
146
+ // Poll /api/generation-progress while a generation is in flight so the
147
+ // Generate pill renders a real fill bar instead of a vague spinner. The
148
+ // backend exposes a single in-flight state; performance generations are
149
+ // sequential (the backend serves one at a time), so this naturally
150
+ // reflects whichever channel is currently busy.
151
+ useEffect(() => {
152
+ if (!generating) {
153
+ setProgress(0);
154
+ return;
155
+ }
156
+ let cancelled = false;
157
+ const tick = async () => {
158
+ if (cancelled) return;
159
+ try {
160
+ const r = await api.get('/api/generation-progress');
161
+ const pct = Number(r.data?.progress) || 0;
162
+ if (!cancelled) {
163
+ // Cap at 95 until handleGenerate resolves so the bar
164
+ // doesn't sit at 100 while waiting for the WAV blob.
165
+ setProgress((prev) => Math.max(prev, Math.min(95, pct)));
166
+ }
167
+ } catch { /* non-fatal β€” bar just freezes briefly */ }
168
+ };
169
+ tick();
170
+ const id = window.setInterval(tick, 250);
171
+ return () => { cancelled = true; window.clearInterval(id); };
172
+ }, [generating]);
173
+
174
  // Mirror form state up to the panel so it can persist the session. Skip the
175
  // first render so we don't re-write what we just loaded from localStorage.
176
+ // Fragments mirror as metadata only β€” the Blob bodies live in IndexedDB
177
+ // and get rehydrated on mount by the effect below.
178
  const initialReportSkippedRef = useRef(false);
179
  useEffect(() => {
180
  if (!initialReportSkippedRef.current) {
181
  initialReportSkippedRef.current = true;
182
  return;
183
  }
184
+ const fragmentsMeta = fragments.map(({ blob, audioUrl, ...rest }) => rest);
185
  onFormStateChange?.(index, {
186
  prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs,
187
+ fragments: fragmentsMeta,
188
+ committedFragmentId,
189
  });
190
+ }, [prompt, duration, durationMode, bars, looping, muted, soloed, batchSize, knobs,
191
+ fragments, committedFragmentId, index, onFormStateChange]);
192
+
193
+ // Hydrate fragments on mount from the session metadata + IDB blobs. Runs
194
+ // once, tolerates missing blobs (skips the entry), and rewinds the
195
+ // fragment numbering counter so newly generated fragments don't collide
196
+ // with the restored ones.
197
+ const hydrationRef = useRef(false);
198
+ useEffect(() => {
199
+ if (hydrationRef.current) return;
200
+ hydrationRef.current = true;
201
+ // Backward compat: pre-rename saves used `takes`/`committedTakeId`.
202
+ // The session loader migrates them into `fragments`/`committedFragmentId`,
203
+ // but we also fall back here defensively in case `initialFormState`
204
+ // came from somewhere unmigrated.
205
+ const meta = initialFormState?.fragments
206
+ ?? initialFormState?.takes
207
+ ?? [];
208
+ const persistedCommittedId = initialFormState?.committedFragmentId
209
+ ?? initialFormState?.committedTakeId
210
+ ?? null;
211
+ if (meta.length === 0) {
212
+ if (persistedCommittedId) setCommittedFragmentId(null);
213
+ return;
214
+ }
215
+
216
+ let cancelled = false;
217
+ (async () => {
218
+ const hydrated = [];
219
+ for (const m of meta) {
220
+ try {
221
+ const blob = await getFragmentBlob(scope, m.id);
222
+ if (cancelled) {
223
+ hydrated.forEach(t => URL.revokeObjectURL(t.audioUrl));
224
+ return;
225
+ }
226
+ if (!blob) continue;
227
+ hydrated.push({
228
+ ...m,
229
+ blob,
230
+ audioUrl: URL.createObjectURL(blob),
231
+ });
232
+ } catch {
233
+ /* one bad fetch β€” keep going */
234
+ }
235
+ }
236
+ if (cancelled) {
237
+ hydrated.forEach(t => URL.revokeObjectURL(t.audioUrl));
238
+ return;
239
+ }
240
+ const maxNumber = hydrated.reduce((a, t) => Math.max(a, t.number || 0), 0);
241
+ nextFragmentNumberRef.current = maxNumber + 1;
242
+ setFragments(hydrated);
243
+ if (persistedCommittedId && hydrated.some(t => t.id === persistedCommittedId)) {
244
+ setCommittedFragmentId(persistedCommittedId);
245
+ setLoaded(true);
246
+ onStateChange?.(index, { loaded: true });
247
+ }
248
+ })();
249
+
250
+ return () => { cancelled = true; };
251
+ // eslint-disable-next-line react-hooks/exhaustive-deps
252
+ }, []);
253
+
254
+ // When the audio strip becomes available AND we have a hydrated committed
255
+ // fragment, load that fragment's blob into the strip so the channel comes back
256
+ // ready to play after reload. Declared here as a ref so the effect that
257
+ // actually does the work (after drawWave is defined below) can guard
258
+ // against multiple loads.
259
+ const autoLoadDoneRef = useRef(false);
260
 
261
  const secondsFromBars = useMemo(
262
  () => bars * (60 / Math.max(bpm, 1)) * BEATS_PER_BAR,
 
286
 
287
  const drawWave = useCallback(() => {
288
  if (strip && canvasRef.current) {
289
+ // Each channel's waveform is drawn in that channel's own color.
290
  strip.drawWaveform(canvasRef.current, color);
291
  }
292
  }, [strip, color]);
293
 
294
  useEffect(() => { drawWave(); }, [drawWave, loaded]);
295
 
296
+ // Auto-load the persisted committed fragment into the strip once Tone.js
297
+ // is ready. Runs at most once per mount; the ref guards against re-trigger
298
+ // when the user later commits a different fragment (handled by
299
+ // handleCommitFragment).
300
+ useEffect(() => {
301
+ if (autoLoadDoneRef.current) return;
302
+ if (!strip || !committedFragmentId) return;
303
+ const fragment = fragments.find(f => f.id === committedFragmentId);
304
+ if (!fragment) return;
305
+ autoLoadDoneRef.current = true;
306
+ strip.loadBlob(fragment.blob).then(() => {
307
+ requestAnimationFrame(drawWave);
308
+ }).catch(err => {
309
+ console.warn(`Channel ${index + 1} auto-load failed:`, err);
310
+ autoLoadDoneRef.current = false;
311
+ });
312
+ }, [strip, committedFragmentId, fragments, drawWave, index]);
313
+
314
  // One-shot: push restored knob/loop values into the audio strip when it
315
  // first becomes available, so the persisted session matches what's heard.
316
  // Mute/solo applies through the parent's mix handler so the panel can
 
321
  stripStateAppliedRef.current = true;
322
  strip.setUserGain(gainDbToLinear(knobs.gain));
323
  strip.setPan(knobs.pan);
324
+ {
325
+ const { lpf, hpf } = bipolarToFilterFreqs(knobs.filter);
326
+ strip.setFilter(lpf);
327
+ strip.setHighpass(hpf);
328
+ }
329
  strip.setDelayMix(knobs.delay);
330
  strip.setReverbMix(knobs.reverb);
331
  strip.setLoop(looping);
 
347
  }
348
  }, [availableBars, bars]);
349
 
350
+ // Per-fragment handler factory β€” fires as each blob returns. Fragment #0
351
+ // auto-loads into the strip so the user can audition while #1..N render.
352
+ // Shared by Generate and Variation so both feed channel history identically.
353
+ const makeOnBlob = (promptSnap, effectiveDuration) => async (blob, i) => {
354
+ const fragmentNumber = nextFragmentNumberRef.current;
355
+ nextFragmentNumberRef.current = fragmentNumber + 1;
356
+ const fragment = {
357
+ id: `${Date.now()}_${i}`,
358
+ blob,
359
+ audioUrl: URL.createObjectURL(blob),
360
+ prompt: promptSnap,
361
+ duration: effectiveDuration,
362
+ createdAt: Date.now(),
363
+ starred: false,
364
+ number: fragmentNumber,
365
+ };
366
+
367
+ // Persist the blob to IndexedDB so it survives reload. Fire-and-forget.
368
+ putFragmentBlob(scope, fragment.id, blob).catch((err) => {
369
+ console.warn(`Channel ${index + 1} fragment persist failed:`, err);
370
+ });
371
+
372
+ // Append to history with FRAGMENT_CAP eviction (oldest unstarred first).
373
+ setFragments((prev) => {
374
+ const combined = [...prev, fragment];
375
+ if (combined.length <= FRAGMENT_CAP) return combined;
376
+ const trimmed = combined.slice();
377
+ while (trimmed.length > FRAGMENT_CAP) {
378
+ let idx = -1;
379
+ for (let j = 0; j < trimmed.length; j++) {
380
+ if (!trimmed[j].starred) { idx = j; break; }
381
+ }
382
+ if (idx < 0) idx = 0; // all starred β†’ drop oldest
383
+ const dying = trimmed[idx];
384
+ if (dying.audioUrl?.startsWith('blob:')) {
385
+ try { URL.revokeObjectURL(dying.audioUrl); } catch { /* ignore */ }
386
+ }
387
+ deleteFragmentBlob(scope, dying.id).catch(() => { /* ignore */ });
388
+ trimmed.splice(idx, 1);
389
+ }
390
+ return trimmed;
391
+ });
392
+
393
+ // Generating must never disturb playback: a playing channel keeps
394
+ // looping its current clip while new fragments just pile into the
395
+ // history list. Only auto-load when the channel has nothing loaded yet
396
+ // (first-ever fragment) β€” harmless since nothing is playing β€” so the
397
+ // user still gets a ready-to-play clip on a fresh channel. To start a
398
+ // newly generated fragment, pick it from the list (handleCommitFragment).
399
+ if (i === 0 && !loaded) {
400
+ await strip.loadBlob(blob);
401
+ setCommittedFragmentId(fragment.id);
402
+ setLoaded(true);
403
+ onStateChange?.(index, { loaded: true });
404
+ requestAnimationFrame(drawWave);
405
+ }
406
+ };
407
+
408
  const handleGenerate = async () => {
409
  if (!prompt.trim() || generating) return;
410
  const inBarsMode = durationMode === 'bars';
411
  const effectiveDuration = inBarsMode ? secondsFromBars : duration;
412
  setGenerating(true);
413
+ // Stop any in-flight cue audition so the old preview doesn't keep
414
+ // playing while we generate the new fragment.
415
  stopCue();
416
+ setAuditioningFragmentId(null);
417
+
418
+ const promptSnap = prompt.trim();
419
+
420
  try {
421
+ await onGenerate({
422
  prompt,
423
  duration: effectiveDuration,
424
  batchSize,
425
  // Only forward alignment params in bars mode β€” seconds mode
426
  // generates raw audio with no post-processing.
427
  ...(inBarsMode ? { alignBars: bars, alignBpm: bpm } : {}),
428
+ // Phase 7: bars-mode + channel-looping β‡’ ask the backend
429
+ // to wrap-inpaint the seam so the clip loops seamlessly.
430
+ ...(inBarsMode && looping ? { loopStitch: 'inpaint' } : {}),
431
+ onBlob: makeOnBlob(promptSnap, effectiveDuration),
432
  });
 
 
 
 
 
 
 
 
 
 
433
  } catch (err) {
434
  console.error(`Channel ${index + 1} generate failed:`, err);
435
  } finally {
 
437
  }
438
  };
439
 
440
+ // Phase 8 "Variation": re-roll the channel using its current fragment as
441
+ // init_audio at a high noise level β€” gives a related-but-different take
442
+ // (A/B/A/C/A live sets). Uploads the source blob to get a server path,
443
+ // then routes through the same generate flow.
444
+ const handleVariation = async () => {
445
+ if (generating) return;
446
+ const src = fragments.find((f) => f.id === committedFragmentId)
447
+ || fragments[fragments.length - 1];
448
+ if (!src?.blob) return;
449
+ const inBarsMode = durationMode === 'bars';
450
+ const effectiveDuration = inBarsMode ? secondsFromBars : duration;
451
+ const promptSnap = (prompt || '').trim() || src.prompt || 'variation';
452
+ setGenerating(true);
453
+ stopCue();
454
+ setAuditioningFragmentId(null);
455
+ try {
456
+ const form = new FormData();
457
+ form.append('file', new File([src.blob], `${scope}_variation_src.wav`, { type: 'audio/wav' }));
458
+ const up = await api.post('/api/audio/upload', form);
459
+ await onGenerate({
460
+ prompt: promptSnap,
461
+ duration: effectiveDuration,
462
+ batchSize: 1,
463
+ initAudioPath: up.data.path,
464
+ initNoiseLevel: 0.9,
465
+ onBlob: makeOnBlob(promptSnap, effectiveDuration),
466
+ });
467
+ } catch (err) {
468
+ console.error(`Channel ${index + 1} variation failed:`, err);
469
+ } finally {
470
+ setGenerating(false);
471
+ }
472
+ };
473
+
474
+ // Fragment history actions β€” toggle audition through cue, commit a
475
+ // fragment to the channel buffer, star/unstar, delete one, or clear
476
+ // the whole list.
477
+ const handleAuditionFragment = async (fragmentId) => {
478
+ const fragment = fragments.find((f) => f.id === fragmentId);
479
+ if (!fragment) return;
480
+ if (auditioningFragmentId === fragmentId) {
481
  stopCue();
482
+ setAuditioningFragmentId(null);
483
  return;
484
  }
485
+ setAuditioningFragmentId(fragmentId);
486
  try {
487
+ await playCueBlob(fragment.blob, {
488
+ onEnded: () => setAuditioningFragmentId((prev) => (prev === fragmentId ? null : prev)),
489
  });
490
  } catch (err) {
491
  console.warn(`Channel ${index + 1} audition failed:`, err);
492
+ setAuditioningFragmentId(null);
493
  }
494
  };
495
 
496
+ // Choosing a fragment launches it from the beginning. The currently
497
+ // playing clip (if any) keeps sounding until the launch point: immediately
498
+ // in seconds mode or when launch quantization is None, otherwise at the
499
+ // next launch-quantization bar. The buffer is decoded WITHOUT stopping the
500
+ // live source, so the swap is gapless (the engine schedules the handoff).
501
+ const handleCommitFragment = async (fragmentId) => {
502
+ const fragment = fragments.find((f) => f.id === fragmentId);
503
+ if (!fragment) return;
504
+ const sameFragment = committedFragmentId === fragmentId;
505
+ // Already looping this exact clip β†’ nothing to (re)launch.
506
+ if (sameFragment && playing) return;
507
+
508
+ // Decode the new clip without cutting the live source; skip the decode
509
+ // when this fragment's buffer is already loaded.
510
+ if (!sameFragment) {
511
+ await strip.loadBufferFromBlob(fragment.blob);
512
+ setCommittedFragmentId(fragmentId);
513
+ }
514
+ // Mark loaded so the play button enables (covers preset/hydrated flows
515
+ // where the first commit happens here rather than via generate).
516
+ if (!loaded) {
517
+ setLoaded(true);
518
+ onStateChange?.(index, { loaded: true });
519
+ }
520
+
521
+ // Launch from the top. Seconds mode is always immediate; bars mode
522
+ // defers to the engine's launch-quantization (ASAP when quantum=None).
523
+ const immediate = durationMode === 'seconds';
524
+ if (engine) engine.relaunchChannel(index, looping, immediate);
525
+ else strip.playAt(looping, 0);
526
+ onStateChange?.(index, { playing: true });
527
  requestAnimationFrame(drawWave);
528
  };
529
 
530
+ // Drag-and-drop: a fragment row from this channel's history can be
531
+ // dropped onto the waveform monitor to load it (same effect as the row's
532
+ // commit βœ“ button). The MIME type is channel-scoped, so a row from
533
+ // channel 1 won't even highlight channel 2's waveform β€” the browser
534
+ // filters at dragOver level via dataTransfer.types matching.
535
+ const dragMime = `application/x-fragmenta-fragment-ch${index}`;
536
+ const [dropActive, setDropActive] = useState(false);
537
+ // Counter pattern β€” dragenter/leave also fire when the cursor crosses
538
+ // into child elements (canvas, overlay). Without the counter, dropActive
539
+ // would flicker false whenever the cursor moved over a child.
540
+ const dragCounterRef = useRef(0);
541
+
542
+ const handleWaveDragEnter = (e) => {
543
+ if (!e.dataTransfer.types.includes(dragMime)) return;
544
+ e.preventDefault();
545
+ dragCounterRef.current += 1;
546
+ if (dragCounterRef.current === 1) setDropActive(true);
547
+ };
548
+ const handleWaveDragOver = (e) => {
549
+ if (!e.dataTransfer.types.includes(dragMime)) return;
550
+ e.preventDefault();
551
+ e.dataTransfer.dropEffect = 'copy';
552
+ };
553
+ const handleWaveDragLeave = () => {
554
+ dragCounterRef.current = Math.max(0, dragCounterRef.current - 1);
555
+ if (dragCounterRef.current === 0) setDropActive(false);
556
+ };
557
+ const handleWaveDrop = (e) => {
558
+ e.preventDefault();
559
+ dragCounterRef.current = 0;
560
+ setDropActive(false);
561
+ const fragmentId = e.dataTransfer.getData(dragMime);
562
+ if (fragmentId) handleCommitFragment(fragmentId);
563
+ };
564
+
565
+ const handleToggleStar = (fragmentId) => {
566
+ setFragments((prev) => prev.map((f) =>
567
+ f.id === fragmentId ? { ...f, starred: !f.starred } : f,
568
+ ));
569
+ };
570
+
571
+ const handleDeleteFragment = (fragmentId) => {
572
+ const target = fragments.find((f) => f.id === fragmentId);
573
+ if (target?.audioUrl?.startsWith('blob:')) {
574
+ try { URL.revokeObjectURL(target.audioUrl); } catch { /* ignore */ }
575
+ }
576
+ deleteFragmentBlob(scope, fragmentId).catch(() => { /* ignore */ });
577
+ setFragments((prev) => prev.filter((f) => f.id !== fragmentId));
578
+ if (committedFragmentId === fragmentId) setCommittedFragmentId(null);
579
+ if (auditioningFragmentId === fragmentId) {
580
+ stopCue();
581
+ setAuditioningFragmentId(null);
582
+ }
583
+ };
584
+
585
+ const handleClearFragments = () => {
586
+ // Stop any in-flight audition and revoke every blob URL before
587
+ // dropping references β€” otherwise the URLs leak until reload.
588
+ stopCue();
589
+ setAuditioningFragmentId(null);
590
+ fragments.forEach((f) => {
591
+ if (f.audioUrl?.startsWith('blob:')) {
592
+ try { URL.revokeObjectURL(f.audioUrl); } catch { /* ignore */ }
593
+ }
594
+ });
595
+ clearFragmentScope(scope).catch(() => { /* ignore */ });
596
+ setFragments([]);
597
+ setCommittedFragmentId(null);
598
+ };
599
+
600
  const handlePlay = () => {
601
  if (!loaded) return;
602
  if (engine) engine.playChannel(index, looping);
 
633
  setKnobs(prev => ({ ...prev, [key]: value }));
634
  if (key === 'gain') strip.setUserGain(gainDbToLinear(value));
635
  else if (key === 'pan') strip.setPan(value);
636
+ else if (key === 'filter') {
637
+ const { lpf, hpf } = bipolarToFilterFreqs(value);
638
+ strip.setFilter(lpf);
639
+ strip.setHighpass(hpf);
640
+ }
641
  else if (key === 'delay') strip.setDelayMix(value);
642
  else if (key === 'reverb') strip.setReverbMix(value);
643
  };
 
659
  return (
660
  <Box sx={styles.strip(color, playing)}>
661
  <Box sx={styles.stripHeader(color)}>
662
+ {/* Transport (Play / Loop) on the left, Mute / Solo on the
663
+ right β€” replaces the old "01" channel badge so the channel
664
+ number isn't using up that slot. */}
665
+ <Box sx={styles.muteSoloRow}>
666
+ <MidiMappable id={ctrlId('transport')} label={ctrlLabel('Play/Stop')} kind="trigger" onChange={handleTransportToggle}>
667
+ <IconButton
668
+ onClick={playing ? handleStop : handlePlay}
669
+ disabled={!loaded}
670
+ sx={styles.transportBtn(color, playing)}
671
+ size="small"
672
+ >
673
+ {playing ? <StopIcon size={16} /> : <PlayIcon size={16} />}
674
+ </IconButton>
675
+ </MidiMappable>
676
+ <MidiMappable id={ctrlId('loop')} label={ctrlLabel('Loop')} kind="trigger" onChange={handleLoopToggle}>
677
+ <Tooltip
678
+ title={TIPS.channel.loop(looping, durationMode)}
679
+ placement="top"
680
+ enterDelay={400}
681
+ >
682
+ <IconButton
683
+ onClick={handleLoopToggle}
684
+ sx={styles.loopBtn(color, looping)}
685
+ size="small"
686
+ aria-label={looping ? 'Loop on' : 'Loop off'}
687
+ >
688
+ L
689
+ </IconButton>
690
+ </Tooltip>
691
+ </MidiMappable>
692
+ </Box>
693
  <Box sx={styles.muteSoloRow}>
694
  <MidiMappable id={ctrlId('mute')} label={ctrlLabel('Mute')} kind="trigger" onChange={handleMuteToggle}>
695
+ <Tooltip title={TIPS.channel.mute}>
696
  <IconButton size="small" onClick={handleMuteToggle} sx={styles.muteBtn(muted)}>M</IconButton>
697
  </Tooltip>
698
  </MidiMappable>
699
  <MidiMappable id={ctrlId('solo')} label={ctrlLabel('Solo')} kind="trigger" onChange={handleSoloToggle}>
700
+ <Tooltip title={TIPS.channel.solo}>
701
  <IconButton size="small" onClick={handleSoloToggle} sx={styles.soloBtn(soloed)}>S</IconButton>
702
  </Tooltip>
703
  </MidiMappable>
 
706
 
707
  <Box sx={styles.promptBox}>
708
  <TextField
709
+ placeholder="Prompt…"
710
  value={prompt}
711
  onChange={(e) => setPrompt(e.target.value)}
712
  multiline
713
  minRows={2}
714
+ maxRows={2}
715
  size="small"
716
  fullWidth
717
  sx={styles.promptField}
718
  disabled={generating}
719
  />
720
  <Box sx={{ ...styles.durationRow, minHeight: 26, height: 26 }}>
721
+ {durationMode === 'seconds' ? (
722
+ <>
723
+ <Typography sx={styles.durationLabel}>{duration.toFixed(0)}s</Typography>
724
+ <Slider
725
+ value={duration}
726
+ onChange={(_, v) => setDuration(v)}
727
+ min={2}
728
+ max={maxDuration}
729
+ step={1}
730
+ size="small"
731
+ sx={styles.durationSlider(color)}
732
+ />
733
+ </>
734
+ ) : (
735
+ <Select
736
+ value={availableBars.includes(bars) ? bars : availableBars[availableBars.length - 1]}
737
+ onChange={(e) => setBars(Number(e.target.value))}
738
+ size="small"
739
+ sx={{ ...panelStyles.pillControl, flex: 1 }}
740
+ >
741
+ {availableBars.map(b => (
742
+ <MenuItem key={b} value={b} sx={{ fontSize: perfTokens.fontSize.sm }}>
743
+ {b} {b === 1 ? 'bar' : 'bars'}
744
+ </MenuItem>
745
+ ))}
746
+ </Select>
747
+ )}
748
+
749
+ {/* Sec/Bars mode toggle β€” moved to the right of the row so
750
+ it mirrors the Generate row layout (content fills left,
751
+ modifier sits right). Width matches the Γ—N selector so
752
+ the right column reads as a uniform stack. */}
753
  <Box
754
  sx={{
755
  display: 'inline-flex',
 
758
  borderRadius: 0.75,
759
  overflow: 'hidden',
760
  height: '100%',
761
+ flexShrink: 0,
762
  }}
763
  >
764
+ {[
765
+ { mode: 'sec', label: 'Sec' },
766
+ { mode: 'bars', label: 'Bars' },
767
+ ].map(({ mode, label }) => {
768
  const value = mode === 'sec' ? 'seconds' : 'bars';
769
  const active = durationMode === value;
770
  return (
 
772
  key={mode}
773
  onClick={() => setDurationMode(value)}
774
  sx={{
775
+ fontSize: perfTokens.fontSize.sm,
 
 
 
776
  px: 0.7,
777
+ minWidth: 36,
778
  bgcolor: active ? color : 'transparent',
779
+ backgroundImage: active ? SHEEN_DARK : 'none',
780
+ boxShadow: active ? RAISE_DARK : 'none',
781
  color: active ? 'rgba(0,0,0,0.88)' : 'text.disabled',
782
+ fontWeight: active ? perfTokens.weight.bold : perfTokens.weight.regular,
783
+ transition: 'background-color 120ms, color 120ms, box-shadow 120ms',
784
  '&:hover': {
785
  bgcolor: active ? color : 'action.hover',
786
  color: active ? 'rgba(0,0,0,0.88)' : 'text.secondary',
787
  },
788
  }}
789
  >
790
+ {label}
791
  </ButtonBase>
792
  );
793
  })}
794
  </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  </Box>
796
  <Box sx={{
797
  display: 'flex',
798
  alignItems: 'center',
799
+ gap: 1,
 
800
  mt: 0.5,
801
  width: '100%',
802
  }}>
803
+ {/* Generate pill β€” wide CTA on the left so the eye lands
804
+ on the primary action first. Fills left-to-right with
805
+ live progress while generating; resets when complete. */}
806
+ <MidiMappable id={ctrlId('generate')} label={ctrlLabel('Generate')} kind="trigger" onChange={handleGenerate}>
807
+ <Tooltip
808
+ title={TIPS.channel.generateDisabled(generating, canGenerate, prompt.trim())}
809
+ placement="top"
810
+ >
811
+ <span style={{ display: 'inline-flex', flex: 1, minWidth: 0 }}>
812
+ <ButtonBase
813
+ onClick={handleGenerate}
814
+ disabled={!canGenerate || !prompt.trim() || generating}
815
+ sx={styles.generatePill(color, {
816
+ generating,
817
+ disabled: !canGenerate || !prompt.trim(),
818
+ })}
819
+ >
820
+ {generating && (
821
+ <Box sx={styles.generatePillFill(color, progress)} />
822
+ )}
823
+ <Box component="span" sx={styles.generatePillLabel}>
824
+ {generating
825
+ ? `Generating Β· ${Math.round(progress)}%`
826
+ : 'Generate'}
827
+ {!generating && <GenerateArrowIcon size={14} strokeWidth={2.25} />}
828
+ </Box>
829
+ </ButtonBase>
830
+ </span>
831
+ </Tooltip>
832
+ </MidiMappable>
833
+
834
+ {/* Batch selector β€” sits right of Generate so the row
835
+ reads "Generate Γ— 4" (action then modifier). Sized to
836
+ its content (Γ—1οΏ½οΏ½Γ—8 + dropdown arrow); no need to match
837
+ the wider Sec/Bars toggle above. */}
838
  <Tooltip
839
+ title={TIPS.channel.batch}
840
  placement="top"
 
 
841
  enterDelay={500}
842
  >
843
  <Select
844
  value={batchSize}
845
  onChange={(e) => setBatchSize(Number(e.target.value))}
 
846
  disabled={generating}
847
+ size="small"
848
+ sx={{ ...styles.channelPillControl, width: 54, flexShrink: 0 }}
849
+ renderValue={(v) => `Γ—${v}`}
 
 
 
 
850
  >
851
+ {BATCH_OPTIONS.map((n) => (
852
+ <MenuItem
853
+ key={n}
854
+ value={n}
855
+ sx={{ fontSize: perfTokens.fontSize.sm, fontVariantNumeric: 'tabular-nums' }}
856
+ >
857
  Γ—{n}
858
  </MenuItem>
859
  ))}
860
  </Select>
861
  </Tooltip>
862
+
863
+ {/* Variation β€” re-roll from the current fragment as
864
+ init_audio (Phase 8). Disabled until a fragment exists. */}
865
+ <MidiMappable id={ctrlId('variation')} label={ctrlLabel('Variation')} kind="trigger" onChange={handleVariation}>
866
+ <Tooltip
867
+ title={TIPS.channel.variation(loaded)}
868
+ placement="top"
869
  >
870
+ <span style={{ display: 'inline-flex', flexShrink: 0 }}>
871
+ <ButtonBase
872
+ onClick={handleVariation}
873
+ disabled={!loaded || generating}
874
+ sx={{
875
+ ...styles.channelPillControl,
876
+ width: 40,
877
+ justifyContent: 'center',
878
+ '&.Mui-disabled': { opacity: 0.4, color: 'text.disabled' },
879
+ }}
880
+ aria-label="Variation"
881
+ >
882
+ <VariationIcon size={15} strokeWidth={2.25} />
883
+ </ButtonBase>
884
+ </span>
885
+ </Tooltip>
886
  </MidiMappable>
887
  </Box>
888
  </Box>
889
 
890
+ <Box
891
+ onDragEnter={handleWaveDragEnter}
892
+ onDragOver={handleWaveDragOver}
893
+ onDragLeave={handleWaveDragLeave}
894
+ onDrop={handleWaveDrop}
895
+ sx={[
896
+ styles.waveformWrap,
897
+ dropActive && {
898
+ borderColor: color,
899
+ boxShadow: `inset 0 0 0 2px ${color}`,
900
+ backgroundColor: `${color}1F`,
901
+ transition: 'border-color 120ms, box-shadow 120ms, background-color 120ms',
902
+ },
903
+ ]}
904
+ >
905
  <canvas
906
  ref={canvasRef}
907
  width={140}
908
  height={42}
909
+ style={{ width: '100%', height: 42, display: 'block', pointerEvents: 'none' }}
910
  />
911
  {!loaded && (
912
+ <Typography sx={styles.waveformPlaceholder}>
913
+ {dropActive ? 'Drop to load' : 'Waveform'}
914
  </Typography>
915
  )}
916
  </Box>
917
 
918
+ {/* Per-channel rolling fragment history. Always rendered (empty
919
+ state included). Star/keep, delete, audition, load β€” all
920
+ inline per row. Capped at FRAGMENT_CAP via FIFO with star
921
+ priority. */}
922
+ <ChannelFragmentHistory
923
+ fragments={fragments}
924
+ color={color}
925
+ channelIndex={index}
926
+ auditioningId={auditioningFragmentId}
927
+ committedId={committedFragmentId}
928
+ maxFragments={FRAGMENT_CAP}
929
+ onAudition={handleAuditionFragment}
930
+ onCommit={handleCommitFragment}
931
+ onToggleStar={handleToggleStar}
932
+ onDelete={handleDeleteFragment}
933
+ onClearAll={handleClearFragments}
934
+ />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
 
936
  <Box sx={{ px: 1, py: 1 }}>
937
  <Box sx={{ display: 'flex', alignItems: 'center', gap: 1, mb: 1 }}>
938
+ <Box component="span" sx={{ ...perfTokens.caps, color: 'text.secondary', minWidth: 28 }}>PAN</Box>
939
  <MidiMappable
940
  id={ctrlId('pan')}
941
  label={ctrlLabel('Pan')}
 
957
  marks={[{ value: 0 }]}
958
  sx={{
959
  flex: 1,
960
+ // Match the channel main color (the global
961
+ // MuiSlider override is amber; the vertical
962
+ // knobs already pass `color` via knobSlider).
963
+ color,
964
  '& .MuiSlider-mark': {
965
  width: 2,
966
  height: 10,
 
981
  <Box sx={styles.knobsGrid}>
982
  {KNOB_DEFS.map((k) => {
983
  const isLog = k.scale === 'log';
984
+ const isBipolar = k.scale === 'bipolar';
985
  // For log knobs, the slider drives a 0..1 position and we
986
  // convert to/from the underlying value (Hz) on the audio
987
  // boundary. The knob value stored in state stays in the
 
1013
  max={isLog ? 1 : k.max}
1014
  step={isLog ? 0.001 : k.step}
1015
  size="small"
1016
+ track={isBipolar ? false : undefined}
1017
+ marks={isBipolar ? [{ value: 0 }] : undefined}
1018
+ sx={{
1019
+ ...styles.knobSlider(color, k.key === 'gain'),
1020
+ ...(isBipolar && {
1021
+ '& .MuiSlider-mark': {
1022
+ width: 10,
1023
+ height: 2,
1024
+ borderRadius: 1,
1025
+ backgroundColor: 'text.secondary',
1026
+ opacity: 0.7,
1027
+ },
1028
+ '& .MuiSlider-markActive': {
1029
+ backgroundColor: 'text.secondary',
1030
+ opacity: 0.7,
1031
+ },
1032
+ }),
1033
+ }}
1034
  />
1035
  </MidiMappable>
1036
  <Box component="span" sx={styles.knobLabel}>{k.label}</Box>
 
1039
  })}
1040
  </Box>
1041
 
1042
+ {/* Bottom row is now just the channel level meter β€” Play and Loop
1043
+ moved to the top header so the channel reads "controls on top,
1044
+ signal flow below". */}
1045
  <Box sx={styles.transportRow}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1046
  <Box sx={styles.meterTrack}>
1047
  <Box ref={meterRef} sx={styles.meterFill(color)} />
1048
  </Box>
app/frontend/src/components/PerformancePanel.js CHANGED
The diff for this file is too large to render. See raw diff
 
app/frontend/src/components/StorageDrilldown.js ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react';
2
+ import {
3
+ Dialog,
4
+ DialogTitle,
5
+ DialogContent,
6
+ DialogActions,
7
+ Button,
8
+ Typography,
9
+ Box,
10
+ LinearProgress,
11
+ Table,
12
+ TableBody,
13
+ TableCell,
14
+ TableHead,
15
+ TableRow,
16
+ } from '@mui/material';
17
+
18
+ const fmtBytes = (n) => {
19
+ if (!n) return 'β€”';
20
+ // Decimal (SI) units β€” matches what HuggingFace shows next to safetensors files.
21
+ const units = ['B', 'KB', 'MB', 'GB', 'TB'];
22
+ let v = n;
23
+ let u = 0;
24
+ while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; }
25
+ return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`;
26
+ };
27
+
28
+ export default function StorageDrilldown({ open, onClose, storage, catalog }) {
29
+ if (!storage) return null;
30
+
31
+ const usedPct = storage.total_used_bytes && (storage.total_used_bytes + storage.total_free_bytes)
32
+ ? (storage.total_used_bytes / (storage.total_used_bytes + storage.total_free_bytes)) * 100
33
+ : 0;
34
+
35
+ const nameFor = (id) => catalog?.find(c => c.id === id)?.name || id;
36
+
37
+ const rows = (storage.per_model || [])
38
+ .filter(m => m.downloaded)
39
+ .sort((a, b) => b.bytes - a.bytes);
40
+
41
+ return (
42
+ <Dialog open={open} onClose={onClose} maxWidth="sm" fullWidth>
43
+ <DialogTitle>Storage</DialogTitle>
44
+ <DialogContent dividers>
45
+ <Box sx={{ mb: 2 }}>
46
+ <Typography variant="body2" color="text.secondary">
47
+ {fmtBytes(storage.total_used_bytes)} used Β· {fmtBytes(storage.total_free_bytes)} free
48
+ </Typography>
49
+ <LinearProgress
50
+ variant="determinate"
51
+ value={Math.min(100, usedPct)}
52
+ sx={{ mt: 1, height: 6, borderRadius: 3 }}
53
+ />
54
+ </Box>
55
+
56
+ {rows.length === 0 ? (
57
+ <Typography variant="body2" color="text.secondary" sx={{ py: 2 }}>
58
+ Nothing downloaded yet.
59
+ </Typography>
60
+ ) : (
61
+ <Table size="small">
62
+ <TableHead>
63
+ <TableRow>
64
+ <TableCell>Checkpoint</TableCell>
65
+ <TableCell align="right">Size</TableCell>
66
+ </TableRow>
67
+ </TableHead>
68
+ <TableBody>
69
+ {rows.map(m => (
70
+ <TableRow key={m.id}>
71
+ <TableCell>{nameFor(m.id)}</TableCell>
72
+ <TableCell align="right">{fmtBytes(m.bytes)}</TableCell>
73
+ </TableRow>
74
+ ))}
75
+ </TableBody>
76
+ </Table>
77
+ )}
78
+ </DialogContent>
79
+ <DialogActions>
80
+ <Button onClick={onClose}>Close</Button>
81
+ </DialogActions>
82
+ </Dialog>
83
+ );
84
+ }
app/frontend/src/components/Tooltip.js ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { cloneElement } from 'react';
2
+ import { useInfoView } from './InfoView';
3
+
4
+ /**
5
+ * App-wide Tooltip.
6
+ *
7
+ * Help text is shown exclusively through the Info View (see
8
+ * components/InfoView.js) β€” there are no popup tooltips on the controls:
9
+ *
10
+ * β€’ Info View ON β€” on hover/focus the `title` is reported to the bottom
11
+ * Info View pill, so help shows in one fixed place rather than over the
12
+ * control itself.
13
+ * β€’ Info View OFF β€” no hover help at all; the child renders untouched.
14
+ *
15
+ * The API matches MUI's Tooltip (drop-in for the existing call sites): pass a
16
+ * `title` plus a single child element. Placement/arrow/delay props are
17
+ * accepted but ignored, since there's no popup.
18
+ */
19
+ export default function Tooltip({ children, title }) {
20
+ const { enabled, setHint } = useInfoView();
21
+ const child = React.Children.only(children);
22
+
23
+ if (!enabled) {
24
+ // No popup tooltips β€” the Info View is the only help surface.
25
+ return child;
26
+ }
27
+
28
+ // Info View mode: route the tip to the bottom pill on hover/focus.
29
+ return cloneElement(child, {
30
+ onMouseEnter: (e) => { setHint(title); child.props?.onMouseEnter?.(e); },
31
+ onMouseLeave: (e) => { setHint(null); child.props?.onMouseLeave?.(e); },
32
+ onFocus: (e) => { setHint(title); child.props?.onFocus?.(e); },
33
+ onBlur: (e) => { setHint(null); child.props?.onBlur?.(e); },
34
+ });
35
+ }
app/frontend/src/components/TrainingMonitor.js CHANGED
@@ -1,14 +1,26 @@
1
- import React from 'react';
2
  import { Paper, Box, Typography, LinearProgress, Grid, Alert } from '@mui/material';
3
  import { Activity as ActivityIcon } from 'lucide-react';
4
  import LossChart from './LossChart';
5
  import { trainingMonitorStyles } from '../theme';
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  export default function TrainingMonitor({
8
  trainingProgress,
9
  trainingStatus,
 
10
  trainingError,
11
- trainingConfig,
12
  indicatorState,
13
  }) {
14
  const getProgressColor = () => {
@@ -21,6 +33,39 @@ export default function TrainingMonitor({
21
  const label = indicatorState?.label || 'Idle';
22
  const animate = indicatorState?.animate || false;
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return (
25
  <Paper sx={trainingMonitorStyles.rootPaper}>
26
  <Box sx={trainingMonitorStyles.headerRow}>
@@ -53,60 +98,56 @@ export default function TrainingMonitor({
53
  />
54
  </Box>
55
 
56
- {trainingStatus?.device_info && (
57
- <Box sx={trainingMonitorStyles.deviceSection}>
58
- <Typography variant="body2" color="textSecondary">
59
- <strong>{
60
- trainingStatus.device_info.type === 'cuda' ? 'CUDA' :
61
- trainingStatus.device_info.type === 'mps' ? 'MPS' : 'CPU'
62
- }</strong>
63
- {' Β· '}{trainingStatus.device_info.device}
64
- {trainingStatus.device_info.memory_gb
65
- ? ` Β· ${trainingStatus.device_info.memory_gb.toFixed(1)} GB`
66
- : ''}
67
- </Typography>
68
- </Box>
69
- )}
70
-
71
  <Grid container spacing={2} sx={trainingMonitorStyles.metricsGrid}>
72
  <Grid item xs={12} sm={6}>
73
- <Typography variant="body2" color="textSecondary">Current Epoch</Typography>
74
- <Typography variant="body1">
75
- {trainingStatus?.current_epoch !== undefined ?
76
- `${trainingStatus.current_epoch + 1} / ${trainingConfig.epochs}` :
77
- '0 / ' + trainingConfig.epochs}
78
  </Typography>
79
  </Grid>
80
  <Grid item xs={12} sm={6}>
81
- <Typography variant="body2" color="textSecondary">Global Step / Total Steps</Typography>
82
- <Typography variant="body1" color="primary">
83
- {trainingStatus?.global_step !== undefined && trainingStatus?.total_steps !== undefined ?
84
- `${trainingStatus.global_step} / ${trainingStatus.total_steps}` :
85
- 'N/A'}
86
  </Typography>
87
  </Grid>
88
  <Grid item xs={12} sm={6}>
89
  <Typography variant="body2" color="textSecondary">Checkpoints Saved</Typography>
90
- <Typography variant="body1">
91
- {trainingStatus?.checkpoints_saved || 0}
92
- </Typography>
93
  </Grid>
94
  <Grid item xs={12} sm={6}>
95
- <Typography variant="body2" color="textSecondary">Current Loss</Typography>
96
- <Typography variant="body1">
97
- {trainingStatus?.loss ? parseFloat(trainingStatus.loss).toFixed(4) : 'N/A'}
98
  </Typography>
99
  </Grid>
 
 
 
 
 
 
 
 
100
  </Grid>
101
 
102
- {trainingStatus?.loss_history && trainingStatus.loss_history.length > 0 && (
103
  <Box sx={trainingMonitorStyles.lossSection}>
104
  <Typography variant="body2" color="textSecondary" gutterBottom>
105
  <strong>Loss History</strong>
106
  </Typography>
107
  <Box sx={trainingMonitorStyles.lossChartBox}>
108
- <LossChart data={trainingStatus.loss_history} />
109
  </Box>
 
 
 
 
 
 
 
 
 
110
  </Box>
111
  )}
112
 
 
1
+ import React, { useMemo } from 'react';
2
  import { Paper, Box, Typography, LinearProgress, Grid, Alert } from '@mui/material';
3
  import { Activity as ActivityIcon } from 'lucide-react';
4
  import LossChart from './LossChart';
5
  import { trainingMonitorStyles } from '../theme';
6
 
7
+ /**
8
+ * TrainingMonitor β€” right-pane status card for the Training tab.
9
+ *
10
+ * Reads SA3-shaped status from the backend:
11
+ * { is_training, status, step, total_steps, current_step, progress,
12
+ * loss, checkpoints, checkpoints_saved, error, ... }
13
+ *
14
+ * SA3 trains by step count, not epochs β€” the panel surfaces step / total
15
+ * directly. Loss curve is built frontend-side from successive poll snapshots
16
+ * (trainingHistory) so we don't depend on the backend emitting a history
17
+ * array.
18
+ */
19
  export default function TrainingMonitor({
20
  trainingProgress,
21
  trainingStatus,
22
+ trainingHistory,
23
  trainingError,
 
24
  indicatorState,
25
  }) {
26
  const getProgressColor = () => {
 
33
  const label = indicatorState?.label || 'Idle';
34
  const animate = indicatorState?.animate || false;
35
 
36
+ // Loss points for the chart. We prefer the backend's loss_history
37
+ // (built from Lightning's metrics.csv, which records per-step loss
38
+ // from step 0) so the chart shows the full curve even before PL's
39
+ // tqdm postfix surfaces train/loss (which only appears after the
40
+ // first metrics flush, typically end of epoch 0). Falls back to the
41
+ // frontend-built trainingHistory if the backend hasn't populated
42
+ // loss_history yet (very early in the run, before PL writes CSV).
43
+ const lossPoints = useMemo(() => {
44
+ const fromBackend = trainingStatus?.loss_history;
45
+ if (Array.isArray(fromBackend) && fromBackend.length > 0) {
46
+ return fromBackend
47
+ .filter(p => Number.isFinite(p?.step) && Number.isFinite(p?.loss))
48
+ .sort((a, b) => a.step - b.step);
49
+ }
50
+ if (!trainingHistory || trainingHistory.length === 0) return [];
51
+ const byStep = new Map();
52
+ for (const h of trainingHistory) {
53
+ const step = h.current_step ?? h.step;
54
+ const loss = typeof h.loss === 'number' ? h.loss : parseFloat(h.loss);
55
+ if (Number.isFinite(step) && Number.isFinite(loss)) {
56
+ byStep.set(step, { step, loss });
57
+ }
58
+ }
59
+ return Array.from(byStep.values()).sort((a, b) => a.step - b.step);
60
+ }, [trainingHistory, trainingStatus?.loss_history]);
61
+
62
+ const step = trainingStatus?.current_step ?? trainingStatus?.step ?? 0;
63
+ const totalSteps = trainingStatus?.total_steps ?? 0;
64
+ const checkpointsSaved = trainingStatus?.checkpoints_saved
65
+ ?? trainingStatus?.checkpoints?.length
66
+ ?? 0;
67
+ const currentLoss = trainingStatus?.loss;
68
+
69
  return (
70
  <Paper sx={trainingMonitorStyles.rootPaper}>
71
  <Box sx={trainingMonitorStyles.headerRow}>
 
98
  />
99
  </Box>
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  <Grid container spacing={2} sx={trainingMonitorStyles.metricsGrid}>
102
  <Grid item xs={12} sm={6}>
103
+ <Typography variant="body2" color="textSecondary">Step</Typography>
104
+ <Typography variant="body1" color="primary">
105
+ {totalSteps > 0 ? `${step} / ${totalSteps}` : `${step}`}
 
 
106
  </Typography>
107
  </Grid>
108
  <Grid item xs={12} sm={6}>
109
+ <Typography variant="body2" color="textSecondary">Current Loss</Typography>
110
+ <Typography variant="body1">
111
+ {Number.isFinite(currentLoss) ? parseFloat(currentLoss).toFixed(4) : 'N/A'}
 
 
112
  </Typography>
113
  </Grid>
114
  <Grid item xs={12} sm={6}>
115
  <Typography variant="body2" color="textSecondary">Checkpoints Saved</Typography>
116
+ <Typography variant="body1">{checkpointsSaved}</Typography>
 
 
117
  </Grid>
118
  <Grid item xs={12} sm={6}>
119
+ <Typography variant="body2" color="textSecondary">Phase</Typography>
120
+ <Typography variant="body1" sx={{ textTransform: 'capitalize' }}>
121
+ {trainingStatus?.status || 'idle'}
122
  </Typography>
123
  </Grid>
124
+ {Number.isFinite(trainingStatus?.seed) && (
125
+ <Grid item xs={12} sm={6}>
126
+ <Typography variant="body2" color="textSecondary">Seed</Typography>
127
+ <Typography variant="body1" sx={{ fontVariantNumeric: 'tabular-nums' }}>
128
+ {trainingStatus.seed}
129
+ </Typography>
130
+ </Grid>
131
+ )}
132
  </Grid>
133
 
134
+ {lossPoints.length > 1 && (
135
  <Box sx={trainingMonitorStyles.lossSection}>
136
  <Typography variant="body2" color="textSecondary" gutterBottom>
137
  <strong>Loss History</strong>
138
  </Typography>
139
  <Box sx={trainingMonitorStyles.lossChartBox}>
140
+ <LossChart data={lossPoints} />
141
  </Box>
142
+ <Typography
143
+ variant="caption"
144
+ color="textSecondary"
145
+ sx={trainingMonitorStyles.lossDisclaimer}
146
+ >
147
+ LoRA diffusion loss is noisy by design β€” each step samples
148
+ a random noise level. Judge the result with your ears, not
149
+ only with this chart.
150
+ </Typography>
151
  </Box>
152
  )}
153
 
app/frontend/src/components/WelcomePage.js CHANGED
@@ -1,6 +1,7 @@
1
  import React, { useState, useEffect } from 'react';
2
- import { Backdrop, Box, Fade, Typography, Button, Checkbox, FormControlLabel } from '@mui/material';
3
  import { welcomePageStyles } from '../theme';
 
4
 
5
  export default function WelcomePage({ open, onClose }) {
6
  const [titleVisible, setTitleVisible] = useState(false);
@@ -48,53 +49,41 @@ export default function WelcomePage({ open, onClose }) {
48
  </Fade>
49
 
50
  <Fade in={textVisible} timeout={1000}>
51
- <Box>
52
- <Typography
53
- variant="overline"
54
- sx={welcomePageStyles.overline}
55
- >
56
- An End-to-End Pipeline to Fine-Tune and Use Text-to-Audio Models.
57
- </Typography>
 
58
 
59
-
60
- <Typography
61
- variant="body2"
62
- sx={welcomePageStyles.footer}
63
- >
64
- @2025-2026 Misagh Azimi
65
- </Typography>
66
- <Typography
67
- variant="body2"
68
- sx={welcomePageStyles.version}
69
- >
70
- Version 0.1.1
71
- </Typography>
72
- <Button
73
- variant="contained"
74
- onClick={() => onClose(dontShowAgain)}
75
- sx={welcomePageStyles.ctaButton}
76
- >
77
- Get Started
78
- </Button>
79
- <Box sx={{ mt: 1.5 }}>
80
  <FormControlLabel
81
  control={
82
  <Checkbox
83
  checked={dontShowAgain}
84
  onChange={(e) => setDontShowAgain(e.target.checked)}
85
  size="small"
86
- sx={{ color: 'text.secondary' }}
87
  />
88
  }
89
  label={
90
- <Typography variant="caption" sx={{ color: 'text.secondary' }}>
91
  Don't show this again
92
  </Typography>
93
  }
94
  />
95
  </Box>
96
-
97
- </Box>
98
  </Fade>
99
  </Box>
100
  </Backdrop>
 
1
  import React, { useState, useEffect } from 'react';
2
+ import { Backdrop, Box, Fade, Typography, Button, Checkbox, FormControlLabel, Stack } from '@mui/material';
3
  import { welcomePageStyles } from '../theme';
4
+ import { APP_VERSION } from '../version';
5
 
6
  export default function WelcomePage({ open, onClose }) {
7
  const [titleVisible, setTitleVisible] = useState(false);
 
49
  </Fade>
50
 
51
  <Fade in={textVisible} timeout={1000}>
52
+ <Stack alignItems="center">
53
+ <Stack alignItems="center">
54
+ <Typography variant="body3" color="text.secondary">
55
+ Β©2025-2026 Misagh Azimi
56
+ </Typography>
57
+ <Typography variant="body3" color="text.secondary">
58
+ Version {APP_VERSION}
59
+ </Typography>
60
 
61
+ </Stack>
62
+ <Box mt={5}>
63
+ <Button
64
+ variant="contained"
65
+ onClick={() => onClose(dontShowAgain)}
66
+ >
67
+ Get Started
68
+ </Button>
69
+ </Box>
70
+ <Box mt={6}>
 
 
 
 
 
 
 
 
 
 
 
71
  <FormControlLabel
72
  control={
73
  <Checkbox
74
  checked={dontShowAgain}
75
  onChange={(e) => setDontShowAgain(e.target.checked)}
76
  size="small"
 
77
  />
78
  }
79
  label={
80
+ <Typography variant="caption" color="text.secondary">
81
  Don't show this again
82
  </Typography>
83
  }
84
  />
85
  </Box>
86
+ </Stack>
 
87
  </Fade>
88
  </Box>
89
  </Backdrop>
app/frontend/src/components/usePerformanceSession.js CHANGED
@@ -66,13 +66,21 @@ export function loadPresetIntoSession(name) {
66
  const CHANNEL_DEFAULT = {
67
  prompt: '',
68
  duration: 8,
69
- durationMode: 'seconds',
70
  bars: 4,
71
  looping: true,
72
  muted: false,
73
  soloed: false,
74
  batchSize: 1,
75
- knobs: { gain: -6, pan: 0, filter: 18000, delay: 0, reverb: 0 },
 
 
 
 
 
 
 
 
76
  };
77
 
78
  function defaultSession(channelCount) {
@@ -88,6 +96,18 @@ function defaultSession(channelCount) {
88
  randomSeed: true,
89
  seedValue: '',
90
  cueDeviceId: '',
 
 
 
 
 
 
 
 
 
 
 
 
91
  channels: Array.from({ length: channelCount }, () => ({
92
  ...CHANNEL_DEFAULT,
93
  knobs: { ...CHANNEL_DEFAULT.knobs },
@@ -104,11 +124,21 @@ function loadSession(channelCount) {
104
  // Merge against defaults so older saves don't crash on missing fields.
105
  // Length shifts (channel count change between releases) are absorbed
106
  // by always producing exactly `channelCount` channels.
107
- const channels = Array.from({ length: channelCount }, (_, i) => ({
108
- ...CHANNEL_DEFAULT,
109
- ...(parsed.channels?.[i] || {}),
110
- knobs: { ...CHANNEL_DEFAULT.knobs, ...(parsed.channels?.[i]?.knobs || {}) },
111
- }));
 
 
 
 
 
 
 
 
 
 
112
  return { ...fallback, ...parsed, channels };
113
  } catch {
114
  return fallback;
 
66
  const CHANNEL_DEFAULT = {
67
  prompt: '',
68
  duration: 8,
69
+ durationMode: 'bars',
70
  bars: 4,
71
  looping: true,
72
  muted: false,
73
  soloed: false,
74
  batchSize: 1,
75
+ knobs: { gain: -6, pan: 0, filter: 0, delay: 0, reverb: 0 },
76
+ // Fragment history metadata (id, prompt, duration, createdAt, starred,
77
+ // number). The Blob audio bodies live in IndexedDB under the
78
+ // `session-ch{N}` scope β€” see utils/fragmentStorage.js. Cleared on
79
+ // Fresh Start and overwritten on preset load.
80
+ fragments: [],
81
+ // Which fragment was loaded into the channel strip last; restored on
82
+ // reload so the channel comes back ready to play instead of empty.
83
+ committedFragmentId: null,
84
  };
85
 
86
  function defaultSession(channelCount) {
 
96
  randomSeed: true,
97
  seedValue: '',
98
  cueDeviceId: '',
99
+ // Master FX defaults β€” the FX are always-on; the wet level on the
100
+ // master bus is determined entirely by per-channel DLY/REV send
101
+ // levels. We only persist the IR choice and the delay division.
102
+ masterReverbIR: 'hall',
103
+ masterDelayDivision: '1/4',
104
+ // Prompt auto-inject fields. Each is appended (comma-separated) to
105
+ // every generated prompt when set. Key and Time accept any text;
106
+ // empty = no injection. BPM is a toggle that, when on, grabs the
107
+ // live master BPM (top-bar value) at generation time.
108
+ promptKey: '',
109
+ promptInjectBpm: false,
110
+ promptTimeSig: '',
111
  channels: Array.from({ length: channelCount }, () => ({
112
  ...CHANNEL_DEFAULT,
113
  knobs: { ...CHANNEL_DEFAULT.knobs },
 
124
  // Merge against defaults so older saves don't crash on missing fields.
125
  // Length shifts (channel count change between releases) are absorbed
126
  // by always producing exactly `channelCount` channels.
127
+ //
128
+ // Migration: pre-rename saves used `takes` / `committedTakeId`. Copy
129
+ // those into the new `fragments` / `committedFragmentId` slots when
130
+ // present, so users' existing generations carry over after the
131
+ // "Takes β†’ Fragments" rename. Old fields are left in place but unused.
132
+ const channels = Array.from({ length: channelCount }, (_, i) => {
133
+ const ch = parsed.channels?.[i] || {};
134
+ return {
135
+ ...CHANNEL_DEFAULT,
136
+ ...ch,
137
+ fragments: ch.fragments ?? ch.takes ?? [],
138
+ committedFragmentId: ch.committedFragmentId ?? ch.committedTakeId ?? null,
139
+ knobs: { ...CHANNEL_DEFAULT.knobs, ...(ch.knobs || {}) },
140
+ };
141
+ });
142
  return { ...fallback, ...parsed, channels };
143
  } catch {
144
  return fallback;
app/frontend/src/theme.js CHANGED
The diff for this file is too large to render. See raw diff
 
app/frontend/src/tooltips.js ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const TIPS = {
2
+ // App.js β€” LoRA training hyperparameters + model row actions.
3
+ training: {
4
+ downloadModel: 'Download this model',
5
+ deleteFineTuned: 'Delete fine-tuned model',
6
+ steps: "SA3's documented quick-start is 1,000 steps.",
7
+ adapter: "DoRA-rows is SA3's upstream default and works best for most stylistic LoRAs. The -xs variants freeze SVD bases and only train a tiny core matrix β€” far fewer parameters, useful when VRAM is tight. BoRA scales both rows and columns independently (more expressive, more parameters).",
8
+ checkpointEvery: 'How often a LoRA .safetensors snapshot gets written. Auto picks ~10 checkpoints per run (capped 250–1 000 steps). Lower = more granular but more disk; higher = fewer files to compare.',
9
+ batchSize: 'SA3 examples use 1. Each extra sample adds ~1–2 GB of activations. Raise only on roomy GPUs (β‰₯24 GB); medium-base activations are heavy. Lower if you hit CUDA OOM.',
10
+ precision: 'Cast applied to the frozen base weights only; LoRA parameters stay in fp32 for the optimizer. bf16 halves the VRAM used by the base with negligible quality cost on Ampere and newer cards.',
11
+ rank: "Capacity of the LoRA update β€” rank-k matrices A (kΓ—in) and B (outΓ—k) are trained. Higher rank = more expressive but larger file and more VRAM. r=16 fits comfortably on 16 GB and is SA3's default.",
12
+ alpha: 'Scaling factor for the LoRA update. Effective scaling is alpha / rank β€” setting alpha = rank gives a scaling of 1.0. Conventional choice: alpha = rank.',
13
+ dropout: 'Regularization probability applied to LoRA inputs during training. 0 is fine for most cases β€” raise to ~0.05 if you see overfitting on small datasets.',
14
+ seed: 'Random seed for reproducibility β€” same dataset + same hyperparameters + same seed produces the same LoRA. Change it to re-roll with different sampling behaviour.',
15
+ learningRate: "AdamW step size for the LoRA weights (base stays frozen). SA3's default is 1e-4, which works for most runs. Too high destabilizes training (loss spikes, artifacts); too low barely moves the adapter. Halve it if loss is erratic.",
16
+ sampleLength: 'Audio fed to the model per training step. Long clips get random-cropped to this length each step; short clips get silence-padded. Capped at the base model\'s native length (~120s small, ~380s medium) β€” longer windows cost markedly more VRAM and step time, so raise it only for long-form material (pre-encoding helps).',
17
+ includeLayers: 'Space-separated substrings β€” only layers whose fully-qualified name contains one of these get LoRA. Empty = all matching Linear/Conv1d layers. Example: transformer.layers.',
18
+ excludeLayers: 'Space-separated substrings β€” matching layers are skipped, even if they also match Include. SA3-docs default (seconds_total to_local_embed) prevents conditioner-hijacking on small datasets.',
19
+ },
20
+
21
+ // PerformancePanel.js β€” top transport bar + bottom controls.
22
+ perf: {
23
+ notDownloaded: 'Not downloaded β€” open Checkpoint Manager',
24
+ midiSettings: 'MIDI settings & mappings',
25
+ presets: 'Save / load presets',
26
+ deletePreset: 'Delete preset',
27
+ launchQuant: "Launch quantization β€” match Ableton's",
28
+ deleteFineTuned: 'Delete fine-tuned model',
29
+ deleteLora: 'Delete LoRA',
30
+ promptKey: 'Auto-inject Key. Leave empty to skip.',
31
+ timeSig: 'Auto-inject Time signature. Leave empty to skip.',
32
+ link: ({ installing, available, enabled, peers }) =>
33
+ installing
34
+ ? 'Installing LinkPython-extern…'
35
+ : !available
36
+ ? 'Click to install Ableton Link script'
37
+ : enabled
38
+ ? `Link on β€” ${peers} peer${peers === 1 ? '' : 's'} (click to disable)`
39
+ : 'Click to sync BPM with Ableton Link',
40
+ midiMode: ({ supported, permissionError, learnMode }) =>
41
+ !supported
42
+ ? (permissionError || 'Web MIDI is not available')
43
+ : learnMode
44
+ ? 'Exit MIDI mode (Esc)'
45
+ : 'Enter MIDI mode β€” click a control then move a hardware knob/button to bind',
46
+ audioSetup: (cueSupported) =>
47
+ cueSupported
48
+ ? 'Audio setup β€” choose output device'
49
+ : 'Audio device selection requires Chrome/Edge (AudioContext.setSinkId). Output falls back to system default.',
50
+ restoreDefaults: (armed) =>
51
+ armed
52
+ ? 'Click again within 3s to confirm β€” clears session, fragments, and MIDI mappings'
53
+ : 'Reset all panel settings, clear fragments, and clear MIDI mappings',
54
+ steps: (isDistilled) =>
55
+ isDistilled
56
+ ? 'Locked at 8 steps for distilled SA3 models β€” pick a *-base checkpoint to override'
57
+ : 'Diffusion steps per generation (more = higher quality, slower)',
58
+ bpmInject: (on, bpm) =>
59
+ on
60
+ ? `Injecting master BPM (${Math.round(bpm)}) into prompts β€” click to disable`
61
+ : 'Click to auto-inject the master BPM (top bar) into every prompt',
62
+ },
63
+
64
+ // PerformanceChannel.js β€” per-channel strip.
65
+ channel: {
66
+ mute: 'Mute',
67
+ solo: 'Solo',
68
+ batch: "Batch generate Fragments and cue below.",
69
+ loop: (looping, durationMode) =>
70
+ looping
71
+ ? (durationMode === 'bars'
72
+ ? 'Loop'
73
+ : 'Playback loop on')
74
+ : 'Loop off',
75
+ generateDisabled: (generating, canGenerate, hasPrompt) =>
76
+ generating
77
+ ? ''
78
+ : !canGenerate
79
+ ? 'Pick a model in the Generation tab first'
80
+ : !hasPrompt
81
+ ? 'Enter a prompt to generate'
82
+ : '',
83
+ variation: (loaded) =>
84
+ loaded
85
+ ? 'Variation from the current fragment'
86
+ : 'Generate a fragment first, then create variations of it',
87
+ },
88
+
89
+ // DatasetPrep.js β€” dataset workbench.
90
+ dataset: {
91
+ richAnnotate: 'Adds genre / mood / instrument tags using LAION-CLAP. Requires the CLAP weights β€” downloadable from the Checkpoint Manager.',
92
+ skipAnnotated: 'When on, Auto-annotate skips clips that already have an annotation. Off means every run overwrites existing prompts.',
93
+ deleteProject: 'Delete this project (folder, audio, sidecars, drafts) β€” irreversible',
94
+ discardChanges: 'Delete unsaved changes β€” reverts to the last created dataset (removes any audio added since)',
95
+ saveDraft: "Save a draft β€” persists across app restarts but isn't the SA3 sidecar form",
96
+ createDataset: 'Create Dataset β€” writes the .txt sidecars (overwrites the previous dataset)',
97
+ selectClips: 'Click to select these clips β€” then Auto-annotate them.',
98
+ autoAnnotateClip: 'Auto-annotate this clip (overwrites any current prompt)',
99
+ sliceClip: 'Slice this clip into shorter children (immediate)',
100
+ removeClip: 'Remove this clip from the project (immediate)',
101
+ tooShort: (thresholdSec) =>
102
+ `Shorter than ${thresholdSec}s β€” gets silence-padded into each batch. Consider deleting. Click to select.`,
103
+ duplicates: (count) =>
104
+ `${count} group${count === 1 ? '' : 's'} of clips share the same annotation. Bad for training diversity β€” click to select all of them.`,
105
+ unsupported: (accepted) =>
106
+ `SA3 only trains on ${(accepted || []).join(', ')}. These clips will be silently skipped at train time β€” re-export them as .wav (or another accepted format) before committing. Click to select.`,
107
+ },
108
+
109
+ // LoraStack.js β€” LoRA slot stack.
110
+ lora: {
111
+ stackInfo: (max) => `Blend up to ${max} LoRAs at any strength`,
112
+ dragReorder: 'Drag to reorder (slot 0 loads first)',
113
+ bypass: (bypassed) =>
114
+ bypassed ? 'Bypassed (strength 0) β€” click to enable' : 'Bypass this slot',
115
+ },
116
+
117
+ // Fragment lists β€” ChannelFragmentHistory.js + GeneratedFragmentsWindow.js.
118
+ fragments: {
119
+ clearAll: 'Clear all (delete every fragment from disk)',
120
+ deleteFromDisk: 'Delete from disk',
121
+ revealInFolder: 'Show in folder (reveal this file on disk)',
122
+ audition: (isAuditioning) =>
123
+ isAuditioning ? 'Stop cue' : 'Audition through cue output',
124
+ star: (starred) =>
125
+ starred ? 'Unstar' : 'Star (keep through eviction)',
126
+ commit: (committed) =>
127
+ committed ? 'Currently loaded' : 'Load into channel',
128
+ },
129
+
130
+ // CheckpointRow.js β€” checkpoint catalog rows.
131
+ checkpoints: {
132
+ gatedAccess: "Open on HuggingFace to accept the model's gated-access terms",
133
+ },
134
+ };
app/frontend/src/utils/cueAudio.js CHANGED
@@ -11,6 +11,7 @@
11
 
12
  let ctx = null;
13
  let currentSource = null;
 
14
  let currentEndedHandler = null;
15
  let currentSinkId = '';
16
 
@@ -132,19 +133,26 @@ export async function playBlob(blob, { onEnded } = {}) {
132
 
133
  const src = c.createBufferSource();
134
  src.buffer = buf;
135
- // Connect into the splitter, NOT directly to destination β€” that's how
136
- // the channel-pair routing applies.
137
- src.connect(cueSplitter);
 
 
 
 
 
138
 
139
  const handler = () => {
140
  if (currentSource === src) {
141
  currentSource = null;
 
142
  currentEndedHandler = null;
143
  }
144
  onEnded?.();
145
  };
146
  src.addEventListener('ended', handler);
147
  currentSource = src;
 
148
  currentEndedHandler = handler;
149
  src.start();
150
 
@@ -157,12 +165,27 @@ export async function playBlob(blob, { onEnded } = {}) {
157
 
158
  export function stopCue() {
159
  if (currentSource) {
 
 
160
  if (currentEndedHandler) {
161
- currentSource.removeEventListener('ended', currentEndedHandler);
162
  }
163
- try { currentSource.stop(); } catch { /* already stopped */ }
164
- try { currentSource.disconnect(); } catch { /* already disconnected */ }
 
 
 
 
 
 
 
 
 
 
 
 
165
  currentSource = null;
 
166
  currentEndedHandler = null;
167
  }
168
  }
 
11
 
12
  let ctx = null;
13
  let currentSource = null;
14
+ let currentSourceFade = null; // per-source gain used by stopCue() to ramp out
15
  let currentEndedHandler = null;
16
  let currentSinkId = '';
17
 
 
133
 
134
  const src = c.createBufferSource();
135
  src.buffer = buf;
136
+ // Per-source fade gain so stopCue() can ramp out instead of hard-cut
137
+ // (a hard cut at non-zero samples is what produces the click /
138
+ // crackle when switching fragments rapidly). The fade graph is:
139
+ // source β†’ fadeGain β†’ cueSplitter β†’ cueMerger β†’ destination
140
+ const fadeGain = c.createGain();
141
+ fadeGain.gain.value = 1;
142
+ src.connect(fadeGain);
143
+ fadeGain.connect(cueSplitter);
144
 
145
  const handler = () => {
146
  if (currentSource === src) {
147
  currentSource = null;
148
+ currentSourceFade = null;
149
  currentEndedHandler = null;
150
  }
151
  onEnded?.();
152
  };
153
  src.addEventListener('ended', handler);
154
  currentSource = src;
155
+ currentSourceFade = fadeGain;
156
  currentEndedHandler = handler;
157
  src.start();
158
 
 
165
 
166
  export function stopCue() {
167
  if (currentSource) {
168
+ const src = currentSource;
169
+ const fade = currentSourceFade;
170
  if (currentEndedHandler) {
171
+ src.removeEventListener('ended', currentEndedHandler);
172
  }
173
+ const now = ctx ? ctx.currentTime : 0;
174
+ const FADE = 0.012;
175
+ try {
176
+ if (fade) {
177
+ fade.gain.cancelScheduledValues(now);
178
+ fade.gain.setValueAtTime(fade.gain.value, now);
179
+ fade.gain.linearRampToValueAtTime(0, now + FADE);
180
+ }
181
+ src.stop(now + FADE + 0.005);
182
+ } catch { /* already stopped */ }
183
+ window.setTimeout(() => {
184
+ try { src.disconnect(); } catch { /* ok */ }
185
+ try { fade && fade.disconnect(); } catch { /* ok */ }
186
+ }, Math.ceil((FADE + 0.02) * 1000));
187
  currentSource = null;
188
+ currentSourceFade = null;
189
  currentEndedHandler = null;
190
  }
191
  }
app/frontend/src/utils/fragmentDrag.js ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // In-app drag handoff for generated fragments.
2
+ //
3
+ // The HTML drag-and-drop dataTransfer can only carry strings, so when a
4
+ // fragment is dragged from the Generated Fragments window into the Edit tab's
5
+ // source dropzone we stash its in-memory audio Blob here on dragStart and read
6
+ // it back on drop. This lets EditPanel use the blob directly instead of
7
+ // re-fetching by filename β€” which is immune to any divergence between the
8
+ // fragment's in-memory name and what actually exists on disk.
9
+ //
10
+ // Falls back gracefully: if no blob was stashed (e.g. a not-yet-preloaded
11
+ // disk fragment), the consumer drops back to the filename-based fetch.
12
+
13
+ let _payload = null; // { filename: string, blob: Blob } | null
14
+
15
+ export function setFragmentDragPayload(payload) {
16
+ _payload = payload;
17
+ }
18
+
19
+ export function getFragmentDragPayload() {
20
+ return _payload;
21
+ }
22
+
23
+ export function clearFragmentDragPayload() {
24
+ _payload = null;
25
+ }