Tajweed-AI / shared_state.py
hetchyy's picture
Add multi-model
53f4a87
"""
Shared state module for Gradio hot-reload compatibility.
This module holds all global state (models, processor, chapters) in a singleton
pattern that survives Gradio's module reloading. By keeping state here instead
of in app.py, we avoid re-initialization on every code change.
"""
from dataclasses import dataclass, field
from typing import Any, Optional
@dataclass
class AppState:
"""Container for all application state."""
processor: Optional[Any] = None
model: Optional[Any] = None
chapters: Optional[list] = None
model_bundles: list = field(default_factory=list)
initialized: bool = False
# Cache for RecitationResult with duration data from audio processing
last_recitation_result: Optional[Any] = None
last_recitation_verse_ref: Optional[str] = None
# Cache for segment mode: List[RecitationResult] with duration data
last_segment_results: Optional[list] = None
last_segment_refs: Optional[list] = None
# Cache for error analysis segment HTML (for re-rendering with sort modes)
last_error_segment_data: Optional[list] = None
# Cache for multi-model mode: {idx: RecitationResult} or {idx: List[RecitationResult]}
last_multi_model_results: Optional[dict] = None # Non-segmented
last_multi_model_segment_results: Optional[dict] = None # Segmented
last_multi_model_segment_refs: Optional[dict] = None
last_multi_model_verse_ref: Optional[str] = None
# Singleton instance - survives module reloads
_state = AppState()
def get_state() -> AppState:
"""Get the shared application state."""
return _state
def is_initialized() -> bool:
"""Check if app has been initialized."""
return _state.initialized
def set_initialized(value: bool = True) -> None:
"""Mark app as initialized."""
_state.initialized = value
def get_processor():
"""Get the processor."""
return _state.processor
def set_processor(processor):
"""Set the processor."""
_state.processor = processor
def get_model():
"""Get the model."""
return _state.model
def set_model(model):
"""Set the model."""
_state.model = model
def get_chapters():
"""Get the chapters list."""
return _state.chapters
def set_chapters(chapters):
"""Set the chapters list."""
_state.chapters = chapters
def get_model_bundles():
"""Get the model bundles list."""
return _state.model_bundles
def set_model_bundles(bundles):
"""Set the model bundles list."""
_state.model_bundles = bundles
def get_last_recitation_result():
"""Get the cached RecitationResult from last audio processing."""
return _state.last_recitation_result, _state.last_recitation_verse_ref
def set_last_recitation_result(result, verse_ref: str):
"""Cache a RecitationResult with duration data for ghunnah/madd tabs."""
_state.last_recitation_result = result
_state.last_recitation_verse_ref = verse_ref
def get_last_segment_results():
"""Get the cached segment results from last segmented audio processing."""
return _state.last_segment_results, _state.last_segment_refs
def set_last_segment_results(results: list, segment_refs: list):
"""Cache segment RecitationResults with duration data for ghunnah/madd tabs."""
_state.last_segment_results = results
_state.last_segment_refs = segment_refs
def clear_segment_results():
"""Clear segment results cache (called when switching to non-segment mode)."""
_state.last_segment_results = None
_state.last_segment_refs = None
def get_last_error_segment_data():
"""Get cached error analysis segment data for re-rendering."""
return _state.last_error_segment_data
def set_last_error_segment_data(data: list):
"""Cache error analysis segment data for re-rendering with sort modes."""
_state.last_error_segment_data = data
def clear_error_segment_data():
"""Clear error segment data cache."""
_state.last_error_segment_data = None
def get_last_multi_model_results():
"""Get cached multi-model results (non-segmented mode)."""
return _state.last_multi_model_results, _state.last_multi_model_verse_ref
def set_last_multi_model_results(results: dict, verse_ref: str):
"""Cache multi-model RecitationResults for ghunnah/madd tabs (non-segmented)."""
_state.last_multi_model_results = results
_state.last_multi_model_verse_ref = verse_ref
# Clear segmented multi-model cache when setting non-segmented
_state.last_multi_model_segment_results = None
_state.last_multi_model_segment_refs = None
def get_last_multi_model_segment_results():
"""Get cached multi-model segment results (segmented mode)."""
return (
_state.last_multi_model_segment_results,
_state.last_multi_model_segment_refs,
_state.last_multi_model_verse_ref,
)
def set_last_multi_model_segment_results(results: dict, refs: dict, verse_ref: str):
"""Cache multi-model segment results for ghunnah/madd tabs (segmented)."""
_state.last_multi_model_segment_results = results
_state.last_multi_model_segment_refs = refs
_state.last_multi_model_verse_ref = verse_ref
# Clear non-segmented multi-model cache when setting segmented
_state.last_multi_model_results = None
def clear_multi_model_results():
"""Clear all multi-model caches."""
_state.last_multi_model_results = None
_state.last_multi_model_segment_results = None
_state.last_multi_model_segment_refs = None
_state.last_multi_model_verse_ref = None
def clear_all_audio_caches():
"""Clear all audio-related caches (segment results + reference audio)."""
clear_segment_results()
clear_error_segment_data()
clear_multi_model_results()
set_last_recitation_result(None, None)
from recitation_engine.reference_audio import clear_audio_caches
clear_audio_caches()
def reset_state():
"""Reset all state (for testing or forced re-init)."""
global _state
_state = AppState()
def recompute_errors_for_result(result):
"""
Re-run error detection on a cached RecitationResult.
Used when settings change (e.g., iqlab/ikhfaa sound) to update
error detection without re-running transcription.
Re-phonemizes to get fresh canonical phonemes that reflect current settings,
then re-runs error detection against the original detected phonemes.
Args:
result: RecitationResult with duration data
Returns:
New RecitationResult with updated errors but preserved duration data
"""
if result is None:
return None
from recitation_analysis.error_analysis import ErrorPipeline
from recitation_analysis.result import RecitationResult
from recitation_analysis.result_builder import get_result_builder
from utils.phonemizer_utils import get_cached_phonemizer_result
# Re-phonemize to get fresh canonical phonemes with current settings
# This is critical: when ikhfaa/iqlab setting changes, the canonical phonemes change
# Use segment_ref if available (contains word-range ref for this segment), otherwise verse_ref
ref_to_use = result.segment_ref or result.verse_ref
phonemizer_result = get_cached_phonemizer_result(ref_to_use, ["compulsory_stop"])
if phonemizer_result is None:
# Can't re-phonemize, return original result unchanged
return result
# Build fresh base result with NEW canonical phonemes
builder = get_result_builder()
fresh_result = builder.build_from_phonemizer_result(
phonemizer_result,
verse_ref=result.verse_ref,
segment_ref=result.segment_ref,
)
# Re-run error detection with FRESH canonical phonemes vs OLD detected phonemes
pipeline = ErrorPipeline()
new_result = pipeline.detect_errors(fresh_result, list(result.detected_phonemes))
# Copy over duration fields that aren't set by detect_errors
updated_result = RecitationResult(
verse_ref=new_result.verse_ref,
segment_ref=new_result.segment_ref,
canonical_phonemes=new_result.canonical_phonemes,
detected_phonemes=new_result.detected_phonemes,
canonical_words=new_result.canonical_words,
detected_words=new_result.detected_words,
alignment=new_result.alignment,
madd_mappings=new_result.madd_mappings,
phoneme_alignment=new_result.phoneme_alignment,
errors=new_result.errors,
ghunnah_instances=new_result.ghunnah_instances,
madd_instances=new_result.madd_instances,
# Preserve duration data from original
fa_segments=result.fa_segments,
cvc_avg_duration_ms=result.cvc_avg_duration_ms,
cvc_instance_count=result.cvc_instance_count,
fa_visualization_data=result.fa_visualization_data,
user_audio_clip=result.user_audio_clip,
)
# Recompute durations on new instances using preserved FA data
# This is needed because detect_errors() creates new instances without duration_ms
if result.fa_segments and updated_result.phoneme_alignment:
from recitation_analysis.duration_analysis.duration_calculator import (
build_canonical_to_decoded_map,
compute_instance_durations,
)
# Rebuild mapping from new alignment (uses new canonical phonemes after setting change)
canonical_to_decoded = build_canonical_to_decoded_map(updated_result.phoneme_alignment)
# Recompute durations for ghunnah instances
if updated_result.ghunnah_instances:
compute_instance_durations(
instances=updated_result.ghunnah_instances,
canonical_phonemes=updated_result.canonical_phonemes,
decoded_phonemes=updated_result.detected_phonemes,
fa_segments=result.fa_segments,
canonical_to_decoded=canonical_to_decoded,
)
# Recompute durations for madd instances
if updated_result.madd_instances:
compute_instance_durations(
instances=updated_result.madd_instances,
canonical_phonemes=updated_result.canonical_phonemes,
decoded_phonemes=updated_result.detected_phonemes,
fa_segments=result.fa_segments,
canonical_to_decoded=canonical_to_decoded,
)
return updated_result