Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Shared state module for Gradio hot-reload compatibility. | |
| This module holds all global state (models, processor, chapters) in a singleton | |
| pattern that survives Gradio's module reloading. By keeping state here instead | |
| of in app.py, we avoid re-initialization on every code change. | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import Any, Optional | |
| class AppState: | |
| """Container for all application state.""" | |
| processor: Optional[Any] = None | |
| model: Optional[Any] = None | |
| chapters: Optional[list] = None | |
| model_bundles: list = field(default_factory=list) | |
| initialized: bool = False | |
| # Cache for RecitationResult with duration data from audio processing | |
| last_recitation_result: Optional[Any] = None | |
| last_recitation_verse_ref: Optional[str] = None | |
| # Cache for segment mode: List[RecitationResult] with duration data | |
| last_segment_results: Optional[list] = None | |
| last_segment_refs: Optional[list] = None | |
| # Cache for error analysis segment HTML (for re-rendering with sort modes) | |
| last_error_segment_data: Optional[list] = None | |
| # Cache for multi-model mode: {idx: RecitationResult} or {idx: List[RecitationResult]} | |
| last_multi_model_results: Optional[dict] = None # Non-segmented | |
| last_multi_model_segment_results: Optional[dict] = None # Segmented | |
| last_multi_model_segment_refs: Optional[dict] = None | |
| last_multi_model_verse_ref: Optional[str] = None | |
| # Singleton instance - survives module reloads | |
| _state = AppState() | |
| def get_state() -> AppState: | |
| """Get the shared application state.""" | |
| return _state | |
| def is_initialized() -> bool: | |
| """Check if app has been initialized.""" | |
| return _state.initialized | |
| def set_initialized(value: bool = True) -> None: | |
| """Mark app as initialized.""" | |
| _state.initialized = value | |
| def get_processor(): | |
| """Get the processor.""" | |
| return _state.processor | |
| def set_processor(processor): | |
| """Set the processor.""" | |
| _state.processor = processor | |
| def get_model(): | |
| """Get the model.""" | |
| return _state.model | |
| def set_model(model): | |
| """Set the model.""" | |
| _state.model = model | |
| def get_chapters(): | |
| """Get the chapters list.""" | |
| return _state.chapters | |
| def set_chapters(chapters): | |
| """Set the chapters list.""" | |
| _state.chapters = chapters | |
| def get_model_bundles(): | |
| """Get the model bundles list.""" | |
| return _state.model_bundles | |
| def set_model_bundles(bundles): | |
| """Set the model bundles list.""" | |
| _state.model_bundles = bundles | |
| def get_last_recitation_result(): | |
| """Get the cached RecitationResult from last audio processing.""" | |
| return _state.last_recitation_result, _state.last_recitation_verse_ref | |
| def set_last_recitation_result(result, verse_ref: str): | |
| """Cache a RecitationResult with duration data for ghunnah/madd tabs.""" | |
| _state.last_recitation_result = result | |
| _state.last_recitation_verse_ref = verse_ref | |
| def get_last_segment_results(): | |
| """Get the cached segment results from last segmented audio processing.""" | |
| return _state.last_segment_results, _state.last_segment_refs | |
| def set_last_segment_results(results: list, segment_refs: list): | |
| """Cache segment RecitationResults with duration data for ghunnah/madd tabs.""" | |
| _state.last_segment_results = results | |
| _state.last_segment_refs = segment_refs | |
| def clear_segment_results(): | |
| """Clear segment results cache (called when switching to non-segment mode).""" | |
| _state.last_segment_results = None | |
| _state.last_segment_refs = None | |
| def get_last_error_segment_data(): | |
| """Get cached error analysis segment data for re-rendering.""" | |
| return _state.last_error_segment_data | |
| def set_last_error_segment_data(data: list): | |
| """Cache error analysis segment data for re-rendering with sort modes.""" | |
| _state.last_error_segment_data = data | |
| def clear_error_segment_data(): | |
| """Clear error segment data cache.""" | |
| _state.last_error_segment_data = None | |
| def get_last_multi_model_results(): | |
| """Get cached multi-model results (non-segmented mode).""" | |
| return _state.last_multi_model_results, _state.last_multi_model_verse_ref | |
| def set_last_multi_model_results(results: dict, verse_ref: str): | |
| """Cache multi-model RecitationResults for ghunnah/madd tabs (non-segmented).""" | |
| _state.last_multi_model_results = results | |
| _state.last_multi_model_verse_ref = verse_ref | |
| # Clear segmented multi-model cache when setting non-segmented | |
| _state.last_multi_model_segment_results = None | |
| _state.last_multi_model_segment_refs = None | |
| def get_last_multi_model_segment_results(): | |
| """Get cached multi-model segment results (segmented mode).""" | |
| return ( | |
| _state.last_multi_model_segment_results, | |
| _state.last_multi_model_segment_refs, | |
| _state.last_multi_model_verse_ref, | |
| ) | |
| def set_last_multi_model_segment_results(results: dict, refs: dict, verse_ref: str): | |
| """Cache multi-model segment results for ghunnah/madd tabs (segmented).""" | |
| _state.last_multi_model_segment_results = results | |
| _state.last_multi_model_segment_refs = refs | |
| _state.last_multi_model_verse_ref = verse_ref | |
| # Clear non-segmented multi-model cache when setting segmented | |
| _state.last_multi_model_results = None | |
| def clear_multi_model_results(): | |
| """Clear all multi-model caches.""" | |
| _state.last_multi_model_results = None | |
| _state.last_multi_model_segment_results = None | |
| _state.last_multi_model_segment_refs = None | |
| _state.last_multi_model_verse_ref = None | |
| def clear_all_audio_caches(): | |
| """Clear all audio-related caches (segment results + reference audio).""" | |
| clear_segment_results() | |
| clear_error_segment_data() | |
| clear_multi_model_results() | |
| set_last_recitation_result(None, None) | |
| from recitation_engine.reference_audio import clear_audio_caches | |
| clear_audio_caches() | |
| def reset_state(): | |
| """Reset all state (for testing or forced re-init).""" | |
| global _state | |
| _state = AppState() | |
| def recompute_errors_for_result(result): | |
| """ | |
| Re-run error detection on a cached RecitationResult. | |
| Used when settings change (e.g., iqlab/ikhfaa sound) to update | |
| error detection without re-running transcription. | |
| Re-phonemizes to get fresh canonical phonemes that reflect current settings, | |
| then re-runs error detection against the original detected phonemes. | |
| Args: | |
| result: RecitationResult with duration data | |
| Returns: | |
| New RecitationResult with updated errors but preserved duration data | |
| """ | |
| if result is None: | |
| return None | |
| from recitation_analysis.error_analysis import ErrorPipeline | |
| from recitation_analysis.result import RecitationResult | |
| from recitation_analysis.result_builder import get_result_builder | |
| from utils.phonemizer_utils import get_cached_phonemizer_result | |
| # Re-phonemize to get fresh canonical phonemes with current settings | |
| # This is critical: when ikhfaa/iqlab setting changes, the canonical phonemes change | |
| # Use segment_ref if available (contains word-range ref for this segment), otherwise verse_ref | |
| ref_to_use = result.segment_ref or result.verse_ref | |
| phonemizer_result = get_cached_phonemizer_result(ref_to_use, ["compulsory_stop"]) | |
| if phonemizer_result is None: | |
| # Can't re-phonemize, return original result unchanged | |
| return result | |
| # Build fresh base result with NEW canonical phonemes | |
| builder = get_result_builder() | |
| fresh_result = builder.build_from_phonemizer_result( | |
| phonemizer_result, | |
| verse_ref=result.verse_ref, | |
| segment_ref=result.segment_ref, | |
| ) | |
| # Re-run error detection with FRESH canonical phonemes vs OLD detected phonemes | |
| pipeline = ErrorPipeline() | |
| new_result = pipeline.detect_errors(fresh_result, list(result.detected_phonemes)) | |
| # Copy over duration fields that aren't set by detect_errors | |
| updated_result = RecitationResult( | |
| verse_ref=new_result.verse_ref, | |
| segment_ref=new_result.segment_ref, | |
| canonical_phonemes=new_result.canonical_phonemes, | |
| detected_phonemes=new_result.detected_phonemes, | |
| canonical_words=new_result.canonical_words, | |
| detected_words=new_result.detected_words, | |
| alignment=new_result.alignment, | |
| madd_mappings=new_result.madd_mappings, | |
| phoneme_alignment=new_result.phoneme_alignment, | |
| errors=new_result.errors, | |
| ghunnah_instances=new_result.ghunnah_instances, | |
| madd_instances=new_result.madd_instances, | |
| # Preserve duration data from original | |
| fa_segments=result.fa_segments, | |
| cvc_avg_duration_ms=result.cvc_avg_duration_ms, | |
| cvc_instance_count=result.cvc_instance_count, | |
| fa_visualization_data=result.fa_visualization_data, | |
| user_audio_clip=result.user_audio_clip, | |
| ) | |
| # Recompute durations on new instances using preserved FA data | |
| # This is needed because detect_errors() creates new instances without duration_ms | |
| if result.fa_segments and updated_result.phoneme_alignment: | |
| from recitation_analysis.duration_analysis.duration_calculator import ( | |
| build_canonical_to_decoded_map, | |
| compute_instance_durations, | |
| ) | |
| # Rebuild mapping from new alignment (uses new canonical phonemes after setting change) | |
| canonical_to_decoded = build_canonical_to_decoded_map(updated_result.phoneme_alignment) | |
| # Recompute durations for ghunnah instances | |
| if updated_result.ghunnah_instances: | |
| compute_instance_durations( | |
| instances=updated_result.ghunnah_instances, | |
| canonical_phonemes=updated_result.canonical_phonemes, | |
| decoded_phonemes=updated_result.detected_phonemes, | |
| fa_segments=result.fa_segments, | |
| canonical_to_decoded=canonical_to_decoded, | |
| ) | |
| # Recompute durations for madd instances | |
| if updated_result.madd_instances: | |
| compute_instance_durations( | |
| instances=updated_result.madd_instances, | |
| canonical_phonemes=updated_result.canonical_phonemes, | |
| decoded_phonemes=updated_result.detected_phonemes, | |
| fa_segments=result.fa_segments, | |
| canonical_to_decoded=canonical_to_decoded, | |
| ) | |
| return updated_result | |