"""Tests for Celery tasks.""" import pytest from unittest.mock import patch, MagicMock, call import json class TestProcessTranscriptionTask: """Test the main Celery transcription task.""" @patch('tasks.shutil.copy') @patch('tasks.TranscriptionPipeline') @patch('tasks.redis_client') def test_task_success(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id, temp_storage_dir): """Test successful task execution.""" from tasks import process_transcription_task # Mock job data in Redis - all string values job_data = { 'job_id': str(sample_job_id), 'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ', 'video_id': 'dQw4w9WgXcQ', 'options': '{"instruments": ["piano"]}' } mock_redis.hgetall.return_value = job_data # Ensure pipeline method returns None mock_redis.pipeline.return_value.__enter__.return_value = mock_redis # Create actual files so they exist (temp_storage_dir / "output.musicxml").write_text("") (temp_storage_dir / "output.mid").write_bytes(b"MThd") # Mock successful pipeline instance mock_pipeline = MagicMock() mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml") mock_pipeline.final_midi_path = temp_storage_dir / "output.mid" mock_pipeline.metadata = { "tempo": 120.0, "time_signature": {"numerator": 4, "denominator": 4}, "key_signature": "C" } mock_pipeline_class.return_value = mock_pipeline # Execute task process_transcription_task(sample_job_id) # Verify pipeline ran mock_pipeline.run.assert_called_once() # Verify progress updates were published assert mock_redis.publish.call_count > 0 # Verify final status was set to completed completed_calls = [ call for call in mock_redis.hset.call_args_list if 'completed' in str(call) ] assert len(completed_calls) > 0 @patch('tasks.shutil.copy') @patch('tasks.TranscriptionPipeline') @patch('tasks.redis_client') def test_task_failure(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id): """Test task execution with pipeline failure.""" from tasks import process_transcription_task from celery.exceptions import Retry job_data = { 'job_id': sample_job_id, 'youtube_url': 'https://www.youtube.com/watch?v=invalid', 'video_id': 'invalid', 'options': '{}' } mock_redis.hgetall.return_value = job_data # Mock failed pipeline mock_pipeline = MagicMock() mock_pipeline.run.side_effect = RuntimeError("Download failed") mock_pipeline_class.return_value = mock_pipeline # Execute task - should raise Retry due to Celery's retry mechanism with pytest.raises((Retry, RuntimeError)): process_transcription_task(sample_job_id) # Verify error was stored in Redis before retry error_calls = [ call for call in mock_redis.hset.call_args_list if 'error' in str(call) ] assert len(error_calls) > 0 @patch('tasks.shutil.copy') @patch('tasks.TranscriptionPipeline') @patch('tasks.redis_client') def test_task_progress_updates(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id, temp_storage_dir): """Test that task publishes progress updates.""" from tasks import process_transcription_task job_data = { 'job_id': str(sample_job_id), 'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ', 'video_id': 'dQw4w9WgXcQ', 'options': '{}' } mock_redis.hgetall.return_value = job_data # Create actual files so they exist (temp_storage_dir / "output.musicxml").write_text("") (temp_storage_dir / "output.mid").write_bytes(b"MThd") # Ensure pipeline method returns None mock_redis.pipeline.return_value.__enter__.return_value = mock_redis mock_pipeline = MagicMock() mock_pipeline.run.return_value = str(temp_storage_dir / "output.musicxml") mock_pipeline.final_midi_path = temp_storage_dir / "output.mid" mock_pipeline.metadata = { "tempo": 120.0, "time_signature": {"numerator": 4, "denominator": 4}, "key_signature": "C" } mock_pipeline_class.return_value = mock_pipeline process_transcription_task(sample_job_id) # Verify completion message was published publish_calls = mock_redis.publish.call_args_list assert len(publish_calls) >= 1 # At least completion message # Verify final publish call contains completion info final_call = publish_calls[-1] channel, message = final_call[0] assert channel == f"job:{sample_job_id}:updates" update_data = json.loads(message) assert 'type' in update_data assert update_data['type'] == 'completed' @patch('tasks.redis_client') def test_task_job_not_found(self, mock_redis, sample_job_id): """Test task execution when job doesn't exist.""" from tasks import process_transcription_task mock_redis.hgetall.return_value = {} with pytest.raises(ValueError) as exc_info: process_transcription_task(sample_job_id) assert "Job not found" in str(exc_info.value) @patch('tasks.shutil.copy') @patch('tasks.TranscriptionPipeline') @patch('tasks.redis_client') def test_task_retry_on_network_error(self, mock_redis, mock_pipeline_class, mock_copy, sample_job_id): """Test task retry logic for transient errors.""" from tasks import process_transcription_task from celery.exceptions import Retry job_data = { 'job_id': sample_job_id, 'youtube_url': 'https://www.youtube.com/watch?v=dQw4w9WgXcQ', 'video_id': 'dQw4w9WgXcQ', 'options': '{}' } mock_redis.hgetall.return_value = job_data # Mock transient network error mock_pipeline = MagicMock() mock_pipeline.run.side_effect = ConnectionError("Network timeout") mock_pipeline_class.return_value = mock_pipeline with pytest.raises((Retry, ConnectionError)): process_transcription_task(sample_job_id) class TestProgressCallback: """Test progress callback functionality.""" @patch('tasks.redis_client') def test_update_progress(self, mock_redis, sample_job_id): """Test progress update function.""" from tasks import update_progress update_progress(sample_job_id, 50, "transcription", "Transcribing audio...") # Verify Redis was updated mock_redis.hset.assert_called() call_args = mock_redis.hset.call_args[0] assert call_args[0] == f"job:{sample_job_id}" # Verify WebSocket message was published mock_redis.publish.assert_called() channel, message = mock_redis.publish.call_args[0] assert channel == f"job:{sample_job_id}:updates" update_data = json.loads(message) assert update_data['progress'] == 50 assert update_data['stage'] == "transcription" assert update_data['message'] == "Transcribing audio..." @patch('tasks.redis_client') def test_multiple_progress_updates(self, mock_redis, sample_job_id): """Test sequence of progress updates.""" from tasks import update_progress stages = [ (5, "download", "Downloading audio"), (25, "separation", "Separating audio sources"), (60, "transcription", "Transcribing to MIDI"), (90, "musicxml", "Generating MusicXML"), (100, "completed", "Processing complete") ] for progress, stage, message in stages: update_progress(sample_job_id, progress, stage, message) # Should have 5 updates assert mock_redis.hset.call_count == 5 assert mock_redis.publish.call_count == 5 class TestCleanup: """Test cleanup of temporary files.""" @patch('tasks.shutil.rmtree') def test_cleanup_temp_files(self, mock_rmtree, sample_job_id, temp_storage_dir): """Test cleanup of temporary files after job completion.""" from tasks import cleanup_temp_files # Create the temp directory so cleanup will attempt to remove it temp_dir = temp_storage_dir / "temp" / sample_job_id temp_dir.mkdir(parents=True, exist_ok=True) cleanup_temp_files(sample_job_id, storage_path=temp_storage_dir) # Verify temp directory was removed mock_rmtree.assert_called() def test_cleanup_preserves_output(self, sample_job_id, temp_storage_dir): """Test that cleanup preserves final output files.""" from tasks import cleanup_temp_files # Create a temp directory with files temp_dir = temp_storage_dir / "temp" / sample_job_id temp_dir.mkdir(parents=True, exist_ok=True) # Create temp files (temp_dir / "temp_audio.wav").touch() (temp_dir / "temp_midi.mid").touch() # Create output files outputs_dir = temp_storage_dir / "outputs" outputs_dir.mkdir(parents=True, exist_ok=True) output_files = [ outputs_dir / "output.musicxml", outputs_dir / "output.mid" ] for f in output_files: f.touch() # Run cleanup cleanup_temp_files(sample_job_id, storage_path=temp_storage_dir) # Verify temp directory was removed assert not temp_dir.exists() # Verify output files still exist for f in output_files: assert f.exists()