File size: 9,844 Bytes
a09cfc1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | from create_ascii_captions import assign_caption
# Quantity order for scoring partial matches
QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]
# Topics to compare
TOPIC_KEYWORDS = [
#"giant gap", # I think all gaps are subsumed by the floor topic
"floor", "ceiling",
"broken pipe", "upside down pipe", "pipe",
"coin line", "coin",
"platform", "tower", #"wall",
"broken cannon", "cannon",
"ascending staircase", "descending staircase",
"rectangular",
"irregular",
"question block", "loose block",
"enem" # catch "enemy"/"enemies"
]
# Need list because the order matters
KEYWORD_TO_NEGATED_PLURAL = [
(" broken pipe.", ""), # If not the first phrase
("broken pipe. ", ""), # If the first phrase (after removing all others)
(" broken cannon.", ""), # If not the first phrase
("broken cannon. ", ""), # If the first phrase (after removing all others)
("pipe", "pipes"),
("cannon", "cannons"),
("platform", "platforms"),
("tower", "towers"),
("staircase", "staircases"),
("enem", "enemies"),
("rectangular", "rectangular block clusters"),
("irregular", "irregular block clusters"),
("coin line", "coin lines"),
("coin.", "coins."), # Need period to avoid matching "coin line"
("question block", "question blocks"),
("loose block", "loose blocks")
]
BROKEN_TOPICS = 2 # Number of topics that are considered "broken" (e.g., "broken pipe", "broken cannon")
# Plural normalization map (irregulars)
PLURAL_EXCEPTIONS = {
"enemies": "enemy",
}
def normalize_plural(phrase):
# Normalize known irregular plurals
for plural, singular in PLURAL_EXCEPTIONS.items():
phrase = phrase.replace(plural, singular)
# Normalize regular plurals (basic "s" endings)
words = phrase.split()
normalized_words = []
for word in words:
if word.endswith('s') and not word.endswith('ss'): # avoid "class", "boss"
singular = word[:-1]
normalized_words.append(singular)
else:
normalized_words.append(word)
return ' '.join(normalized_words)
def extract_phrases(caption, debug=False):
phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
topic_to_phrase = {}
already_matched_phrases = set() # Track phrases that have been matched
for topic in TOPIC_KEYWORDS:
matching_phrases = []
for p in phrases:
# Only consider phrases that haven't been matched to longer topics
if topic in p and p not in already_matched_phrases:
matching_phrases.append(p)
if matching_phrases:
# Filter out "no ..." phrases as equivalent to absence
phrase = matching_phrases[0]
if phrase.lower().startswith("no "):
topic_to_phrase[topic] = None
if debug:
print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
else:
topic_to_phrase[topic] = phrase
already_matched_phrases.add(phrase) # Mark this phrase as matched
if debug:
print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
else:
topic_to_phrase[topic] = None
if debug:
print(f"[Extract] Topic '{topic}': no phrase found")
return topic_to_phrase
def quantity_score(phrase1, phrase2, debug=False):
def find_quantity(phrase):
for term in QUANTITY_TERMS:
if term in phrase:
return term
return None
qty1 = find_quantity(phrase1)
qty2 = find_quantity(phrase2)
if debug:
print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")
if qty1 and qty2:
idx1 = QUANTITY_TERMS.index(qty1)
idx2 = QUANTITY_TERMS.index(qty2)
diff = abs(idx1 - idx2)
max_diff = len(QUANTITY_TERMS) - 1
score = 1.0 - (diff / max_diff)
if debug:
print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
return score
if debug:
print("[Quantity] At least one quantity missing, assigning partial score 0.1")
return 0.1
def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
correct_phrases = extract_phrases(correct_caption, debug=debug)
generated_phrases = extract_phrases(generated_caption, debug=debug)
total_score = 0.0
num_topics = len(TOPIC_KEYWORDS)
exact_matches = []
partial_matches = []
excess_phrases = []
if debug:
print("\n--- Starting Topic Comparison ---\n")
for topic in TOPIC_KEYWORDS:
correct = correct_phrases[topic]
generated = generated_phrases[topic]
if debug:
print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")
if correct is None and generated is None:
total_score += 1.0
if debug:
print(f"[Topic: {topic}] Both None — full score: 1.0\n")
elif correct is None or generated is None:
total_score += -1.0
if generated is not None: # Considered an excess phrase
excess_phrases.append(generated)
if debug:
print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
else:
# Normalize pluralization before comparison
norm_correct = normalize_plural(correct)
norm_generated = normalize_plural(generated)
if debug:
print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")
if norm_correct == norm_generated:
total_score += 1.0
exact_matches.append(generated)
if debug:
print(f"[Topic: {topic}] Exact match — score: 1.0\n")
elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
total_score += qty_score
partial_matches.append(generated)
if debug:
print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
else:
total_score += 0.1
partial_matches.append(generated)
if debug:
print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")
if debug:
print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")
if debug:
print("total_score before normalization:", total_score)
print(f"Number of topics: {num_topics}")
final_score = total_score / num_topics
if debug:
print(f"--- Final score: {final_score:.4f} ---\n")
if return_matches:
return final_score, exact_matches, partial_matches, excess_phrases
return final_score
def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
"""
Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.
Args:
scene (list): The scene to process, represented as a 2D list.
segment_width (int): The width of each segment.
prompt (str): The prompt to compare captions against.
id_to_char (dict): Mapping from tile IDs to characters.
char_to_id (dict): Mapping from characters to tile IDs.
tile_descriptors (dict): Descriptions of individual tile types.
describe_locations (bool): Whether to include location descriptions in captions.
describe_absence (bool): Whether to indicate absence of items in captions.
verbose (bool): If True, print captions and scores for each segment.
Returns:
tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.
"""
# Partition the scene into segments of the specified width
segments = [
[row[i:i+segment_width] for row in scene] # Properly slice each row of the scene
for i in range(0, len(scene[0]), segment_width)
]
# Assign captions and compute scores for each segment
segment_scores = []
segment_captions = []
for idx, segment in enumerate(segments):
segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
segment_score = compare_captions(prompt, segment_caption)
segment_scores.append(segment_score)
segment_captions.append(segment_caption)
if verbose:
print(f"Segment {idx + 1} caption: {segment_caption}")
print(f"Segment {idx + 1} comparison score: {segment_score}")
# Compute the average comparison score
average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0
if verbose:
print(f"Average comparison score across all segments: {average_score}")
return average_score, segment_captions, segment_scores
if __name__ == '__main__':
ref = "floor with one gap. two enemies. one platform. one tower."
gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."
score = compare_captions(ref, gen, debug=True)
print(f"Should be: {ref}")
print(f" but was: {gen}")
print(f"Score: {score}")
|