File size: 22,409 Bytes
46a8a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
"""
spatial content planning and story board curation
"""

import json
from pathlib import Path
from typing import Dict, Any, List

from src.state.poster_state import PosterState
from utils.langgraph_utils import LangGraphAgent, extract_json, load_prompt
from utils.src.logging_utils import log_agent_info, log_agent_success, log_agent_error, log_agent_warning
from src.config.poster_config import load_config
from jinja2 import Template

class StoryBoardCurator:
    """creates spatial content plan and story board"""
    
    def __init__(self):
        self.name = "spatial_content_planner"
        self.spatial_planning_prompt = load_prompt("config/prompts/spatial_content_planner.txt")
        self.config = load_config()
        self.validation_config = self.config["validation"]
        self.utilization_config = self.config["utilization_thresholds"]

    def __call__(self, state: PosterState) -> PosterState:
        log_agent_info(self.name, "creating spatial content plan")
        
        try:
            structured_sections = state.get("structured_sections")
            narrative_content = state.get("narrative_content")
            classified_visuals = state.get("classified_visuals")

            if not structured_sections:
                log_agent_error(self.name, "missing structured_sections from parser")
                raise ValueError("missing structured_sections from parser")
            if not narrative_content:
                log_agent_error(self.name, "missing narrative_content from parser")
                raise ValueError("missing narrative_content from parser")
            if not classified_visuals:
                log_agent_error(self.name, "missing classified_visuals from parser")
                raise ValueError("missing classified_visuals from parser")
            
            # prepare visual height context for spatial planning
            visual_context = self._prepare_visual_context_for_curator(state)
            
            story_board, inp, out = self._create_story_board(
                structured_sections, narrative_content, classified_visuals, 
                state.get("images", {}), state.get("tables", {}),
                visual_context, state["text_model"]
            )
            state["tokens"].add_text(inp, out)
            
            # validate height distribution
            validation_result = self._validate_height_distribution(story_board, visual_context)
            if validation_result["warnings"]:
                log_agent_warning(self.name, f"height validation warnings: {validation_result['warnings']}")
            log_agent_info(self.name, f"column utilizations: {validation_result['column_utilizations']}")
            
            state["story_board"] = story_board
            state["current_agent"] = self.name
            
            self._save_story_board(state)
            
            # log story board summary
            sections = story_board.get("spatial_content_plan", {}).get("sections", [])
            total_visuals = sum(len(section.get("visual_assets", [])) for section in sections)
            
            log_agent_success(self.name, f"created story board with {len(sections)} sections")
            log_agent_success(self.name, f"selected {total_visuals} visual assets")

        except Exception as e:
            log_agent_error(self.name, f"failed: {e}")
            state["errors"].append(f"{self.name}: {e}")
            
        return state

    def _create_story_board(self, structured_sections, narrative_content, classified_visuals, images, tables, visual_context, config):
        
        log_agent_info(self.name, "generating spatial content plan")
        agent = LangGraphAgent("expert spatial poster designer", config)
        
        template_data = {
            "structured_sections": json.dumps(structured_sections, indent=2),
            "narrative_content": json.dumps(narrative_content, indent=2),
            "classified_visuals": json.dumps(classified_visuals, indent=2),
            "available_images": json.dumps({k: {"caption": v.get("caption", ""), "aspect": v.get("aspect", 1.0)} 
                                          for k, v in images.items()}, indent=2),
            "available_tables": json.dumps({k: {"caption": v.get("caption", ""), "aspect": v.get("aspect", 1.0)} 
                                          for k, v in tables.items()}, indent=2),
            "available_height_per_column": visual_context["available_height_per_column"],
            "visual_heights_info": json.dumps(visual_context["visual_assets_heights"], indent=2)
        }
        
        max_attempts = self.validation_config["max_llm_attempts"]
        for attempt in range(max_attempts):
            try:
                prompt = Template(self.spatial_planning_prompt).render(**template_data)
                agent.reset()
                response = agent.step(prompt)
                
                story_board = extract_json(response.content)
                
                if self._validate_story_board(story_board, classified_visuals, visual_context):
                    log_agent_success(self.name, f"successfully created story board on attempt {attempt + 1}")
                    return story_board, response.input_tokens, response.output_tokens
                else:
                    log_agent_warning(self.name, f"attempt {attempt + 1}: validation failed, retrying")
                    
            except Exception as e:
                log_agent_warning(self.name, f"story board attempt {attempt + 1} failed: {e}")
                if attempt == max_attempts - 1:
                    raise ValueError("failed to create story board after multiple attempts")

        raise ValueError("failed to create story board")

    def _validate_story_board(self, story_board: Dict, classified_visuals: Dict = None, visual_context: Dict = None) -> bool:
        """validate story board structure and constraints"""
        if "spatial_content_plan" not in story_board:
            log_agent_warning(self.name, "validation error: missing 'spatial_content_plan'")
            return False
        
        scp = story_board["spatial_content_plan"]
        
        # check sections
        if "sections" not in scp or not isinstance(scp["sections"], list):
            log_agent_warning(self.name, "validation error: missing or invalid 'sections'")
            return False
        
        sections = scp["sections"]
        min_sections = self.validation_config["min_section_count"]
        max_sections = self.validation_config["max_section_count"] 
        if len(sections) < min_sections or len(sections) > max_sections:
            log_agent_warning(self.name, f"validation error: need 5-8 sections, got {len(sections)}")
            return False
        
        # validate each section
        for i, section in enumerate(sections):
            required_fields = ["section_id", "section_title", "column_assignment", "vertical_priority", "text_content"]
            for field in required_fields:
                if field not in section:
                    log_agent_warning(self.name, f"validation error: section {i} missing '{field}'")
                    return False
            
            # check column assignment is valid
            if section["column_assignment"] not in ["left", "middle", "right"]:
                log_agent_warning(self.name, f"validation error: section {i} invalid column_assignment")
                return False
                
            # check vertical priority is valid  
            if section["vertical_priority"] not in ["top", "middle", "bottom"]:
                log_agent_warning(self.name, f"validation error: section {i} invalid vertical_priority")
                return False
            
            # check section title length (4 words max)
            title = section.get("section_title", "")
            title_words = len(title.split())
            max_words = self.validation_config["max_title_words"]
            if title_words > max_words:
                log_agent_warning(self.name, f"validation error: section {i} title too long ({title_words} words): '{title}'")
                return False
            
            # check text content is list of bullet points
            min_items = self.validation_config["min_text_content_items"]
            if not isinstance(section["text_content"], list) or len(section["text_content"]) < min_items:
                log_agent_warning(self.name, f"validation error: section {i} invalid text_content")
                return False
            
            # check for ellipsis in text content
            for j, text in enumerate(section["text_content"]):
                if "..." in text:
                    log_agent_warning(self.name, f"validation error: section {i} bullet {j} contains ellipsis")
                    return False
        
        # validate key_visual placement if classified_visuals provided
        if classified_visuals:
            key_visual = classified_visuals.get("key_visual")
            if key_visual:
                key_visual_found = False
                key_visual_in_middle_top = False
                
                for section in sections:
                    visual_assets = section.get("visual_assets", [])
                    for visual in visual_assets:
                        if visual.get("visual_id") == key_visual:
                            key_visual_found = True
                            if (section.get("column_assignment") == "middle" and 
                                section.get("vertical_priority") == "top"):
                                key_visual_in_middle_top = True
                            break
                    if key_visual_found:
                        break
                
                if not key_visual_found:
                    log_agent_warning(self.name, f"validation error: key_visual '{key_visual}' not found in any section")
                    return False
                    
                if not key_visual_in_middle_top:
                    log_agent_warning(self.name, f"validation error: key_visual '{key_visual}' not placed in middle column, top priority")
                    return False
        
        # validate height exclusion compliance if visual_context provided
        if visual_context:
            visual_heights = visual_context.get("visual_assets_heights", {})
            oversized_visuals = []
            
            # check all visual assets in the story board
            for section in sections:
                visual_assets = section.get("visual_assets", [])
                for visual in visual_assets:
                    visual_id = visual.get("visual_id")
                    if visual_id in visual_heights:
                        height_info = visual_heights[visual_id]
                        # extract percentage value from string like "91%"
                        height_str = height_info.get("height_percentage", "0%")
                        height_percentage = float(height_str.rstrip('%'))
                        
                        if height_percentage > 50:
                            oversized_visuals.append(f"{visual_id} ({height_str})")
            
            if oversized_visuals:
                # check if only one oversized visual is selected
                if len(oversized_visuals) == 1:
                    # only one oversized visual selected, allow it as fallback
                    log_agent_info(self.name, f"fallback applied: allowing single oversized visual: {oversized_visuals[0]}")
                else:
                    # multiple oversized visuals selected, only allow the smallest
                    selected_oversized = []
                    for section in sections:
                        visual_assets = section.get("visual_assets", [])
                        for visual in visual_assets:
                            visual_id = visual.get("visual_id")
                            if visual_id in visual_heights:
                                height_info = visual_heights[visual_id]
                                height_str = height_info.get("height_percentage", "0%")
                                height_percentage = float(height_str.rstrip('%'))
                                if height_percentage > 50:
                                    selected_oversized.append((visual_id, height_percentage, height_str))
                    
                    smallest = min(selected_oversized, key=lambda x: x[1])
                    invalid_visuals = [f"{vid} ({h_str})" for vid, h, h_str in selected_oversized if vid != smallest[0]]
                    log_agent_warning(self.name, f"validation error: oversized visuals (>50% height) selected: {invalid_visuals} (fallback: only smallest allowed: {smallest[0]} ({smallest[2]}))")
                    return False
        
        return True

    def _prepare_visual_context_for_curator(self, state: PosterState) -> Dict[str, Any]:
        """prepare visual assets height information for curator's spatial planning"""
        config = load_config()
        
        # get poster dimensions
        poster_width = state["poster_width"] 
        poster_height = state["poster_height"]
        
        # calculate available height per column (18% of effective height for title region)
        poster_margins = 2 * config["layout"]["poster_margin"]
        effective_height = poster_height - poster_margins  # effective height after margins
        title_region_height = effective_height * config["layout"]["title_height_fraction"]  # 18% fixed region
        available_height = effective_height - title_region_height  # remaining height for sections
        
        # calculate effective column width for visual sizing
        column_margins = 2 * config["layout"]["poster_margin"]
        column_spacing = 2 * config["layout"]["column_spacing"]  # 2 gaps between 3 columns
        total_column_width = poster_width - column_margins - column_spacing
        column_width = total_column_width / 3
        
        # account for text padding within each column
        text_padding = 2 * config["layout"]["text_padding"]["left_right"]
        effective_width = column_width - text_padding
        
        log_agent_info(self.name, f"visual context: available_height={available_height:.1f}\", effective_width={effective_width:.1f}\"")
        
        # calculate height for each visual asset
        visual_heights = {}
        
        # process figures (images in state)
        figures = state.get("images", {})
        for fig_id, fig_data in figures.items():
            aspect_ratio = fig_data.get("aspect", 1.0)
            visual_height = effective_width / aspect_ratio
            height_percentage = (visual_height / available_height) * 100
            
            visual_heights[f"figure_{fig_id}"] = {
                "height_inches": round(visual_height, 1),
                "height_percentage": f"{height_percentage:.0f}%",
                "type": "figure",
                "aspect_ratio": aspect_ratio
            }
            log_agent_info(self.name, f"figure_{fig_id}: {visual_height:.1f}\" ({height_percentage:.0f}% of column)")
        
        # process tables
        tables = state.get("tables", {})
        for table_id, table_data in tables.items():
            aspect_ratio = table_data.get("aspect", 1.0)
            visual_height = effective_width / aspect_ratio
            height_percentage = (visual_height / available_height) * 100
            
            visual_heights[f"table_{table_id}"] = {
                "height_inches": round(visual_height, 1),
                "height_percentage": f"{height_percentage:.0f}%",
                "type": "table", 
                "aspect_ratio": aspect_ratio
            }
            log_agent_info(self.name, f"table_{table_id}: {visual_height:.1f}\" ({height_percentage:.0f}% of column)")
        
        return {
            "available_height_per_column": round(available_height, 1),
            "visual_assets_heights": visual_heights,
            "column_width": round(column_width, 1),
            "effective_width": round(effective_width, 1)
        }

    def _validate_height_distribution(self, story_board: Dict, visual_context: Dict) -> Dict[str, Any]:
        """validate spatial plan for height constraints and generate warnings"""
        config = load_config()
        available_height = visual_context["available_height_per_column"]
        visual_heights = visual_context["visual_assets_heights"]
        
        # extract sections from story board
        sections = story_board.get("spatial_content_plan", {}).get("sections", [])
        if not sections:
            return {"warnings": ["No sections found in story board"], "column_utilizations": {}}
        
        # organize sections by column
        columns = {"left": [], "middle": [], "right": []}
        for section in sections:
            column = section.get("column_assignment", "left")
            if column in columns:
                columns[column].append(section)
        
        # calculate estimated height for each section and column
        column_utilizations = {}
        warnings = []
        
        for column_name, column_sections in columns.items():
            total_height = 0
            total_visual_height = 0
            total_visuals = 0
            section_details = []
            
            for section in column_sections:
                section_height = self._estimate_section_height(section, visual_heights, config)
                total_height += section_height
                
                # calculate visual contribution for this section
                section_visual_height = 0
                visual_assets = section.get("visual_assets", [])
                for visual_asset in visual_assets:
                    visual_id = visual_asset.get("visual_id", "")
                    if visual_id in visual_heights:
                        section_visual_height += visual_heights[visual_id]["height_inches"]
                        total_visuals += 1
                
                total_visual_height += section_visual_height
                section_details.append({
                    "section_id": section.get("section_id", "unknown"),
                    "estimated_height": section_height,
                    "visual_count": len(visual_assets),
                    "visual_height": round(section_visual_height, 1)
                })
            
            utilization = total_height / available_height if available_height > 0 else 0
            visual_density = total_visual_height / available_height if available_height > 0 else 0
            
            column_utilizations[column_name] = {
                "total_height": round(total_height, 1),
                "utilization_percent": f"{utilization*100:.0f}%",
                "visual_density_percent": f"{visual_density*100:.0f}%",
                "section_count": len(column_sections),
                "total_visuals": total_visuals,
                "sections": section_details,
                "status": "OK" if utilization <= self.utilization_config["overflow_critical"] else "OVERFLOW"
            }
            
            if utilization > self.utilization_config["overflow_critical"]:
                warnings.append(f"{column_name} column serious overflow: {utilization*100:.0f}% (visual density: {visual_density*100:.0f}%)")
            elif utilization > self.utilization_config["overflow_warning"]:
                warnings.append(f"{column_name} column minor overflow: {utilization*100:.0f}% (visual density: {visual_density*100:.0f}%)")
            elif utilization < self.utilization_config["underutilized"]:
                warnings.append(f"{column_name} column underutilized: {utilization*100:.0f}% (visual density: {visual_density*100:.0f}%)")
            
            if total_visuals == 0:
                warnings.append(f"{column_name} column has no visuals - add visual assets")
        
        return {
            "column_utilizations": column_utilizations,
            "warnings": warnings,
            "overall_status": "PASS" if not warnings else "NEEDS_OPTIMIZATION"
        }

    def _estimate_section_height(self, section: Dict, visual_heights: Dict, config: Dict) -> float:
        """estimate total height for a section including visuals and text"""
        total_height = 0
        
        # section title height (from config)
        section_title_height = config["section_estimation"]["base_title_height"]
        total_height += section_title_height
        
        # visual assets height
        visual_assets = section.get("visual_assets", [])
        for visual_asset in visual_assets:
            visual_id = visual_asset.get("visual_id", "")
            if visual_id in visual_heights:
                visual_height = visual_heights[visual_id]["height_inches"]
                visual_spacing = config["layout"]["visual_spacing"]["below_visual"]
                total_height += visual_height + visual_spacing
        
        # text content height (rough estimation)
        text_content = section.get("text_content", [])
        text_lines = len(text_content)
        bullet_height = config["section_estimation"]["bullet_point_height"]
        text_height = text_lines * bullet_height
        total_height += text_height
        
        # spacing between title and content
        title_spacing = config["layout"]["title_to_content_spacing"]
        total_height += title_spacing
        
        # section bottom spacing
        section_spacing = config["layout"]["section_spacing"]
        total_height += section_spacing
        
        return total_height

    def _save_story_board(self, state: PosterState):
        """save story board to json file"""
        output_dir = Path(state["output_dir"]) / "content"
        output_dir.mkdir(parents=True, exist_ok=True)
        with open(output_dir / "story_board.json", "w", encoding='utf-8') as f:
            json.dump(state.get("story_board", {}), f, indent=2)


def curator_node(state) -> Dict[str, Any]:
    result = StoryBoardCurator()(state)
    return {
        **state,
        "story_board": result["story_board"],
        "tokens": result["tokens"],
        "current_agent": result["current_agent"],
        "errors": result["errors"]
    }