Spaces:
Sleeping
Sleeping
Deployment
Fix: Add validation gates, character consistency enforcement, and proper data serialization
7289c0c | """Pipeline orchestrator that manages agent flow and data passing.""" | |
| import json | |
| import logging | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| from agents import ( | |
| ShowrunnerAgent, | |
| StoryEditorAgent, | |
| CulturalConsultantAgent, | |
| LeadWriterAgent, | |
| DialogueSpecialistAgent, | |
| ComedyWriterAgent, | |
| ProofreaderAgent, | |
| ) | |
| from hf_uploader import HFUploader | |
| from config import settings | |
| logger = logging.getLogger(__name__) | |
| class PipelineValidationError(Exception): | |
| """Raised when pipeline validation fails.""" | |
| pass | |
| class PipelineOrchestrator: | |
| """Orchestrates the multi-agent content generation pipeline.""" | |
| def __init__(self): | |
| """Initialize the orchestrator with all agents.""" | |
| self.showrunner = ShowrunnerAgent() | |
| self.story_editor = StoryEditorAgent() | |
| self.cultural_consultant = CulturalConsultantAgent() | |
| self.lead_writer = LeadWriterAgent() | |
| self.dialogue_specialist = DialogueSpecialistAgent() | |
| self.comedy_writer = ComedyWriterAgent() | |
| self.proofreader = ProofreaderAgent() | |
| self.hf_uploader = HFUploader() | |
| # Pipeline state | |
| self.run_id = str(uuid.uuid4()) | |
| self.pipeline_state = { | |
| "run_id": self.run_id, | |
| "start_time": datetime.now().isoformat(), | |
| "stages": {}, | |
| } | |
| # Extracted character list for consistency | |
| self.character_list = [] | |
| logger.info(f"Initialized pipeline orchestrator with run_id: {self.run_id}") | |
| def _validate_output(self, stage_name: str, output: Dict[str, Any], required_keys: list) -> None: | |
| """Validate that a stage output contains required keys and is not empty. | |
| Args: | |
| stage_name: Name of the stage | |
| output: Output dictionary from the stage | |
| required_keys: List of required keys | |
| Raises: | |
| PipelineValidationError: If validation fails | |
| """ | |
| if not output: | |
| raise PipelineValidationError(f"{stage_name}: Output is empty or None") | |
| for key in required_keys: | |
| if key not in output: | |
| raise PipelineValidationError(f"{stage_name}: Missing required key '{key}'") | |
| value = output.get(key, "") | |
| if isinstance(value, str) and not value.strip(): | |
| raise PipelineValidationError( | |
| f"{stage_name}: Required field '{key}' is empty. " | |
| f"This indicates a processing failure. Aborting pipeline." | |
| ) | |
| def _extract_characters(self, character_bible: str) -> list: | |
| """Extract character names from character bible. | |
| Args: | |
| character_bible: Character definitions | |
| Returns: | |
| List of character names | |
| """ | |
| # Simple extraction - look for common patterns | |
| characters = [] | |
| lines = character_bible.split('\n') | |
| for line in lines: | |
| # Look for lines that define characters (e.g., "Alex (CEO)" or "- Jordan") | |
| if any(marker in line for marker in ['(', ':', '-']): | |
| # Extract the first word as potential character name | |
| words = line.strip().split() | |
| if words and words[0].replace('-', '').replace('*', '').isalpha(): | |
| char_name = words[0].replace('-', '').replace('*', '').strip() | |
| if len(char_name) > 1 and char_name[0].isupper(): | |
| characters.append(char_name) | |
| # Remove duplicates while preserving order | |
| seen = set() | |
| unique_chars = [] | |
| for char in characters: | |
| if char not in seen: | |
| seen.add(char) | |
| unique_chars.append(char) | |
| self.character_list = unique_chars | |
| logger.info(f"Extracted characters: {self.character_list}") | |
| return unique_chars | |
| def execute_pipeline( | |
| self, | |
| user_brief: str, | |
| season_arc_document: str, | |
| character_bible: str, | |
| world_building_document: str, | |
| character_voice_guide: str, | |
| style_guide: str, | |
| continuity_log: str, | |
| hook_brief: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """Execute the full content generation pipeline. | |
| Args: | |
| user_brief: Initial user brief | |
| season_arc_document: Season context | |
| character_bible: Character definitions | |
| world_building_document: World context | |
| character_voice_guide: Character voice definitions | |
| style_guide: Style reference | |
| continuity_log: Continuity tracking | |
| hook_brief: Optional hook brief for comedy writer | |
| Returns: | |
| Dictionary with final output and metadata | |
| """ | |
| try: | |
| logger.info("Starting pipeline execution") | |
| # Extract character list for consistency enforcement | |
| self._extract_characters(character_bible) | |
| # Stage 1: Showrunner | |
| logger.info("Stage 1: Showrunner - Generating directive") | |
| showrunner_inputs = { | |
| "user_brief": user_brief, | |
| "season_arc_document": season_arc_document, | |
| "character_bible": character_bible, | |
| } | |
| showrunner_output = self.showrunner.generate_directive(showrunner_inputs) | |
| self._validate_output( | |
| "Showrunner", | |
| showrunner_output, | |
| ["episode_directive", "story_premise", "tone_brief", "character_focus_notes"] | |
| ) | |
| self.pipeline_state["stages"]["showrunner"] = showrunner_output | |
| logger.info("Stage 1 completed β") | |
| # Stage 2: Story Editor | |
| logger.info("Stage 2: Story Editor - Generating outline") | |
| story_editor_inputs = { | |
| "episode_directive": showrunner_output.get("episode_directive", ""), | |
| "series_continuity_log": continuity_log, | |
| "character_list": self.character_list, # Pass character list for consistency | |
| } | |
| story_editor_output = self.story_editor.generate_outline(story_editor_inputs) | |
| self._validate_output( | |
| "Story Editor", | |
| story_editor_output, | |
| ["episode_outline", "act_structure"] | |
| ) | |
| self.pipeline_state["stages"]["story_editor"] = story_editor_output | |
| logger.info("Stage 2 completed β") | |
| # Stage 3: Cultural Consultant (parallel with Lead Writer) | |
| logger.info("Stage 3: Cultural Consultant - Reviewing outline") | |
| cultural_inputs = { | |
| "episode_outline": story_editor_output.get("episode_outline", ""), | |
| "world_building_document": world_building_document, | |
| "character_list": self.character_list, | |
| } | |
| cultural_output = self.cultural_consultant.review_outline(cultural_inputs) | |
| self._validate_output( | |
| "Cultural Consultant", | |
| cultural_output, | |
| ["cultural_accuracy_notes"] | |
| ) | |
| # Check if cultural consultant flagged critical issues | |
| flagged = cultural_output.get("flagged_inaccuracies", []) | |
| if flagged and len(flagged) > 2: | |
| logger.warning(f"Cultural Consultant flagged {len(flagged)} issues - proceeding with caution") | |
| self.pipeline_state["stages"]["cultural_consultant"] = cultural_output | |
| logger.info("Stage 3 completed β") | |
| # Stage 4: Lead Writer | |
| logger.info("Stage 4: Lead Writer - Writing script") | |
| lead_writer_inputs = { | |
| "approved_outline": story_editor_output.get("episode_outline", ""), | |
| "cultural_consultant_notes": cultural_output.get("cultural_accuracy_notes", ""), | |
| "character_voice_guide": character_voice_guide, | |
| "character_list": self.character_list, # Enforce character consistency | |
| "story_premise": showrunner_output.get("story_premise", ""), | |
| } | |
| lead_writer_output = self.lead_writer.write_script(lead_writer_inputs) | |
| self._validate_output( | |
| "Lead Writer", | |
| lead_writer_output, | |
| ["full_episode_first_draft"] | |
| ) | |
| self.pipeline_state["stages"]["lead_writer"] = lead_writer_output | |
| logger.info("Stage 4 completed β") | |
| # Stage 5: Dialogue Specialist | |
| logger.info("Stage 5: Dialogue Specialist - Polishing dialogue") | |
| # Ensure script is properly serialized as string | |
| first_draft = lead_writer_output.get("full_episode_first_draft", "") | |
| if isinstance(first_draft, dict): | |
| first_draft = json.dumps(first_draft, indent=2) | |
| dialogue_inputs = { | |
| "first_draft_script": first_draft, | |
| "character_voice_guide": character_voice_guide, | |
| "character_list": self.character_list, | |
| "dialect_slang_reference": "", | |
| } | |
| dialogue_output = self.dialogue_specialist.polish_dialogue(dialogue_inputs) | |
| self._validate_output( | |
| "Dialogue Specialist", | |
| dialogue_output, | |
| ["dialogue_polished_script"] | |
| ) | |
| self.pipeline_state["stages"]["dialogue_specialist"] = dialogue_output | |
| logger.info("Stage 5 completed β") | |
| # Stage 6: Comedy Writer | |
| logger.info("Stage 6: Comedy Writer - Adding humor") | |
| # Ensure script is properly serialized | |
| polished_script = dialogue_output.get("dialogue_polished_script", "") | |
| if isinstance(polished_script, dict): | |
| polished_script = json.dumps(polished_script, indent=2) | |
| comedy_inputs = { | |
| "dialogue_polished_script": polished_script, | |
| "hook_brief_from_showrunner": hook_brief or user_brief, | |
| "character_list": self.character_list, | |
| "tone_brief": showrunner_output.get("tone_brief", ""), | |
| } | |
| comedy_output = self.comedy_writer.add_humor(comedy_inputs) | |
| self._validate_output( | |
| "Comedy Writer", | |
| comedy_output, | |
| ["comedy_sharpened_script"] | |
| ) | |
| self.pipeline_state["stages"]["comedy_writer"] = comedy_output | |
| logger.info("Stage 6 completed β") | |
| # Stage 7: Proofreader (Final QC) | |
| logger.info("Stage 7: Proofreader - Final quality control") | |
| # Ensure script is properly serialized | |
| comedy_script = comedy_output.get("comedy_sharpened_script", "") | |
| if isinstance(comedy_script, dict): | |
| comedy_script = json.dumps(comedy_script, indent=2) | |
| proofreader_inputs = { | |
| "comedy_sharpened_script": comedy_script, | |
| "style_guide": style_guide, | |
| "continuity_log": continuity_log, | |
| "character_list": self.character_list, | |
| } | |
| proofreader_output = self.proofreader.final_qc(proofreader_inputs) | |
| self._validate_output( | |
| "Proofreader", | |
| proofreader_output, | |
| ["final_locked_script"] | |
| ) | |
| self.pipeline_state["stages"]["proofreader"] = proofreader_output | |
| logger.info("Stage 7 completed β") | |
| # Mark completion | |
| self.pipeline_state["end_time"] = datetime.now().isoformat() | |
| self.pipeline_state["status"] = "completed" | |
| # Save local state | |
| self._save_pipeline_state() | |
| # Upload to Hugging Face | |
| logger.info("Uploading final output to Hugging Face") | |
| hf_url = self.hf_uploader.upload_final_output( | |
| proofreader_output, self.run_id | |
| ) | |
| hf_metadata_url = self.hf_uploader.upload_pipeline_metadata( | |
| self.pipeline_state | |
| ) | |
| final_result = { | |
| "run_id": self.run_id, | |
| "status": "success", | |
| "final_output": proofreader_output, | |
| "hf_output_url": hf_url, | |
| "hf_metadata_url": hf_metadata_url, | |
| "pipeline_state": self.pipeline_state, | |
| } | |
| logger.info("β Pipeline execution completed successfully") | |
| return final_result | |
| except PipelineValidationError as e: | |
| logger.error(f"β Pipeline validation failed: {str(e)}") | |
| self.pipeline_state["status"] = "failed" | |
| self.pipeline_state["error"] = str(e) | |
| self._save_pipeline_state() | |
| raise | |
| except Exception as e: | |
| logger.error(f"β Pipeline execution failed: {str(e)}") | |
| self.pipeline_state["status"] = "failed" | |
| self.pipeline_state["error"] = str(e) | |
| self._save_pipeline_state() | |
| raise | |
| def _save_pipeline_state(self) -> None: | |
| """Save the pipeline state to local storage.""" | |
| output_dir = Path(settings.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| state_file = output_dir / f"pipeline_{self.run_id}.json" | |
| with open(state_file, "w") as f: | |
| json.dump(self.pipeline_state, f, indent=2) | |
| logger.info(f"Pipeline state saved to {state_file}") | |