Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,333 Bytes
e877de6 fcc17af 3e13e03 53f4a87 e877de6 fcc17af 3e13e03 53f4a87 0417385 3e13e03 53f4a87 0417385 e877de6 8ea5e1e c6974e2 8ea5e1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
"""
Shared state module for Gradio hot-reload compatibility.
This module holds all global state (models, processor, chapters) in a singleton
pattern that survives Gradio's module reloading. By keeping state here instead
of in app.py, we avoid re-initialization on every code change.
"""
from dataclasses import dataclass, field
from typing import Any, Optional
@dataclass
class AppState:
"""Container for all application state."""
processor: Optional[Any] = None
model: Optional[Any] = None
chapters: Optional[list] = None
model_bundles: list = field(default_factory=list)
initialized: bool = False
# Cache for RecitationResult with duration data from audio processing
last_recitation_result: Optional[Any] = None
last_recitation_verse_ref: Optional[str] = None
# Cache for segment mode: List[RecitationResult] with duration data
last_segment_results: Optional[list] = None
last_segment_refs: Optional[list] = None
# Cache for error analysis segment HTML (for re-rendering with sort modes)
last_error_segment_data: Optional[list] = None
# Cache for multi-model mode: {idx: RecitationResult} or {idx: List[RecitationResult]}
last_multi_model_results: Optional[dict] = None # Non-segmented
last_multi_model_segment_results: Optional[dict] = None # Segmented
last_multi_model_segment_refs: Optional[dict] = None
last_multi_model_verse_ref: Optional[str] = None
# Singleton instance - survives module reloads
_state = AppState()
def get_state() -> AppState:
"""Get the shared application state."""
return _state
def is_initialized() -> bool:
"""Check if app has been initialized."""
return _state.initialized
def set_initialized(value: bool = True) -> None:
"""Mark app as initialized."""
_state.initialized = value
def get_processor():
"""Get the processor."""
return _state.processor
def set_processor(processor):
"""Set the processor."""
_state.processor = processor
def get_model():
"""Get the model."""
return _state.model
def set_model(model):
"""Set the model."""
_state.model = model
def get_chapters():
"""Get the chapters list."""
return _state.chapters
def set_chapters(chapters):
"""Set the chapters list."""
_state.chapters = chapters
def get_model_bundles():
"""Get the model bundles list."""
return _state.model_bundles
def set_model_bundles(bundles):
"""Set the model bundles list."""
_state.model_bundles = bundles
def get_last_recitation_result():
"""Get the cached RecitationResult from last audio processing."""
return _state.last_recitation_result, _state.last_recitation_verse_ref
def set_last_recitation_result(result, verse_ref: str):
"""Cache a RecitationResult with duration data for ghunnah/madd tabs."""
_state.last_recitation_result = result
_state.last_recitation_verse_ref = verse_ref
def get_last_segment_results():
"""Get the cached segment results from last segmented audio processing."""
return _state.last_segment_results, _state.last_segment_refs
def set_last_segment_results(results: list, segment_refs: list):
"""Cache segment RecitationResults with duration data for ghunnah/madd tabs."""
_state.last_segment_results = results
_state.last_segment_refs = segment_refs
def clear_segment_results():
"""Clear segment results cache (called when switching to non-segment mode)."""
_state.last_segment_results = None
_state.last_segment_refs = None
def get_last_error_segment_data():
"""Get cached error analysis segment data for re-rendering."""
return _state.last_error_segment_data
def set_last_error_segment_data(data: list):
"""Cache error analysis segment data for re-rendering with sort modes."""
_state.last_error_segment_data = data
def clear_error_segment_data():
"""Clear error segment data cache."""
_state.last_error_segment_data = None
def get_last_multi_model_results():
"""Get cached multi-model results (non-segmented mode)."""
return _state.last_multi_model_results, _state.last_multi_model_verse_ref
def set_last_multi_model_results(results: dict, verse_ref: str):
"""Cache multi-model RecitationResults for ghunnah/madd tabs (non-segmented)."""
_state.last_multi_model_results = results
_state.last_multi_model_verse_ref = verse_ref
# Clear segmented multi-model cache when setting non-segmented
_state.last_multi_model_segment_results = None
_state.last_multi_model_segment_refs = None
def get_last_multi_model_segment_results():
"""Get cached multi-model segment results (segmented mode)."""
return (
_state.last_multi_model_segment_results,
_state.last_multi_model_segment_refs,
_state.last_multi_model_verse_ref,
)
def set_last_multi_model_segment_results(results: dict, refs: dict, verse_ref: str):
"""Cache multi-model segment results for ghunnah/madd tabs (segmented)."""
_state.last_multi_model_segment_results = results
_state.last_multi_model_segment_refs = refs
_state.last_multi_model_verse_ref = verse_ref
# Clear non-segmented multi-model cache when setting segmented
_state.last_multi_model_results = None
def clear_multi_model_results():
"""Clear all multi-model caches."""
_state.last_multi_model_results = None
_state.last_multi_model_segment_results = None
_state.last_multi_model_segment_refs = None
_state.last_multi_model_verse_ref = None
def clear_all_audio_caches():
"""Clear all audio-related caches (segment results + reference audio)."""
clear_segment_results()
clear_error_segment_data()
clear_multi_model_results()
set_last_recitation_result(None, None)
from recitation_engine.reference_audio import clear_audio_caches
clear_audio_caches()
def reset_state():
"""Reset all state (for testing or forced re-init)."""
global _state
_state = AppState()
def recompute_errors_for_result(result):
"""
Re-run error detection on a cached RecitationResult.
Used when settings change (e.g., iqlab/ikhfaa sound) to update
error detection without re-running transcription.
Re-phonemizes to get fresh canonical phonemes that reflect current settings,
then re-runs error detection against the original detected phonemes.
Args:
result: RecitationResult with duration data
Returns:
New RecitationResult with updated errors but preserved duration data
"""
if result is None:
return None
from recitation_analysis.error_analysis import ErrorPipeline
from recitation_analysis.result import RecitationResult
from recitation_analysis.result_builder import get_result_builder
from utils.phonemizer_utils import get_cached_phonemizer_result
# Re-phonemize to get fresh canonical phonemes with current settings
# This is critical: when ikhfaa/iqlab setting changes, the canonical phonemes change
# Use segment_ref if available (contains word-range ref for this segment), otherwise verse_ref
ref_to_use = result.segment_ref or result.verse_ref
phonemizer_result = get_cached_phonemizer_result(ref_to_use, ["compulsory_stop"])
if phonemizer_result is None:
# Can't re-phonemize, return original result unchanged
return result
# Build fresh base result with NEW canonical phonemes
builder = get_result_builder()
fresh_result = builder.build_from_phonemizer_result(
phonemizer_result,
verse_ref=result.verse_ref,
segment_ref=result.segment_ref,
)
# Re-run error detection with FRESH canonical phonemes vs OLD detected phonemes
pipeline = ErrorPipeline()
new_result = pipeline.detect_errors(fresh_result, list(result.detected_phonemes))
# Copy over duration fields that aren't set by detect_errors
updated_result = RecitationResult(
verse_ref=new_result.verse_ref,
segment_ref=new_result.segment_ref,
canonical_phonemes=new_result.canonical_phonemes,
detected_phonemes=new_result.detected_phonemes,
canonical_words=new_result.canonical_words,
detected_words=new_result.detected_words,
alignment=new_result.alignment,
madd_mappings=new_result.madd_mappings,
phoneme_alignment=new_result.phoneme_alignment,
errors=new_result.errors,
ghunnah_instances=new_result.ghunnah_instances,
madd_instances=new_result.madd_instances,
# Preserve duration data from original
fa_segments=result.fa_segments,
cvc_avg_duration_ms=result.cvc_avg_duration_ms,
cvc_instance_count=result.cvc_instance_count,
fa_visualization_data=result.fa_visualization_data,
user_audio_clip=result.user_audio_clip,
)
# Recompute durations on new instances using preserved FA data
# This is needed because detect_errors() creates new instances without duration_ms
if result.fa_segments and updated_result.phoneme_alignment:
from recitation_analysis.duration_analysis.duration_calculator import (
build_canonical_to_decoded_map,
compute_instance_durations,
)
# Rebuild mapping from new alignment (uses new canonical phonemes after setting change)
canonical_to_decoded = build_canonical_to_decoded_map(updated_result.phoneme_alignment)
# Recompute durations for ghunnah instances
if updated_result.ghunnah_instances:
compute_instance_durations(
instances=updated_result.ghunnah_instances,
canonical_phonemes=updated_result.canonical_phonemes,
decoded_phonemes=updated_result.detected_phonemes,
fa_segments=result.fa_segments,
canonical_to_decoded=canonical_to_decoded,
)
# Recompute durations for madd instances
if updated_result.madd_instances:
compute_instance_durations(
instances=updated_result.madd_instances,
canonical_phonemes=updated_result.canonical_phonemes,
decoded_phonemes=updated_result.detected_phonemes,
fa_segments=result.fa_segments,
canonical_to_decoded=canonical_to_decoded,
)
return updated_result
|