from __future__ import annotations # Phase-3C: Targeted Modality Retry Controller from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple from src.coherence.drift_detector import detect_drift from src.coherence.msci import compute_msci_v0 from src.coherence.reporting import build_final_assessment from src.coherence.scorer import CoherenceScorer from src.coherence.controller import route_retry from src.coherence.retry.retry_si_a import retry_si_a from src.coherence.retry.retry_st_i import retry_st_i from src.embeddings.aligned_embeddings import AlignedEmbedder from src.generators.audio.generator import AudioGenerator from src.generators.image.generator import ImageRetrievalGenerator from src.generators.text.generator import TextGenerator from src.narrative.generator import NarrativeGenerator from src.orchestrator.regeneration_policy import decide_regeneration from src.orchestrator.run_manager import create_run_paths from src.planner.council import SemanticPlanningCouncil from src.planner.schema import SemanticPlan from src.planner.schema_to_text import plan_to_canonical_text from src.storage.metadata import write_run_metadata @dataclass(frozen=True) class RunOutput: run_id: str semantic_plan: Dict[str, Any] merge_report: Dict[str, Any] planner_outputs: Dict[str, Any] narrative_structured: Dict[str, Any] narrative_text: str image_path: str audio_path: str scores: Dict[str, Any] coherence: Dict[str, Any] final_assessment: Dict[str, Any] drift: Dict[str, bool] attempts: int decisions: List[Dict[str, Any]] class Orchestrator: def __init__( self, council: SemanticPlanningCouncil, text_gen: TextGenerator, image_gen: ImageRetrievalGenerator, audio_gen: AudioGenerator, msci_threshold: float = 0.42, max_attempts: int = 4, runs_dir: str = "runs", ): self.council = council self.text_gen = text_gen self.image_gen = image_gen self.audio_gen = audio_gen self.msci_threshold = msci_threshold self.max_attempts = max_attempts self.runs_dir = runs_dir self.embedder = AlignedEmbedder(target_dim=512) self.narrative_generator = NarrativeGenerator() self.coherence_scorer = CoherenceScorer() def run(self, user_prompt: str) -> RunOutput: paths = create_run_paths(self.runs_dir) council_result = self.council.run(user_prompt) if isinstance(council_result, SemanticPlan): plan = council_result merge_report = { "agreement_score": 1.0, "per_section_agreement": {}, "conflicts": {}, "notes": "unified_planner", } planner_outputs = {"unified": plan.model_dump()} else: plan = council_result.merged_plan merge_report = { "agreement_score": council_result.merge_report.agreement_score, "per_section_agreement": council_result.merge_report.per_section_agreement, "conflicts": council_result.merge_report.conflicts, "notes": council_result.merge_report.notes, } planner_outputs = { "plan_a": council_result.plan_a.model_dump(), "plan_b": council_result.plan_b.model_dump(), "plan_c": council_result.plan_c.model_dump(), } plan_text = plan_to_canonical_text(plan) plan_embedding = self.embedder.embed_text(plan_text) img_pool = self.image_gen.retrieve_top_k(plan_text, k=8) if not img_pool: index_path = getattr(self.image_gen, "index_path", None) hint = f" Expected index at {index_path}." if index_path else "" raise RuntimeError( "No image candidates retrieved. Build the image index or switch to a" f" generative image backend.{hint}" ) best_state: Optional[ Tuple[float, str, str, str, Dict[str, Any], Dict[str, bool], int] ] = None decisions: List[Dict[str, Any]] = [] retry_outcomes: List[Dict[str, Any]] = [] narrative_structured = self.narrative_generator.generate(plan.model_dump()) narrative = narrative_structured.combined_scene image_path = img_pool[0][0] audio_path = str(paths.audio_dir / "audio_attempt1.wav") audio_prompt = ( f"{plan.scene_summary}. Soundscape: {', '.join(plan.audio_elements)}. " f"Mood: {', '.join(plan.mood_emotion)}." ) retry_analysis: List[Dict[str, Any]] = [] epsilon = 0.01 for attempt in range(1, self.max_attempts + 1): if attempt == 1: audio_result = self.audio_gen.generate(audio_prompt, audio_path) audio_path = audio_result.audio_path audio_backend = audio_result.backend else: last_scores = decisions[-1]["scores"] last_coherence = decisions[-1].get("coherence", {}) classification = last_coherence.get("classification", {}) context = { "semantic_plan": plan.model_dump(), "narrative_structured": narrative_structured.model_dump(), "plan_text": plan_text, "image_path": image_path, "audio_path": audio_path, "image_generator": self.image_gen, "audio_generator": self.audio_gen, } retry_action = None retry_strategy = None retry_metric = None retry_trigger = classification.get("label") handled_regen = False if ( classification.get("label") == "MODALITY_FAILURE" and classification.get("weakest_metric") == "st_i" ): context = retry_st_i(context) image_path = context.get("image") or context.get("image_path") or image_path retry_strategy = "ALIGN_IMAGE_TO_TEXT" retry_metric = "st_i" retry_action = { "regenerate": "image", "failed_metric": "st_i", "strategy": retry_strategy, } handled_regen = True elif ( classification.get("label") == "MODALITY_FAILURE" and classification.get("weakest_metric") == "si_a" ): audio_retry_path = str(paths.audio_dir / f"audio_attempt{attempt}.wav") context["audio_path"] = audio_retry_path context = retry_si_a(context) audio_path = context.get("audio") or context.get("audio_path") or audio_path audio_backend = context.get("audio_backend") retry_meta = context.get("retry", {}) retry_strategy = retry_meta.get("strategy", "ALIGN_AUDIO_TO_IMAGE") retry_metric = "si_a" retry_action = { "regenerate": "audio", "failed_metric": "si_a", "strategy": retry_strategy, } handled_regen = True else: retry_action = route_retry(classification, context) if retry_action and retry_action.get("regenerate") == "full": retry_strategy = retry_action.get("strategy") retry_metric = retry_action.get("failed_metric") handled_regen = True council_result = self.council.run(user_prompt) if isinstance(council_result, SemanticPlan): plan = council_result merge_report = { "agreement_score": 1.0, "per_section_agreement": {}, "conflicts": {}, "notes": "unified_planner", } planner_outputs = {"unified": plan.model_dump()} else: plan = council_result.merged_plan merge_report = { "agreement_score": council_result.merge_report.agreement_score, "per_section_agreement": council_result.merge_report.per_section_agreement, "conflicts": council_result.merge_report.conflicts, "notes": council_result.merge_report.notes, } planner_outputs = { "plan_a": council_result.plan_a.model_dump(), "plan_b": council_result.plan_b.model_dump(), "plan_c": council_result.plan_c.model_dump(), } plan_text = plan_to_canonical_text(plan) plan_embedding = self.embedder.embed_text(plan_text) narrative_structured = self.narrative_generator.generate(plan.model_dump()) narrative = narrative_structured.combined_scene img_pool = self.image_gen.retrieve_top_k(plan_text, k=8) if not img_pool: index_path = getattr(self.image_gen, "index_path", None) hint = f" Expected index at {index_path}." if index_path else "" raise RuntimeError( "No image candidates retrieved. Build the image index or switch to a" f" generative image backend.{hint}" ) image_path = img_pool[0][0] audio_prompt = ( f"{plan.scene_summary}. Soundscape: {', '.join(plan.audio_elements)}. " f"Mood: {', '.join(plan.mood_emotion)}." ) audio_path = str(paths.audio_dir / f"audio_attempt{attempt}.wav") audio_result = self.audio_gen.generate(audio_prompt, audio_path) audio_path = audio_result.audio_path audio_backend = audio_result.backend target = "full" elif retry_action and retry_action.get("regenerate") in {"audio", "image"}: target = retry_action["regenerate"] retry_strategy = retry_action.get("strategy") retry_metric = retry_action.get("failed_metric") if target == "audio" and retry_action.get("audio_prompt"): audio_prompt = retry_action["audio_prompt"] if target == "image" and retry_action.get("image_prompt"): img_pool = self.image_gen.retrieve_top_k( retry_action["image_prompt"], k=8, ) else: target = decide_regeneration( last_scores["msci"], last_scores["st_i"], last_scores["st_a"], self.msci_threshold, ) if not handled_regen and target == "image": idx = min(attempt - 1, max(len(img_pool) - 1, 0)) image_path = img_pool[idx][0] if img_pool else image_path elif not handled_regen and target == "audio": audio_path = str(paths.audio_dir / f"audio_attempt{attempt}.wav") audio_prompt_variant = audio_prompt + f" Intensity level: {attempt}." audio_result = self.audio_gen.generate(audio_prompt_variant, audio_path) audio_backend = audio_result.backend elif not handled_regen and target == "text": narrative = self.text_gen.generate( f"{plan_text}\n\nRewrite concisely, keep the same meaning:\n" ).text else: target = "none" if not image_path: raise RuntimeError("Image path is empty; retrieval produced no candidates.") image_emb = self.embedder.embed_image(image_path) audio_emb = self.embedder.embed_audio(audio_path) msci = compute_msci_v0( plan_embedding, image_emb, audio_emb, include_image_audio=True, ) drift = detect_drift(msci.msci, msci.st_i, msci.st_a, msci.si_a) scores = { "msci": msci.msci, "st_i": msci.st_i, "st_a": msci.st_a, "si_a": msci.si_a, "agreement_score": merge_report["agreement_score"], "per_section_agreement": merge_report["per_section_agreement"], } metric_scores = {k: scores[k] for k in ("msci", "st_i", "st_a", "si_a")} coherence_step = self.coherence_scorer.score( scores=metric_scores, global_drift=drift["global_drift"], ) coherence_step["needs_repair"] = ( coherence_step["classification"]["label"] == "MODALITY_FAILURE" and coherence_step["classification"]["weakest_metric"] == "st_i" ) repair_attempts = 0 while coherence_step["needs_repair"] and repair_attempts < 2: narrative_structured = self.narrative_generator.repair_visual_description( plan.model_dump(), image_path=image_path, ) narrative = narrative_structured.combined_scene plan_embedding = self.embedder.embed_text( narrative_structured.visual_description ) msci = compute_msci_v0( plan_embedding, image_emb, audio_emb, include_image_audio=True, ) drift = detect_drift(msci.msci, msci.st_i, msci.st_a, msci.si_a) scores = { "msci": msci.msci, "st_i": msci.st_i, "st_a": msci.st_a, "si_a": msci.si_a, "agreement_score": merge_report["agreement_score"], "per_section_agreement": merge_report["per_section_agreement"], } metric_scores = {k: scores[k] for k in ("msci", "st_i", "st_a", "si_a")} coherence_step = self.coherence_scorer.score( scores=metric_scores, global_drift=drift["global_drift"], ) coherence_step["needs_repair"] = ( coherence_step["classification"]["label"] == "MODALITY_FAILURE" and coherence_step["classification"]["weakest_metric"] == "st_i" ) repair_attempts += 1 if coherence_step["classification"]["label"] in { "HIGH_COHERENCE", "LOCAL_MODALITY_WEAKNESS", }: break step_decision = { "attempt": attempt, "image_path": image_path, "audio_path": audio_path, "audio_backend": audio_backend if "audio_backend" in locals() else None, "scores": scores, "coherence": coherence_step, "drift": drift, "retry_strategy": retry_strategy if attempt > 1 else None, "retry_metric": retry_metric if attempt > 1 else None, } decisions.append(step_decision) if attempt > 1 and retry_metric: prev_scores = decisions[-2].get("scores", {}) before = prev_scores.get(retry_metric) after = scores.get(retry_metric) if before is not None and after is not None: before_status = self.coherence_scorer.thresholds.classify_value( retry_metric, before, ) after_status = self.coherence_scorer.thresholds.classify_value( retry_metric, after, ) success = (before_status == "FAIL" and after_status in {"WEAK", "GOOD"}) or ( after > before + epsilon ) retry_outcomes.append( { "strategy": retry_strategy, "trigger": retry_trigger, "weakest_metric": retry_metric, "before": { "msci": prev_scores.get("msci"), "st_i": prev_scores.get("st_i"), "st_a": prev_scores.get("st_a"), "si_a": prev_scores.get("si_a"), }, "after": { "msci": scores.get("msci"), "st_i": scores.get("st_i"), "st_a": scores.get("st_a"), "si_a": scores.get("si_a"), }, "epsilon": epsilon, "success": success, } ) if best_state is None or scores["msci"] > best_state[0]: best_state = ( scores["msci"], narrative, image_path, audio_path, scores, drift, attempt, ) if scores["msci"] >= self.msci_threshold and not drift["global_drift"]: break assert best_state is not None _, best_text, best_img, best_aud, best_scores, best_drift, best_attempt = best_state metric_scores = {k: best_scores[k] for k in ("msci", "st_i", "st_a", "si_a") if k in best_scores} coherence = self.coherence_scorer.score( scores=metric_scores, global_drift=best_drift["global_drift"], ) final_assessment = build_final_assessment(coherence, retry_outcomes) out = RunOutput( run_id=paths.run_id, semantic_plan=plan.model_dump(), merge_report=merge_report, planner_outputs=planner_outputs, narrative_structured=narrative_structured.model_dump(), narrative_text=best_text, image_path=best_img, audio_path=best_aud, scores=best_scores, coherence=coherence, final_assessment=final_assessment, drift=best_drift, attempts=best_attempt, decisions=decisions, ) write_run_metadata( paths.logs_dir / "run.json", { "run_id": out.run_id, "user_prompt": user_prompt, "semantic_plan": out.semantic_plan, "merge_report": out.merge_report, "planner_outputs": out.planner_outputs, "narrative_structured": out.narrative_structured, "final": { "narrative_text": out.narrative_text, "image_path": out.image_path, "audio_path": out.audio_path, "scores": out.scores, "coherence": out.coherence, "final_assessment": out.final_assessment, "drift": out.drift, "attempts": out.attempts, }, "attempt_history": out.decisions, }, ) if retry_outcomes: write_run_metadata( paths.logs_dir / "retry_outcome.json", {"retries": retry_outcomes}, ) return out