schrum2 commited on
Commit
35bc800
·
verified ·
1 Parent(s): 634e6bc

Don't think I need this

Browse files
Files changed (1) hide show
  1. caption_match.py +0 -247
caption_match.py DELETED
@@ -1,247 +0,0 @@
1
- from create_ascii_captions import assign_caption
2
-
3
- # Quantity order for scoring partial matches
4
- QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]
5
-
6
- # Topics to compare
7
- TOPIC_KEYWORDS = [
8
- #"giant gap", # I think all gaps are subsumed by the floor topic
9
- "floor", "ceiling",
10
- "broken pipe", "upside down pipe", "pipe",
11
- "coin line", "coin",
12
- "platform", "tower", #"wall",
13
- "broken cannon", "cannon",
14
- "ascending staircase", "descending staircase",
15
- "rectangular",
16
- "irregular",
17
- "question block", "loose block",
18
- "enem" # catch "enemy"/"enemies"
19
- ]
20
-
21
- # Need list because the order matters
22
- KEYWORD_TO_NEGATED_PLURAL = [
23
- (" broken pipe.", ""), # If not the first phrase
24
- ("broken pipe. ", ""), # If the first phrase (after removing all others)
25
- (" broken cannon.", ""), # If not the first phrase
26
- ("broken cannon. ", ""), # If the first phrase (after removing all others)
27
- ("pipe", "pipes"),
28
- ("cannon", "cannons"),
29
- ("platform", "platforms"),
30
- ("tower", "towers"),
31
- ("staircase", "staircases"),
32
- ("enem", "enemies"),
33
- ("rectangular", "rectangular block clusters"),
34
- ("irregular", "irregular block clusters"),
35
- ("coin line", "coin lines"),
36
- ("coin.", "coins."), # Need period to avoid matching "coin line"
37
- ("question block", "question blocks"),
38
- ("loose block", "loose blocks")
39
- ]
40
-
41
- BROKEN_TOPICS = 2 # Number of topics that are considered "broken" (e.g., "broken pipe", "broken cannon")
42
-
43
- # Plural normalization map (irregulars)
44
- PLURAL_EXCEPTIONS = {
45
- "enemies": "enemy",
46
- }
47
-
48
- def normalize_plural(phrase):
49
- # Normalize known irregular plurals
50
- for plural, singular in PLURAL_EXCEPTIONS.items():
51
- phrase = phrase.replace(plural, singular)
52
-
53
- # Normalize regular plurals (basic "s" endings)
54
- words = phrase.split()
55
- normalized_words = []
56
- for word in words:
57
- if word.endswith('s') and not word.endswith('ss'): # avoid "class", "boss"
58
- singular = word[:-1]
59
- normalized_words.append(singular)
60
- else:
61
- normalized_words.append(word)
62
- return ' '.join(normalized_words)
63
-
64
- def extract_phrases(caption, debug=False):
65
- phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
66
- topic_to_phrase = {}
67
- already_matched_phrases = set() # Track phrases that have been matched
68
-
69
- for topic in TOPIC_KEYWORDS:
70
- matching_phrases = []
71
-
72
- for p in phrases:
73
- # Only consider phrases that haven't been matched to longer topics
74
- if topic in p and p not in already_matched_phrases:
75
- matching_phrases.append(p)
76
-
77
- if matching_phrases:
78
- # Filter out "no ..." phrases as equivalent to absence
79
- phrase = matching_phrases[0]
80
- if phrase.lower().startswith("no "):
81
- topic_to_phrase[topic] = None
82
- if debug:
83
- print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
84
- else:
85
- topic_to_phrase[topic] = phrase
86
- already_matched_phrases.add(phrase) # Mark this phrase as matched
87
- if debug:
88
- print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
89
- else:
90
- topic_to_phrase[topic] = None
91
- if debug:
92
- print(f"[Extract] Topic '{topic}': no phrase found")
93
-
94
- return topic_to_phrase
95
-
96
- def quantity_score(phrase1, phrase2, debug=False):
97
- def find_quantity(phrase):
98
- for term in QUANTITY_TERMS:
99
- if term in phrase:
100
- return term
101
- return None
102
-
103
- qty1 = find_quantity(phrase1)
104
- qty2 = find_quantity(phrase2)
105
-
106
- if debug:
107
- print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")
108
-
109
- if qty1 and qty2:
110
- idx1 = QUANTITY_TERMS.index(qty1)
111
- idx2 = QUANTITY_TERMS.index(qty2)
112
- diff = abs(idx1 - idx2)
113
- max_diff = len(QUANTITY_TERMS) - 1
114
- score = 1.0 - (diff / max_diff)
115
- if debug:
116
- print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
117
- return score
118
- if debug:
119
- print("[Quantity] At least one quantity missing, assigning partial score 0.1")
120
- return 0.1
121
-
122
- def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
123
- correct_phrases = extract_phrases(correct_caption, debug=debug)
124
- generated_phrases = extract_phrases(generated_caption, debug=debug)
125
-
126
- total_score = 0.0
127
- num_topics = len(TOPIC_KEYWORDS)
128
-
129
- exact_matches = []
130
- partial_matches = []
131
- excess_phrases = []
132
-
133
- if debug:
134
- print("\n--- Starting Topic Comparison ---\n")
135
-
136
- for topic in TOPIC_KEYWORDS:
137
- correct = correct_phrases[topic]
138
- generated = generated_phrases[topic]
139
-
140
- if debug:
141
- print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")
142
-
143
- if correct is None and generated is None:
144
- total_score += 1.0
145
- if debug:
146
- print(f"[Topic: {topic}] Both None — full score: 1.0\n")
147
- elif correct is None or generated is None:
148
- total_score += -1.0
149
- if generated is not None: # Considered an excess phrase
150
- excess_phrases.append(generated)
151
- if debug:
152
- print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
153
- else:
154
- # Normalize pluralization before comparison
155
- norm_correct = normalize_plural(correct)
156
- norm_generated = normalize_plural(generated)
157
-
158
- if debug:
159
- print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")
160
-
161
- if norm_correct == norm_generated:
162
- total_score += 1.0
163
- exact_matches.append(generated)
164
- if debug:
165
- print(f"[Topic: {topic}] Exact match — score: 1.0\n")
166
- elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
167
- qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
168
- total_score += qty_score
169
- partial_matches.append(generated)
170
- if debug:
171
- print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
172
- else:
173
- total_score += 0.1
174
- partial_matches.append(generated)
175
- if debug:
176
- print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")
177
-
178
- if debug:
179
- print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")
180
-
181
- if debug:
182
- print("total_score before normalization:", total_score)
183
- print(f"Number of topics: {num_topics}")
184
-
185
- final_score = total_score / num_topics
186
- if debug:
187
- print(f"--- Final score: {final_score:.4f} ---\n")
188
-
189
- if return_matches:
190
- return final_score, exact_matches, partial_matches, excess_phrases
191
-
192
- return final_score
193
-
194
- def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
195
- """
196
- Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.
197
-
198
- Args:
199
- scene (list): The scene to process, represented as a 2D list.
200
- segment_width (int): The width of each segment.
201
- prompt (str): The prompt to compare captions against.
202
- id_to_char (dict): Mapping from tile IDs to characters.
203
- char_to_id (dict): Mapping from characters to tile IDs.
204
- tile_descriptors (dict): Descriptions of individual tile types.
205
- describe_locations (bool): Whether to include location descriptions in captions.
206
- describe_absence (bool): Whether to indicate absence of items in captions.
207
- verbose (bool): If True, print captions and scores for each segment.
208
-
209
- Returns:
210
- tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.
211
- """
212
- # Partition the scene into segments of the specified width
213
- segments = [
214
- [row[i:i+segment_width] for row in scene] # Properly slice each row of the scene
215
- for i in range(0, len(scene[0]), segment_width)
216
- ]
217
-
218
- # Assign captions and compute scores for each segment
219
- segment_scores = []
220
- segment_captions = []
221
- for idx, segment in enumerate(segments):
222
- segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
223
- segment_score = compare_captions(prompt, segment_caption)
224
- segment_scores.append(segment_score)
225
- segment_captions.append(segment_caption)
226
-
227
- if verbose:
228
- print(f"Segment {idx + 1} caption: {segment_caption}")
229
- print(f"Segment {idx + 1} comparison score: {segment_score}")
230
-
231
- # Compute the average comparison score
232
- average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0
233
-
234
- if verbose:
235
- print(f"Average comparison score across all segments: {average_score}")
236
-
237
- return average_score, segment_captions, segment_scores
238
-
239
- if __name__ == '__main__':
240
-
241
- ref = "floor with one gap. two enemies. one platform. one tower."
242
- gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."
243
-
244
- score = compare_captions(ref, gen, debug=True)
245
- print(f"Should be: {ref}")
246
- print(f" but was: {gen}")
247
- print(f"Score: {score}")