File size: 13,184 Bytes
44a2550
a5359f9
 
 
 
 
 
 
 
44a2550
 
 
5ad0f35
44a2550
f09b953
44a2550
75d3906
44a2550
 
5ad0f35
 
44a2550
 
 
 
 
ac5c764
44a2550
 
 
 
 
 
 
 
 
75d3906
44a2550
 
 
 
 
 
 
 
 
 
f849d05
44a2550
 
 
 
 
 
 
 
f849d05
 
 
75d3906
f849d05
 
 
 
44a2550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5359f9
44a2550
 
a5359f9
 
 
44a2550
 
a5359f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a2550
 
a5359f9
 
 
44a2550
 
 
a5359f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a2550
 
 
 
 
 
 
 
 
 
 
3c8b7e3
 
 
 
6293e69
 
44a2550
6293e69
44a2550
6293e69
 
3c8b7e3
6293e69
 
 
 
 
 
 
44a2550
 
 
 
 
 
 
 
75d3906
 
6293e69
44a2550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6293e69
 
 
 
 
 
 
 
 
 
 
44a2550
 
 
 
 
6293e69
 
 
44a2550
 
 
 
 
 
 
 
 
 
6293e69
 
44a2550
 
 
 
 
6293e69
 
 
44a2550
 
6293e69
 
44a2550
 
 
ac5c764
44a2550
ac5c764
44a2550
ac5c764
44a2550
 
 
 
 
 
 
ac5c764
 
 
 
 
 
