Spaces:
Running
Running
| """ | |
| spatial content planning and story board curation | |
| """ | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, Any, List | |
| from src.state.poster_state import PosterState | |
| from utils.langgraph_utils import LangGraphAgent, extract_json, load_prompt | |
| from utils.src.logging_utils import log_agent_info, log_agent_success, log_agent_error, log_agent_warning | |
| from src.config.poster_config import load_config | |
| from jinja2 import Template | |
| class StoryBoardCurator: | |
| """creates spatial content plan and story board""" | |
| def __init__(self): | |
| self.name = "spatial_content_planner" | |
| self.spatial_planning_prompt = load_prompt("config/prompts/spatial_content_planner.txt") | |
| self.config = load_config() | |
| self.validation_config = self.config["validation"] | |
| self.utilization_config = self.config["utilization_thresholds"] | |
| def __call__(self, state: PosterState) -> PosterState: | |
| log_agent_info(self.name, "creating spatial content plan") | |
| try: | |
| structured_sections = state.get("structured_sections") | |
| narrative_content = state.get("narrative_content") | |
| classified_visuals = state.get("classified_visuals") | |
| if not structured_sections: | |
| log_agent_error(self.name, "missing structured_sections from parser") | |
| raise ValueError("missing structured_sections from parser") | |
| if not narrative_content: | |
| log_agent_error(self.name, "missing narrative_content from parser") | |
| raise ValueError("missing narrative_content from parser") | |
| if not classified_visuals: | |
| log_agent_error(self.name, "missing classified_visuals from parser") | |
| raise ValueError("missing classified_visuals from parser") | |
| # prepare visual height context for spatial planning | |
| visual_context = self._prepare_visual_context_for_curator(state) | |
| story_board, inp, out = self._create_story_board( | |
| structured_sections, narrative_content, classified_visuals, | |
| state.get("images", {}), state.get("tables", {}), | |
| visual_context, state["text_model"] | |
| ) | |
| state["tokens"].add_text(inp, out) | |
| # validate height distribution | |
| validation_result = self._validate_height_distribution(story_board, visual_context) | |
| if validation_result["warnings"]: | |
| log_agent_warning(self.name, f"height validation warnings: {validation_result['warnings']}") | |
| log_agent_info(self.name, f"column utilizations: {validation_result['column_utilizations']}") | |
| state["story_board"] = story_board | |
| state["current_agent"] = self.name | |
| self._save_story_board(state) | |
| # log story board summary | |
| sections = story_board.get("spatial_content_plan", {}).get("sections", []) | |
| total_visuals = sum(len(section.get("visual_assets", [])) for section in sections) | |
| log_agent_success(self.name, f"created story board with {len(sections)} sections") | |
| log_agent_success(self.name, f"selected {total_visuals} visual assets") | |
| except Exception as e: | |
| log_agent_error(self.name, f"failed: {e}") | |
| state["errors"].append(f"{self.name}: {e}") | |
| return state | |
| def _create_story_board(self, structured_sections, narrative_content, classified_visuals, images, tables, visual_context, config): | |
| log_agent_info(self.name, "generating spatial content plan") | |
| agent = LangGraphAgent("expert spatial poster designer", config) | |
| template_data = { | |
| "structured_sections": json.dumps(structured_sections, indent=2), | |
| "narrative_content": json.dumps(narrative_content, indent=2), | |
| "classified_visuals": json.dumps(classified_visuals, indent=2), | |
| "available_images": json.dumps({k: {"caption": v.get("caption", ""), "aspect": v.get("aspect", 1.0)} | |
| for k, v in images.items()}, indent=2), | |
| "available_tables": json.dumps({k: {"caption": v.get("caption", ""), "aspect": v.get("aspect", 1.0)} | |
| for k, v in tables.items()}, indent=2), | |
| "available_height_per_column": visual_context["available_height_per_column"], | |
| "visual_heights_info": json.dumps(visual_context["visual_assets_heights"], indent=2) | |
| } | |
| max_attempts = self.validation_config["max_llm_attempts"] | |
| for attempt in range(max_attempts): | |
| try: | |
| prompt = Template(self.spatial_planning_prompt).render(**template_data) | |
| agent.reset() | |
| response = agent.step(prompt) | |
| story_board = extract_json(response.content) | |
| if self._validate_story_board(story_board, classified_visuals, visual_context): | |
| log_agent_success(self.name, f"successfully created story board on attempt {attempt + 1}") | |
| return story_board, response.input_tokens, response.output_tokens | |
| else: | |
| log_agent_warning(self.name, f"attempt {attempt + 1}: validation failed, retrying") | |
| except Exception as e: | |
| log_agent_warning(self.name, f"story board attempt {attempt + 1} failed: {e}") | |
| if attempt == max_attempts - 1: | |
| raise ValueError("failed to create story board after multiple attempts") | |
| raise ValueError("failed to create story board") | |
| def _validate_story_board(self, story_board: Dict, classified_visuals: Dict = None, visual_context: Dict = None) -> bool: | |
| """validate story board structure and constraints""" | |
| if "spatial_content_plan" not in story_board: | |
| log_agent_warning(self.name, "validation error: missing 'spatial_content_plan'") | |
| return False | |
| scp = story_board["spatial_content_plan"] | |
| # check sections | |
| if "sections" not in scp or not isinstance(scp["sections"], list): | |
| log_agent_warning(self.name, "validation error: missing or invalid 'sections'") | |
| return False | |
| sections = scp["sections"] | |
| min_sections = self.validation_config["min_section_count"] | |
| max_sections = self.validation_config["max_section_count"] | |
| if len(sections) < min_sections or len(sections) > max_sections: | |
| log_agent_warning(self.name, f"validation error: need 5-8 sections, got {len(sections)}") | |
| return False | |
| # validate each section | |
| for i, section in enumerate(sections): | |
| required_fields = ["section_id", "section_title", "column_assignment", "vertical_priority", "text_content"] | |
| for field in required_fields: | |
| if field not in section: | |
| log_agent_warning(self.name, f"validation error: section {i} missing '{field}'") | |
| return False | |
| # check column assignment is valid | |
| if section["column_assignment"] not in ["left", "middle", "right"]: | |
| log_agent_warning(self.name, f"validation error: section {i} invalid column_assignment") | |
| return False | |
| # check vertical priority is valid | |
| if section["vertical_priority"] not in ["top", "middle", "bottom"]: | |
| log_agent_warning(self.name, f"validation error: section {i} invalid vertical_priority") | |
| return False | |
| # check section title length (4 words max) | |
| title = section.get("section_title", "") | |
| title_words = len(title.split()) | |
| max_words = self.validation_config["max_title_words"] | |
| if title_words > max_words: | |
| log_agent_warning(self.name, f"validation error: section {i} title too long ({title_words} words): '{title}'") | |
| return False | |
| # check text content is list of bullet points | |
| min_items = self.validation_config["min_text_content_items"] | |
| if not isinstance(section["text_content"], list) or len(section["text_content"]) < min_items: | |
| log_agent_warning(self.name, f"validation error: section {i} invalid text_content") | |
| return False | |
| # check for ellipsis in text content | |
| for j, text in enumerate(section["text_content"]): | |
| if "..." in text: | |
| log_agent_warning(self.name, f"validation error: section {i} bullet {j} contains ellipsis") | |
| return False | |
| # validate key_visual placement if classified_visuals provided | |
| if classified_visuals: | |
| key_visual = classified_visuals.get("key_visual") | |
| if key_visual: | |
| key_visual_found = False | |
| key_visual_in_middle_top = False | |
| for section in sections: | |
| visual_assets = section.get("visual_assets", []) | |
| for visual in visual_assets: | |
| if visual.get("visual_id") == key_visual: | |
| key_visual_found = True | |
| if (section.get("column_assignment") == "middle" and | |
| section.get("vertical_priority") == "top"): | |
| key_visual_in_middle_top = True | |
| break | |
| if key_visual_found: | |
| break | |
| if not key_visual_found: | |
| log_agent_warning(self.name, f"validation error: key_visual '{key_visual}' not found in any section") | |
| return False | |
| if not key_visual_in_middle_top: | |
| log_agent_warning(self.name, f"validation error: key_visual '{key_visual}' not placed in middle column, top priority") | |
| return False | |
| # validate height exclusion compliance if visual_context provided | |
| if visual_context: | |
| visual_heights = visual_context.get("visual_assets_heights", {}) | |
| oversized_visuals = [] | |
| # check all visual assets in the story board | |
| for section in sections: | |
| visual_assets = section.get("visual_assets", []) | |
| for visual in visual_assets: | |
| visual_id = visual.get("visual_id") | |
| if visual_id in visual_heights: | |
| height_info = visual_heights[visual_id] | |
| # extract percentage value from string like "91%" | |
| height_str = height_info.get("height_percentage", "0%") | |
| height_percentage = float(height_str.rstrip('%')) | |
| if height_percentage > 50: | |
| oversized_visuals.append(f"{visual_id} ({height_str})") | |
| if oversized_visuals: | |
| # check if only one oversized visual is selected | |
| if len(oversized_visuals) == 1: | |
| # only one oversized visual selected, allow it as fallback | |
| log_agent_info(self.name, f"fallback applied: allowing single oversized visual: {oversized_visuals[0]}") | |
| else: | |
| # multiple oversized visuals selected, only allow the smallest | |
| selected_oversized = [] | |
| for section in sections: | |
| visual_assets = section.get("visual_assets", []) | |
| for visual in visual_assets: | |
| visual_id = visual.get("visual_id") | |
| if visual_id in visual_heights: | |
| height_info = visual_heights[visual_id] | |
| height_str = height_info.get("height_percentage", "0%") | |
| height_percentage = float(height_str.rstrip('%')) | |
| if height_percentage > 50: | |
| selected_oversized.append((visual_id, height_percentage, height_str)) | |
| smallest = min(selected_oversized, key=lambda x: x[1]) | |
| invalid_visuals = [f"{vid} ({h_str})" for vid, h, h_str in selected_oversized if vid != smallest[0]] | |
| log_agent_warning(self.name, f"validation error: oversized visuals (>50% height) selected: {invalid_visuals} (fallback: only smallest allowed: {smallest[0]} ({smallest[2]}))") | |
| return False | |
| return True | |
| def _prepare_visual_context_for_curator(self, state: PosterState) -> Dict[str, Any]: | |
| """prepare visual assets height information for curator's spatial planning""" | |
| config = load_config() | |
| # get poster dimensions | |
| poster_width = state["poster_width"] | |
| poster_height = state["poster_height"] | |
| # calculate available height per column (18% of effective height for title region) | |
| poster_margins = 2 * config["layout"]["poster_margin"] | |
| effective_height = poster_height - poster_margins # effective height after margins | |
| title_region_height = effective_height * config["layout"]["title_height_fraction"] # 18% fixed region | |
| available_height = effective_height - title_region_height # remaining height for sections | |
| # calculate effective column width for visual sizing | |
| column_margins = 2 * config["layout"]["poster_margin"] | |
| column_spacing = 2 * config["layout"]["column_spacing"] # 2 gaps between 3 columns | |
| total_column_width = poster_width - column_margins - column_spacing | |
| column_width = total_column_width / 3 | |
| # account for text padding within each column | |
| text_padding = 2 * config["layout"]["text_padding"]["left_right"] | |
| effective_width = column_width - text_padding | |
| log_agent_info(self.name, f"visual context: available_height={available_height:.1f}\", effective_width={effective_width:.1f}\"") | |
| # calculate height for each visual asset | |
| visual_heights = {} | |
| # process figures (images in state) | |
| figures = state.get("images", {}) | |
| for fig_id, fig_data in figures.items(): | |
| aspect_ratio = fig_data.get("aspect", 1.0) | |
| visual_height = effective_width / aspect_ratio | |
| height_percentage = (visual_height / available_height) * 100 | |
| visual_heights[f"figure_{fig_id}"] = { | |
| "height_inches": round(visual_height, 1), | |
| "height_percentage": f"{height_percentage:.0f}%", | |
| "type": "figure", | |
| "aspect_ratio": aspect_ratio | |
| } | |
| log_agent_info(self.name, f"figure_{fig_id}: {visual_height:.1f}\" ({height_percentage:.0f}% of column)") | |
| # process tables | |
| tables = state.get("tables", {}) | |
| for table_id, table_data in tables.items(): | |
| aspect_ratio = table_data.get("aspect", 1.0) | |
| visual_height = effective_width / aspect_ratio | |
| height_percentage = (visual_height / available_height) * 100 | |
| visual_heights[f"table_{table_id}"] = { | |
| "height_inches": round(visual_height, 1), | |
| "height_percentage": f"{height_percentage:.0f}%", | |
| "type": "table", | |
| "aspect_ratio": aspect_ratio | |
| } | |
| log_agent_info(self.name, f"table_{table_id}: {visual_height:.1f}\" ({height_percentage:.0f}% of column)") | |
| return { | |
| "available_height_per_column": round(available_height, 1), | |
| "visual_assets_heights": visual_heights, | |
| "column_width": round(column_width, 1), | |
| "effective_width": round(effective_width, 1) | |
| } | |
| def _validate_height_distribution(self, story_board: Dict, visual_context: Dict) -> Dict[str, Any]: | |
| """validate spatial plan for height constraints and generate warnings""" | |
| config = load_config() | |
| available_height = visual_context["available_height_per_column"] | |
| visual_heights = visual_context["visual_assets_heights"] | |
| # extract sections from story board | |
| sections = story_board.get("spatial_content_plan", {}).get("sections", []) | |
| if not sections: | |
| return {"warnings": ["No sections found in story board"], "column_utilizations": {}} | |
| # organize sections by column | |
| columns = {"left": [], "middle": [], "right": []} | |
| for section in sections: | |
| column = section.get("column_assignment", "left") | |
| if column in columns: | |
| columns[column].append(section) | |
| # calculate estimated height for each section and column | |
| column_utilizations = {} | |
| warnings = [] | |
| for column_name, column_sections in columns.items(): | |
| total_height = 0 | |
| total_visual_height = 0 | |
| total_visuals = 0 | |
| section_details = [] | |
| for section in column_sections: | |
| section_height = self._estimate_section_height(section, visual_heights, config) | |
| total_height += section_height | |
| # calculate visual contribution for this section | |
| section_visual_height = 0 | |
| visual_assets = section.get("visual_assets", []) | |
| for visual_asset in visual_assets: | |
| visual_id = visual_asset.get("visual_id", "") | |
| if visual_id in visual_heights: | |
| section_visual_height += visual_heights[visual_id]["height_inches"] | |
| total_visuals += 1 | |
| total_visual_height += section_visual_height | |
| section_details.append({ | |
| "section_id": section.get("section_id", "unknown"), | |
| "estimated_height": section_height, | |
| "visual_count": len(visual_assets), | |
| "visual_height": round(section_visual_height, 1) | |
| }) | |
| utilization = total_height / available_height if available_height > 0 else 0 | |
| visual_density = total_visual_height / available_height if available_height > 0 else 0 | |
| column_utilizations[column_name] = { | |
| "total_height": round(total_height, 1), | |
| "utilization_percent": f"{utilization*100:.0f}%", | |
| "visual_density_percent": f"{visual_density*100:.0f}%", | |
| "section_count": len(column_sections), | |
| "total_visuals": total_visuals, | |
| "sections": section_details, | |
| "status": "OK" if utilization <= self.utilization_config["overflow_critical"] else "OVERFLOW" | |
| } | |
| if utilization > self.utilization_config["overflow_critical"]: | |
| warnings.append(f"{column_name} column serious overflow: {utilization*100:.0f}% (visual density: {visual_density*100:.0f}%)") | |
| elif utilization > self.utilization_config["overflow_warning"]: | |
| warnings.append(f"{column_name} column minor overflow: {utilization*100:.0f}% (visual density: {visual_density*100:.0f}%)") | |
| elif utilization < self.utilization_config["underutilized"]: | |
| warnings.append(f"{column_name} column underutilized: {utilization*100:.0f}% (visual density: {visual_density*100:.0f}%)") | |
| if total_visuals == 0: | |
| warnings.append(f"{column_name} column has no visuals - add visual assets") | |
| return { | |
| "column_utilizations": column_utilizations, | |
| "warnings": warnings, | |
| "overall_status": "PASS" if not warnings else "NEEDS_OPTIMIZATION" | |
| } | |
| def _estimate_section_height(self, section: Dict, visual_heights: Dict, config: Dict) -> float: | |
| """estimate total height for a section including visuals and text""" | |
| total_height = 0 | |
| # section title height (from config) | |
| section_title_height = config["section_estimation"]["base_title_height"] | |
| total_height += section_title_height | |
| # visual assets height | |
| visual_assets = section.get("visual_assets", []) | |
| for visual_asset in visual_assets: | |
| visual_id = visual_asset.get("visual_id", "") | |
| if visual_id in visual_heights: | |
| visual_height = visual_heights[visual_id]["height_inches"] | |
| visual_spacing = config["layout"]["visual_spacing"]["below_visual"] | |
| total_height += visual_height + visual_spacing | |
| # text content height (rough estimation) | |
| text_content = section.get("text_content", []) | |
| text_lines = len(text_content) | |
| bullet_height = config["section_estimation"]["bullet_point_height"] | |
| text_height = text_lines * bullet_height | |
| total_height += text_height | |
| # spacing between title and content | |
| title_spacing = config["layout"]["title_to_content_spacing"] | |
| total_height += title_spacing | |
| # section bottom spacing | |
| section_spacing = config["layout"]["section_spacing"] | |
| total_height += section_spacing | |
| return total_height | |
| def _save_story_board(self, state: PosterState): | |
| """save story board to json file""" | |
| output_dir = Path(state["output_dir"]) / "content" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| with open(output_dir / "story_board.json", "w", encoding='utf-8') as f: | |
| json.dump(state.get("story_board", {}), f, indent=2) | |
| def curator_node(state) -> Dict[str, Any]: | |
| result = StoryBoardCurator()(state) | |
| return { | |
| **state, | |
| "story_board": result["story_board"], | |
| "tokens": result["tokens"], | |
| "current_agent": result["current_agent"], | |
| "errors": result["errors"] | |
| } |