File size: 5,518 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Phase 3: Validation to ensure generators use ONLY semantic plan, not raw prompt.

This module provides validation functions to detect prompt leakage.
"""

from typing import Dict, Any, List, Tuple


def extract_keywords(text: str) -> set:
    """Extract meaningful keywords from text (simple heuristic)."""
    # Remove common stopwords and punctuation
    stopwords = {
        "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
        "of", "with", "by", "from", "as", "is", "are", "was", "were", "be",
        "been", "being", "have", "has", "had", "do", "does", "did", "will",
        "would", "should", "could", "may", "might", "must", "can", "this",
        "that", "these", "those", "it", "its", "they", "them", "their",
    }
    
    # Simple tokenization
    words = text.lower().replace(",", " ").replace(".", " ").split()
    keywords = {w.strip() for w in words if len(w) > 3 and w not in stopwords}
    return keywords


def detect_prompt_leakage(
    original_prompt: str,
    plan: Dict[str, Any],
    text_prompt: str,
    image_prompt: str,
    audio_prompt: str,
) -> Dict[str, Any]:
    """
    Detect if any generator prompt contains words from original prompt
    that are NOT in the semantic plan.
    
    Returns:
        {
            "has_leakage": bool,
            "leakage_details": {
                "text": [...],
                "image": [...],
                "audio": [...]
            },
            "plan_coverage": float  # How much of original prompt is covered by plan
        }
    """
    original_keywords = extract_keywords(original_prompt)
    
    # Extract keywords from plan (convert plan to text representation)
    plan_text = _plan_to_text(plan)
    plan_keywords = extract_keywords(plan_text)
    
    # Keywords in original but NOT in plan (potential leakage)
    leakage_candidates = original_keywords - plan_keywords
    
    # Check each modality prompt
    text_keywords = extract_keywords(text_prompt)
    image_keywords = extract_keywords(image_prompt)
    audio_keywords = extract_keywords(audio_prompt)
    
    text_leakage = text_keywords & leakage_candidates
    image_leakage = image_keywords & leakage_candidates
    audio_leakage = audio_keywords & leakage_candidates
    
    has_leakage = bool(text_leakage or image_leakage or audio_leakage)
    
    # Plan coverage: how much of original is explained by plan
    plan_coverage = len(plan_keywords & original_keywords) / max(len(original_keywords), 1)
    
    return {
        "has_leakage": has_leakage,
        "leakage_details": {
            "text": sorted(text_leakage),
            "image": sorted(image_leakage),
            "audio": sorted(audio_leakage),
        },
        "plan_coverage": round(plan_coverage, 4),
        "original_keywords_count": len(original_keywords),
        "plan_keywords_count": len(plan_keywords),
    }


def _plan_to_text(plan: Dict[str, Any]) -> str:
    """Convert semantic plan to text representation for keyword extraction."""
    parts = []
    
    if isinstance(plan, dict):
        # Extract all string fields from plan
        for key, value in plan.items():
            if isinstance(value, str):
                parts.append(value)
            elif isinstance(value, list):
                parts.extend(str(v) for v in value if isinstance(v, str))
            elif isinstance(value, dict):
                # Recursively extract from nested dicts
                parts.append(_plan_to_text(value))
    
    return " ".join(parts)


def validate_conditioning_strictness(
    original_prompt: str,
    plan: Dict[str, Any],
    modality_prompts: Dict[str, str],
) -> Tuple[bool, str]:
    """
    Validate that modality prompts are derived strictly from plan.
    
    Returns:
        (is_valid, reason)
    """
    validation = detect_prompt_leakage(
        original_prompt=original_prompt,
        plan=plan,
        text_prompt=modality_prompts.get("text_prompt", ""),
        image_prompt=modality_prompts.get("image_prompt", ""),
        audio_prompt=modality_prompts.get("audio_prompt", ""),
    )
    
    if validation["has_leakage"]:
        leakage = validation["leakage_details"]
        reasons = []
        if leakage["text"]:
            reasons.append(f"Text prompt leaks: {leakage['text']}")
        if leakage["image"]:
            reasons.append(f"Image prompt leaks: {leakage['image']}")
        if leakage["audio"]:
            reasons.append(f"Audio prompt leaks: {leakage['audio']}")
        return False, "; ".join(reasons)
    
    return True, "All prompts derived from plan"


def check_plan_completeness(plan: Dict[str, Any]) -> Dict[str, bool]:
    """Check if plan has all necessary fields for conditioning."""
    if isinstance(plan, dict):
        has_scene = bool(plan.get("scene_summary") or plan.get("scene"))
        has_visual = bool(plan.get("visual_attributes") or plan.get("visual_elements"))
        has_audio = bool(plan.get("audio_elements") or plan.get("audio_intent"))
        has_entities = bool(plan.get("primary_entities") or plan.get("main_subjects"))
        
        return {
            "has_scene": has_scene,
            "has_visual": has_visual,
            "has_audio": has_audio,
            "has_entities": has_entities,
            "is_complete": has_scene and (has_visual or has_audio),
        }
    
    return {
        "has_scene": False,
        "has_visual": False,
        "has_audio": False,
        "has_entities": False,
        "is_complete": False,
    }