File size: 9,844 Bytes
a09cfc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from create_ascii_captions import assign_caption

# Quantity order for scoring partial matches
QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]

# Topics to compare
TOPIC_KEYWORDS = [
    #"giant gap", # I think all gaps are subsumed by the floor topic 
    "floor", "ceiling", 
    "broken pipe", "upside down pipe", "pipe", 
    "coin line", "coin",
    "platform", "tower", #"wall", 
    "broken cannon", "cannon",
    "ascending staircase", "descending staircase",
    "rectangular",
    "irregular", 
    "question block", "loose block",
    "enem"  # catch "enemy"/"enemies"
]

# Need list because the order matters
KEYWORD_TO_NEGATED_PLURAL = [
    (" broken pipe.", ""), # If not the first phrase
    ("broken pipe. ", ""), # If the first phrase (after removing all others)
    (" broken cannon.", ""), # If not the first phrase
    ("broken cannon. ", ""), # If the first phrase (after removing all others)
    ("pipe", "pipes"),
    ("cannon", "cannons"),
    ("platform", "platforms"),
    ("tower", "towers"),
    ("staircase", "staircases"),
    ("enem", "enemies"),
    ("rectangular", "rectangular block clusters"),
    ("irregular", "irregular block clusters"),
    ("coin line", "coin lines"),
    ("coin.", "coins."), # Need period to avoid matching "coin line"
    ("question block", "question blocks"),
    ("loose block", "loose blocks")
]

BROKEN_TOPICS = 2 # Number of topics that are considered "broken" (e.g., "broken pipe", "broken cannon")

# Plural normalization map (irregulars)
PLURAL_EXCEPTIONS = {
    "enemies": "enemy",
}

def normalize_plural(phrase):
    # Normalize known irregular plurals
    for plural, singular in PLURAL_EXCEPTIONS.items():
        phrase = phrase.replace(plural, singular)

    # Normalize regular plurals (basic "s" endings)
    words = phrase.split()
    normalized_words = []
    for word in words:
        if word.endswith('s') and not word.endswith('ss'):  # avoid "class", "boss"
            singular = word[:-1]
            normalized_words.append(singular)
        else:
            normalized_words.append(word)
    return ' '.join(normalized_words)

def extract_phrases(caption, debug=False):
    phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
    topic_to_phrase = {}
    already_matched_phrases = set()  # Track phrases that have been matched
    
    for topic in TOPIC_KEYWORDS:
        matching_phrases = []
        
        for p in phrases:
            # Only consider phrases that haven't been matched to longer topics
            if topic in p and p not in already_matched_phrases:
                matching_phrases.append(p)
        
        if matching_phrases:
            # Filter out "no ..." phrases as equivalent to absence
            phrase = matching_phrases[0]
            if phrase.lower().startswith("no "):
                topic_to_phrase[topic] = None
                if debug:
                    print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
            else:
                topic_to_phrase[topic] = phrase
                already_matched_phrases.add(phrase)  # Mark this phrase as matched
                if debug:
                    print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
        else:
            topic_to_phrase[topic] = None
            if debug:
                print(f"[Extract] Topic '{topic}': no phrase found")
    
    return topic_to_phrase

def quantity_score(phrase1, phrase2, debug=False):
    def find_quantity(phrase):
        for term in QUANTITY_TERMS:
            if term in phrase:
                return term
        return None

    qty1 = find_quantity(phrase1)
    qty2 = find_quantity(phrase2)

    if debug:
        print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")

    if qty1 and qty2:
        idx1 = QUANTITY_TERMS.index(qty1)
        idx2 = QUANTITY_TERMS.index(qty2)
        diff = abs(idx1 - idx2)
        max_diff = len(QUANTITY_TERMS) - 1
        score = 1.0 - (diff / max_diff)
        if debug:
            print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
        return score
    if debug:
        print("[Quantity] At least one quantity missing, assigning partial score 0.1")
    return 0.1

def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
    correct_phrases = extract_phrases(correct_caption, debug=debug)
    generated_phrases = extract_phrases(generated_caption, debug=debug)

    total_score = 0.0
    num_topics = len(TOPIC_KEYWORDS)

    exact_matches = []
    partial_matches = []
    excess_phrases = []

    if debug:
        print("\n--- Starting Topic Comparison ---\n")

    for topic in TOPIC_KEYWORDS:
        correct = correct_phrases[topic]
        generated = generated_phrases[topic]

        if debug:
            print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")

        if correct is None and generated is None:
            total_score += 1.0
            if debug:
                print(f"[Topic: {topic}] Both None — full score: 1.0\n")
        elif correct is None or generated is None:
            total_score += -1.0
            if generated is not None: # Considered an excess phrase
                excess_phrases.append(generated)
            if debug:
                print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
        else:
            # Normalize pluralization before comparison
            norm_correct = normalize_plural(correct)
            norm_generated = normalize_plural(generated)

            if debug:
                print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")

            if norm_correct == norm_generated:
                total_score += 1.0
                exact_matches.append(generated)
                if debug:
                    print(f"[Topic: {topic}] Exact match — score: 1.0\n")
            elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
                qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
                total_score += qty_score
                partial_matches.append(generated)
                if debug:
                    print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
            else:
                total_score += 0.1
                partial_matches.append(generated)
                if debug:
                    print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")

        if debug:
            print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")

    if debug:
        print("total_score before normalization:", total_score)
        print(f"Number of topics: {num_topics}")
        
    final_score = total_score / num_topics
    if debug:
        print(f"--- Final score: {final_score:.4f} ---\n")

    if return_matches:
        return final_score, exact_matches, partial_matches, excess_phrases

    return final_score

def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
    """

    Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.



    Args:

        scene (list): The scene to process, represented as a 2D list.

        segment_width (int): The width of each segment.

        prompt (str): The prompt to compare captions against.

        id_to_char (dict): Mapping from tile IDs to characters.

        char_to_id (dict): Mapping from characters to tile IDs.

        tile_descriptors (dict): Descriptions of individual tile types.

        describe_locations (bool): Whether to include location descriptions in captions.

        describe_absence (bool): Whether to indicate absence of items in captions.

        verbose (bool): If True, print captions and scores for each segment.



    Returns:

        tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.

    """
    # Partition the scene into segments of the specified width
    segments = [
        [row[i:i+segment_width] for row in scene]  # Properly slice each row of the scene
        for i in range(0, len(scene[0]), segment_width)
    ]

    # Assign captions and compute scores for each segment
    segment_scores = []
    segment_captions = []
    for idx, segment in enumerate(segments):
        segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
        segment_score = compare_captions(prompt, segment_caption)
        segment_scores.append(segment_score)
        segment_captions.append(segment_caption)

        if verbose:
            print(f"Segment {idx + 1} caption: {segment_caption}")
            print(f"Segment {idx + 1} comparison score: {segment_score}")

    # Compute the average comparison score
    average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0

    if verbose:
        print(f"Average comparison score across all segments: {average_score}")

    return average_score, segment_captions, segment_scores

if __name__ == '__main__':

    ref = "floor with one gap. two enemies. one platform. one tower."
    gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."

    score = compare_captions(ref, gen, debug=True)
    print(f"Should be: {ref}")
    print(f"  but was: {gen}")
    print(f"Score: {score}")