Somin-Aggarwal's picture
Upload 3 files
01a014b verified
"""
Grading utilities for the Annotation QA Environment.
Provides deterministic scoring (0.0-1.0) based on:
- IoU (Intersection over Union) of bounding boxes
- Class label accuracy
- Precision (penalizes spurious annotations)
- Recall (penalizes missed annotations)
Uses Hungarian matching to optimally pair predicted vs gold annotations.
"""
from typing import Dict, List, Tuple
def compute_iou(box_a: List[float], box_b: List[float]) -> float:
"""
Compute Intersection over Union between two boxes.
Boxes are [x, y, w, h] with values in 0.0–1.0.
"""
ax, ay, aw, ah = box_a
bx, by, bw, bh = box_b
# Convert to (x1, y1, x2, y2)
a_x1, a_y1, a_x2, a_y2 = ax, ay, ax + aw, ay + ah
b_x1, b_y1, b_x2, b_y2 = bx, by, bx + bw, by + bh
# Intersection
inter_x1 = max(a_x1, b_x1)
inter_y1 = max(a_y1, b_y1)
inter_x2 = min(a_x2, b_x2)
inter_y2 = min(a_y2, b_y2)
inter_w = max(0, inter_x2 - inter_x1)
inter_h = max(0, inter_y2 - inter_y1)
inter_area = inter_w * inter_h
# Union
area_a = aw * ah
area_b = bw * bh
union_area = area_a + area_b - inter_area
if union_area < 1e-8:
return 0.0
return inter_area / union_area
def compute_annotation_quality(
annotations: List[Dict],
gold_annotations: List[Dict],
) -> float:
"""
Compute specific Semantic VLM visual QA testing metrics (0.0-1.0).
Graded on:
- Spurious Precision (35%): Did you remove fake boxes without destroying real ones?
- Class Match Accuracy (35%): For existing valid boxes, did you change to the correct Gold label?
- Missing Flag Recall (30%): Did you successfully use FLAG_MISSING for objects removed from the image?
"""
from collections import Counter
if not gold_annotations:
return 1.0 if not annotations else 0.5
# 1. Spurious Precision
gold_map = {a["id"]: a for a in gold_annotations}
predictions_valid = [a for a in annotations if not a.get("class_label", "").startswith("missing_")]
if not predictions_valid:
precision = 0.0
else:
precision = sum(1 for a in predictions_valid if a["id"] in gold_map) / len(predictions_valid)
# 2. Class Match Accuracy for valid boxes
matched = [a for a in predictions_valid if a["id"] in gold_map]
if not matched:
class_acc = 0.0
else:
class_acc = sum(1 for a in matched if a.get("class_label", "") == gold_map[a["id"]].get("class_label", "")) / len(matched)
# 3. Missing Object Flag Recall
expected_classes = [g.get("class_label", "") for g in gold_annotations]
present_classes = [a.get("class_label", "") for a in annotations if a["id"] in gold_map and not a.get("class_label", "").startswith("missing_")]
# Calculate exact missing instances mathematically
exp_counts = Counter(expected_classes)
pres_counts = Counter(present_classes)
actual_missing_classes = []
for cls, count in exp_counts.items():
if count > pres_counts.get(cls, 0):
for _ in range(count - pres_counts.get(cls, 0)):
actual_missing_classes.append(cls)
if not actual_missing_classes:
missing_acc = 1.0
else:
flagged_classes = [a.get("class_label", "").replace("missing_", "", 1) for a in annotations if a.get("class_label", "").startswith("missing_")]
flagged_counts = Counter(flagged_classes)
caught = 0
for cls in actual_missing_classes:
if flagged_counts.get(cls, 0) > 0:
caught += 1
flagged_counts[cls] -= 1
missing_acc = caught / len(actual_missing_classes)
quality = 0.35 * class_acc + 0.35 * precision + 0.30 * missing_acc
return max(0.0, min(1.0, quality))
def grade_episode(
initial_annotations: List[Dict],
final_annotations: List[Dict],
gold_annotations: List[Dict],
) -> float:
"""
Compute the episode grade (0.0–1.0).
"""
initial_quality = compute_annotation_quality(initial_annotations, gold_annotations)
final_quality = compute_annotation_quality(final_annotations, gold_annotations)
max_improvement = 1.0 - initial_quality
if max_improvement < 0.01:
return 1.0 if final_quality >= initial_quality - 0.01 else 0.5
improvement = final_quality - initial_quality
score = improvement / max_improvement
return max(0.0, min(1.0, score))
def compute_step_reward(
old_annotations: List[Dict],
new_annotations: List[Dict],
gold_annotations: List[Dict],
action_type: str,
) -> float:
"""
Compute dense per-step reward based on quality delta.
"""
old_quality = compute_annotation_quality(old_annotations, gold_annotations)
new_quality = compute_annotation_quality(new_annotations, gold_annotations)
delta = new_quality - old_quality
reward = delta * 2.0 # quality improvement → reward
reward -= 0.01 # step penalty
if action_type == "submit":
reward += 0.05
return round(reward, 4)