44a2550
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""Celery tasks for background job processing."""
import sys
from pathlib import Path

# Ensure backend directory is in Python path for imports
backend_dir = Path(__file__).parent.resolve()
if str(backend_dir) not in sys.path:
    sys.path.insert(0, str(backend_dir))

from celery import Task
from celery_app import celery_app
from pipeline import TranscriptionPipeline, run_transcription_pipeline
from redis_client import get_redis_client
import json
import os
from datetime import datetime
from app_config import settings
import shutil

# Get shared Redis client singleton
redis_client = get_redis_client()


class TranscriptionTask(Task):
    """Base task with progress tracking."""

    def update_progress(self, job_id: str, progress: int, stage: str, message: str) -> None:
        """
        Update job progress in Redis and publish to WebSocket subscribers.

        Args:
            job_id: Job identifier
            progress: Progress percentage (0-100)
            stage: Current stage name
            message: Status message
        """
        print(f"[PROGRESS] {progress}% - {stage} - {message}")
        job_key = f"job:{job_id}"

        # Update Redis hash
        redis_client.hset(job_key, mapping={
            "progress": progress,
            "current_stage": stage,
            "status_message": message,
            "updated_at": datetime.utcnow().isoformat(),
        })

        # Store progress in a list for history (helps with eager mode where client connects late)
        update = {
            "type": "progress",
            "job_id": job_id,
            "progress": progress,
            "stage": stage,
            "message": message,
            "timestamp": datetime.utcnow().isoformat(),
        }
        redis_client.rpush(f"job:{job_id}:progress_history", json.dumps(update))
        
        # Publish to pub/sub for WebSocket clients (in case they're already connected)
        num_subscribers = redis_client.publish(f"job:{job_id}:updates", json.dumps(update))
        if num_subscribers > 0:
            print(f"[PROGRESS] Published to {num_subscribers} subscribers")
        else:
            print(f"[PROGRESS] Stored in history (no live subscribers)")


@celery_app.task(base=TranscriptionTask, bind=True)
def process_transcription_task(self, job_id: str):
    """
    Main transcription task.

    Args:
        job_id: Unique job identifier

    Returns:
        Path to generated MusicXML file
    """
    try:
        # Mark job as started
        redis_client.hset(f"job:{job_id}", mapping={
            "status": "processing",
            "started_at": datetime.utcnow().isoformat(),
        })

        # Get job data
        job_data = redis_client.hgetall(f"job:{job_id}")

        if not job_data:
            raise ValueError(f"Job not found: {job_id}")

        # Check if this is a file upload or YouTube URL job
        upload_path = job_data.get('upload_path')
        youtube_url = job_data.get('youtube_url')

        # Parse instruments option (defaults to piano only)
        instruments = ['piano']
        vocal_instrument_program = 40  # Default to violin
        if 'options' in job_data:
            try:
                options = json.loads(job_data['options'])
                instruments = options.get('instruments', ['piano'])
                vocal_instrument_program = options.get('vocal_instrument', 40)
            except (json.JSONDecodeError, KeyError):
                instruments = ['piano']
                vocal_instrument_program = 40

        # Import shutil and subprocess
        import shutil
        import subprocess

        # Create pipeline
        pipeline = TranscriptionPipeline(
            job_id=job_id,
            youtube_url=youtube_url or "file://uploaded",  # Dummy URL for file uploads
            storage_path=settings.storage_path,
            instruments=instruments
        )
        pipeline.set_progress_callback(lambda p, s, m: self.update_progress(job_id, p, s, m))

        # Get audio.wav - either from upload or YouTube download
        audio_path = pipeline.temp_dir / "audio.wav"

        if upload_path:
            # File upload - convert to WAV if needed
            upload_file = Path(upload_path)
            if upload_file.suffix.lower() == '.wav':
                shutil.copy(str(upload_file), str(audio_path))
            else:
                # Convert to WAV using ffmpeg
                result = subprocess.run([
                    'ffmpeg', '-i', str(upload_file),
                    '-ar', '44100', '-ac', '2',
                    str(audio_path)
                ], capture_output=True, text=True)
                if result.returncode != 0:
                    raise RuntimeError(f"Audio conversion failed: {result.stderr}")
        elif youtube_url:
            # YouTube download
            pipeline.progress(0, "download", "Starting audio download")
            audio_path = pipeline.download_audio()
        else:
            raise ValueError(f"Job missing both youtube_url and upload_path: {job_id}")

        # From here, both paths converge - process audio.wav the same way
        # Preprocess audio if enabled
        if pipeline.config.enable_audio_preprocessing:
            pipeline.progress(10, "preprocess", "Preprocessing audio")
            audio_path = pipeline.preprocess_audio(audio_path)

        # Source separation
        pipeline.progress(20, "separate", "Starting source separation")
        all_stems = pipeline.separate_sources(audio_path)

        # Select stems to transcribe based on user selection
        stems_to_transcribe = {}
        for instrument in instruments:
            if instrument in all_stems:
                stems_to_transcribe[instrument] = all_stems[instrument]
                print(f"   [DEBUG] Will transcribe {instrument} stem")
            else:
                print(f"   [WARNING] {instrument} stem not found in separated audio")

        # If no selected stems available, fall back to piano
        if not stems_to_transcribe:
            print(f"   [WARNING] No selected stems found, falling back to piano")
            if 'piano' in all_stems:
                stems_to_transcribe['piano'] = all_stems['piano']
            else:
                stems_to_transcribe['other'] = all_stems['other']

        pipeline.progress(50, "transcribe", f"Transcribing {len(stems_to_transcribe)} instrument(s)")

        # Transcribe stems
        if len(stems_to_transcribe) == 1:
            # Single stem - use original method
            stem_path = list(stems_to_transcribe.values())[0]
            combined_midi = pipeline.transcribe_to_midi(stem_path)
        else:
            # Multiple stems - use new multi-stem method
            combined_midi = pipeline.transcribe_multiple_stems(stems_to_transcribe)

        # Filter MIDI to only include selected instruments
        filtered_midi = pipeline.filter_midi_by_instruments(combined_midi)

        # Remap vocals MIDI program if vocals were selected
        if 'vocals' in instruments and vocal_instrument_program != 65:
            print(f"   [DEBUG] Remapping vocals MIDI program from 65 to {vocal_instrument_program}")
            import pretty_midi
            pm = pretty_midi.PrettyMIDI(str(filtered_midi))
            for inst in pm.instruments:
                if inst.program == 65 and not inst.is_drum:  # Singing Voice
                    inst.program = vocal_instrument_program
                    print(f"   [DEBUG] Changed track '{inst.name}' program to {vocal_instrument_program}")
            # Save remapped MIDI
            pm.write(str(filtered_midi))

        # Apply post-processing
        midi_path = pipeline.apply_post_processing_filters(filtered_midi)
        pipeline.final_midi_path = midi_path

        # Get audio stem for MusicXML generation (use piano if available, otherwise first available stem)
        audio_stem = stems_to_transcribe.get('piano') or list(stems_to_transcribe.values())[0]

        pipeline.progress(90, "musicxml", "Generating MusicXML")
        temp_output_path = pipeline.generate_musicxml_minimal(midi_path, audio_stem)
        pipeline.progress(100, "complete", "Transcription complete")

        # Output is already in the temp directory, move to persistent storage
        output_path = settings.outputs_path / f"{job_id}.musicxml"
        midi_path = settings.outputs_path / f"{job_id}.mid"

        # Ensure outputs directory exists
        settings.outputs_path.mkdir(parents=True, exist_ok=True)

        # Copy the MusicXML file to outputs
        shutil.copy(str(temp_output_path), str(output_path))

        # Copy the MIDI file to outputs (use actual processed MIDI from pipeline)
        # Pipeline stores the final MIDI path (after quantization) in final_midi_path
        temp_midi_path = getattr(pipeline, 'final_midi_path', pipeline.temp_dir / "piano.mid")
        print(f"[DEBUG] Using MIDI from pipeline: {temp_midi_path}")
        print(f"[DEBUG] MIDI exists: {temp_midi_path.exists()}")

        if temp_midi_path.exists():
            print(f"[DEBUG] Copying MIDI from {temp_midi_path} to {midi_path}")
            shutil.copy(str(temp_midi_path), str(midi_path))
            print(f"[DEBUG] Copy complete, destination exists: {midi_path.exists()}")
        else:
            print(f"[DEBUG] WARNING: No MIDI file found at {temp_midi_path}!")

        # Store metadata for API access
        metadata = getattr(pipeline, 'metadata', {
            "tempo": 120.0,
            "time_signature": {"numerator": 4, "denominator": 4},
            "key_signature": "C",
        })

        # Cleanup temp files (pipeline has its own cleanup method)
        pipeline.cleanup()

        # Mark job as completed
        redis_client.hset(f"job:{job_id}", mapping={
            "status": "completed",
            "progress": 100,
            "output_path": str(output_path.absolute()),
            "midi_path": str(midi_path.absolute()) if temp_midi_path.exists() else "",
            "metadata": json.dumps(metadata),
            "completed_at": datetime.utcnow().isoformat(),
        })

        # Publish completion message
        completion_msg = {
            "type": "completed",
            "job_id": job_id,
            "result_url": f"/api/v1/scores/{job_id}",
            "timestamp": datetime.utcnow().isoformat(),
        }
        redis_client.publish(f"job:{job_id}:updates", json.dumps(completion_msg))

        return str(output_path)

    except Exception as e:
        import traceback

        # Determine if error is retryable (only retry transient errors, not code bugs)
        RETRYABLE_EXCEPTIONS = (
            ConnectionError,  # Network errors
            TimeoutError,     # Timeout errors
            IOError,          # I/O errors (file system, disk full, etc.)
        )

        is_retryable = isinstance(e, RETRYABLE_EXCEPTIONS) and self.request.retries < self.max_retries

        # Mark job as failed
        redis_client.hset(f"job:{job_id}", mapping={
            "status": "failed",
            "error": json.dumps({
                "message": str(e),
                "type": type(e).__name__,
                "retryable": is_retryable,
                "traceback": traceback.format_exc(),
            }),
            "failed_at": datetime.utcnow().isoformat(),
        })

        # Publish error message
        error_msg = {
            "type": "error",
            "job_id": job_id,
            "error": {
                "message": str(e),
                "type": type(e).__name__,
                "retryable": is_retryable,
            },
            "timestamp": datetime.utcnow().isoformat(),
        }
        redis_client.publish(f"job:{job_id}:updates", json.dumps(error_msg))

        # Only retry if the error is transient (network, I/O, etc.)
        if is_retryable:
            print(f"[RETRY] Retrying job {job_id} (attempt {self.request.retries + 1}/{self.max_retries})")
            raise self.retry(exc=e, countdown=2 ** self.request.retries)
        else:
            # Non-retryable error (code bug, validation error, etc.) - fail immediately
            print(f"[ERROR] Non-retryable error for job {job_id}: {type(e).__name__}: {e}")
            raise


# === Module-level helper functions ===

def update_progress(job_id: str, progress: int, stage: str, message: str) -> None:
    """
    Update job progress (wrapper for backward compatibility).

    Args:
        job_id: Job identifier
        progress: Progress percentage (0-100)
        stage: Current stage name
        message: Status message
    """
    # Instantiate task to use its update_progress method
    task = TranscriptionTask()
    task.update_progress(job_id, progress, stage, message)


def cleanup_temp_files(job_id: str, storage_path: Path = None) -> None:
    """
    Clean up temporary files for a job.

    Args:
        job_id: Job identifier
        storage_path: Path to storage directory (uses settings if not provided)
    """
    if storage_path is None:
        storage_path = settings.storage_path

    temp_dir = storage_path / "temp" / job_id
    if temp_dir.exists():
        shutil.rmtree(temp_dir, ignore_errors=True)