| | from create_ascii_captions import assign_caption
|
| |
|
| |
|
| | QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]
|
| |
|
| |
|
| | TOPIC_KEYWORDS = [
|
| |
|
| | "floor", "ceiling",
|
| | "broken pipe", "upside down pipe", "pipe",
|
| | "coin line", "coin",
|
| | "platform", "tower",
|
| | "broken cannon", "cannon",
|
| | "ascending staircase", "descending staircase",
|
| | "rectangular",
|
| | "irregular",
|
| | "question block", "loose block",
|
| | "enem"
|
| | ]
|
| |
|
| |
|
| | KEYWORD_TO_NEGATED_PLURAL = [
|
| | (" broken pipe.", ""),
|
| | ("broken pipe. ", ""),
|
| | (" broken cannon.", ""),
|
| | ("broken cannon. ", ""),
|
| | ("pipe", "pipes"),
|
| | ("cannon", "cannons"),
|
| | ("platform", "platforms"),
|
| | ("tower", "towers"),
|
| | ("staircase", "staircases"),
|
| | ("enem", "enemies"),
|
| | ("rectangular", "rectangular block clusters"),
|
| | ("irregular", "irregular block clusters"),
|
| | ("coin line", "coin lines"),
|
| | ("coin.", "coins."),
|
| | ("question block", "question blocks"),
|
| | ("loose block", "loose blocks")
|
| | ]
|
| |
|
| | BROKEN_TOPICS = 2
|
| |
|
| |
|
| | PLURAL_EXCEPTIONS = {
|
| | "enemies": "enemy",
|
| | }
|
| |
|
| | def normalize_plural(phrase):
|
| |
|
| | for plural, singular in PLURAL_EXCEPTIONS.items():
|
| | phrase = phrase.replace(plural, singular)
|
| |
|
| |
|
| | words = phrase.split()
|
| | normalized_words = []
|
| | for word in words:
|
| | if word.endswith('s') and not word.endswith('ss'):
|
| | singular = word[:-1]
|
| | normalized_words.append(singular)
|
| | else:
|
| | normalized_words.append(word)
|
| | return ' '.join(normalized_words)
|
| |
|
| | def extract_phrases(caption, debug=False):
|
| | phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
|
| | topic_to_phrase = {}
|
| | already_matched_phrases = set()
|
| |
|
| | for topic in TOPIC_KEYWORDS:
|
| | matching_phrases = []
|
| |
|
| | for p in phrases:
|
| |
|
| | if topic in p and p not in already_matched_phrases:
|
| | matching_phrases.append(p)
|
| |
|
| | if matching_phrases:
|
| |
|
| | phrase = matching_phrases[0]
|
| | if phrase.lower().startswith("no "):
|
| | topic_to_phrase[topic] = None
|
| | if debug:
|
| | print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
|
| | else:
|
| | topic_to_phrase[topic] = phrase
|
| | already_matched_phrases.add(phrase)
|
| | if debug:
|
| | print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
|
| | else:
|
| | topic_to_phrase[topic] = None
|
| | if debug:
|
| | print(f"[Extract] Topic '{topic}': no phrase found")
|
| |
|
| | return topic_to_phrase
|
| |
|
| | def quantity_score(phrase1, phrase2, debug=False):
|
| | def find_quantity(phrase):
|
| | for term in QUANTITY_TERMS:
|
| | if term in phrase:
|
| | return term
|
| | return None
|
| |
|
| | qty1 = find_quantity(phrase1)
|
| | qty2 = find_quantity(phrase2)
|
| |
|
| | if debug:
|
| | print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")
|
| |
|
| | if qty1 and qty2:
|
| | idx1 = QUANTITY_TERMS.index(qty1)
|
| | idx2 = QUANTITY_TERMS.index(qty2)
|
| | diff = abs(idx1 - idx2)
|
| | max_diff = len(QUANTITY_TERMS) - 1
|
| | score = 1.0 - (diff / max_diff)
|
| | if debug:
|
| | print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
|
| | return score
|
| | if debug:
|
| | print("[Quantity] At least one quantity missing, assigning partial score 0.1")
|
| | return 0.1
|
| |
|
| | def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
|
| | correct_phrases = extract_phrases(correct_caption, debug=debug)
|
| | generated_phrases = extract_phrases(generated_caption, debug=debug)
|
| |
|
| | total_score = 0.0
|
| | num_topics = len(TOPIC_KEYWORDS)
|
| |
|
| | exact_matches = []
|
| | partial_matches = []
|
| | excess_phrases = []
|
| |
|
| | if debug:
|
| | print("\n--- Starting Topic Comparison ---\n")
|
| |
|
| | for topic in TOPIC_KEYWORDS:
|
| | correct = correct_phrases[topic]
|
| | generated = generated_phrases[topic]
|
| |
|
| | if debug:
|
| | print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")
|
| |
|
| | if correct is None and generated is None:
|
| | total_score += 1.0
|
| | if debug:
|
| | print(f"[Topic: {topic}] Both None — full score: 1.0\n")
|
| | elif correct is None or generated is None:
|
| | total_score += -1.0
|
| | if generated is not None:
|
| | excess_phrases.append(generated)
|
| | if debug:
|
| | print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
|
| | else:
|
| |
|
| | norm_correct = normalize_plural(correct)
|
| | norm_generated = normalize_plural(generated)
|
| |
|
| | if debug:
|
| | print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")
|
| |
|
| | if norm_correct == norm_generated:
|
| | total_score += 1.0
|
| | exact_matches.append(generated)
|
| | if debug:
|
| | print(f"[Topic: {topic}] Exact match — score: 1.0\n")
|
| | elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
|
| | qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
|
| | total_score += qty_score
|
| | partial_matches.append(generated)
|
| | if debug:
|
| | print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
|
| | else:
|
| | total_score += 0.1
|
| | partial_matches.append(generated)
|
| | if debug:
|
| | print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")
|
| |
|
| | if debug:
|
| | print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")
|
| |
|
| | if debug:
|
| | print("total_score before normalization:", total_score)
|
| | print(f"Number of topics: {num_topics}")
|
| |
|
| | final_score = total_score / num_topics
|
| | if debug:
|
| | print(f"--- Final score: {final_score:.4f} ---\n")
|
| |
|
| | if return_matches:
|
| | return final_score, exact_matches, partial_matches, excess_phrases
|
| |
|
| | return final_score
|
| |
|
| | def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
|
| | """
|
| | Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.
|
| |
|
| | Args:
|
| | scene (list): The scene to process, represented as a 2D list.
|
| | segment_width (int): The width of each segment.
|
| | prompt (str): The prompt to compare captions against.
|
| | id_to_char (dict): Mapping from tile IDs to characters.
|
| | char_to_id (dict): Mapping from characters to tile IDs.
|
| | tile_descriptors (dict): Descriptions of individual tile types.
|
| | describe_locations (bool): Whether to include location descriptions in captions.
|
| | describe_absence (bool): Whether to indicate absence of items in captions.
|
| | verbose (bool): If True, print captions and scores for each segment.
|
| |
|
| | Returns:
|
| | tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.
|
| | """
|
| |
|
| | segments = [
|
| | [row[i:i+segment_width] for row in scene]
|
| | for i in range(0, len(scene[0]), segment_width)
|
| | ]
|
| |
|
| |
|
| | segment_scores = []
|
| | segment_captions = []
|
| | for idx, segment in enumerate(segments):
|
| | segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
|
| | segment_score = compare_captions(prompt, segment_caption)
|
| | segment_scores.append(segment_score)
|
| | segment_captions.append(segment_caption)
|
| |
|
| | if verbose:
|
| | print(f"Segment {idx + 1} caption: {segment_caption}")
|
| | print(f"Segment {idx + 1} comparison score: {segment_score}")
|
| |
|
| |
|
| | average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0
|
| |
|
| | if verbose:
|
| | print(f"Average comparison score across all segments: {average_score}")
|
| |
|
| | return average_score, segment_captions, segment_scores
|
| |
|
| | if __name__ == '__main__':
|
| |
|
| | ref = "floor with one gap. two enemies. one platform. one tower."
|
| | gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."
|
| |
|
| | score = compare_captions(ref, gen, debug=True)
|
| | print(f"Should be: {ref}")
|
| | print(f" but was: {gen}")
|
| | print(f"Score: {score}")
|
| |
|