| | from __future__ import annotations |
| |
|
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | from src.coherence.drift_detector import detect_drift |
| | from src.coherence.msci import compute_msci_v0 |
| | from src.coherence.reporting import build_final_assessment |
| | from src.coherence.scorer import CoherenceScorer |
| | from src.coherence.controller import route_retry |
| | from src.coherence.retry.retry_si_a import retry_si_a |
| | from src.coherence.retry.retry_st_i import retry_st_i |
| | from src.embeddings.aligned_embeddings import AlignedEmbedder |
| | from src.generators.audio.generator import AudioGenerator |
| | from src.generators.image.generator import ImageRetrievalGenerator |
| | from src.generators.text.generator import TextGenerator |
| | from src.narrative.generator import NarrativeGenerator |
| | from src.orchestrator.regeneration_policy import decide_regeneration |
| | from src.orchestrator.run_manager import create_run_paths |
| | from src.planner.council import SemanticPlanningCouncil |
| | from src.planner.schema import SemanticPlan |
| | from src.planner.schema_to_text import plan_to_canonical_text |
| | from src.storage.metadata import write_run_metadata |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class RunOutput: |
| | run_id: str |
| | semantic_plan: Dict[str, Any] |
| | merge_report: Dict[str, Any] |
| | planner_outputs: Dict[str, Any] |
| | narrative_structured: Dict[str, Any] |
| | narrative_text: str |
| | image_path: str |
| | audio_path: str |
| | scores: Dict[str, Any] |
| | coherence: Dict[str, Any] |
| | final_assessment: Dict[str, Any] |
| | drift: Dict[str, bool] |
| | attempts: int |
| | decisions: List[Dict[str, Any]] |
| |
|
| |
|
| | class Orchestrator: |
| | def __init__( |
| | self, |
| | council: SemanticPlanningCouncil, |
| | text_gen: TextGenerator, |
| | image_gen: ImageRetrievalGenerator, |
| | audio_gen: AudioGenerator, |
| | msci_threshold: float = 0.42, |
| | max_attempts: int = 4, |
| | runs_dir: str = "runs", |
| | ): |
| | self.council = council |
| | self.text_gen = text_gen |
| | self.image_gen = image_gen |
| | self.audio_gen = audio_gen |
| | self.msci_threshold = msci_threshold |
| | self.max_attempts = max_attempts |
| | self.runs_dir = runs_dir |
| |
|
| | self.embedder = AlignedEmbedder(target_dim=512) |
| | self.narrative_generator = NarrativeGenerator() |
| | self.coherence_scorer = CoherenceScorer() |
| |
|
| | def run(self, user_prompt: str) -> RunOutput: |
| | paths = create_run_paths(self.runs_dir) |
| |
|
| | council_result = self.council.run(user_prompt) |
| | if isinstance(council_result, SemanticPlan): |
| | plan = council_result |
| | merge_report = { |
| | "agreement_score": 1.0, |
| | "per_section_agreement": {}, |
| | "conflicts": {}, |
| | "notes": "unified_planner", |
| | } |
| | planner_outputs = {"unified": plan.model_dump()} |
| | else: |
| | plan = council_result.merged_plan |
| | merge_report = { |
| | "agreement_score": council_result.merge_report.agreement_score, |
| | "per_section_agreement": council_result.merge_report.per_section_agreement, |
| | "conflicts": council_result.merge_report.conflicts, |
| | "notes": council_result.merge_report.notes, |
| | } |
| | planner_outputs = { |
| | "plan_a": council_result.plan_a.model_dump(), |
| | "plan_b": council_result.plan_b.model_dump(), |
| | "plan_c": council_result.plan_c.model_dump(), |
| | } |
| | plan_text = plan_to_canonical_text(plan) |
| |
|
| | plan_embedding = self.embedder.embed_text(plan_text) |
| |
|
| | img_pool = self.image_gen.retrieve_top_k(plan_text, k=8) |
| | if not img_pool: |
| | index_path = getattr(self.image_gen, "index_path", None) |
| | hint = f" Expected index at {index_path}." if index_path else "" |
| | raise RuntimeError( |
| | "No image candidates retrieved. Build the image index or switch to a" |
| | f" generative image backend.{hint}" |
| | ) |
| |
|
| | best_state: Optional[ |
| | Tuple[float, str, str, str, Dict[str, Any], Dict[str, bool], int] |
| | ] = None |
| | decisions: List[Dict[str, Any]] = [] |
| | retry_outcomes: List[Dict[str, Any]] = [] |
| |
|
| | narrative_structured = self.narrative_generator.generate(plan.model_dump()) |
| | narrative = narrative_structured.combined_scene |
| |
|
| | image_path = img_pool[0][0] |
| | audio_path = str(paths.audio_dir / "audio_attempt1.wav") |
| |
|
| | audio_prompt = ( |
| | f"{plan.scene_summary}. Soundscape: {', '.join(plan.audio_elements)}. " |
| | f"Mood: {', '.join(plan.mood_emotion)}." |
| | ) |
| | retry_analysis: List[Dict[str, Any]] = [] |
| |
|
| | epsilon = 0.01 |
| | for attempt in range(1, self.max_attempts + 1): |
| | if attempt == 1: |
| | audio_result = self.audio_gen.generate(audio_prompt, audio_path) |
| | audio_path = audio_result.audio_path |
| | audio_backend = audio_result.backend |
| | else: |
| | last_scores = decisions[-1]["scores"] |
| | last_coherence = decisions[-1].get("coherence", {}) |
| | classification = last_coherence.get("classification", {}) |
| | context = { |
| | "semantic_plan": plan.model_dump(), |
| | "narrative_structured": narrative_structured.model_dump(), |
| | "plan_text": plan_text, |
| | "image_path": image_path, |
| | "audio_path": audio_path, |
| | "image_generator": self.image_gen, |
| | "audio_generator": self.audio_gen, |
| | } |
| |
|
| | retry_action = None |
| | retry_strategy = None |
| | retry_metric = None |
| | retry_trigger = classification.get("label") |
| | handled_regen = False |
| |
|
| | if ( |
| | classification.get("label") == "MODALITY_FAILURE" |
| | and classification.get("weakest_metric") == "st_i" |
| | ): |
| | context = retry_st_i(context) |
| | image_path = context.get("image") or context.get("image_path") or image_path |
| | retry_strategy = "ALIGN_IMAGE_TO_TEXT" |
| | retry_metric = "st_i" |
| | retry_action = { |
| | "regenerate": "image", |
| | "failed_metric": "st_i", |
| | "strategy": retry_strategy, |
| | } |
| | handled_regen = True |
| | elif ( |
| | classification.get("label") == "MODALITY_FAILURE" |
| | and classification.get("weakest_metric") == "si_a" |
| | ): |
| | audio_retry_path = str(paths.audio_dir / f"audio_attempt{attempt}.wav") |
| | context["audio_path"] = audio_retry_path |
| | context = retry_si_a(context) |
| | audio_path = context.get("audio") or context.get("audio_path") or audio_path |
| | audio_backend = context.get("audio_backend") |
| | retry_meta = context.get("retry", {}) |
| | retry_strategy = retry_meta.get("strategy", "ALIGN_AUDIO_TO_IMAGE") |
| | retry_metric = "si_a" |
| | retry_action = { |
| | "regenerate": "audio", |
| | "failed_metric": "si_a", |
| | "strategy": retry_strategy, |
| | } |
| | handled_regen = True |
| | else: |
| | retry_action = route_retry(classification, context) |
| |
|
| | if retry_action and retry_action.get("regenerate") == "full": |
| | retry_strategy = retry_action.get("strategy") |
| | retry_metric = retry_action.get("failed_metric") |
| | handled_regen = True |
| |
|
| | council_result = self.council.run(user_prompt) |
| | if isinstance(council_result, SemanticPlan): |
| | plan = council_result |
| | merge_report = { |
| | "agreement_score": 1.0, |
| | "per_section_agreement": {}, |
| | "conflicts": {}, |
| | "notes": "unified_planner", |
| | } |
| | planner_outputs = {"unified": plan.model_dump()} |
| | else: |
| | plan = council_result.merged_plan |
| | merge_report = { |
| | "agreement_score": council_result.merge_report.agreement_score, |
| | "per_section_agreement": council_result.merge_report.per_section_agreement, |
| | "conflicts": council_result.merge_report.conflicts, |
| | "notes": council_result.merge_report.notes, |
| | } |
| | planner_outputs = { |
| | "plan_a": council_result.plan_a.model_dump(), |
| | "plan_b": council_result.plan_b.model_dump(), |
| | "plan_c": council_result.plan_c.model_dump(), |
| | } |
| |
|
| | plan_text = plan_to_canonical_text(plan) |
| | plan_embedding = self.embedder.embed_text(plan_text) |
| | narrative_structured = self.narrative_generator.generate(plan.model_dump()) |
| | narrative = narrative_structured.combined_scene |
| |
|
| | img_pool = self.image_gen.retrieve_top_k(plan_text, k=8) |
| | if not img_pool: |
| | index_path = getattr(self.image_gen, "index_path", None) |
| | hint = f" Expected index at {index_path}." if index_path else "" |
| | raise RuntimeError( |
| | "No image candidates retrieved. Build the image index or switch to a" |
| | f" generative image backend.{hint}" |
| | ) |
| | image_path = img_pool[0][0] |
| |
|
| | audio_prompt = ( |
| | f"{plan.scene_summary}. Soundscape: {', '.join(plan.audio_elements)}. " |
| | f"Mood: {', '.join(plan.mood_emotion)}." |
| | ) |
| | audio_path = str(paths.audio_dir / f"audio_attempt{attempt}.wav") |
| | audio_result = self.audio_gen.generate(audio_prompt, audio_path) |
| | audio_path = audio_result.audio_path |
| | audio_backend = audio_result.backend |
| | target = "full" |
| | elif retry_action and retry_action.get("regenerate") in {"audio", "image"}: |
| | target = retry_action["regenerate"] |
| | retry_strategy = retry_action.get("strategy") |
| | retry_metric = retry_action.get("failed_metric") |
| | if target == "audio" and retry_action.get("audio_prompt"): |
| | audio_prompt = retry_action["audio_prompt"] |
| | if target == "image" and retry_action.get("image_prompt"): |
| | img_pool = self.image_gen.retrieve_top_k( |
| | retry_action["image_prompt"], |
| | k=8, |
| | ) |
| | else: |
| | target = decide_regeneration( |
| | last_scores["msci"], |
| | last_scores["st_i"], |
| | last_scores["st_a"], |
| | self.msci_threshold, |
| | ) |
| |
|
| | if not handled_regen and target == "image": |
| | idx = min(attempt - 1, max(len(img_pool) - 1, 0)) |
| | image_path = img_pool[idx][0] if img_pool else image_path |
| | elif not handled_regen and target == "audio": |
| | audio_path = str(paths.audio_dir / f"audio_attempt{attempt}.wav") |
| | audio_prompt_variant = audio_prompt + f" Intensity level: {attempt}." |
| | audio_result = self.audio_gen.generate(audio_prompt_variant, audio_path) |
| | audio_backend = audio_result.backend |
| | elif not handled_regen and target == "text": |
| | narrative = self.text_gen.generate( |
| | f"{plan_text}\n\nRewrite concisely, keep the same meaning:\n" |
| | ).text |
| | else: |
| | target = "none" |
| |
|
| | if not image_path: |
| | raise RuntimeError("Image path is empty; retrieval produced no candidates.") |
| | image_emb = self.embedder.embed_image(image_path) |
| | audio_emb = self.embedder.embed_audio(audio_path) |
| |
|
| | msci = compute_msci_v0( |
| | plan_embedding, |
| | image_emb, |
| | audio_emb, |
| | include_image_audio=True, |
| | ) |
| | drift = detect_drift(msci.msci, msci.st_i, msci.st_a, msci.si_a) |
| |
|
| | scores = { |
| | "msci": msci.msci, |
| | "st_i": msci.st_i, |
| | "st_a": msci.st_a, |
| | "si_a": msci.si_a, |
| | "agreement_score": merge_report["agreement_score"], |
| | "per_section_agreement": merge_report["per_section_agreement"], |
| | } |
| | metric_scores = {k: scores[k] for k in ("msci", "st_i", "st_a", "si_a")} |
| | coherence_step = self.coherence_scorer.score( |
| | scores=metric_scores, |
| | global_drift=drift["global_drift"], |
| | ) |
| | coherence_step["needs_repair"] = ( |
| | coherence_step["classification"]["label"] == "MODALITY_FAILURE" |
| | and coherence_step["classification"]["weakest_metric"] == "st_i" |
| | ) |
| |
|
| | repair_attempts = 0 |
| | while coherence_step["needs_repair"] and repair_attempts < 2: |
| | narrative_structured = self.narrative_generator.repair_visual_description( |
| | plan.model_dump(), |
| | image_path=image_path, |
| | ) |
| | narrative = narrative_structured.combined_scene |
| | plan_embedding = self.embedder.embed_text( |
| | narrative_structured.visual_description |
| | ) |
| |
|
| | msci = compute_msci_v0( |
| | plan_embedding, |
| | image_emb, |
| | audio_emb, |
| | include_image_audio=True, |
| | ) |
| | drift = detect_drift(msci.msci, msci.st_i, msci.st_a, msci.si_a) |
| |
|
| | scores = { |
| | "msci": msci.msci, |
| | "st_i": msci.st_i, |
| | "st_a": msci.st_a, |
| | "si_a": msci.si_a, |
| | "agreement_score": merge_report["agreement_score"], |
| | "per_section_agreement": merge_report["per_section_agreement"], |
| | } |
| | metric_scores = {k: scores[k] for k in ("msci", "st_i", "st_a", "si_a")} |
| | coherence_step = self.coherence_scorer.score( |
| | scores=metric_scores, |
| | global_drift=drift["global_drift"], |
| | ) |
| | coherence_step["needs_repair"] = ( |
| | coherence_step["classification"]["label"] == "MODALITY_FAILURE" |
| | and coherence_step["classification"]["weakest_metric"] == "st_i" |
| | ) |
| | repair_attempts += 1 |
| |
|
| | if coherence_step["classification"]["label"] in { |
| | "HIGH_COHERENCE", |
| | "LOCAL_MODALITY_WEAKNESS", |
| | }: |
| | break |
| |
|
| | step_decision = { |
| | "attempt": attempt, |
| | "image_path": image_path, |
| | "audio_path": audio_path, |
| | "audio_backend": audio_backend if "audio_backend" in locals() else None, |
| | "scores": scores, |
| | "coherence": coherence_step, |
| | "drift": drift, |
| | "retry_strategy": retry_strategy if attempt > 1 else None, |
| | "retry_metric": retry_metric if attempt > 1 else None, |
| | } |
| | decisions.append(step_decision) |
| |
|
| | if attempt > 1 and retry_metric: |
| | prev_scores = decisions[-2].get("scores", {}) |
| | before = prev_scores.get(retry_metric) |
| | after = scores.get(retry_metric) |
| | if before is not None and after is not None: |
| | before_status = self.coherence_scorer.thresholds.classify_value( |
| | retry_metric, |
| | before, |
| | ) |
| | after_status = self.coherence_scorer.thresholds.classify_value( |
| | retry_metric, |
| | after, |
| | ) |
| | success = (before_status == "FAIL" and after_status in {"WEAK", "GOOD"}) or ( |
| | after > before + epsilon |
| | ) |
| | retry_outcomes.append( |
| | { |
| | "strategy": retry_strategy, |
| | "trigger": retry_trigger, |
| | "weakest_metric": retry_metric, |
| | "before": { |
| | "msci": prev_scores.get("msci"), |
| | "st_i": prev_scores.get("st_i"), |
| | "st_a": prev_scores.get("st_a"), |
| | "si_a": prev_scores.get("si_a"), |
| | }, |
| | "after": { |
| | "msci": scores.get("msci"), |
| | "st_i": scores.get("st_i"), |
| | "st_a": scores.get("st_a"), |
| | "si_a": scores.get("si_a"), |
| | }, |
| | "epsilon": epsilon, |
| | "success": success, |
| | } |
| | ) |
| |
|
| | if best_state is None or scores["msci"] > best_state[0]: |
| | best_state = ( |
| | scores["msci"], |
| | narrative, |
| | image_path, |
| | audio_path, |
| | scores, |
| | drift, |
| | attempt, |
| | ) |
| |
|
| | if scores["msci"] >= self.msci_threshold and not drift["global_drift"]: |
| | break |
| |
|
| | assert best_state is not None |
| | _, best_text, best_img, best_aud, best_scores, best_drift, best_attempt = best_state |
| |
|
| | metric_scores = {k: best_scores[k] for k in ("msci", "st_i", "st_a", "si_a") if k in best_scores} |
| | coherence = self.coherence_scorer.score( |
| | scores=metric_scores, |
| | global_drift=best_drift["global_drift"], |
| | ) |
| | final_assessment = build_final_assessment(coherence, retry_outcomes) |
| |
|
| | out = RunOutput( |
| | run_id=paths.run_id, |
| | semantic_plan=plan.model_dump(), |
| | merge_report=merge_report, |
| | planner_outputs=planner_outputs, |
| | narrative_structured=narrative_structured.model_dump(), |
| | narrative_text=best_text, |
| | image_path=best_img, |
| | audio_path=best_aud, |
| | scores=best_scores, |
| | coherence=coherence, |
| | final_assessment=final_assessment, |
| | drift=best_drift, |
| | attempts=best_attempt, |
| | decisions=decisions, |
| | ) |
| |
|
| | write_run_metadata( |
| | paths.logs_dir / "run.json", |
| | { |
| | "run_id": out.run_id, |
| | "user_prompt": user_prompt, |
| | "semantic_plan": out.semantic_plan, |
| | "merge_report": out.merge_report, |
| | "planner_outputs": out.planner_outputs, |
| | "narrative_structured": out.narrative_structured, |
| | "final": { |
| | "narrative_text": out.narrative_text, |
| | "image_path": out.image_path, |
| | "audio_path": out.audio_path, |
| | "scores": out.scores, |
| | "coherence": out.coherence, |
| | "final_assessment": out.final_assessment, |
| | "drift": out.drift, |
| | "attempts": out.attempts, |
| | }, |
| | "attempt_history": out.decisions, |
| | }, |
| | ) |
| | if retry_outcomes: |
| | write_run_metadata( |
| | paths.logs_dir / "retry_outcome.json", |
| | {"retries": retry_outcomes}, |
| | ) |
| |
|
| | return out |
| |
|