File size: 5,057 Bytes
01a014b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Grading utilities for the Annotation QA Environment.

Provides deterministic scoring (0.0-1.0) based on:
- IoU (Intersection over Union) of bounding boxes
- Class label accuracy
- Precision (penalizes spurious annotations)
- Recall (penalizes missed annotations)

Uses Hungarian matching to optimally pair predicted vs gold annotations.
"""

from typing import Dict, List, Tuple


def compute_iou(box_a: List[float], box_b: List[float]) -> float:
    """
    Compute Intersection over Union between two boxes.
    Boxes are [x, y, w, h] with values in 0.0–1.0.
    """
    ax, ay, aw, ah = box_a
    bx, by, bw, bh = box_b

    # Convert to (x1, y1, x2, y2)
    a_x1, a_y1, a_x2, a_y2 = ax, ay, ax + aw, ay + ah
    b_x1, b_y1, b_x2, b_y2 = bx, by, bx + bw, by + bh

    # Intersection
    inter_x1 = max(a_x1, b_x1)
    inter_y1 = max(a_y1, b_y1)
    inter_x2 = min(a_x2, b_x2)
    inter_y2 = min(a_y2, b_y2)

    inter_w = max(0, inter_x2 - inter_x1)
    inter_h = max(0, inter_y2 - inter_y1)
    inter_area = inter_w * inter_h

    # Union
    area_a = aw * ah
    area_b = bw * bh
    union_area = area_a + area_b - inter_area

    if union_area < 1e-8:
        return 0.0

    return inter_area / union_area


def compute_annotation_quality(
    annotations: List[Dict],
    gold_annotations: List[Dict],
) -> float:
    """
    Compute specific Semantic VLM visual QA testing metrics (0.0-1.0).
    Graded on:
    - Spurious Precision (35%): Did you remove fake boxes without destroying real ones?
    - Class Match Accuracy (35%): For existing valid boxes, did you change to the correct Gold label?
    - Missing Flag Recall (30%): Did you successfully use FLAG_MISSING for objects removed from the image?
    """
    from collections import Counter

    if not gold_annotations:
        return 1.0 if not annotations else 0.5

    # 1. Spurious Precision
    gold_map = {a["id"]: a for a in gold_annotations}
    predictions_valid = [a for a in annotations if not a.get("class_label", "").startswith("missing_")]

    if not predictions_valid:
        precision = 0.0
    else:
        precision = sum(1 for a in predictions_valid if a["id"] in gold_map) / len(predictions_valid)
        
    # 2. Class Match Accuracy for valid boxes
    matched = [a for a in predictions_valid if a["id"] in gold_map]
    if not matched:
        class_acc = 0.0
    else:
        class_acc = sum(1 for a in matched if a.get("class_label", "") == gold_map[a["id"]].get("class_label", "")) / len(matched)
        
    # 3. Missing Object Flag Recall
    expected_classes = [g.get("class_label", "") for g in gold_annotations]
    present_classes = [a.get("class_label", "") for a in annotations if a["id"] in gold_map and not a.get("class_label", "").startswith("missing_")]
    
    # Calculate exact missing instances mathematically
    exp_counts = Counter(expected_classes)
    pres_counts = Counter(present_classes)
    
    actual_missing_classes = []
    for cls, count in exp_counts.items():
        if count > pres_counts.get(cls, 0):
            for _ in range(count - pres_counts.get(cls, 0)):
                actual_missing_classes.append(cls)
                
    if not actual_missing_classes:
        missing_acc = 1.0
    else:
        flagged_classes = [a.get("class_label", "").replace("missing_", "", 1) for a in annotations if a.get("class_label", "").startswith("missing_")]
        flagged_counts = Counter(flagged_classes)

        caught = 0
        for cls in actual_missing_classes:
            if flagged_counts.get(cls, 0) > 0:
                caught += 1
                flagged_counts[cls] -= 1
        missing_acc = caught / len(actual_missing_classes)
        
    quality = 0.35 * class_acc + 0.35 * precision + 0.30 * missing_acc
    return max(0.0, min(1.0, quality))


def grade_episode(
    initial_annotations: List[Dict],
    final_annotations: List[Dict],
    gold_annotations: List[Dict],
) -> float:
    """
    Compute the episode grade (0.0–1.0).
    """
    initial_quality = compute_annotation_quality(initial_annotations, gold_annotations)
    final_quality = compute_annotation_quality(final_annotations, gold_annotations)

    max_improvement = 1.0 - initial_quality
    if max_improvement < 0.01:
        return 1.0 if final_quality >= initial_quality - 0.01 else 0.5

    improvement = final_quality - initial_quality
    score = improvement / max_improvement
    return max(0.0, min(1.0, score))


def compute_step_reward(
    old_annotations: List[Dict],
    new_annotations: List[Dict],
    gold_annotations: List[Dict],
    action_type: str,
) -> float:
    """
    Compute dense per-step reward based on quality delta.
    """
    old_quality = compute_annotation_quality(old_annotations, gold_annotations)
    new_quality = compute_annotation_quality(new_annotations, gold_annotations)
    delta = new_quality - old_quality
    reward = delta * 2.0  # quality improvement → reward
    reward -= 0.01  # step penalty
    if action_type == "submit":
        reward += 0.05
    return round(reward, 4)