File size: 4,616 Bytes
8b4d6a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Annotation QA Environment β€” Type-Safe Models.

Defines the API contract for the Annotation QA Environment:
- AnnotationQAAction: What corrections the agent can make
- AnnotationQAObservation: What the agent sees (scene + annotations)
- AnnotationQAState: Episode metadata

The agent reviews intentionally-flawed annotations on synthetic scenes
and must fix bounding boxes, correct class labels, add missing annotations,
or remove spurious ones.
"""

from typing import Any, Dict, List, Literal, Optional

from pydantic import BaseModel, Field


# ──────────────────────────────────────────────
# Annotation data structure
# ──────────────────────────────────────────────

class Annotation(BaseModel):
    """A single annotation: bounding box + class label."""
    id: int
    bbox: List[float] = Field(
        ...,
        description="Bounding box as [x, y, w, h] normalized to 0.0–1.0",
        min_length=4,
        max_length=4,
    )
    class_label: str = Field(..., description="Object class label, e.g. 'car', 'person'")


# ──────────────────────────────────────────────
# Action
# ──────────────────────────────────────────────

class AnnotationQAAction(BaseModel):
    """
    An action the agent can take to correct annotations.

    action_type determines which fields are required:
    - "adjust_bbox": requires annotation_id, new_bbox
    - "change_class": requires annotation_id, new_class
    - "add_annotation": requires new_bbox, new_class
    - "remove_annotation": requires annotation_id
    - "submit": no extra fields needed (finalizes episode)
    """
    action_type: Literal[
        "adjust_bbox",
        "change_class",
        "add_annotation",
        "remove_annotation",
        "submit",
    ]
    annotation_id: Optional[int] = Field(
        None, description="ID of the annotation to modify"
    )
    new_bbox: Optional[List[float]] = Field(
        None,
        description="New bounding box [x, y, w, h] in 0.0–1.0",
        min_length=4,
        max_length=4,
    )
    new_class: Optional[str] = Field(
        None, description="New class label"
    )
    metadata: Dict[str, Any] = Field(default_factory=dict)


# ──────────────────────────────────────────────
# Observation
# ──────────────────────────────────────────────

class AnnotationQAObservation(BaseModel):
    """
    What the agent sees after each step.

    Includes the scene description, current annotations (some may be wrong),
    available classes, and progress info.
    """
    done: bool = False
    reward: Optional[float] = None

    # Scene information
    scene_description: str = Field(
        "", description="Natural-language description of the scene"
    )
    scene_objects: List[Dict[str, Any]] = Field(
        default_factory=list,
        description="Ground-truth object list with positions (visible to agent as scene context)",
    )

    # Current annotations (may contain errors)
    annotations: List[Annotation] = Field(
        default_factory=list,
        description="Current annotations the agent should review/fix",
    )

    # Task context
    available_classes: List[str] = Field(
        default_factory=list,
        description="Valid class labels for this task",
    )
    task_id: str = ""
    task_description: str = ""

    # Progress
    corrections_made: int = 0
    step_count: int = 0
    max_steps: int = 20

    # Feedback
    message: str = ""
    last_action_error: Optional[str] = None


# ──────────────────────────────────────────────
# State
# ──────────────────────────────────────────────

class AnnotationQAState(BaseModel):
    """Episode metadata β€” internal state tracked by the environment."""
    episode_id: Optional[str] = None
    step_count: int = 0
    task_id: str = ""
    sample_id: str = ""
    initial_quality: float = 0.0
    current_quality: float = 0.0
    corrections_made: int = 0