diff --git a/.gitattributes b/.gitattributes index 0b12894ddbf771abd5a004c7ddc7b59e8c7f8515..f83cc71f0c003acf31b515bf4ca53c336aa015f5 100644 --- a/.gitattributes +++ b/.gitattributes @@ -47,3 +47,5 @@ utils/vendor/wheels/antlr4_python3_runtime-4.9.3-py3-none-any.whl filter=lfs dif vendor/stable-audio-tools/demo_cfg_3_00000001.wav filter=lfs diff=lfs merge=lfs -text vendor/stable-audio-tools/demo_cfg_6_00000001.wav filter=lfs diff=lfs merge=lfs -text vendor/stable-audio-tools/demo_cfg_9_00000001.wav filter=lfs diff=lfs merge=lfs -text +app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf filter=lfs diff=lfs merge=lfs -text +app/frontend/public/InterTight-VariableFont_wght.ttf filter=lfs diff=lfs merge=lfs -text diff --git a/Dockerfile b/Dockerfile index 3672fe9ff653fb2bc4208230ea859e6d120294ad..4802ea9f5665417ec88e02c7e749616dddf8f376 100644 --- a/Dockerfile +++ b/Dockerfile @@ -60,8 +60,9 @@ RUN grep -ivE 'flash-attn|extra-index-url|pycairo|pygobject|pywebview' requireme COPY . . COPY --from=frontend-builder /build/frontend/build ./app/frontend/build -# Install stable-audio-tools in-tree -RUN pip install --no-cache-dir --root-user-action=ignore -e ./vendor/stable-audio-tools/ +# Install vendored Stable Audio 3 in-tree (--no-deps: runtime deps come from +# requirements.txt). Makes `import stable_audio_3` resolve. +RUN pip install --no-cache-dir --root-user-action=ignore --no-deps -e ./vendor/stable-audio-3/ # Create writable directories RUN mkdir -p /app/models/pretrained \ @@ -105,6 +106,7 @@ ENV FLASK_HOST=0.0.0.0 ENV FLASK_PORT=7860 ENV FRAGMENTA_LOG_LEVEL=INFO ENV FRAGMENTA_DOCKER=1 +ENV PYTHONPATH=/app/vendor/stable-audio-3 ENV FRAGMENTA_USE_CUSTOM_MODELS=true ENV HOME=/home/user ENV PATH="/home/user/.local/bin:${PATH}" diff --git a/README.md b/README.md index 8fdedb0eac373cbd3989177247c285ab4f6f34f1..22e73ef5ddda59783a9d1c288c1f381b715b68dd 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,10 @@ Generate and fine-tune audio from text prompts using Stable Audio Open. ## Getting Started -1. Upload your model weights (`.safetensors`) to `models/pretrained/` in the Space Files tab. - - `stable-audio-open-small-model.safetensors` (recommended for CPU) - - `stable-audio-open-model.safetensors` (full model, recommended for GPU) +1. Download an SA3 checkpoint via the in-app Checkpoint Manager, or place one + under `models/pretrained/sa3/hub/` in the Space Files tab. + - `sa3-small-music` (recommended for CPU Spaces) + - `sa3-medium` (recommended for GPU Spaces with Flash Attention 2) 2. The Space will auto-rebuild after the upload. 3. Use the **Data Processing** tab to upload audio + prompts. 4. Use the **Training** tab to fine-tune. diff --git a/app/backend/app.py b/app/backend/app.py index b2e2f2e3d833991fe1b922a52ba6639d4bc73b60..2580bd3598384c62740ed2b542cb7867531e96f4 100644 --- a/app/backend/app.py +++ b/app/backend/app.py @@ -3,19 +3,21 @@ from utils.exceptions import ModelNotFoundError, ValidationError, GenerationErro from utils.api_responses import APIResponse, handle_api_error from utils.logger import setup_logging, get_logger from app.core.generation.audio_generator import AudioGenerator -from app.core.training.fine_tuner import start_training as start_training_func, get_training_status, stop_training, preview_training_plan -from app.backend.data.simple_audio_processor import SimpleAudioProcessor +from app.core.training.sa3_trainer import start_training as start_training_func, get_training_status, stop_training, preview_training_plan from app.core.config import get_config -from flask import Flask, request, jsonify, send_file, send_from_directory +from flask import Flask, request, jsonify, send_file, send_from_directory, Response from flask_cors import CORS import os +import re +import random +import queue from pathlib import Path import sys import threading import time import json import logging -from typing import Dict, Optional +from typing import Any, Dict, Optional from werkzeug.serving import WSGIRequestHandler sys.path.append(os.path.abspath( @@ -46,6 +48,10 @@ app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0 CORS(app, resources={r"/api/*": {"origins": "*"}}, supports_credentials=True, + # /api/generate returns the WAV as the body, so the canonical on-disk + # filename (and resolved seed) ride back in custom headers. They must be + # whitelisted here or the browser hides them from axios cross-origin. + expose_headers=["X-Fragment-Filename", "X-Fragment-Seed"], methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"]) @@ -60,7 +66,6 @@ def request_entity_too_large(error): DEBUG_MODE = os.environ.get('FRAGMENTA_DEBUG', 'false').lower() == 'true' config = None -audio_processor = None generator = None model_manager = None _components_initialised = False @@ -121,7 +126,7 @@ def _resolve_finetuned_config_path(model_dir: Path, config) -> str: def _ensure_components(): - global config, audio_processor, generator, model_manager + global config, generator, model_manager global _components_initialised, _init_error if _components_initialised: @@ -132,9 +137,6 @@ def _ensure_components(): try: logger.info("Initializing Backend API components (lazy)…") config = get_config() - audio_processor = SimpleAudioProcessor( - model_config_path=config.get_path("models_config") / "model_config.json" - ) generator = AudioGenerator(config) from app.core.model_manager import ModelManager @@ -193,112 +195,9 @@ def serve_static_files(path): return response -@app.route('/api/process-files', methods=['POST']) -def process_files(): - try: - content_length = request.content_length - if content_length: - print(f"Upload request size: {content_length / (1024*1024):.2f} MB") - else: - print("Upload request size: Unknown") - - max_size = app.config.get('MAX_CONTENT_LENGTH', 1000 * 1024 * 1024) - if content_length and content_length > max_size: - return jsonify({ - 'error': f'Upload too large: {content_length / (1024*1024):.2f} MB exceeds {max_size / (1024*1024):.0f} MB limit', - 'max_size_mb': max_size // (1024*1024) - }), 413 - - config = get_config() - data_dir = config.get_path("data") - data_dir.mkdir(exist_ok=True, parents=True) - - files_and_prompts = [] - for key in request.files: - if key.startswith('file_'): - index = key.split('_')[1] - prompt_key = f'prompt_{index}' - - if prompt_key in request.form: - file_obj = request.files[key] - prompt = request.form[prompt_key] - - if file_obj and prompt and prompt.strip(): - files_and_prompts.append((file_obj, prompt.strip())) - - if not files_and_prompts: - return jsonify({'error': 'No files uploaded or no prompts provided'}), 400 - - target_length = int(request.form.get('target_length', 30)) - sample_rate = int(request.form.get('sample_rate', 44100)) - channels = int(request.form.get('channels', 2)) - - saved_files = [] - prompts_data = [] - - for file_obj, prompt in files_and_prompts: - try: - file_path = data_dir / file_obj.filename - file_obj.save(file_path) - saved_files.append(file_obj.filename) - - prompts_data.append((file_obj.filename, prompt)) - - except Exception as e: - print(f"Error saving {file_obj.filename}: {e}") - continue - - chunks_preview_data = [] - for filename, prompt in prompts_data: - chunks_preview_data.append([filename, filename, prompt, "original"]) - - # Merge into existing metadata instead of overwriting, so repeated - # uploads accumulate into one dataset. - json_path = Path(config.get_metadata_json_path()) - existing_metadata = [] - - if json_path.exists(): - try: - with open(json_path, 'r', encoding='utf-8') as f: - import json - existing_metadata = json.load(f) - except Exception as e: - print(f"Warning: Could not load existing metadata: {e}") - existing_metadata = [] - - existing_files = {item['file_name']: item for item in existing_metadata} - - for filename, prompt in prompts_data: - existing_files[filename] = { - "file_name": filename, - "prompt": prompt, - "path": str(data_dir / filename) - } - - # Convert back to list and save - final_metadata = list(existing_files.values()) - - with open(json_path, 'w', encoding='utf-8') as f: - import json - json.dump(final_metadata, f, indent=2) - - try: - config.update_dataset_config() - except Exception as exc: - print(f"Warning: failed to refresh dataset-config.json: {exc}") - return jsonify({ - 'message': f'Files saved successfully! {len(saved_files)} original files saved to data folder', - 'saved_files': saved_files, - 'processed_count': len(saved_files), - 'chunks_preview': chunks_preview_data, - 'data_folder': str(data_dir), - 'metadata_json': str(json_path), - 'approach': 'original_files_only' - }) - except Exception as e: - return jsonify({'error': str(e)}), 500 +_SA3_LORA_BASES = ['sa3-small-music-base', 'sa3-small-sfx-base', 'sa3-medium-base'] @app.route('/api/start-training', methods=['POST']) @@ -308,78 +207,83 @@ def start_training(): if training_config is None: raise ValueError("No training configuration provided") - print("\n" + "="*80) - print("API TRAINING REQUEST RECEIVED") - print("="*80) - print(f"RECEIVED CONFIG FROM FRONTEND:") - print(f" - Mode: {training_config.get('mode', 'full')}") - print(f" - Model Name: {training_config.get('modelName', 'untitled')}") - print(f" - Base Model: {training_config.get('baseModel', 'NOT SET')}") - print(f" - Epochs: {training_config.get('epochs', 'NOT SET')}") - print(f" - Checkpoint Steps: {training_config.get('checkpointSteps', 'NOT SET')}") - print(f" - Batch Size: {training_config.get('batchSize', 'NOT SET')}") - print(f" - Learning Rate: {training_config.get('learningRate', 'NOT SET')}") - if training_config.get('mode') == 'lora': - print(f" - LoRA rank: {training_config.get('loraRank', 16)}") - print(f" - LoRA alpha: {training_config.get('loraAlpha', 16)}") - print(f" - LoRA dropout: {training_config.get('loraDropout', 0)}") - else: - print(f" - Save Wrapped Checkpoint: {training_config.get('saveWrappedCheckpoint', False)}") + # Required fields. + required_fields = ['modelName', 'baseModel', 'projectName'] + missing = [f for f in required_fields if not training_config.get(f)] + if missing: + return jsonify({'error': f"Missing required fields: {missing}"}), 400 - required_fields = ['modelName', 'baseModel'] - missing_fields = [field for field in required_fields if field not in training_config] - if missing_fields: - error_msg = f"Missing required fields: {missing_fields}" - print(f"API ERROR: {error_msg}") - return jsonify({'error': error_msg}), 400 + # Project must exist on disk (Dataset Workbench created it). + from app.backend.data.projects import project_path + proj_dir = project_path(training_config['projectName']) + if not proj_dir.exists(): + return jsonify({ + 'error': f"Project not found: {training_config['projectName']}. " + "Create or load it in the Dataset tab first.", + }), 400 - valid_models = ['stable-audio-open-small', 'stable-audio-open-1.0'] + # SA3 base validation. LoRA training requires a CFG-aware *-base + # checkpoint; the post-trained / distilled checkpoints have had + # the gradient signal LoRAs target collapsed away. base_model = training_config.get('baseModel') - if base_model not in valid_models: - error_msg = f"Invalid base model '{base_model}'. Must be one of: {valid_models}" - print(f"API ERROR: {error_msg}") - return jsonify({'error': error_msg}), 400 - - # Mode validation: 'full' (existing SAO fine-tune) or 'lora' (LoRAW). - mode = training_config.get('mode', 'full') - if mode not in ('full', 'lora'): - return jsonify({'error': f"Invalid mode '{mode}'. Must be 'full' or 'lora'."}), 400 - training_config['mode'] = mode - - if 'epochs' not in training_config: - training_config['epochs'] = 30 - print(f" Setting default epochs: 30") - if 'checkpointSteps' not in training_config: - training_config['checkpointSteps'] = 50 - print(f" Setting default checkpointSteps: 50") - if 'batchSize' not in training_config: - training_config['batchSize'] = 1 - print(f" Setting default batch size: 1") - if 'learningRate' not in training_config: - training_config['learningRate'] = 1e-4 - print(f" Setting default learning rate: 1e-4") - if 'saveWrappedCheckpoint' not in training_config: - training_config['saveWrappedCheckpoint'] = False - print(f" Setting default saveWrappedCheckpoint: False") - if 'precision' not in training_config or not training_config['precision']: - training_config['precision'] = 'auto' - print(f" Setting default precision: auto") - # LoRA-specific defaults - if mode == 'lora': - training_config.setdefault('loraRank', 16) - training_config.setdefault('loraAlpha', training_config['loraRank']) - training_config.setdefault('loraDropout', 0) - training_config.setdefault('loraMultiplier', 1.0) - - print(f"\nVALIDATED CONFIG:") - print(f" - Model Name: {training_config['modelName']}") - print(f" - Base Model: {training_config['baseModel']}") - print(f" - Epochs: {training_config['epochs']}") - print(f" - Checkpoint Steps: {training_config['checkpointSteps']}") - print(f" - Batch Size: {training_config['batchSize']}") - print(f" - Learning Rate: {training_config['learningRate']}") - print(f" - Save Wrapped Checkpoint: {training_config['saveWrappedCheckpoint']}") - print(f" - Precision: {training_config['precision']}") + if base_model not in _SA3_LORA_BASES: + return jsonify({ + 'error': ( + f"baseModel '{base_model}' is not a valid LoRA target. " + f"Pick one of: {_SA3_LORA_BASES}. SA2 models are gone in 0.2.0; " + f"post-trained SA3 checkpoints (no -base suffix) can't be used " + f"as a training base." + ) + }), 400 + + # Same-name collision check. If a previous run for this modelName + # already wrote checkpoints, refuse unless the caller passes + # overwrite=true. Stops a re-train from quietly co-mingling with + # stale artifacts from the previous run. + if not training_config.get('overwrite'): + from app.core.training.sa3_trainer import SA3Trainer + existing = SA3Trainer.existing_run_info(training_config['modelName']) + if existing: + return jsonify({ + 'error': 'run_exists', + 'code': 'run_exists', + 'message': ( + f"A run named “{existing['run_name']}” already exists " + f"with {existing['checkpoint_count']} checkpoint(s). " + "Confirm overwrite to replace it." + ), + **existing, + }), 409 + + # SA3-aligned defaults. Phase 5 ships LoRA-only — no `mode` switch. + training_config['mode'] = 'lora' + training_config.setdefault('steps', 5000) + training_config.setdefault('checkpointSteps', 500) + training_config.setdefault('batchSize', 1) + training_config.setdefault('learningRate', 1e-4) + # Window defaults to the base model's native length (medium ≈380s, + # small ≈120s); sa3_trainer clamps to the same ceiling. + training_config.setdefault( + 'duration', + 380.0 if 'medium' in str(training_config.get('baseModel', '')) else 120.0, + ) + training_config.setdefault('precision', 'bf16') + training_config.setdefault('loraRank', 16) + training_config.setdefault('loraAlpha', training_config['loraRank']) + training_config.setdefault('loraDropout', 0.0) + training_config.setdefault('adapterType', 'dora-rows') + # Random seed by default: the UI sends seed=null when "Random" is on. + # Roll a concrete seed here so it's recorded (training_metadata.json + + # status) and the run stays reproducible if the user liked the result. + if training_config.get('seed') is None: + training_config['seed'] = random.randint(0, 2**31 - 1) + + logger.info( + f"Training request: base={base_model}, name={training_config['modelName']}, " + f"rank={training_config['loraRank']}, adapter={training_config['adapterType']}, " + f"steps={training_config['steps']}, batch={training_config['batchSize']}, " + f"lr={training_config['learningRate']}" + ) result = start_training_func(training_config) @@ -398,8 +302,44 @@ def start_training(): return jsonify({'error': error_msg}), 500 +_SA3_MODEL_IDS = { + "sa3-small-music", "sa3-small-sfx", "sa3-medium", + "sa3-small-music-base", "sa3-small-sfx-base", "sa3-medium-base", +} + + @app.route('/api/generate', methods=['POST']) def generate_audio(): + """SA3 inference. Phase 3 + Phase 7 of SA3_INTEGRATION_PLAN.md. + + Schema: + { + "model_id": "sa3-small-music", # required, SA3 IDs only + "prompt": "techno kick", + "duration": 5.0, + "steps": 8, # optional; default by model + "cfg_scale": 1.0, # optional; default by model + "seed": -1, + "negative_prompt": null, + "batch_size": 1, + "align_bars": null, # bars-mode passthrough + "align_bpm": null, + "chunked_decode": null, + "loras": [{"path":"…","strength":1.0}, …], + "init_audio_path": null, # audio-to-audio source + "init_noise_level": 1.0, + "inpaint_audio_path": null, # inpainting source + "inpaint_starts": [4.0], # list or single float + "inpaint_ends": [8.0], + "loop_stitch": null # "inpaint" | "crossfade" | null + } + + `loop_stitch`, `align_bars`, and `align_bpm` are accepted for API + compatibility but currently ignored — raw model output is returned + for all modes. The post-processing pipeline (grid alignment, head- + trim, seam-smoothing inpaint) was removed after listening tests + showed it degraded every prompt class. + """ if not request.json: return jsonify(APIResponse.error("No JSON data provided", status_code=400)), 400 @@ -407,26 +347,56 @@ def generate_audio(): try: prompt = Validator.string( data.get('prompt', ''), 'prompt', min_length=1, max_length=500) + + # Legacy callers send model_name; new callers send model_id. Honour either. + model_id = (data.get('model_id') or data.get('model_name') or '').strip() + if not model_id: + return jsonify(APIResponse.error( + "model_id is required (e.g. 'sa3-small-music').", + status_code=400)), 400 + if model_id not in _SA3_MODEL_IDS: + return jsonify(APIResponse.error( + f"'{model_id}' is not a SA3 model. The SA2/SAO engine was " + f"removed in 0.2.0 (see v0.1.x-legacy tag for legacy use). " + f"Pick one of: {sorted(_SA3_MODEL_IDS)}.", + status_code=400)), 400 + duration = Validator.number( - data.get('duration', 10.0), 'duration', min_value=1, max_value=60) - cfg_scale = Validator.number( - data.get('cfg_scale', 7.0), 'cfg_scale', min_value=0.1, max_value=20.0) - steps = Validator.number( - data.get('steps', 250), 'steps', min_value=1, max_value=500, integer_only=True) + data.get('duration', 10.0), 'duration', min_value=1, max_value=380) seed = Validator.number( - data.get('seed', -1), 'seed', min_value=-1, max_value=2**32 - 1, integer_only=True) - batch_index = Validator.number( - data.get('batch_index', 1), 'batch_index', min_value=1, max_value=10, integer_only=True) - batch_total = Validator.number( - data.get('batch_total', 1), 'batch_total', min_value=1, max_value=10, integer_only=True) - model_name = data.get('model_name', data.get('model', 'default')) - model_path = data.get('model_path') - unwrapped_model_path = data.get('unwrapped_model_path') - lora_path = data.get('lora_path') or None - lora_multiplier = Validator.number( - data.get('lora_multiplier', 1.0), 'lora_multiplier', - min_value=0.0, max_value=2.0, - ) + data.get('seed', -1), 'seed', + min_value=-1, max_value=2**32 - 1, integer_only=True) + # Resolve a random request (-1) to a concrete seed up front, so the + # actual seed used is reproducible AND recorded in the sidecar — an + # unresolved -1 would leave the fragment with no usable seed info. + seed = random.randint(0, 2**32 - 1) if (seed is None or int(seed) < 0) else int(seed) + # `/api/generate` returns exactly one WAV (single-file response), and + # the engine's _finalize writes one clip. Batching is done client-side + # — the UI loops this endpoint with distinct seeds (see App.js + # batchCount / PerformancePanel generateForChannel). So a server-side + # batch_size>1 would silently drop all but the first member; reject it + # with a clear message instead of failing quietly. + batch_size = Validator.number( + data.get('batch_size', 1), 'batch_size', + min_value=1, max_value=1, integer_only=True) + + steps_raw = data.get('steps') + # 250 matches the frontend slider max. Past ~80–100 the marginal + # quality gain from more steps is negligible on SA3 base models — + # the cap is a sanity boundary, not a recommendation. + steps = Validator.number( + steps_raw, 'steps', min_value=1, max_value=250, integer_only=True + ) if steps_raw is not None else None + + cfg_raw = data.get('cfg_scale') + cfg_scale = Validator.number( + cfg_raw, 'cfg_scale', min_value=0.1, max_value=20.0 + ) if cfg_raw is not None else None + + negative_prompt_raw = data.get('negative_prompt') + negative_prompt = Validator.string( + negative_prompt_raw, 'negative_prompt', min_length=0, max_length=500 + ) if negative_prompt_raw else None align_bars_raw = data.get('align_bars') align_bpm_raw = data.get('align_bpm') @@ -438,196 +408,230 @@ def generate_audio(): ) if align_bpm_raw is not None else None do_align = align_bars is not None and align_bpm is not None - except ValidationError as e: - field = e.details.get('field', 'unknown') if e.details else 'unknown' - logger.warning(f"/api/generate validation failed on '{field}': {e}") - return jsonify(APIResponse.validation_error({field: [str(e)]})), 400 + chunked_decode = data.get('chunked_decode') # tri-state: True / False / None - logger.info(f"Audio generation request received") - logger.debug(f"Request details: prompt='{prompt[:50]}...', duration={duration}s, model={model_name}") - if DEBUG_MODE: - logger.debug(f"Model paths: model_path={model_path}, unwrapped_model_path={unwrapped_model_path}") - - def determine_model_config(model_name, model_path, unwrapped_model_path): - config_file = None - model_file_path = None - - # Priority: unwrapped_model_path > model_path > base model. - if unwrapped_model_path: - model_file_path = Path(unwrapped_model_path) - if not model_file_path.exists(): - raise ModelNotFoundError( - f"unwrapped_model:{model_name}", str(model_file_path)) - logger.debug(f"Using unwrapped model: {model_file_path}") - - elif model_path: - model_file_path = Path(model_path) - if not model_file_path.exists(): - raise ModelNotFoundError( - f"model_path:{model_name}", str(model_file_path)) - logger.debug(f"Using model path: {model_file_path}") - - # Small and full models use different configs; pick by file size when the name is ambiguous. - if model_file_path: - file_size_gb = model_file_path.stat().st_size / (1024**3) - config_file = "model_config_small.json" if file_size_gb < 2.0 else "model_config.json" - logger.debug( - f"Model file size: {file_size_gb:.2f} GB, using {'small' if file_size_gb < 2.0 else 'large'} config") - - elif model_name in ['stable-audio-open-small', 'stable-audio-open-1.0']: - config_file = "model_config_small.json" if 'small' in model_name else "model_config.json" - logger.debug(f"Using base model config for {model_name}") - else: - logger.warning(f"No config determined for model: {model_name}") - config_file = "model_config_small.json" - - return config_file, model_file_path - - config_file, determined_model_path = determine_model_config( - model_name, model_path, unwrapped_model_path) - logger.info(f"Starting generation with config: {config_file}") - - # In bars mode we need a little extra audio so the post-processor can - # onset-trim and tempo-warp without running short of the requested length. - # The generator caps duration to model.sample_size internally, so this - # never overshoots the model's natural length. - ALIGN_HEADROOM_SECONDS = 1.5 - if do_align: - duration = duration + ALIGN_HEADROOM_SECONDS - logger.debug( - f"Bars-mode alignment requested: bars={align_bars}, bpm={align_bpm}; " - f"requesting {duration:.2f}s with headroom" + # Phase 7: audio-to-audio + inpainting -------------------------- + def _resolve_src(p): + if not p: + return None + ap = Path(str(p)) + if not ap.is_absolute(): + ap = config.project_root / ap + if not ap.exists(): + raise FileNotFoundError(f"Source audio not found: {p}") + return str(ap) + + try: + init_audio_path = _resolve_src(data.get('init_audio_path')) + inpaint_audio_path = _resolve_src(data.get('inpaint_audio_path')) + except FileNotFoundError as e: + return jsonify(APIResponse.error(str(e), status_code=400)), 400 + + init_noise_level = Validator.number( + data.get('init_noise_level', 1.0), 'init_noise_level', + min_value=0.0, max_value=1.0, ) - # Resolve LoRA: read the lora_config section out of the adjacent - # training_metadata.json so the inference wrapper knows the rank/alpha/etc. - lora_kwargs = {} - if lora_path: - lora_p = Path(lora_path) - if not lora_p.is_absolute(): - lora_p = config.project_root / lora_p - if not lora_p.exists(): + def _normalize_seconds(raw): + if raw is None: + return None + if isinstance(raw, (int, float)): + return [float(raw)] + if isinstance(raw, list): + return [float(x) for x in raw] + raise ValueError("must be a number or list of numbers") + try: + inpaint_starts = _normalize_seconds(data.get('inpaint_starts')) + inpaint_ends = _normalize_seconds(data.get('inpaint_ends')) + except (TypeError, ValueError) as e: return jsonify(APIResponse.error( - f"LoRA file not found: {lora_p}", status_code=400)), 400 - metadata_path = None - for ancestor in [lora_p.parent, *lora_p.parents]: - candidate = ancestor / "training_metadata.json" - if candidate.exists(): - metadata_path = candidate - break - if ancestor == config.project_root: - break - if metadata_path is None: + f"inpaint_starts/inpaint_ends invalid: {e}", status_code=400)), 400 + if (inpaint_starts is None) != (inpaint_ends is None): return jsonify(APIResponse.error( - f"No training_metadata.json found near {lora_p}", status_code=400)), 400 - try: - metadata = json.loads(metadata_path.read_text()) - except Exception as exc: + "inpaint_starts and inpaint_ends must both be set or both omitted.", + status_code=400)), 400 + if inpaint_starts and inpaint_ends and len(inpaint_starts) != len(inpaint_ends): return jsonify(APIResponse.error( - f"Failed to read LoRA metadata at {metadata_path}: {exc}", + "inpaint_starts and inpaint_ends must be the same length.", status_code=400)), 400 - if metadata.get("mode") != "lora": + + # Phase 7: seamless looping. Bars/BPM are required so the seam- + # smoothing pass knows how much audio to regenerate at the join. + loop_stitch = data.get('loop_stitch') + if loop_stitch is not None: + if loop_stitch not in ("inpaint", "crossfade"): + return jsonify(APIResponse.error( + f"loop_stitch must be 'inpaint', 'crossfade', or null; got {loop_stitch!r}.", + status_code=400)), 400 + if not do_align: + return jsonify(APIResponse.error( + "loop_stitch requires align_bars and align_bpm " + "(seamless loops are tempo-aware).", + status_code=400)), 400 + if loop_stitch == "inpaint" and inpaint_audio_path: + return jsonify(APIResponse.error( + "loop_stitch='inpaint' is incompatible with inpaint_audio_path " + "(the loop algorithm itself uses inpainting as a second pass).", + status_code=400)), 400 + + # LoRA stack: list of { path, strength }. Validate each entry. + loras_raw = data.get('loras') or [] + if not isinstance(loras_raw, list): return jsonify(APIResponse.error( - f"{metadata_path} is not a LoRA training metadata file", + "loras must be an array of {path, strength} entries.", status_code=400)), 400 - lora_kwargs = { - "lora_path": str(lora_p), - "lora_config": metadata.get("lora_config", {}), - "lora_multiplier": float(lora_multiplier), - } - logger.info( - f"LoRA selected: {lora_p.name} (rank={lora_kwargs['lora_config'].get('rank')}, " - f"multiplier={lora_multiplier})" - ) + loras = [] + for i, item in enumerate(loras_raw): + if not isinstance(item, dict) or 'path' not in item: + return jsonify(APIResponse.error( + f"loras[{i}] missing 'path'.", status_code=400)), 400 + lora_path = str(item['path']).strip() + if not lora_path: + continue + strength = float(item.get('strength', 1.0)) + strength = max(-2.0, min(2.0, strength)) + # Resolve relative paths against the project root. + lora_abs = Path(lora_path) + if not lora_abs.is_absolute(): + lora_abs = config.project_root / lora_abs + if not lora_abs.exists(): + return jsonify(APIResponse.error( + f"LoRA not found: {lora_path}", status_code=400)), 400 + # Compatibility gate (Phase 4 contract #5): a LoRA's embedded + # base_model must share a backbone with the active model. A + # `*-base` LoRA also runs on its distilled sibling (same + # architecture, differ only in CFG state), so compare with a + # trailing `-base` stripped from both sides. Unknown/missing + # base metadata is allowed through (legacy LoRAs) rather than + # blocking generation on a metadata gap. + try: + from safetensors import safe_open + with safe_open(str(lora_abs), framework="pt") as _f: + _lora_meta = _f.metadata() or {} + except Exception: + _lora_meta = {} + _lora_base = _lora_meta.get('base_model') or _lora_meta.get('base_model_id') + if _lora_base and str(_lora_base).startswith('sa3-'): + _strip = lambda m: m[:-5] if m.endswith('-base') else m + if _strip(str(_lora_base)) != _strip(model_id): + return jsonify(APIResponse.error( + f"LoRA base mismatch: '{Path(lora_path).name}' was trained " + f"against {_lora_base}, which is incompatible with {model_id}.", + status_code=400, + details={'error_code': 'lora_base_mismatch', + 'lora_id': lora_path, + 'expected': model_id, + 'actual': _lora_base})), 400 + loras.append({'path': str(lora_abs), 'strength': strength}) + + except ValidationError as e: + field = e.details.get('field', 'unknown') if e.details else 'unknown' + logger.warning(f"/api/generate validation failed on '{field}': {e}") + return jsonify(APIResponse.validation_error({field: [str(e)]})), 400 + + logger.info( + f"Audio generation request: model={model_id} duration={duration}s " + f"prompt='{prompt[:50]}{'…' if len(prompt) > 50 else ''}'" + ) + + # Variable-length lets us ask SA3 for exactly the bars-mode target duration + # up front — no time-stretch needed in the common path. Headroom gives the + # post-processor room for head-trim + drift correction without running short. + # Proportional (8% of target, clamped to [0.5s, 2.0s]) so fast tempos don't + # Bars-mode alignment was removed; deliver raw model output at the + # requested duration. No headroom is needed because we no longer + # head-trim or stretch the result. + effective_duration = duration try: - if determined_model_path and determined_model_path.exists(): - output_path = generator.generate_audio( - prompt, - unwrapped_model_path=unwrapped_model_path if unwrapped_model_path else None, - model_path=determined_model_path if not unwrapped_model_path else None, - config_file=config_file, - duration=duration, - cfg_scale=cfg_scale, - steps=steps, - seed=seed, - batch_index=batch_index, - batch_total=batch_total, - loop_mode=do_align, - **lora_kwargs, - ) - elif model_name in ['stable-audio-open-small', 'stable-audio-open-1.0']: - model_file_mapping = { - 'stable-audio-open-small': 'stable-audio-open-small-model.safetensors', - 'stable-audio-open-1.0': 'stable-audio-open-model.safetensors' - } - model_file_name = model_file_mapping.get( - model_name, f"{model_name}-model.safetensors") - model_file_path = config.project_root / \ - "models" / "pretrained" / model_file_name - - if not model_file_path.exists(): - raise ModelNotFoundError(model_name, str(model_file_path)) - - output_path = generator.generate_audio( - prompt, - model_path=model_file_path, - config_file=config_file, - duration=duration, - cfg_scale=cfg_scale, - steps=steps, - seed=seed, - batch_index=batch_index, - batch_total=batch_total, - loop_mode=do_align, - **lora_kwargs, - ) - elif model_name and model_name != 'default': - fine_tuned_path = config.get_path("models_fine_tuned") / model_name - if not fine_tuned_path.exists(): - raise ModelNotFoundError(model_name, str(fine_tuned_path)) - - output_path = generator.generate_audio( - prompt, fine_tuned_path, duration=duration, - cfg_scale=cfg_scale, steps=steps, seed=seed, - batch_index=batch_index, batch_total=batch_total, - loop_mode=do_align, **lora_kwargs) - else: - logger.debug("Using default model") - output_path = generator.generate_audio( - prompt, duration=duration, cfg_scale=cfg_scale, steps=steps, - seed=seed, batch_index=batch_index, batch_total=batch_total, - loop_mode=do_align, **lora_kwargs) + output_path = generator.generate_audio( + prompt, + model_id=model_id, + duration=float(effective_duration), + steps=int(steps) if steps is not None else None, + cfg_scale=float(cfg_scale) if cfg_scale is not None else None, + seed=int(seed), + negative_prompt=negative_prompt, + batch_size=int(batch_size), + chunked_decode=chunked_decode, + loop_mode=do_align, + loras=loras, + init_audio_path=init_audio_path, + init_noise_level=float(init_noise_level), + inpaint_audio_path=inpaint_audio_path, + inpaint_starts=inpaint_starts, + inpaint_ends=inpaint_ends, + loop_stitch=loop_stitch, + loop_bars=int(align_bars) if (loop_stitch and align_bars) else None, + loop_bpm=float(align_bpm) if (loop_stitch and align_bpm) else None, + ) if not output_path.exists(): - raise GenerationError(prompt, model_name, "Generated audio file not found") + raise GenerationError(prompt, model_id, "Generated audio file not found") - if do_align: - try: - from app.core.generation.audio_post_process import align_to_grid - align_to_grid( - output_path, - target_bpm=float(align_bpm), - target_bars=int(align_bars), - ) - logger.info( - f"Aligned to grid: bars={align_bars}, bpm={align_bpm}" - ) - except Exception as exc: - # Never fail the request because alignment failed — the user - # would rather have the raw clip than an error toast. - logger.warning(f"Grid alignment skipped after error: {exc}") + # Grid alignment and seamless-loop stitching were removed after + # user A/B testing showed every post-processing variant degraded + # the model output. We now serve raw SA3 audio for all modes. + + logger.info( + f"Audio generation completed: {output_path.name} " + f"({output_path.stat().st_size} bytes)" + ) - logger.info(f"Audio generation completed: {output_path.name} ({output_path.stat().st_size} bytes)") - return send_file( + # Sidecar metadata — lets the frontend restore the "Generated + # Fragments" panel across page reloads. Failure to write is non- + # fatal (the WAV is the only mandatory artifact). + sidecar_path = output_path.with_suffix(output_path.suffix + ".json") + try: + edit_mode = None + if init_audio_path: + edit_mode = 'style' + elif inpaint_audio_path: + edit_mode = 'inpaint/extend' + sidecar = { + "filename": output_path.name, + "created_at": time.time(), + "prompt": prompt, + "model_id": model_id, + "duration": float(duration), + "seed": int(seed), + "negative_prompt": negative_prompt, + "cfg_scale": float(cfg_scale) if cfg_scale is not None else None, + "steps": int(steps) if steps is not None else None, + "batch_size": int(batch_size), + "align_bars": int(align_bars) if align_bars else None, + "align_bpm": float(align_bpm) if align_bpm else None, + "loop_stitch": loop_stitch, + "loras": loras or [], + "init_audio_path": init_audio_path, + "init_noise_level": float(init_noise_level) if init_audio_path else None, + "inpaint_audio_path": inpaint_audio_path, + "inpaint_starts": list(inpaint_starts) if inpaint_starts else None, + "inpaint_ends": list(inpaint_ends) if inpaint_ends else None, + "edit_mode": edit_mode, + } + with open(sidecar_path, "w") as f: + json.dump(sidecar, f, indent=2) + except Exception as exc: + logger.warning(f"Failed to write fragment sidecar at {sidecar_path}: {exc}") + + resp = send_file( str(output_path), mimetype='audio/wav', as_attachment=True, - download_name='generated_audio.wav' + download_name=output_path.name, ) + # Backend is the single source of truth for the on-disk name. The + # frontend used to invent its own filename, which never matched what + # _finalize wrote — so reveal-in-folder and delete both 404'd on + # freshly generated fragments. Hand the real name (and resolved seed) + # back so the UI's fragment.filename always points at a real file. + resp.headers['X-Fragment-Filename'] = output_path.name + resp.headers['X-Fragment-Seed'] = str(int(seed)) + return resp except (ModelNotFoundError, GenerationError, ValidationError) as e: - logger.error(f"Generation error: {str(e)}") + logger.error(f"Generation error: {e}") return jsonify(APIResponse.error(str(e), status_code=400)), 400 except Exception as e: from app.core.generation.audio_generator import GenerationStopped @@ -635,162 +639,524 @@ def generate_audio(): logger.info("Generation stopped by user request") return jsonify({'stopped': True, 'message': 'Generation stopped'}), 499 logger.exception("Unexpected error during audio generation") - return jsonify(APIResponse.error(f"Unexpected error: {str(e)}", status_code=500)), 500 + return jsonify(APIResponse.error(f"Unexpected error: {e}", status_code=500)), 500 @app.route('/api/loras', methods=['GET']) def list_loras(): - """Enumerate all trained LoRAs under models/fine_tuned/. - - For each model directory that has a `training_metadata.json` with - `mode == "lora"`, returns the metadata plus the latest checkpoint file - (.ckpt) discovered under the dir. LoRAW writes the LoRA state_dict via - pl.callbacks.ModelCheckpoint, so the exact filename depends on - Lightning's pattern and the (disabled) wandb logger's run id — we just - glob for *.ckpt and return them sorted by mtime. + """Enumerate SA3 LoRAs under models/fine_tuned//checkpoints/. + + SA3 LoRAs are .safetensors files with config (rank, adapter_type, + base_model, etc.) embedded in the safetensors metadata header. + `train_lora.py` from vendor/stable-audio-3/scripts/ writes one per + checkpoint step. We surface every step so the user can A/B-test + checkpoints inside a single run. + + Lightning .ckpt files from prior runs get lazily converted to + .safetensors on demand so the picker can see them. + + Query: + base_model (optional) — filter to LoRAs compatible with this base + (e.g. ?base_model=sa3-small-music or ?base_model=sa3-small-music-base). + A small-music LoRA is compatible with both small-music and + small-music-base (same backbone, different CFG distillation + state); the matcher strips a trailing `-base` from both sides + before comparing. """ + requested_base = (request.args.get('base_model') or '').strip() + + def _base_root(model_id: str) -> str: + """Strip the `-base` suffix so LoRAs trained against `*-base` filter + as compatible with their post-trained sibling (same architecture).""" + if not model_id: + return '' + return model_id[:-5] if model_id.endswith('-base') else model_id + + requested_root = _base_root(requested_base) + try: config = get_config() fine_tuned_dir = config.get_path("models_fine_tuned") - loras = [] + # Grouping: one entry per LoRA *run* (modelName), with `all_checkpoints` + # listing every snapshot oldest→latest. `path` defaults to the latest + # checkpoint so picking the LoRA without changing the sub-picker uses + # the final state of training. + loras_by_name: Dict[str, Dict[str, Any]] = {} if fine_tuned_dir.exists(): - for model_dir in sorted(fine_tuned_dir.iterdir()): - if not model_dir.is_dir(): - continue - metadata_path = model_dir / "training_metadata.json" - if not metadata_path.exists(): - continue - try: - metadata = json.loads(metadata_path.read_text()) - except Exception: - continue - if metadata.get("mode") != "lora": + try: + from safetensors import safe_open + except ImportError: + return jsonify({"loras": []}) + + # Lazy migration: any run dir that still has Lightning .ckpt + # files (from a training run that completed before the + # auto-convert step landed) gets a one-time conversion here so + # the picker can see it. Cheap: no-op if .safetensors already + # exists for each .ckpt. + from app.core.training.sa3_lora_runner import convert_run_checkpoints_to_safetensors + + for run_dir in sorted(fine_tuned_dir.iterdir()): + if not run_dir.is_dir(): continue - ckpt_files = sorted( - model_dir.rglob("*.ckpt"), - key=lambda p: p.stat().st_mtime, - ) + ckpt_dir = run_dir / "checkpoints" + if ckpt_dir.is_dir() and any(ckpt_dir.glob("*.ckpt")): + # Read the run's base_model from its training_metadata.json + # so the converted .safetensors carries the right tag. + meta_path = run_dir / "training_metadata.json" + base_model = None + model_name = run_dir.name + if meta_path.exists(): + try: + rm = json.loads(meta_path.read_text()) + base_model = rm.get("base_model") + model_name = rm.get("model_name") or model_name + except Exception: + pass + if base_model: + try: + convert_run_checkpoints_to_safetensors( + run_dir, base_model=base_model, model_name=model_name, + ) + except Exception as conv_err: + logger.warning("Could not auto-convert %s: %s", run_dir.name, conv_err) + + # SA3 checkpoints live under /checkpoints/. Fall back + # to the run dir itself for runs that don't follow that + # convention. + search_dirs = [run_dir / "checkpoints", run_dir] + ckpt_files = [] + for d in search_dirs: + if d.is_dir(): + ckpt_files = sorted( + d.glob("*.safetensors"), + key=lambda p: p.stat().st_mtime, + ) + if ckpt_files: + break if not ckpt_files: continue - latest = ckpt_files[-1] - lora_cfg = metadata.get("lora_config", {}) - loras.append({ - "name": model_dir.name, - "path": str(latest), - "base_model": metadata.get("base_model"), - "rank": lora_cfg.get("rank"), - "alpha": lora_cfg.get("alpha"), - "all_checkpoints": [str(p) for p in ckpt_files], - }) + + for ckpt in ckpt_files: + try: + with safe_open(str(ckpt), framework="pt") as f: + meta = f.metadata() or {} + except Exception: + meta = {} + # SA3 canonically nests rank/alpha/adapter_type inside a + # `lora_config` JSON metadata key (see save_lora_safetensors + # in vendor/.../models/lora/utils.py). Parse it so the + # picker can surface those values; fall back to top-level + # for forward-compat with any future shape. + lora_config = {} + if meta.get("lora_config"): + try: + lora_config = json.loads(meta["lora_config"]) + except Exception: + lora_config = {} + + # `train_lora.py` doesn't embed base_model itself — we add + # it during the .ckpt→.safetensors conversion step. + # Fall back to training_metadata.json for legacy runs. + base_model = meta.get("base_model") or meta.get("base_model_id") + if not base_model: + run_meta_path = run_dir / "training_metadata.json" + if run_meta_path.exists(): + try: + rm = json.loads(run_meta_path.read_text()) + base_model = rm.get("base_model") + except Exception: + pass + if not base_model or not str(base_model).startswith("sa3-"): + continue # not a SA3 LoRA + + # Filter by requested base if the caller specified one. + # Treat `sa3-small-music` and `sa3-small-music-base` as + # compatible: same backbone, only differ in CFG state. + if requested_root and _base_root(base_model) != requested_root: + continue + + rel_path = str(ckpt.relative_to(config.project_root)) + rank = _safe_int(lora_config.get("rank") or meta.get("rank")) + alpha = _safe_int( + lora_config.get("alpha") + or meta.get("lora_alpha") + or meta.get("alpha") + ) + adapter_type = ( + lora_config.get("adapter_type") + or meta.get("adapter_type") + or "lora" + ) + + entry = loras_by_name.get(run_dir.name) + if entry is None: + entry = { + "id": run_dir.name, + "name": run_dir.name, + "base_model": base_model, + "rank": rank, + "alpha": alpha, + "adapter_type": adapter_type, + "all_checkpoints": [], + } + loras_by_name[run_dir.name] = entry + + entry["all_checkpoints"].append({ + "path": rel_path, + "checkpoint": ckpt.stem, + "size_bytes": ckpt.stat().st_size, + "mtime": ckpt.stat().st_mtime, + }) + + # Finalize each LoRA entry — sort checkpoints by training step + # extracted from the filename (Lightning writes "epoch=X-step=Y.ckpt" + # which converts to "epoch=X-step=Y.safetensors"). mtime is unreliable + # because the lazy .ckpt→.safetensors converter can rewrite files in + # alphabetical (not training-step) order. + import re as _re + _step_pat = _re.compile(r"step=(\d+)") + + def _step_of(checkpoint_stem: str) -> int: + m = _step_pat.search(checkpoint_stem) + return int(m.group(1)) if m else -1 + + loras = [] + for entry in loras_by_name.values(): + ckpts = sorted(entry["all_checkpoints"], key=lambda c: _step_of(c["checkpoint"])) + entry["all_checkpoints"] = [c["path"] for c in ckpts] + latest = ckpts[-1] + entry["path"] = latest["path"] + entry["checkpoint"] = latest["checkpoint"] + entry["size_bytes"] = latest["size_bytes"] + entry["mtime"] = latest["mtime"] + loras.append(entry) + loras.sort(key=lambda e: e["mtime"], reverse=True) + return jsonify({"loras": loras}) except Exception as e: logger.exception("Failed to enumerate LoRAs") return jsonify(APIResponse.error(f"Failed to list LoRAs: {e}", status_code=500)), 500 -@app.route('/api/status', methods=['GET']) -def get_status(): - _log_api_call('status') +def _safe_int(v): try: - config = get_config() - data_dir = config.get_path("data") - metadata_json = Path(config.get_metadata_json_path()) - custom_metadata = Path(config.get_custom_metadata_path()) + return int(v) if v is not None else None + except (TypeError, ValueError): + return None + + +@app.route('/api/audio/upload', methods=['POST']) +def upload_source_audio(): + """Accept a user-uploaded audio file to use as init_audio / inpaint_audio. + + Stores under /uploads/_ so the file is + inside the project tree. Returns {path, name} where `path` is relative + to project_root so /api/generate can resolve it. + """ + if 'file' not in request.files: + return jsonify(APIResponse.error("No file provided.", status_code=400)), 400 + fileobj = request.files['file'] + if not fileobj.filename: + return jsonify(APIResponse.error("Empty filename.", status_code=400)), 400 + + name = Path(fileobj.filename).name + # Strip path components + restrict to known extensions. + ext = Path(name).suffix.lower() + if ext not in {".wav", ".mp3", ".flac", ".m4a", ".ogg", ".opus"}: + return jsonify(APIResponse.error( + f"Unsupported audio format '{ext}'. Use wav/mp3/flac/m4a/ogg/opus.", + status_code=400)), 400 + safe = re.sub(r"[^a-zA-Z0-9._-]", "_", Path(name).stem)[:60] or "upload" + + cfg = get_config() + uploads_dir = cfg.get_path("output") / "uploads" + uploads_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d_%H%M%S") + dest = uploads_dir / f"{ts}_{safe}{ext}" + fileobj.save(str(dest)) + + rel = dest.relative_to(cfg.project_root) + return jsonify({"path": str(rel), "name": dest.name, "size_bytes": dest.stat().st_size}) + + +@app.route('/api/performance/recording', methods=['POST']) +def save_performance_recording(): + """Persist a master-bus performance capture (WAV) to the output folder. + + Accepts multipart: `file` (audio/wav) + `name` (user-supplied label) and + an optional `duration` (seconds). Sanitizes the name, ensures a unique + `.wav` filename, and writes a sidecar so the capture lists cleanly in the + Fragments window alongside generated audio. + """ + if 'file' not in request.files: + return jsonify(APIResponse.error("No recording file provided.", status_code=400)), 400 + fileobj = request.files['file'] + if not fileobj.filename: + return jsonify(APIResponse.error("Empty recording.", status_code=400)), 400 + + raw_name = (request.form.get('name') or '').strip() + safe = re.sub(r"[^a-zA-Z0-9._ -]", "_", raw_name).strip().replace(" ", "_")[:80] + if not safe: + safe = time.strftime("performance_%Y%m%d_%H%M%S") + + cfg = get_config() + output_dir = cfg.get_path("output") + output_dir.mkdir(parents=True, exist_ok=True) + + dest = output_dir / f"{safe}.wav" + # Don't clobber an existing capture — bump a numeric suffix until free. + counter = 2 + while dest.exists(): + dest = output_dir / f"{safe}_{counter}.wav" + counter += 1 + fileobj.save(str(dest)) + + try: + duration = float(request.form["duration"]) if request.form.get("duration") else None + except (TypeError, ValueError): + duration = None + + # Sidecar so /api/fragments shows a friendly label instead of parsing the + # bare filename. Mirrors the generation sidecar schema. + sidecar = { + "filename": dest.name, + "created_at": time.time(), + "prompt": raw_name or dest.stem, + "model_id": "", + "duration": duration, + "seed": None, + "cfg_scale": None, + "steps": None, + "source": "performance", + } + try: + (output_dir / f"{dest.name}.json").write_text(json.dumps(sidecar, indent=2)) + except Exception as exc: + logger.warning(f"Failed to write recording sidecar for {dest.name}: {exc}") + + return jsonify({ + "filename": dest.name, + "size_bytes": dest.stat().st_size, + "duration": duration, + }) + + +@app.route('/api/fragments', methods=['GET']) +def list_fragments(): + """List previously-generated audio fragments (latest first). + + Returns the union of: + • Generations with a sidecar JSON (full metadata: prompt, seed, etc.) + • Orphan WAVs in output/ (no sidecar — happens for clips made by + older versions of Fragmenta; metadata is recovered from the + filename + mtime). + + Query: ?limit= (default 100, capped at 500) + """ + cfg = get_config() + output_dir = cfg.get_path("output") + try: + limit = max(1, min(500, int(request.args.get('limit', 100)))) + except (TypeError, ValueError): + limit = 100 + + if not output_dir.exists(): + return jsonify({"fragments": []}) + fragments = [] + seen_wavs = set() + + # Sidecared generations + for sidecar_path in output_dir.glob("*.wav.json"): try: - import torchaudio - except ImportError: - torchaudio = None + with open(sidecar_path) as f: + meta = json.load(f) + wav_name = meta.get("filename") or sidecar_path.name[:-len(".json")] + wav_path = output_dir / wav_name + if not wav_path.exists(): + continue # sidecar without its WAV — skip silently + seen_wavs.add(wav_name) + meta["filename"] = wav_name + meta["size_bytes"] = wav_path.stat().st_size + fragments.append(meta) + except Exception as exc: + logger.warning(f"Failed to read fragment sidecar {sidecar_path}: {exc}") + + # Orphan WAVs (no sidecar) — recover what we can. + # Filename format from _finalize: __.wav + for wav_path in output_dir.glob("*.wav"): + if wav_path.name in seen_wavs: + continue try: - import soundfile as sf - except ImportError: - sf = None + stem = wav_path.stem + parts = stem.split("_") + # Try to parse "YYYYMMDD_HHMMSS" as the first two tokens. + created_at = wav_path.stat().st_mtime + model_id = "" + prompt = stem + if len(parts) >= 2: + try: + created_at = time.mktime( + time.strptime(f"{parts[0]}_{parts[1]}", "%Y%m%d_%H%M%S") + ) + rest = "_".join(parts[2:]) + # rest = "_". model_id is + # contiguous hyphenated; we use the longest known prefix + # by matching against SA3 model ids in _MODEL_INFO. + rest_low = rest.lower() + for mid in sorted(_SA3_MODEL_IDS, key=len, reverse=True): + if rest_low.startswith(mid): + model_id = mid + prompt = rest[len(mid):].lstrip("_").replace("_", " ") + break + else: + prompt = rest.replace("_", " ") + except ValueError: + pass + fragments.append({ + "filename": wav_path.name, + "created_at": created_at, + "prompt": prompt or "(unknown)", + "model_id": model_id, + "duration": None, + "seed": None, + "cfg_scale": None, + "steps": None, + "size_bytes": wav_path.stat().st_size, + "_orphan": True, + }) + except Exception as exc: + logger.warning(f"Failed to list orphan fragment {wav_path}: {exc}") + + fragments.sort(key=lambda x: x.get("created_at") or 0, reverse=True) + return jsonify({"fragments": fragments[:limit]}) + + +@app.route('/api/fragments/', methods=['GET']) +def serve_fragment(filename): + """Serve a WAV from output/ by name. Path traversal is rejected.""" + if "/" in filename or "\\" in filename or ".." in filename: + return jsonify(APIResponse.error("Invalid filename.", status_code=400)), 400 + if not filename.endswith(".wav"): + return jsonify(APIResponse.error("Only .wav files are served.", status_code=400)), 400 + cfg = get_config() + full = cfg.get_path("output") / filename + if not full.exists() or not full.is_file(): + return jsonify(APIResponse.error("File not found.", status_code=404)), 404 + return send_file(str(full), mimetype="audio/wav") + + +@app.route('/api/fragments/', methods=['DELETE']) +def delete_fragment(filename): + """Delete a single fragment (WAV + its sidecar JSON) from output/.""" + if "/" in filename or "\\" in filename or ".." in filename: + return jsonify(APIResponse.error("Invalid filename.", status_code=400)), 400 + if not filename.endswith(".wav"): + return jsonify(APIResponse.error("Only .wav files are deletable.", status_code=400)), 400 + cfg = get_config() + output_dir = cfg.get_path("output") + wav_path = output_dir / filename + sidecar_path = output_dir / f"{filename}.json" + if not wav_path.exists(): + return jsonify(APIResponse.error("File not found.", status_code=404)), 404 + removed = [] + try: + wav_path.unlink() + removed.append(wav_path.name) + if sidecar_path.exists(): + sidecar_path.unlink() + removed.append(sidecar_path.name) + except Exception as exc: + logger.error(f"Failed to delete fragment {filename}: {exc}") + return jsonify(APIResponse.error(f"Delete failed: {exc}", status_code=500)), 500 + logger.info(f"Deleted fragment: {', '.join(removed)}") + return jsonify({"deleted": removed}) - def _duration(path: Path) -> float: - import warnings - try: - with warnings.catch_warnings(): - # /api/status polls every few seconds; without this, every - # torchaudio.info call spams pages of deprecation noise. - warnings.simplefilter("ignore") - if torchaudio is not None: - info = torchaudio.info(str(path)) - return info.num_frames / info.sample_rate - if sf is not None: - f = sf.SoundFile(str(path)) - return len(f) / f.samplerate - except Exception as exc: - print(f"Error reading {path}: {exc}") - return 0.0 - - # Prefer metadata.json as the source of truth for "the dataset" — it spans - # files that may live outside data/ (e.g. CSV import with copy_files=false). - # Fall back to scanning data/ if metadata is absent. - file_names = [] - total_duration = 0.0 - if metadata_json.exists(): + +@app.route('/api/fragments', methods=['DELETE']) +def clear_fragments(): + """Delete EVERY .wav + .wav.json directly under output/. + + Does NOT recurse — uploaded source clips live under output/uploads/ + and are intentionally left alone (they may still be referenced by + in-flight Edit-mode work, and the user uploaded them deliberately). + """ + cfg = get_config() + output_dir = cfg.get_path("output") + if not output_dir.exists(): + return jsonify({"deleted": 0}) + removed = 0 + errors = [] + for pattern in ("*.wav", "*.wav.json"): + for p in output_dir.glob(pattern): + if not p.is_file(): + continue try: - with open(metadata_json, 'r', encoding='utf-8') as f: - entries = json.load(f) or [] + p.unlink() + removed += 1 except Exception as exc: - print(f"Could not read metadata.json: {exc}") - entries = [] - for item in entries: - if not isinstance(item, dict): - continue - name = item.get('file_name') - if not name: - continue - file_names.append(name) - # Files are staged into data/ (copy or symlink) at commit time, so - # resolve by basename rather than trusting the stored `path` field - # (legacy entries may use an incorrect prefix). - p = data_dir / name - if not p.exists(): - stored = item.get('path') or '' - candidate = Path(stored) - if not candidate.is_absolute(): - candidate = config.project_root / stored - if candidate.is_file(): - p = candidate - if p.is_file(): - total_duration += _duration(p) - else: - audio_files = list(data_dir.glob("*.wav")) + \ - list(data_dir.glob("*.mp3")) + list(data_dir.glob("*.flac")) - for audio_file in audio_files: - total_duration += _duration(audio_file) - file_names = [f.name for f in audio_files] - - status_response = { - 'status': 'running', - 'raw_files': len(file_names), - 'processed_segments': len(file_names), - 'raw_file_names': file_names[:10], - 'total_duration': total_duration, - 'has_metadata_json': metadata_json.exists(), - 'has_custom_metadata': custom_metadata.exists(), - 'trained_models': len(list(config.get_path("models_fine_tuned").glob("*"))) if config.get_path("models_fine_tuned").exists() else 0, - 'training': get_training_status() - } + errors.append(f"{p.name}: {exc}") + if errors: + logger.warning(f"clear_fragments: removed {removed}, errors: {errors}") + else: + logger.info(f"clear_fragments: removed {removed} file(s)") + return jsonify({"deleted": removed, "errors": errors}) + + +@app.route('/api/lora-strength', methods=['POST']) +def update_lora_strength(): + """Live-update a loaded LoRA's strength without regenerating. + + Performance Mode uses this when the user drags a strength slider — + the next generate() picks up the new value, but the model itself + doesn't need to be reloaded. Returns 409 if no LoRAs are loaded yet + or the index is out of range. + """ + data = request.json or {} + try: + index = int(data.get('index', -1)) + strength = float(data.get('strength', 1.0)) + except (TypeError, ValueError): + return jsonify(APIResponse.error("index and strength are required.", status_code=400)), 400 - return jsonify(status_response) + try: + if not getattr(generator, 'model', None): + return jsonify(APIResponse.error("No model loaded.", status_code=409)), 409 + ok = generator.set_lora_strength(index, strength) + if not ok: + return jsonify(APIResponse.error( + f"LoRA index {index} not loaded.", status_code=409)), 409 + return jsonify({'success': True, 'index': index, 'strength': strength}) except Exception as e: - return jsonify({'error': str(e)}), 500 + logger.exception("set_lora_strength failed") + return jsonify(APIResponse.error(str(e), status_code=500)), 500 @app.route('/api/training/suggest-hyperparams', methods=['GET']) def training_suggest_hyperparams(): """Heuristic hyperparameter suggester for the Training tab's Suggest button. - Query: mode=lora|full (default lora). - Returns: {ok, stats, config, rationale} — see hyperparam_suggester.suggest. + Query: + project_name (required) — Dataset Workbench project to analyse + base_model (optional) — picked SA3 base, e.g. sa3-medium-base. Used to + pick a -XS adapter when VRAM is tight and to + emit base-model-aware warnings. + Returns: {ok, stats, config, rationale, warnings} — see hyperparam_suggester.suggest. """ try: + from app.backend.data.projects import project_path from app.core.training.hyperparam_suggester import suggest - mode = request.args.get('mode', 'lora') - config = get_config() - result = suggest(config.get_path('data'), mode=mode) + project_name = request.args.get('project_name', '').strip() + base_model = request.args.get('base_model', '').strip() or None + if not project_name: + return jsonify({'ok': False, 'error': "project_name is required."}), 400 + proj_dir = project_path(project_name) + if not proj_dir.exists(): + return jsonify({ + 'ok': False, + 'error': f"Project not found: {project_name}", + }), 404 + result = suggest(proj_dir, base_model=base_model) return jsonify(result) except Exception as exc: logger.exception("hyperparam suggestion failed") @@ -889,20 +1255,6 @@ def get_models(): latest_config = max( config_files, key=lambda x: x.stat().st_mtime) if config_files else None - unwrapped_dir = model_dir / "unwrapped" - unwrapped_models = [] - if unwrapped_dir.exists(): - for unwrapped_file in unwrapped_dir.glob("*.safetensors"): - unwrapped_models.append({ - 'name': unwrapped_file.stem, - 'path': str(unwrapped_file.relative_to(config.project_root)), - 'size_mb': round(unwrapped_file.stat().st_size / (1024 * 1024), 1), - 'created': unwrapped_file.stat().st_mtime - }) - - unwrapped_models.sort( - key=lambda x: x['created'], reverse=True) - # Resolve the architecture config + base-model identity for # this fine-tuned model. Order: per-run copy in the model # folder, then training_metadata breadcrumb, then legacy @@ -918,7 +1270,6 @@ def get_models(): 'config_path': resolved['config_path'], 'base_model': resolved['base_model'], 'checkpoints': checkpoints, - 'unwrapped_models': unwrapped_models, 'created': model_dir.stat().st_mtime if model_dir.exists() else None }) @@ -927,128 +1278,125 @@ def get_models(): return jsonify({'error': str(e)}), 500 -@app.route('/api/models/available', methods=['GET']) -def get_available_models(): +# ============================================================================ +# Checkpoint Manager — SA3 catalog endpoints (Phase 2a of SA3_INTEGRATION_PLAN) +# ============================================================================ + +@app.route('/api/checkpoints', methods=['GET']) +def list_checkpoints(): try: - models = model_manager.get_available_models() - return jsonify({'models': models}) + # ?include=all → also returns base + standalone AE entries (training + # subprocess uses this; the manager UI relies on the default). + include_hidden = request.args.get('include') == 'all' + return jsonify({ + 'checkpoints': model_manager.get_catalog(include_hidden=include_hidden), + }) except Exception as e: return jsonify({'error': str(e)}), 500 -@app.route('/api/models//info', methods=['GET']) -def get_model_info(model_id): +@app.route('/api/checkpoints/storage', methods=['GET']) +def checkpoints_storage(): try: - model_info = model_manager.get_model_info(model_id) - if not model_info: - return jsonify({'error': 'Model not found'}), 404 - return jsonify(model_info) + return jsonify(model_manager.get_storage_info()) except Exception as e: return jsonify({'error': str(e)}), 500 -@app.route('/api/models//accept-terms', methods=['POST']) -def accept_model_terms(model_id): +@app.route('/api/checkpoints/', methods=['GET']) +def get_checkpoint(model_id): try: - success = model_manager.accept_terms(model_id) - if success: - return jsonify({'success': True, 'message': f'Terms accepted for {model_id}'}) - else: - return jsonify({'error': 'Failed to accept terms'}), 400 + info = model_manager.get_model_info(model_id) + if not info: + return jsonify({'error': 'Unknown checkpoint'}), 404 + return jsonify(info) except Exception as e: return jsonify({'error': str(e)}), 500 -@app.route('/api/models//download', methods=['POST']) -def download_model(model_id): +@app.route('/api/checkpoints//download', methods=['POST']) +def start_checkpoint_download(model_id): try: - if not model_manager.is_terms_accepted(model_id): - return jsonify({'error': 'Terms not accepted for this model'}), 400 + result = model_manager.start_download(model_id) + # _DownloadJob.to_dict() always includes the "error" key (None when + # ok); use a truthy check so a successful job doesn't 400. + if result.get('error'): + return jsonify(result), 400 + return jsonify(result) + except Exception as e: + return jsonify({'error': str(e)}), 500 - success = model_manager.download_model(model_id) - if success: - return jsonify({ - 'success': True, - 'message': f'Model {model_id} downloaded successfully' - }) - else: - return jsonify({'error': f'Failed to download {model_id}'}), 500 + +@app.route('/api/checkpoints//cancel-download', methods=['POST']) +def cancel_checkpoint_download(model_id): + try: + # Cancel every in-flight job for this checkpoint (usually one). + jobs = [j for j in model_manager.list_jobs() + if j['model_id'] == model_id and j['status'] in ('queued', 'running')] + cancelled = [j['job_id'] for j in jobs if model_manager.cancel_job(j['job_id'])] + return jsonify({'cancelled': cancelled}) except Exception as e: return jsonify({'error': str(e)}), 500 -@app.route('/api/hf-login', methods=['POST']) -def hf_login(): +@app.route('/api/checkpoints/', methods=['DELETE']) +def delete_checkpoint_download(model_id): try: - data = request.json - token = data.get('token') - if not token: - return jsonify({'error': 'Token is required'}), 400 - - import huggingface_hub - try: - huggingface_hub.login(token=token, add_to_git_credential=False) - user_info = huggingface_hub.whoami(token=token) - return jsonify({'success': True, 'user': user_info.get('name', 'User')}) - except Exception as e: - return jsonify({'error': f'Invalid token or connection error: {str(e)}'}), 401 + if model_manager.delete_model(model_id): + return jsonify({'success': True}) + return jsonify({'error': 'Nothing to delete'}), 404 except Exception as e: return jsonify({'error': str(e)}), 500 -@app.route('/api/base-models/status', methods=['GET']) -def get_base_models_status(): +@app.route('/api/checkpoints/jobs/', methods=['GET']) +def get_checkpoint_job(job_id): try: - import os - from pathlib import Path - - base_models = { - 'stable-audio-open-1.0': { - 'name': 'Stable Audio Open 1.0', - 'path': 'models/pretrained', - 'file': 'stable-audio-open-model.safetensors', - 'downloaded': False - }, - 'stable-audio-open-small': { - 'name': 'Stable Audio Open Small', - 'path': 'models/pretrained', - 'file': 'stable-audio-open-small-model.safetensors', - 'downloaded': False - } - } + job = model_manager.get_job(job_id) + if not job: + return jsonify({'error': 'Unknown job'}), 404 + return jsonify(job) + except Exception as e: + return jsonify({'error': str(e)}), 500 - for model_id, info in base_models.items(): - model_dir = Path(info['path']) - model_file = model_dir / info['file'] - if model_file.exists() and model_file.is_file(): - info['downloaded'] = True - else: - # Legacy layout: model stored in a subdirectory. - old_path = model_dir / model_id - if old_path.exists() and old_path.is_dir(): - has_files = any([ - (old_path / 'model.safetensors').exists(), - (old_path / 'pytorch_model.bin').exists(), - (old_path / 'model.ckpt').exists(), - len(list(old_path.glob('*.safetensors'))) > 0, - len(list(old_path.glob('*.bin'))) > 0 - ]) - info['downloaded'] = has_files - - return jsonify({'base_models': base_models}) +# --- HuggingFace auth ------------------------------------------------------- + +@app.route('/api/hf-auth/status', methods=['GET']) +def hf_auth_status(): + try: + return jsonify(model_manager.hf_auth_status()) except Exception as e: return jsonify({'error': str(e)}), 500 -@app.route('/api/models//delete', methods=['DELETE']) -def delete_model(model_id): +@app.route('/api/hf-auth', methods=['POST']) +def hf_auth_login(): try: - success = model_manager.delete_model(model_id) - if success: - return jsonify({'success': True, 'message': f'Model {model_id} deleted'}) - else: - return jsonify({'error': f'Failed to delete {model_id}'}), 400 + data = request.json or {} + token = (data.get('token') or '').strip() + if not token: + return jsonify({'error': 'Token is required'}), 400 + import huggingface_hub + try: + huggingface_hub.login(token=token, add_to_git_credential=False) + info = huggingface_hub.whoami(token=token) + return jsonify({ + 'success': True, + 'username': info.get('name') or info.get('fullname'), + }) + except Exception as e: + return jsonify({'error': f'Invalid token: {e}'}), 401 + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/hf-auth', methods=['DELETE']) +def hf_auth_logout(): + try: + import huggingface_hub + huggingface_hub.logout() + return jsonify({'success': True}) except Exception as e: return jsonify({'error': str(e)}), 500 @@ -1174,208 +1522,19 @@ def delete_fine_tuned_model(model_name): return jsonify({'error': str(e)}), 500 -@app.route('/api/models/storage', methods=['GET']) -def get_model_storage(): - try: - storage_info = model_manager.get_storage_info() - return jsonify(storage_info) - except Exception as e: - return jsonify({'error': str(e)}), 500 - +@app.route('/api/generation-progress', methods=['GET']) +def get_generation_progress_route(): + """Live progress for the in-flight `/api/generate` call. -@app.route('/api/start-fresh', methods=['POST']) -def start_fresh(): - try: - config = get_config() - data_dir = config.get_path("data") - config_dir = config.get_path("models_config") - - data_files_deleted = 0 - if data_dir.exists(): - # iterdir() (vs glob("*")) catches dotfiles too — e.g. the - # hyperparam suggester's .duration_cache.json, which should - # absolutely be wiped on Fresh Start. - for file_path in data_dir.iterdir(): - if file_path.is_file() and not file_path.name.endswith('.py'): - file_path.unlink() - data_files_deleted += 1 - - config_files_deleted = 0 - if config_dir.exists(): - for file_path in config_dir.glob("custom_metadata.py"): - if file_path.is_file(): - file_path.unlink() - config_files_deleted += 1 - - labels_reset = False - user_labels_path = _annotator_labels_user_path() - if user_labels_path.exists(): - user_labels_path.unlink() - labels_reset = True - - data_dir.mkdir(exist_ok=True, parents=True) + Returns the same dict as audio_generator.get_generation_progress(): + is_generating, phase ("idle"|"loading"|"sampling"|"decoding"| + "complete"|"failed"), step, total_steps, progress (0-100), + batch_index, batch_total, started_at, ended_at, error. - return jsonify({ - 'message': f'Fresh start completed! Deleted {data_files_deleted} data files and {config_files_deleted} config metadata files.', - 'data_files_deleted': data_files_deleted, - 'config_files_deleted': config_files_deleted, - 'annotator_labels_reset': labels_reset - }) - - except Exception as e: - return jsonify({'error': str(e)}), 500 - - -@app.route('/api/unwrap-model', methods=['POST']) -def unwrap_model(): - try: - data = request.json - model_config = data.get('model_config') - ckpt_path = data.get('ckpt_path') - name = data.get('name', 'model_unwrap') - - if not model_config or not ckpt_path: - return jsonify({'error': 'model_config and ckpt_path are required'}), 400 - - import subprocess - from pathlib import Path - - config = get_config() - repo_root = config.project_root - - model_config_path = repo_root / \ - model_config if not Path( - model_config).is_absolute() else Path(model_config) - ckpt_path_resolved = repo_root / \ - ckpt_path if not Path(ckpt_path).is_absolute() else Path(ckpt_path) - - if not model_config_path.exists(): - return jsonify({'error': f'Model config not found: {model_config_path}'}), 400 - if not ckpt_path_resolved.exists(): - return jsonify({'error': f'Checkpoint not found: {ckpt_path_resolved}'}), 400 - - model_dir = ckpt_path_resolved.parent - unwrapped_dir = model_dir / "unwrapped" - unwrapped_dir.mkdir(exist_ok=True) - - cmd = [ - sys.executable, 'unwrap_model.py', - '--model-config', str(model_config_path), - '--ckpt-path', str(ckpt_path_resolved), - '--name', name, - '--use-safetensors' - ] - - # unwrap_model.py writes next to its CWD, so run from vendor/stable-audio-tools/. - stable_audio_dir = repo_root / "vendor" / "stable-audio-tools" - - proc = subprocess.run(cmd, cwd=stable_audio_dir, - capture_output=True, text=True) - - if proc.returncode == 0: - - import glob - pattern = str(stable_audio_dir / f"{name}*.safetensors") - created_files = glob.glob(pattern) - - moved_files = [] - for created_file in created_files: - created_path = Path(created_file) - target_path = unwrapped_dir / created_path.name - - try: - created_path.rename(target_path) - moved_files.append(str(target_path)) - print(f"Moved {created_path.name} to {target_path}") - except Exception as e: - print(f"Error moving {created_path}: {e}") - - unwrapped_files = list(unwrapped_dir.glob("*.safetensors")) - - return jsonify({ - 'status': 'success', - 'output': proc.stdout, - 'unwrapped_path': moved_files[0] if moved_files else None, - 'unwrapped_files': [str(f) for f in unwrapped_files], - 'moved_files': moved_files - }) - else: - return jsonify({'status': 'error', 'error': proc.stderr, 'output': proc.stdout}), 500 - except Exception as e: - return jsonify({'error': str(e)}), 500 - - -@app.route('/api/delete-checkpoint', methods=['POST']) -def delete_checkpoint(): - try: - data = request.json - checkpoint_path = data.get('checkpoint_path') - - if not checkpoint_path: - return jsonify({'error': 'checkpoint_path is required'}), 400 - - config = get_config() - repo_root = config.project_root - - ckpt_path_resolved = repo_root / \ - checkpoint_path if not Path( - checkpoint_path).is_absolute() else Path(checkpoint_path) - - if not ckpt_path_resolved.exists(): - return jsonify({'error': f'Checkpoint file not found: {ckpt_path_resolved}'}), 404 - - # Restrict deletion to .ckpt to avoid accidental loss of unwrapped models. - if not ckpt_path_resolved.suffix == '.ckpt': - return jsonify({'error': f'Only .ckpt files can be deleted: {ckpt_path_resolved}'}), 400 - - try: - ckpt_path_resolved.unlink() - return jsonify({ - 'status': 'success', - 'message': f'Checkpoint deleted successfully', - 'deleted_file': str(ckpt_path_resolved.name) - }) - except Exception as e: - return jsonify({'error': f'Failed to delete checkpoint: {str(e)}'}), 500 - - except Exception as e: - return jsonify({'error': str(e)}), 500 - - -@app.route('/api/delete-wrapped-checkpoint', methods=['POST']) -def delete_wrapped_checkpoint(): - try: - data = request.json - model_name = data.get('model_name') - - if not model_name: - return jsonify({'error': 'model_name is required'}), 400 - - config = get_config() - models_dir = config.get_path("models_fine_tuned") - model_dir = models_dir / model_name - - if not model_dir.exists(): - return jsonify({'error': f'Model directory not found: {model_dir}'}), 404 - - deleted_files = [] - for ckpt_file in model_dir.glob("*.ckpt"): - try: - ckpt_file.unlink() - deleted_files.append(str(ckpt_file.name)) - except Exception as e: - return jsonify({'error': f'Failed to delete {ckpt_file.name}: {str(e)}'}), 500 - - if not deleted_files: - return jsonify({'message': 'No wrapped checkpoint files found to delete'}) - - return jsonify({ - 'status': 'success', - 'message': f'Deleted {len(deleted_files)} wrapped checkpoint file(s)', - 'deleted_files': deleted_files - }) - except Exception as e: - return jsonify({'error': str(e)}), 500 + Cheap (just a dict copy under a lock); safe to poll at ~200ms. + """ + from app.core.generation.audio_generator import get_generation_progress + return jsonify(get_generation_progress()) @app.route('/api/stop-generation', methods=['POST']) @@ -1636,10 +1795,13 @@ def open_output_folder(): try: import subprocess import platform - - output_path = Path("output") - output_path.mkdir(exist_ok=True) - + + # Use the configured output dir, not Path("output") relative to the + # process cwd — the launcher may start the backend from a different + # working directory, which would open (or create) the wrong folder. + output_path = get_config().get_path("output") + output_path.mkdir(parents=True, exist_ok=True) + system = platform.system() if system == "Windows": subprocess.run(["explorer", str(output_path.absolute())]) @@ -1647,12 +1809,80 @@ def open_output_folder(): subprocess.run(["open", str(output_path.absolute())]) else: # Linux subprocess.run(["xdg-open", str(output_path.absolute())]) - + return jsonify({"success": True, "message": "Output folder opened"}) except Exception as e: logger.error(f"Error opening output folder: {e}") return jsonify({"success": False, "error": str(e)}), 500 + +@app.route('/api/reveal-fragment', methods=['POST']) +def reveal_fragment(): + """Reveal a single generated fragment in the OS file manager, with the + file itself selected/highlighted where the platform supports it. + + Body: { "filename": "20260601_120000_sa3-small-music_kick.wav" } + + Platform behaviour: + * Windows : explorer /select, → folder opens, file highlighted + * macOS : open -R → Finder opens, file highlighted + * Linux : org.freedesktop.FileManager1 ShowItems (D-Bus) highlights + the file in Nautilus/Dolphin/etc.; falls back to + `xdg-open ` (opens the folder, no highlight) when no + FileManager1 provider is on the bus. + """ + import subprocess + import platform + + data = request.json or {} + filename = str(data.get('filename', '')).strip() + if not filename: + return jsonify(APIResponse.error("filename is required.", status_code=400)), 400 + # Guard against path traversal — fragments live flat in the output dir. + if '/' in filename or '\\' in filename or filename.startswith('.'): + return jsonify(APIResponse.error("Invalid filename.", status_code=400)), 400 + + output_dir = get_config().get_path("output") + target = output_dir / filename + if not target.exists(): + return jsonify(APIResponse.error( + f"Fragment not found on disk: {filename}", status_code=404)), 404 + + try: + system = platform.system() + abs_path = str(target.absolute()) + if system == "Windows": + # /select needs the path glued to the flag (no space after comma). + subprocess.run(["explorer", f"/select,{abs_path}"]) + elif system == "Darwin": + subprocess.run(["open", "-R", abs_path]) + else: # Linux + revealed = False + try: + subprocess.run( + ["dbus-send", "--session", "--print-reply", + "--dest=org.freedesktop.FileManager1", + "--type=method_call", + "/org/freedesktop/FileManager1", + "org.freedesktop.FileManager1.ShowItems", + f"array:string:file://{abs_path}", + "string:"], + check=True, + timeout=5, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + revealed = True + except Exception: + revealed = False + if not revealed: + # No FileManager1 provider — open the containing folder. + subprocess.run(["xdg-open", str(output_dir.absolute())]) + return jsonify({"success": True, "message": "Fragment revealed"}) + except Exception as e: + logger.error(f"Error revealing fragment {filename}: {e}") + return jsonify(APIResponse.error(str(e), status_code=500)), 500 + @app.route('/api/open-documentation', methods=['POST']) def open_documentation(): try: @@ -1736,50 +1966,6 @@ def get_license_info(): "error": str(e) }), 500 -@app.route('/api/models-status', methods=['GET']) -def get_models_status(): - try: - required_models = ['stable-audio-open-small', 'stable-audio-open-1.0'] - downloaded_models = [ - model_id for model_id in required_models if model_manager.is_model_downloaded(model_id) - ] - models_exist = len(downloaded_models) > 0 - models_message = ( - "Required base models are available." - if models_exist - else "No required base model is downloaded yet." - ) - - hf_authenticated = False - try: - from huggingface_hub import HfApi - HfApi().whoami() - hf_authenticated = True - except Exception: - hf_authenticated = False - - should_show = (not models_exist) and (not hf_authenticated) - auth_reason = ( - "Hugging Face authentication is required to download gated models." - if should_show - else "Authentication already available or models already downloaded." - ) - - return jsonify({ - "models_exist": models_exist, - "models_message": models_message, - "should_show_auth_dialog": should_show, - "auth_reason": auth_reason - }) - except Exception as e: - logger.error(f"Error checking models status: {e}") - return jsonify({ - "error": str(e), - "models_exist": False, - "should_show_auth_dialog": True, - "auth_reason": f"Error checking models: {str(e)}" - }), 500 - @app.route('/api/gpu-memory-status', methods=['GET']) def get_gpu_memory_status(): _log_api_call('gpu_memory_status') @@ -1812,7 +1998,7 @@ def get_gpu_memory_status(): if allocated_memory == 0: try: result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,noheader,nounits'], - capture_output=True, text=True, timeout=1) + capture_output=True, text=True, timeout=5) if result.stdout.strip(): used_mb, total_mb = result.stdout.strip().split(', ') nvidia_used_gb = float(used_mb) / 1024 @@ -1877,23 +2063,6 @@ def get_gpu_memory_status(): return jsonify({'error': str(e)}), 500 -_annotate_job_lock = threading.Lock() -_annotate_job = { - 'state': 'idle', # idle | running | done | error - 'current': 0, - 'total': 0, - 'current_file': '', - 'tier': None, - 'folder': None, - 'results': [], - 'error': None, -} -_clap_download_job = { - 'state': 'idle', # idle | running | done | error - 'message': '', - 'error': None, -} - def _annotator_labels_default_path(): return Path(get_config().project_root) / 'config' / 'annotator_labels.json' @@ -1974,8 +2143,28 @@ def _clap_ckpt_path(): @app.route('/api/environment', methods=['GET']) def environment(): + # Host capability flags so the Checkpoint Manager can grey out models this + # machine can't run (e.g. sa3-medium needs CUDA + Flash-Attn 2; no Windows + # wheels). Mirrors the gate in audio_generator._ensure_model. + import platform as _platform + cuda = mps = flash = False + try: + import torch + cuda = bool(torch.cuda.is_available()) + mps = bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()) + except Exception: + pass + try: + import flash_attn # noqa: F401 + flash = True + except Exception: + flash = False return jsonify({ 'docker': os.environ.get('FRAGMENTA_DOCKER', '0') == '1', + 'platform': _platform.system(), # 'Windows' | 'Linux' | 'Darwin' + 'cuda_available': cuda, + 'mps_available': mps, + 'flash_attn_available': flash, }) @@ -1997,7 +2186,7 @@ def upload_folder(): folder_name = first_rel.split('/', 1)[0] if '/' in first_rel else 'folder' safe_folder = ''.join(c for c in folder_name if c.isalnum() or c in '-_') or 'folder' - staging_root = get_config().get_path('data') / 'uploads' + staging_root = get_config().get_path('uploads') staging_root.mkdir(parents=True, exist_ok=True) target_dir = staging_root / f"{int(time.time())}-{safe_folder}" target_dir.mkdir(parents=True, exist_ok=True) @@ -2082,438 +2271,636 @@ def pick_folder(): return jsonify({'path': chosen}) -@app.route('/api/bulk-annotate/status', methods=['GET']) -def bulk_annotate_status(): - from app.backend.data.auto_annotator import clap_checkpoint_available - with _annotate_job_lock: - snapshot = {k: v for k, v in _annotate_job.items() if k != 'results'} - snapshot['result_count'] = len(_annotate_job['results']) - snapshot['clap_available'] = clap_checkpoint_available(get_config().get_path('models_pretrained')) - snapshot['clap_download'] = dict(_clap_download_job) - return jsonify(snapshot) + +# --- SA3 sidecar-native dataset prep ----------------------------------------- +# Projects are folders under /projects//. Editing happens +# against an in-memory session per loaded project; persistence is explicit via +# Save (writes .draft.json) and Commit (writes .txt sidecars + marks audio +# committed). See DATASET_PREP_REDESIGN.md. + +_project_annotate_jobs: Dict[str, dict] = {} +_project_annotate_jobs_lock = threading.Lock() + + +def _get_project_annotate_job(project_name: str) -> dict: + with _project_annotate_jobs_lock: + job = _project_annotate_jobs.get(project_name) + if job is None: + job = { + 'state': 'idle', + 'current': 0, + 'total': 0, + 'current_file': '', + 'tier': None, + 'annotated': 0, + 'skipped_existing': 0, + 'errors': 0, + 'error': None, + 'started_at': None, + 'finished_at': None, + 'cancelled': False, + } + _project_annotate_jobs[project_name] = job + return job -@app.route('/api/bulk-annotate/results', methods=['GET']) -def bulk_annotate_results(): - with _annotate_job_lock: - return jsonify({'results': list(_annotate_job['results']), 'state': _annotate_job['state']}) +@app.route('/api/projects', methods=['GET']) +def list_projects_route(): + from app.backend.data.projects import list_projects + try: + return jsonify({'projects': list_projects()}) + except Exception as exc: + logger.exception("Failed to list projects") + return jsonify({'error': str(exc)}), 500 -@app.route('/api/bulk-annotate', methods=['POST']) -def bulk_annotate(): +@app.route('/api/projects', methods=['POST']) +def create_project_route(): + from app.backend.data.projects import create_project, sanitize_project_name payload = request.json or {} - folder = payload.get('folder_path', '').strip() - tier = payload.get('tier', 'basic') - if tier not in ('basic', 'rich'): - return jsonify({'error': f"Invalid tier: {tier}"}), 400 - if not folder: - return jsonify({'error': 'folder_path is required'}), 400 + raw_name = payload.get('name', '') + try: + name = sanitize_project_name(raw_name) + except ValueError as exc: + return jsonify({'error': str(exc)}), 400 + try: + project = create_project(name) + except FileExistsError as exc: + return jsonify({'error': str(exc)}), 409 + except Exception as exc: + logger.exception("Failed to create project %s", name) + return jsonify({'error': str(exc)}), 500 + return jsonify(project), 201 - folder_path = Path(folder).expanduser() - if not folder_path.exists() or not folder_path.is_dir(): - return jsonify({'error': f'Folder not found: {folder_path}'}), 400 - from app.backend.data.auto_annotator import ( - annotate_folder, load_label_sets, clap_checkpoint_available, - ) +@app.route('/api/projects/', methods=['GET']) +def get_project_route(name): + from app.backend.data.projects import get_project + try: + return jsonify(get_project(name)) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except Exception as exc: + logger.exception("Failed to get project %s", name) + return jsonify({'error': str(exc)}), 500 - if tier == 'rich' and not clap_checkpoint_available(get_config().get_path('models_pretrained')): - return jsonify({'error': 'CLAP checkpoint not downloaded yet.'}), 409 - - with _annotate_job_lock: - if _annotate_job['state'] == 'running': - return jsonify({'error': 'An annotation job is already running.'}), 409 - _annotate_job.update({ - 'state': 'running', 'current': 0, 'total': 0, 'current_file': '', - 'tier': tier, 'folder': str(folder_path), 'results': [], 'error': None, - }) - labels = load_label_sets(_annotator_labels_path()) +@app.route('/api/projects//template', methods=['PATCH']) +def patch_project_template_route(name): + """Update the project's annotation-template preset. - def progress_cb(i, total, name): - with _annotate_job_lock: - _annotate_job['current'] = i - _annotate_job['total'] = total - _annotate_job['current_file'] = name + Body: { "preset": "music" | "instrument" | "sfx" } + """ + from app.backend.data.projects import update_project_template_preset + payload = request.json or {} + if 'preset' not in payload: + return jsonify({'error': 'preset is required'}), 400 + try: + return jsonify(update_project_template_preset(name, payload['preset'])) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except ValueError as exc: + return jsonify({'error': str(exc)}), 400 + except Exception as exc: + logger.exception("Failed to update template preset for %s", name) + return jsonify({'error': str(exc)}), 500 - def runner(): - try: - results = annotate_folder( - folder_path, tier=tier, label_sets=labels, - clap_ckpt_path=_clap_ckpt_path() if tier == 'rich' else None, - progress_cb=progress_cb, - ) - with _annotate_job_lock: - _annotate_job['results'] = results - _annotate_job['state'] = 'done' - except Exception as exc: - logger.exception("Bulk annotation failed") - with _annotate_job_lock: - _annotate_job['state'] = 'error' - _annotate_job['error'] = str(exc) - threading.Thread(target=runner, daemon=True).start() - return jsonify({'message': 'Annotation started', 'tier': tier, 'folder': str(folder_path)}) +@app.route('/api/projects//health', methods=['GET']) +def project_health_route(name): + """Per-clip health checks. Returns counts + file lists the UI can route + into the existing selection model.""" + from app.backend.data.projects import compute_health + try: + short_th = float(request.args.get('short_threshold_sec', 1.0)) + except (TypeError, ValueError): + short_th = 1.0 + try: + return jsonify(compute_health(name, short_th)) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except Exception as exc: + logger.exception("Health check failed for %s", name) + return jsonify({'error': str(exc)}), 500 + +@app.route('/api/projects/', methods=['DELETE']) +def delete_project_route(name): + """Nuke the project folder + drop the in-memory session. Irreversible.""" + from app.backend.data.projects import delete_project + try: + delete_project(name) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except Exception as exc: + logger.exception("Failed to delete project %s", name) + return jsonify({'error': str(exc)}), 500 + return jsonify({'name': name, 'deleted': True}) -@app.route('/api/bulk-annotate/commit', methods=['POST']) -def bulk_annotate_commit(): - """Merge user-reviewed annotation results into metadata.json. - Body: { entries: [{ file_name, prompt, path }, ...], copy_files: bool } - """ +@app.route('/api/projects//ingest', methods=['POST']) +def ingest_into_project_route(name): + """Body: { folder_path: string, mode: "copy" | "symlink" }""" + from app.backend.data.projects import ingest_folder, INGEST_MODES, get_project payload = request.json or {} - entries = payload.get('entries') or [] - copy_files = bool(payload.get('copy_files', True)) - if not entries: - return jsonify({'error': 'No entries to commit.'}), 400 - - config = get_config() - data_dir = config.get_path('data') - data_dir.mkdir(exist_ok=True, parents=True) - - json_path = Path(config.get_metadata_json_path()) - existing_metadata = [] - if json_path.exists(): - try: - with open(json_path, 'r', encoding='utf-8') as f: - existing_metadata = json.load(f) - except Exception as exc: - logger.warning("Could not load existing metadata: %s", exc) - existing_metadata = [] - existing_files = {item['file_name']: item for item in existing_metadata} - - committed = 0 - for entry in entries: - file_name = entry.get('file_name') - prompt = (entry.get('prompt') or '').strip() - src_path = entry.get('path') - if not file_name or not prompt or not src_path: - continue + folder = (payload.get('folder_path') or '').strip() + mode = payload.get('mode', 'copy') + if mode not in INGEST_MODES: + return jsonify({'error': f'Invalid ingest mode: {mode}'}), 400 + if not folder: + return jsonify({'error': 'folder_path is required'}), 400 + folder_path = Path(folder).expanduser() + try: + result = ingest_folder(name, folder_path, mode) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except ValueError as exc: + return jsonify({'error': str(exc)}), 400 + except Exception as exc: + logger.exception("Ingest failed for project %s", name) + return jsonify({'error': str(exc)}), 500 + logger.info( + "Ingest into project=%s mode=%s added=%d (copied=%d symlinked=%d skipped=%d)", + name, mode, result['added'], result['copied'], result['symlinked'], result['skipped'], + ) + return jsonify({**result, 'project': get_project(name)}) - src = Path(src_path) - if not src.exists(): - logger.warning("Source missing for %s: %s", file_name, src) - continue - if src.parent.resolve() != data_dir.resolve(): - dst = data_dir / file_name - try: - _stage_into_data_dir(src, dst, copy_files=copy_files) - except Exception as exc: - logger.warning("Stage failed for %s: %s", src, exc) - continue +@app.route('/api/projects//clip/', methods=['PATCH']) +def patch_clip_route(name, file_name): + """In-memory prompt edit. Persists only on Save or Commit.""" + from app.backend.data.projects import update_clip_prompt + payload = request.json or {} + if 'prompt' not in payload: + return jsonify({'error': 'prompt is required'}), 400 + try: + clip = update_clip_prompt(name, file_name, payload['prompt']) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except Exception as exc: + logger.exception("Failed to update clip %s in project %s", file_name, name) + return jsonify({'error': str(exc)}), 500 + return jsonify(clip) - stored_path = str(data_dir / file_name) - existing_files[file_name] = { - 'file_name': file_name, - 'prompt': prompt, - 'path': stored_path, - } - committed += 1 +@app.route('/api/projects//clip/', methods=['DELETE']) +def delete_clip_route(name, file_name): + """Immediate delete — cannot be discarded back.""" + from app.backend.data.projects import delete_clip, get_project + try: + delete_clip(name, file_name) + except Exception as exc: + logger.exception("Failed to delete clip %s in project %s", file_name, name) + return jsonify({'error': str(exc)}), 500 + return jsonify({'name': name, 'file_name': file_name, 'deleted': True, 'project': get_project(name)}) - final_metadata = list(existing_files.values()) - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(final_metadata, f, indent=2) +@app.route('/api/projects//clip//audio', methods=['GET']) +def clip_audio_route(name, file_name): + """Stream raw audio bytes for a clip. Range requests work via send_file.""" + from app.backend.data.projects import get_session_handle + try: + session = get_session_handle(name) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + with session.lock: + clip = session.clips.get(file_name) + if clip is None: + return jsonify({'error': f"Clip not found: {file_name}"}), 404 + audio_path = Path(clip.path) + if not audio_path.exists(): + return jsonify({'error': 'Audio file missing on disk'}), 404 + return send_file(str(audio_path), conditional=True) + + +@app.route('/api/projects//clip//slice', methods=['POST']) +def clip_slice_route(name, file_name): + """Split one clip into N children. Body: { target_duration, overlap_sec, strategy }.""" + from app.backend.data.projects import slice_clip + from app.backend.data.slicing import VALID_STRATEGIES + payload = request.json or {} + try: + target = float(payload.get('target_duration', 0)) + overlap = float(payload.get('overlap_sec', 0)) + strategy = str(payload.get('strategy', 'hard')).lower() + except (TypeError, ValueError): + return jsonify({'error': 'target_duration / overlap_sec must be numeric'}), 400 + if target <= 0 or target > 600: + return jsonify({'error': 'target_duration must be in (0, 600]'}), 400 + if overlap < 0 or overlap >= target: + return jsonify({'error': 'overlap_sec must be >= 0 and < target_duration'}), 400 + if strategy not in VALID_STRATEGIES: + return jsonify({'error': f"strategy must be one of {VALID_STRATEGIES}"}), 400 try: - config.update_dataset_config() + result = slice_clip(name, file_name, target, overlap, strategy) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except ValueError as exc: + return jsonify({'error': str(exc)}), 400 except Exception as exc: - logger.warning("Failed to refresh dataset-config.json: %s", exc) + logger.exception("Slice failed for %s/%s", name, file_name) + return jsonify({'error': str(exc)}), 500 + return jsonify(result) - return jsonify({ - 'message': f'Committed {committed} annotations.', - 'committed': committed, - 'metadata_json': str(json_path), - }) +@app.route('/api/projects//clip//peaks', methods=['GET']) +def clip_peaks_route(name, file_name): + """Return waveform peaks + duration JSON for a clip. Cached per session.""" + from app.backend.data.projects import get_session_handle, get_or_compute_peaks + try: + n = int(request.args.get('n', 200)) + except (TypeError, ValueError): + n = 200 + n = max(20, min(n, 500)) + try: + session = get_session_handle(name) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + with session.lock: + clip = session.clips.get(file_name) + if clip is None: + return jsonify({'error': f"Clip not found: {file_name}"}), 404 + audio_path = Path(clip.path) + if not audio_path.exists(): + return jsonify({'error': 'Audio file missing on disk'}), 404 + try: + peaks, duration = get_or_compute_peaks(session, file_name, audio_path, n) + except Exception as exc: + logger.exception("Peak computation failed for %s/%s", name, file_name) + return jsonify({'error': f'Peak computation failed: {exc}'}), 500 + return jsonify({'peaks': peaks, 'duration': duration}) -def _stage_into_data_dir(src: Path, dst: Path, copy_files: bool) -> str: - """Place an audio file into the data dir so the trainer's scan picks it up. - Trainer reads `dataset-config.json` -> a single `audio_dir` path (data/), - so files referenced only by absolute paths in metadata.json are invisible. - Copying duplicates disk; symlinking is the cheap alternative. +@app.route('/api/projects//save', methods=['POST']) +def save_project_route(name): + """Persist in-memory diffs as a hidden draft (not the SA3 sidecars).""" + from app.backend.data.projects import save_project + try: + return jsonify(save_project(name)) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except Exception as exc: + logger.exception("Failed to save project %s", name) + return jsonify({'error': str(exc)}), 500 - Returns 'copy' | 'symlink' to describe what was done. - """ - import shutil - if dst.is_symlink() or dst.exists(): - try: - dst.unlink() - except IsADirectoryError: - shutil.rmtree(dst) +@app.route('/api/projects//commit', methods=['POST']) +def commit_project_route(name): + """Flush in-memory state to .txt sidecars; overwrites the previous commit.""" + from app.backend.data.projects import commit_project + try: + return jsonify(commit_project(name)) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except Exception as exc: + logger.exception("Failed to commit project %s", name) + return jsonify({'error': str(exc)}), 500 - if not copy_files: - try: - dst.symlink_to(src.resolve()) - return 'symlink' - except (OSError, NotImplementedError) as exc: - # Windows without developer mode disallows symlinks; fall back. - logger.info("Symlink failed for %s, falling back to copy: %s", src, exc) - shutil.copy2(src, dst) - return 'copy' +@app.route('/api/projects//discard', methods=['POST']) +def discard_project_route(name): + """Drop uncommitted state and delete audio files added since the last commit.""" + from app.backend.data.projects import discard_project + try: + return jsonify(discard_project(name)) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + except Exception as exc: + logger.exception("Failed to discard project %s", name) + return jsonify({'error': str(exc)}), 500 -@app.route('/api/import-csv/preview', methods=['POST']) -def import_csv_preview(): - """Parse a CSV upload + audio folder, return rows with conflict status. +@app.route('/api/projects//annotate', methods=['POST']) +def annotate_project_route(name): + """Kick off auto-annotation. Updates the in-memory session; sidecars on + disk are not touched until the user commits. - Form fields: - - csv: file upload (text/csv) with at least file_name, prompt columns - - audio_folder: server-side path to a folder containing the audio files + Body: { + tier: "basic" | "rich", + scope?: "all" | ["file_name1", ...], # default: "all" + skip_existing?: bool # default: true + } """ - import csv as _csv - from io import StringIO - - csv_file = request.files.get('csv') - audio_folder = (request.form.get('audio_folder') or '').strip() - - if not csv_file: - return jsonify({'error': 'CSV file is required.'}), 400 - if not audio_folder: - return jsonify({'error': 'Audio folder path is required.'}), 400 + from app.backend.data.projects import get_session_handle, reset_cancel, project_path + from app.backend.data.auto_annotator import ( + annotate_file, load_label_sets, get_clap_tagger, clap_checkpoint_available, + ) - audio_dir = Path(audio_folder).expanduser() - if not audio_dir.exists() or not audio_dir.is_dir(): - return jsonify({'error': f'Audio folder does not exist: {audio_folder}'}), 400 - audio_dir_resolved = audio_dir.resolve() + payload = request.json or {} + tier = payload.get('tier', 'basic') + scope = payload.get('scope', 'all') + skip_existing = bool(payload.get('skip_existing', True)) + if tier not in ('basic', 'rich'): + return jsonify({'error': f'Invalid tier: {tier}'}), 400 try: - text = csv_file.read().decode('utf-8-sig') - except UnicodeDecodeError: - return jsonify({'error': 'CSV must be UTF-8 encoded.'}), 400 + session = get_session_handle(name) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 - reader = _csv.DictReader(StringIO(text)) - fieldnames = reader.fieldnames or [] - if 'file_name' not in fieldnames or 'prompt' not in fieldnames: + if tier == 'rich' and not clap_checkpoint_available(get_config().get_path('models_pretrained')): return jsonify({ - 'error': "CSV must include 'file_name' and 'prompt' columns.", - 'found_columns': fieldnames, - }), 400 + 'error': 'CLAP checkpoint not downloaded yet. Open Model Management to download it.', + 'code': 'clap_not_available', + }), 409 + + with session.lock: + all_clips = sorted(session.clips.values(), key=lambda c: c.file_name) + if scope == 'all': + target = list(all_clips) + elif isinstance(scope, list): + wanted = set(scope) + target = [c for c in all_clips if c.file_name in wanted] + missing = wanted - {c.file_name for c in target} + if missing: + return jsonify({'error': f'Clips not in project: {sorted(missing)}'}), 404 + else: + return jsonify({'error': 'scope must be "all" or a list of file names'}), 400 - config = get_config() - json_path = Path(config.get_metadata_json_path()) - existing_files = set() - if json_path.exists(): + if skip_existing: + run_targets = [c for c in target if not (c.prompt or '').strip()] + skipped_existing = len(target) - len(run_targets) + else: + run_targets = list(target) + skipped_existing = 0 + target_names = [c.file_name for c in run_targets] + + job = _get_project_annotate_job(name) + with _project_annotate_jobs_lock: + if job['state'] == 'running': + return jsonify({'error': f'Annotation already running for project {name}.'}), 409 + job.update({ + 'state': 'running', + 'current': 0, + 'total': len(target_names), + 'current_file': '', + 'tier': tier, + 'annotated': 0, + 'skipped_existing': skipped_existing, + 'errors': 0, + 'error': None, + 'started_at': time.time(), + 'finished_at': None, + 'cancelled': False, + }) + + reset_cancel(session) + labels = load_label_sets(_annotator_labels_path()) + clap_tagger = None + if tier == 'rich': + clap_tagger = get_clap_tagger(_clap_ckpt_path()) try: - with open(json_path, 'r', encoding='utf-8') as f: - existing = json.load(f) - existing_files = {item['file_name'] for item in existing if isinstance(item, dict)} + clap_tagger.ensure_loaded() + except FileNotFoundError as exc: + # File-existence check passed earlier but the actual load failed. + with _project_annotate_jobs_lock: + job['state'] = 'idle' + return jsonify({ + 'error': str(exc), + 'code': 'clap_not_available', + }), 409 + except ImportError as exc: + # The .pt weights are on disk but one of CLAP's Python deps isn't + # installed in the venv. Could be laion_clap itself or anything it + # imports transitively (e.g. torchvision). Model Manager can't fix + # this — the user has to pip install in their environment. + with _project_annotate_jobs_lock: + job['state'] = 'idle' + missing = getattr(exc, 'name', None) or 'laion_clap' + # Module name → PyPI name when they differ; default identical. + _PYPI_NAME = {'laion_clap': 'laion-clap'} + pip_name = _PYPI_NAME.get(missing, missing) + install_command = f'pip install {pip_name}' + # torch-family packages need the CUDA index URL to match the + # pinned torch build; otherwise pip installs the CPU-only wheel. + if missing in {'torchvision', 'torchaudio'}: + install_command += ' --extra-index-url https://download.pytorch.org/whl/cu128' + return jsonify({ + 'error': ( + f"The '{missing}' Python package is required for Rich-tier annotation " + "but isn't installed. Install it in Fragmenta's venv, then restart the app:" + ), + 'code': 'clap_package_missing', + 'install_command': install_command, + }), 409 except Exception as exc: - logger.warning("Could not load existing metadata: %s", exc) - - rows = [] - seen_keys = set() - for line_no, raw in enumerate(reader, start=2): - file_name_csv = (raw.get('file_name') or '').strip() - prompt = (raw.get('prompt') or '').strip() - if not file_name_csv and not prompt: - continue + with _project_annotate_jobs_lock: + job['state'] = 'idle' + logger.exception("CLAP load failed for project %s", name) + return jsonify({ + 'error': f'CLAP failed to load: {exc}', + 'code': 'clap_load_failed', + }), 500 - row_errors = [] - if not file_name_csv: - row_errors.append('missing file_name') - if not prompt: - row_errors.append('missing prompt') - - src_path = None - audio_found = False - if file_name_csv: - rel = file_name_csv.replace('\\', '/').lstrip('/') - if '..' in rel.split('/'): - row_errors.append('file_name contains ".."') - else: - candidate = (audio_dir / rel).resolve() + proj_path = project_path(name) + # Resolve the active preset to a template string once per job. + from app.backend.data.projects import resolve_prompt_template + active_template = resolve_prompt_template(session) + + def runner(): + try: + logger.info( + "Project annotate started: name=%s tier=%s targets=%d skip_existing=%s", + name, tier, len(target_names), skip_existing, + ) + for i, file_name in enumerate(target_names, start=1): + if session.cancel_event.is_set(): + logger.info("Project annotate cancelled mid-run: name=%s", name) + with _project_annotate_jobs_lock: + job['cancelled'] = True + break + with _project_annotate_jobs_lock: + job['current_file'] = file_name + logger.info(" annotating %d/%d: %s", i, len(target_names), file_name) + audio_path = proj_path / file_name try: - candidate.relative_to(audio_dir_resolved) - if candidate.is_file(): - src_path = str(candidate) - audio_found = True - else: - row_errors.append('audio file not found in folder') - except ValueError: - row_errors.append('file_name resolves outside audio folder') - - key_name = Path(file_name_csv).name if file_name_csv else '' - duplicate_in_csv = key_name and key_name in seen_keys - if duplicate_in_csv: - row_errors.append('duplicate file_name within this CSV') - if key_name: - seen_keys.add(key_name) - - rows.append({ - 'line': line_no, - 'file_name': key_name, - 'csv_path': file_name_csv, - 'prompt': prompt, - 'src_path': src_path, - 'audio_found': audio_found, - 'conflict': bool(key_name) and key_name in existing_files, - 'errors': row_errors, - }) + result = annotate_file( + audio_path, tier, clap_tagger, labels, + prompt_template=active_template, + ) + except Exception as exc: + logger.warning("annotate_file failed for %s: %s", file_name, exc) + with _project_annotate_jobs_lock: + job['errors'] += 1 + job['current'] += 1 + continue + if result.get('error'): + with _project_annotate_jobs_lock: + job['errors'] += 1 + job['current'] += 1 + continue + prompt = result.get('prompt', '') or '' + with session.lock: + clip = session.clips.get(file_name) + if clip is not None: + clip.prompt = prompt + with _project_annotate_jobs_lock: + job['annotated'] += 1 + job['current'] += 1 + with _project_annotate_jobs_lock: + job['state'] = 'done' + job['finished_at'] = time.time() + job['current_file'] = '' + logger.info( + "Project annotate done: name=%s annotated=%d errors=%d skipped_existing=%d cancelled=%s", + name, job['annotated'], job['errors'], job['skipped_existing'], job['cancelled'], + ) + except Exception as exc: + logger.exception("Project annotate failed: name=%s", name) + with _project_annotate_jobs_lock: + job['state'] = 'error' + job['error'] = str(exc) + job['finished_at'] = time.time() - return jsonify({ - 'rows': rows, - 'total': len(rows), - 'conflicts': sum(1 for r in rows if r['conflict']), - 'missing_audio': sum(1 for r in rows if not r['audio_found']), - 'existing_count': len(existing_files), - }) + threading.Thread(target=runner, daemon=True).start() + with _project_annotate_jobs_lock: + snapshot = dict(job) + return jsonify({'name': name, 'job': snapshot}), 202 -@app.route('/api/import-csv/commit', methods=['POST']) -def import_csv_commit(): - """Merge reviewed CSV import entries into metadata.json with a conflict policy. +@app.route('/api/projects//annotate/cancel', methods=['POST']) +def annotate_project_cancel_route(name): + """Stop a running annotate job after the in-flight clip completes.""" + from app.backend.data.projects import get_session_handle + try: + session = get_session_handle(name) + except FileNotFoundError as exc: + return jsonify({'error': str(exc)}), 404 + session.cancel_event.set() + return jsonify({'name': name, 'cancel_signal_set': True}) + + +@app.route('/api/projects//annotate/status', methods=['GET']) +def annotate_project_status_route(name): + from app.backend.data.projects import project_path + if not project_path(name).exists(): + return jsonify({'error': f'Project not found: {name}'}), 404 + with _project_annotate_jobs_lock: + job = _project_annotate_jobs.get(name) + snapshot = dict(job) if job else {'state': 'idle'} + return jsonify({'name': name, 'job': snapshot}) + + +# --- Phase 6: pre-encoded latents ----------------------------------------- + +@app.route('/api/projects//pre-encode', methods=['POST']) +def pre_encode_project_route(name): + """Kick off SA3 pre-encoding for a project. Returns the job state (202).""" + from app.backend.data.projects import project_path + from app.backend.data.pre_encoder import start_pre_encode + if not project_path(name).exists(): + return jsonify({'error': f'Project not found: {name}'}), 404 + # silent=True so an empty/no-Content-Type body is treated as {} instead of + # Flask returning 415 — callers usually fire-and-forget without a payload. + body = request.get_json(silent=True) or {} + autoencoder = (body.get('autoencoder') or '').strip() or None + try: + job = start_pre_encode(name, autoencoder=autoencoder) + except (FileNotFoundError, ValueError) as exc: + return jsonify({'error': str(exc)}), 400 + except Exception as exc: + logger.exception("Failed to start pre-encode") + return jsonify({'error': str(exc)}), 500 + return jsonify({'name': name, 'job': job}), 202 - Body: { - entries: [{ file_name, prompt, src_path }, ...], - conflict_policy: 'skip' | 'overwrite' | 'rename', - copy_files: bool - } - """ - payload = request.json or {} - entries = payload.get('entries') or [] - policy = payload.get('conflict_policy') or 'skip' - copy_files = bool(payload.get('copy_files', True)) - - if policy not in {'skip', 'overwrite', 'rename'}: - return jsonify({'error': "conflict_policy must be one of 'skip', 'overwrite', 'rename'."}), 400 - if not entries: - return jsonify({'error': 'No entries to commit.'}), 400 - - config = get_config() - data_dir = config.get_path('data') - data_dir.mkdir(exist_ok=True, parents=True) - data_dir_resolved = data_dir.resolve() - - json_path = Path(config.get_metadata_json_path()) - existing_metadata = [] - if json_path.exists(): - try: - with open(json_path, 'r', encoding='utf-8') as f: - existing_metadata = json.load(f) - except Exception as exc: - logger.warning("Could not load existing metadata: %s", exc) - existing_metadata = [] - existing_files = {item['file_name']: item for item in existing_metadata if isinstance(item, dict)} - - def unique_name(base_name: str) -> str: - stem = Path(base_name).stem - suffix = Path(base_name).suffix - n = 2 - while True: - candidate = f"{stem}_{n}{suffix}" - if candidate not in existing_files and not (data_dir / candidate).exists(): - return candidate - n += 1 - - committed = 0 - skipped = 0 - renamed = 0 - overwritten = 0 - errors = [] - for entry in entries: - file_name = Path((entry.get('file_name') or '').strip()).name - prompt = (entry.get('prompt') or '').strip() - src_path = entry.get('src_path') or entry.get('path') +@app.route('/api/projects//pre-encode/status', methods=['GET']) +def pre_encode_status_route(name): + """Poll the current job state. Cheap (dict copy under lock).""" + from app.backend.data.projects import project_path + from app.backend.data.pre_encoder import get_pre_encode_job + if not project_path(name).exists(): + return jsonify({'error': f'Project not found: {name}'}), 404 + return jsonify({'name': name, 'job': get_pre_encode_job(name)}) - if not file_name or not prompt or not src_path: - errors.append({'file_name': file_name, 'reason': 'missing required field'}) - continue - src = Path(src_path) - if not src.is_file(): - errors.append({'file_name': file_name, 'reason': 'source audio missing'}) - continue +@app.route('/api/projects//pre-encode/cancel', methods=['POST']) +def pre_encode_cancel_route(name): + """Cooperative cancel. Signals SIGINT → SIGTERM → SIGKILL on the subprocess.""" + from app.backend.data.projects import project_path + from app.backend.data.pre_encoder import cancel_pre_encode + if not project_path(name).exists(): + return jsonify({'error': f'Project not found: {name}'}), 404 + cancelled = cancel_pre_encode(name) + return jsonify({'name': name, 'cancelled': cancelled}) - target_name = file_name - had_conflict = file_name in existing_files - force_copy = False - if had_conflict: - if policy == 'skip': - skipped += 1 - continue - if policy == 'rename': - target_name = unique_name(file_name) - renamed += 1 - # Rename only makes sense if a renamed file actually lands in data/. - force_copy = True - elif policy == 'overwrite': - overwritten += 1 - - if src.parent.resolve() != data_dir_resolved: - dst = data_dir / target_name - try: - _stage_into_data_dir(src, dst, copy_files=(copy_files or force_copy)) - except Exception as exc: - errors.append({'file_name': file_name, 'reason': f'stage failed: {exc}'}) - continue - stored_path = str(data_dir / target_name) - existing_files[target_name] = { - 'file_name': target_name, - 'prompt': prompt, - 'path': stored_path, - } - committed += 1 +@app.route('/api/projects//pre-encode/prompt', methods=['PATCH']) +def pre_encode_prompt_route(name): + """Persist the 'Don't ask again' choice from the post-commit dialog. - final_metadata = list(existing_files.values()) - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(final_metadata, f, indent=2) + Body: { "suppress": bool } + """ + from app.backend.data.projects import project_path, update_pre_encode_suppression + if not project_path(name).exists(): + return jsonify({'error': f'Project not found: {name}'}), 404 + body = request.get_json(silent=True) or {} + if 'suppress' not in body: + return jsonify({'error': "Body must contain 'suppress': bool."}), 400 + updated = update_pre_encode_suppression(name, bool(body['suppress'])) + return jsonify(updated) + + +@app.route('/api/clap/unload', methods=['POST']) +def clap_unload_route(): + """Free CLAP weights from VRAM (e.g. before starting training or generation).""" + from app.backend.data.auto_annotator import unload_clap + unload_clap() + return jsonify({'message': 'CLAP unloaded from memory.'}) - try: - config.update_dataset_config() - except Exception as exc: - logger.warning("Failed to refresh dataset-config.json: %s", exc) +# --- Native MIDI input ----------------------------------------------------- +# python-rtmidi reads hardware MIDI natively (CoreMIDI / WinMM / ALSA), so MIDI +# works regardless of the web engine. The frontend keeps all mapping/learn +# logic; it just consumes /api/midi/stream instead of Web MIDI. +@app.route('/api/midi/devices', methods=['GET']) +def midi_devices(): + from app.core.audio import midi_input return jsonify({ - 'message': ( - f'Imported {committed} entries ' - f'(skipped: {skipped}, renamed: {renamed}, overwritten: {overwritten}).' - ), - 'committed': committed, - 'skipped': skipped, - 'renamed': renamed, - 'overwritten': overwritten, - 'errors': errors, - 'metadata_json': str(json_path), + "available": midi_input.is_available(), + "inputs": midi_input.list_inputs(), + "current": midi_input.current_port(), }) -@app.route('/api/bulk-annotate/download-clap', methods=['POST']) -def bulk_annotate_download_clap(): - from app.backend.data.auto_annotator import download_clap_checkpoint - - with _annotate_job_lock: - if _clap_download_job['state'] == 'running': - return jsonify({'error': 'CLAP download already in progress.'}), 409 - _clap_download_job.update({'state': 'running', 'message': 'Starting download…', 'error': None}) - - def runner(): +@app.route('/api/midi/select', methods=['POST']) +def midi_select(): + from app.core.audio import midi_input + data = request.get_json(silent=True) or {} + port_id = data.get('port_id') + ok = midi_input.open_input(port_id) + if not ok and port_id: + return jsonify(APIResponse.error( + "Could not open that MIDI input port.", status_code=400)), 400 + return jsonify({"current": midi_input.current_port()}) + + +@app.route('/api/midi/stream', methods=['GET']) +def midi_stream(): + """Server-Sent Events stream of incoming MIDI from the open port. Each + event is {"data": [status, d1, d2]} — the same shape the frontend's + Web-MIDI dispatcher already expects.""" + from app.core.audio import midi_input + + def gen(): + q = midi_input.subscribe() try: - target = download_clap_checkpoint( - get_config().get_path('models_pretrained'), - progress_cb=lambda m: _clap_download_job.update({'message': m}), - ) - _clap_download_job.update({'state': 'done', 'message': f'Downloaded to {target}'}) - except Exception as exc: - logger.exception("CLAP download failed") - _clap_download_job.update({'state': 'error', 'error': str(exc)}) - - threading.Thread(target=runner, daemon=True).start() - return jsonify({'message': 'CLAP download started'}) - + yield ": connected\n\n" + while True: + try: + payload = q.get(timeout=15) + yield f"data: {json.dumps(payload)}\n\n" + except queue.Empty: + yield ": keepalive\n\n" # keep idle proxies/connections alive + except GeneratorExit: + pass + finally: + midi_input.unsubscribe(q) -@app.route('/api/bulk-annotate/unload-clap', methods=['POST']) -def bulk_annotate_unload_clap(): - from app.backend.data.auto_annotator import unload_clap - unload_clap() - return jsonify({'message': 'CLAP unloaded from memory.'}) + resp = Response(gen(), mimetype='text/event-stream') + resp.headers['Cache-Control'] = 'no-cache' + resp.headers['X-Accel-Buffering'] = 'no' + return resp @app.route('/shutdown', methods=['POST']) @@ -2532,4 +2919,10 @@ def shutdown(): if __name__ == '__main__': host = os.environ.get('FLASK_HOST', '0.0.0.0') port = int(os.environ.get('FLASK_PORT', '5001')) - app.run(debug=True, host=host, port=port) + # threaded=True so the long-lived MIDI SSE stream (/api/midi/stream) doesn't + # block other requests on the single dev-server worker. + # use_reloader=False: Fragmenta is launched as a packaged desktop app via + # start.py, not a hot-reload dev loop. The reloader would fork a second + # process that re-imports the module (re-running init and doubling backend + # processes); we don't want that here. + app.run(debug=True, host=host, port=port, threaded=True, use_reloader=False) diff --git a/app/backend/data/auto_annotator.py b/app/backend/data/auto_annotator.py index ead04e1cdb1635ce56563ca2becb74b273370704..f0cae9e268baf771768a9f2ebfdf60a4b65e6f05 100644 --- a/app/backend/data/auto_annotator.py +++ b/app/backend/data/auto_annotator.py @@ -22,6 +22,11 @@ AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac", ".m4a", ".ogg", ".aac") CLAP_CKPT_FILENAME = "music_audioset_epoch_15_esc_90.14.pt" CLAP_REPO = "lukewys/laion_clap" +# Text-side dependencies laion_clap pulls from HF on construction. +# We stage these into models/pretrained/clap/hub/ so the rich tier is +# fully offline after a single download and nothing leaks to ~/.cache. +CLAP_TEXT_DEPS = ("roberta-base", "bert-base-uncased", "facebook/bart-base") + KEY_NAMES_SHARP = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] KEY_NAMES_FLAT = ["C", "Db", "D", "Eb", "E", "F", "Gb", "G", "Ab", "A", "Bb", "B"] @@ -44,9 +49,18 @@ def _iter_audio_files(folder: Path) -> List[Path]: def _estimate_tempo(y, sr) -> Optional[int]: import librosa + import numpy as np try: tempo, _ = librosa.beat.beat_track(y=y, sr=sr) - bpm = float(tempo if hasattr(tempo, "__float__") else tempo[0]) + # librosa 0.10+ returns tempo as np.ndarray (shape (1,) typically). + # numpy 2.x removed implicit float() conversion of N-d arrays — + # `float(np.array([120.]))` now raises TypeError instead of returning + # 120.0 like numpy 1.x did. Normalize via .flat[0] which handles + # scalar, 0-d, 1-d, and N-d uniformly. + arr = np.atleast_1d(np.asarray(tempo)) + if arr.size == 0: + return None + bpm = float(arr.flat[0]) if bpm <= 0: return None return int(round(bpm)) @@ -178,12 +192,50 @@ class _ClapTagger: f"CLAP checkpoint not found at {self.ckpt_path}. " "Download it first via /api/bulk-annotate/download-clap." ) - import laion_clap - import torch logging.getLogger("transformers").setLevel(logging.ERROR) - device = "cuda" if torch.cuda.is_available() else "cpu" - model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device=device) + # Point HF resolution at our project-local cache and disable the + # HEAD-revalidation traffic. After download_clap_checkpoint() has + # staged the text deps under /clap/hub/, CLAP_Module + # loads them offline with zero HF hub requests. + # + # Two reasons env vars alone aren't enough: + # 1. huggingface_hub.constants.HF_HUB_OFFLINE is captured at + # module-import time (constants.py:185). model_manager.py + # imports huggingface_hub at app startup, so the constant is + # already False by the time we set the env var here. + # transformers.utils.hub.is_offline_mode reads that same + # constant — patching the attribute makes both libraries see + # offline mode. + # 2. laion_clap/training/data.py:44-46 runs three from_pretrained + # calls at MODULE LEVEL — those fire the first time we do + # `import laion_clap` and predate any patch we do after the + # import. So we patch BEFORE the import, not after. + hub_dir = self.ckpt_path.parent / "hub" + env_keys = ("HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE", "TRANSFORMERS_CACHE", + "HF_HUB_OFFLINE", "TRANSFORMERS_OFFLINE") + prev_env = {k: os.environ.get(k) for k in env_keys} + os.environ["HF_HUB_CACHE"] = str(hub_dir) + os.environ["HUGGINGFACE_HUB_CACHE"] = str(hub_dir) + os.environ["TRANSFORMERS_CACHE"] = str(hub_dir) + os.environ["HF_HUB_OFFLINE"] = "1" + os.environ["TRANSFORMERS_OFFLINE"] = "1" + + import huggingface_hub.constants as _hhc + prev_offline_attr = _hhc.HF_HUB_OFFLINE + _hhc.HF_HUB_OFFLINE = True + try: + import laion_clap # noqa: E402 — must follow the offline patch + import torch + device = "cuda" if torch.cuda.is_available() else "cpu" + model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device=device) + finally: + _hhc.HF_HUB_OFFLINE = prev_offline_attr + for k, v in prev_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v # torch >= 2.6 flipped torch.load(weights_only=True) and newer # transformers dropped the roberta position_ids buffer, so @@ -268,44 +320,97 @@ def clap_checkpoint_path(models_pretrained_dir: Path) -> Path: return models_pretrained_dir / "clap" / CLAP_CKPT_FILENAME +def clap_hub_dir(models_pretrained_dir: Path) -> Path: + """HF cache for laion_clap's text-side deps. Sibling of the .pt.""" + return models_pretrained_dir / "clap" / "hub" + + def clap_checkpoint_available(models_pretrained_dir: Path) -> bool: return clap_checkpoint_path(models_pretrained_dir).exists() +def _text_dep_snapshot_present(hub_dir: Path, repo_id: str) -> bool: + safe = "models--" + repo_id.replace("/", "--") + snap_root = hub_dir / safe / "snapshots" + if not snap_root.exists(): + return False + return any(snap_root.iterdir()) + + def download_clap_checkpoint( models_pretrained_dir: Path, progress_cb: Optional[Callable[[str], None]] = None, + phase_cb: Optional[Callable[[int, int, str], None]] = None, ) -> Path: + """Download the CLAP audio .pt plus laion_clap's text-side HF snapshots. + + Four sequential phases — emit a phase update (current, total, label) at the + start of each so a multi-phase progress UI can show real context. Skips + phases whose artifacts are already on disk. + + `progress_cb` (str-only) is kept for the bulk-annotate API. + `phase_cb` (current, total, label) is the structured channel. + """ target = clap_checkpoint_path(models_pretrained_dir) target.parent.mkdir(parents=True, exist_ok=True) - if target.exists(): - return target + hub_dir = clap_hub_dir(models_pretrained_dir) + hub_dir.mkdir(parents=True, exist_ok=True) - from huggingface_hub import hf_hub_download + from huggingface_hub import hf_hub_download, snapshot_download import os - if progress_cb: - progress_cb("Downloading CLAP checkpoint (~630 MB)…") + total_phases = 1 + len(CLAP_TEXT_DEPS) + + def _emit(phase_index: int, label: str) -> None: + if phase_cb: + phase_cb(phase_index, total_phases, label) + if progress_cb: + progress_cb(f"[{phase_index}/{total_phases}] {label}") + + if not target.exists(): + _emit(1, "CLAP audio model (~2.35 GB)") + + # Use custom CLAP from fragmenta-models on HF Spaces + use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true' + if use_custom_repo: + repo_id = "MazCodes/fragmenta-models" + else: + repo_id = CLAP_REPO + + downloaded = hf_hub_download( + repo_id=repo_id, + filename=CLAP_CKPT_FILENAME, + local_dir=str(target.parent), + ) + downloaded_path = Path(downloaded) + if downloaded_path != target: + try: + downloaded_path.replace(target) + except OSError: + import shutil + shutil.copy2(downloaded_path, target) + + # laion_clap's CLAP_Module(...) constructor instantiates a Roberta text + # branch plus bert/bart tokenizers at import time. Pre-stage them into + # our own cache so the rich tier is fully offline after this step. + # safetensors only — pytorch_model.bin is a redundant copy. + for i, repo_id in enumerate(CLAP_TEXT_DEPS, start=2): + if _text_dep_snapshot_present(hub_dir, repo_id): + continue + _emit(i, f"Text encoder: {repo_id}") + snapshot_download( + repo_id=repo_id, + cache_dir=str(hub_dir), + allow_patterns=[ + "config.json", + "tokenizer*", + "vocab*", + "merges.txt", + "special_tokens_map.json", + "model.safetensors", + ], + ) - # Use custom CLAP from fragmenta-models on HF Spaces - use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true' - if use_custom_repo: - repo_id = "MazCodes/fragmenta-models" - else: - repo_id = CLAP_REPO - - downloaded = hf_hub_download( - repo_id=repo_id, - filename=CLAP_CKPT_FILENAME, - local_dir=str(target.parent), - ) - downloaded_path = Path(downloaded) - if downloaded_path != target: - try: - downloaded_path.replace(target) - except OSError: - import shutil - shutil.copy2(downloaded_path, target) return target @@ -328,8 +433,10 @@ def annotate_file( label_sets: Dict[str, List[str]], sr: int = 22050, max_seconds: float = 60.0, + prompt_template: Optional[str] = None, ) -> Dict[str, Any]: import librosa + import warnings parts: Dict[str, Any] = {} try: @@ -343,10 +450,19 @@ def annotate_file( "error": f"load failed: {exc}", } - parts["bpm"] = _estimate_tempo(y, loaded_sr) - parts["key"] = _estimate_key(y, loaded_sr) - parts["brightness"] = _estimate_brightness(y, loaded_sr) - parts["character"] = _estimate_character(y, loaded_sr) + # Silent / harmonically flat clips trip librosa's "Trying to estimate + # tuning from empty frequency set" warning during chroma extraction. + # The warning is benign — the analysis returns sensible defaults — but + # it spams stderr on every silent file, so we mute it here. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Trying to estimate tuning from empty frequency set", + ) + parts["bpm"] = _estimate_tempo(y, loaded_sr) + parts["key"] = _estimate_key(y, loaded_sr) + parts["brightness"] = _estimate_brightness(y, loaded_sr) + parts["character"] = _estimate_character(y, loaded_sr) if tier == "rich" and clap_tagger is not None: try: @@ -355,7 +471,14 @@ def annotate_file( except Exception as exc: logger.warning("CLAP tagging failed for %s: %s", audio_path.name, exc) - prompt = _compose_prompt(parts) + # Template-driven prompt assembly. Falls back to the legacy descriptive + # prose if no template is supplied (call sites that haven't been + # threaded with project metadata yet). + if prompt_template is not None and prompt_template.strip(): + from app.backend.data.projects import apply_template + prompt = apply_template(prompt_template, parts) + else: + prompt = _compose_prompt(parts) return { "file_name": audio_path.name, "prompt": prompt, diff --git a/app/backend/data/pre_encoder.py b/app/backend/data/pre_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d6f8daca1cddf3b2dc5a6e01fb3a7338721579 --- /dev/null +++ b/app/backend/data/pre_encoder.py @@ -0,0 +1,354 @@ +"""SA3 pre-encoding job runner — Phase 6. + +Encodes every audio clip in a Dataset Workbench project into SA3 latents +once, ahead of training, so the training subprocess can skip the SAME +autoencoder pass per step. Mirrors the shape of `_project_annotate_jobs` +in app.py (background thread, per-project state, cooperative cancel). + +Latents land in `/.latents/` — a hidden subdirectory inside the +project folder. Disk layout matches SA3's `pre_encode_dataset.py`: + + /.latents/ + 000000000000.npy # latent tensor (shape (256, T_lat)) + 000000000000.json # {"prompt": "...", "padding_mask": [...], ...} + 000001000000.npy + 000001000000.json + ... + silence.npy # padding latent (auto-generated) + _meta.json # Fragmenta-specific: AE used, source clip count + +SA3's `train_lora.py --encoded_dir /.latents` consumes this layout +directly. `SA3Trainer._stage_dataset` auto-detects the directory and feeds +`--encoded_dir` to the subprocess when latents are present. + +Cache invalidation lives in projects.py — any project mutation that could +desync the latents (commit, delete_clip, slice_clip) wipes the directory. +""" + +from __future__ import annotations + +import json +import os +import re +import signal +import subprocess +import sys +import threading +import time +from pathlib import Path +from typing import Any, Dict, Optional + +from app.backend.data.projects import project_path +from app.core.config import get_config +from utils.logger import get_logger + +logger = get_logger("PreEncoder") + + +# --- Per-project job registry ---------------------------------------------- + +_pre_encode_jobs: Dict[str, Dict[str, Any]] = {} +_pre_encode_jobs_lock = threading.Lock() +_pre_encode_processes: Dict[str, subprocess.Popen] = {} + + +def get_pre_encode_job(project_name: str) -> Dict[str, Any]: + """Snapshot of the current job state for a project. Always returns a + well-formed dict so the frontend can render against it without guards.""" + with _pre_encode_jobs_lock: + job = _pre_encode_jobs.get(project_name) + if job is None: + return _idle_job() + return dict(job) + + +def _idle_job() -> Dict[str, Any]: + return { + "state": "idle", # idle | queued | running | complete | failed | cancelled + "current": 0, # batch index (0-based) + "total": 0, # total batches (derived from clip count) + "current_file": "", + "error": None, + "started_at": None, + "finished_at": None, + "autoencoder": None, + } + + +# --- Autoencoder selection ------------------------------------------------- + +# Bind latents to a specific SA3 autoencoder. Latents from same-s only work +# with small-music / small-sfx DiTs; same-l latents only work with medium. +# For v1 we default to same-s (covers the most common base) and leave a +# manifest in .latents/_meta.json that training reads to verify +# compatibility. If a user trains against medium with same-s latents, +# SA3Trainer falls back to non-encoded training and logs a warning. +DEFAULT_AUTOENCODER = "same-s" + +# Audio length (samples per channel) the dataset pads/crops to before +# encoding. SA3's pre_encode_dataset.py defaults to ~285s at 44.1 kHz, which +# covers any training-time --duration up to that limit (and SA3 small caps +# at 120s anyway). Longer clips in the project will be cropped to this +# length during encoding — a documented limitation for v1. +DEFAULT_SAMPLE_SIZE = 12_582_912 + + +# --- Job lifecycle --------------------------------------------------------- + +def latents_dir(project_name: str) -> Path: + return project_path(project_name) / ".latents" + + +def latents_count(project_name: str) -> int: + d = latents_dir(project_name) + if not d.exists(): + return 0 + return sum( + 1 for p in d.glob("*.npy") + if p.name != "silence.npy" + ) + + +def latents_meta(project_name: str) -> Optional[Dict[str, Any]]: + """Read the manifest we drop alongside the .npy files.""" + p = latents_dir(project_name) / "_meta.json" + if not p.exists(): + return None + try: + return json.loads(p.read_text(encoding="utf-8")) + except Exception: + return None + + +def latents_match_base(project_name: str, base_model: str) -> bool: + """Whether the cached latents are compatible with the chosen base. + + same-s ↔ small-music / small-sfx (and their *-base variants). + same-l ↔ medium (and medium-base). + """ + meta = latents_meta(project_name) + if not meta: + return False + ae = meta.get("autoencoder") + if ae == "same-s": + return base_model in ("sa3-small-music", "sa3-small-music-base", + "sa3-small-sfx", "sa3-small-sfx-base") + if ae == "same-l": + return base_model in ("sa3-medium", "sa3-medium-base") + return False + + +def cancel_pre_encode(project_name: str) -> bool: + """Send a cancel signal to an in-flight job. Returns True if cancelled.""" + with _pre_encode_jobs_lock: + job = _pre_encode_jobs.get(project_name) + if not job or job.get("state") not in ("queued", "running"): + return False + job["state"] = "cancelled" + job["cancelled"] = True + + proc = _pre_encode_processes.get(project_name) + if proc is not None and proc.poll() is None: + try: + proc.send_signal(signal.SIGINT) + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + except Exception as exc: + logger.warning("Failed to signal pre-encode subprocess: %s", exc) + return True + + +def start_pre_encode( + project_name: str, + autoencoder: Optional[str] = None, + sample_size: Optional[int] = None, +) -> Dict[str, Any]: + """Spawn the pre-encode subprocess in a background thread. Returns the + job state — frontend polls /pre-encode/status thereafter. + """ + proj_dir = project_path(project_name) + if not proj_dir.exists(): + raise FileNotFoundError(f"project not found: {project_name}") + + ae = autoencoder or DEFAULT_AUTOENCODER + if ae not in ("same-s", "same-l"): + raise ValueError(f"autoencoder must be 'same-s' or 'same-l'; got {ae!r}") + + with _pre_encode_jobs_lock: + existing = _pre_encode_jobs.get(project_name) + if existing and existing.get("state") in ("queued", "running"): + return dict(existing) + + # Count source clips (sidecars committed) so we know the denominator. + sidecars = list(proj_dir.glob("*.txt")) + clip_count = sum( + 1 for p in sidecars + if p.read_text(encoding="utf-8").strip() + and p.with_suffix(".wav").exists() # cheap & accurate enough + ) + + job: Dict[str, Any] = { + "state": "queued", + "current": 0, + "total": clip_count, + "current_file": "", + "error": None, + "started_at": time.time(), + "finished_at": None, + "autoencoder": ae, + "cancelled": False, + } + _pre_encode_jobs[project_name] = job + + thread = threading.Thread( + target=_run_pre_encode, + args=(project_name, ae, sample_size or DEFAULT_SAMPLE_SIZE), + daemon=True, + name=f"sa3-pre-encode:{project_name}", + ) + thread.start() + return get_pre_encode_job(project_name) + + +# --- Worker ---------------------------------------------------------------- + +def _update_job(project_name: str, **fields: Any) -> None: + with _pre_encode_jobs_lock: + job = _pre_encode_jobs.get(project_name) + if job is None: + return + job.update(fields) + + +def _run_pre_encode(project_name: str, ae: str, sample_size: int) -> None: + """Background-thread target. Spawns the SA3 pre_encode_dataset.py script, + streams stdout for progress, writes a _meta.json manifest on success.""" + cfg = get_config() + proj_dir = project_path(project_name) + out_dir = latents_dir(project_name) + out_dir.mkdir(parents=True, exist_ok=True) + + sa3_vendor = cfg.get_path("stable_audio_3") + venv_python = sys.executable + + cmd = [ + venv_python, + str(sa3_vendor / "scripts" / "pre_encode_dataset.py"), + "--model", ae, + "--data_dir", str(proj_dir), + "--output_path", str(out_dir), + "--batch_size", "1", + "--sample_size", str(int(sample_size)), + ] + + env = os.environ.copy() + pp = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = ( + f"{sa3_vendor}{os.pathsep}{pp}" if pp else str(sa3_vendor) + ) + hub_dir = cfg.get_path("models_pretrained") / "sa3" / "hub" + env["HF_HUB_CACHE"] = str(hub_dir) + env["HUGGINGFACE_HUB_CACHE"] = str(hub_dir) + env["TRANSFORMERS_CACHE"] = str(hub_dir) + env["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + env["HF_HUB_OFFLINE"] = "1" + env["TRANSFORMERS_OFFLINE"] = "1" + + _update_job(project_name, state="running") + logger.info( + "Pre-encoding started · project=%s · autoencoder=%s · clips=%d · sample_size=%d", + project_name, ae, get_pre_encode_job(project_name)["total"], sample_size, + ) + + batch_pat = re.compile(r"Processing batch (\d+)") + process: Optional[subprocess.Popen] = None + try: + process = subprocess.Popen( + cmd, + cwd=str(cfg.project_root), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + _pre_encode_processes[project_name] = process + + if process.stdout is not None: + for line in process.stdout: + line = line.rstrip() + m = batch_pat.search(line) + if m: + # Subprocess prints "Processing batch N" once per batch + # (and batch_size=1 → one batch per clip). N starts at 0. + _update_job(project_name, current=int(m.group(1)) + 1) + + rc = process.wait() if process else 1 + + # Check whether we got cancelled mid-flight. + snapshot = get_pre_encode_job(project_name) + if snapshot.get("cancelled"): + _update_job( + project_name, + state="cancelled", + finished_at=time.time(), + ) + logger.info("Pre-encoding cancelled · project=%s", project_name) + return + + if rc != 0: + _update_job( + project_name, + state="failed", + error=f"pre_encode_dataset.py exited with code {rc}", + finished_at=time.time(), + ) + logger.error( + "Pre-encoding failed (exit %s) · project=%s", + rc, project_name, + ) + return + + # Success — write manifest so SA3Trainer can verify AE compatibility. + manifest = { + "autoencoder": ae, + "sample_size": sample_size, + "created_at": time.time(), + "source_clip_count": snapshot.get("total", 0), + "encoded_count": latents_count(project_name), + } + try: + (out_dir / "_meta.json").write_text( + json.dumps(manifest, indent=2), encoding="utf-8", + ) + except Exception as exc: + logger.warning("Failed to write latents manifest: %s", exc) + + _update_job( + project_name, + state="complete", + current=manifest["encoded_count"], + total=manifest["encoded_count"] or snapshot.get("total", 0), + finished_at=time.time(), + ) + logger.info( + "Pre-encoding complete · project=%s · %d latent(s) · ae=%s", + project_name, manifest["encoded_count"], ae, + ) + + except Exception as exc: + _update_job( + project_name, + state="failed", + error=str(exc), + finished_at=time.time(), + ) + logger.exception("Pre-encoding crashed for project=%s", project_name) + finally: + _pre_encode_processes.pop(project_name, None) diff --git a/app/backend/data/projects.py b/app/backend/data/projects.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9a1333c4c96654e02e4434cf256626018d2649 --- /dev/null +++ b/app/backend/data/projects.py @@ -0,0 +1,1023 @@ +"""On-disk projects + buffered in-memory editing for SA3 sidecar datasets. + +A *project* is a folder under `/projects//` (or wherever +`FRAGMENTA_PROJECTS_DIR` points) holding audio + `.txt` sidecar pairs plus a +hidden `.project.json` with Fragmenta metadata. The on-disk folder is the +**committed** dataset — what training reads, what survives across app +restarts. + +The UI works against an **in-memory session** per loaded project. Prompt +edits, auto-annotate output, and just-ingested audio all live in memory +until the user explicitly persists them via: + + Save → write `.draft.json` (transient, hidden). Survives app restart + but is not the SA3 deliverable. + Commit → flush prompts to `.txt` sidecars, mark current audio as + committed in `.project.json`, delete `.draft.json`. + Discard → drop the in-memory session, delete `.draft.json`, remove any + audio files added since the last commit. + +See DATASET_PREP_REDESIGN.md for the full design and rationale. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import shutil +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from app.backend.data.auto_annotator import AUDIO_EXTENSIONS, _iter_audio_files + +logger = logging.getLogger(__name__) + +PROJECT_METADATA_FILENAME = ".project.json" +PROJECT_DRAFT_FILENAME = ".draft.json" +DEFAULT_INGEST_MODE = "copy" # copy | symlink +INGEST_MODES = ("copy", "symlink") + +# SA3's prompting guide (vendor/stable-audio-3/docs/guides/prompting.md) +# distinguishes three generation modes — music, stems / solo instruments, +# and audio samples / SFX — each with its own AudioSparx-tag convention. +# We ship one preset per mode and let the user pick a single id; the rest +# is opinionated defaults. Each segment is rendered by apply_template's +# segment-drop semantics, so missing CLAP attributes never leave dangling +# punctuation. +PROMPT_TEMPLATE_PRESETS: Dict[str, Dict[str, str]] = { + "music": { + "label": "Music", + "description": "Full instrumental tracks (SA3's `TrackType: Music` convention).", + "template": ( + "TrackType: Music, VocalType: Instrumental, " + "Genre: {genre}, Mood: {mood}, Instruments: {instruments}, " + "BPM: {bpm}, Key: {key}" + ), + }, + "instrument": { + "label": "Instrument / Stem", + "description": "Isolated parts or single-instrument pieces (`TrackType: Instrument`).", + "template": ( + "TrackType: Instrument, " + "Instruments: {instruments}, Genre: {genre}, " + "BPM: {bpm}, Key: {key}, Mood: {mood}" + ), + }, + "sfx": { + "label": "Sample / SFX", + "description": "Sound effects, one-shots, samples (`TrackType: SFX`).", + "template": "TrackType: SFX, {brightness}, {character}", + }, +} +DEFAULT_PROMPT_TEMPLATE_PRESET = "music" + +# Names must look like reasonable filesystem folders. +_VALID_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9 _\-.]{0,99}$") + + +# ---------- Locations ------------------------------------------------------- + + +def get_projects_dir() -> Path: + """Resolve the projects root. + + Honors `FRAGMENTA_PROJECTS_DIR` for power users; otherwise sits next to + `data/` and `models/` under the configured user_data_dir. + """ + override = os.environ.get("FRAGMENTA_PROJECTS_DIR") + if override: + root = Path(override).expanduser() + else: + from app.core.config import get_config + root = get_config().user_data_dir / "projects" + root.mkdir(parents=True, exist_ok=True) + return root + + +def project_path(name: str) -> Path: + return get_projects_dir() / name + + +def project_metadata_path(name: str) -> Path: + return project_path(name) / PROJECT_METADATA_FILENAME + + +def project_draft_path(name: str) -> Path: + return project_path(name) / PROJECT_DRAFT_FILENAME + + +# ---------- Validation ------------------------------------------------------ + + +def sanitize_project_name(raw: Any) -> str: + if not isinstance(raw, str): + raise ValueError("Project name must be a string.") + name = raw.strip() + if not name: + raise ValueError("Project name cannot be empty.") + if name in (".", ".."): + raise ValueError("Invalid project name.") + if not _VALID_NAME_RE.match(name): + raise ValueError( + "Project name must start with a letter or digit and may only " + "contain letters, digits, spaces, dashes, underscores, and dots." + ) + return name + + +# ---------- Disk persistence: committed state ------------------------------- + + +def _default_metadata(name: str) -> Dict[str, Any]: + now = time.time() + return { + "name": name, + "created_at": now, + "modified_at": now, + "committed_at": None, + "ingest_mode": DEFAULT_INGEST_MODE, + "prompt_template_preset": DEFAULT_PROMPT_TEMPLATE_PRESET, + "source_folders": [], + "committed_files": [], # files written to disk + already committed + } + + +def _read_metadata(name: str) -> Dict[str, Any]: + path = project_metadata_path(name) + if not path.exists(): + return _default_metadata(name) + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + except (OSError, json.JSONDecodeError) as exc: + logger.warning("Could not read project metadata %s: %s; using defaults.", path, exc) + return _default_metadata(name) + defaults = _default_metadata(name) + for k, v in defaults.items(): + data.setdefault(k, v) + return data + + +def _write_metadata(name: str, metadata: Dict[str, Any]) -> None: + metadata["modified_at"] = time.time() + path = project_metadata_path(name) + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + os.replace(tmp, path) + + +def _sidecar_for(audio_path: Path) -> Path: + return audio_path.with_suffix(".txt") + + +def _read_sidecar(audio_path: Path) -> str: + txt = _sidecar_for(audio_path) + if not txt.exists(): + return "" + try: + return txt.read_text(encoding="utf-8").strip() + except OSError: + return "" + + +def _write_sidecar(audio_path: Path, prompt: str) -> None: + _sidecar_for(audio_path).write_text(prompt or "", encoding="utf-8") + + +# ---------- Disk persistence: draft state ----------------------------------- + + +def _read_draft(name: str) -> Optional[Dict[str, Any]]: + path = project_draft_path(name) + if not path.exists(): + return None + try: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + except (OSError, json.JSONDecodeError) as exc: + logger.warning("Could not read draft %s: %s; treating as no draft.", path, exc) + return None + + +def _write_draft(name: str, draft: Dict[str, Any]) -> None: + draft["saved_at"] = time.time() + path = project_draft_path(name) + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(draft, f, indent=2) + os.replace(tmp, path) + + +def _delete_draft(name: str) -> None: + path = project_draft_path(name) + if path.exists(): + path.unlink() + + +# ---------- In-memory session ---------------------------------------------- + + +@dataclass +class ClipState: + """One clip in an active project session. + + `prompt` is the live in-memory value (what the UI shows). `committed_prompt` + is what's on disk in the sidecar — used to compute dirtiness. + + `parent` is the original clip's file_name if this clip was produced by a + slice operation in the current session. In-memory only; not persisted + across restart (yet). Future merge-back will need disk-level lineage. + """ + file_name: str + path: str + prompt: str = "" + committed_prompt: str = "" + committed: bool = True # False if audio was added since last commit + parent: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "file_name": self.file_name, + "path": self.path, + "prompt": self.prompt, + "committed_prompt": self.committed_prompt, + "committed": self.committed, + "dirty": self.prompt != self.committed_prompt, + "parent": self.parent, + } + + +@dataclass +class ProjectSession: + """In-memory view of a project. One per loaded project name. + + Loading happens lazily on first GET. The session stays alive until + the user discards, commits, or the process exits. + """ + name: str + clips: Dict[str, ClipState] = field(default_factory=dict) # by file_name + saved_at: Optional[float] = None # last time .draft.json was written + last_save_snapshot: Dict[str, str] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + cancel_event: threading.Event = field(default_factory=threading.Event) + lock: threading.Lock = field(default_factory=threading.Lock) + # file_name -> (peaks, duration). Lazily filled by get_or_compute_peaks. + # Cleared on Discard. Survives an annotate; safe to recompute on miss. + peaks_cache: Dict[str, Tuple[List[float], float]] = field(default_factory=dict) + # file_name -> duration_sec. Same lifecycle, but populated cheaply via + # soundfile.info() instead of waiting for a peaks fetch. + duration_cache: Dict[str, float] = field(default_factory=dict) + + def _draft_snapshot(self) -> Dict[str, str]: + """Map file_name -> prompt, only for clips whose prompt differs from + the committed sidecar. Used both to decide if a Save is needed and + to compute the on-disk draft contents.""" + return {c.file_name: c.prompt for c in self.clips.values() if c.prompt != c.committed_prompt} + + def has_dirty_prompts(self) -> bool: + return any(c.prompt != c.committed_prompt for c in self.clips.values()) + + def has_uncommitted_files(self) -> bool: + return any(not c.committed for c in self.clips.values()) + + def has_unsaved_changes(self) -> bool: + """True if the in-memory state differs from the saved draft.""" + return self._draft_snapshot() != self.last_save_snapshot + + def to_dict(self) -> Dict[str, Any]: + ordered = sorted(self.clips.values(), key=lambda c: c.file_name) + # Phase 6 — pre-encoded latents state. The latents live inside the + # project at .latents/. Surface presence + count for the UI, plus + # the per-project "don't ask again" flag for the post-commit dialog. + proj_path = project_path(self.name) + latents_dir = proj_path / ".latents" + latents_npy = ( + [p for p in latents_dir.glob("*.npy") if p.name != "silence.npy"] + if latents_dir.exists() else [] + ) + return { + "name": self.name, + "created_at": self.metadata.get("created_at"), + "modified_at": self.metadata.get("modified_at"), + "committed_at": self.metadata.get("committed_at"), + "ingest_mode": self.metadata.get("ingest_mode", DEFAULT_INGEST_MODE), + "prompt_template_preset": ( + self.metadata.get("prompt_template_preset") or DEFAULT_PROMPT_TEMPLATE_PRESET + ), + "prompt_template_presets": [ + {"id": k, "label": v["label"], "description": v["description"], "template": v["template"]} + for k, v in PROMPT_TEMPLATE_PRESETS.items() + ], + "source_folders": list(self.metadata.get("source_folders", [])), + "saved_at": self.saved_at, + "dirty": self.has_dirty_prompts() or self.has_uncommitted_files(), + "has_unsaved_changes": self.has_unsaved_changes(), + "uncommitted_files": [c.file_name for c in ordered if not c.committed], + "clips": [c.to_dict() for c in ordered], + "clip_count": len(self.clips), + "latents_present": bool(latents_npy), + "latents_count": len(latents_npy), + "suppress_pre_encode_prompt": bool(self.metadata.get("suppress_pre_encode_prompt")), + } + + +# Registry of active sessions keyed by project name. +_sessions: Dict[str, ProjectSession] = {} +_sessions_lock = threading.Lock() + + +def _get_or_load_session(name: str) -> ProjectSession: + """Return the active session for `name`, loading from disk if needed.""" + with _sessions_lock: + existing = _sessions.get(name) + if existing is not None: + return existing + + # Validate folder exists. + path = project_path(name) + if not path.exists() or not path.is_dir(): + raise FileNotFoundError(f"Project not found: {name}") + + metadata = _read_metadata(name) + committed_files = set(metadata.get("committed_files") or []) + + # Build clip states from the disk layout. `committed_prompt` is whatever's + # in the .txt sidecar today. + clips: Dict[str, ClipState] = {} + for audio_path in sorted(path.iterdir()): + if not audio_path.is_file(): + continue + if audio_path.suffix.lower() not in AUDIO_EXTENSIONS: + continue + committed_prompt = _read_sidecar(audio_path) + is_committed = audio_path.name in committed_files + clips[audio_path.name] = ClipState( + file_name=audio_path.name, + path=str(audio_path), + prompt=committed_prompt, + committed_prompt=committed_prompt, + committed=is_committed, + ) + + session = ProjectSession(name=name, clips=clips, metadata=metadata) + + # Overlay any draft prompts on top of committed values. + draft = _read_draft(name) + if draft: + for file_name, prompt in (draft.get("prompts") or {}).items(): + clip = session.clips.get(file_name) + if clip is not None: + clip.prompt = prompt + session.saved_at = draft.get("saved_at") + session.last_save_snapshot = dict(draft.get("prompts") or {}) + + with _sessions_lock: + # Race: another thread may have loaded concurrently. Use whichever + # got in first. + existing = _sessions.get(name) + if existing is not None: + return existing + _sessions[name] = session + return session + + +def _drop_session(name: str) -> None: + with _sessions_lock: + _sessions.pop(name, None) + + +# ---------- CRUD ------------------------------------------------------------ + + +def list_projects() -> List[Dict[str, Any]]: + root = get_projects_dir() + out: List[Dict[str, Any]] = [] + for entry in sorted(root.iterdir()): + if not entry.is_dir() or entry.name.startswith("."): + continue + try: + meta = _read_metadata(entry.name) + except Exception as exc: + logger.warning("Skipping project %s: %s", entry.name, exc) + continue + clip_count = sum( + 1 for f in entry.iterdir() + if f.is_file() and f.suffix.lower() in AUDIO_EXTENSIONS + ) + has_draft = project_draft_path(entry.name).exists() + out.append({ + "name": entry.name, + "created_at": meta.get("created_at"), + "modified_at": meta.get("modified_at"), + "committed_at": meta.get("committed_at"), + "clip_count": clip_count, + "has_draft": has_draft, + }) + return out + + +def create_project(name: str) -> Dict[str, Any]: + name = sanitize_project_name(name) + path = project_path(name) + if path.exists(): + raise FileExistsError(f"Project '{name}' already exists.") + path.mkdir(parents=True) + metadata = _default_metadata(name) + _write_metadata(name, metadata) + return get_project(name) + + +def get_project(name: str) -> Dict[str, Any]: + session = _get_or_load_session(name) + with session.lock: + return session.to_dict() + + +def _stage_file(src: Path, dst: Path, mode: str) -> str: + """Place `src` at `dst` using the requested ingest mode.""" + if dst.exists() or dst.is_symlink(): + return "skipped" + if mode == "symlink": + try: + dst.symlink_to(src.resolve()) + return "symlinked" + except OSError as exc: + logger.warning("Symlink failed for %s -> %s: %s; falling back to copy.", src, dst, exc) + shutil.copy2(src, dst) + return "copied" + else: + shutil.copy2(src, dst) + return "copied" + + +def ingest_folder(name: str, source_folder: Path, mode: str) -> Dict[str, Any]: + """Add every audio file under `source_folder` to project `name`. + + Audio is written to disk immediately (we don't buffer gigabytes). The + new files are flagged as uncommitted in the session so a later Discard + can remove them. + """ + if mode not in INGEST_MODES: + raise ValueError(f"Invalid ingest mode: {mode}") + if not source_folder.exists() or not source_folder.is_dir(): + raise FileNotFoundError(f"Source folder not found: {source_folder}") + + session = _get_or_load_session(name) + proj_path = project_path(name) + + files = _iter_audio_files(source_folder) + if not files: + raise ValueError(f"No audio files found in {source_folder}") + + copied = 0 + symlinked = 0 + skipped = 0 + with session.lock: + for src in files: + dst = proj_path / src.name + tag = _stage_file(src, dst, mode) + if tag == "copied": + copied += 1 + elif tag == "symlinked": + symlinked += 1 + else: + skipped += 1 + if tag != "skipped" and src.name not in session.clips: + # Newly added file — uncommitted. + session.clips[src.name] = ClipState( + file_name=src.name, + path=str(dst), + prompt="", + committed_prompt="", + committed=False, + ) + + session.metadata["ingest_mode"] = mode + src_abs = str(source_folder.resolve()) + if src_abs not in session.metadata.setdefault("source_folders", []): + session.metadata["source_folders"].append(src_abs) + + return { + "copied": copied, + "symlinked": symlinked, + "skipped": skipped, + "added": copied + symlinked, + } + + +def update_clip_prompt(name: str, file_name: str, prompt: str) -> Dict[str, Any]: + """In-memory only. Disk is not touched until Save or Commit.""" + session = _get_or_load_session(name) + with session.lock: + clip = session.clips.get(file_name) + if clip is None: + raise FileNotFoundError(f"Clip not found in project '{name}': {file_name}") + clip.prompt = prompt or "" + return clip.to_dict() + + +def delete_clip(name: str, file_name: str) -> None: + """Remove a clip immediately (audio + sidecar + session entry). + + Treated like ingest: the disk change happens now, since carrying a + pending-deletion in memory complicates everything for no real win. + Discard cannot recover deleted files. + """ + session = _get_or_load_session(name) + proj_path = project_path(name) + with session.lock: + audio_path = proj_path / file_name + txt_path = _sidecar_for(audio_path) + if audio_path.exists(): + audio_path.unlink() + if txt_path.exists(): + txt_path.unlink() + session.clips.pop(file_name, None) + # Evict any cached peaks for this file (regardless of N). + for key in list(session.peaks_cache): + if key.startswith(f"{file_name}:"): + del session.peaks_cache[key] + session.duration_cache.pop(file_name, None) + committed = session.metadata.get("committed_files") or [] + if file_name in committed: + session.metadata["committed_files"] = [f for f in committed if f != file_name] + # Invalidate latents — outside the lock so we don't block under FS I/O. + _invalidate_latents(name) + + +# ---------- Save / Commit / Discard ----------------------------------------- + + +def save_project(name: str) -> Dict[str, Any]: + """Persist the current in-memory prompt diffs as a hidden draft.""" + session = _get_or_load_session(name) + with session.lock: + snapshot = session._draft_snapshot() + draft = { + "prompts": snapshot, + "uncommitted_files": [c.file_name for c in session.clips.values() if not c.committed], + } + _write_draft(name, draft) + session.saved_at = time.time() + session.last_save_snapshot = dict(snapshot) + return session.to_dict() + + +def _invalidate_latents(name: str) -> None: + """Phase 6 — wipe any pre-encoded latents for this project. + + Latents are bound to specific source-clip content; any mutation that + changes the source set (commit, delete_clip, slice_clip) renders them + misaligned. v1 strategy is wipe-and-recompute; per-clip invalidation + is a follow-up (not worth the complexity for the speed-up we get). + """ + latents_dir = project_path(name) / ".latents" + if latents_dir.exists(): + shutil.rmtree(latents_dir, ignore_errors=True) + + +def update_pre_encode_suppression(name: str, suppress: bool) -> Dict[str, Any]: + """Persist the 'Don't ask again' choice from the post-commit dialog. + + Stored on .project.json so it survives restart. The Training-tab + fallback button is always available regardless of this flag. + """ + session = _get_or_load_session(name) + with session.lock: + session.metadata["suppress_pre_encode_prompt"] = bool(suppress) + _write_metadata(name, session.metadata) + return session.to_dict() + + +def commit_project(name: str) -> Dict[str, Any]: + """Flush in-memory state to disk as the canonical SA3 dataset. + + Overwrites existing sidecars. Marks all current audio as committed. + Deletes any draft. Wipes any pre-encoded latents — re-encode is + explicit via the post-commit dialog or the Training-tab button. + """ + _invalidate_latents(name) + session = _get_or_load_session(name) + proj_path = project_path(name) + with session.lock: + # Write a sidecar for every clip, even if the prompt didn't change. + # This guarantees the on-disk state is exactly the in-memory state + # after Commit, no surprises. + for clip in session.clips.values(): + audio_path = proj_path / clip.file_name + _write_sidecar(audio_path, clip.prompt) + clip.committed_prompt = clip.prompt + clip.committed = True + + session.metadata["committed_files"] = sorted(session.clips.keys()) + session.metadata["committed_at"] = time.time() + _write_metadata(name, session.metadata) + _delete_draft(name) + session.saved_at = None + session.last_save_snapshot = {} + return session.to_dict() + + +def delete_project(name: str) -> None: + """Permanently remove a project — folder, sidecars, drafts, session. + + Destructive: there is no recovery path. Caller should confirm with + the user before invoking. + """ + proj_path = project_path(name) + if not proj_path.exists(): + raise FileNotFoundError(f"Project not found: {name}") + + # Cancel any in-flight annotate first, drop the session, then nuke + # the folder. Order matters: if we rm the folder while another + # thread is writing to it (e.g. annotate writing prompts to memory + # is fine, but the audio-stream endpoint could be holding a file + # handle), at least the session is gone so no fresh writes happen. + with _sessions_lock: + existing = _sessions.pop(name, None) + if existing is not None: + existing.cancel_event.set() + shutil.rmtree(proj_path, ignore_errors=True) + + +def discard_project(name: str) -> Dict[str, Any]: + """Throw away all uncommitted work. + + - Delete the draft. + - Delete audio files added since the last commit (and their sidecars). + - Drop the in-memory session so the next GET rebuilds from disk. + """ + session = _get_or_load_session(name) + proj_path = project_path(name) + with session.lock: + # Cancel any in-flight annotate before we tear state apart. + session.cancel_event.set() + + uncommitted = [c.file_name for c in session.clips.values() if not c.committed] + for file_name in uncommitted: + audio_path = proj_path / file_name + txt_path = _sidecar_for(audio_path) + if audio_path.exists(): + audio_path.unlink() + if txt_path.exists(): + txt_path.unlink() + + _delete_draft(name) + + _drop_session(name) + return get_project(name) + + +# ---------- Annotate cancellation handle ------------------------------------ + + +def get_session_handle(name: str) -> ProjectSession: + """Used by the annotate endpoint to share a cancel handle + clip dict.""" + return _get_or_load_session(name) + + +def reset_cancel(session: ProjectSession) -> None: + session.cancel_event.clear() + + +# ---------- Prompt template ------------------------------------------------- + + +_TEMPLATE_VAR_RE = re.compile(r"\{([a-zA-Z_][a-zA-Z0-9_]*)\}") + + +def _render_value(name: str, raw: Any) -> str: + """Stringify one variable value. Lists get joined; falsy is empty.""" + if raw is None: + return "" + if isinstance(raw, (list, tuple)): + parts = [str(x).strip() for x in raw if str(x).strip()] + return ", ".join(parts) + text = str(raw).strip() + return text + + +def apply_template(template: str, attributes: Dict[str, Any]) -> str: + """Segment-based templating with graceful missing-value handling. + + The template is split on ',' (segments). For each segment, every + {var} placeholder is resolved against `attributes`. If any placeholder + in the segment resolves to empty/missing, the whole segment is dropped + — so a missing key/BPM/whatever doesn't leave dangling punctuation. + + Segments without any placeholders (e.g. "TrackType: Music") always + appear. + """ + if not template: + return "" + out_segments: List[str] = [] + for raw_segment in template.split(","): + segment = raw_segment.strip() + if not segment: + continue + var_names = _TEMPLATE_VAR_RE.findall(segment) + if var_names: + resolved = {n: _render_value(n, attributes.get(n)) for n in var_names} + if any(not v for v in resolved.values()): + continue # drop the segment — one of its vars is missing + segment = _TEMPLATE_VAR_RE.sub( + lambda m: resolved[m.group(1)], + segment, + ) + out_segments.append(segment) + return ", ".join(out_segments) + + +def resolve_prompt_template(session: "ProjectSession") -> str: + """Return the active template string for the project's selected preset. + + Falls back to the music default if the stored preset id is unknown + (e.g. someone hand-edited .project.json to a bad value). + """ + preset_id = (session.metadata.get("prompt_template_preset") + or DEFAULT_PROMPT_TEMPLATE_PRESET) + preset = PROMPT_TEMPLATE_PRESETS.get(preset_id) + if preset is None: + preset = PROMPT_TEMPLATE_PRESETS[DEFAULT_PROMPT_TEMPLATE_PRESET] + return preset["template"] + + +def update_project_template_preset(name: str, preset_id: str) -> Dict[str, Any]: + """Persist the user-selected preset id and return updated project state.""" + if not isinstance(preset_id, str) or preset_id not in PROMPT_TEMPLATE_PRESETS: + valid = ", ".join(PROMPT_TEMPLATE_PRESETS.keys()) + raise ValueError(f"Unknown preset id: {preset_id!r}. Valid: {valid}") + session = _get_or_load_session(name) + with session.lock: + session.metadata["prompt_template_preset"] = preset_id + # Drop the legacy free-form field so we stop carrying two parallel + # ways to configure annotation shape. + session.metadata.pop("prompt_template", None) + _write_metadata(name, session.metadata) + return get_project(name) + + +# ---------- Waveform peaks -------------------------------------------------- + + +def _compute_peaks(audio_path: Path, n: int) -> Tuple[List[float], float]: + """Return N normalized peak amplitudes + duration in seconds. + + Reads N short blocks at evenly spaced offsets via soundfile.seek instead + of decoding the whole file. ~40x faster than librosa.load on a typical + 30s clip; bounded I/O regardless of file length (a 5-minute clip costs + the same as a 30s one). + + Falls back to a librosa-based decode for formats soundfile can't open + on this build (typically m4a/aac without ffmpeg-libsndfile). + """ + import numpy as np + try: + import soundfile as sf + with sf.SoundFile(str(audio_path)) as src: + total = src.frames + sr = src.samplerate + if total == 0: + return ([0.0] * n, 0.0) + duration = float(total / sr) + # ~6 buckets-worth of samples per probe gives stable peaks without + # devolving into "read the whole file." + block = max(256, total // (n * 6)) + peaks = np.zeros(n, dtype="float32") + for i in range(n): + center = int((i + 0.5) * total / n) + start = max(0, center - block // 2) + src.seek(start) + data = src.read(block, dtype="float32", always_2d=False) + if data.ndim > 1: + data = data.max(axis=1) + if len(data): + peaks[i] = float(np.abs(data).max()) + max_peak = float(peaks.max()) + if max_peak > 0: + peaks = peaks / max_peak + return (peaks.tolist(), duration) + except Exception as exc: + logger.debug("soundfile peak path failed for %s (%s); falling back to librosa", audio_path.name, exc) + + # Fallback: librosa.load handles every codec we register, at the cost of + # a full-file decode + resample. Slower but bulletproof. + import librosa + y, sr = librosa.load(str(audio_path), sr=8000, mono=True) + if len(y) == 0: + return ([0.0] * n, 0.0) + duration = float(len(y) / sr) + chunks = np.array_split(y, n) + peaks = np.array([float(np.abs(c).max()) if len(c) else 0.0 for c in chunks]) + max_peak = peaks.max() + if max_peak > 0: + peaks = peaks / max_peak + return (peaks.tolist(), duration) + + +def get_or_compute_peaks( + session: ProjectSession, + file_name: str, + audio_path: Path, + n: int = 200, +) -> Tuple[List[float], float]: + """Memoized per-session peak computation. Cache key is `file_name:N`.""" + cache_key = f"{file_name}:{n}" + cached = session.peaks_cache.get(cache_key) + if cached is not None: + return cached + result = _compute_peaks(audio_path, n) + session.peaks_cache[cache_key] = result + return result + + +# ---------- Health checks --------------------------------------------------- + + +def _clip_duration_sec(audio_path: Path) -> Optional[float]: + """Cheap duration probe via soundfile.info() — header read, no decode.""" + try: + import soundfile as sf + info = sf.info(str(audio_path)) + if info.samplerate <= 0: + return None + return float(info.frames / info.samplerate) + except Exception: + return None + + +def compute_health( + name: str, + short_threshold_sec: float = 1.0, +) -> Dict[str, Any]: + """Per-clip checks that surface dataset problems before training. + + Note: we don't flag "too long" clips. The SA3 dataloader handles them + via random-crop per __getitem__ — long files just get sampled at + different windows across epochs. Slicing remains useful for annotation + granularity and CLAP's 10s window, but it's not a correctness issue. + + We also don't flag mixed sample rates or loudness: SA3 resamples every + file to its model rate (T.Resample in its dataset loader) and Fragmenta + enables SA3's built-in -16 LUFS VolumeNorm at train/pre-encode time, so + both are handled automatically downstream. + + short_threshold_sec defaults to 1s — clips below this end up mostly + silence-padded into the training window. + """ + from collections import defaultdict + + # Single source of truth for what SA3's loader actually accepts. Fragmenta + # ingest accepts a wider set (.m4a, .aac) — those files would be silently + # skipped at train time, so we surface them here. + from app.core.training.sa3_lora_runner import SA3_AUDIO_EXTENSIONS + + session = _get_or_load_session(name) + with session.lock: + clips = list(session.clips.values()) + + empty_prompts: List[str] = [] + too_short: List[str] = [] + unsupported_format: List[str] = [] + prompt_groups: Dict[str, List[str]] = defaultdict(list) + + for c in clips: + if not (c.prompt or "").strip(): + empty_prompts.append(c.file_name) + else: + prompt_groups[c.prompt.strip().lower()].append(c.file_name) + + ext = Path(c.file_name).suffix.lower() + if ext not in SA3_AUDIO_EXTENSIONS: + unsupported_format.append(c.file_name) + + # Duration (header-only, ~free) — only used for the too-short check now. + dur = session.duration_cache.get(c.file_name) + if dur is None: + dur = _clip_duration_sec(Path(c.path)) + if dur is not None: + session.duration_cache[c.file_name] = dur + if dur is not None and dur < short_threshold_sec: + too_short.append(c.file_name) + + # --- Duplicate annotations: any non-empty prompt shared by 2+ clips. + dup_groups = [files for files in prompt_groups.values() if len(files) > 1] + dup_files = sorted({f for group in dup_groups for f in group}) + + empty_prompts.sort() + too_short.sort() + unsupported_format.sort() + + return { + "total_clips": len(clips), + "empty_prompts": {"count": len(empty_prompts), "files": empty_prompts}, + "too_short": { + "count": len(too_short), + "threshold_sec": short_threshold_sec, + "files": too_short, + }, + "unsupported_format": { + "count": len(unsupported_format), + "accepted": sorted(SA3_AUDIO_EXTENSIONS), + "files": unsupported_format, + }, + "duplicate_annotations": { + "count": len(dup_files), + "group_count": len(dup_groups), + "files": dup_files, + }, + } + + +# ---------- Slicing --------------------------------------------------------- + + +def slice_clip( + name: str, + file_name: str, + target_sec: float, + overlap_sec: float, + strategy: str, +) -> Dict[str, Any]: + """Split one clip into N children. Disk-level — happens immediately. + + The parent file (and its sidecar) is deleted. Each child: + - lives in the project folder as `__NNN.wav` + - inherits the parent's in-memory prompt verbatim + - is uncommitted (so Discard rolls it back) + - keeps `parent=` in its session state + + Discard cannot recover the parent file from children — same rule as + delete_clip. Commit makes the slice permanent. + """ + from app.backend.data.slicing import plan_slices, write_slices + + session = _get_or_load_session(name) + proj_path = project_path(name) + audio_path = proj_path / file_name + + if not audio_path.exists(): + raise FileNotFoundError(f"Clip not on disk: {file_name}") + + plans = plan_slices(audio_path, target_sec, overlap_sec, strategy) + if len(plans) <= 1: + raise ValueError( + f"{file_name} is shorter than the target duration " + f"({target_sec:.1f}s); nothing to slice." + ) + + stem = audio_path.stem + children = write_slices(audio_path, plans, proj_path, stem) + if not children: + raise RuntimeError("Slice produced no children — check the audio file.") + + with session.lock: + parent_clip = session.clips.get(file_name) + inherited_prompt = parent_clip.prompt if parent_clip else "" + + # Remove the parent from session + disk. + session.clips.pop(file_name, None) + for key in list(session.peaks_cache): + if key.startswith(f"{file_name}:"): + del session.peaks_cache[key] + session.duration_cache.pop(file_name, None) + sidecar = _sidecar_for(audio_path) + if audio_path.exists(): + audio_path.unlink() + if sidecar.exists(): + sidecar.unlink() + committed = session.metadata.get("committed_files") or [] + if file_name in committed: + session.metadata["committed_files"] = [f for f in committed if f != file_name] + + # Register children as uncommitted clips with parent linkage. + for child_path in children: + session.clips[child_path.name] = ClipState( + file_name=child_path.name, + path=str(child_path), + prompt=inherited_prompt, + committed_prompt="", + committed=False, + parent=file_name, + ) + + # Slicing replaces the parent's audio with N children → any cached + # latents reference the deleted parent and are now misaligned. + _invalidate_latents(name) + + return { + "parent": file_name, + "children": [ + {"file_name": p.name, "start_sec": pl.start_sec, "end_sec": pl.end_sec} + for p, pl in zip(children, plans) + ], + "project": get_project(name), + } diff --git a/app/backend/data/slicing.py b/app/backend/data/slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..706b0108c45f389ca21cd4096729d8fbbcf572de --- /dev/null +++ b/app/backend/data/slicing.py @@ -0,0 +1,183 @@ +"""Audio slicing for the Dataset Workbench. + +Splits one audio file into N children. Three strategies: + + hard — uniform cuts every `target_duration` seconds. + transient — uniform anchor points, each snapped to the nearest onset + (librosa.onset.onset_detect). + silence — uniform anchor points, each snapped to the nearest low-RMS + window (cleanest splice between phrases). + +All three honor `overlap_sec`, applied as a head-overlap on every child +after the first: child i starts at (end of child i-1) - overlap_sec. + +Writes WAV regardless of source format (lossless, no codec deps). Parent +prompt is inherited verbatim; the user edits children individually after. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import List, Literal, Tuple + +logger = logging.getLogger(__name__) + +SliceStrategy = Literal["hard", "transient", "silence"] +VALID_STRATEGIES = ("hard", "transient", "silence") + +# How far a snap is allowed to move from the uniform anchor. Beyond this we +# just take the anchor — better a tidy cut than a wildly off-target chunk. +SNAP_WINDOW_FRAC = 0.35 + + +@dataclass +class SlicePlan: + """One child's location inside the parent. Times are in seconds.""" + index: int # 1-based + start_sec: float + end_sec: float + + +def _uniform_anchors(duration_sec: float, target_sec: float, overlap_sec: float) -> List[Tuple[float, float]]: + """Return [(start, end), ...] for uniform cuts, before any snapping.""" + if target_sec <= 0: + raise ValueError("target_duration must be positive") + if overlap_sec < 0 or overlap_sec >= target_sec: + raise ValueError("overlap_sec must be >= 0 and < target_duration") + step = target_sec - overlap_sec + anchors: List[Tuple[float, float]] = [] + start = 0.0 + while start < duration_sec - 0.05: # don't emit a sub-50ms tail + end = min(start + target_sec, duration_sec) + anchors.append((start, end)) + if end >= duration_sec: + break + start += step + return anchors + + +def _snap_to_onsets(anchors: List[Tuple[float, float]], y, sr: int, target_sec: float) -> List[Tuple[float, float]]: + """Snap each cut boundary to the nearest detected onset within a window.""" + import librosa + import numpy as np + onsets = librosa.onset.onset_detect(y=y, sr=sr, units="time", backtrack=True) + if len(onsets) == 0: + return anchors + snap_window = target_sec * SNAP_WINDOW_FRAC + out: List[Tuple[float, float]] = [] + for i, (s, e) in enumerate(anchors): + if i > 0: + # Snap the start (= previous end) to nearest onset within window. + candidates = onsets[(onsets >= s - snap_window) & (onsets <= s + snap_window)] + if len(candidates): + s = float(min(candidates, key=lambda t: abs(t - s))) + out.append((s, e)) + # Stitch ends to match next start so no gap/overlap drift creeps in. + for i in range(len(out) - 1): + s, _ = out[i] + next_s, _ = out[i + 1] + out[i] = (s, next_s + (target_sec * 0.0)) # next_s alone — overlap is in next_s already from caller + return out + + +def _snap_to_silence(anchors: List[Tuple[float, float]], y, sr: int, target_sec: float) -> List[Tuple[float, float]]: + """Snap each cut boundary to the lowest-RMS frame within a window.""" + import librosa + import numpy as np + # Frame-level RMS at ~20ms hop. + hop = max(1, sr // 50) + rms = librosa.feature.rms(y=y, frame_length=hop * 2, hop_length=hop)[0] + if len(rms) == 0: + return anchors + frame_times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=hop) + snap_window = target_sec * SNAP_WINDOW_FRAC + out: List[Tuple[float, float]] = [] + for i, (s, e) in enumerate(anchors): + if i > 0: + mask = (frame_times >= s - snap_window) & (frame_times <= s + snap_window) + if mask.any(): + local_idx = int(np.argmin(rms[mask])) + # Map masked-index back to absolute time. + masked_times = frame_times[mask] + s = float(masked_times[local_idx]) + out.append((s, e)) + return out + + +def plan_slices( + audio_path: Path, + target_sec: float, + overlap_sec: float, + strategy: SliceStrategy, +) -> List[SlicePlan]: + """Compute the (start, end) for each child without writing anything yet.""" + if strategy not in VALID_STRATEGIES: + raise ValueError(f"Unknown strategy: {strategy}") + import librosa + # Use mono for boundary detection only; final write uses the original. + y, sr = librosa.load(str(audio_path), sr=22050, mono=True) + duration = float(len(y) / sr) if len(y) else 0.0 + if duration <= 0: + raise ValueError(f"{audio_path.name} has zero duration") + if duration < target_sec: + # Single child = the whole file. Skip the slice loop entirely. + return [SlicePlan(index=1, start_sec=0.0, end_sec=duration)] + + anchors = _uniform_anchors(duration, target_sec, overlap_sec) + if strategy == "transient": + anchors = _snap_to_onsets(anchors, y, sr, target_sec) + elif strategy == "silence": + anchors = _snap_to_silence(anchors, y, sr, target_sec) + + return [ + SlicePlan(index=i + 1, start_sec=s, end_sec=e) + for i, (s, e) in enumerate(anchors) + ] + + +def write_slices( + audio_path: Path, + plans: List[SlicePlan], + out_dir: Path, + stem: str, +) -> List[Path]: + """Write children as `__001.wav`, `__002.wav`, ... in `out_dir`. + + Uses soundfile for lossless WAV write at the source's native sample rate. + Skips names that already exist on disk to avoid clobbering. + """ + import soundfile as sf + import numpy as np + + info = sf.info(str(audio_path)) + sr = info.samplerate + total_frames = info.frames + written: List[Path] = [] + width = max(3, len(str(len(plans)))) + + with sf.SoundFile(str(audio_path)) as src: + for plan in plans: + start_frame = max(0, int(plan.start_sec * sr)) + end_frame = min(total_frames, int(plan.end_sec * sr)) + if end_frame <= start_frame: + logger.warning("Skipping empty slice %s [%.2f-%.2f]", plan.index, plan.start_sec, plan.end_sec) + continue + src.seek(start_frame) + data = src.read(end_frame - start_frame, dtype="float32", always_2d=True) + + child_name = f"{stem}__{plan.index:0{width}d}.wav" + child_path = out_dir / child_name + if child_path.exists(): + # Don't silently overwrite; bump the suffix until free. + k = 2 + while True: + candidate = out_dir / f"{stem}__{plan.index:0{width}d}_{k}.wav" + if not candidate.exists(): + child_path = candidate + break + k += 1 + sf.write(str(child_path), data, sr, subtype="PCM_16") + written.append(child_path) + return written diff --git a/app/core/audio/midi_input.py b/app/core/audio/midi_input.py new file mode 100644 index 0000000000000000000000000000000000000000..fce221f5222afd73c0f706f3146bc9424aed41d1 --- /dev/null +++ b/app/core/audio/midi_input.py @@ -0,0 +1,172 @@ +"""Native MIDI input. + +Reads hardware MIDI via python-rtmidi (CoreMIDI on macOS, WinMM on Windows, +ALSA on Linux) so MIDI works regardless of the web engine the OS gives us — +WKWebView has no Web MIDI, WebView2's is flaky. Same pattern as the native +Ableton Link binding in link_sync.py: wrap an optional native lib and no-op +gracefully if it isn't importable. + +The backend owns the *transport* only: it enumerates input ports, opens one, +and broadcasts incoming messages to subscribers (drained by the SSE endpoint +in app.py). All mapping / learn / takeover logic stays in the frontend +MidiContext — it just consumes these events instead of Web MIDI. +""" +from __future__ import annotations + +import ctypes +import glob +import os +import queue +import sys +import threading +from typing import Any, Dict, List, Optional + + +def _preload_bundled_jack() -> None: + """Work around a broken RPATH in python-rtmidi's manylinux wheel. + + The wheel bundles libjack as `python_rtmidi/libjack-.so.*`, but the + `_rtmidi` extension's RPATH points at a directory that doesn't exist + (`$ORIGIN/../python_rtmidi.` — note the stray trailing dot), so the loader + can't find it and `import rtmidi` dies with + `ImportError: libjack-.so...: cannot open shared object file`. + + The bundled lib's soname matches the extension's DT_NEEDED exactly, so + dlopen'ing it with RTLD_GLOBAL first lets the loader satisfy the dependency + from the already-loaded object. Doing it here (rather than patching the + venv) survives a pip reinstall and needs no patchelf/root. Linux-only; a + no-op everywhere the glob finds nothing. + """ + if not sys.platform.startswith("linux"): + return + for base in sys.path: + if not base or not os.path.isdir(base): + continue + for lib in glob.glob(os.path.join(base, "python_rtmidi*", "libjack-*.so*")): + try: + ctypes.CDLL(lib, mode=ctypes.RTLD_GLOBAL) + except OSError: + pass + + +try: + import rtmidi # python-rtmidi + _RTMIDI_OK = True +except Exception: # pragma: no cover - import guard + # Most likely the bundled-libjack RPATH bug — preload it and retry once. + try: + _preload_bundled_jack() + import rtmidi + _RTMIDI_OK = True + except Exception: + rtmidi = None + _RTMIDI_OK = False + +_lock = threading.Lock() +_midi_in: Any = None # the open rtmidi.MidiIn, or None +_current_port: Optional[str] = None # name of the open port, or None +_subscribers: List["queue.Queue"] = [] + + +def is_available() -> bool: + """True if the native MIDI backend is importable.""" + return _RTMIDI_OK + + +def list_inputs() -> List[Dict[str, Any]]: + """Enumerate input ports. `id` is the port name (stable across index + shuffles); `index` is its current rtmidi index.""" + if not _RTMIDI_OK: + return [] + mi = rtmidi.MidiIn() + try: + names = mi.get_ports() + finally: + mi.delete() + return [{"id": name, "name": name, "index": i} for i, name in enumerate(names)] + + +def current_port() -> Optional[str]: + with _lock: + return _current_port + + +def _on_message(event, _data=None) -> None: + """rtmidi callback (runs on its own thread). `event` is (message, delta). + Broadcast the raw status/data bytes so the frontend can reuse its existing + Web-MIDI-shaped dispatcher unchanged.""" + message, _delta = event + payload = {"data": list(message)} + with _lock: + subs = list(_subscribers) + for q in subs: + try: + q.put_nowait(payload) + except queue.Full: + pass # slow consumer — drop rather than block the MIDI thread + + +def close_input() -> None: + global _midi_in, _current_port + with _lock: + mi = _midi_in + _midi_in = None + _current_port = None + if mi is not None: + try: + mi.cancel_callback() + except Exception: + pass + try: + mi.close_port() + except Exception: + pass + try: + mi.delete() + except Exception: + pass + + +def open_input(port_id: Optional[str]) -> bool: + """Open the input port whose name == port_id. A falsy port_id just closes + the current port. Returns True on success (or on a pure close).""" + if not _RTMIDI_OK: + return False + close_input() + if not port_id: + return True + + mi = rtmidi.MidiIn() + idx = None + for i, name in enumerate(mi.get_ports()): + if name == port_id: + idx = i + break + if idx is None: + mi.delete() + return False + + mi.open_port(idx) + # Drop sysex / timing-clock / active-sensing so the stream stays to the + # control messages the mapper cares about (CC + notes). + mi.ignore_types(sysex=True, timing=True, active_sense=True) + mi.set_callback(_on_message) + + global _midi_in, _current_port + with _lock: + _midi_in = mi + _current_port = port_id + return True + + +def subscribe() -> "queue.Queue": + q: "queue.Queue" = queue.Queue(maxsize=512) + with _lock: + _subscribers.append(q) + return q + + +def unsubscribe(q: "queue.Queue") -> None: + with _lock: + if q in _subscribers: + _subscribers.remove(q) diff --git a/app/core/config.py b/app/core/config.py index 603c6367e357e32a8a50f54bf4f99036b7b45561..23b91c7a2653bbaedec061b95ed482ed20f6c463 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Dict, Any, Optional import json + class ProjectConfig: def __init__(self, project_root: Optional[Path] = None) -> None: @@ -18,11 +19,11 @@ class ProjectConfig: self.user_data_dir = Path.home() / "Library" / "Application Support" / "FragmentaDesktop" else: self.user_data_dir = Path.home() / ".local" / "share" / "FragmentaDesktop" - + self.user_data_dir.mkdir(parents=True, exist_ok=True) print(f"Running in frozen mode. Project root: {self.project_root}") print(f"User data directory: {self.user_data_dir}") - + else: self.frozen = False if project_root is None: @@ -37,123 +38,52 @@ class ProjectConfig: break else: project_root = config_file_dir - + self.project_root: Path = Path(project_root).resolve() self.user_data_dir = self.project_root fine_tuned_override = os.environ.get("FRAGMENTA_FINE_TUNED_DIR") fine_tuned_dir = Path(fine_tuned_override) if fine_tuned_override else self.user_data_dir / "models" / "fine_tuned" - data_override = os.environ.get("FRAGMENTA_DATA_DIR") - data_dir = Path(data_override) if data_override else self.user_data_dir / "data" + # Scratch area for browser folder uploads (/api/upload-folder). The + # SA2-era "data" dataset directory is gone in 0.2.0 — datasets are now + # Dataset Workbench projects under projects/. + uploads_override = os.environ.get("FRAGMENTA_UPLOADS_DIR") + uploads_dir = Path(uploads_override) if uploads_override else self.user_data_dir / "uploads" self.paths: Dict[str, Path] = { "models": self.user_data_dir / "models", "models_config": self.user_data_dir / "models" / "config", "models_pretrained": self.user_data_dir / "models" / "pretrained", "models_fine_tuned": fine_tuned_dir, - "data": data_dir, + "uploads": uploads_dir, "logs": self.user_data_dir / "logs", "output": self.user_data_dir / "output", "application": self.project_root, "backend": self.project_root / "app" / "backend", "frontend": self.project_root / "app" / "frontend", - "stable_audio_tools": self.project_root / "vendor" / "stable-audio-tools", - "loraw_vendor": self.project_root / "vendor" / "loraw_vendor", + "stable_audio_3": self.project_root / "vendor" / "stable-audio-3", "venv": self.project_root / "venv", } self._ensure_directories() - self.model_configs: Dict[str, Dict[str, str] - ] = self._load_model_configs() + # The SA3 catalog lives in app/core/model_manager.py. This dict stays + # empty; it's retained only because to_dict()/print_paths() and the + # config validator still reference it. + self.model_configs: Dict[str, Dict[str, str]] = {} def _ensure_directories(self) -> None: for path_name, path in self.paths.items(): - if path_name.endswith(('_fine_tuned', 'data')): + if path_name.endswith(('_fine_tuned', 'uploads')): path.mkdir(parents=True, exist_ok=True) - def _load_model_configs(self) -> Dict[str, Dict[str, str]]: - - return { - "stable-audio-open-1.0": { - "config": str(self.paths["models_config"] / "model_config.json"), - "ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-model.safetensors") - }, - "stable-audio-open-small": { - "config": str(self.paths["models_config"] / "model_config_small.json"), - "ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-small-model.safetensors") - }, - "custom": { - "config": str(self.paths["models_config"] / "model_config_small.json"), - "ckpt": str(self.paths["models_pretrained"] / "stable-audio-open-small-model.safetensors") - } - } - def get_path(self, path_name: str) -> Path: if path_name not in self.paths: raise ValueError(f"Unknown path name: {path_name}") return self.paths[path_name] - def get_model_config(self, model_name: str) -> Dict[str, str]: - if model_name not in self.model_configs: - raise ValueError(f"Unknown model: {model_name}") - return self.model_configs[model_name] - - def get_dataset_config_path(self) -> str: - return str(self.paths["models_config"] / "dataset-config.json") - - def get_custom_metadata_path(self) -> str: - return str(self.project_root / "vendor" / "stable-audio-tools" / "custom_metadata.py") - - def get_metadata_json_path(self) -> str: - return str(self.paths["data"] / "metadata.json") - - def update_dataset_config(self) -> None: - from app.backend.data.simple_audio_processor import SimpleAudioProcessor - - try: - processor = SimpleAudioProcessor( - model_config_path=self.paths["models_config"] / "model_config.json" - ) - - result = processor.create_dataset_config( - input_dir=self.paths["data"], - output_dir=self.paths["data"] - ) - - target_config = self.paths["models_config"] / "dataset-config.json" - with open(target_config, 'w') as f: - json.dump(result["dataset_config"], f, indent=4) - - print(f"Updated dataset config: {target_config}") - print(f"Points to {result['file_count']} original audio files") - print(f"Sample size: {result['sample_size']} samples ({result['sample_size']/result['sample_rate']:.1f}s)") - print(f"Random cropping during training (correct!)") - - except Exception as e: - print(f"Failed to update dataset config: {e}") - print("Falling back to basic dataset config...") - - dataset_config: Dict[str, Any] = { - "dataset_type": "audio_dir", - "datasets": [ - { - "id": "fine_tune_data", - "path": str(self.paths["data"]), - "custom_metadata_module": "custom_metadata" - } - ], - "random_crop": True - } - - config_path = self.paths["models_config"] / "dataset-config.json" - with open(config_path, 'w') as f: - json.dump(dataset_config, f, indent=4) - - print(f"Updated fallback dataset config: {config_path}") - def to_dict(self) -> Dict[str, Any]: return { "project_root": str(self.project_root), diff --git a/app/core/generation/audio_generator.py b/app/core/generation/audio_generator.py index 947d562596c5039e3d89db57e31faad393fbc7a0..21c0fa7f55d3e7a91ceaa6dd410b39f9e8508842 100644 --- a/app/core/generation/audio_generator.py +++ b/app/core/generation/audio_generator.py @@ -1,519 +1,536 @@ -import torch -import soundfile as sf -import numpy as np -from pathlib import Path -from typing import Dict, Any, Optional, List, Tuple -import logging +"""SA3 inference engine. + +Thin wrapper around stable_audio_3.StableAudioModel.from_pretrained() that +caches the loaded model between requests (eviction on model_id change), +auto-detects the device, and writes 44.1 kHz stereo int16 WAV. + +Cancellation is wired via `request_stop()` for API parity, but SA3's +generate() doesn't expose a per-step callback yet — the flag is checked +between calls, not inside them. A finer-grained cancel hook is a Phase +3.1 follow-up. +""" +import os +import platform import re import sys import threading import time import warnings -from datetime import datetime - +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple -class GenerationStopped(Exception): - """Raised by the per-step callback when a stop has been requested.""" - pass +import numpy as np +import soundfile as sf +import torch +from utils.logger import get_logger + +logger = get_logger("AudioGenerator") + + +# Live progress from the SA3 sampler. SA3's `model.generate(**sampler_kwargs)` +# forwards `callback=fn` into the sampler, which fires it per ODE step with +# `{'i': step_index, ...}`. We mirror that into this dict so the frontend can +# poll real progress instead of a fake ticker. Reset on each new generation. +_generation_state: Dict[str, Any] = { + "is_generating": False, + # idle | loading | sampling | decoding | complete | failed + "phase": "idle", + "step": 0, + "total_steps": 0, + "progress": 0, # 0-100, derived + "batch_index": 0, + "batch_total": 0, + "started_at": None, + "ended_at": None, + "error": None, +} +_generation_state_lock = threading.Lock() + + +def get_generation_progress() -> Dict[str, Any]: + """Snapshot of the current generation's live progress. Cheap to call.""" + with _generation_state_lock: + return dict(_generation_state) + + +def _set_progress(**kwargs: Any) -> None: + """Merge fields into _generation_state under the lock. Recomputes + `progress` automatically when step/total_steps land in the same update.""" + with _generation_state_lock: + _generation_state.update(kwargs) + total = int(_generation_state.get("total_steps") or 0) + step = int(_generation_state.get("step") or 0) + _generation_state["progress"] = ( + int(round(100 * step / total)) if total > 0 else 0 + ) + + +def _reset_progress() -> None: + with _generation_state_lock: + _generation_state.update({ + "is_generating": False, "phase": "idle", + "step": 0, "total_steps": 0, "progress": 0, + "batch_index": 0, "batch_total": 0, + "started_at": None, "ended_at": None, "error": None, + }) + +# Vendored SA3 lives at /vendor/stable-audio-3 — put it on sys.path so +# `import stable_audio_3` resolves without a global pip install. +_SA3_VENDOR = Path(__file__).resolve().parents[3] / "vendor" / "stable-audio-3" +if str(_SA3_VENDOR) not in sys.path: + sys.path.insert(0, str(_SA3_VENDOR)) + + +# model_id -> (sa3_name passed to StableAudioModel.from_pretrained, +# "user-visible or base" tag, max duration seconds). +# Kept in sync manually with _SA3_CATALOG in app/core/model_manager.py. +_MODEL_INFO: Dict[str, Tuple[str, str, int]] = { + "sa3-small-music": ("small-music", "post", 120), + "sa3-small-sfx": ("small-sfx", "post", 120), + "sa3-medium": ("medium", "post", 380), + "sa3-small-music-base": ("small-music-base", "base", 120), + "sa3-small-sfx-base": ("small-sfx-base", "base", 120), + "sa3-medium-base": ("medium-base", "base", 380), +} -def _slugify_prompt(text: str, max_len: int = 40) -> str: - s = re.sub(r'[^a-zA-Z0-9]+', '_', text.strip().lower()) - s = re.sub(r'_+', '_', s).strip('_') - return s[:max_len] or 'untitled' -sys.path.append( - str(Path(__file__).parent.parent.parent.parent / "vendor" / "stable-audio-tools")) -# LoRAW lives at /vendor/loraw_vendor; expose its `loraw` package for inference. -sys.path.append( - str(Path(__file__).parent.parent.parent.parent / "vendor" / "loraw_vendor")) +class GenerationStopped(Exception): + """Raised when an in-flight generation is interrupted by a stop request.""" -warnings.filterwarnings( - "ignore", - message=r"pkg_resources is deprecated as an API.*", - category=UserWarning, -) +def _slugify(text: str, max_len: int = 40) -> str: + s = re.sub(r"[^a-zA-Z0-9_-]+", "_", text or "") + return s[:max_len].strip("_").lower() or "audio" -from stable_audio_tools.models.utils import load_ckpt_state_dict -from stable_audio_tools.inference.generation import generate_diffusion_cond -from stable_audio_tools.models import create_model_from_config -from loraw.network import create_lora_from_config -logger = logging.getLogger(__name__) +def _autodetect_device() -> str: + """cuda → mps → cpu, with FRAGMENTA_FORCE_DEVICE override.""" + override = os.environ.get("FRAGMENTA_FORCE_DEVICE") + if override: + return override + if torch.cuda.is_available(): + return "cuda" + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + return "mps" + return "cpu" class AudioGenerator: - def __init__(self, config): - self.config = config - self.model = None - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.current_model_name = None - self.current_model_path = None - self.current_model_key = None - self.is_distilled_small = False - self.is_fine_tuned = False - # LoRA state. `lora` holds the LoRAWrapper instance when one is active; - # `_active_lora_path` / `_active_lora_multiplier` are used (along with - # the base-model identifier) in `current_model_key` so the cache - # invalidates whenever the LoRA selection changes — forcing a fresh - # base reload because LoRAW's `activate()` is not reversible in-place. - self.lora = None - self._active_lora_path = None - self._active_lora_multiplier = 1.0 - self._stop_event = threading.Event() - logger.info(f"Using device: {self.device}") - - def _apply_lora(self, lora_path: str, lora_config: Dict[str, Any], multiplier: float = 1.0): - """Wrap the currently-loaded base model with a LoRA from LoRAW. - - Caller is responsible for ensuring the base model is fresh (no prior - LoRA injected) — typically by routing through `generate_audio`'s cache - invalidation, which reloads the base when the LoRA selection changes. - """ - if self.model is None: - raise RuntimeError("Base model must be loaded before applying a LoRA") - # torch.compile wraps in OptimizedModule, which prefixes named_modules() - # with `_orig_mod/`. LoRAW's saved state has no such prefix (training - # didn't compile). Operate on the underlying module so scan_model keys - # match the checkpoint exactly. The compiled wrapper still dispatches - # forward through this same module, so the LoRA stays active. - target = getattr(self.model, "_orig_mod", self.model) - full_config = { - "model_type": getattr(target, "model_type", "diffusion_cond"), - "lora": lora_config, - } - self.lora = create_lora_from_config(full_config, target) - state = torch.load(lora_path, map_location=self.device) - self.lora.load_weights(state, multiplier=multiplier) - self.lora.activate() - self._active_lora_path = lora_path - self._active_lora_multiplier = multiplier - logger.info(f"LoRA applied: {Path(lora_path).name} (multiplier={multiplier})") + """One-model warm cache. Reload only when model_id changes.""" + def __init__(self, config: Any) -> None: + self.config = config + self.model: Any = None + self._model_id: Optional[str] = None + self._device: Optional[str] = None + self._stop_requested: bool = False + # Tracks LoRAs currently injected into self.model. List of + # {"path": str, "strength": float}. Empty when no LoRAs are active. + self._loaded_loras: list = [] + + # --- cooperative cancel --------------------------------------------------- def request_stop(self) -> bool: - """Signal the in-flight diffusion loop (if any) to abort at the next step.""" - already_set = self._stop_event.is_set() - self._stop_event.set() - return not already_set - - def load_local_base_model(self, model_name: str = "stable-audio-open-small") -> bool: - try: - logger.info(f"Loading local base model: {model_name}") - - self.current_model_name = model_name - - from stable_audio_tools.models.factory import create_model_from_config - from stable_audio_tools.models.utils import load_ckpt_state_dict - if "small" in model_name: - config_file = "model_config_small.json" - else: - config_file = "model_config.json" - self.is_distilled_small = "small" in model_name.lower() - self.is_fine_tuned = False - - config_path = Path(__file__).parent.parent.parent.parent / "models" / "config" / config_file - logger.info(f"Using config file: {config_path}") - - with open(config_path, 'r') as f: - import json - model_config = json.load(f) - - self.model = create_model_from_config(model_config) - if model_name == 'stable-audio-open-small': - model_file_name = 'stable-audio-open-small-model.safetensors' - elif model_name == 'stable-audio-open-1.0': - model_file_name = 'stable-audio-open-model.safetensors' - else: - model_file_name = f"{model_name}-model.safetensors" - - model_file = Path(__file__).parent.parent.parent.parent / "models" / "pretrained" / model_file_name - self.current_model_path = str(model_file) - logger.info(f"Loading weights from: {model_file}") - - if not model_file.exists(): - raise FileNotFoundError(f"Local model file not found: {model_file}") - - state_dict = load_ckpt_state_dict(str(model_file)) - self.model.load_state_dict(state_dict, strict=False) - - self.model = self.model.to(self.device) - self.model.eval() - self.model.requires_grad_(False) - if self.device.startswith("cuda"): - self.model = torch.compile(self.model, mode="reduce-overhead") - - logger.info("Local base model loaded successfully") - return True - - except Exception as e: - logger.error(f"Failed to load local base model: {e}") + if self._stop_requested: return False + self._stop_requested = True + return True - def load_model(self, model_path: Optional[Path] = None) -> bool: - try: - print(f"Loading model from {model_path}") - - if model_path is None: - return self.load_local_base_model("stable-audio-open-small") - else: - safetensors_files = list(model_path.glob("*.safetensors")) - if safetensors_files: - unwrapped_path = str(safetensors_files[0]) - print(f"Found safetensors file: {unwrapped_path}") - return self.load_unwrapped_model(unwrapped_path) - else: - print(f"No safetensors files found in {model_path}, using local base model") - return self.load_local_base_model("stable-audio-open-small") - - except Exception as e: - print(f"Failed to load model: {e}") - return False - - def load_unwrapped_model(self, unwrapped_model_path: str, config_file: str = None) -> bool: - try: - print(f"Loading unwrapped model from {unwrapped_model_path}") - - self.current_model_path = unwrapped_model_path - - from stable_audio_tools.models.factory import create_model_from_config - from stable_audio_tools.models.utils import load_ckpt_state_dict - if config_file is None: - config_file = "model_config_small.json" - self.is_distilled_small = "small" in config_file.lower() - - - metadata_path = Path(unwrapped_model_path).parent.parent / "training_metadata.json" - self.is_fine_tuned = metadata_path.exists() - if self.is_fine_tuned: - logger.info( - f"Detected fine-tuned model via {metadata_path}; " - f"using full diffusion sampler recipe instead of distilled 8-step pingpong" + # --- model load ----------------------------------------------------------- + def _ensure_model( + self, + model_id: str, + device: Optional[str] = None, + half: bool = True, + ) -> None: + if model_id not in _MODEL_INFO: + raise ValueError(f"Unknown SA3 model_id: {model_id}") + sa3_name, _kind, _max_dur = _MODEL_INFO[model_id] + + if model_id in ("sa3-medium", "sa3-medium-base"): + # Medium normally requires Flash Attention 2 for its long-form (up + # to 380s) sliding-window attention. FRAGMENTA_MEDIUM_NO_FLASH=1 is + # the Path-B validation switch: it lets medium load WITHOUT + # flash_attn and fall back to PyTorch-native attention + # (flex_attention -> chunked-halo SDPA -> masked SDPA; see + # transformer.apply_attn). Output is math-equivalent, but VRAM is + # higher and sampling slower at long durations. Off by default, so + # the shipped behaviour is unchanged until the fallback is validated. + allow_no_flash = os.environ.get("FRAGMENTA_MEDIUM_NO_FLASH") == "1" + try: + import flash_attn # noqa: F401 + have_flash = True + except ImportError as err: + have_flash = False + _flash_err = err + + if not have_flash and not allow_no_flash: + if platform.system() == "Windows": + raise RuntimeError( + "sa3-medium requires Flash Attention 2, which doesn't " + "have Windows wheels. Use sa3-small-music / sa3-small-sfx, " + "run Fragmenta via Docker on WSL2, or set " + "FRAGMENTA_MEDIUM_NO_FLASH=1 to run on the (slower, " + "higher-memory) PyTorch attention fallback." + ) from _flash_err + raise RuntimeError( + "sa3-medium needs Flash Attention 2 (flash_attn) but the " + f"current install is unusable: {_flash_err}.\n" + "Pick the wheel matching your torch+ABI+Python+CUDA from\n" + " https://github.com/Dao-AILab/flash-attention/releases\n" + "and install with `pip install --no-deps `. " + "See the note next to flash-attn in requirements.txt for an example.\n" + "Or set FRAGMENTA_MEDIUM_NO_FLASH=1 to use the PyTorch " + "attention fallback." + ) from _flash_err + + if not have_flash: + logger.warning( + "sa3-medium loading WITHOUT Flash Attention 2 " + "(FRAGMENTA_MEDIUM_NO_FLASH=1). Using the PyTorch-native " + "attention fallback — expect higher VRAM and slower sampling " + "at long durations. Validate memory headroom before " + "generating long-form (up to 380s) clips." ) - config_path = Path(__file__).parent.parent.parent.parent / \ - "models" / "config" / config_file - print(f"Using config file: {config_path}") - - with open(config_path, 'r') as f: + device = device or _autodetect_device() + if ( + self.model is not None + and self._model_id == model_id + and self._device == device + ): + return # warm cache hit + + if self.model is not None: + del self.model + self.model = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Two layouts to support during the unification transition: + # 1. Canonical (post-Phase 5c): HF cache layout rooted at + # /models/pretrained/sa3/hub/. model_manager sets + # HF_HUB_CACHE to that path, so StableAudioModel.from_pretrained + # finds files there without going to ~/.cache/huggingface. + # 2. Legacy: /models/pretrained/sa3// flat layout + # from earlier downloads. We fall back to direct load so + # pre-existing users don't have to re-download. + # + # Defense-in-depth: re-force the HF cache vars here too. model_manager + # sets them at construction, but if generation is reached via an + # alternate code path or the env was clobbered later, we still + # guarantee resolution into /sa3/hub/. + hub_dir = self.config.get_path("models_pretrained") / "sa3" / "hub" + hf_env_keys = ("HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE", + "TRANSFORMERS_CACHE", "HF_HUB_OFFLINE") + prev_env = {k: os.environ.get(k) for k in hf_env_keys} + os.environ["HF_HUB_CACHE"] = str(hub_dir) + os.environ["HUGGINGFACE_HUB_CACHE"] = str(hub_dir) + os.environ["TRANSFORMERS_CACHE"] = str(hub_dir) + os.environ["HF_HUB_OFFLINE"] = "1" + # huggingface_hub captures HF_HUB_CACHE and HF_HUB_OFFLINE as + # module-level constants AT IMPORT TIME. The Flask backend imports + # huggingface_hub (transitively, via model_manager.py) before we ever + # set these env vars, so the constants point at ~/.cache/huggingface/ + # and offline=False. Setting os.environ now has no effect on already- + # captured constants. We have to monkey-patch them directly. + # Same trick we used for the CLAP loader. + prev_hub_constants = {} + try: + import huggingface_hub.constants as _hf_const + prev_hub_constants = { + "HF_HUB_CACHE": _hf_const.HF_HUB_CACHE, + "HF_HUB_OFFLINE": _hf_const.HF_HUB_OFFLINE, + } + _hf_const.HF_HUB_CACHE = str(hub_dir) + _hf_const.HF_HUB_OFFLINE = True + except Exception: + _hf_const = None + try: + try: + from stable_audio_3 import StableAudioModel + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.model = StableAudioModel.from_pretrained( + sa3_name, device=device, model_half=half, + ) + except (FileNotFoundError, OSError) as primary_err: + # HF cache miss — fall back to flat layout. + legacy_dir = self.config.get_path("models_pretrained") / "sa3" / model_id + config_path = legacy_dir / "model_config.json" + ckpt_path = legacy_dir / "model.safetensors" + if not (config_path.exists() and ckpt_path.exists()): + raise FileNotFoundError( + f"Checkpoint '{model_id}' not found in HF cache " + f"({os.environ.get('HF_HUB_CACHE')}) or legacy flat " + f"layout ({legacy_dir}). Download it from the " + f"Checkpoint Manager." + ) from primary_err import json - model_config = json.load(f) + with open(config_path) as fh: + model_config = json.load(fh) + from stable_audio_3.loading_utils import load_diffusion_cond + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + inner = load_diffusion_cond( + model_config, str(ckpt_path), + device=device, model_half=half, + ) + inner.use_lora = False + inner.lora_names = [] + self.model = StableAudioModel(inner, model_config, device, half) + finally: + for k, v in prev_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + # Restore the patched constants so we don't permanently alter + # global huggingface_hub state for anything else in-process. + if _hf_const is not None and prev_hub_constants: + _hf_const.HF_HUB_CACHE = prev_hub_constants["HF_HUB_CACHE"] + _hf_const.HF_HUB_OFFLINE = prev_hub_constants["HF_HUB_OFFLINE"] + self._model_id = model_id + self._device = device + + # --- LoRA stack ----------------------------------------------------------- + def _apply_loras(self, loras: list) -> None: + """Inject the given LoRA stack into self.model (idempotent). + + loras: [{"path": str, "strength": float}, ...] + + Strategy: + * Same paths in same order → just update strengths in place. + * Different paths → remove all, load fresh. + """ + if self.model is None: + return + + new_paths = [l["path"] for l in loras] + cur_paths = [l["path"] for l in self._loaded_loras] - self.model = create_model_from_config(model_config) + if new_paths == cur_paths: + # Path-set unchanged; only strengths may have moved. + for i, l in enumerate(loras): + self.model.set_lora_strength(l["strength"], lora_index=i) + self._loaded_loras = list(loras) + return - state_dict = load_ckpt_state_dict(unwrapped_model_path) - self.model.load_state_dict(state_dict, strict=False) + # Path-set changed. Remove any currently loaded, then load the new set. + if cur_paths: + try: + from stable_audio_3.models.lora import remove_lora + # SA3 applies LoRA to the DiffusionCond's DiT (.model) and + # conditioner (.conditioner) — mirror StableAudioModel's own + # set_lora_strength which iterates both submodules. + # `self.model` is StableAudioModel; `self.model.model` is the + # inner DiffusionCond. + # + # remove_lora() strips *every* LoRA parametrization in one + # pass. We use it instead of remove_lora_by_index(..., 0) in a + # loop: removal does NOT renumber the remaining adapters, so + # repeatedly popping index 0 only ever clears the first LoRA + # and leaves indices 1..n-1 stranded — stale adapters then + # contaminate every later generation with a different stack. + inner = self.model.model + remove_lora(inner.model) + remove_lora(inner.conditioner) + except Exception as exc: + # If removal fails (e.g. an upstream API change), force a + # base-model reload so we don't carry stale adapters. KEEP + # _model_id intact — _ensure_model needs it to know what to + # reload. (Previous code zeroed it; the reload then raised + # "Unknown SA3 model_id: None".) + logger.warning( + "LoRA removal failed (%s); reloading base model %s", + exc, self._model_id, + ) + self.model = None - self.model = self.model.to(self.device) - self.model.eval() - self.model.requires_grad_(False) + if self.model is None and self._model_id is not None: + # Forced full reload (only if remove failed above). + self._ensure_model(self._model_id, device=self._device, half=True) - if self.device.startswith("cuda"): - self.model = torch.compile(self.model, mode="reduce-overhead") + if loras: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.model.load_lora(new_paths) + for i, l in enumerate(loras): + self.model.set_lora_strength(l["strength"], lora_index=i) - print(f"AUDIO GENERATOR: Unwrapped model loaded successfully") - return True + self._loaded_loras = list(loras) - except Exception as e: - print(f"Failed to load unwrapped model: {e}") + def set_lora_strength(self, index: int, strength: float) -> bool: + """Live-update one slot's strength. Returns False if index invalid.""" + if not self.model or index < 0 or index >= len(self._loaded_loras): return False + self.model.set_lora_strength(float(strength), lora_index=index) + self._loaded_loras[index]["strength"] = float(strength) + return True + # --- public entry --------------------------------------------------------- def generate_audio( self, prompt: str, - model_path: Optional[Path] = None, - unwrapped_model_path: Optional[str] = None, - config_file: Optional[str] = None, + *, + model_id: str, duration: float = 10.0, - cfg_scale: float = 7.0, - steps: int = 250, + steps: Optional[int] = None, + cfg_scale: Optional[float] = None, seed: int = -1, - output_path: Optional[Path] = None, - batch_index: int = 1, - batch_total: int = 1, - loop_mode: bool = False, - lora_path: Optional[str] = None, - lora_config: Optional[Dict[str, Any]] = None, - lora_multiplier: float = 1.0, + negative_prompt: Optional[str] = None, + batch_size: int = 1, + device: Optional[str] = None, + half: bool = True, + chunked_decode: Optional[bool] = None, + loop_mode: bool = False, # bars-mode passthrough + loras: Optional[list] = None, # [{path, strength}, ...] + # Phase 7: audio-to-audio + inpainting ----------------------------- + init_audio_path: Optional[str] = None, + init_noise_level: float = 1.0, + inpaint_audio_path: Optional[str] = None, + inpaint_starts: Optional[list] = None, # list[float], seconds + inpaint_ends: Optional[list] = None, + # Phase 7: seamless looping ---------------------------------------- + loop_stitch: Optional[str] = None, # "inpaint" | "crossfade" | None + loop_bars: Optional[int] = None, + loop_bpm: Optional[float] = None, + **_ignored_legacy_kwargs: Any, ) -> Path: - print(f"\nAUDIO GENERATOR: generate_audio called") - print(f" - Prompt: '{prompt}'") - print(f" - Duration: {duration}s") - if lora_path: - print(f" - LoRA: {lora_path} (×{lora_multiplier})") - - # The cache key includes LoRA selection so the base reloads whenever - # the LoRA changes (LoRAW's activate() is not reversible in-place; - # the only safe way to drop or swap a LoRA is to reload the base). - lora_signature = (lora_path, lora_multiplier) if lora_path else (None, 1.0) - if unwrapped_model_path: - target_key = ('unwrapped', str(unwrapped_model_path), lora_signature) - elif model_path: - target_key = ('path', str(model_path), lora_signature) - else: - target_key = ('default', 'stable-audio-open-small', lora_signature) - - if self.model is not None and self.current_model_key == target_key: - print(f"AUDIO GENERATOR: Reusing already-loaded model") - else: - print(f"AUDIO GENERATOR: Loading new model") - # Reset any prior LoRA state — load_*_model rebuilds self.model fresh. - self.lora = None - self._active_lora_path = None - self._active_lora_multiplier = 1.0 - - if unwrapped_model_path: - print(f"AUDIO GENERATOR: Loading unwrapped model from {unwrapped_model_path}") - if not self.load_unwrapped_model(unwrapped_model_path, config_file): - raise ValueError(f"Failed to load unwrapped model from {unwrapped_model_path}") - elif model_path: - model_path_str = str(model_path) - print(f"AUDIO GENERATOR: Checking model path: {model_path_str}") - - if "stable-audio-open-small" in model_path_str: - print(f"AUDIO GENERATOR: Loading local small base model") - if not self.load_local_base_model("stable-audio-open-small"): - raise ValueError("Failed to load local small base model") - elif "stable-audio-open-model" in model_path_str: - print(f"AUDIO GENERATOR: Loading local large base model") - if not self.load_local_base_model("stable-audio-open-1.0"): - raise ValueError("Failed to load local large base model") - else: - print(f"AUDIO GENERATOR: Loading fine-tuned model from {model_path}") - if not self.load_model(model_path): - raise ValueError(f"Failed to load model from {model_path}") - else: - print(f"AUDIO GENERATOR: Loading default local small base model") - if not self.load_local_base_model("stable-audio-open-small"): - raise ValueError("Failed to load default local base model") - - # Attach the LoRA (if requested) onto the freshly loaded base. - if lora_path: - if not lora_config: - raise ValueError("lora_config required when lora_path is set") - self._apply_lora(lora_path, lora_config, lora_multiplier) - - self.current_model_key = target_key - - print(f"AUDIO GENERATOR: Model loaded successfully") - - self._stop_event.clear() - - def _stop_callback(state): - if self._stop_event.is_set(): - raise GenerationStopped("Stop requested mid-diffusion") - - try: - # Three recipes, picked by what the loaded weights actually are: - # 1. Original distilled small — rectified-flow + CFG distillation - # baked in. Requires pingpong / 8 steps / CFG 1.0. - # 2. Fine-tuned small — distillation destroyed by SFT but the - # objective is still rectified-flow, so the sampler name must - # come from the rectified-flow family (euler|rk4|dpmpp|pingpong), - # NOT from the v-diffusion family. Use external CFG. - # 3. Large model — standard v-diffusion, accepts dpmpp-3m-sde. - use_distilled_recipe = self.is_distilled_small and not self.is_fine_tuned - if use_distilled_recipe: - effective_sampler = "pingpong" - effective_steps = 8 - effective_cfg = 1.0 - sigma_kwargs = {} - elif self.is_distilled_small: - effective_sampler = "dpmpp" - effective_steps = steps - effective_cfg = cfg_scale - sigma_kwargs = {"sigma_max": 1.0} - else: - effective_sampler = "dpmpp-3m-sde" - effective_steps = steps - effective_cfg = cfg_scale - sigma_kwargs = {"sigma_min": 0.03, "sigma_max": 1000} - - print(f"Generating audio for prompt: '{prompt}'") - recipe_note = "" - if use_distilled_recipe: - recipe_note = " (distilled small overrides applied)" - elif self.is_fine_tuned and self.is_distilled_small: - recipe_note = " (fine-tuned small: rectified-flow dpmpp + external CFG)" - print( - f"Duration: {duration}s, CFG scale: {effective_cfg}, " - f"Steps: {effective_steps}, Sampler: {effective_sampler}" - + recipe_note - ) - requested_sample_size = int(duration * self.model.sample_rate) - max_sample_size = None - try: - max_sample_size = self.model.sample_size - except AttributeError: - if hasattr(self.model, 'model') and hasattr(self.model.model, 'sample_size'): - max_sample_size = self.model.model.sample_size - else: - config_path = Path(__file__).parent.parent.parent.parent / "models" / "config" - if hasattr(self, 'current_model_name') and self.current_model_name: - if 'small' in self.current_model_name: - config_file = config_path / "model_config_small.json" - else: - config_file = config_path / "model_config.json" - else: - if hasattr(self, 'current_model_path') and self.current_model_path: - model_file = Path(self.current_model_path) - if model_file.exists(): - file_size_gb = model_file.stat().st_size / (1024**3) - if file_size_gb < 2.0: - config_file = config_path / "model_config_small.json" - else: - config_file = config_path / "model_config.json" - else: - config_file = config_path / "model_config_small.json" - else: - config_file = config_path / "model_config_small.json" - - if config_file.exists(): - with open(config_file, 'r') as f: - import json - config_data = json.load(f) - max_sample_size = config_data.get('sample_size', 44100 * 10) - else: - max_sample_size = 44100 * 10 - if max_sample_size and requested_sample_size > max_sample_size: - print(f"Requested duration {duration}s exceeds model maximum. Truncating.") - requested_sample_size = max_sample_size - duration = requested_sample_size / self.model.sample_rate - - if seed == -1: - import numpy as np - seed = np.random.randint(0, 2**32 - 1, dtype=np.int64) - - print(f"Using seed: {seed}") - - if loop_mode and max_sample_size: - song_seconds = max(int(duration), - int(max_sample_size / self.model.sample_rate)) - else: - song_seconds = int(duration) - - conditioning = [{ - "prompt": prompt, - "seconds_start": 0, - "seconds_total": song_seconds, - }] - - device = next(self.model.parameters()).device - print(f"Using device: {device}") - - with warnings.catch_warnings(): - # Known torchsde float-boundary chatter from dpmpp-3m-sde. - warnings.filterwarnings( - "ignore", - message=r"Should have tb<=t1 but got tb=.*", - category=UserWarning, - module=r"torchsde\._brownian\.brownian_interval", + self._stop_requested = False + if self._stop_requested: # honour pre-call stop + raise GenerationStopped() + + # `loop_stitch` / `loop_bars` / `loop_bpm` are accepted for API + # compatibility but ignored — the seamless-loop pipeline was + # removed because user A/B testing showed it degraded audio + # quality on every prompt class. We deliver raw model output. + + _set_progress( + is_generating=True, phase="loading", + step=0, total_steps=0, error=None, + started_at=time.time(), ended_at=None, + ) + + self._ensure_model(model_id, device=device, half=half) + self._apply_loras(loras or []) + + init_audio = self._load_audio(init_audio_path) if init_audio_path else None + inpaint_audio = self._load_audio(inpaint_audio_path) if inpaint_audio_path else None + + _, kind, max_dur = _MODEL_INFO[model_id] + is_base = (kind == "base") + + # Defaults differ by model kind. Post-trained models distilled CFG + # away; we force cfg=1.0 there even if the caller overrides. + effective_steps = int(steps) if steps else (50 if is_base else 8) + effective_cfg = float(cfg_scale) if (cfg_scale is not None and is_base) else ( + 7.0 if is_base else 1.0 + ) + + duration = float(min(max(1.0, float(duration)), float(max_dur))) + + target_samples = int(round(duration * 44100)) + gen_duration = duration + total_steps_logical = effective_steps + + if self._stop_requested: # one more check before the heavy call + raise GenerationStopped() + + # Sampler callback — fires per ODE step. Also gives us a cheap + # cancellation hook: raising mid-callback aborts the sampler. + def _cb(info: Dict[str, Any]) -> None: + if self._stop_requested: + raise GenerationStopped() + i = info.get("i") + if isinstance(i, int): + _set_progress(step=min(i + 1, total_steps_logical)) + + _set_progress(phase="sampling", total_steps=int(total_steps_logical), step=0) + + gen_kwargs = dict( + prompt=prompt, + negative_prompt=negative_prompt or None, + duration=gen_duration, + steps=effective_steps, + cfg_scale=effective_cfg, + seed=int(seed), + batch_size=int(batch_size), + chunked_decode=chunked_decode, + callback=_cb, + ) + if init_audio is not None: + gen_kwargs["init_audio"] = init_audio + gen_kwargs["init_noise_level"] = float(init_noise_level) + if inpaint_audio is not None: + gen_kwargs["inpaint_audio"] = inpaint_audio + if inpaint_starts is not None and len(inpaint_starts) > 0: + # SA3 accepts a single float or a list for multi-region. + gen_kwargs["inpaint_mask_start_seconds"] = ( + list(inpaint_starts) if len(inpaint_starts) > 1 else float(inpaint_starts[0]) ) - warnings.filterwarnings( - "ignore", - message=r"Should have ta>=t0 but got ta=.*", - category=UserWarning, - module=r"torchsde\._brownian\.brownian_interval", - ) - - audio = generate_diffusion_cond( - model=self.model, - steps=effective_steps, - cfg_scale=effective_cfg, - conditioning=conditioning, - batch_size=1, - sample_size=requested_sample_size, - seed=seed, - device=str(device), - sampler_type=effective_sampler, - callback=_stop_callback, - **sigma_kwargs, + if inpaint_ends is not None and len(inpaint_ends) > 0: + gen_kwargs["inpaint_mask_end_seconds"] = ( + list(inpaint_ends) if len(inpaint_ends) > 1 else float(inpaint_ends[0]) ) - print(f"Generation complete, audio shape: {audio.shape}") - - from einops import rearrange - audio = rearrange(audio, "b d n -> d (b n)").to(torch.float32) - audio = audio / audio.abs().max() - audio_int16 = (audio.clamp(-1, 1) * 32767).to(torch.int16).cpu() - - if output_path is None: - output_dir = Path(__file__).parent.parent.parent.parent / "output" - output_dir.mkdir(exist_ok=True) - ts = datetime.now().strftime('%Y%m%d_%H%M%S') - slug = _slugify_prompt(prompt) - suffix = f"_{batch_index}" if batch_total > 1 else "" - output_path = output_dir / f"fragmenta_{ts}_{slug}{suffix}.wav" - - self.save_audio(audio_int16, output_path, self.model.sample_rate) - - print(f"AUDIO GENERATOR: Generation complete") - print(f" - Output file: {output_path}") - print(f" - Output file size: {output_path.stat().st_size} bytes") - - return output_path - + try: + audio = self.model.generate(**gen_kwargs) + # audio: torch.Tensor[B, channels=2, samples] in [-1, 1] @ 44.1 kHz except GenerationStopped: - print("AUDIO GENERATOR: Generation stopped by user request") + _set_progress(phase="idle", is_generating=False, ended_at=time.time()) raise - except Exception as e: - print(f"AUDIO GENERATOR: Error during generation: {str(e)}") - import traceback - traceback.print_exc() + except Exception as exc: + _set_progress(phase="failed", is_generating=False, + error=str(exc), ended_at=time.time()) raise - finally: - self._stop_event.clear() - - def generate_batch( - self, - prompts: List[str], - duration: float = 10.0, - cfg_scale: float = 6.0, - steps: int = 250, - seed: int = -1, - output_dir: Optional[Path] = None - ) -> List[Path]: - results = [] - - for i, prompt in enumerate(prompts): - print(f"Generating audio {i+1}/{len(prompts)}") - current_seed = seed if seed != -1 else seed + i - output_path = None - if output_dir: - output_dir.mkdir(exist_ok=True, parents=True) - output_path = output_dir / f"generated_{i+1:03d}.wav" - - try: - output_path = self.generate_audio( - prompt=prompt, - duration=duration, - cfg_scale=cfg_scale, - steps=steps, - seed=current_seed, - output_path=output_path - ) - results.append(output_path) - - except Exception as e: - print(f"Failed to generate audio for prompt {i+1}: {e}") - results.append(None) - - return results + # Seamless-loop processing (quantize, inpaint, crossfade) was + # removed: the user A/B-compared raw SA3 output against the full + # pipeline and confirmed the post-processing made every prompt + # worse — silence-at-start on percussion, smeared transients, + # off-grid anchoring. We now deliver the raw model output. The + # `loop_stitch` / `loop_bars` / `loop_bpm` parameters are still + # accepted from the frontend for API compatibility but are + # ignored. Performance-Bars looping will have an audible click + # at the wrap point and multi-channel stacks will not be + # sample-aligned — both acceptable trade-offs vs. the artifacts + # the quantizer was introducing. + _set_progress(phase="decoding", step=total_steps_logical) + try: + return self._finalize(audio, prompt=prompt, model_id=model_id) + finally: + _set_progress(phase="complete", is_generating=False, + step=total_steps_logical, ended_at=time.time()) - def save_audio(self, audio: torch.Tensor, output_path: Path, sample_rate: int): - output_path.parent.mkdir(exist_ok=True, parents=True) - audio_np = audio.detach().cpu().transpose(0, 1).numpy() - sf.write(str(output_path), audio_np, sample_rate, subtype="PCM_16") + # --- audio loader (a2a + inpaint inputs) ---------------------------------- + @staticmethod + def _load_audio(path: str): + """Load a wav/mp3/flac into the (sample_rate, tensor) tuple SA3 expects. - def get_model_info(self) -> Dict[str, Any]: - if self.model is None: - return {"status": "no_model_loaded"} - - return { - "status": "loaded", - "sample_rate": self.model.sample_rate, - "device": str(self.device), - "model_type": getattr(self.model, 'model_type', 'unknown'), - "io_channels": getattr(self.model, 'io_channels', 'unknown') - } + Returns a stereo float32 tensor of shape (channels, samples). Mono + inputs are duplicated to stereo (SA3 expects 2 channels); ≥3-channel + inputs are truncated to the first 2. + """ + import torchaudio + wav, sr = torchaudio.load(str(path)) # (channels, samples), float32 + if wav.shape[0] == 1: + wav = wav.repeat(2, 1) + elif wav.shape[0] > 2: + wav = wav[:2] + return int(sr), wav + + # --- output -------------------------------------------------------------- + def _finalize(self, audio: torch.Tensor, *, prompt: str, model_id: str) -> Path: + audio = audio.detach().clamp_(-1.0, 1.0).cpu() + if audio.ndim != 3: + raise RuntimeError(f"Unexpected SA3 output shape {tuple(audio.shape)}") + first = audio[0] # [C, samples] + pcm = (first.numpy() * 32767.0).astype(np.int16).T # → [samples, C] + + out_dir = self.config.get_path("output") + out_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d_%H%M%S") + out_path = out_dir / f"{ts}_{model_id}_{_slugify(prompt)}.wav" + sf.write(str(out_path), pcm, 44100, subtype="PCM_16") + return out_path diff --git a/app/core/generation/audio_post_process.py b/app/core/generation/audio_post_process.py index bb6980766341d24e39688a6625269f5e0042233e..075c5545e655d0031ab3945b79351f9e6b0cfc34 100644 --- a/app/core/generation/audio_post_process.py +++ b/app/core/generation/audio_post_process.py @@ -1,9 +1,40 @@ """Beat-align and tempo-conform a generated WAV to a target BPM and bar count. + +DEPRECATED — this entire module is superseded by ``app/core/loop_quantizer/`` +(see ``task_1.md`` and ``AUDIT.md`` §9 "Scheduled for removal"). The legacy +``align_to_grid`` / ``align_for_loop`` path and the gated ``_stage_a_v2`` path +both live here until the new module passes acceptance; once it does, every +public symbol below is removed and the file deletes itself. Do NOT add new +callers, do NOT extend the v1 helpers, and prefer adding work directly under +``app/core/loop_quantizer/`` for any new behaviour. + +SA3 generates at the exact requested duration via variable-length flow +matching, so the post-processor's role is **drift correction**, not length +control: it only nudges the audio when librosa detects that the realised +tempo has drifted from the target. The tempo-conform gate is intentionally +tight — `|rate - 1| > 5%` AND `rate in [0.85, 1.15]` — so we never warp +audibly when SA3 was already close. + +Pipeline (in order): + 1. Detect tempo + beat grid via librosa (with target BPM as prior). + 2. Head-trim to the first detected beat (or first onset as fallback), + followed by a 3 ms equal-power fade-in to mask the trim seam. + 3. Tempo-conform via phase-vocoder time-stretch, ONLY when the detected + tempo drifts >5% from target AND the resulting stretch lies inside + the safe range [0.85, 1.15]. Outside this window we leave the audio + alone and let the user re-roll. + 4. End-anchored truncation: snap the cut to the nearest detected beat + within ±½ beat of the mathematical target sample count, so loops + don't end mid-note. Followed by an 8 ms equal-power fade-out so the + loop seam doesn't click. + 5. Zero-pad if the audio came out shorter than the target. """ from __future__ import annotations import logging +import os +import warnings from pathlib import Path from typing import Optional, Tuple @@ -14,35 +45,429 @@ import soundfile as sf logger = logging.getLogger(__name__) -# Safe range for phase-vocoder time-stretching. Wider than the previous -# [0.7, 1.4] so we actually warp in more cases — librosa's vocoder produces -# acceptable audio across this range for music, and the alternative -# (no warp at all) drifts off the grid completely on loop. +# DEPRECATED: flag goes away with the v1/v2 split (AUDIT.md §9d). +def beatsync_v2_enabled() -> bool: + """Feature gate for the hardened Stage A pipeline (sample-exact length, + first-transient-to-zero alignment, transient-preserving stretch). + + Off by default: with the flag unset, every Stage A function takes its + legacy code path, so Bars-mode output is byte-identical to pre-flag + builds and Seconds mode (which never enters Stage A at all) is unaffected. + Enable with ``FRAGMENTA_BEATSYNC_V2=1``. + """ + return os.environ.get("FRAGMENTA_BEATSYNC_V2", "0").strip().lower() in ( + "1", "true", "yes", "on", + ) + + +# DEPRECATED: flag goes away with the v1/v2 split (AUDIT.md §9d). +def _warp_enabled() -> bool: + """Per-beat (Ableton 'Beats'-style) warp gate — OFF by default. + + The warp is only as reliable as librosa's per-beat detection; on real audio + a single mis-detected beat scrambles the groove. Anchor + exact-crop already + lands real loops at ~3 ms, so the warp is opt-in for experimentation only. + Enable with ``FRAGMENTA_BEATSYNC_WARP=1``.""" + return os.environ.get("FRAGMENTA_BEATSYNC_WARP", "0").strip().lower() in ( + "1", "true", "yes", "on", + ) + + +# Liberal module-default range for `_best_stretch_rate`. Kept wide so any +# future force-warp caller has room; the bars-mode drift-correction path +# (`align_to_grid`) overrides with tighter bounds below. _STRETCH_SAFE_MIN = 0.6 _STRETCH_SAFE_MAX = 1.7 +# Bars-mode drift correction. SA3 hits the requested duration exactly via +# variable-length generation, so the post-processor only kicks in when the +# detected tempo of the generated audio drifts from the requested target. +# Tight gates avoid audible vocoder artifacts when SA3 was already close. +_BARS_MODE_STRETCH_MIN = 0.85 +_BARS_MODE_STRETCH_MAX = 1.15 +_BARS_MODE_DEADBAND = 0.05 + +# Loop-mode (Phase 7) is stricter — a 5% tempo slack compounds visibly when +# multiple loop channels run side-by-side, even though loop iteration +# lengths are sample-exact. 0.5% is below librosa's noise floor for beat +# detection on rhythmic content, so we won't be acting on noise, but we +# WILL correct anything detectable that the looser bars-mode would skip. +_LOOP_MODE_DEADBAND = 0.005 + +# Fade durations applied at trim points. Kept very short — the fade is +# click-prevention, not a perceptible ramp. Performance Mode loops these +# clips, and longer fades audibly "duck" the loop seam. +_HEAD_FADE_SEC = 0.003 # mask click at the trimmed head +_TAIL_FADE_SEC = 0.003 # mask click at a mid-note truncation; skipped on beats + +# Trailing-silence detection. SA3 occasionally pads a generation with low- +# level tail; the post-processor used to keep that and fade over it, which +# produced perceptible "silence + duck" at the loop point. +_SILENCE_THRESHOLD_DB = -50.0 # anything below is silence +_SILENCE_WINDOW_SEC = 0.05 # RMS window granularity +_SILENCE_TAIL_KEEP_SEC = 0.010 # leave a tiny natural decay + +# v2 first-transient search: a downbeat lands within the first bar or two of +# generated content, so we never hunt past this window for the musical "1". +_V2_TRANSIENT_SEARCH_SEC = 1.5 +_V2_STRONG_RATIO = 0.30 # candidate must reach 30% of peak +_V2_RISE_RATIO = 0.15 # rising-edge threshold for refinement +_V2_REFINE_WIN_SEC = 0.03 # +/- window for sample-accurate refine + +# Grid confidence. librosa's beat tracker emits a tempo for ANY input — on +# ambient/textural content it is essentially noise (measured: 49-161 BPM on a +# 120-BPM target, 130+ ms intra-beat drift). Warping toward a wrong detected +# tempo is worse than not warping, so we only tempo-conform when the detected +# grid is trustworthy: beats evenly spaced (low interval CV) AND a clear pulse +# in the onset envelope. Below the threshold we trust the *requested* grid and +# skip the stretch (still doing the safe, tempo-independent transient@0 + crop). +# Calibrated on real fixtures: clean drum/bass loops score 0.76-0.88, pure +# pads 0.00 (no trackable beat), and ambiguous textures 0.44-0.57 — often with +# a wrong detected tempo. 0.65 sits in that gap. (The safe-range gate in +# _best_stretch_rate independently rejects octave-wrong tempos like 49/161 BPM.) +_GRID_CONFIDENCE_MIN = 0.65 +_CV_MAX = 0.20 # interval CV at which regularity -> 0 + +# Beat-synchronous warp (Ableton "Beats"-style). Measured: real drum loops are +# already coherent to ~3-6 ms, where anchor+exact-crop alone lands single-digit +# ms — so a global/elastic warp there only adds phase-vocoder jitter for no gain. +# We therefore warp ONLY when a confident grid still drifts past this threshold, +# and need enough beats to define segments. +_WARP_DRIFT_MIN_MS = 15.0 +_WARP_MIN_BEATS = 6 + + +# === Stage A v2 (FRAGMENTA_BEATSYNC_V2) ==================================== +# DEPRECATED: every symbol in this section is scheduled for relocation into +# `app/core/loop_quantizer` (see AUDIT.md §9c). Port the logic, then delete +# the originals here. Do NOT add new callers to anything below. +# A single hardened core shared by both align entry points. It enforces the +# locked invariants directly instead of relying on librosa's beat[0] for +# phase and on end-snap/silence-trim for length: +# * tempo conform with a bounded phase-vocoder stretch (_conform_stretch) — +# gen-time warp only, no live tracking (decision: v1); +# * align the first STRONG transient to sample 0 (rotate-free head trim) so +# two independently-correct clips share a downbeat with zero per-clip code; +# * crop to the exact target sample count — overgenerate-then-trim, never +# zero-pad in the common path (pad only as a logged last resort). + +def _stage_a_v2( + audio: np.ndarray, + sr: int, + *, + target_samples: int, + target_bpm: float, + deadband: float, +) -> np.ndarray: + """Hardened Stage A core. Input/return: float32 ``[T, C]``. + + Decides per clip how to land it on the grid: + * low grid confidence -> place as-is (trust the requested grid; no warp, + no trim — Ableton likewise won't warp a pulse-less texture); + * confident + non-uniform drift -> beat-synchronous warp (each inter-beat + segment stretched onto the exact grid, Ableton "Beats" warp); + * confident + already coherent -> anchor + (optional) whole-loop tempo + nudge; the measured workhorse path (single-digit ms on real loops). + Always finishes with: first-strong-transient -> sample 0, then exact crop. + """ + mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] + detected_bpm, beats = _detect_grid(mono, sr, start_bpm=target_bpm) + confidence = _grid_confidence(mono, sr, beats) + spb = sr * 60.0 / target_bpm + + trusted = ( + detected_bpm is not None + and confidence >= _GRID_CONFIDENCE_MIN + and beats is not None + and len(beats) >= _WARP_MIN_BEATS + ) + + if not trusted: + logger.info( + "stage_a_v2: %s; trusting requested %.2f BPM grid, exact-length only", + "low grid confidence (%.2f < %.2f)" % (confidence, _GRID_CONFIDENCE_MIN) + if detected_bpm is not None else "no usable grid", + target_bpm, + ) + return _exact_len(audio, target_samples, sr) + + # --- anchor the musical "1" to sample 0 (INV#4, enables INV#9) -------- + # Anchor to the first TRACKED beat, not the "first loud onset": the tracked + # beat is the same metrical position across clips, so two loops coincide; + # "first loud onset" lands on whatever transient happens to be loudest and + # differs per clip (measured: 200+ ms apart). Refine beats[0] to the exact + # rising edge for sample accuracy. + anchor = _refine_to_transient(mono, int(beats[0]), sr) + if anchor > 0: + audio = audio[anchor:] + mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] + beats = np.asarray(beats, dtype=np.int64) - anchor + beats = beats[beats >= 0] + + drift = _grid_drift_samples(beats) + if (_warp_enabled() and drift > _WARP_DRIFT_MIN_MS * sr / 1000.0 + and len(beats) >= 2): + # OFF BY DEFAULT (FRAGMENTA_BEATSYNC_WARP). Per-beat warp is only as good + # as librosa's beat detection — when detection is even slightly off it + # warps the wrong points onto the grid and SCRAMBLES the groove on real + # audio. Measured gain on clean drift was marginal (it merely halved it + # and added jitter), while anchor + exact-crop already lands real loops + # at ~3 ms. So it's opt-in for experiments, not the default path. + audio = _beat_sync_warp(audio, beats, spb) + logger.info("stage_a_v2: anchored + beat-sync warp (intra-loop drift " + "%.1f ms)", drift / sr * 1000) + else: + # Already coherent: a single global stretch is sufficient (and cleaner + # than per-segment warping) when the overall tempo is off; otherwise + # the anchor + exact crop is all that's needed. + rate, eff = _best_stretch_rate( + detected_bpm, target_bpm, + safe_min=_BARS_MODE_STRETCH_MIN, safe_max=_BARS_MODE_STRETCH_MAX, + ) + if rate is not None and abs(rate - 1.0) > deadband: + audio = _conform_stretch(audio, rate, sr) + logger.info("stage_a_v2: anchored + global tempo conform x%.4f " + "(detected %.2f -> %.2f)", rate, detected_bpm, target_bpm) + else: + logger.info("stage_a_v2: anchored only (low drift, on-tempo)") + + return _exact_len(audio, target_samples, sr) + + +def _exact_len(audio: np.ndarray, target_samples: int, sr: int) -> np.ndarray: + """Crop to exactly target_samples (INV#2/#3). Pads only as a logged last + resort — the generation overshoots duration so trimming is the norm.""" + if audio.shape[0] >= target_samples: + return np.ascontiguousarray(audio[:target_samples], dtype=np.float32) + pad = target_samples - audio.shape[0] + logger.warning( + "stage_a_v2: content short by %d samp (%.0f ms) — padding as a last " + "resort; raise generation headroom or re-roll", pad, pad / sr * 1000, + ) + return np.ascontiguousarray( + np.concatenate([audio, np.zeros((pad, audio.shape[1]), np.float32)], 0), + dtype=np.float32, + ) + + +def _grid_drift_samples(beats: Optional[np.ndarray]) -> float: + """Std of detected-beat residuals vs a uniform least-squares grid (samples). + A coherent loop sits near 0; tempo wobble shows up as a large residual.""" + if beats is None or len(beats) < 4: + return 0.0 + idx = np.arange(len(beats)) + A = np.vstack([idx, np.ones_like(idx)]).T + slope, icpt = np.linalg.lstsq(A, beats.astype(float), rcond=None)[0] + resid = beats.astype(float) - (slope * idx + icpt) + return float(np.std(resid)) + + +def _refine_to_transient(mono: np.ndarray, approx: int, sr: int, + win_sec: float = 0.015) -> int: + """Snap a frame-resolution beat sample to the exact rising edge of the + transient AT that beat. librosa picks WHICH transient is the beat (good); + this gives it sample accuracy (INV#4). The window is deliberately tight + (~15 ms): wide enough to cover beat-tracker frame jitter, narrow enough not + to jump to a neighbouring transient (which would desync clips, INV#9).""" + n = len(mono) + if n == 0: + return 0 + approx = int(max(0, min(approx, n - 1))) + lo = max(0, approx - int(sr * win_sec)) + hi = min(n, approx + int(sr * win_sec)) + if hi - lo < 2: + return approx + seg = np.abs(mono[lo:hi]) + pk = float(seg.max()) + if pk <= 1e-6: + return approx + above = np.flatnonzero(seg >= _V2_RISE_RATIO * pk) + return int(lo + above[0]) if len(above) else approx + + +def _beat_sync_warp(audio: np.ndarray, beats: np.ndarray, spb: float) -> np.ndarray: + """Ableton 'Beats'-style warp: stretch each inter-beat segment to exactly + round(spb) samples. Output starts at the first detected beat and has a + perfectly uniform grid, so two clips at the same tempo become sample-for- + sample periodic (INV#9). Phase-vocoder per segment; only invoked when drift + is high enough to be worth the boundary jitter.""" + beats = np.asarray(beats, dtype=np.int64) + beats = beats[(beats >= 0) & (beats < audio.shape[0])] + if len(beats) < 2: + return audio + target_spb = int(round(spb)) + segs = [] + for i in range(len(beats) - 1): + s, e = int(beats[i]), int(beats[i + 1]) + seg = audio[s:e] + if seg.shape[0] < 16: + continue + rate = float(np.clip(seg.shape[0] / spb, 0.5, 2.0)) + w = librosa.effects.time_stretch(seg.T, rate=rate).T + if w.shape[0] >= target_spb: + w = w[:target_spb] + else: + w = np.concatenate( + [w, np.zeros((target_spb - w.shape[0], w.shape[1]), np.float32)], 0) + segs.append(np.ascontiguousarray(w, dtype=np.float32)) + return np.concatenate(segs, 0) if segs else audio + + +def _grid_confidence( + mono: np.ndarray, sr: int, beats: Optional[np.ndarray] +) -> float: + """Trustworthiness of the detected beat grid, in [0, 1]. + + Two evidence sources, averaged: + * regularity — how evenly spaced the detected beats are (1 - interval + coefficient of variation, clamped); a locked tracker gives near-even + intervals, ambient content gives erratic ones; + * pulse clarity — the strongest off-zero peak of the onset-envelope + autocorrelation relative to lag 0; high when there is a real periodic + pulse, low for drones/pads. + """ + if beats is None or len(beats) < 4: + return 0.0 + intervals = np.diff(beats.astype(np.float64)) + mean_i = float(np.mean(intervals)) if len(intervals) else 0.0 + if mean_i <= 0: + return 0.0 + cv = float(np.std(intervals) / mean_i) + regularity = max(0.0, min(1.0, 1.0 - cv / _CV_MAX)) + + clarity = 0.0 + try: + oenv = librosa.onset.onset_strength(y=mono, sr=sr) + oenv = oenv - float(np.mean(oenv)) + ac = librosa.autocorrelate(oenv) + if len(ac) > 4 and ac[0] > 0: + clarity = float(np.max(ac[4:]) / ac[0]) + clarity = max(0.0, min(1.0, clarity)) + except Exception as exc: + logger.warning("grid-confidence clarity failed: %s", exc) + + return 0.5 * regularity + 0.5 * clarity + + +def _first_strong_transient(mono: np.ndarray, sr: int) -> int: + """Sample index of the first STRONG transient, refined to the rising edge. + + Two-stage so we neither latch onto low-level noise nor lose sample + accuracy to librosa's 512-sample hop: + 1. librosa onset candidates; take the first whose local peak reaches + ``_V2_STRONG_RATIO`` of the search-window peak; + 2. refine within a small window to the first sample crossing + ``_V2_RISE_RATIO`` of that local peak — the attack's true start. + Returns 0 when the clip is silent or no strong transient is found. + """ + n = len(mono) + search = min(n, int(sr * _V2_TRANSIENT_SEARCH_SEC)) + if search <= 0: + return 0 + peak = float(np.max(np.abs(mono[:search]))) + if peak <= 1e-6: + return 0 + + try: + onsets = librosa.onset.onset_detect( + y=mono, sr=sr, units="samples", backtrack=True + ) + except Exception as exc: + logger.warning("v2 onset detection failed: %s", exc) + onsets = None + + cand: Optional[int] = None + if onsets is not None and len(onsets) > 0: + look = int(sr * 0.05) + for o in np.asarray(onsets, dtype=np.int64): + if o >= search: + break + lo, hi = int(o), min(n, int(o) + look) + if float(np.max(np.abs(mono[lo:hi]))) >= _V2_STRONG_RATIO * peak: + cand = int(o) + break + + if cand is None: + # No qualifying onset — fall back to the first sample that crosses a + # fraction of the window peak (handles smooth/pad content). + idx = np.flatnonzero(np.abs(mono[:search]) >= _V2_STRONG_RATIO * peak) + return int(idx[0]) if len(idx) else 0 + win = int(sr * _V2_REFINE_WIN_SEC) + lo = max(0, cand - win) + hi = min(n, cand + win) + local_peak = float(np.max(np.abs(mono[lo:hi]))) or peak + seg = np.abs(mono[lo:hi]) + above = np.flatnonzero(seg >= _V2_RISE_RATIO * local_peak) + return int(lo + above[0]) if len(above) else cand + + +def _conform_stretch(audio: np.ndarray, rate: float, sr: int) -> np.ndarray: + """Tempo-conform time-stretch — the INV#5 "justified equivalent". + + We use the librosa phase vocoder (no external binary to ship) rather than + RubberBand's transient mode, justified by three properties that keep + transient smearing perceptually negligible here: + + 1. Bounded rate. This only runs inside the safe range [0.85, 1.15] — at + most a 15% stretch — where phase-vocoder transient blur is minor. + 2. Rare path. It fires only on high grid-confidence, off-by->0.5%-tempo + loops; SA3 usually hits the target at gen-time and skips it entirely. + 3. The perceptually critical transient — the downbeat — is positioned by + the sample-accurate trim in `_stage_a_v2`, NOT by this stretch, so the + musical "1" is never vocoded. + + `sr` is accepted for call-site symmetry (the phase vocoder is rate-only).""" + if abs(rate - 1.0) < 1e-9: + return audio + return _time_stretch_multichannel(audio, rate) + + +# DEPRECATED: superseded by app/core/loop_quantizer (see task_1.md / AUDIT.md §9a). +# Public entry; emits DeprecationWarning at runtime. Scheduled for removal once +# the new module passes acceptance. def align_to_grid( input_path: Path, target_bpm: float, target_bars: int, beats_per_bar: int = 4, ) -> Path: + warnings.warn( + "align_to_grid is deprecated and will be removed once " + "app/core/loop_quantizer ships (see task_1.md / AUDIT.md §9a).", + DeprecationWarning, + stacklevel=2, + ) audio, sr = sf.read(str(input_path), always_2d=True) audio = audio.astype(np.float32, copy=False) - target_samples = int(round(target_bars * beats_per_bar * 60.0 / target_bpm * sr)) + samples_per_beat = sr * 60.0 / float(target_bpm) + target_samples = int(round(target_bars * beats_per_bar * samples_per_beat)) + + if beatsync_v2_enabled(): + out = _stage_a_v2( + np.ascontiguousarray(audio), sr, + target_samples=target_samples, target_bpm=float(target_bpm), + deadband=_BARS_MODE_DEADBAND, + ) + # 3 ms head fade-in masks any click at the new sample-0 transient. + _apply_fade(out, _HEAD_FADE_SEC, sr, fade_in=True) + sf.write(str(input_path), out, sr, subtype="PCM_16") + logger.info("align_to_grid[v2]: %d samples (exact target %d)", + out.shape[0], target_samples) + return input_path mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] - # Pass target_bpm as a prior to librosa — biases the beat tracker away - # from half-time / double-time interpretations of the same grid. - detected_bpm, first_beat = _detect_grid_anchor(mono, sr, start_bpm=target_bpm) + detected_bpm, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm) + # --- Head trim --------------------------------------------------------- head_offset = 0 - if first_beat is not None and 0 < first_beat < sr * 1.5: - head_offset = first_beat - logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms to first beat") - elif first_beat is None: + if beat_samples is not None and len(beat_samples) > 0: + first_beat = int(beat_samples[0]) + if 0 < first_beat < sr * 1.5: + head_offset = first_beat + logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms to first beat") + elif beat_samples is None: head_offset = _detect_first_onset_sample(mono, sr) if head_offset > 0: logger.info(f"align_to_grid: trimmed {head_offset / sr * 1000:.1f} ms (onset fallback)") @@ -50,12 +475,26 @@ def align_to_grid( if head_offset > 0: audio = audio[head_offset:] mono = mono[head_offset:] + if beat_samples is not None: + shifted = np.asarray(beat_samples, dtype=np.int64) - head_offset + beat_samples = shifted[shifted > 0] + # Head fade-in: 3 ms equal-power so the trim seam doesn't click. + _apply_fade(audio, _HEAD_FADE_SEC, sr, fade_in=True) + # --- Tempo conform ----------------------------------------------------- if detected_bpm is not None: - rate, effective_bpm = _best_stretch_rate(detected_bpm, target_bpm) - if rate is not None: - if abs(rate - 1.0) > 1e-3: - audio = _time_stretch_multichannel(audio, rate) + rate, effective_bpm = _best_stretch_rate( + detected_bpm, + target_bpm, + safe_min=_BARS_MODE_STRETCH_MIN, + safe_max=_BARS_MODE_STRETCH_MAX, + ) + if rate is not None and abs(rate - 1.0) > _BARS_MODE_DEADBAND: + audio = _time_stretch_multichannel(audio, rate) + # Beats have moved — re-detect from the warped audio so the + # end-snap step below sees current beat positions. + mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] + _, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm) interp_note = ( f" (interpreted as {effective_bpm:.2f} BPM, " f"octave={effective_bpm / detected_bpm:.2f}×)" @@ -66,33 +505,267 @@ def align_to_grid( f"align_to_grid: detected {detected_bpm:.2f} BPM{interp_note}, " f"stretched by {rate:.4f} to match target {target_bpm:.2f} BPM" ) + elif rate is not None: + logger.info( + f"align_to_grid: detected {detected_bpm:.2f} BPM is within " + f"{_BARS_MODE_DEADBAND * 100:.0f}% of target {target_bpm:.2f}; " + f"skipping stretch to preserve transients" + ) else: logger.info( f"align_to_grid: detected {detected_bpm:.2f} BPM has no safe " - f"interpretation vs target {target_bpm:.2f}; skipping warp" + f"interpretation vs target {target_bpm:.2f} within " + f"[{_BARS_MODE_STRETCH_MIN:.2f}, {_BARS_MODE_STRETCH_MAX:.2f}]; " + f"skipping warp (user re-roll recommended)" ) else: logger.info("align_to_grid: no usable tempo detected; skipping warp") + # --- Trim trailing silence -------------------------------------------- + # Done before end-snap so the snap operates on real audio, not on + # beats that happen to fall inside a quiet tail. + new_len = _trailing_audio_end(audio, sr) + if new_len < audio.shape[0]: + trimmed_ms = (audio.shape[0] - new_len) / sr * 1000 + logger.info(f"align_to_grid: trimmed {trimmed_ms:.0f} ms trailing silence") + audio = audio[:new_len] + if beat_samples is not None: + beat_samples = beat_samples[beat_samples < new_len] + + # --- End-anchored truncation ------------------------------------------ if audio.shape[0] > target_samples: - audio = audio[:target_samples] - # 8ms tail fade prevents the click at the loop boundary when the - # truncation point lands mid-waveform. - fade_samples = min(int(0.008 * sr), audio.shape[0]) - if fade_samples > 1: - fade = np.linspace(1.0, 0.0, fade_samples, dtype=audio.dtype) - audio[-fade_samples:] *= fade[:, np.newaxis] if audio.ndim > 1 else fade - elif audio.shape[0] < target_samples: - pad = np.zeros((target_samples - audio.shape[0], audio.shape[1]), dtype=audio.dtype) - audio = np.concatenate([audio, pad], axis=0) + end = _snap_to_beat(target_samples, beat_samples, samples_per_beat, audio.shape[0]) + cut_on_beat = beat_samples is not None and end in beat_samples.tolist() + audio = audio[:end] + if not cut_on_beat: + # Mid-note cut — short fade hides the click. On a clean beat + # boundary the cut is on a natural transient edge, so the fade + # would only "duck" the start of the next beat at the loop + # seam without preventing any audible click. + _apply_fade(audio, _TAIL_FADE_SEC, sr, fade_in=False) + # If we came in shorter than target, return the actual audio without + # zero-padding. A 7.5-bar clip that loops cleanly beats an 8-bar clip + # with 0.5 bars of silence at the loop seam. sf.write(str(input_path), audio, sr, subtype="PCM_16") return input_path +# --- Phase 7 loop alignment ----------------------------------------------- + +# DEPRECATED: superseded by app/core/loop_quantizer (see task_1.md / AUDIT.md §9a). +# Public entry; emits DeprecationWarning at runtime. Scheduled for removal once +# the new module passes acceptance. +def align_for_loop( + audio: np.ndarray, + sr: int, + *, + target_samples: int, + target_bpm: float, +) -> np.ndarray: + """Align a baseline clip for seamless looping at an exact length. + + DEPRECATED — superseded by ``app/core/loop_quantizer`` (see ``task_1.md`` / + ``AUDIT.md`` §9a). Scheduled for removal once the new module ships. + + Pipeline (in-memory, no disk I/O): + 1. Detect tempo + beat grid via librosa. + 2. Time-stretch (uniformly) if detected BPM drifts past the bars-mode + deadband AND the required rate is in the safe range. Drift + beyond the safe range is left alone (caller can re-roll). + 3. Head-trim to the first detected beat (or first onset as fallback), + within the first ~1.5 s. This is the phase-alignment step — it + puts the loop's "downbeat" at sample 0 so multiple channels' + beats coincide when launched on a bar boundary. + 4. Crop or zero-pad to exactly `target_samples`. No end-snap: the + loop iteration length is sample-exact so it stays phase-locked + to the master clock across iterations. + + Returns a `np.ndarray` of shape `(target_samples, channels)` (or 1-D + if input was 1-D). The caller is expected to wrap-and-inpaint the + output to smooth the seam — `align_for_loop` does no fade. + """ + warnings.warn( + "align_for_loop is deprecated and will be removed once " + "app/core/loop_quantizer ships (see task_1.md / AUDIT.md §9a).", + DeprecationWarning, + stacklevel=2, + ) + if audio.ndim == 1: + audio = audio[:, np.newaxis] + squeeze_out = True + else: + squeeze_out = False + audio = np.ascontiguousarray(audio, dtype=np.float32) + + if beatsync_v2_enabled(): + out = _stage_a_v2( + audio, sr, + target_samples=target_samples, target_bpm=float(target_bpm), + deadband=_LOOP_MODE_DEADBAND, + ) + return out.squeeze(1) if squeeze_out else out + + mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] + detected_bpm, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm) + + # --- 1+2: tempo conform --------------------------------------------- + if detected_bpm is not None: + rate, effective_bpm = _best_stretch_rate( + detected_bpm, + target_bpm, + safe_min=_BARS_MODE_STRETCH_MIN, + safe_max=_BARS_MODE_STRETCH_MAX, + ) + if rate is not None and abs(rate - 1.0) > _LOOP_MODE_DEADBAND: + audio = _time_stretch_multichannel(audio, rate) + mono = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0] + _, beat_samples = _detect_grid(mono, sr, start_bpm=target_bpm) + interp = ( + f" (interpreted as {effective_bpm:.2f} BPM)" + if abs(effective_bpm - detected_bpm) > 1e-2 else "" + ) + logger.info( + "align_for_loop: detected %.2f BPM%s, stretched by %.4f to " + "match %.2f target", + detected_bpm, interp, rate, target_bpm, + ) + elif rate is not None: + logger.info( + "align_for_loop: detected %.2f BPM within %.2f%% of %.2f target; " + "no stretch", + detected_bpm, _LOOP_MODE_DEADBAND * 100, target_bpm, + ) + else: + logger.info( + "align_for_loop: detected %.2f BPM has no safe stretch to " + "%.2f target within [%.2f, %.2f]; leaving tempo as-is", + detected_bpm, target_bpm, + _BARS_MODE_STRETCH_MIN, _BARS_MODE_STRETCH_MAX, + ) + else: + logger.info("align_for_loop: no usable tempo detected; skipping stretch") + + # --- 3: head-trim to first beat / onset (phase alignment) ----------- + head_offset = 0 + if beat_samples is not None and len(beat_samples) > 0: + first_beat = int(beat_samples[0]) + if 0 < first_beat < sr * 1.5: + head_offset = first_beat + if head_offset == 0: + # Onset fallback when beat tracking didn't lock — gives at least + # a transient-aligned start instead of mid-attack on sample 0. + head_offset = _detect_first_onset_sample(mono, sr) + if head_offset >= sr * 1.5: + head_offset = 0 + if head_offset > 0: + audio = audio[head_offset:] + logger.info( + "align_for_loop: head-trimmed %.1f ms to first beat/onset", + head_offset / sr * 1000, + ) + + # --- 4: crop or pad to exact target_samples ------------------------- + if audio.shape[0] > target_samples: + audio = audio[:target_samples] + elif audio.shape[0] < target_samples: + pad = target_samples - audio.shape[0] + audio = np.concatenate( + [audio, np.zeros((pad, audio.shape[1]), dtype=audio.dtype)], + axis=0, + ) + + return audio.squeeze(1) if squeeze_out else audio + + +# --- helpers --------------------------------------------------------------- + +# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md §9b). +def _trailing_audio_end(audio: np.ndarray, sr: int) -> int: + """Return the sample index just past the last audible content. + + Walks backwards in non-overlapping windows of `_SILENCE_WINDOW_SEC` and + finds the last window whose RMS exceeds `_SILENCE_THRESHOLD_DB`. Returns + the end of that window plus a small natural-decay tail. + + Falls back to the original audio length when the entire clip is below + threshold (silent input) or shorter than one window. + """ + n = audio.shape[0] + window = int(sr * _SILENCE_WINDOW_SEC) + if n <= window: + return n + mono = audio.mean(axis=1) if audio.ndim > 1 else audio + # Squared amplitudes — comparing to threshold² is equivalent to RMS vs + # threshold but avoids a sqrt per window. + sq = (mono ** 2) + thresh_sq = (10.0 ** (_SILENCE_THRESHOLD_DB / 20.0)) ** 2 + tail_keep = int(sr * _SILENCE_TAIL_KEEP_SEC) + end = n + while end > 0: + start = max(0, end - window) + if float(sq[start:end].mean()) > thresh_sq: + return min(n, end + tail_keep) + end = start + # Whole clip is below threshold — leave as-is rather than truncate to 0. + return n + + +# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md §9b). +def _snap_to_beat( + target_samples: int, + beat_samples: Optional[np.ndarray], + samples_per_beat: float, + audio_len: int, +) -> int: + """Return the cut point: the nearest detected beat within ±½ beat of + target_samples, falling back to target_samples itself if no beat is in + range. Never overshoots audio length.""" + fallback = min(target_samples, audio_len) + if beat_samples is None or len(beat_samples) == 0: + return fallback + tol = samples_per_beat * 0.5 + valid = beat_samples[(beat_samples > 0) & (beat_samples <= audio_len)] + if len(valid) == 0: + return fallback + diffs = np.abs(valid - target_samples) + idx = int(np.argmin(diffs)) + if diffs[idx] <= tol: + return int(valid[idx]) + return fallback + + +# DEPRECATED: superseded by loop_quantizer (AUDIT.md §9b); may be ported if reused. +def _apply_fade(audio: np.ndarray, duration_sec: float, sr: int, *, fade_in: bool) -> None: + """In-place equal-power fade on the head (fade_in=True) or tail.""" + n = min(int(duration_sec * sr), audio.shape[0]) + if n <= 1: + return + ramp = _equal_power_ramp(n, fade_in=fade_in, dtype=audio.dtype) + if audio.ndim > 1: + ramp = ramp[:, np.newaxis] + if fade_in: + audio[:n] *= ramp + else: + audio[-n:] *= ramp + + +# DEPRECATED: superseded by loop_quantizer (AUDIT.md §9b); may be ported if reused. +def _equal_power_ramp(n: int, *, fade_in: bool, dtype) -> np.ndarray: + """Cosine-shaped equal-power fade. Energy at the midpoint is preserved + when summing fade-out + fade-in of complementary segments, avoiding the + perceptible 'duck' that linear ramps produce at loop seams.""" + t = np.linspace(0.0, np.pi / 2.0, n).astype(dtype, copy=False) + return np.sin(t) if fade_in else np.cos(t) + + +# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md §9b). def _best_stretch_rate( detected_bpm: float, target_bpm: float, + *, + safe_min: float = _STRETCH_SAFE_MIN, + safe_max: float = _STRETCH_SAFE_MAX, ) -> Tuple[Optional[float], float]: """Pick the time-stretch rate that maps detected → target, considering half-time and double-time interpretations of the detected tempo. Returns @@ -101,27 +774,20 @@ def _best_stretch_rate( nothing safe is available. Order of preference: - 1. Detected as-is, if it lands inside the safe stretch range. + 1. Detected as-is, if it lands inside [safe_min, safe_max]. 2. Octave-corrected (detected × 0.5 or × 2.0), only when the as-is interpretation is out of range. This is the librosa half-/double- time error recovery path. - - This biases the algorithm toward honesty: only re-interpret the - detector's reading when it can't otherwise produce a usable stretch. """ - # First, try the detector's reading at face value. rate_asis = target_bpm / detected_bpm - if _STRETCH_SAFE_MIN <= rate_asis <= _STRETCH_SAFE_MAX: + if safe_min <= rate_asis <= safe_max: return rate_asis, detected_bpm - # As-is is out of safe range — almost certainly a librosa octave error. - # Try the half-time and double-time reinterpretations and pick whichever - # is closest to a no-op stretch. candidates = [] for octave_factor in (0.5, 2.0): interpreted = detected_bpm * octave_factor rate = target_bpm / interpreted - if _STRETCH_SAFE_MIN <= rate <= _STRETCH_SAFE_MAX: + if safe_min <= rate <= safe_max: candidates.append((abs(rate - 1.0), rate, interpreted)) if not candidates: return None, detected_bpm @@ -130,6 +796,7 @@ def _best_stretch_rate( return best_rate, best_interp +# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md §9b). def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int: """Return the sample index of the first detected onset, or 0 if none found.""" try: @@ -147,15 +814,16 @@ def _detect_first_onset_sample(mono: np.ndarray, sr: int) -> int: return first -def _detect_grid_anchor( +# DEPRECATED: superseded by loop_quantizer detector (AUDIT.md §9c); port or replace. +def _detect_grid( mono: np.ndarray, sr: int, start_bpm: Optional[float] = None, -) -> Tuple[Optional[float], Optional[int]]: - """Run librosa beat tracking with the target tempo as a prior. Passing - start_bpm reduces (but doesn't eliminate) half-time / double-time errors. - The octave-correction in _best_stretch_rate handles whatever librosa - still gets wrong.""" +) -> Tuple[Optional[float], Optional[np.ndarray]]: + """Run librosa beat tracking with the target tempo as a prior. Returns + (bpm, beat_samples_array). Passing start_bpm reduces (but doesn't + eliminate) half-time / double-time errors; the octave-correction in + _best_stretch_rate handles whatever librosa still gets wrong.""" try: kwargs = {"y": mono, "sr": sr, "units": "samples"} if start_bpm is not None and start_bpm > 0: @@ -169,9 +837,10 @@ def _detect_grid_anchor( bpm = float(np.atleast_1d(tempo).flatten()[0]) if not (40.0 <= bpm <= 240.0): return None, None - return bpm, int(beats[0]) + return bpm, np.asarray(beats, dtype=np.int64) +# DEPRECATED: legacy v1 helper; delete with this module (AUDIT.md §9b). def _time_stretch_multichannel(audio: np.ndarray, rate: float) -> np.ndarray: """Phase-vocoder time stretch, applied per channel and re-stacked.""" stretched = librosa.effects.time_stretch(audio.T, rate=rate) diff --git a/app/core/model_manager.py b/app/core/model_manager.py index 5d9aba76eeab84b5f9a4ed06d48ef24b5570b5db..962a85e6204fec28f143ecf477c4c2f346f99233 100644 --- a/app/core/model_manager.py +++ b/app/core/model_manager.py @@ -1,478 +1,669 @@ -import os +"""Checkpoint Manager — SA3 catalog, HF downloads, license + auth. + +Phase 2a in SA3_INTEGRATION_PLAN.md. Replaces the SA2-era SAO catalog. +Eight downloadable artifacts (3 post-trained + 3 base + 2 autoencoders); +each is fetched via `huggingface_hub.snapshot_download` with cooperative +cancel + progress reporting. + +The Phase 2b frontend (CheckpointManagerWindow.js) consumes the JSON shapes +returned by the `/api/checkpoints/*` endpoints in `app/backend/app.py`. +""" import json +import os import shutil -from pathlib import Path -from typing import Dict, List, Optional, Callable +import threading +import uuid +from dataclasses import dataclass, field from datetime import datetime -import requests -from huggingface_hub import snapshot_download, hf_hub_download -import hashlib +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +from huggingface_hub import get_token, snapshot_download, whoami +from huggingface_hub.errors import GatedRepoError, RepositoryNotFoundError + + +# --- Catalog ------------------------------------------------------------------ + +# Approximate sizes; the frontend can refine these by hitting +# `huggingface_hub.HfApi().model_info(repo_id)` lazily. Numbers come from the +# HF model cards (paragraph parameter counts × bytes/param, rounded). +_SA3_CATALOG: Dict[str, Dict[str, Any]] = { + # --- Generation models (post-trained) ---------------------------------- + "sa3-small-music": { + "user_visible": True, + "kind": "post-trained", + "name": "Small - Music", + "sa3_name": "small-music", + "repo": "stabilityai/stable-audio-3-small-music", + "size_bytes": 2_270_000_000, + "hardware": "cpu", # CPU / MPS / CUDA all work + "max_duration_sec": 120, + "description": "Fast distilled music generation. Locked to 8 steps, cfg 1.0.", + }, + "sa3-small-sfx": { + "user_visible": True, + "kind": "post-trained", + "name": "Small - SFX", + "sa3_name": "small-sfx", + "repo": "stabilityai/stable-audio-3-small-sfx", + "size_bytes": 2_270_000_000, + "hardware": "cpu", + "max_duration_sec": 120, + "description": "Fast distilled SFX/foley generation. Locked to 8 steps, cfg 1.0.", + }, + "sa3-medium": { + "user_visible": True, + "kind": "post-trained", + "name": "Medium", + "sa3_name": "medium", + "repo": "stabilityai/stable-audio-3-medium", + "size_bytes": 9_220_000_000, + "hardware": "cuda+flash-attn", + "max_duration_sec": 380, + "description": "Fast distilled hi-fi generation, up to 380s. Locked to 8 steps, cfg 1.0.", + }, + # --- Base checkpoints (full artist control) ---------------------------- + # These are the CFG-aware pre-distillation models. Slower (~50 steps, + # cfg ~7), but the user controls cfg_scale, steps, and the inference + # trajectory. Also the canonical targets for LoRA training. + "sa3-small-music-base": { + "user_visible": True, + "kind": "base", + "name": "Small - Music (Base)", + "sa3_name": "small-music-base", + "repo": "stabilityai/stable-audio-3-small-music-base", + "size_bytes": 2_270_000_000, + "hardware": "cpu", + "max_duration_sec": 120, + "description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.", + }, + "sa3-small-sfx-base": { + "user_visible": True, + "kind": "base", + "name": "Small - SFX (Base)", + "sa3_name": "small-sfx-base", + "repo": "stabilityai/stable-audio-3-small-sfx-base", + "size_bytes": 2_270_000_000, + "hardware": "cpu", + "max_duration_sec": 120, + "description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.", + }, + "sa3-medium-base": { + "user_visible": True, + "kind": "base", + "name": "Medium (Base)", + "sa3_name": "medium-base", + "repo": "stabilityai/stable-audio-3-medium-base", + "size_bytes": 9_220_000_000, + "hardware": "cuda+flash-attn", + "max_duration_sec": 380, + "description": "CFG-aware base. Full control over cfg_scale, steps. Slower than distilled.", + }, + # Standalone autoencoders: the AE is bundled INSIDE each DiT repo + # already (StableAudioModel.from_pretrained loads it from there), so + # we don't surface SAME-S / SAME-L in the manager. They remain + # downloadable via /api/checkpoints?include=all for advanced uses + # (autoencoder-only workflows, pre-encoding datasets for training). + "sa3-same-s": { + "user_visible": False, + "kind": "autoencoder", + "name": "SAME-S", + "sa3_name": "same-s", + "repo": "stabilityai/SAME-S", + "size_bytes": 530_000_000, + "hardware": "cpu", + "description": "Standalone autoencoder (266M). Already bundled with the small-* DiTs.", + }, + "sa3-same-l": { + "user_visible": False, + "kind": "autoencoder", + "name": "SAME-L", + "sa3_name": "same-l", + "repo": "stabilityai/SAME-L", + "size_bytes": 3_400_000_000, + "hardware": "cuda", + "description": "Standalone autoencoder (1.7B). Already bundled with medium.", + }, + # --- Auto-annotation tools --------------------------------------------- + # Single-file HF download, lives under /clap/. + # `is_model_downloaded` and `_run_download` special-case kind=="tagger". + "clap-music": { + "user_visible": True, + "kind": "tagger", + "name": "LAION-CLAP (music)", + "sa3_name": "clap-music", + "repo": "lukewys/laion_clap", + "filename": "music_audioset_epoch_15_esc_90.14.pt", + # ~2.35 GB .pt + ~1.4 GB of text-encoder snapshots (roberta-base, + # bert-base-uncased, facebook/bart-base) that laion_clap loads at + # construction. download_clap_checkpoint pulls all of them. + "size_bytes": 3_800_000_000, + "hardware": "cpu", + "description": ( + "Zero-shot tagger used by the dataset prep's rich-tier annotation. " + "Scores each clip against your genre / mood / instrument vocabulary." + ), + }, +} + +# --- Job state for in-flight downloads ---------------------------------------- + +@dataclass +class _DownloadJob: + """In-memory record of one download attempt.""" + job_id: str + model_id: str + status: str = "queued" # queued | running | complete | failed | cancelled + downloaded_bytes: int = 0 + total_bytes: int = 0 + error: Optional[str] = None + started_at: Optional[str] = None + finished_at: Optional[str] = None + _cancel_flag: threading.Event = field(default_factory=threading.Event) + _thread: Optional[threading.Thread] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "job_id": self.job_id, + "model_id": self.model_id, + "status": self.status, + "downloaded_bytes": self.downloaded_bytes, + "total_bytes": self.total_bytes, + "error": self.error, + "started_at": self.started_at, + "finished_at": self.finished_at, + } + + +class _DownloadCancelled(Exception): + """Raised inside the tqdm hook when a job's cancel flag fires.""" +# --- ModelManager ------------------------------------------------------------- + class ModelManager: + """Owns the SA3 catalog and the on-disk pretrained directory.""" - def __init__(self, config): + def __init__(self, config: Any) -> None: self.config = config - self.models_dir = config.get_path("models_pretrained") + self.models_dir: Path = config.get_path("models_pretrained") self.models_dir.mkdir(exist_ok=True, parents=True) - # Use fragmenta-models repo on HF Spaces, Stability AI models elsewhere - use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true' - - if use_custom_repo: - models_repo = 'MazCodes/fragmenta-models' - small_file = 'stable-audio-open-small-model.safetensors' - large_file = 'stable-audio-open-model.safetensors' - else: - models_repo_small = 'stabilityai/stable-audio-open-small' - models_repo_large = 'stabilityai/stable-audio-open-1.0' - small_file = 'model.safetensors' - large_file = 'model.safetensors' - - self.available_models = { - 'stable-audio-open-small': { - 'name': 'Stable Audio Open Small', - 'repo': models_repo if use_custom_repo else models_repo_small, - 'files': [small_file], - 'size': '2.1 GB', - 'description': 'Fast generation, good quality, lower memory usage', - 'best_for': 'Beginners, quick experiments, limited GPU', - 'license': 'Stability AI License', - 'checksum': 'sha256:abc123...' - }, - 'stable-audio-open-1.0': { - 'name': 'Stable Audio Open 1.0', - 'repo': models_repo if use_custom_repo else models_repo_large, - 'files': [large_file], - 'size': '8.2 GB', - 'description': 'Highest quality, more detailed audio', - 'best_for': 'Professional use, high-end GPUs', - 'license': 'Stability AI License', - 'checksum': 'sha256:def456...' - } + # Project-wide policy: every HF download lands inside + # /models/pretrained/. SA3 generation + training uses + # /sa3/hub/; CLAP text deps use /clap/hub/. + # Both are HF cache layout so snapshot_download / hf_hub_download / + # from_pretrained resolve there transparently. + self.hub_dir: Path = self.models_dir / "sa3" / "hub" + self.hub_dir.mkdir(exist_ok=True, parents=True) + # Hard-force the resolution vars — never let an external env leak + # downloads into ~/.cache/huggingface or anywhere else outside the + # app folder. Covers huggingface_hub (current + legacy name) and + # transformers (which still consults TRANSFORMERS_CACHE). + os.environ["HF_HUB_CACHE"] = str(self.hub_dir) + os.environ["HUGGINGFACE_HUB_CACHE"] = str(self.hub_dir) + os.environ["TRANSFORMERS_CACHE"] = str(self.hub_dir) + + # available_models is exposed for backwards compat with the existing + # /api/models/available endpoint. New code should use get_catalog(). + self.available_models: Dict[str, Dict] = { + mid: dict(meta) for mid, meta in _SA3_CATALOG.items() } - self.terms_file = Path("config/terms_accepted.json") - self.terms_file.parent.mkdir(exist_ok=True) - - def get_available_models(self) -> List[Dict]: - - models = [] - - for model_id, info in self.available_models.items(): - is_downloaded = self.is_model_downloaded(model_id) - - downloaded_size = None - if is_downloaded: - if model_id == 'stable-audio-open-small': - model_file = self.models_dir / 'stable-audio-open-small-model.safetensors' - downloaded_size = self._get_file_size( - model_file) if model_file.exists() else None - elif model_id == 'stable-audio-open-1.0': - model_file = self.models_dir / 'stable-audio-open-model.safetensors' - downloaded_size = self._get_file_size( - model_file) if model_file.exists() else None - else: - model_path = self.models_dir / model_id - downloaded_size = self._get_downloaded_size( - model_path) if model_path.exists() else None - - models.append({ - 'id': model_id, - 'name': info['name'], - 'size': info['size'], - 'description': info['description'], - 'best_for': info['best_for'], - 'license': info['license'], - 'downloaded': is_downloaded, - 'downloaded_size': downloaded_size, - 'terms_accepted': self.is_terms_accepted(model_id) - }) - - return models - - def _get_file_size(self, file_path: Path) -> str: - - if not file_path.exists() or not file_path.is_file(): - return "0 B" + self._jobs: Dict[str, _DownloadJob] = {} + self._jobs_lock = threading.Lock() + + # --- Catalog -------------------------------------------------------------- + + def get_catalog(self, include_hidden: bool = False) -> List[Dict[str, Any]]: + """Checkpoint Manager catalog with per-item state. + + Default returns only user-visible entries (the three generation + models). `include_hidden=True` also returns base + standalone-AE + entries — used by the Phase 5 training subprocess to ensure the + right base variant is on disk before kicking train_lora.py. + """ + return [ + self._catalog_entry(mid) + for mid, info in _SA3_CATALOG.items() + if include_hidden or info.get("user_visible") + ] + + def _catalog_entry(self, model_id: str) -> Dict[str, Any]: + info = _SA3_CATALOG[model_id] + downloaded = self.is_model_downloaded(model_id) + bytes_total = 0 + if downloaded: + for d in (self._hub_cache_dir_for(model_id), self._legacy_flat_dir_for(model_id)): + if d.exists(): + bytes_total += self._dir_size(d) + + # Surface the most recent in-flight job for this model so the + # frontend can resume the progress bar after the Checkpoint Manager + # dialog is closed and reopened. The job lives on the backend; only + # the polling died with the dismissed UI. + active_job = None + with self._jobs_lock: + in_flight = [ + j for j in self._jobs.values() + if j.model_id == model_id and j.status in ("queued", "running") + ] + if in_flight: + in_flight.sort(key=lambda j: j.started_at or "", reverse=True) + active_job = in_flight[0].to_dict() - size = file_path.stat().st_size - return self._bytes_to_human(size) - - def _get_downloaded_size(self, model_path: Path) -> str: - - if not model_path.exists(): - return "0 B" - - total_size = 0 - for file_path in model_path.rglob("*"): - if file_path.is_file(): - total_size += file_path.stat().st_size + return { + "id": model_id, + "kind": info.get("kind"), + "name": info["name"], + "sa3_name": info["sa3_name"], + "repo": info["repo"], + "size_bytes": info["size_bytes"], + "hardware": info["hardware"], + "max_duration_sec": info.get("max_duration_sec"), + "description": info["description"], + "user_visible": info.get("user_visible", False), + "downloaded": downloaded, + "downloaded_bytes": bytes_total, + "active_job": active_job, + } - for unit in ['B', 'KB', 'MB', 'GB']: - if total_size < 1024.0: - return f"{total_size:.1f} {unit}" - total_size /= 1024.0 - return f"{total_size:.1f} TB" + def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]: + if model_id not in _SA3_CATALOG: + return None + return self._catalog_entry(model_id) - def get_model_info(self, model_id: str) -> Optional[Dict]: + # --- Filesystem layout ---------------------------------------------------- - if model_id not in self.available_models: - return None + def _hub_cache_dir_for(self, model_id: str) -> Path: + """HF-cache-shaped directory inside the app folder.""" + info = _SA3_CATALOG.get(model_id) + if info is None: + return self.hub_dir / "_unknown" + safe = "models--" + info["repo"].replace("/", "--") + return self.hub_dir / safe - info = self.available_models[model_id].copy() - info['id'] = model_id - info['downloaded'] = self.is_model_downloaded(model_id) - info['terms_accepted'] = self.is_terms_accepted(model_id) + def _legacy_flat_dir_for(self, model_id: str) -> Path: + """Pre-unification per-model dir. Read-only fallback for migration.""" + return self.models_dir / "sa3" / model_id - return info + def _local_dir_for(self, model_id: str) -> Path: + """Public: returns the canonical (HF cache) directory for a model.""" + return self._hub_cache_dir_for(model_id) def is_model_downloaded(self, model_id: str) -> bool: - - if model_id == 'stable-audio-open-small': - model_file = self.models_dir / 'stable-audio-open-small-model.safetensors' - return model_file.exists() and model_file.is_file() - elif model_id == 'stable-audio-open-1.0': - model_file = self.models_dir / 'stable-audio-open-model.safetensors' - return model_file.exists() and model_file.is_file() - else: - model_path = self.models_dir / model_id - if model_path.exists() and model_path.is_dir(): - return any(model_path.iterdir()) - pattern = f"*{model_id}*.safetensors" - matching_files = list(self.models_dir.glob(pattern)) - return len(matching_files) > 0 - - def is_terms_accepted(self, model_id: str) -> bool: - - if not self.terms_file.exists(): + if model_id not in _SA3_CATALOG: return False - - try: - with open(self.terms_file, 'r') as f: - terms_data = json.load(f) - return terms_data.get(model_id, {}).get('accepted', False) - except: + info = _SA3_CATALOG[model_id] + if info.get("kind") == "tagger": + # Single-file artifacts live in //. + # auto_annotator owns the exact path for CLAP, so we delegate. + from app.backend.data.auto_annotator import clap_checkpoint_available + return clap_checkpoint_available(self.models_dir) + # Canonical: HF cache layout under /models/pretrained/sa3/hub/. + # Look for the *top-level* model.safetensors only — NOT recursive — + # because a sibling repo may have only its conditioner subfolder + # downloaded (e.g. via the eager T5Gemma companion fetch when the + # user installed the matching *-base), and that doesn't make the + # post-trained model "downloaded". + main_present = False + hub = self._hub_cache_dir_for(model_id) + if hub.is_dir(): + snaps = hub / "snapshots" + if snaps.is_dir(): + for sub in snaps.iterdir(): + if any(sub.glob("*.safetensors")): + main_present = True + break + if not main_present: + # Fallback: legacy flat layout (predates the unification). Counts + # as downloaded for inference purposes; trainer will re-stage into + # hub. + legacy = self._legacy_flat_dir_for(model_id) + if legacy.is_dir() and any(legacy.glob("*.safetensors")): + main_present = True + if not main_present: return False - def accept_terms(self, model_id: str) -> bool: - - if model_id not in self.available_models: + # Base models need a T5Gemma conditioner that lives in a subfolder + # of the *post-trained sibling* repo. "Installed" must mean "ready + # to train / generate" — without the companion the first run blocks + # for 30s+ on an HF fetch. + return self._is_companion_present(model_id) + + def _is_companion_present(self, model_id: str) -> bool: + from app.core.training.sa3_lora_runner import SA3_T5GEMMA_SIBLINGS + sibling = SA3_T5GEMMA_SIBLINGS.get(model_id) + if not sibling: + return True # nothing to check (post-trained / autoencoder / tagger) + sib_repo, sib_subfolder = sibling + safe = "models--" + sib_repo.replace("/", "--") + sib_hub = self.hub_dir / safe + snaps = sib_hub / "snapshots" + if not snaps.is_dir(): return False - - terms_data = {} - if self.terms_file.exists(): - try: - with open(self.terms_file, 'r') as f: - terms_data = json.load(f) - except: - terms_data = {} - - terms_data[model_id] = { - 'accepted': True, - 'accepted_at': datetime.now().isoformat(), - 'model_name': self.available_models[model_id]['name'], - 'license': self.available_models[model_id]['license'] - } - + for sub in snaps.iterdir(): + if (sub / sib_subfolder).is_dir(): + # Any non-empty file presence is good enough — the eager + # fetch always pulls the tokenizer + config + safetensors. + if any((sub / sib_subfolder).iterdir()): + return True + return False + + # --- HF auth -------------------------------------------------------------- + + @staticmethod + def hf_auth_status() -> Dict[str, Any]: + token = get_token() + if not token: + return {"signed_in": False, "username": None} try: - with open(self.terms_file, 'w') as f: - json.dump(terms_data, f, indent=2) - return True - except Exception as e: - print(f"Error saving terms acceptance: {e}") + user = whoami(token=token) + return {"signed_in": True, "username": user.get("name") or user.get("fullname")} + except Exception as err: + return {"signed_in": False, "username": None, "error": str(err)} + + # --- Downloads ------------------------------------------------------------ + + def start_download( + self, + model_id: str, + progress_callback: Optional[Callable[[int, str], None]] = None, + ) -> Dict[str, Any]: + """Spawn a background download job. Returns the job descriptor.""" + if model_id not in _SA3_CATALOG: + return {"error": f"Unknown checkpoint: {model_id}"} + + job = _DownloadJob( + job_id=str(uuid.uuid4()), + model_id=model_id, + total_bytes=_SA3_CATALOG[model_id]["size_bytes"], + ) + with self._jobs_lock: + self._jobs[job.job_id] = job + + thread = threading.Thread( + target=self._run_download, + args=(job, progress_callback), + daemon=True, + name=f"sa3-download:{model_id}", + ) + job._thread = thread + thread.start() + return job.to_dict() + + def get_job(self, job_id: str) -> Optional[Dict[str, Any]]: + with self._jobs_lock: + job = self._jobs.get(job_id) + return job.to_dict() if job else None + + def list_jobs(self) -> List[Dict[str, Any]]: + with self._jobs_lock: + return [j.to_dict() for j in self._jobs.values()] + + def cancel_job(self, job_id: str) -> bool: + with self._jobs_lock: + job = self._jobs.get(job_id) + if not job: return False - - def download_model(self, model_id: str, progress_callback: Optional[Callable] = None) -> bool: - - if model_id not in self.available_models: + if job.status not in ("queued", "running"): return False - - if not self.is_terms_accepted(model_id): - print(f"Terms not accepted for {model_id}") - self.accept_terms(model_id) - print(f"Automatically accepted terms for {model_id}") - - model_info = self.available_models[model_id] - target_dir = self.models_dir - target_dir.mkdir(exist_ok=True, parents=True) - - try: - print(f"Downloading {model_info['name']} to {target_dir}") - - if progress_callback: - progress_callback( - 0, f"Starting download of {model_info['name']}...") - - from huggingface_hub import HfApi - api = HfApi() - + job._cancel_flag.set() + return True + + def _run_download( + self, + job: _DownloadJob, + progress_callback: Optional[Callable[[int, str], None]], + ) -> None: + info = _SA3_CATALOG[job.model_id] + job.status = "running" + job.started_at = datetime.now().isoformat() + + # Tagger kind (e.g. CLAP) is a .pt file plus auxiliary HF snapshots + # living outside the sa3/hub layout. Multi-phase: 1 hf_hub_download + # for the audio .pt, then N sequential snapshot_downloads for the + # text encoders. Each spawns its own tqdm bars, so we use the + # cumulative hook to accumulate bytes across phases, and a phase_cb + # to prefix the message with which step the user is on. + if info.get("kind") == "tagger": try: - user = api.whoami() - print(f"Authenticated as: {user}") + from app.backend.data.auto_annotator import download_clap_checkpoint if progress_callback: - progress_callback(10, "Authentication verified...") - except Exception as auth_error: - print(f"Not authenticated with Hugging Face: {auth_error}") - if progress_callback: - progress_callback(0, "Authentication required...") - print("To download models, you need to:") - print( - "1. Visit https://huggingface.co/stabilityai/stable-audio-open-small") - print("2. Accept the terms and conditions") - print("3. Log in to your Hugging Face account") - print( - "4. Get your access token from https://huggingface.co/settings/tokens") - print("5. Use the in-app Hugging Face login dialog") + progress_callback(0, f"Downloading {info['name']}…") + # Pin total to the catalog estimate so the % stays anchored + # even before tqdm reports any file's size. + job.total_bytes = info["size_bytes"] + current_phase = {"label": ""} + + def phase_cb(idx: int, total: int, label: str) -> None: + current_phase["label"] = f"[{idx}/{total}] {label}" + if progress_callback: + pct = (int(job.downloaded_bytes / job.total_bytes * 100) + if job.total_bytes else 0) + progress_callback(pct, current_phase["label"]) + + with _cumulative_tqdm_hook(job, progress_callback, current_phase): + download_clap_checkpoint(self.models_dir, phase_cb=phase_cb) + job.downloaded_bytes = job.total_bytes + job.status = "complete" + job.finished_at = datetime.now().isoformat() if progress_callback: - progress_callback(0, "Please authenticate in the app first") - return False + progress_callback(100, f"Downloaded {info['name']}") + except _DownloadCancelled: + job.status = "cancelled" + job.error = "Cancelled by user" + job.finished_at = datetime.now().isoformat() + except Exception as err: + job.status = "failed" + job.error = f"{type(err).__name__}: {err}" + job.finished_at = datetime.now().isoformat() + return + + cache_dir = self._hub_cache_dir_for(job.model_id).parent # = self.hub_dir + cache_dir.mkdir(exist_ok=True, parents=True) + target = self._hub_cache_dir_for(job.model_id) + + token = get_token() - if progress_callback: - progress_callback(20, "Starting file download...") - - try: - from huggingface_hub import hf_hub_download - import shutil - from tqdm import tqdm - import sys - - class TqdmToCallback: - def __init__(self, callback, file_index, total_files): - self.callback = callback - self.file_index = file_index - self.total_files = total_files - self.last_percent = 0 - - def __call__(self, t): - def inner(bytes_amount=1): - if t.total: - file_progress = (t.n / t.total) - overall_progress = (self.file_index + file_progress) / self.total_files - percent = 20 + int(overall_progress * 70) - - if percent != self.last_percent: - self.last_percent = percent - downloaded_mb = t.n / (1024 * 1024) - total_mb = t.total / (1024 * 1024) - if self.callback: - self.callback( - percent, - f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB" - ) - return inner - - downloaded_files = [] - total_files = len(model_info['files']) - - for i, file_pattern in enumerate(model_info['files']): + try: + with _tqdm_progress_hook(job, progress_callback): + # Write into hub/ in HF cache layout. snapshot_download in + # hf-hub 1.x populates `/models----/` + # with the blobs/refs/snapshots structure that + # hf_hub_download() and StableAudioModel.from_pretrained() + # both consume. + snapshot_download( + repo_id=info["repo"], + cache_dir=str(cache_dir), + token=token, + allow_patterns=[ + "*.safetensors", "*.json", "*.txt", "*.model", + "tokenizer*", "*.tiktoken", + ], + ) + + # Companion fetch: base models reference their T5Gemma + # conditioner in a subfolder of the *post-trained sibling* + # repo. Without it the training subprocess crashes at + # AutoTokenizer.from_pretrained, and inference can't build + # the conditioner either. Pull it eagerly so "Installed" + # actually means "ready to use". + from app.core.training.sa3_lora_runner import SA3_T5GEMMA_SIBLINGS + sibling = SA3_T5GEMMA_SIBLINGS.get(job.model_id) + if sibling: + sib_repo, sib_subfolder = sibling if progress_callback: progress_callback( - 20 + int((i / total_files) * 70), - f"Starting download of {file_pattern}..." + min(99, int(job.downloaded_bytes / max(1, job.total_bytes) * 100)), + f"Fetching T5Gemma conditioner from {sib_repo}…", ) - - try: - if file_pattern == 'model.safetensors': - if model_id == 'stable-audio-open-small': - final_filename = 'stable-audio-open-small-model.safetensors' - elif model_id == 'stable-audio-open-1.0': - final_filename = 'stable-audio-open-model.safetensors' - else: - final_filename = f"{model_id}-model.safetensors" - else: - final_filename = f"{model_id}-{file_pattern}" - - tqdm_callback = TqdmToCallback(progress_callback, i, total_files) - - # hf_hub_download drives its own tqdm — monkey-patch its init/update so we - # forward byte progress to progress_callback without a second progress bar. - original_tqdm_init = tqdm.__init__ - - def patched_tqdm_init(self, *args, **kwargs): - original_tqdm_init(self, *args, **kwargs) - original_update = self.update - def new_update(n=1): - result = original_update(n) - if progress_callback and self.total: - file_progress = (self.n / self.total) - overall_progress = (i + file_progress) / total_files - percent = 20 + int(overall_progress * 70) - downloaded_mb = self.n / (1024 * 1024) - total_mb = self.total / (1024 * 1024) - progress_callback( - percent, - f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB" - ) - return result - self.update = new_update - - tqdm.__init__ = patched_tqdm_init - - try: - downloaded_file = hf_hub_download( - repo_id=model_info['repo'], - filename=file_pattern, - resume_download=True - ) - finally: - tqdm.__init__ = original_tqdm_init - - downloaded_path = Path(downloaded_file) - final_path = target_dir / final_filename - - final_path.parent.mkdir(parents=True, exist_ok=True) - - shutil.copy2(str(downloaded_path), str(final_path)) - print(f"Saved as {final_filename}") - - downloaded_files.append(str(final_path)) - - if progress_callback: - progress_callback( - 20 + int(((i + 1) / total_files) * 70), - f"Completed {file_pattern}" - ) - - except Exception as file_error: - print( - f"Failed to download {file_pattern}: {file_error}") - if progress_callback: - progress_callback( - 0, f"Failed to download {file_pattern}") - continue - - print(f"Downloaded {len(downloaded_files)} files") - - if progress_callback: - progress_callback( - 95, "Download completed, verifying files...") - - except Exception as download_error: - print(f"Error during download: {download_error}") - if progress_callback: - progress_callback( - 0, f"Download failed: {str(download_error)}") - return False - + snapshot_download( + repo_id=sib_repo, + cache_dir=str(cache_dir), + token=token, + allow_patterns=[f"{sib_subfolder}/*"], + ) + job.status = "complete" + job.downloaded_bytes = self._dir_size(target) if progress_callback: - progress_callback(95, "Verifying download...") - - expected_files = [] - if model_id == 'stable-audio-open-small': - expected_files.append( - 'stable-audio-open-small-model.safetensors') - elif model_id == 'stable-audio-open-1.0': - expected_files.append('stable-audio-open-model.safetensors') - else: - expected_files.append(f"{model_id}-model.safetensors") - - files_exist = any((target_dir / expected_file).exists() - for expected_file in expected_files) - - if files_exist: - if progress_callback: - progress_callback(100, "Download complete!") - print(f"Successfully downloaded {model_info['name']}") - return True - else: - if progress_callback: - progress_callback(0, "Download verification failed") - print(f"Expected files not found: {expected_files}") - return False - - except Exception as e: - print(f"Error downloading {model_info['name']}: {e}") - if progress_callback: - progress_callback(0, f"Error: {str(e)}") - - if "403" in str(e) and "gated repositories" in str(e).lower(): - print("Token permission issue detected!") - print( - "Your Hugging Face token needs 'Read access to public gated repositories'") - print("Please:") - print("1. Go to https://huggingface.co/settings/tokens") - print("2. Edit your token or create a new one") - print("3. Enable 'Read access to public gated repositories'") - print("4. Try the download again") - elif "401" in str(e) or "restricted" in str(e).lower(): - print("This model requires Hugging Face authentication.") - print("Please visit the model page and accept terms first:") - print(f"https://huggingface.co/{model_info['repo']}") - return False + progress_callback(100, f"Downloaded {info['name']}") + except _DownloadCancelled: + job.status = "cancelled" + job.error = "Cancelled by user" + shutil.rmtree(target, ignore_errors=True) + except GatedRepoError as err: + job.status = "failed" + job.error = f"hf_auth_required: {err}" + except RepositoryNotFoundError as err: + job.status = "failed" + job.error = f"Repository not found: {err}" + except Exception as err: + job.status = "failed" + job.error = str(err) + finally: + job.finished_at = datetime.now().isoformat() + + # --- Delete --------------------------------------------------------------- def delete_model(self, model_id: str) -> bool: - - deleted_something = False - - if model_id == 'stable-audio-open-small': - model_file = self.models_dir / 'stable-audio-open-small-model.safetensors' - config_file = self.models_dir / 'stable-audio-open-small-config.json' - elif model_id == 'stable-audio-open-1.0': - model_file = self.models_dir / 'stable-audio-open-model.safetensors' - config_file = self.models_dir / 'stable-audio-open-1.0-config.json' - else: - model_file = self.models_dir / f"{model_id}-model.safetensors" - config_file = self.models_dir / f"{model_id}-config.json" - - for file_path in [model_file, config_file]: - if file_path.exists(): - try: - file_path.unlink() - print(f"Deleted {file_path.name}") - deleted_something = True - except Exception as e: - print(f"Error deleting {file_path.name}: {e}") - - model_path = self.models_dir / model_id - if model_path.exists() and model_path.is_dir(): - try: - shutil.rmtree(model_path) - print(f"Deleted {model_id} directory") - deleted_something = True - except Exception as e: - print(f"Error deleting {model_id} directory: {e}") - - if deleted_something: - print(f"Deleted {model_id}") - return True - else: - print(f"No files found for {model_id}") + if model_id not in _SA3_CATALOG: return False - - def get_download_progress(self, model_id: str) -> Dict: - - return { - 'model_id': model_id, - 'downloaded': self.is_model_downloaded(model_id), - 'size': self.available_models.get(model_id, {}).get('size', 'Unknown') - } - - def get_storage_info(self) -> Dict: - - total_size = 0 - model_count = 0 - - if self.models_dir.exists(): - for model_id in self.available_models.keys(): - if self.is_model_downloaded(model_id): - model_count += 1 - - for file_path in self.models_dir.rglob("*"): - if file_path.is_file(): - total_size += file_path.stat().st_size - + # Remove both the canonical hub copy and the legacy flat copy if + # they exist. Either being present is enough to consider the + # model "downloaded", so both must be cleaned for the row to + # flip back to "Get". + hub = self._hub_cache_dir_for(model_id) + legacy = self._legacy_flat_dir_for(model_id) + any_existed = hub.exists() or legacy.exists() + if hub.exists(): + shutil.rmtree(hub, ignore_errors=True) + if legacy.exists(): + shutil.rmtree(legacy, ignore_errors=True) + return any_existed and not (hub.exists() or legacy.exists()) + + # --- Storage -------------------------------------------------------------- + + def get_storage_info(self) -> Dict[str, Any]: + per_model: List[Dict[str, Any]] = [] + total_used = 0 + for mid in _SA3_CATALOG: + bytes_ = 0 + for d in (self._hub_cache_dir_for(mid), self._legacy_flat_dir_for(mid)): + if d.exists(): + bytes_ += self._dir_size(d) + per_model.append({ + "id": mid, + "downloaded": self.is_model_downloaded(mid), + "bytes": bytes_, + }) + total_used += bytes_ return { - 'total_size_bytes': total_size, - 'total_size_human': self._bytes_to_human(total_size), - 'model_count': model_count, - 'models_dir': str(self.models_dir) + "total_used_bytes": total_used, + "total_free_bytes": shutil.disk_usage(self.models_dir).free, + "per_model": per_model, } - def _bytes_to_human(self, bytes_value: int) -> str: - - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_value < 1024.0: - return f"{bytes_value:.1f} {unit}" - bytes_value /= 1024.0 - return f"{bytes_value:.1f} TB" + # --- Helpers -------------------------------------------------------------- + + @staticmethod + def _dir_size(path: Path) -> int: + if not path.exists(): + return 0 + return sum(p.stat().st_size for p in path.rglob("*") if p.is_file()) + +# --- tqdm hook ---------------------------------------------------------------- + +import contextlib + +@contextlib.contextmanager +def _tqdm_progress_hook( + job: _DownloadJob, + progress_callback: Optional[Callable[[int, str], None]], +): + """Monkey-patch tqdm so snapshot_download updates flow into the job state. + + `snapshot_download` doesn't expose a progress callback. tqdm is its + internal progress bar — we wrap `update` to update job state and raise + `_DownloadCancelled` when the job's cancel flag fires. + """ + from tqdm.auto import tqdm + original_init = tqdm.__init__ + + def patched_init(self, *args: Any, **kwargs: Any) -> None: + original_init(self, *args, **kwargs) + original_update = self.update + + def new_update(n: int = 1) -> Any: + if job._cancel_flag.is_set(): + raise _DownloadCancelled() + result = original_update(n) + if self.total: + job.downloaded_bytes = max(job.downloaded_bytes, self.n) + if job.total_bytes < self.total: + job.total_bytes = self.total + if progress_callback: + pct = int(self.n / self.total * 100) if self.total else 0 + mb_done = self.n / (1024 * 1024) + mb_total = self.total / (1024 * 1024) + progress_callback(pct, f"Downloading: {mb_done:.1f}MB / {mb_total:.1f}MB") + return result + + self.update = new_update # type: ignore[method-assign] + + tqdm.__init__ = patched_init # type: ignore[method-assign] + try: + yield + finally: + tqdm.__init__ = original_init # type: ignore[method-assign] + + +@contextlib.contextmanager +def _cumulative_tqdm_hook( + job: _DownloadJob, + progress_callback: Optional[Callable[[int, str], None]], + current_phase: Dict[str, str], +): + """Like _tqdm_progress_hook, but sums bytes across sequential bars. + + Each tqdm bar reports `self.n` cumulative within ITS file. The single-bar + hook uses max() which freezes the UI when a fresh bar starts smaller than + the previous bar's total. Here we track the previous `self.n` per bar id + and add only the delta to job.downloaded_bytes — so progress climbs + monotonically across all phases. + """ + from tqdm.auto import tqdm + original_init = tqdm.__init__ + prev_n: Dict[int, int] = {} + + def patched_init(self, *args: Any, **kwargs: Any) -> None: + original_init(self, *args, **kwargs) + original_update = self.update + prev_n[id(self)] = 0 + + def new_update(n: int = 1) -> Any: + if job._cancel_flag.is_set(): + raise _DownloadCancelled() + result = original_update(n) + prev = prev_n.get(id(self), 0) + delta = self.n - prev + prev_n[id(self)] = self.n + if delta > 0: + job.downloaded_bytes += delta + if progress_callback and job.total_bytes: + pct = min(int(job.downloaded_bytes / job.total_bytes * 100), 99) + mb_done = job.downloaded_bytes / (1024 * 1024) + mb_total = job.total_bytes / (1024 * 1024) + label = current_phase.get("label", "") + msg = (f"{label} · {mb_done:.0f} MB / {mb_total:.0f} MB" + if label else f"{mb_done:.0f} MB / {mb_total:.0f} MB") + progress_callback(pct, msg) + return result + + self.update = new_update # type: ignore[method-assign] + + tqdm.__init__ = patched_init # type: ignore[method-assign] + try: + yield + finally: + tqdm.__init__ = original_init # type: ignore[method-assign] diff --git a/app/core/training/hyperparam_suggester.py b/app/core/training/hyperparam_suggester.py index ef1653735c50860a01e34c2676ed9c522cae389b..653c7a7223fcd7b22889f3de2c664722b171fdf1 100644 --- a/app/core/training/hyperparam_suggester.py +++ b/app/core/training/hyperparam_suggester.py @@ -1,76 +1,67 @@ -"""Heuristic hyperparameter suggester for the Training tab's "Suggest" button. - -Given the dataset on disk and the current hardware, returns a config that -trades off "small dataset, needs more updates per epoch" vs "big dataset, -batch up for throughput", plus the practical VRAM ceilings of the LoRA path -on Stable Audio Open 1.0. Returns the same shape the frontend `trainingConfig` -uses, so Apply can spread the result into state directly. +"""SA3 LoRA hyperparameter suggester for the Training tab's "Suggest" button. + +Reads a Dataset Workbench project directly — counts SA3-compatible audio +files, measures their durations via the same `soundfile.info()` header-only +probe used elsewhere in the app, factors in the user's picked base model +and detected GPU VRAM, and returns a config that: + + * matches the upstream SA3 LoRA docs as the starting point + (see vendor/stable-audio-3/docs/workflows/lora.md) + * sets `--include transformer.layers` and `--exclude seconds_total + to_local_embed` by default (documented best practices, prevents the + "conditioner hijacking" failure mode on small datasets) + * picks a `-XS` adapter family when VRAM is tight for the chosen base + * proposes a `duration` derived from the actual clip lengths in the + project — not a hardcoded 30s + * warns when the dataset is below SA3's documented minimum (~20 clips) + or when clips are too short to learn from + +Returns the same shape the frontend `trainingConfig` uses, so Apply can +spread the result into state directly. """ from __future__ import annotations -import json -import os -import subprocess +import math from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple + +from app.backend.data.projects import _clip_duration_sec +from app.core.training.sa3_lora_runner import SA3_AUDIO_EXTENSIONS, SA3_BASE_MODELS -AUDIO_EXTS = {".wav", ".mp3", ".flac", ".m4a"} -# Cache file for total-duration measurement. ffprobe across 500 files takes -# 10-30s; we don't want to pay that on every button click. Cache key is the -# (file_count, max_mtime_int) pair — invalidates automatically when files -# are added/removed/touched. -_DURATION_CACHE_NAME = ".duration_cache.json" +# --- Discovery ------------------------------------------------------------- def _list_audio_files(data_dir: Path) -> List[Path]: + """Files SA3's loader would actually train on. Mirrors the loader's filter.""" if not data_dir.exists(): return [] return [ p for p in data_dir.iterdir() - if p.is_file() and p.suffix.lower() in AUDIO_EXTS + if p.is_file() and p.suffix.lower() in SA3_AUDIO_EXTENSIONS ] -def _measure_total_duration(audio_files: List[Path], cache_path: Path) -> float: - if not audio_files: - return 0.0 - - file_count = len(audio_files) - max_mtime = int(max(p.stat().st_mtime for p in audio_files)) - cache_key = f"{file_count}:{max_mtime}" - - if cache_path.exists(): - try: - cached = json.loads(cache_path.read_text()) - if cached.get("key") == cache_key: - return float(cached["duration_sec"]) - except Exception: - pass - - total = 0.0 +def _duration_stats(audio_files: List[Path]) -> Dict[str, Optional[float]]: + """Header-only duration probe + summary stats. None-safe for unreadable files.""" + durations: List[float] = [] for f in audio_files: - try: - out = subprocess.check_output( - ["ffprobe", "-v", "error", "-show_entries", "format=duration", - "-of", "default=noprint_wrappers=1:nokey=1", str(f)], - text=True, timeout=10, - ).strip() - total += float(out) - except Exception: - # Skip files ffprobe can't read; better to under-report than crash. - continue - - try: - cache_path.write_text(json.dumps({ - "key": cache_key, - "duration_sec": total, - })) - except Exception: - pass - - return total + d = _clip_duration_sec(f) + if d is not None and d > 0: + durations.append(d) + if not durations: + return {"count": 0, "total": 0.0, "median": None, "p95": None, "max": None, "min": None} + durations.sort() + n = len(durations) + return { + "count": n, + "total": float(sum(durations)), + "median": float(durations[n // 2]), + "p95": float(durations[min(n - 1, int(math.ceil(0.95 * n)) - 1)]), + "max": float(durations[-1]), + "min": float(durations[0]), + } def _detect_vram_gb() -> Optional[float]: @@ -83,6 +74,9 @@ def _detect_vram_gb() -> Optional[float]: return None +# --- Bucketing & sizing ---------------------------------------------------- + + def _bucket(file_count: int) -> str: if file_count < 20: return "tiny" @@ -93,66 +87,136 @@ def _bucket(file_count: int) -> str: return "large" -def _heuristic(file_count: int, vram_gb: Optional[float], mode: str) -> Dict[str, Any]: - """The rules-of-thumb. Same shape regardless of mode; the frontend ignores - LoRA-specific keys when mode='full'.""" - +# SA3's documented quick-start: --steps 1000, with no dataset-size caveat. +# (vendor/stable-audio-3/docs/workflows/lora.md, "Standard (recommended starting point)".) +# SA3 trains by *windows seen*, not epochs, so a 5h dataset doesn't need more +# steps than a 30min one — it just produces more diverse sampling per step. +# We keep the SA3 default for tiny/small, and bump modestly only when a +# dataset is large enough that 1000 steps won't see all unique windows. +_STEPS_BY_BUCKET: Dict[str, int] = { + "tiny": 1000, + "small": 1000, + "medium": 2000, + "large": 4000, +} + + +# Per-base-model VRAM table from SA3 docs. (standard_gb, xs_bf16_gb) +# Source: docs/workflows/lora.md memory table. +_VRAM_REQ: Dict[str, Tuple[float, float]] = { + "sa3-small-music-base": (2.5, 2.0), + "sa3-small-sfx-base": (2.5, 2.0), + "sa3-medium-base": (6.5, 5.5), +} + + +def _pick_adapter(base_model: Optional[str], vram_gb: Optional[float]) -> Tuple[str, bool]: + """Choose adapter family. Returns (adapter_type, vram_constrained_flag). + + SA3 docs recommend the `-xs` family + bf16 base precision for VRAM-limited + hosts. Headroom rule: standard_gb + 4 GB activations is the comfort target; + below that we pick the xs family. + """ + default = "dora-rows" + if base_model is None or vram_gb is None: + return default, False + std_gb, _xs_gb = _VRAM_REQ.get(base_model, (2.5, 2.0)) + comfort = std_gb + 4.0 + constrained = vram_gb < comfort + return ("dora-rows-xs" if constrained else default), constrained + + +def _model_max_window_sec(base_model: Optional[str]) -> float: + """SA3's native training length for the base, from its model config + sample_size / sample_rate: medium-base ≈380s, small bases ≈120s. The + `seconds_total` conditioner caps at 384s, so 380 is the safe medium ceiling. + Longer windows aren't a model limit below these — they're VRAM/time bound. + """ + if base_model and "medium" in base_model: + return 380.0 + return 120.0 + + +def _pick_duration(p95_clip_sec: Optional[float], base_model: Optional[str]) -> float: + """Set training window from the project's actual p95 clip length. + + Floors at 5s; caps at — and defaults to — the model's native length + (≈120s small / ≈380s medium) rather than an arbitrary 30s. SA3 random-crops + longer files, so the only real limits are the model's sequence length and + VRAM. Rounds up p95 with 2s headroom so the window isn't cropping the tails + of typical clips. With no duration data, defaults to the model max. + """ + model_max = _model_max_window_sec(base_model) + if p95_clip_sec is None or p95_clip_sec <= 0: + return model_max + suggested = math.ceil(p95_clip_sec + 2.0) + return float(max(5, min(model_max, suggested))) + + +def _pick_batch_size(bucket: str, vram_gb: Optional[float]) -> int: + """SA3 examples all use batch 1. Only go higher on roomy hardware + big data. + + 24 GB threshold for batch 2 leaves enough headroom for medium-base + bf16 + activations across two samples. Going beyond batch 2 hits diminishing + returns and risks OOM mid-run. + """ + if vram_gb is None or vram_gb < 24: + return 1 + if bucket in ("medium", "large"): + return 2 + return 1 + + +# Filter pattern straight from SA3 docs: +# --include transformer.layers --exclude seconds_total to_local_embed +# "Everything except local embedding and seconds_total conditioner" — prevents +# the conditioner-hijacking failure mode that bites small datasets hardest. +_INCLUDE_DEFAULT: List[str] = ["transformer.layers"] +_EXCLUDE_DEFAULT: List[str] = ["seconds_total", "to_local_embed"] + + +# --- Suggestion + rationale ------------------------------------------------ + + +def _heuristic( + file_count: int, + dur_stats: Dict[str, Optional[float]], + base_model: Optional[str], + vram_gb: Optional[float], +) -> Dict[str, Any]: bucket = _bucket(file_count) - has_vram = vram_gb is not None - constrained = (has_vram and vram_gb < 12) - - # Target total weight updates. Sublinear with dataset size so tiny sets - # still get enough gradient steps, while large sets don't run forever. - target_steps_by_bucket = { - "tiny": 2500, - "small": 2000, - "medium": 1500, - "large": 3000, - } - target_steps = target_steps_by_bucket[bucket] - - # Rank/LR/alpha scale with how much "capacity per data point" the run needs. - # Small dataset trick: keep rank moderate (16) and conservative LR (1e-4 — - # 2e-4 caused overshoot/flat loss in testing), but boost alpha so the - # LoRA delta trains at higher effective voltage (scaling = alpha/rank). - # This produces a stronger imprint without the parameter bloat of rank=32 - # or the instability of higher LR. - if bucket in ("tiny", "small"): - rank, alpha, lr = 16, 32, 1e-4 - else: - rank, alpha, lr = 16, 16, 1e-4 - - # Batch size: smaller on small datasets (more updates per epoch + better - # gradient noise); larger on medium/large for throughput. VRAM caps the top. - if bucket == "tiny": - batch = 1 if constrained else 2 - elif bucket == "small": - # Hold batch=2 even on roomy VRAM — the noise benefit on a small - # dataset outweighs the throughput win, and it keeps the epoch - # count to a reasonable display number. - batch = 2 - elif bucket == "medium": - batch = 2 if constrained else 4 - else: - batch = 4 if constrained else 8 + steps = _STEPS_BY_BUCKET[bucket] + adapter, constrained = _pick_adapter(base_model, vram_gb) + duration = _pick_duration(dur_stats.get("p95"), base_model) + batch = _pick_batch_size(bucket, vram_gb) + + # Mild dropout for tiny datasets only — extra regularization where overfit + # is most likely. SA3 default is 0.0; we deviate intentionally. + dropout = 0.05 if bucket == "tiny" else 0.0 - steps_per_epoch = max(1, file_count // batch) - epochs = max(20, round(target_steps / steps_per_epoch)) + # Checkpoint cadence: ~10 checkpoints per run, but keep within sane bounds + # so we don't write a checkpoint every 50 steps on tiny runs or sit on a + # 2K-step gap on long ones. + checkpoint_every = max(250, min(1000, steps // 10)) return { + "steps": steps, "batchSize": batch, - "learningRate": lr, - "epochs": epochs, - "loraRank": rank, - "loraAlpha": alpha, - "loraDropout": 0, - "loraMultiplier": 1.0, + "learningRate": 1e-4, + "loraRank": 16, + "loraAlpha": 16, + "loraDropout": dropout, + "adapterType": adapter, + "precision": "bf16", + "duration": duration, + "checkpointSteps": checkpoint_every, + "include": list(_INCLUDE_DEFAULT), + "exclude": list(_EXCLUDE_DEFAULT), "_meta": { "bucket": bucket, - "target_steps": target_steps, - "steps_per_epoch": steps_per_epoch, - "total_steps": steps_per_epoch * epochs, + "target_steps": steps, "vram_constrained": constrained, + "picked_adapter_for_vram": constrained, }, } @@ -166,68 +230,162 @@ def _format_duration(seconds: float) -> str: return f"{m}m {s}s" -def _compose_rationale(file_count: int, duration_sec: float, vram_gb: Optional[float], - mode: str, meta: Dict[str, Any]) -> List[str]: - """Human-readable explanation, returned as a list of bullet strings.""" - bullets = [] +def _compose_rationale( + file_count: int, + dur_stats: Dict[str, Optional[float]], + base_model: Optional[str], + vram_gb: Optional[float], + config: Dict[str, Any], + meta: Dict[str, Any], +) -> Tuple[List[str], List[str]]: + """Return (bullets, warnings). Warnings are surfaced separately in the UI.""" + bullets: List[str] = [] + warnings: List[str] = [] + + total = dur_stats.get("total") or 0.0 bullets.append( - f"Dataset: {file_count} audio file{'s' if file_count != 1 else ''}, " - f"total {_format_duration(duration_sec)} → " - f"\"{meta['bucket']}\" bucket." + f"Dataset: {file_count} clip{'s' if file_count != 1 else ''}, " + f"total {_format_duration(total)} → \"{meta['bucket']}\" bucket." ) + + p95 = dur_stats.get("p95") + median = dur_stats.get("median") + if p95 is not None and median is not None: + bullets.append( + f"Clip durations: median {median:.1f}s, p95 {p95:.1f}s. " + f"Training window set to {config['duration']:.0f}s." + ) + if vram_gb is not None: - constraint = "VRAM-constrained" if meta["vram_constrained"] else "comfortable VRAM headroom" - bullets.append(f"Detected GPU with {vram_gb:.1f} GB ({constraint}).") + bullets.append( + f"Detected GPU: {vram_gb:.1f} GB" + + (" (tight for the chosen base — switched adapter to a -XS variant)." + if meta["vram_constrained"] else " (comfortable headroom).") + ) else: - bullets.append("No GPU detected — assuming consumer-class constraints.") - bullets.append( - f"Targeting ~{meta['target_steps']} weight updates total; with batch_size " - f"the dataset gives {meta['steps_per_epoch']} steps/epoch, so " - f"{meta['total_steps']} steps over the recommended epoch count." - ) - if meta["bucket"] in ("tiny", "small"): + bullets.append("No CUDA GPU detected — adapter defaults to dora-rows; " + "training will run on CPU/MPS where supported.") + + if meta["target_steps"] == 1000: bullets.append( - "Small dataset → conservative 1e-4 LR + rank=16 for stability, " - "but alpha=32 (alpha/rank = 2.0) so the LoRA delta trains at " - "double voltage. Stronger imprint without overshoot risk." + "Target 1 000 optimizer steps — SA3's documented quick-start. " + "LoRAs typically overfit well before this; watch the loss curve." ) else: bullets.append( - "Larger dataset → moderate batch + standard 1e-4 LR. Rank=16 has " - "plenty of capacity for the prompt distribution this size implies." + f"Target {meta['target_steps']:,} optimizer steps — modest bump " + f"above SA3's 1 000-step default for larger datasets to see more " + "unique sampling windows." ) - return bullets + bullets.append( + f"Layer filter: include `{config['include'][0]}`, exclude " + f"`{' '.join(config['exclude'])}`. " + "Documented SA3 default — prevents conditioner-hijacking on small sets." + ) -def suggest(data_dir: Path, mode: str = "lora") -> Dict[str, Any]: - """Public entry point. Returns the suggestion + a rationale + raw stats.""" + bullets.append( + f"Adapter `{config['adapterType']}` · rank 16 · α 16 · " + f"dropout {config['loraDropout']} · {config['precision']} base." + ) + + # --- Warnings (separate channel) --------------------------------------- + + if file_count < 20: + warnings.append( + f"{file_count} clips is below SA3's documented minimum of ~20. " + "Expect heavy overfit and poor generalization — add more data if you can." + ) + if median is not None and median < 2.0: + warnings.append( + f"Median clip is only {median:.1f}s — most of the training window " + f"({config['duration']:.0f}s) will be silence-padded. " + "Re-slice the source material to longer chunks for better signal." + ) + if config["duration"] > 45: + warnings.append( + f"Training window is {config['duration']:.0f}s. Longer windows use " + "markedly more VRAM and step time (DiT attention scales with length). " + "If you hit OOM, lower the window or pre-encode the dataset first." + ) + + # VRAM × base model crosscheck + if base_model in _VRAM_REQ: + std_gb, xs_gb = _VRAM_REQ[base_model] + if vram_gb is None: + if base_model == "sa3-medium-base": + warnings.append( + "No CUDA GPU detected, but you picked Medium-Base. " + "Medium-base needs CUDA + Flash-Attn 2 (Linux) and ≥5.5 GB VRAM. " + "Consider Small-Music-Base or Small-SFX-Base for CPU/MPS hosts." + ) + elif vram_gb < xs_gb: + warnings.append( + f"GPU has {vram_gb:.1f} GB; even {base_model} with bf16+lora-xs needs " + f"~{xs_gb:.1f} GB. Training will likely OOM. Pick a smaller base." + ) + elif vram_gb < std_gb: + warnings.append( + f"GPU has {vram_gb:.1f} GB; {base_model} standard config needs " + f"~{std_gb:.1f} GB. The -XS adapter (selected) brings it to ~{xs_gb:.1f} GB." + ) + + return bullets, warnings + + +def suggest(data_dir: Path, base_model: Optional[str] = None) -> Dict[str, Any]: + """Public entry point. SA3 is LoRA-only; no `mode` switch.""" audio_files = _list_audio_files(data_dir) file_count = len(audio_files) if file_count == 0: return { "ok": False, - "error": f"No audio files found in {data_dir}", + "error": ( + f"No SA3-compatible audio in {data_dir}. SA3's loader accepts " + + ", ".join(SA3_AUDIO_EXTENSIONS) + "." + ), } - cache_path = data_dir / _DURATION_CACHE_NAME - duration_sec = _measure_total_duration(audio_files, cache_path) + dur_stats = _duration_stats(audio_files) vram_gb = _detect_vram_gb() - suggestion = _heuristic(file_count, vram_gb, mode) + suggestion = _heuristic(file_count, dur_stats, base_model, vram_gb) meta = suggestion.pop("_meta") - rationale = _compose_rationale(file_count, duration_sec, vram_gb, mode, meta) + bullets, warnings = _compose_rationale( + file_count, dur_stats, base_model, vram_gb, suggestion, meta + ) + + # Caption coverage: SA3 trains on audio + matching .txt sidecars, and + # silently drops clips whose prompt is blank. Surface missing captions so + # the user isn't unknowingly training on a fraction of the dataset. + uncaptioned = sum( + 1 for p in audio_files + if not (p.with_suffix(".txt").exists() + and p.with_suffix(".txt").read_text(encoding="utf-8", errors="ignore").strip()) + ) + if uncaptioned: + warnings.insert(0, ( + f"{uncaptioned} of {file_count} clip{'s' if file_count != 1 else ''} " + "have no annotation. SA3 silently skips un-captioned clips at train " + "time — annotate them first or they won't contribute to the LoRA." + )) return { "ok": True, "stats": { "file_count": file_count, - "duration_sec": duration_sec, - "duration_human": _format_duration(duration_sec), + "duration_sec": dur_stats.get("total") or 0.0, + "duration_human": _format_duration(dur_stats.get("total") or 0.0), + "median_clip_sec": dur_stats.get("median"), + "p95_clip_sec": dur_stats.get("p95"), + "max_clip_sec": dur_stats.get("max"), + "min_clip_sec": dur_stats.get("min"), "vram_gb": round(vram_gb, 2) if vram_gb is not None else None, "bucket": meta["bucket"], - "steps_per_epoch": meta["steps_per_epoch"], - "total_steps": meta["total_steps"], + "total_steps": meta["target_steps"], + "base_model": base_model, }, "config": suggestion, - "rationale": rationale, + "rationale": bullets, + "warnings": warnings, } diff --git a/app/core/training/sa3_lora_runner.py b/app/core/training/sa3_lora_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..b967dff4835ff4d948c7d016315396bac0471a82 --- /dev/null +++ b/app/core/training/sa3_lora_runner.py @@ -0,0 +1,331 @@ +"""Helpers for the SA3 LoRA training pipeline. + +Responsibilities: + * Pre-stage the base model in an app-folder HF cache so the training + subprocess finds it without falling back to ~/.cache/huggingface. + * Build the train_lora.py subprocess command + env. + * Convert PyTorch Lightning .ckpt LoRA outputs to SA3-native .safetensors + with the base_model and run name embedded in the metadata header. +""" +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +# SA3 model_id → (sa3_name passed to train_lora.py --model, HF repo id) +# Only `*-base` variants are valid LoRA targets — SA3 won't train against +# the post-trained / distilled checkpoints. +SA3_BASE_MODELS: Dict[str, Tuple[str, str]] = { + "sa3-small-music-base": ("small-music-base", "stabilityai/stable-audio-3-small-music-base"), + "sa3-small-sfx-base": ("small-sfx-base", "stabilityai/stable-audio-3-small-sfx-base"), + "sa3-medium-base": ("medium-base", "stabilityai/stable-audio-3-medium-base"), +} + + +# Each *-base config references its T5Gemma conditioner at a subfolder of the +# *post-trained sibling* repo (e.g., medium-base's t5gemma lives at +# stabilityai/stable-audio-3-medium / t5gemma-b-b-ul2/). Without that subtree +# in the cache, training crashes inside the conditioner constructor when SA3 +# does `AutoTokenizer.from_pretrained(repo_id, subfolder=...)`. +# Keep in sync with model_config.json's `conditioning.configs[0].config.repo_id`. +SA3_T5GEMMA_SIBLINGS: Dict[str, Tuple[str, str]] = { + "sa3-small-music-base": ("stabilityai/stable-audio-3-small-music", "t5gemma-b-b-ul2"), + "sa3-small-sfx-base": ("stabilityai/stable-audio-3-small-sfx", "t5gemma-b-b-ul2"), + "sa3-medium-base": ("stabilityai/stable-audio-3-medium", "t5gemma-b-b-ul2"), +} + + +# Extensions SA3's training data loader actually accepts. +# Source: vendor/stable-audio-3/stable_audio_3/data/dataset.py:91. +# Single source of truth — both the health check and the hyperparam suggester +# use this so what we count matches what the loader will train on. +SA3_AUDIO_EXTENSIONS: Tuple[str, ...] = (".wav", ".mp3", ".flac", ".ogg", ".aif", ".opus") + + +# --- Base model pre-staging ------------------------------------------------- + +def prestage_base_model( + sa3_model_id: str, + hub_dir: Path, + token: Optional[str] = None, + progress_callback: Optional[Any] = None, +) -> Path: + """Ensure the base model is in `hub_dir` (HF-cache layout, inside app folder). + + train_lora.py calls `model_cfg.resolve()` which is hf_hub_download under + the hood — it reads from the HF cache root. We point it at hub_dir via + the HF_HUB_CACHE env var on the subprocess; for that to actually find + files we need to download into hub_dir using snapshot_download with + `cache_dir=hub_dir`. + + Idempotent: if the model is already cached there, returns the cached + snapshot dir without re-downloading. + """ + if sa3_model_id not in SA3_BASE_MODELS: + raise ValueError( + f"'{sa3_model_id}' is not a valid LoRA base. Pick one of " + f"{list(SA3_BASE_MODELS)} (only *-base variants are CFG-aware)." + ) + sa3_name, repo_id = SA3_BASE_MODELS[sa3_model_id] + hub_dir.mkdir(parents=True, exist_ok=True) + + from huggingface_hub import snapshot_download + + allow_patterns = [ + "*.safetensors", "*.json", "*.txt", "*.model", + "tokenizer*", "*.tiktoken", + ] + + if progress_callback: + progress_callback(5, f"Staging {sa3_name} base model in {hub_dir.name}/...") + + # Prefer cache. snapshot_download otherwise phones home on every run to + # check the model's revision — wasteful and noisy when the user just + # downloaded the weights through the Checkpoint Manager. If anything's + # missing, fall back to an online fetch. + try: + local_snap = snapshot_download( + repo_id=repo_id, + cache_dir=str(hub_dir), + token=token, + allow_patterns=allow_patterns, + local_files_only=True, + ) + if progress_callback: + progress_callback(15, "Base model ready (from cache).") + except Exception: + if progress_callback: + progress_callback(8, "Cache miss — fetching from HuggingFace…") + local_snap = snapshot_download( + repo_id=repo_id, + cache_dir=str(hub_dir), + token=token, + allow_patterns=allow_patterns, + ) + if progress_callback: + progress_callback(15, "Base model ready.") + + # Pre-stage the T5Gemma conditioner from the post-trained sibling repo. + # SA3's *-base model_config.json points the prompt conditioner at + # e.g. stabilityai/stable-audio-3-medium / t5gemma-b-b-ul2/, NOT at the + # base repo. Without this subtree in the cache, the training subprocess + # (HF_HUB_OFFLINE=1) crashes when AutoTokenizer.from_pretrained tries + # to phone home. + sibling = SA3_T5GEMMA_SIBLINGS.get(sa3_model_id) + if sibling: + sib_repo, sib_subfolder = sibling + sib_patterns = [f"{sib_subfolder}/*"] + if progress_callback: + progress_callback(16, f"Staging T5Gemma conditioner from {sib_repo}…") + try: + snapshot_download( + repo_id=sib_repo, + cache_dir=str(hub_dir), + token=token, + allow_patterns=sib_patterns, + local_files_only=True, + ) + if progress_callback: + progress_callback(18, "T5Gemma conditioner ready (from cache).") + except Exception: + if progress_callback: + progress_callback(17, f"T5Gemma cache miss — fetching from {sib_repo}…") + snapshot_download( + repo_id=sib_repo, + cache_dir=str(hub_dir), + token=token, + allow_patterns=sib_patterns, + ) + if progress_callback: + progress_callback(18, "T5Gemma conditioner ready.") + + return Path(local_snap) + + +# --- Subprocess command builder --------------------------------------------- + +def build_train_command( + *, + venv_python: str, + sa3_vendor_dir: Path, + sa3_model_name: str, + data_dir: Path, + encoded_dir: Optional[Path] = None, + svd_bases_path: Optional[Path] = None, + save_dir: Path, + rank: int = 16, + lora_alpha: Optional[int] = None, + adapter_type: str = "dora-rows", + dropout: float = 0.0, + lr: float = 1e-4, + steps: int = 5000, + batch_size: int = 1, + duration: float = 30.0, + base_precision: str = "bf16", + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + seed: int = 42, + checkpoint_every: int = 500, + # `--log_every` controls how often DiffusionCondTrainingWrapper calls + # self.log(). 50 is SA3's example value and gives a much cleaner chart + # than per-step logging — diffusion loss is intrinsically noisy (each + # step samples a random timestep), so per-step values bounce wildly and + # the trend is hard to read. Sampling every 50 steps gives ~20 points + # for a 1000-step run, which the EMA smoother turns into a legible + # descent. First point arrives after step 49 (≈15s on small, ≈50s on + # medium, dominated by first-step JIT warmup anyway). + log_every: int = 50, + num_workers: int = 2, + name: str = "fragmenta-lora", +) -> List[str]: + """Construct the train_lora.py subprocess argv.""" + cmd = [ + venv_python, + str(sa3_vendor_dir / "scripts" / "train_lora.py"), + "--model", sa3_model_name, + "--data_dir", str(data_dir), + "--save_dir", str(save_dir), + "--rank", str(int(rank)), + "--adapter_type", adapter_type, + "--dropout", str(float(dropout)), + "--lr", str(float(lr)), + "--steps", str(int(steps)), + "--batch_size", str(int(batch_size)), + "--duration", str(float(duration)), + "--base_precision", base_precision, + "--seed", str(int(seed)), + "--checkpoint_every", str(int(checkpoint_every)), + "--log_every", str(int(log_every)), + "--num_workers", str(int(num_workers)), + "--name", name, + "--logger", "csv", + # demo_every set to a very large number — Fragmenta's training + # monitor doesn't surface demo audio, no need to spend cycles. + "--demo_every", "1000000", + ] + if encoded_dir is not None: + # Phase 6 — feed pre-encoded latents directory. SA3's train_lora.py + # then uses PreEncodedDataset instead of SampleDataset and skips + # the SAME autoencoder pass per step. + cmd += ["--encoded_dir", str(encoded_dir)] + if svd_bases_path is not None and adapter_type.endswith("-xs"): + # -XS adapters factor weights against precomputed SVD bases. SA3 only + # *loads* bases from this path (it doesn't write them), so we pass it + # only when a cached .pt already exists — otherwise SA3 recomputes the + # SVD per layer on device (slower, but correct). See SA3Trainer for the + # cache path convention. + cmd += ["--svd_bases_path", str(svd_bases_path)] + if lora_alpha is not None: + cmd += ["--lora_alpha", str(int(lora_alpha))] + if include: + cmd += ["--include", *include] + if exclude: + cmd += ["--exclude", *exclude] + return cmd + + +# --- Checkpoint conversion (.ckpt → .safetensors with base_model metadata) --- + +def convert_run_checkpoints_to_safetensors( + run_dir: Path, + base_model: str, + model_name: Optional[str] = None, + delete_originals: bool = True, +) -> List[Path]: + """Convert PyTorch Lightning .ckpt files in a run's checkpoints/ directory + to SA3's native .safetensors LoRA format, with `base_model` injected into + the safetensors metadata header so /api/loras can filter by it. + + Why: SA3's `train_lora.py` writes Lightning .ckpt files. The inference + LoRA picker (/api/loras) globs for *.safetensors only. Without this + conversion, every trained LoRA is functionally orphaned — saved + correctly to disk but invisible to the inference loader. + + Idempotent: skips any .ckpt whose .safetensors sibling already exists + with a non-zero size. + + Returns the list of paths to the produced .safetensors files (sorted). + """ + ckpt_dir = run_dir / "checkpoints" + if not ckpt_dir.exists(): + return [] + + # Imports deferred so this module can be imported without the SA3 vendor + # being on sys.path (e.g., during pure orchestrator construction). + from app.core.config import get_config + sa3_vendor = get_config().get_path("stable_audio_3") + pp = sys.path[:] + if str(sa3_vendor) not in pp: + sys.path.insert(0, str(sa3_vendor)) + try: + from stable_audio_3.models.lora.utils import load_lora_checkpoint + from safetensors.torch import save_file as st_save_file + finally: + # Don't permanently mutate sys.path from a helper call. + if sys.path != pp: + sys.path[:] = pp + + written: List[Path] = [] + for ckpt_path in sorted(ckpt_dir.glob("*.ckpt")): + out_path = ckpt_path.with_suffix(".safetensors") + if out_path.exists() and out_path.stat().st_size > 0: + # Already converted (older artifact or a previous pass). Just + # bookkeep so the caller sees it in the return list. + written.append(out_path) + continue + try: + state_dict, lora_config = load_lora_checkpoint(ckpt_path) + except Exception: + # Corrupt or truncated ckpt — skip rather than crash the + # post-training pass. + continue + + # Top-level metadata is what /api/loras' safetensors reader inspects + # directly. We also keep the canonical `lora_config` JSON blob so + # SA3's own load_lora_checkpoint() can parse the file as-is. + metadata = { + "lora_config": json.dumps(lora_config or {}), + "base_model": base_model, + } + if model_name: + metadata["model_name"] = model_name + # Cast fp16 to keep file sizes consistent with SA3's standard format. + fp16_dict = {k: (v.half() if v.is_floating_point() else v) + for k, v in state_dict.items()} + st_save_file(fp16_dict, str(out_path), metadata=metadata) + if delete_originals: + try: + ckpt_path.unlink() + except OSError: + pass + written.append(out_path) + return sorted(written) + + +def build_train_env(sa3_vendor_dir: Path, hub_dir: Path) -> Dict[str, str]: + """Subprocess env: redirect HF cache into the app folder + silence WANDB.""" + env = os.environ.copy() + # Make `import stable_audio_3` work without pip-installing the package. + pp = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = ( + f"{sa3_vendor_dir}{os.pathsep}{pp}" if pp else str(sa3_vendor_dir) + ) + # Pin the HF cache to our app-folder hub dir; otherwise train_lora.py's + # model_cfg.resolve() would write into ~/.cache/huggingface/hub. Cover + # the legacy + transformers env names too for defense-in-depth. + env["HF_HUB_CACHE"] = str(hub_dir) + env["HUGGINGFACE_HUB_CACHE"] = str(hub_dir) + env["TRANSFORMERS_CACHE"] = str(hub_dir) + env["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + env["WANDB_DISABLED"] = "1" + # Force the training subprocess into offline mode for HF — we already + # pre-staged the base model in prestage_base_model(), so any remaining + # network call from the SA3 internals would be a noisy revision check + # against a cache we know is current. + env["HF_HUB_OFFLINE"] = "1" + env["TRANSFORMERS_OFFLINE"] = "1" + return env diff --git a/app/core/training/sa3_trainer.py b/app/core/training/sa3_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c712167cbf0bb4b14f922ac89a45ba5bcce0b79 --- /dev/null +++ b/app/core/training/sa3_trainer.py @@ -0,0 +1,839 @@ +"""SA3 LoRA training orchestrator — Phase 5. + +Public surface (matches what app/backend/app.py imports): + start_training(config) -> dict + get_training_status() -> dict + stop_training() -> dict + preview_training_plan(config) -> dict + class SA3Trainer + +Training is dispatched as a subprocess running +`vendor/stable-audio-3/scripts/train_lora.py`. Progress comes back through +two channels: + * stdout/stderr from the subprocess (parsed for tqdm "step X/Y" lines) + * metrics.csv that train_lora.py writes under --save_dir + +Config shape (from the frontend training form): +{ + "modelName": "my-lora", # used for run dir name + "baseModel": "sa3-medium-base", # must end in -base + "projectName": "my_first_track", # Dataset Workbench project name + "steps": 5000, + "checkpointSteps": 500, # checkpoint cadence + "batchSize": 1, + "learningRate": 1.0e-4, + "duration": 30.0, # max clip seconds per sample + "loraRank": 16, + "loraAlpha": 16, # null → defaults to rank + "loraDropout": 0.0, + "adapterType": "dora-rows", + "precision": "bf16", # bf16|fp16 + "seed": 42, + "include": null, # list[str] or null + "exclude": null +} +""" +from __future__ import annotations + +import csv +import json +import os +import re +import shlex +import signal +import subprocess +import sys +import threading +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +from app.backend.data.projects import project_path +from app.core.config import get_config +from app.core.training.sa3_lora_runner import ( + SA3_BASE_MODELS, + build_train_command, + build_train_env, + convert_run_checkpoints_to_safetensors, + prestage_base_model, +) +from utils.logger import get_logger + +logger = get_logger("SA3Trainer") + + +# --- Defaults -------------------------------------------------------------- + +DEFAULT_STEPS = 5000 +DEFAULT_CHECKPOINT_STEPS = 500 +DEFAULT_BATCH_SIZE = 1 +DEFAULT_LR = 1e-4 +DEFAULT_DURATION = 30.0 +DEFAULT_RANK = 16 +DEFAULT_ADAPTER = "dora-rows" +DEFAULT_PRECISION = "bf16" + + +# --- SA3Trainer singleton -------------------------------------------------- + +class SA3Trainer: + def __init__(self, config: Dict[str, Any]) -> None: + self.config: Dict[str, Any] = config or {} + self.process: Optional[subprocess.Popen] = None + self.run_dir: Optional[Path] = None + self.metrics_csv: Optional[Path] = None + self._monitor_thread: Optional[threading.Thread] = None + self.status: Dict[str, Any] = { + "is_training": False, + "status": "idle", + "step": 0, + "total_steps": 0, + "loss": None, + "message": "", + "started_at": None, + "ended_at": None, + "log_tail": [], # last ~50 stdout lines + "checkpoints": [], # safetensors written so far + "error": None, + } + + # --- Public API -------------------------------------------------------- + + def start(self) -> Dict[str, Any]: + # Fresh run on this trainer — clear any stop flag from a prior run. + self._stop_requested = False + # Mark training as in-flight BEFORE any blocking work. /api/start-training + # can block for tens of seconds (T5Gemma sibling fetch, base-model + # prestaging) — during that window the frontend polls + # /api/training-status and would otherwise see is_training=False from + # the __init__ default and interpret it as "training complete". + self.status.update({ + "is_training": True, + "status": "staging", + "started_at": time.time(), + "ended_at": None, + "step": 0, + "total_steps": int(self.config.get("steps") or DEFAULT_STEPS), + "loss": None, + "error": None, + "checkpoints": [], + # Surface the concrete seed (the backend rolls a random one when the + # UI requests it) so the user can reproduce a run they liked. + "seed": (int(self.config["seed"]) if self.config.get("seed") is not None else None), + "message": "Preparing dataset and base model…", + }) + try: + self._maybe_wipe_run_dir() + self._resolve_paths() + self._stage_dataset() + self._stage_base_model() + cmd, env = self._build_invocation() + self._spawn(cmd, env) + logger.info( + "Training started · project=%s · base=%s · adapter=%s · " + "rank=%s · steps=%s · batch=%s · lr=%s · duration=%ss", + self.config.get("projectName"), + self.config.get("baseModel"), + self.config.get("adapterType") or DEFAULT_ADAPTER, + self.config.get("loraRank") or DEFAULT_RANK, + self.config.get("steps") or DEFAULT_STEPS, + self.config.get("batchSize") or DEFAULT_BATCH_SIZE, + self.config.get("learningRate") or DEFAULT_LR, + self.config.get("duration") or DEFAULT_DURATION, + ) + return {"success": True, "run_dir": str(self.run_dir)} + except Exception as e: + self.status["error"] = str(e) + self.status["status"] = "failed" + self.status["is_training"] = False + self.status["ended_at"] = time.time() + logger.error("Training failed to start: %s", e) + return {"error": str(e)} + + def get_status(self) -> Dict[str, Any]: + # Snapshot + add a few derived fields the frontend already reads, so + # the polling loop in App.js doesn't have to know about both names. + # SA3 is step-based; we no longer expose `current_epoch`. + # If the on-disk checkpoint count looks stale (run finished, glob + # ran with the old filter, no live files surfaced), rescan once + # lazily so the UI catches up without needing a backend restart. + if not self.status.get("checkpoints") and self.run_dir is not None: + ckpt_dir = self.run_dir / "checkpoints" + if ckpt_dir.exists() and any(ckpt_dir.glob("*.ckpt")): + self._scan_checkpoints() + s = dict(self.status) + total = int(s.get("total_steps") or 0) + step = int(s.get("step") or 0) + s["current_step"] = step + s["progress"] = int(round(100 * step / total)) if total > 0 else 0 + s["checkpoints_saved"] = len(s.get("checkpoints") or []) + return s + + def stop(self) -> Dict[str, Any]: + if not self.process or self.process.poll() is not None: + return {"error": "Nothing to stop — no active training run."} + try: + # Flag the stop so the monitor thread labels the exit "stopped" + # rather than "failed" — SIGINT doesn't yield a stable rc==-2. + self._stop_requested = True + self.process.send_signal(signal.SIGINT) + try: + self.process.wait(timeout=10) + except subprocess.TimeoutExpired: + self.process.terminate() + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + self.status["status"] = "stopped" + self.status["is_training"] = False + self.status["ended_at"] = time.time() + return {"success": True} + except Exception as e: + return {"error": str(e)} + + def preview_plan(self) -> Dict[str, Any]: + try: + self._resolve_paths(create_dirs=False) + except FileNotFoundError as e: + return {"error": str(e)} + steps = int(self.config.get("steps") or DEFAULT_STEPS) + ckpt_every = int(self.config.get("checkpointSteps") or DEFAULT_CHECKPOINT_STEPS) + ckpts = max(1, steps // max(1, ckpt_every)) + proj_name = self.config.get("projectName") or self.config.get("project_name") + data_dir = str(project_path(proj_name)) if proj_name else None + return { + "model_name": self.config.get("modelName", "fragmenta-lora"), + "base_model": self.config.get("baseModel"), + "project_name": proj_name, + "data_dir": data_dir, + "save_dir": str(self.run_dir / "checkpoints") if self.run_dir else None, + "steps": steps, + "checkpoint_every": ckpt_every, + "expected_checkpoints": ckpts, + "rank": int(self.config.get("loraRank") or DEFAULT_RANK), + "alpha": int(self.config.get("loraAlpha") or self.config.get("loraRank") or DEFAULT_RANK), + "adapter_type": self.config.get("adapterType") or DEFAULT_ADAPTER, + "batch_size": int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE), + "lr": float(self.config.get("learningRate") or DEFAULT_LR), + "duration": float(self.config.get("duration") or DEFAULT_DURATION), + "precision": self.config.get("precision") or DEFAULT_PRECISION, + } + + # --- Internals --------------------------------------------------------- + + def _resolve_paths(self, create_dirs: bool = True) -> None: + cfg = get_config() + run_name = self._safe_name(self.config.get("modelName") or "lora-run") + self.run_dir = cfg.get_path("models_fine_tuned") / run_name + # Lightning's CSVLogger writes metrics.csv under + # `/lightning_logs/version_X/metrics.csv`. We don't know X + # upfront, so leave this unset and let _scrape_loss_history / + # _scrape_csv_loss rglob for it the first time they're called. + self.metrics_csv = None + if create_dirs: + self.run_dir.mkdir(parents=True, exist_ok=True) + (self.run_dir / "checkpoints").mkdir(exist_ok=True) + + @classmethod + def existing_run_info(cls, model_name: str) -> Optional[Dict[str, Any]]: + """Look up an existing run dir for a given LoRA name. Returns a dict + of countable artifacts if the dir exists with content, else None. + + Used by /api/start-training to refuse a same-name run unless the + caller explicitly opts in to overwrite. Counts only *.ckpt and + *.safetensors so a half-set-up dir with only a metadata file + doesn't trip the prompt. + """ + import shutil # noqa: F401 # ensures shutil resolves if user calls _maybe_wipe later + cfg = get_config() + run_name = cls._safe_name(model_name or "lora-run") + run_dir = cfg.get_path("models_fine_tuned") / run_name + if not run_dir.exists(): + return None + ckpt_dir = run_dir / "checkpoints" + files = [] + if ckpt_dir.exists(): + for ext in ("*.safetensors", "*.ckpt"): + files.extend(ckpt_dir.glob(ext)) + if not files and not (run_dir / "training.log").exists(): + return None + return { + "run_dir": str(run_dir), + "run_name": run_name, + "checkpoint_count": len(files), + "has_log": (run_dir / "training.log").exists(), + } + + def _maybe_wipe_run_dir(self) -> None: + """Honor the `overwrite` flag — wipe the run dir before staging.""" + if not self.config.get("overwrite"): + return + cfg = get_config() + run_name = self._safe_name(self.config.get("modelName") or "lora-run") + run_dir = cfg.get_path("models_fine_tuned") / run_name + if run_dir.exists(): + import shutil + shutil.rmtree(run_dir) + logger.info("Cleared existing run dir before restart: %s", run_dir) + + def _stage_dataset(self) -> None: + """Resolve --data_dir from a Dataset Workbench project. + + Training reads the committed `.txt` sidecars sitting next to each + audio file inside `//`. The Workbench's + "Create Dataset" action materialised those sidecars; we don't + rewrite anything here. + """ + project_name = self.config.get("projectName") or self.config.get("project_name") + if not project_name: + raise FileNotFoundError( + "projectName is required. Pick a project in the Training " + "tab's Dataset picker before starting a run." + ) + proj_dir = project_path(project_name) + if not proj_dir.exists(): + raise FileNotFoundError(f"project not found: {project_name}") + + sidecars = list(proj_dir.glob("*.txt")) + if not sidecars: + raise RuntimeError( + f"project “{project_name}” has no committed prompts yet — " + "annotate the clips and click Create Dataset, then retry." + ) + # SA3's caption_metadata_fn rejects clips whose sidecar is empty, + # so they silently drop out of the training set. Count them upfront + # so the user knows what they're actually training on (and refuse + # to start if NONE have prompts — that would just waste GPU hours). + non_empty = [p for p in sidecars if p.read_text(encoding="utf-8").strip()] + if not non_empty: + raise RuntimeError( + f"project “{project_name}” has {len(sidecars)} clip(s) but every " + "sidecar is empty — SA3 will reject all of them. Annotate at " + "least one clip and re-commit before training." + ) + blank = len(sidecars) - len(non_empty) + if blank > 0: + logger.warning( + "%d of %d clip(s) in project '%s' have empty prompts — " + "SA3 will silently drop them. Training on %d clip(s).", + blank, len(sidecars), project_name, len(non_empty), + ) + self.status["log_tail"].append( + f"Warning: {blank}/{len(sidecars)} clips have empty prompts and " + "will be dropped by SA3's data loader." + ) + self.status["log_tail"].append( + f"Dataset: project '{project_name}' · {len(non_empty)} usable clip(s) · {proj_dir}" + ) + self._data_dir = proj_dir + + # Phase 6 — opt into pre-encoded latents if a compatible .latents/ + # cache exists. SA3's `train_lora.py --encoded_dir` then skips the + # autoencoder pass per step. The cache is AE-bound (same-s vs + # same-l) so we verify the manifest matches the picked base before + # using it — otherwise we'd feed the DiT mis-shaped latents. + self._encoded_dir: Optional[Path] = None + try: + from app.backend.data.pre_encoder import ( + latents_dir, latents_count, latents_match_base, + ) + ldir = latents_dir(project_name) + base_model = self.config.get("baseModel") + if ldir.exists() and latents_count(project_name) > 0: + if latents_match_base(project_name, base_model): + self._encoded_dir = ldir + self.status["log_tail"].append( + f"Using pre-encoded latents: {latents_count(project_name)} " + f"file(s) · {ldir}" + ) + logger.info( + "Pre-encoded latents detected for project '%s' (%d files) — " + "skipping SAME autoencoder per step.", + project_name, latents_count(project_name), + ) + else: + logger.warning( + "Pre-encoded latents exist for project '%s' but were " + "produced by a different autoencoder than the chosen " + "base (%s) — falling back to live encoding.", + project_name, base_model, + ) + self.status["log_tail"].append( + f"Note: project has cached latents but they're for a " + f"different autoencoder than {base_model}. Training " + "will re-encode audio per step." + ) + except Exception as exc: + logger.warning("Pre-encoded latents probe failed: %s", exc) + + def _stage_base_model(self) -> None: + cfg = get_config() + base_model = self.config.get("baseModel") + if base_model not in SA3_BASE_MODELS: + raise ValueError( + f"baseModel must be one of {list(SA3_BASE_MODELS)}. " + "Post-trained checkpoints (no -base suffix) can't be used " + "as a LoRA training base — CFG distillation has collapsed " + "the gradient signal LoRAs target." + ) + hub_dir = cfg.get_path("models_pretrained") / "sa3" / "hub" + try: + from huggingface_hub import get_token + token = get_token() + except Exception: + token = None + + def _cb(pct: int, msg: str) -> None: + self.status["message"] = msg + self.status["log_tail"].append(f"[stage] {msg}") + # Mirror to the project logger so the terminal shows what's + # happening during long blocking operations (e.g. first-time + # T5Gemma sibling fetch can take ~30s on medium-base). + logger.info("[stage] %s", msg) + + prestage_base_model(base_model, hub_dir, token=token, progress_callback=_cb) + self._hub_dir = hub_dir + + def _build_invocation(self): + cfg = get_config() + sa3_vendor = cfg.get_path("stable_audio_3") + sa3_name, _repo = SA3_BASE_MODELS[self.config["baseModel"]] + + # Use the Fragmenta venv's python so we share installed packages. + venv_python = sys.executable + + precision_raw = (self.config.get("precision") or DEFAULT_PRECISION).lower() + precision = "bf16" if precision_raw in ("bf16", "bfloat16", "auto", "") else "fp16" + + include = self.config.get("include") + if include and isinstance(include, str): + include = shlex.split(include) + exclude = self.config.get("exclude") + if exclude and isinstance(exclude, str): + exclude = shlex.split(exclude) + + adapter_type = self.config.get("adapterType") or DEFAULT_ADAPTER + + # -XS adapters can reuse a precomputed SVD-bases cache keyed by base + # model, skipping the per-layer SVD at startup. SA3 only loads (never + # writes) this file, so we pass it only when present; population is a + # manual/precompute step. Ensure the dir exists so it's discoverable. + svd_bases_path = None + if adapter_type.endswith("-xs"): + svd_cache_dir = get_config().get_path("models_fine_tuned") / ".svd_cache" + svd_cache_dir.mkdir(parents=True, exist_ok=True) + candidate = svd_cache_dir / f"{self.config['baseModel']}.pt" + if candidate.exists(): + svd_bases_path = candidate + + cmd = build_train_command( + venv_python=venv_python, + sa3_vendor_dir=sa3_vendor, + sa3_model_name=sa3_name, + data_dir=self._data_dir, + encoded_dir=getattr(self, "_encoded_dir", None), + svd_bases_path=svd_bases_path, + save_dir=self.run_dir / "checkpoints", + rank=int(self.config.get("loraRank") or DEFAULT_RANK), + lora_alpha=self.config.get("loraAlpha"), + adapter_type=adapter_type, + dropout=float(self.config.get("loraDropout") or 0.0), + lr=float(self.config.get("learningRate") or DEFAULT_LR), + steps=int(self.config.get("steps") or DEFAULT_STEPS), + batch_size=int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE), + # Default to AND clamp at the base model's native training length + # (medium ≈380s, small ≈120s) — SA3's DiT tops out at 4096 latent + # tokens, so a longer window would exceed the model, not just cost + # VRAM. A missing duration defaults to the model max. + duration=min( + float(self.config.get("duration") or (380.0 if "medium" in sa3_name else 120.0)), + 380.0 if "medium" in sa3_name else 120.0, + ), + base_precision=precision, + include=include, + exclude=exclude, + seed=(int(self.config["seed"]) if self.config.get("seed") is not None else 42), + checkpoint_every=int(self.config.get("checkpointSteps") or DEFAULT_CHECKPOINT_STEPS), + name=self.config.get("modelName") or "fragmenta-lora", + ) + env = build_train_env(sa3_vendor, self._hub_dir) + return cmd, env + + def _spawn(self, cmd: List[str], env: Dict[str, str]) -> None: + log_path = self.run_dir / "training.log" + rank = int(self.config.get("loraRank") or DEFAULT_RANK) + alpha_cfg = self.config.get("loraAlpha") + alpha = int(alpha_cfg) if alpha_cfg not in (None, "") else rank + # Stamp training_metadata.json so /api/loras can find the base_model + # if the embedded safetensors metadata is missing it (legacy paths). + (self.run_dir / "training_metadata.json").write_text(json.dumps({ + "mode": "lora", + "engine": "sa3", + "base_model": self.config.get("baseModel"), + "model_name": self.config.get("modelName"), + "started_at": time.time(), + "lora_config": { + "rank": rank, + "alpha": alpha, + "adapter_type": self.config.get("adapterType") or DEFAULT_ADAPTER, + "dropout": float(self.config.get("loraDropout") or 0.0), + }, + "steps": int(self.config.get("steps") or DEFAULT_STEPS), + "lr": float(self.config.get("learningRate") or DEFAULT_LR), + "batch_size": int(self.config.get("batchSize") or DEFAULT_BATCH_SIZE), + }, indent=2)) + + self.status.update({ + "is_training": True, + "status": "running", + "step": 0, + "total_steps": int(self.config.get("steps") or DEFAULT_STEPS), + "loss": None, + "error": None, + "started_at": time.time(), + "ended_at": None, + "checkpoints": [], + "message": "Starting training subprocess...", + }) + + self.process = subprocess.Popen( + cmd, + cwd=str(get_config().project_root), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + self._monitor_thread = threading.Thread( + target=self._monitor, + args=(log_path,), + daemon=True, + name=f"sa3-train-monitor:{self.run_dir.name}", + ) + self._monitor_thread.start() + + def _monitor(self, log_path: Path) -> None: + """Pull stdout, parse PyTorch Lightning progress, scrape loss, watch checkpoints. + + SA3 trains via PL whose default progress bar emits *per-epoch* step + counts ("Epoch 6: 50%|...| 25/50 [00:07<00:07, 3.36it/s, train/loss=0.559]"). + We derive the global step as `epoch * batches_per_epoch + step_in_epoch`, + capture `batches_per_epoch` from the first such line (it's stable across + epochs since SampleDataset returns a fixed length), and clamp the + result to the configured max_steps so the percentage doesn't go past + 100 if the final epoch overruns. + """ + epoch_pat = re.compile(r"Epoch\s+(\d+):") + in_epoch_pat = re.compile(r"\|\s*(\d+)/(\d+)\b") # tqdm's "current/total" + loss_pat = re.compile(r"train/loss=([\d.eE+\-]+)") + speed_pat = re.compile(r"([\d.]+)it/s") + last_log_flush = time.time() + last_ckpt_scan = 0.0 + last_terminal_log = 0.0 + last_logged_step = -1 + prev_ckpt_count = 0 + current_epoch = 0 + batches_per_epoch = 0 + try: + with open(log_path, "w") as logf: + if self.process and self.process.stdout: + for line in self.process.stdout: + line = line.rstrip() + logf.write(line + "\n") + if time.time() - last_log_flush > 1: + logf.flush() + last_log_flush = time.time() + self.status["log_tail"].append(line) + if len(self.status["log_tail"]) > 80: + self.status["log_tail"] = self.status["log_tail"][-50:] + + # Only parse the step counter on lines that ARE the + # training progress bar (prefixed with "Epoch N:"), + # so unrelated tqdm bars during startup (e.g. + # "Loading checkpoint shards: 9/9") don't pollute + # batches_per_epoch. + m_epoch = epoch_pat.search(line) + if m_epoch: + current_epoch = int(m_epoch.group(1)) + m_step = in_epoch_pat.search(line) + if m_step: + cur_in_epoch = int(m_step.group(1)) + per_epoch = int(m_step.group(2)) + if per_epoch > 0 and batches_per_epoch == 0: + batches_per_epoch = per_epoch + if batches_per_epoch > 0: + global_step = current_epoch * batches_per_epoch + cur_in_epoch + max_steps = self.status.get("total_steps") or 0 + if max_steps > 0: + global_step = min(global_step, max_steps) + if global_step > self.status.get("step", 0): + self.status["step"] = global_step + + m_loss = loss_pat.search(line) + if m_loss: + try: + self.status["loss"] = float(m_loss.group(1)) + except ValueError: + pass + + # Live checkpoint enumeration + loss history scrape. + # Lightning writes *.ckpt every N steps; we want the + # count to climb in the UI as files appear, not only + # at end-of-run. Bucketed to ~2s so we don't pound + # the FS. The loss history scrape backfills step + # 0..49 from metrics.csv since PL's stdout postfix + # doesn't show train/loss until end-of-epoch-0. + now = time.time() + if now - last_ckpt_scan > 2.0: + last_ckpt_scan = now + self._scan_checkpoints() + self._scrape_loss_history() + cur_ckpt_count = len(self.status.get("checkpoints") or []) + if cur_ckpt_count > prev_ckpt_count: + logger.info( + "Checkpoint saved · %d total · run=%s", + cur_ckpt_count, self.run_dir.name, + ) + prev_ckpt_count = cur_ckpt_count + + # Throttled progress to the backend terminal log. + # Lightning emits step lines ~3× per second; we + # condense to one tidy summary every 5s. Omit the + # loss segment when we don't have a value yet (the + # CSV scrape runs every 2s but PL may not have + # logged anything during the very first second). + cur_step = self.status.get("step") or 0 + if (cur_step > last_logged_step + and now - last_terminal_log >= 5.0): + total = self.status.get("total_steps") or 0 + loss = self.status.get("loss") + pct = round(100 * cur_step / total) if total > 0 else 0 + speed_m = speed_pat.search(line) + parts = [f"step {cur_step}/{total} ({pct}%)"] + if isinstance(loss, (int, float)): + parts.append(f"loss {loss:.4f}") + if speed_m: + parts.append(f"{speed_m.group(1)} it/s") + logger.info(" · ".join(parts)) + last_terminal_log = now + last_logged_step = cur_step + rc = self.process.wait() if self.process else 1 + except Exception as e: + self.status["error"] = str(e) + rc = -1 + + self.status["ended_at"] = time.time() + self.status["is_training"] = False + # A user-requested stop wins regardless of the exit code (SIGINT can + # surface as various negative/non-zero codes across platforms). + if getattr(self, "_stop_requested", False): + self.status["status"] = "stopped" + else: + self.status["status"] = "complete" if rc == 0 else "failed" + if self.status["status"] == "failed" and not self.status.get("error"): + self.status["error"] = f"train_lora.py exited with code {rc}" + + # Convert PyTorch Lightning .ckpt files to SA3's native .safetensors + # LoRA format — the inference loader (/api/loras) only sees + # .safetensors, so unconverted .ckpt files would be functionally + # orphaned. We also inject `base_model` into the safetensors header + # so /api/loras' metadata filter passes without a JSON fallback. + # Best-effort: failure here doesn't fail the run. + if self.status["status"] in ("complete", "stopped") and self.run_dir: + try: + produced = convert_run_checkpoints_to_safetensors( + self.run_dir, + base_model=self.config.get("baseModel"), + model_name=self.config.get("modelName"), + ) + if produced: + logger.info( + "Converted %d checkpoint(s) to .safetensors · run=%s", + len(produced), self.run_dir.name, + ) + except Exception as exc: + logger.warning("Checkpoint conversion failed: %s", exc) + + # Final pass: enumerate written checkpoints + full loss history + + # latest single-value loss. + self._scan_checkpoints() + self._scrape_loss_history() + self._scrape_csv_loss() + + final_step = self.status.get("step") or 0 + final_total = self.status.get("total_steps") or 0 + final_loss = self.status.get("loss") + final_ckpts = len(self.status.get("checkpoints") or []) + loss_str = f"{final_loss:.4f}" if isinstance(final_loss, (int, float)) else "—" + if self.status["status"] == "complete": + logger.info( + "Training complete · %d/%d steps · final loss %s · %d checkpoint(s) · run=%s", + final_step, final_total, loss_str, final_ckpts, self.run_dir.name, + ) + elif self.status["status"] == "stopped": + logger.info( + "Training stopped at step %d/%d · %d checkpoint(s) · run=%s", + final_step, final_total, final_ckpts, self.run_dir.name, + ) + else: + logger.error( + "Training failed (exit %s) · %d/%d steps · error: %s · run=%s", + rc, final_step, final_total, self.status.get("error"), self.run_dir.name, + ) + + def _scrape_loss_history(self) -> None: + """Refresh self.status['loss_history'] from Lightning's metrics.csv. + + PL's tqdm postfix only surfaces `train/loss=` *after* the first + metrics flush (typically end-of-epoch-0), so step 0..49 of a fresh + run never appear in stdout. metrics.csv, on the other hand, has + per-step rows from step 0 — we just need to read it. + + Cheap: even at 10K steps a CSV scan is sub-10ms. Skipped silently + if the file hasn't been created yet (early in the run, before PL's + CSVLogger flushes anything). + """ + if not self.metrics_csv or not self.metrics_csv.exists(): + # CSVLogger writes under /lightning_logs/version_*/ + if self.run_dir: + for p in (self.run_dir / "checkpoints").rglob("metrics.csv"): + self.metrics_csv = p + break + if not self.metrics_csv or not self.metrics_csv.exists(): + return + try: + with open(self.metrics_csv) as f: + rows = list(csv.DictReader(f)) + except Exception: + return + points: List[Dict[str, Any]] = [] + loss_keys = ("train/loss", "loss", "train_loss") + for row in rows: + step_raw = row.get("step") + if step_raw in (None, ""): + continue + try: + step = int(step_raw) + except ValueError: + continue + for k in loss_keys: + v = row.get(k) + if v not in (None, ""): + try: + points.append({"step": step, "loss": float(v)}) + except ValueError: + pass + break + # Dedupe: csv can have multiple rows per step (different metric flush + # boundaries) — keep the last loss seen for each step. + by_step: Dict[int, float] = {} + for p in points: + by_step[p["step"]] = p["loss"] + ordered = sorted(by_step.items()) + self.status["loss_history"] = [{"step": s, "loss": l} for s, l in ordered] + # Also surface the most recent loss as the scalar so the terminal + # log and "Current Loss" field don't show "—" until end-of-epoch-0. + # PL's tqdm postfix is async; the CSV row lands a beat ahead. + if ordered: + self.status["loss"] = ordered[-1][1] + + def _scan_checkpoints(self) -> None: + """Update self.status['checkpoints'] from on-disk artifacts. + + SA3's train_lora.py uses PyTorch Lightning's ModelCheckpoint, which + writes `.ckpt` files (Lightning pickle format). The diffusion wrapper's + `on_save_checkpoint` hook strips the state_dict to LoRA-only weights + plus the embedded `lora_config`, so each .ckpt IS a LoRA checkpoint. + We also accept .safetensors for forward-compat with a future export + path or manual conversion. + """ + if not self.run_dir: + return + ckpt_dir = self.run_dir / "checkpoints" + if not ckpt_dir.exists(): + return + found = [] + for ext in ("*.safetensors", "*.ckpt"): + found.extend(ckpt_dir.glob(ext)) + # Lightning writes nested lightning_logs/version_X/* — those aren't + # the user-facing artifacts; skip recursion. + project_root = get_config().project_root + self.status["checkpoints"] = sorted( + str(p.relative_to(project_root)) for p in found + ) + + def _scrape_csv_loss(self) -> None: + if not self.metrics_csv or not self.metrics_csv.exists(): + # train_lora.py writes its CSV under the lightning logger dir, + # which is `//version_*/metrics.csv`. Walk to + # find it. + ckpt_dir = self.run_dir / "checkpoints" + for p in ckpt_dir.rglob("metrics.csv"): + self.metrics_csv = p + break + if not self.metrics_csv or not self.metrics_csv.exists(): + return + try: + with open(self.metrics_csv) as f: + rows = list(csv.DictReader(f)) + for row in reversed(rows): + for k in ("train/loss", "loss", "train_loss"): + v = row.get(k) + if v not in (None, ""): + try: + self.status["loss"] = float(v) + return + except ValueError: + pass + except Exception: + pass + + @staticmethod + def _safe_name(s: str) -> str: + return re.sub(r"[^a-zA-Z0-9_-]+", "_", s).strip("_") or "lora-run" + + +# --- Module-level singleton + public functions ----------------------------- + +_active: Optional[SA3Trainer] = None +_lock = threading.Lock() + + +def get_trainer() -> Optional[SA3Trainer]: + return _active + + +def start_training(config: Dict[str, Any]) -> Dict[str, Any]: + global _active + with _lock: + if _active and _active.status.get("is_training"): + return {"error": "A training run is already in progress."} + _active = SA3Trainer(config) + return _active.start() + + +def get_training_status() -> Dict[str, Any]: + if _active is None: + return { + "is_training": False, + "status": "idle", + "message": "No training run has been started yet.", + "progress": 0, + "current_step": 0, + "total_steps": 0, + "checkpoints_saved": 0, + "loss": None, + } + return _active.get_status() + + +def stop_training() -> Dict[str, Any]: + if _active is None: + return {"error": "No training run to stop."} + return _active.stop() + + +def preview_training_plan(config: Dict[str, Any]) -> Dict[str, Any]: + return SA3Trainer(config).preview_plan() diff --git a/app/frontend/index.html b/app/frontend/index.html index c81debbcf3d6acfdd9f47a92eb74ef97c2c7fbec..028bcce935cb7fab848e94a1b91cbe6db7380927 100644 --- a/app/frontend/index.html +++ b/app/frontend/index.html @@ -7,15 +7,38 @@ - - - - - Fragmenta Desktop + Fragmenta diff --git a/app/frontend/logs/fragmenta_20260525.log b/app/frontend/logs/fragmenta_20260525.log new file mode 100644 index 0000000000000000000000000000000000000000..f49dc23509eb5f3041182b7869955314af79c1b5 --- /dev/null +++ b/app/frontend/logs/fragmenta_20260525.log @@ -0,0 +1,8 @@ +2026-05-25 11:21:33 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO) +2026-05-25 11:21:33 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log +2026-05-25 11:44:54 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO) +2026-05-25 11:44:54 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log +2026-05-25 13:55:04 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO) +2026-05-25 13:55:04 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log +2026-05-25 13:55:05 | INFO | FragmentaLogger | setup_logging:105 | Logging system initialized (Level: INFO) +2026-05-25 13:55:05 | INFO | FragmentaLogger | setup_logging:107 | Log file: logs/fragmenta_20260525.log diff --git a/app/frontend/package.json b/app/frontend/package.json index d187ad3f526b80dea9d32479591201b7562d9502..20d43bf7ea34103950373fa0ab1fa5d97e59763e 100644 --- a/app/frontend/package.json +++ b/app/frontend/package.json @@ -1,7 +1,7 @@ { "name": "fragmenta-desktop", - "version": "0.1.2", - "description": "Fragmenta Desktop", + "version": "0.2.0", + "description": "Fragmenta", "type": "module", "scripts": { "dev": "vite", diff --git a/app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf b/app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf new file mode 100644 index 0000000000000000000000000000000000000000..cd1051ba41edf62f23f5b1d5abb30f03b7606784 --- /dev/null +++ b/app/frontend/public/BricolageGrotesque-VariableFont_opsz,wdth,wght.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31b91d15aae398699fae58363dbc8ca1167faffe7d2cd62e68c716dcaa7d5fdd +size 407844 diff --git a/app/frontend/public/InterTight-VariableFont_wght.ttf b/app/frontend/public/InterTight-VariableFont_wght.ttf new file mode 100644 index 0000000000000000000000000000000000000000..7e573255693a79cd51220ff186141afa5fda5f02 --- /dev/null +++ b/app/frontend/public/InterTight-VariableFont_wght.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b8ef9ed255ebe7341aa566554c0f3e87ee10ce06d2085f07ccf66f41ef96c28 +size 580572 diff --git a/app/frontend/public/fragmenta_background.png b/app/frontend/public/fragmenta_background.png index fca8c573341554aada2fe1fd094ad52436cf2444..40d95005341ee2ce331dc2ba5dabcc35a87d7a33 100644 --- a/app/frontend/public/fragmenta_background.png +++ b/app/frontend/public/fragmenta_background.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:048aea503935f9763e76db3f5d1fcd6d561d3db9aeac415605c46527a3d6631b -size 132732 +oid sha256:f7c5c50356c595570f790621b89da04b93680b2be43803810b33a165111e8600 +size 161840 diff --git a/app/frontend/public/interface.png b/app/frontend/public/interface.png index b2b04fd1fe7e66e8f0087cfa87ea1bc2b55b9997..e8f6b221013a9f1e59c74a293ff895dc0a091be2 100644 --- a/app/frontend/public/interface.png +++ b/app/frontend/public/interface.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:00d2730e2f53440597b018538ed30200928e26d1034c51ec8ef7a95fc0477e98 -size 1807207 +oid sha256:a05704da0c9b7ea812b44d94186d81fe969a3963bf11cca6c79fbadf5d33f645 +size 1590904 diff --git a/app/frontend/src/App.js b/app/frontend/src/App.js index c839bdcf9b6d040d112e2fd19ecfef7cc1a1dd77..94f3ac5ec5825d20b3df7140d9828ce6ec61b278 100644 --- a/app/frontend/src/App.js +++ b/app/frontend/src/App.js @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useMemo, useRef, Suspense, lazy } from 'react'; +import React, { useCallback, useState, useEffect, useMemo, useRef } from 'react'; import { Container, Box, @@ -21,6 +21,11 @@ import { FormControl, Select, MenuItem, + Menu, + ListItemIcon, + ListItemText, + Divider, + Snackbar, Accordion, AccordionSummary, AccordionDetails, @@ -31,8 +36,9 @@ import { useMediaQuery, ToggleButton, ToggleButtonGroup, - Tooltip, } from '@mui/material'; +import { TIPS } from './tooltips'; +import Tooltip from './components/Tooltip'; import { Plus as AddIcon, Database as UploadIcon, @@ -46,61 +52,76 @@ import { CloudDownload as CloudDownloadIcon, FolderOpen as FolderOpenIcon, Info as InfoIcon, - BookOpen as BookOpenIcon, + HelpCircle as InfoViewIcon, Moon as MoonIcon, Sun as SunIcon, Piano as PerformanceIcon, AlertCircle as AlertIcon, Wand2 as WandIcon, - Trash2 as DeleteIcon + Trash2 as DeleteIcon, + Menu as MenuIcon, + CheckCircle2 as CheckCircleIcon, } from 'lucide-react'; import api from './api'; -import HfAuthDialog from './components/HfAuthDialog'; +import AboutDialog from './components/AboutDialog'; +import { InfoViewProvider } from './components/InfoView'; import TabPanel from './components/TabPanel'; -import AudioUploadRow from './components/AudioUploadRow'; -import BulkAnnotatePanel from './components/BulkAnnotatePanel'; -import CsvImportPanel from './components/CsvImportPanel'; +import DatasetPrep from './components/DatasetPrep'; import TrainingMonitor from './components/TrainingMonitor'; -import ModelUnwrapButton from './components/ModelUnwrapButton'; -import CheckpointManager from './components/CheckpointManager'; +import CheckpointManagerWindow from './components/CheckpointManagerWindow'; +import LoraStack from './components/LoraStack'; +import EditPanel from './components/EditPanel'; import GeneratedFragmentsWindow from './components/GeneratedFragmentsWindow'; import WelcomePage from './components/WelcomePage'; -import { clearPerformanceSession } from './components/usePerformanceSession'; import { formatDuration } from './utils/format'; import theme, { appStyles, lightTheme } from './theme'; -const PerformancePanel = lazy(() => import('./components/PerformancePanel')); +import PerformancePanel from './components/PerformancePanel'; const COLOR_MODE_STORAGE_KEY = 'fragmenta-color-mode'; -const HIDE_WELCOME_PAGE_KEY = 'fragmenta-hide-welcome'; -const PERFORMANCE_ENABLED_KEY = 'fragmenta-performance-enabled'; +const HIDE_WELCOME_PAGE_KEY = 'fragmenta-hide-welcome-v2'; +const INFO_VIEW_STORAGE_KEY = 'fragmenta-info-view'; + +// Persisted across reload so the user lands back where they were. +// Tabs are: 0=Dataset, 1=Training, 2=Generation, 3=Performance. +const TAB_STORAGE_KEY = 'fragmenta.lastTab'; +const TAB_COUNT = 4; +const readStoredTab = () => { + try { + const raw = window.localStorage.getItem(TAB_STORAGE_KEY); + const n = Number(raw); + return Number.isFinite(n) && n >= 0 && n < TAB_COUNT ? n : 0; + } catch { + return 0; + } +}; function App() { - const [tabValue, setTabValue] = useState(0); - const [uploadRows, setUploadRows] = useState([ - { file: null, prompt: '', audioUrl: '' } - ]); + const [tabValue, setTabValue] = useState(readStoredTab); + // Lags behind tabValue by ~fadeDuration so content swap happens + // while the panel is invisible (cross-fade between pages). + const [displayedTab, setDisplayedTab] = useState(readStoredTab); + const TAB_FADE_MS = 180; + + // Persist the active tab so a reload returns the user to it. + useEffect(() => { + try { window.localStorage.setItem(TAB_STORAGE_KEY, String(tabValue)); } catch {} + }, [tabValue]); + // Header sticky chrome only kicks in once the page has scrolled. + const [isScrolled, setIsScrolled] = useState(false); + // Measure the header's actual rendered height so the fixed nav + // rail can be pinned at exactly the first card's top edge. + const headerRef = useRef(null); + const [navTopPx, setNavTopPx] = useState(94); const [processingStatus, setProcessingStatus] = useState(''); const [isProcessing, setIsProcessing] = useState(false); - const [processedCount, setProcessedCount] = useState(0); - const [chunksPreview, setChunksPreview] = useState([]); const [showWelcomePage, setShowWelcomePage] = useState( () => window.localStorage.getItem(HIDE_WELCOME_PAGE_KEY) !== 'true' ); - const [performanceEnabled, setPerformanceEnabled] = useState( - () => window.localStorage.getItem(PERFORMANCE_ENABLED_KEY) === 'true' - ); - const togglePerformance = () => { - setPerformanceEnabled((prev) => { - const next = !prev; - window.localStorage.setItem(PERFORMANCE_ENABLED_KEY, next ? 'true' : 'false'); - if (!next && tabValue === 3) setTabValue(0); - if (next) setTabValue(3); - return next; - }); - }; - const [authDialogOpen, setAuthDialogOpen] = useState(false); + const [checkpointMgrOpen, setCheckpointMgrOpen] = useState(false); + const [generationModelSelectOpen, setGenerationModelSelectOpen] = useState(false); + const [trainingBaseModelSelectOpen, setTrainingBaseModelSelectOpen] = useState(false); const [showInfoDialog, setShowInfoDialog] = useState(false); const [isOpeningDocumentation, setIsOpeningDocumentation] = useState(false); const [colorMode, setColorMode] = useState(() => { @@ -116,22 +137,48 @@ function App() { return 'dark'; }); + // Ableton-style Info View: when on, control help text shows in a fixed + // bottom bar (fed by the shared ) instead of popping over each + // control. Off by default; preference persisted. + const [infoViewEnabled, setInfoViewEnabled] = useState(() => { + if (typeof window === 'undefined') return false; + // Off by default — only on if the user explicitly turned it on. + return window.localStorage.getItem(INFO_VIEW_STORAGE_KEY) === 'on'; + }); + const toggleInfoView = useCallback(() => { + setInfoViewEnabled((prev) => { + const next = !prev; + try { window.localStorage.setItem(INFO_VIEW_STORAGE_KEY, next ? 'on' : 'off'); } catch (_) {} + return next; + }); + }, []); + const [trainingConfig, setTrainingConfig] = useState({ - mode: 'lora', - epochs: 30, - checkpointSteps: 500, + steps: 1000, // SA3 quick-start + checkpointSteps: 250, checkpointAuto: true, - batchSize: 4, + batchSize: 1, // SA3 examples all use 1 learningRate: 1e-4, - modelName: 'my_fine_tuned_model', - baseModel: 'stable-audio-open-1.0', - saveWrappedCheckpoint: false, - precision: 'auto', + modelName: 'my_lora', + baseModel: 'sa3-small-music-base', // only *-base checkpoints are valid targets + precision: 'bf16', + // Training window defaults to the base model's native length (small + // ≈120s; medium ≈380s — set on base-model change). Default base is + // small-music-base → 120s. + duration: 120.0, loraRank: 16, loraAlpha: 16, loraDropout: 0, - loraMultiplier: 1.0, + adapterType: 'dora-rows', // SA3 upstream default + seedRandom: true, // fresh random seed each run (recorded server-side) + seed: 42, // used only when seedRandom is off + + // SA3 docs' "common case" layer filter — prevents conditioner-hijacking + // on small datasets. Stored as space-separated strings (the format SA3's + // CLI consumes) so the Advanced TextFields can edit them directly. + include: 'transformer.layers', + exclude: 'seconds_total to_local_embed', }); const [checkpointPreview, setCheckpointPreview] = useState(null); const [suggestionDialog, setSuggestionDialog] = useState({ open: false, data: null, loading: false }); @@ -143,14 +190,18 @@ function App() { const [trainingStartTime, setTrainingStartTime] = useState(null); const [trainingError, setTrainingError] = useState(null); + // Generation panel top-level mode: 'create' (text → audio) or + // 'edit' (audio → audio: style transfer, inpaint, extend). + const [generationMode, setGenerationMode] = useState('create'); const [generationPrompt, setGenerationPrompt] = useState(''); + const [negativePrompt, setNegativePrompt] = useState(''); + const [loraStack, setLoraStack] = useState([]); // [{path, strength}] const [generationDuration, setGenerationDuration] = useState(10); const [generatedAudio, setGeneratedAudio] = useState(null); const [generatedAudioBlob, setGeneratedAudioBlob] = useState(null); const [isGenerating, setIsGenerating] = useState(false); const [generationProgress, setGenerationProgress] = useState(0); const [selectedModel, setSelectedModel] = useState(''); - const [selectedUnwrappedModel, setSelectedUnwrappedModel] = useState(''); const [generatedFragments, setGeneratedFragments] = useState([]); const [currentFilename, setCurrentFilename] = useState(''); const [cfgScale, setCfgScale] = useState(7.0); @@ -200,48 +251,199 @@ function App() { } }; - const downloadFragment = (fragment) => { - const link = document.createElement('a'); - link.href = fragment.audioUrl; - link.download = fragment.filename; - document.body.appendChild(link); - link.click(); - document.body.removeChild(link); + const deleteFragment = async (fragment) => { + if (!fragment?.filename) return; + try { + await api.delete(`/api/fragments/${encodeURIComponent(fragment.filename)}`); + setGeneratedFragments(prev => prev.filter(f => f.id !== fragment.id)); + // Best-effort revoke of blob URLs created during this session so + // we don't leak object URLs after delete. + if (fragment.audioUrl?.startsWith('blob:')) { + try { URL.revokeObjectURL(fragment.audioUrl); } catch { /* ignore */ } + } + } catch (err) { + console.error('Delete fragment failed:', err); + } + }; + + const clearAllFragments = async () => { + try { + await api.delete('/api/fragments'); + // Revoke any in-session blob URLs before clearing state. + generatedFragments.forEach(f => { + if (f.audioUrl?.startsWith('blob:')) { + try { URL.revokeObjectURL(f.audioUrl); } catch { /* ignore */ } + } + }); + setGeneratedFragments([]); + } catch (err) { + console.error('Clear all fragments failed:', err); + } }; - const [systemStatus, setSystemStatus] = useState(null); - const [isStatusLoading, setIsStatusLoading] = useState(false); const [availableModels, setAvailableModels] = useState([]); const [gpuMemoryStatus, setGpuMemoryStatus] = useState(null); const [isUpdatingGpuMemory, setIsUpdatingGpuMemory] = useState(false); const [baseModels, setBaseModels] = useState([ - { - name: 'stable-audio-open-small', - displayName: 'Stable Audio Open Small (Recommended)', - description: 'Faster - Lower memory usage', - type: 'base', - path: '/models/pretrained/stable-audio-open-small-model.safetensors', - configPath: '/models/config/model_config_small.json', - downloaded: false - }, - { - name: 'stable-audio-open-1.0', - displayName: 'Stable Audio Open 1.0', - description: 'Higher quality - Requires more memory', - type: 'base', - path: '/models/pretrained/stable-audio-open-model.safetensors', - configPath: '/models/config/model_config.json', - downloaded: false - } + { name: 'sa3-small-music', displayName: 'Small - Music', description: 'CPU/GPU · ≤ 120s', kind: 'post-trained', downloaded: false }, + { name: 'sa3-small-sfx', displayName: 'Small - SFX', description: 'CPU/GPU · ≤ 120s', kind: 'post-trained', downloaded: false }, + { name: 'sa3-medium', displayName: 'Medium', description: 'CUDA + Flash-Attn · ≤ 380s', kind: 'post-trained', downloaded: false }, + { name: 'sa3-small-music-base', displayName: 'Small - Music (Base)', description: 'CPU/GPU · ≤ 120s', kind: 'base', downloaded: false }, + { name: 'sa3-small-sfx-base', displayName: 'Small - SFX (Base)', description: 'CPU/GPU · ≤ 120s', kind: 'base', downloaded: false }, + { name: 'sa3-medium-base', displayName: 'Medium (Base)', description: 'CUDA + Flash-Attn · ≤ 380s', kind: 'base', downloaded: false }, ]); - const [showStartFreshDialog, setShowStartFreshDialog] = useState(false); - const [isStartingFresh, setIsStartingFresh] = useState(false); - const [uploadKey, setUploadKey] = useState(0); - // Bumping this key forces the performance panel to remount, which is how - // we flush its in-memory session state on Fresh Start (clearing localStorage - // alone wouldn't reset the mounted panel's useState mirrors). - const [performanceResetKey, setPerformanceResetKey] = useState(0); + // Dataset Workbench projects available as training inputs. Refreshed on + // mount and every time the Training tab becomes visible (in case the user + // just committed a project on the Dataset tab). + const [trainingProjects, setTrainingProjects] = useState([]); + const [trainingProject, setTrainingProject] = useState(() => { + try { return window.localStorage.getItem('fragmenta.training.lastProject') || ''; } + catch { return ''; } + }); + // Phase 6 — pre-encode state for the selected training project. + // { latents_count, latents_present, job: {state, current, total, ...} | null } + const [trainingPreEncode, setTrainingPreEncode] = useState({ + latents_count: 0, + latents_present: false, + job: null, + }); + const preEncodePollRef = useRef(null); + const refreshTrainingProjects = useCallback(async () => { + try { + const { data } = await api.get('/api/projects'); + setTrainingProjects(data.projects || []); + } catch { /* non-fatal */ } + }, []); + useEffect(() => { refreshTrainingProjects(); }, [refreshTrainingProjects]); + useEffect(() => { + if (tabValue === 1) refreshTrainingProjects(); + }, [tabValue, refreshTrainingProjects]); + + // Hydrate the Generated Fragments panel from disk on mount. Each + // /api/generate writes a sidecar JSON next to the WAV; this restores + // the latest 100 across page reloads. Server returns newest-first; we + // reverse so the in-memory order stays oldest-first (matches the + // append-at-end pattern used elsewhere — GeneratedFragmentsWindow + // reverses for display). + useEffect(() => { + let cancelled = false; + (async () => { + try { + const r = await api.get('/api/fragments?limit=100'); + if (cancelled) return; + const items = (r.data?.fragments || []) + // Performance-tab master recordings live in the same output + // folder but aren't generations — keep them out of here. + .filter((f) => f.source !== 'performance') + // Cap the browser at the 50 most recent generations. + .slice(0, 50) + .map((f, i) => ({ + id: f.created_at ? Math.round(f.created_at * 1000) + i : Date.now() - i, + prompt: f.prompt || '', + duration: f.duration, + cfgScale: f.cfg_scale, + steps: f.steps, + seed: f.seed, + modelId: f.model_id || '', + batchIndex: 1, + batchTotal: f.batch_size || 1, + audioUrl: `/api/fragments/${encodeURIComponent(f.filename)}`, + audioBlob: null, + filename: f.filename, + timestamp: f.created_at + ? new Date(f.created_at * 1000).toLocaleString() + : '', + createdAt: f.created_at ? f.created_at * 1000 : null, + editMode: f.edit_mode || null, + })); + // Server sends newest-first; reverse to keep the in-memory + // append-at-end convention. + items.reverse(); + setGeneratedFragments(items); + } catch (err) { + // Non-fatal — empty list is fine. + console.warn('Failed to hydrate fragments from server:', err); + } + })(); + return () => { cancelled = true; }; + }, []); + useEffect(() => { + try { + if (trainingProject) window.localStorage.setItem('fragmenta.training.lastProject', trainingProject); + } catch {} + }, [trainingProject]); + // If the persisted project no longer exists, clear it so the picker shows "(none)". + useEffect(() => { + if (trainingProject && trainingProjects.length > 0 && !trainingProjects.some(p => p.name === trainingProject)) { + setTrainingProject(''); + } + }, [trainingProject, trainingProjects]); + + // Phase 6 — refresh pre-encode state when the user changes which project + // they're training on, and keep polling while a job is in flight. + const refreshTrainingPreEncode = useCallback(async (name) => { + if (!name) { + setTrainingPreEncode({ latents_count: 0, latents_present: false, job: null }); + return; + } + try { + const [proj, status] = await Promise.all([ + api.get(`/api/projects/${encodeURIComponent(name)}`), + api.get(`/api/projects/${encodeURIComponent(name)}/pre-encode/status`), + ]); + setTrainingPreEncode({ + latents_count: proj.data.latents_count ?? 0, + latents_present: !!proj.data.latents_present, + job: status.data.job ?? null, + }); + } catch { /* non-fatal */ } + }, []); + + useEffect(() => { + refreshTrainingPreEncode(trainingProject); + }, [trainingProject, refreshTrainingPreEncode]); + + // Poll while a job is queued/running. Clean up on project change or unmount. + useEffect(() => { + const job = trainingPreEncode.job; + const inFlight = job && (job.state === 'queued' || job.state === 'running'); + if (!inFlight || !trainingProject) { + if (preEncodePollRef.current) { + window.clearTimeout(preEncodePollRef.current); + preEncodePollRef.current = null; + } + return; + } + preEncodePollRef.current = window.setTimeout(() => { + refreshTrainingPreEncode(trainingProject); + }, 750); + return () => { + if (preEncodePollRef.current) { + window.clearTimeout(preEncodePollRef.current); + preEncodePollRef.current = null; + } + }; + }, [trainingProject, trainingPreEncode.job, refreshTrainingPreEncode]); + + const startTrainingPreEncode = useCallback(async () => { + if (!trainingProject) return; + try { + await api.post(`/api/projects/${encodeURIComponent(trainingProject)}/pre-encode`); + refreshTrainingPreEncode(trainingProject); + } catch (e) { + console.error('Failed to start pre-encode', e); + } + }, [trainingProject, refreshTrainingPreEncode]); + + const cancelTrainingPreEncode = useCallback(async () => { + if (!trainingProject) return; + try { + await api.post(`/api/projects/${encodeURIComponent(trainingProject)}/pre-encode/cancel`); + refreshTrainingPreEncode(trainingProject); + } catch (e) { /* non-fatal */ } + }, [trainingProject, refreshTrainingPreEncode]); + const [isFreeingGPU, setIsFreeingGPU] = useState(false); const [showFreeGPUDialog, setShowFreeGPUDialog] = useState(false); const [modelWarning, setModelWarning] = useState({ @@ -255,11 +457,18 @@ function App() { [colorMode] ); const isCompactLayout = useMediaQuery(appTheme.breakpoints.down('md')); - const isIconOnlySidebar = useMediaQuery(appTheme.breakpoints.between('md', 'lg')); - - useEffect(() => { - setSelectedUnwrappedModel(''); - }, [selectedModel]); + // Vertical icon-only mode: between the compact (horizontal) threshold + // and a custom upper bound. The MUI `lg` breakpoint at 1200 was too + // eager — labels collapsed while there was still plenty of room. + const isIconOnlySidebar = useMediaQuery('(min-width: 900px) and (max-width: 1099.95px)'); + // Mobile/very-small width — the nav rail goes horizontal (compact) + // AND drops the text labels, matching the icon-only treatment used + // on mid-size vertical. + const isMobileLayout = useMediaQuery(appTheme.breakpoints.down('sm')); + // Dock collapses to a hamburger at the same threshold where the nav + // rail flips horizontal — keeps the chrome transition unified. + const isDockCollapsed = isCompactLayout; + const [dockMenuAnchor, setDockMenuAnchor] = useState(null); useEffect(() => { console.log('Model changed:', selectedModel); @@ -268,35 +477,27 @@ function App() { setSelectedLora(''); }, [selectedModel]); - // Resolve the base model identity for the currently-selected entry. Works - // for both base-model selections (selectedModel === 'stable-audio-open-...') - // and fine-tunes (where the API returns base_model from training_metadata). + // Resolve the base SA3 model identity for the currently-selected entry. + // For a direct base pick it's selectedModel itself; for a fine-tune we + // read base_model from the training_metadata exposed by /api/models. const resolvedBaseModel = (() => { if (!selectedModel) return null; - if (selectedModel === 'stable-audio-open-small' || selectedModel === 'stable-audio-open-1.0') { - return selectedModel; - } + if (selectedModel.startsWith('sa3-')) return selectedModel; const model = availableModels.find(m => m.name === selectedModel); - if (model?.base_model) return model.base_model; - // Legacy fine-tunes without base_model metadata: fall back to the - // unwrapped-file size heuristic. - if (model && selectedUnwrappedModel) { - const u = model.unwrapped_models?.find(x => x.path === selectedUnwrappedModel); - if (u) return (u.size_mb || 0) < 2000 ? 'stable-audio-open-small' : 'stable-audio-open-1.0'; - } - return null; + return model?.base_model || null; })(); - // True only for the original distilled small base, NOT for fine-tunes of - // it. Fine-tuning destroys the CFG distillation, so the 8-step / CFG-1.0 - // lock no longer applies — the user controls steps and CFG normally. - const isDistilledBase = selectedModel === 'stable-audio-open-small'; + // All three user-visible SA3 models are post-trained (distilled to 8 + // steps, CFG baked at 1.0). The backend ignores cfg_scale on these and + // defaults steps to 8 — the UI just mirrors that so the controls don't + // show misleading values. + const isDistilledBase = !!selectedModel && selectedModel.startsWith('sa3-') && !selectedModel.endsWith('-base'); const getMaxDuration = () => { - if (!selectedModel) return 10; - if (resolvedBaseModel === 'stable-audio-open-small') return 11; - if (resolvedBaseModel === 'stable-audio-open-1.0') return 47; - return 10; + if (!selectedModel) return 30; + if (resolvedBaseModel === 'sa3-medium' || resolvedBaseModel === 'sa3-medium-base') return 380; + if (resolvedBaseModel && resolvedBaseModel.startsWith('sa3-')) return 120; + return 30; }; useEffect(() => { @@ -304,49 +505,65 @@ function App() { if (generationDuration > maxDuration) { setGenerationDuration(maxDuration); } - // The distilled small model is hard-coded to 8 steps + pingpong sampler - // at the backend regardless of slider value; snap the slider so the UI - // reflects what will actually run. When switching BACK to a non- - // distilled model, restore a sensible default — otherwise the slider - // is stuck at 8 from the prior selection and the big model runs 8 - // steps (which produces noise). + // SA3 post-trained models run at 8 steps with CFG=1.0; base variants + // want ~50 steps with CFG~7. Snap the slider so the UI reflects what + // will actually run. if (isDistilledBase && steps !== 8) { setSteps(8); } else if (!isDistilledBase && steps < 50) { - setSteps(250); + setSteps(50); } - }, [selectedModel, selectedUnwrappedModel, isDistilledBase]); + }, [selectedModel, isDistilledBase]); const handleTabChange = (event, newValue) => { + if (newValue === tabValue) return; setTabValue(newValue); }; - const addUploadRow = () => { - setUploadRows([...uploadRows, { file: null, prompt: '', audioUrl: '' }]); - }; + // Sync displayedTab to tabValue with a fade-out delay so content + // swap happens while the wrapper opacity is at 0. Works for any + // code path that updates tabValue (Tabs click, model-warning + // auto-jump, etc). + useEffect(() => { + if (tabValue === displayedTab) return; + const t = window.setTimeout(() => setDisplayedTab(tabValue), TAB_FADE_MS); + return () => window.clearTimeout(t); + }, [tabValue, displayedTab]); - const removeUploadRow = (index) => { - const newRows = uploadRows.filter((_, i) => i !== index); - setUploadRows(newRows); - }; + useEffect(() => { + const onScroll = () => setIsScrolled(window.scrollY > 8); + onScroll(); + window.addEventListener('scroll', onScroll, { passive: true }); + return () => window.removeEventListener('scroll', onScroll); + }, []); - const updateUploadRow = (index, data) => { - const newRows = [...uploadRows]; - newRows[index] = data; - setUploadRows(newRows); - }; + // Re-measure header bottom edge on mount, resize, and content + // reflows. Nav rail's `top` = headerBottom + headerRow.mb + + // tabPanelStyles.pt so it lines up with the first card. + useEffect(() => { + if (!headerRef.current) return undefined; + const el = headerRef.current; + const measure = () => { + // Header is sticky at top: 0, so rect.bottom is already the + // viewport y of the header's bottom edge. + const rect = el.getBoundingClientRect(); + const w = window.innerWidth; + const offset = w >= 900 ? 18 : w >= 600 ? 14 : 12; + setNavTopPx(rect.bottom + offset); + }; + measure(); + // Re-measure only when the header's actual size changes (e.g. + // GPU card transitions detected ↔ not on first load) or the + // window resizes — never on scroll, never on poll churn. + const ro = new ResizeObserver(measure); + ro.observe(el); + window.addEventListener('resize', measure); + return () => { + ro.disconnect(); + window.removeEventListener('resize', measure); + }; + }, []); - const fetchSystemStatus = async () => { - setIsStatusLoading(true); - try { - const response = await api.get('/api/status'); - setSystemStatus(response.data); - } catch (error) { - console.error('Error fetching system status:', error); - } finally { - setIsStatusLoading(false); - } - }; const fetchAvailableModels = async () => { try { @@ -369,17 +586,18 @@ function App() { const fetchBaseModelsStatus = async () => { try { - const response = await api.get('/api/base-models/status'); - const baseModelsStatus = response.data.base_models; - + const response = await api.get('/api/checkpoints'); + const byId = Object.fromEntries( + (response.data.checkpoints || []).map(c => [c.id, c]) + ); setBaseModels(prevModels => prevModels.map(model => ({ ...model, - downloaded: baseModelsStatus[model.name]?.downloaded || false + downloaded: byId[model.name]?.downloaded || false, })) ); } catch (error) { - console.error('Error fetching base models status:', error); + console.error('Error fetching checkpoint status:', error); } }; @@ -402,7 +620,6 @@ function App() { } else { if (selectedModel === name) { setSelectedModel(''); - setSelectedUnwrappedModel(''); } } refreshAllModels(); @@ -435,7 +652,6 @@ function App() { }; useEffect(() => { - fetchSystemStatus(); fetchAvailableModels(); fetchBaseModelsStatus(); fetchAvailableLoras(); @@ -467,7 +683,7 @@ function App() { }, 300); return () => clearTimeout(handle); }, [ - trainingConfig.epochs, + trainingConfig.steps, trainingConfig.batchSize, trainingConfig.checkpointSteps, trainingConfig.checkpointAuto, @@ -496,10 +712,10 @@ function App() { const newEntry = { timestamp: Date.now(), progress: currentStatus.progress || 0, - current_epoch: currentStatus.current_epoch || 0, - current_step: currentStatus.current_step || 0, + current_step: currentStatus.current_step ?? currentStatus.step ?? 0, loss: currentStatus.loss, - checkpoints_saved: currentStatus.checkpoints_saved || 0, + checkpoints_saved: currentStatus.checkpoints_saved + ?? (currentStatus.checkpoints?.length || 0), is_training: currentStatus.is_training, message: currentStatus.error || (currentStatus.progress > 0 ? `Progress: ${currentStatus.progress}%` : 'Starting...') @@ -508,7 +724,6 @@ function App() { const lastEntry = prev[prev.length - 1]; if (!lastEntry || lastEntry.progress !== newEntry.progress || - lastEntry.current_epoch !== newEntry.current_epoch || lastEntry.current_step !== newEntry.current_step || lastEntry.loss !== newEntry.loss || lastEntry.checkpoints_saved !== newEntry.checkpoints_saved || @@ -530,11 +745,9 @@ function App() { setTrainingProgress(100); } setTimeout(() => { - fetchSystemStatus(); - // refreshAllModels picks up the new LoRA too if - // this was a LoRA run — without it, the LoRA - // picker stays empty until the user manually hits - // refresh. + // refreshAllModels picks up the new LoRA — without it, + // the LoRA picker stays empty until the user manually + // hits refresh. refreshAllModels(); }, 0); } @@ -552,41 +765,23 @@ function App() { }; }, [isTraining]); - const processFiles = async () => { - setIsProcessing(true); - setProcessingStatus('Processing files...'); - - try { - const formData = new FormData(); - - uploadRows.forEach((row, index) => { - if (row.file && row.prompt) { - formData.append(`file_${index}`, row.file); - formData.append(`prompt_${index}`, row.prompt); - } - }); - - const response = await api.post('/api/process-files', formData); - - setProcessingStatus(response.data.message); - setProcessedCount(response.data.processed_count); - setChunksPreview(response.data.chunks_preview || []); - - setUploadRows([{ file: null, prompt: '', audioUrl: '' }]); - - fetchSystemStatus(); - } catch (error) { - setProcessingStatus(`Error: ${error.response?.data?.error || error.message}`); - } finally { - setIsProcessing(false); - } - }; const fetchHyperparamSuggestion = async () => { setShowRationale(false); + if (!trainingProject) { + setSuggestionDialog({ + open: true, + data: { ok: false, error: "Pick a dataset project first." }, + loading: false, + }); + return; + } setSuggestionDialog({ open: true, data: null, loading: true }); try { - const resp = await api.get(`/api/training/suggest-hyperparams?mode=${trainingConfig.mode}`); + const url = `/api/training/suggest-hyperparams` + + `?project_name=${encodeURIComponent(trainingProject)}` + + `&base_model=${encodeURIComponent(trainingConfig.baseModel || '')}`; + const resp = await api.get(url); setSuggestionDialog({ open: true, data: resp.data, loading: false }); } catch (e) { setSuggestionDialog({ @@ -600,11 +795,25 @@ function App() { const applyHyperparamSuggestion = () => { const cfg = suggestionDialog.data?.config; if (!cfg) return; - setTrainingConfig({ ...trainingConfig, ...cfg }); + // Suggester returns include/exclude as arrays; the form edits them as + // space-separated strings. Backend's sa3_trainer accepts either. + const normalized = { + ...cfg, + include: Array.isArray(cfg.include) ? cfg.include.join(' ') : (cfg.include || ''), + exclude: Array.isArray(cfg.exclude) ? cfg.exclude.join(' ') : (cfg.exclude || ''), + }; + setTrainingConfig({ ...trainingConfig, ...normalized }); setSuggestionDialog({ open: false, data: null, loading: false }); }; - const startTraining = async () => { + // Confirm dialog for the same-name LoRA collision case. + const [overwriteConfirm, setOverwriteConfirm] = useState(null); + + const startTraining = async (overwrite = false) => { + // Defensive: an `onClick={startTraining}` would pass React's + // SyntheticEvent in as the first arg; coerce so it can never + // leak into the JSON payload as a circular DOM reference. + overwrite = overwrite === true; const selectedBaseModel = baseModels.find(m => m.name === trainingConfig.baseModel); if (!selectedBaseModel) { showModelWarning({ @@ -624,19 +833,33 @@ function App() { return; } + if (!trainingProject) { + showModelWarning({ + title: 'Dataset Required', + message: 'Pick a dataset project before starting training. ' + + 'Create one in the Dataset tab if you don\'t have any yet.', + canOpenModels: false, + }); + return; + } + setIsTraining(true); setTrainingProgress(0); setTrainingError(null); setTrainingStartTime(Date.now()); setTrainingHistory([]); - await api.post('/api/bulk-annotate/unload-clap').catch(() => {}); + await api.post('/api/clap/unload').catch(() => {}); try { - const { checkpointAuto, ...rest } = trainingConfig; + const { checkpointAuto, seedRandom, ...rest } = trainingConfig; const payload = { ...rest, + projectName: trainingProject, checkpointSteps: checkpointAuto ? null : trainingConfig.checkpointSteps, + // null = let the backend roll a fresh seed and record it. + seed: seedRandom ? null : trainingConfig.seed, + overwrite: overwrite, }; const response = await api.post('/api/start-training', payload); setProcessingStatus('Training started successfully!'); @@ -644,6 +867,19 @@ function App() { const errorData = error.response?.data; const errorMessage = errorData?.error || error.message; + // Same-name collision (HTTP 409) — surface a confirm dialog so the + // user can choose to overwrite the previous run rather than + // co-mingling its checkpoints. + if (error.response?.status === 409 && errorData?.code === 'run_exists') { + setIsTraining(false); + setOverwriteConfirm({ + runName: errorData.run_name, + checkpointCount: errorData.checkpoint_count, + message: errorData.message, + }); + return; + } + if (errorData?.checkpoint_warning) { setTrainingError(errorMessage); setProcessingStatus(errorMessage); @@ -677,36 +913,56 @@ function App() { const baseRequestData = { prompt: generationPrompt, duration: generationDuration, - cfg_scale: cfgScale, - steps: steps + steps: steps, }; + const negTrim = negativePrompt.trim(); + if (negTrim) { + baseRequestData.negative_prompt = negTrim; + } + + // LoRA stack — LoraStack is the single source of truth for the + // Generation panel. Empty slots (path === '') are filtered out + // so an unused slot doesn't break the request. + const activeLoras = (loraStack || []).filter(s => s.path); + if (activeLoras.length) { + // Bypassed slots stay in the stack (load order preserved) but + // contribute nothing — send strength 0. + baseRequestData.loras = activeLoras.map(s => ({ + path: s.path, + strength: s.bypassed ? 0 : s.strength, + })); + } + // SA3 post-trained models bake CFG at 1.0 — only the *-base variants + // honour cfg_scale. Sending it on a post-trained model is harmless + // (backend forces 1.0), but we only attach it for base variants so + // the UI matches what the backend will use. + if (!isDistilledBase) { + baseRequestData.cfg_scale = cfgScale; + } const baseModel = baseModels.find(m => m.name === selectedModel); if (baseModel) { if (!baseModel.downloaded) { showModelWarning({ - title: 'Base Model Not Downloaded', - message: `The selected base model "${baseModel.displayName}" is not downloaded.`, + title: 'Model Not Downloaded', + message: `"${baseModel.displayName}" hasn't been downloaded yet. Open the Checkpoint Manager to fetch it.`, canOpenModels: true, }); return; } - - baseRequestData.model_name = selectedModel; - } else if (selectedUnwrappedModel) { - baseRequestData.unwrapped_model_path = selectedUnwrappedModel; + baseRequestData.model_id = selectedModel; + } else if (selectedModel && selectedModel.startsWith('sa3-')) { + // Hidden SA3 variant (base or AE) reachable via /api/checkpoints?include=all. + baseRequestData.model_id = selectedModel; } else { - setProcessingStatus('Please select a model'); + setProcessingStatus( + selectedModel + ? `'${selectedModel}' is an SA2 fine-tune; SA3 cannot load it. Pick a Stable Audio 3 model.` + : 'Please select a model' + ); return; } - // LoRA only meaningful on top of a base model (the LoRA was trained - // against that exact base — applying it to a full-FT model is undefined). - if (selectedLora && baseModel) { - baseRequestData.lora_path = selectedLora; - baseRequestData.lora_multiplier = loraMultiplier; - } - const parsedSeed = parseInt(seedValue, 10); if (!randomSeed && (Number.isNaN(parsedSeed) || parsedSeed < 0)) { setProcessingStatus('Please enter a non-negative integer seed, or enable Random Seed'); @@ -715,7 +971,7 @@ function App() { const totalRuns = Math.max(1, Math.min(10, batchCount)); - await api.post('/api/bulk-annotate/unload-clap').catch(() => {}); + await api.post('/api/clap/unload').catch(() => {}); stopGenerationRef.current = false; const abortController = new AbortController(); @@ -724,14 +980,26 @@ function App() { setIsGenerating(true); setGenerationProgress(0); + // Real progress polling — the backend exposes /api/generation-progress + // which reflects the SA3 sampler's per-ODE-step callback. We poll at + // ~250ms; sampling is N steps total (8 for distilled, ~50 for base) + // so each step takes hundreds of ms to several seconds — finer polling + // is unnecessary. let progressInterval; const startProgressTicker = () => { - progressInterval = setInterval(() => { - setGenerationProgress(prev => { - if (prev >= 90) return prev; - return prev + Math.random() * 3; - }); - }, 1000); + progressInterval = setInterval(async () => { + try { + const r = await api.get('/api/generation-progress'); + const d = r.data || {}; + // Don't drop to 0 just because backend briefly reports + // idle between batch elements; clamp monotonic until + // we hand off to setGenerationProgress(100) on response. + const pct = Number(d.progress) || 0; + setGenerationProgress(prev => Math.max(prev, Math.min(95, pct))); + } catch { + /* poll failure is non-fatal — bar just freezes briefly */ + } + }, 250); }; const stopProgressTicker = () => { if (progressInterval) { @@ -780,9 +1048,15 @@ function App() { setGenerationProgress(100); const audioUrl = URL.createObjectURL(response.data); - const fragmentFilename = buildFragmentFilename( - generationPrompt, batchTimestamp, batchIndex, totalRuns - ); + // The backend is authoritative for the on-disk name (it writes + // the WAV + sidecar). Use the header it returns so reveal / + // delete / serve all hit the real file; only fall back to a + // locally-built name if the header is somehow missing. + const fragmentFilename = + response.headers?.['x-fragment-filename'] || + buildFragmentFilename( + generationPrompt, batchTimestamp, batchIndex, totalRuns + ); setGeneratedAudio(audioUrl); setGeneratedAudioBlob(response.data); @@ -795,15 +1069,20 @@ function App() { cfgScale, steps, seed: seedForRun, + modelId: selectedModel, batchIndex, batchTotal: totalRuns, audioUrl, audioBlob: response.data, filename: fragmentFilename, - timestamp: new Date().toLocaleString() + timestamp: new Date().toLocaleString(), + createdAt: Date.now(), }; - setGeneratedFragments(prev => [...prev, newFragment]); + setGeneratedFragments(prev => { + const next = [...prev, newFragment]; + return next.length > 100 ? next.slice(next.length - 100) : next; + }); completedRuns += 1; } @@ -852,40 +1131,6 @@ function App() { setProcessingStatus('Stopping generation…'); }; - const handleStartFresh = async () => { - setIsStartingFresh(true); - setShowStartFreshDialog(false); - - try { - const response = await api.post('/api/start-fresh'); - - setUploadRows([{ file: null, prompt: '', audioUrl: '' }]); - setProcessedCount(0); - setChunksPreview([]); - setGeneratedAudio(null); - setGeneratedAudioBlob(null); - setGeneratedFragments([]); - setProcessingStatus(''); - setGenerationPrompt(''); - setUploadKey(prev => prev + 1); - - // Wipe persisted performance session and force-remount the panel so - // its in-memory state resets to defaults along with localStorage. - // (MIDI mappings and other app preferences are intentionally kept.) - clearPerformanceSession(); - setPerformanceResetKey(prev => prev + 1); - - setProcessingStatus(response.data.message); - - fetchSystemStatus(); - - } catch (error) { - setProcessingStatus(`Start fresh error: ${error.response?.data?.error || error.message}`); - } finally { - setIsStartingFresh(false); - } - }; - const handleFreeGPUMemory = async () => { setIsFreeingGPU(true); setShowFreeGPUDialog(false); @@ -946,32 +1191,9 @@ function App() { }; const getSelectedModelDisplayName = () => { - console.log('=== GETTING DISPLAY NAME ==='); - console.log('selectedModel:', selectedModel); - console.log('selectedUnwrappedModel:', selectedUnwrappedModel); - - if (!selectedModel) { - console.log('No selectedModel, returning empty string'); - return ''; - } - + if (!selectedModel) return ''; const baseModel = baseModels.find(m => m.name === selectedModel); - if (baseModel) { - console.log('Found base model:', baseModel.displayName); - return baseModel.displayName; - } - - const model = availableModels.find(m => m.name === selectedModel); - if (model && selectedUnwrappedModel) { - const selectedUnwrapped = model.unwrapped_models?.find(u => u.path === selectedUnwrappedModel); - if (selectedUnwrapped) { - const displayName = `${model.name} (${selectedUnwrapped.name})`; - console.log('Generated fine-tuned display name:', displayName); - return displayName; - } - } - - console.log('Using fallback name:', selectedModel); + if (baseModel) return baseModel.displayName; return selectedModel; }; @@ -984,8 +1206,6 @@ function App() { const newSelectedModel = event.target.value; setSelectedModel(newSelectedModel); - setSelectedUnwrappedModel(''); - const selectedBaseModel = baseModels.find(m => m.name === newSelectedModel); if (selectedBaseModel && !selectedBaseModel.downloaded) { showModelWarning({ @@ -1011,7 +1231,7 @@ function App() { const handleOpenModelsFromWarning = () => { closeModelWarning(); - setAuthDialogOpen(true); + setCheckpointMgrOpen(true); }; const getTrainingIndicatorState = () => { @@ -1032,6 +1252,7 @@ function App() { return ( + - + {/* Logo */} @@ -1067,127 +1288,46 @@ function App() { - - - - - - - - + {gpuMemoryStatus && gpuMemoryStatus.cuda ? ( <> - - - GPU Memory + + + GPU - - 2 ? 'good' : gpuMemoryStatus.cuda.free > 0.5 ? 'low' : 'critical' - )} - /> - - {gpuMemoryStatus.cuda.free > 2 ? 'Good' : - gpuMemoryStatus.cuda.free > 0.5 ? 'Low' : 'Critical'} - - - - - - - - - - - - - - {gpuMemoryStatus.cuda.free.toFixed(1)}GB free - - - {gpuMemoryStatus.cuda.total.toFixed(1)}GB total + + {gpuMemoryStatus.cuda.free.toFixed(1)} / {gpuMemoryStatus.cuda.total.toFixed(0)} GB free + + + ) : ( - <> - - - GPU Status - - - - - No GPU - - - - - - No CUDA GPU detected + + + GPU - - Using CPU for processing + + Not detected · CPU mode - + )} - + {/* Main Content with Sidebar Layout */} - + {/* Left Sidebar with Vertical Tabs */} - + - } iconPosition={isIconOnlySidebar ? 'top' : 'start'} label={isIconOnlySidebar ? undefined : 'Data Processing'} /> - } iconPosition={isIconOnlySidebar ? 'top' : 'start'} label={isIconOnlySidebar ? undefined : 'Training'} /> - } iconPosition={isIconOnlySidebar ? 'top' : 'start'} label={isIconOnlySidebar ? undefined : 'Generation'} /> + } iconPosition={isIconOnlySidebar ? 'top' : 'start'} label={(isIconOnlySidebar || isMobileLayout) ? undefined : 'Dataset'} /> + } iconPosition={isIconOnlySidebar ? 'top' : 'start'} label={(isIconOnlySidebar || isMobileLayout) ? undefined : 'Training'} /> + } iconPosition={isIconOnlySidebar ? 'top' : 'start'} label={(isIconOnlySidebar || isMobileLayout) ? undefined : 'Generation'} /> } iconPosition={isIconOnlySidebar ? 'top' : 'start'} - label={isIconOnlySidebar ? undefined : ( - - Performance - {}} - onClick={(e) => { e.stopPropagation(); togglePerformance(); }} - sx={{ transform: 'scale(0.75)' }} - /> - - )} - sx={{ opacity: performanceEnabled ? 1 : 0.5, transition: 'opacity 0.2s' }} + label={(isIconOnlySidebar || isMobileLayout) ? undefined : 'Performance'} /> {/* Main Content Area */} - - - {/* Data Processing Tab */} - - - - - - - - Manual Annotation - - - Upload audio files one by one and annotate them yourself. - Use this when you want full control over every annotation. - - - {uploadRows.map((row, index) => ( - - ))} - - - - - - - - - - - - - - - {processingStatus && ( - - {processingStatus} - - )} - - {!systemStatus && ( - - - - - - Dataset Status - - - - - Scanning dataset… - - - - )} - - {systemStatus && ( - - - - - - Dataset Status - {isStatusLoading && ( - - )} - - Raw Files: {systemStatus.raw_files} - - Total Duration: {formatDuration(systemStatus.total_duration || 0)} - - - Custom Metadata: {systemStatus.has_metadata_json ? 'Yes' : 'Not Found'} - - {systemStatus.raw_file_names && systemStatus.raw_file_names.length > 0 && ( - - - Recent files: {systemStatus.raw_file_names.join(', ')} - - - )} - - )} + + - - + {/* Dataset Tab */} + + setCheckpointMgrOpen(true)} /> {/* Training Tab */} - + @@ -1343,34 +1377,6 @@ function App() { Training Configuration - - - Training mode - - { - if (newMode !== null) { - setTrainingConfig({ ...trainingConfig, mode: newMode }); - } - }} - fullWidth - > - - - LoRA Adapter - - - - - Full Fine-tune - - - - - - - }> - Advanced Settings - - - - - Epochs - - + + Dataset + + {trainingProjects.length === 0 ? ( + + No projects yet — create one in the Dataset tab. + + ) : ( + + )} + + {/* Phase 6 — pre-encode latents button. State machine: + no latents → "Pre-encode latents · N clips" (clickable, outlined) + running → "Encoding… X / Y" (disabled, with Stop button) + present → "✓ Pre-encoded · N latents" (disabled, outlined, green tint). */} + {trainingProject && (() => { + const job = trainingPreEncode.job; + const inFlight = job && (job.state === 'queued' || job.state === 'running'); + const ready = trainingPreEncode.latents_present && !inFlight; + const project = trainingProjects.find(p => p.name === trainingProject); + const clipCount = project?.clip_count ?? 0; + let label = `Pre-encode latents · ${clipCount} clip${clipCount === 1 ? '' : 's'}`; + if (inFlight) { + label = job.total > 0 + ? `Encoding… ${job.current} / ${job.total}` + : 'Encoding…'; + } else if (ready) { + label = `Pre-encoded · ${trainingPreEncode.latents_count} latent${trainingPreEncode.latents_count === 1 ? '' : 's'}`; + } + return ( + + + {inFlight && ( + + )} + + ); + })()} + + + + + Base model to fine-tune + + + + + + }> + Advanced Settings + + + + + + + Training Steps + + setTrainingConfig({ ...trainingConfig, - epochs: value + steps: value })} - min={1} - max={1000} + min={500} + max={20000} + step={500} + marks={[ + { value: 1000, label: '1k' }, + { value: 5000, label: '5k' }, + { value: 10000, label: '10k' }, + { value: 20000, label: '20k' }, + ]} valueLabelDisplay="auto" sx={appStyles.sliderFlexGrow} /> { - const val = parseInt(e.target.value) || 1; + const val = parseInt(e.target.value) || 500; setTrainingConfig({ ...trainingConfig, - epochs: Math.max(1, Math.min(1000, val)) + steps: Math.max(500, Math.min(20000, val)) }); }} - inputProps={{ min: 1, max: 1000, step: 1 }} + inputProps={{ min: 500, max: 20000, step: 100 }} sx={appStyles.sliderInputSmall} size="small" /> + + + + + Adapter Type + + + + + + + + Checkpoint Interval (steps) )} + + - Learning Rate + + + Learning Rate + + - Batch Size + + + Batch Size @@ -1542,21 +1765,22 @@ function App() { const val = parseInt(e.target.value, 10) || 1; setTrainingConfig({ ...trainingConfig, - batchSize: Math.max(1, Math.min(32, val)) + batchSize: Math.max(1, Math.min(8, val)) }); }} - inputProps={{ min: 1, max: 32, step: 1 }} + inputProps={{ min: 1, max: 8, step: 1 }} sx={appStyles.sliderInputSmall} size="small" /> - - Lower this if you hit CUDA out-of-memory; raise it for faster training on large GPUs. - + + - Precision + + + Base-model Precision - - Auto picks bf16-mixed on modern CUDA, 16-mixed on older cards, fp32 on CPU/MPS. - + + - {trainingConfig.mode === 'lora' && ( - - - LoRA settings - + + + LoRA settings + - Rank + + + Rank - - Higher rank = more capacity but more VRAM. r=16 fits comfortably on 16 GB. - - - Alpha + + + + + Alpha - - Scaling factor for the LoRA update. Conventional choice: alpha = rank. - - - Dropout + + + + + Dropout - - Regularization for the LoRA layers. 0 is fine for most cases; raise if overfitting on small datasets. - - - )} + + + + + + + Seed + setTrainingConfig({ + ...trainingConfig, + seedRandom: e.target.checked, + })} + /> + } + label="Random" + labelPlacement="start" + /> + + { + const v = parseInt(e.target.value, 10); + setTrainingConfig({ + ...trainingConfig, + seed: Number.isFinite(v) ? v : 42, + }); + }} + inputProps={{ min: 0, step: 1 }} + /> + + + + + + Training Window (seconds) + + setTrainingConfig({ + ...trainingConfig, + duration: value, + })} + min={5} + max={(trainingConfig.baseModel || '').includes('medium') ? 380 : 120} + step={1} + marks={[{ value: 30, label: '30s' }]} + valueLabelDisplay="auto" + sx={appStyles.sliderFlexGrow} + /> + { + const cap = (trainingConfig.baseModel || '').includes('medium') ? 380 : 120; + const v = Math.max(5, Math.min(cap, parseFloat(e.target.value) || 30)); + setTrainingConfig({ ...trainingConfig, duration: v }); + }} + inputProps={{ min: 5, max: (trainingConfig.baseModel || '').includes('medium') ? 380 : 120, step: 1 }} + sx={appStyles.sliderInputSmall} + size="small" + /> + + + + {/* include/exclude layer targeting is intentionally not + exposed — the default (transformer.layers / exclude + seconds_total to_local_embed) is SA3's documented + small-dataset-safe filter; a wrong value silently + degrades training. Still sent from trainingConfig. */} + @@ -1686,8 +1983,8 @@ function App() { - - - - {/* Free GPU Memory Confirmation Dialog */} - - {colorMode === 'light' ? : } - - - setShowInfoDialog(true)} - sx={appStyles.infoButton} + { if (reason !== 'clickaway') setProcessingStatus(''); }} + anchorOrigin={{ vertical: 'bottom', horizontal: 'right' }} > - - - - setShowInfoDialog(false)} - aria-labelledby="about-documentation-dialog-title" - maxWidth="sm" - fullWidth - > - - - setProcessingStatus('')} + severity={ + /error|failed/i.test(processingStatus) ? 'error' + : /completed|success/i.test(processingStatus) ? 'success' + : 'info' + } + variant="filled" + sx={{ minWidth: 280, boxShadow: 6 }} + > + {processingStatus} + + + + {isDockCollapsed ? ( + <> + setDockMenuAnchor(e.currentTarget)} + sx={appStyles.dockHamburger} + > + + + setDockMenuAnchor(null)} + anchorOrigin={{ vertical: 'top', horizontal: 'left' }} + transformOrigin={{ vertical: 'bottom', horizontal: 'left' }} + > + { setDockMenuAnchor(null); setCheckpointMgrOpen(true); }} + > + + Get Models + + { setDockMenuAnchor(null); handleOpenOutputFolder(); }} + > + + Outputs + + { setDockMenuAnchor(null); setShowFreeGPUDialog(true); }} + disabled={isFreeingGPU || !(gpuMemoryStatus && gpuMemoryStatus.cuda)} + > + + {isFreeingGPU ? : } + + {isFreeingGPU ? 'Freeing…' : 'Free GPU'} + + + { setDockMenuAnchor(null); toggleColorMode(); }} + > + + {colorMode === 'light' ? : } + + {colorMode === 'light' ? 'Dark Mode' : 'Light Mode'} + + { setDockMenuAnchor(null); setShowInfoDialog(true); }} + > + + About + + + + ) : ( + ({ + position: 'fixed', + left: { xs: theme.spacing(1.5), sm: theme.spacing(2), md: theme.spacing(3) }, + bottom: { xs: theme.spacing(7), sm: theme.spacing(9), md: theme.spacing(12) }, + zIndex: 1350, + display: 'flex', + flexDirection: 'column', + alignItems: 'center', + gap: 0.75, + })} + > + {/* Small Info View toggle, sitting above the dock card. */} + ({ + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + gap: 0.4, + px: 0.5, + py: 0.25, + m: 0, border: 'none', - boxShadow: 'none', - filter: 'none', - }} /> - - Fragmenta - + background: 'transparent', + cursor: 'pointer', + borderRadius: 999, + fontFamily: 'inherit', + fontSize: '0.6rem', + lineHeight: 1, + letterSpacing: '0.02em', + color: infoViewEnabled ? theme.palette.primary.main : theme.palette.text.disabled, + opacity: infoViewEnabled ? 0.9 : 0.55, + transition: 'color 160ms ease, opacity 160ms ease', + '&:hover': { + opacity: 1, + color: infoViewEnabled ? theme.palette.primary.light : theme.palette.text.secondary, + }, + })} + > + + Info - - - - Fragmenta is an open source, local-first pipeline to fine-tune, LoRA, train, generate and perform with text-to-audio diffusion models. - Made by the composer and researcher Misagh Azimi. - - - Resources - - - - - - + {isFreeingGPU ? : } + + + {isFreeingGPU ? 'Freeing…' : 'Free GPU'} + - - - Powered by Stability AI —{' '} - - Stable Audio Open - {' '} - models, governed by the{' '} - - Stability AI Community License - . + + + {colorMode === 'light' ? : } + + + {colorMode === 'light' ? 'Dark Mode' : 'Light Mode'} - - "This Stability AI Model is licensed under the Stability AI Community License,{' '} - Copyright © Stability AI Ltd. All Rights Reserved" + + + + setShowInfoDialog(true)} + sx={appStyles.dockIconButton} + > + + + + About - - - - - + + + )} + + setShowInfoDialog(false)} + onOpenDocumentation={handleOpenDocumentation} + isOpeningDocumentation={isOpeningDocumentation} + /> )} {!suggestionDialog.loading && suggestionDialog.data?.ok && (() => { - const { stats, config, rationale } = suggestionDialog.data; + const { stats, config, rationale, warnings } = suggestionDialog.data; + const includeStr = (config.include || []).join(', ') || '(all layers)'; + const excludeStr = (config.exclude || []).join(', ') || '(none)'; return ( {stats.file_count} files · {stats.duration_human} - {stats.vram_gb ? ` · GPU ${stats.vram_gb} GB` : ''} + {stats.median_clip_sec ? ` · median ${stats.median_clip_sec.toFixed(1)}s` : ''} + {stats.vram_gb ? ` · GPU ${stats.vram_gb} GB` : ' · no GPU'} + {(warnings || []).map((w, i) => ( + + {w} + + ))} + + Steps + {config.steps.toLocaleString()} Batch size {config.batchSize} Learning rate {config.learningRate} - Epochs - {config.epochs} - {trainingConfig.mode === 'lora' && ( - <> - LoRA rank / alpha - - {config.loraRank} / {config.loraAlpha} - - - )} - Total steps - {stats.total_steps} + Training window + {config.duration.toFixed(0)}s + Adapter · rank / α + + {config.adapterType} · {config.loraRank} / {config.loraAlpha} + + Dropout · precision + + {config.loraDropout} · {config.precision} + + Include layers + {includeStr} + Exclude layers + {excludeStr} + Checkpoint every + {config.checkpointSteps.toLocaleString()} steps - { - setAuthDialogOpen(false); - if (success) { - refreshAllModels(); - } + setOverwriteConfirm(null)} + maxWidth="xs" + fullWidth + > + Overwrite existing run? + + + {overwriteConfirm?.message} + + + The previous run dir for {overwriteConfirm?.runName} will + be deleted, including {overwriteConfirm?.checkpointCount} checkpoint(s), + training.log, metrics.csv and any Lightning logs. This cannot be undone. + + + + + + + + + { + setCheckpointMgrOpen(false); + refreshAllModels(); }} /> + ); } diff --git a/app/frontend/src/api.js b/app/frontend/src/api.js index d338d3b0cd92d7e04d827e409031986006c6a80b..bc994655b65739585e2514078adc699f09be4099 100644 --- a/app/frontend/src/api.js +++ b/app/frontend/src/api.js @@ -37,6 +37,7 @@ const api = { get: (url, config) => request('GET', url, null, config), post: (url, body, config) => request('POST', url, body, config), put: (url, body, config) => request('PUT', url, body, config), + patch: (url, body, config) => request('PATCH', url, body, config), delete: (url, config) => request('DELETE', url, null, config), }; diff --git a/app/frontend/src/components/AboutDialog.js b/app/frontend/src/components/AboutDialog.js new file mode 100644 index 0000000000000000000000000000000000000000..cfb3ddb5227eb92ab6ed1038f845931dc96b1e8a --- /dev/null +++ b/app/frontend/src/components/AboutDialog.js @@ -0,0 +1,130 @@ +import React from 'react'; +import { + Box, + Button, + Dialog, + DialogActions, + DialogContent, + DialogTitle, + Typography, +} from '@mui/material'; +import { + Info as InfoIcon, + BookOpen as BookOpenIcon, +} from 'lucide-react'; +import { appStyles } from '../theme'; +import { APP_VERSION } from '../version'; + +/** + * "About Fragmenta" dialog — logo + title, short intro, three doc buttons + * (About / Documentation / Tutorials), and the Stability AI Community + * License attribution footer. + * + * Props: + * open: bool + * onClose: () => void + * onOpenDocumentation: ('about' | 'documentation') => void + * isOpeningDocumentation: bool — disables the doc buttons while a + * native open-file call is in flight + */ +export default function AboutDialog({ + open, + onClose, + onOpenDocumentation, + isOpeningDocumentation, +}) { + return ( + + + + + + Fragmenta + + + v{APP_VERSION} + + + + + + Fragmenta is an open source, local-first suit to prepare datasets, train, generate and perform with text-to-audio diffusion models. + Made by the composer and researcher Misagh Azimi. + + + + + + + + + + + Powered by{' '} + + Stable Audio 3 + {' '}by Stability AI. "This Stability AI Model is licensed under the{' '} + + Stability AI Community License + ,{' '} + Copyright © Stability AI Ltd. All Rights Reserved" + + + + + + + + ); +} diff --git a/app/frontend/src/components/AudioWaveform.js b/app/frontend/src/components/AudioWaveform.js new file mode 100644 index 0000000000000000000000000000000000000000..254e13d101cc8cb32f4ffbc5dce6cfe569515ef2 --- /dev/null +++ b/app/frontend/src/components/AudioWaveform.js @@ -0,0 +1,258 @@ +import React, { useEffect, useRef, useState, useCallback } from 'react'; +import { Box, Typography } from '@mui/material'; + +/** + * Canvas waveform with a single draggable region (for SA3 inpaint UX). + * + * Decodes the supplied File via the Web Audio API (no network round-trip), + * computes per-pixel min/max peaks once per (file, width) pair, and renders + * a region overlay + two draggable handles. Region drag in three modes: + * - drag the left handle → adjust start + * - drag the right handle → adjust end + * - drag the body → shift the whole region in place + * + * Region is controlled: parent owns `start` / `end` in seconds. + * + * Props: + * file: File | null — source audio + * duration: number — clip length in seconds (must be passed; we + * don't infer it from decoded length so the + * caller can drive a probe before decode + * finishes) + * start, end: number — region in seconds + * onRegionChange: (start, end) => void + * minRegionSec: number — default 0.1 + * height: number — canvas height in px (default 96) + * color: CSS color — waveform peak color (default theme accent) + * regionColor: CSS color — fill for the region rect + */ +export default function AudioWaveform({ + file, + duration, + start, + end, + onRegionChange, + minRegionSec = 0.1, + height = 96, + color = '#279FBB', + regionColor = 'rgba(253, 162, 43, 0.28)', +}) { + const canvasRef = useRef(null); + const containerRef = useRef(null); + const [width, setWidth] = useState(0); + const [peaks, setPeaks] = useState(null); + const [decoding, setDecoding] = useState(false); + const [decodeError, setDecodeError] = useState(null); + // Drag state lives in a ref to avoid re-renders during pointer move. + const dragRef = useRef(null); + + // --- responsive width via ResizeObserver ----------------------------- + useEffect(() => { + const el = containerRef.current; + if (!el) return; + const ro = new ResizeObserver((entries) => { + const w = Math.max(1, Math.floor(entries[0].contentRect.width)); + setWidth(w); + }); + ro.observe(el); + return () => ro.disconnect(); + }, []); + + // --- decode + peak computation --------------------------------------- + useEffect(() => { + if (!file || !width) return; + let cancelled = false; + setDecoding(true); + setDecodeError(null); + + (async () => { + try { + const buf = await file.arrayBuffer(); + if (cancelled) return; + // Reuse one AudioContext where possible. Safari and Chrome both + // permit creating an offline one for pure decode without user + // gesture, which is what we want. + const Ctx = window.OfflineAudioContext || window.webkitOfflineAudioContext; + const tmpCtx = Ctx + ? new Ctx(1, 44100, 44100) + : new (window.AudioContext || window.webkitAudioContext)(); + const audio = await tmpCtx.decodeAudioData(buf.slice(0)); + if (cancelled) return; + + // Average across channels into mono peaks, then bucket into + // `width` columns. Each column gets (min, max) in [-1, 1]. + const ch0 = audio.getChannelData(0); + const ch1 = audio.numberOfChannels > 1 ? audio.getChannelData(1) : null; + const totalSamples = ch0.length; + const bucketSize = Math.max(1, Math.floor(totalSamples / width)); + const out = new Float32Array(width * 2); + for (let i = 0; i < width; i++) { + const s = i * bucketSize; + const e = Math.min(totalSamples, s + bucketSize); + let mn = 0, mx = 0; + for (let j = s; j < e; j++) { + const v = ch1 ? (ch0[j] + ch1[j]) * 0.5 : ch0[j]; + if (v < mn) mn = v; + if (v > mx) mx = v; + } + out[i * 2] = mn; + out[i * 2 + 1] = mx; + } + setPeaks(out); + } catch (err) { + setDecodeError(err.message || 'Failed to decode audio'); + } finally { + if (!cancelled) setDecoding(false); + } + })(); + + return () => { cancelled = true; }; + }, [file, width]); + + // --- canvas drawing -------------------------------------------------- + const draw = useCallback(() => { + const canvas = canvasRef.current; + if (!canvas || !width || !height) return; + const dpr = window.devicePixelRatio || 1; + canvas.width = width * dpr; + canvas.height = height * dpr; + const ctx = canvas.getContext('2d'); + ctx.setTransform(dpr, 0, 0, dpr, 0, 0); + ctx.clearRect(0, 0, width, height); + + // Background: faint center line so empty audio still shows scale. + ctx.fillStyle = 'rgba(255, 255, 255, 0.05)'; + ctx.fillRect(0, height / 2 - 0.5, width, 1); + + // Peaks + if (peaks) { + ctx.fillStyle = color; + const mid = height / 2; + const scale = (height - 4) / 2; + for (let i = 0; i < width; i++) { + const mn = peaks[i * 2]; + const mx = peaks[i * 2 + 1]; + const y0 = mid - mx * scale; + const y1 = mid - mn * scale; + ctx.fillRect(i, y0, 1, Math.max(1, y1 - y0)); + } + } + + // Region overlay + if (duration > 0 && Number.isFinite(start) && Number.isFinite(end)) { + const sPx = Math.max(0, Math.min(width, (start / duration) * width)); + const ePx = Math.max(0, Math.min(width, (end / duration) * width)); + const rectW = Math.max(1, ePx - sPx); + ctx.fillStyle = regionColor; + ctx.fillRect(sPx, 0, rectW, height); + // Handles + ctx.fillStyle = '#FDA22B'; + ctx.fillRect(sPx - 1, 0, 2, height); + ctx.fillRect(ePx - 1, 0, 2, height); + } + }, [width, height, peaks, color, regionColor, start, end, duration]); + + useEffect(() => { draw(); }, [draw]); + + // --- pointer interaction -------------------------------------------- + const HIT_PX = 8; + const pxToSec = useCallback((px) => { + return Math.max(0, Math.min(duration, (px / width) * duration)); + }, [width, duration]); + + const onPointerDown = (e) => { + if (!duration || !width) return; + const rect = canvasRef.current.getBoundingClientRect(); + const px = e.clientX - rect.left; + const sPx = (start / duration) * width; + const ePx = (end / duration) * width; + let mode; + if (Math.abs(px - sPx) <= HIT_PX) mode = 'start'; + else if (Math.abs(px - ePx) <= HIT_PX) mode = 'end'; + else if (px > sPx && px < ePx) mode = 'body'; + else mode = 'new'; // start a new region by drag + dragRef.current = { + mode, + startPx: px, + origStart: start, + origEnd: end, + }; + canvasRef.current.setPointerCapture(e.pointerId); + if (mode === 'new') { + const t = pxToSec(px); + onRegionChange?.(t, Math.min(duration, t + minRegionSec)); + dragRef.current.mode = 'end'; + dragRef.current.origStart = t; + dragRef.current.origEnd = t + minRegionSec; + } + }; + + const onPointerMove = (e) => { + const d = dragRef.current; + if (!d) return; + const rect = canvasRef.current.getBoundingClientRect(); + const px = e.clientX - rect.left; + const delta = pxToSec(px) - pxToSec(d.startPx); + let s = d.origStart; + let en = d.origEnd; + if (d.mode === 'start') { + s = Math.max(0, Math.min(d.origEnd - minRegionSec, d.origStart + delta)); + } else if (d.mode === 'end') { + en = Math.max(d.origStart + minRegionSec, Math.min(duration, d.origEnd + delta)); + } else if (d.mode === 'body') { + const span = d.origEnd - d.origStart; + s = Math.max(0, Math.min(duration - span, d.origStart + delta)); + en = s + span; + } + onRegionChange?.(s, en); + }; + + const onPointerUp = (e) => { + if (dragRef.current) { + canvasRef.current.releasePointerCapture(e.pointerId); + dragRef.current = null; + } + }; + + // --- render ---------------------------------------------------------- + return ( + + + {(decoding || decodeError || !file) && ( + + + {decodeError + ? `decode failed: ${decodeError}` + : !file + ? 'no source loaded' + : 'decoding…'} + + + )} + + ); +} diff --git a/app/frontend/src/components/ChannelFragmentHistory.js b/app/frontend/src/components/ChannelFragmentHistory.js new file mode 100644 index 0000000000000000000000000000000000000000..c258a094c7ef1ff0f712d3638a565bd864b6689a --- /dev/null +++ b/app/frontend/src/components/ChannelFragmentHistory.js @@ -0,0 +1,217 @@ +import React, { useState } from 'react'; +import { + Box, + IconButton, + Dialog, + DialogTitle, + DialogContent, + DialogContentText, + DialogActions, + Button, +} from '@mui/material'; +import { TIPS } from '../tooltips'; +import Tooltip from './Tooltip'; +import { + Play as PlayIcon, + Square as StopIcon, + Star as StarIcon, + Trash2 as DeleteIcon, + Check as CommitIcon, + Eraser as ClearAllIcon, +} from 'lucide-react'; +import { performanceChannelStyles as styles } from '../theme'; +import { MidiMappable } from './MidiContext'; + +/** + * Per-channel rolling fragment history. Always visible (empty-state included) + * so the user knows the strip exists. Chronological order — oldest at + * the top, newest at the bottom; scrolls vertically when the list grows + * past ~4 visible rows. + * + * Each row exposes four actions, all visible by default (no hover-reveal — + * Performance use is fast, can't afford the discoverability tax): + * • Cue ▶/■ — audition through the cue output (separate from main mix) + * • Star ★/☆ — mark as a keeper. Starred fragments survive the cap + * eviction; unstarred get dropped FIFO when over cap. + * • Delete ⌫ — remove this fragment from history (cancellable confirm not + * shown for single deletes — the entry can be regenerated + * or audition can be retriggered after a quick re-tap). + * • Load ✓ — commit this fragment to the channel strip (becomes the + * audio the channel plays). Disabled while already loaded. + * + * Props: + * fragments: [{ id, audioUrl, blob, prompt, duration, createdAt, + * starred, number }] + * color: channel accent color + * auditioningId: the id currently playing through cue, or null + * committedId: the id currently loaded into the channel strip, or null + * maxFragments: cap, default 50 (informational; eviction lives in parent) + * on{Audition,Commit,ToggleStar,Delete}: (fragmentId) => void + * onClearAll: () => void (parent confirms separately — we still show + * a confirm dialog here for the trash-everything action) + */ +export default function ChannelFragmentHistory({ + fragments, + color, + channelIndex, + auditioningId, + committedId, + maxFragments = 50, + onAudition, + onCommit, + onToggleStar, + onDelete, + onClearAll, +}) { + const [clearConfirmOpen, setClearConfirmOpen] = useState(false); + // Channel-scoped MIME type for drag-and-drop. The waveform drop target on + // this same channel listens for this exact type — cross-channel drags + // won't highlight or accept because the mime won't match. + const dragMime = `application/x-fragmenta-fragment-ch${channelIndex}`; + + return ( + + + + Fragments + + {fragments.length > 0 && ( + setClearConfirmOpen(true)} + sx={styles.fragmentHistoryHeaderBtn} + aria-label="Clear all fragments" + > + + + )} + + + {fragments.length === 0 ? ( + Empty + ) : ( + + {fragments.map((fragment) => { + const isAuditioning = auditioningId === fragment.id; + const isCommitted = committedId === fragment.id; + return ( + { + e.dataTransfer.setData(dragMime, fragment.id); + e.dataTransfer.effectAllowed = 'copy'; + }} + sx={{ + ...styles.fragmentRow(color, isCommitted, isAuditioning), + cursor: 'grab', + '&:active': { cursor: 'grabbing' }, + }} + > + onAudition(fragment.id)} + > + + onAudition(fragment.id)} + sx={styles.fragmentIconBtn(color, isAuditioning, true)} + aria-label={isAuditioning ? 'Stop cue' : 'Audition'} + > + {isAuditioning + ? + : } + + + + + + + F{fragment.number} + + + + + onToggleStar(fragment.id)} + sx={styles.fragmentIconBtn(color, fragment.starred)} + aria-label={fragment.starred ? 'Unstar fragment' : 'Star fragment'} + > + + + + + onDelete(fragment.id)} + sx={styles.fragmentDeleteBtn} + aria-label="Delete fragment" + > + + + + + + onCommit(fragment.id)} + disabled={isCommitted} + sx={styles.fragmentIconBtn(color, isCommitted, true)} + aria-label="Load fragment into channel" + > + + + + + + ); + })} + + )} + + setClearConfirmOpen(false)}> + Clear fragment history? + + + Removes all {fragments.length} fragments from this channel's history, + including starred ones. The currently loaded clip stays loaded + — only the history entries are dropped. + + + + + + + + + ); +} diff --git a/app/frontend/src/components/CheckpointManagerWindow.js b/app/frontend/src/components/CheckpointManagerWindow.js new file mode 100644 index 0000000000000000000000000000000000000000..b8ddf603cf4a7130dc730e769ce86350a8a5373b --- /dev/null +++ b/app/frontend/src/components/CheckpointManagerWindow.js @@ -0,0 +1,243 @@ +import React, { useCallback, useEffect, useState } from 'react'; +import { + Dialog, + DialogTitle, + DialogContent, + DialogActions, + Box, + Typography, + Button, + IconButton, + Stack, + Alert, + TextField, + LinearProgress, +} from '@mui/material'; +import { + X as CloseIcon, + HardDrive as StorageIcon, + LogIn as LoginIcon, + LogOut as LogoutIcon, +} from 'lucide-react'; +import api from '../api'; +import CheckpointRow from './CheckpointRow'; +import StorageDrilldown from './StorageDrilldown'; + +const fmtBytes = (n) => { + if (!n && n !== 0) return '—'; + const units = ['B', 'KB', 'MB', 'GB', 'TB']; + let v = n; + let u = 0; + while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; } + return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`; +}; + +export default function CheckpointManagerWindow({ open, onClose }) { + const [catalog, setCatalog] = useState([]); + const [storage, setStorage] = useState(null); + const [env, setEnv] = useState(null); + const [hfAuth, setHfAuth] = useState({ signed_in: false, username: null }); + const [tokenDraft, setTokenDraft] = useState(''); + const [showTokenInput, setShowTokenInput] = useState(false); + const [authError, setAuthError] = useState(null); + const [showStorage, setShowStorage] = useState(false); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + + const refresh = useCallback(async () => { + setLoading(true); + setError(null); + try { + const [cat, store, auth, environment] = await Promise.all([ + api.get('/api/checkpoints'), + api.get('/api/checkpoints/storage'), + api.get('/api/hf-auth/status'), + api.get('/api/environment'), + ]); + setCatalog(cat.data.checkpoints); + setStorage(store.data); + setHfAuth(auth.data); + setEnv(environment.data); + } catch (e) { + setError(e.response?.data?.error || e.message); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + if (open) refresh(); + }, [open, refresh]); + + const submitToken = async () => { + setAuthError(null); + try { + await api.post('/api/hf-auth', { token: tokenDraft.trim() }); + setTokenDraft(''); + setShowTokenInput(false); + refresh(); + } catch (e) { + setAuthError(e.response?.data?.error || e.message); + } + }; + + const logout = async () => { + try { + await api.delete('/api/hf-auth'); + refresh(); + } catch (e) { + setAuthError(e.response?.data?.error || e.message); + } + }; + + const anyInstalled = catalog.some(c => c.downloaded); + + return ( + <> + + + Checkpoint Manager + + + + + + + + + + + {hfAuth.signed_in ? ( + + + HuggingFace: signed in as {hfAuth.username} + + + + ) : showTokenInput ? ( + + setTokenDraft(e.target.value)} + type="password" + sx={{ width: 240 }} + /> + + + + ) : ( + + )} + + {authError && {authError}} + + + {!hfAuth.signed_in ? ( + + SA3 checkpoints are gated on HuggingFace. You need a{' '} + + HuggingFace account + + {' '}to continue. Then{' '} + + create a Read access token + + {' '}and sign in above. + + ) : ( + + You're signed in. Each model is gated — click its name below to open the + HuggingFace page and accept the model's terms before downloading. + + )} + + {error && {error}} + {loading && } + + {!loading && !anyInstalled && catalog.length > 0 && ( + + + Pick a model to get started. + + + Small - Music (1.2 GB) is a good first choice on a laptop or any GPU. + + + )} + + {[ + { kind: 'post-trained', label: 'Distilled (fast)', hint: '8 steps, cfg locked at 1.0. Prompt, duration and seed only.' }, + { kind: 'base', label: 'Base (full control)', hint: 'CFG-aware. ~50 steps, cfg ~7. Cfg-scale and steps are live controls.' }, + { kind: 'tagger', label: 'Auto-annotation tools', hint: 'Optional helpers for dataset prep. CLAP scores audio against your vocabulary.' }, + ].map(group => { + const rows = catalog.filter(c => c.kind === group.kind); + if (!rows.length) return null; + return ( + + {group.label} + + {group.hint} + + + {rows.map(c => ( + setShowTokenInput(true)} + onChanged={refresh} + /> + ))} + + + ); + })} + + + + + + + + + setShowStorage(false)} + storage={storage} + catalog={catalog} + /> + + ); +} diff --git a/app/frontend/src/components/CheckpointRow.js b/app/frontend/src/components/CheckpointRow.js new file mode 100644 index 0000000000000000000000000000000000000000..406c1a27fc2c4a9a0dde92d8c3f56da52b2aa111 --- /dev/null +++ b/app/frontend/src/components/CheckpointRow.js @@ -0,0 +1,270 @@ +import React, { useEffect, useRef, useState } from 'react'; +import { + Box, + Typography, + Button, + Chip, + LinearProgress, + Stack, + IconButton, +} from '@mui/material'; +import { TIPS } from '../tooltips'; +import Tooltip from './Tooltip'; +import { + CloudDownload as DownloadIcon, + Trash2 as DeleteIcon, + X as CancelIcon, +} from 'lucide-react'; +import api from '../api'; + +const fmtBytes = (n) => { + if (!n && n !== 0) return '—'; + const units = ['B', 'KB', 'MB', 'GB', 'TB']; + let v = n; + let u = 0; + while (v >= 1000 && u < units.length - 1) { v /= 1000; u += 1; } + return `${v.toFixed(v < 10 ? 2 : 1)} ${units[u]}`; +}; + +const hardwareLabel = (hw) => ({ + 'cpu': 'CPU / GPU', + 'cuda': 'CUDA', + 'cuda+flash-attn': 'CUDA + Flash-Attn', +}[hw] || hw); + +// Why this host can't run a given model, or null if it can. Mirrors the gate +// in audio_generator._ensure_model. `env` comes from GET /api/environment. +const hostIncompatReason = (hw, env) => { + if (!env) return null; // capabilities unknown — don't block + if (hw === 'cuda+flash-attn') { + if (!env.cuda_available) { + return 'Requires an NVIDIA CUDA GPU. Use a Small model — those run on CPU, Apple Silicon, or any GPU.'; + } + // Gate on the real capability, not the platform: Windows works once a + // matching flash-attn wheel is installed (Blackwell/Ampere + cu12x). + // No wheel → guide the user to install one (or use Docker on WSL2). + if (!env.flash_attn_available) { + return env.platform === 'Windows' + ? 'Requires Flash Attention 2 (flash-attn). No official Windows wheel — install a matching prebuilt/built wheel for your torch+CUDA, or run via Docker on WSL2.' + : 'Requires Flash Attention 2 (flash-attn) — not installed. Install it, or use a Small model.'; + } + } + if (hw === 'cuda' && !env.cuda_available) { + return 'Recommended on an NVIDIA CUDA GPU; this host has none.'; + } + return null; +}; + +export default function CheckpointRow({ checkpoint, env, onAuthRequired, onChanged }) { + const [jobId, setJobId] = useState(checkpoint.active_job?.job_id || null); + const [job, setJob] = useState(checkpoint.active_job || null); + const [error, setError] = useState(null); + const [busy, setBusy] = useState(false); + const pollTimer = useRef(null); + + // If the parent's refresh tells us about an in-flight job and we don't + // already have one locally (typical case: dialog was closed mid-download + // and just got reopened), adopt it. Don't stomp a freshly-started local + // job_id with stale catalog data — only sync when the local state is empty + // or a *different* job is now active for this checkpoint. + useEffect(() => { + const incoming = checkpoint.active_job?.job_id || null; + if (incoming && incoming !== jobId) { + setJobId(incoming); + setJob(checkpoint.active_job); + } + }, [checkpoint.active_job, jobId]); + + useEffect(() => { + if (!jobId) return undefined; + const tick = async () => { + try { + const r = await api.get(`/api/checkpoints/jobs/${jobId}`); + setJob(r.data); + if (['complete', 'failed', 'cancelled'].includes(r.data.status)) { + if (r.data.status === 'failed' && (r.data.error || '').startsWith('hf_auth_required')) { + onAuthRequired?.(); + } else if (r.data.status === 'failed') { + setError(r.data.error); + } + setJobId(null); + onChanged?.(); + } + } catch (e) { + setError(e.response?.data?.error || e.message); + setJobId(null); + } + }; + tick(); + pollTimer.current = setInterval(tick, 1500); + return () => clearInterval(pollTimer.current); + }, [jobId, onAuthRequired, onChanged]); + + const startDownload = async () => { + setBusy(true); + setError(null); + try { + const r = await api.post(`/api/checkpoints/${checkpoint.id}/download`); + setJobId(r.data.job_id); + } catch (e) { + setError(e.response?.data?.error || e.message); + } finally { + setBusy(false); + } + }; + + const cancelDownload = async () => { + try { + await api.post(`/api/checkpoints/${checkpoint.id}/cancel-download`); + } catch (e) { + setError(e.response?.data?.error || e.message); + } + }; + + const deleteCheckpoint = async () => { + if (!window.confirm(`Delete ${checkpoint.name} (${fmtBytes(checkpoint.downloaded_bytes)})?`)) return; + setBusy(true); + try { + await api.delete(`/api/checkpoints/${checkpoint.id}`); + onChanged?.(); + } catch (e) { + setError(e.response?.data?.error || e.message); + } finally { + setBusy(false); + } + }; + + const downloading = !!jobId && job?.status === 'running'; + const queued = !!jobId && job?.status === 'queued'; + const pct = job?.total_bytes ? (job.downloaded_bytes / job.total_bytes) * 100 : 0; + const incompatReason = hostIncompatReason(checkpoint.hardware, env); + + const renderAction = () => { + if (downloading || queued) { + return ( + + ); + } + if (checkpoint.downloaded) { + return ( + + + + ); + } + if (incompatReason) { + return ( + + {/* span wrapper so the tooltip works on a disabled button */} + + + + + ); + } + return ( + + ); + }; + + return ( + + + + + + + {checkpoint.name} + + + + {checkpoint.downloaded && ( + + )} + + + {fmtBytes(checkpoint.size_bytes)} + {checkpoint.max_duration_sec && ` · up to ${checkpoint.max_duration_sec}s`} + + {incompatReason && !checkpoint.downloaded && ( + + Not supported on this machine + + )} + + {renderAction()} + + + {(downloading || queued) && ( + + + + {queued ? 'Queued…' : `${fmtBytes(job?.downloaded_bytes)} / ${fmtBytes(job?.total_bytes)}`} + + + )} + + {error && ( + + {error} + + )} + + ); +} diff --git a/app/frontend/src/components/DatasetPrep.js b/app/frontend/src/components/DatasetPrep.js new file mode 100644 index 0000000000000000000000000000000000000000..47746411f7f23fc7f52fca31a46fd5aa24bd01f4 --- /dev/null +++ b/app/frontend/src/components/DatasetPrep.js @@ -0,0 +1,1823 @@ +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { + Accordion, + AccordionDetails, + AccordionSummary, + Alert, + Autocomplete, + Box, + Button, + Checkbox, + Chip, + Dialog, + DialogActions, + DialogContent, + DialogTitle, + FormControl, + FormControlLabel, + IconButton, + InputLabel, + LinearProgress, + MenuItem, + Paper, + Portal, + Radio, + RadioGroup, + Select, + Snackbar, + Stack, + Switch, + Table, + TableBody, + TableCell, + TableContainer, + TableHead, + TableRow, + TextField, + Typography, + useTheme, +} from '@mui/material'; +import { TIPS } from '../tooltips'; +import Tooltip from './Tooltip'; +import { + ChevronDown as ChevronDownIcon, + FolderOpenIcon, + PlusIcon, + WandSparkles, + SaveIcon, + Database as Database, + DatabaseZap as DatasetIcon, + Square as StopIcon, + Trash2 as TrashIcon, + Play as PlayIcon, + Pause as PauseIcon, + Scissors as ScissorsIcon, + Music as MusicIcon, + Activity as HealthIcon, +} from 'lucide-react'; +import api from '../api'; +import { appStyles } from '../theme'; + +/** + * DatasetPrep — sidecar-native dataset surface with a buffered editing model. + * + * One page, no modes. Pick or create a project. The dataset folder on disk + * is the *committed* state. Edits, auto-annotate output, and just-ingested + * audio all live in an in-memory session until the user explicitly hits + * Save (writes a draft) or Commit (writes .txt sidecars). + */ +export default function DatasetPrep({ onOpenCheckpointManager }) { + const [projects, setProjects] = useState([]); + const [selectedName, setSelectedName] = useState(() => { + try { return window.localStorage.getItem('fragmenta.datasetPrep.lastProject') || ''; } + catch { return ''; } + }); + const [project, setProject] = useState(null); + const [createOpen, setCreateOpen] = useState(false); + const [loadOpen, setLoadOpen] = useState(false); + const [ingestOpen, setIngestOpen] = useState(false); + const [sliceTarget, setSliceTarget] = useState(null); // file_name or null + // Single confirm-dialog state powering destructive actions. Mirrors the + // Free GPU / Start Fresh confirm style from App.js — replaces the + // browser-native window.confirm() prompts so the UX is consistent. + const [confirm, setConfirm] = useState(null); + const [confirmBusy, setConfirmBusy] = useState(false); + const [error, setError] = useState(''); + + const [errorCode, setErrorCode] = useState(''); + const [errorExtra, setErrorExtra] = useState(null); + const [annotateJob, setAnnotateJob] = useState(null); + const [notice, setNotice] = useState(null); // { severity, message } | null + // Phase 6 — pre-encoded latents + const [preEncodeJob, setPreEncodeJob] = useState(null); + const [preEncodeOffer, setPreEncodeOffer] = useState(false); // post-commit dialog + const [tier, setTier] = useState(() => { + try { return window.localStorage.getItem('fragmenta.datasetPrep.tier') || 'basic'; } + catch { return 'basic'; } + }); + const [skipExisting, setSkipExisting] = useState(true); + + const pollHandleRef = useRef(null); + const preEncodePollRef = useRef(null); + const isAnnotating = annotateJob?.state === 'running'; + const isPreEncoding = preEncodeJob?.state === 'running' || preEncodeJob?.state === 'queued'; + + // --- Multi-row selection (for bulk Slice) ----------------------------- + // Set of clip file_names. Reset whenever the active project + // changes, since selections from a different project are meaningless. + const [selectedFiles, setSelectedFiles] = useState(() => new Set()); + useEffect(() => { setSelectedFiles(new Set()); }, [selectedName]); + + const toggleSelected = useCallback((fileName) => { + setSelectedFiles((prev) => { + const next = new Set(prev); + if (next.has(fileName)) next.delete(fileName); + else next.add(fileName); + return next; + }); + }, []); + const toggleSelectAll = useCallback((clips) => { + setSelectedFiles((prev) => { + const allNames = clips.map((c) => c.file_name); + const allSelected = allNames.length > 0 && allNames.every((n) => prev.has(n)); + return allSelected ? new Set() : new Set(allNames); + }); + }, []); + const clearSelection = useCallback(() => setSelectedFiles(new Set()), []); + + // --- Per-row audio preview -------------------------------------------- + // One