Somin-Aggarwal commited on
Commit
efaa228
Β·
verified Β·
1 Parent(s): 01a014b

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +87 -6
  2. client.py +88 -0
  3. debug_overlay_test.jpg +0 -0
  4. inference.py +418 -0
  5. models.py +149 -0
  6. openenv.yaml +6 -0
  7. pyproject.toml +43 -0
  8. uv.lock +0 -0
README.md CHANGED
@@ -1,10 +1,91 @@
1
  ---
2
- title: AnnotationReviewer
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: docker
7
- pinned: false
8
  ---
 
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Semantic Annotation QA Env
3
+ emoji: πŸ”
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 8000
8
  ---
9
+ # πŸ” Semantic Annotation QA Environment
10
 
11
+ An **OpenEnv** framework where a Vision-Language Model (VLM) agent reviews and corrects intentionally flawed machine-learning annotations on **real COCO val2017 images**.
12
+
13
+ This environment simulates a highly critical **real-world task**: human-in-the-loop ML Data QA / Content Cleaning. By having an agent actively audit and correct data labels, it tests a *valid domain* while serving as a pure evaluation bed for multimodal agent alignment.
14
+
15
+ ## 🎯 The Challenge & Novelty
16
+
17
+ Traditionally, spatial bounding-box regression tasks test VLMs poorly because model tokenizers destroy contiguous pixel geometry logic. **We solved this.**
18
+
19
+ Instead of asking the model to hallucinate geometric bounding box sizes, we use a **"Set-of-Mark"** overlay philosophy. The environment renders the image with ID tags directly on the visual feed, transforming the VLM into a pure **Semantic Auditor**. This *novel approach* completely fills a severe evaluation gap by cleanly testing a multimodal agent's reasoning power without arbitrary fractional coordinate failures.
20
+
21
+ 1. **Agent receives** a real COCO image + current annotation state
22
+ 2. **Agent visually inspects** the IDs using a continuous inference loop (`openai` client)
23
+ 3. **Agent corrects** errors by calling `REMOVE`, `CHANGE_CLASS`, or `FLAG_MISSING`
24
+ 4. **Agent receives Dense Rewards** at every single step based on strict mathematical quality tracking
25
+
26
+ ## πŸ“‹ 3 Tiered Tasks
27
+
28
+ The environment supports exactly 3 progressively difficult semantic datasets, guaranteeing a deterministic difficulty ramp capable of challenging even the smartest frontier models.
29
+
30
+ | Task | Difficulty | Mechanistic Objective | Max Steps |
31
+ |------|-----------|--------|-----------|
32
+ | `remove_spurious` | Easy 🟒 | Detect and delete fake/hallucinated bounding boxes that enclose thin air. | 15 |
33
+ | `fix_classes` | Medium 🟑 | Combines spurious errors with deliberate cross-class confusion (e.g. `car` ↔ `truck`). | 20 |
34
+ | `find_missing` | Hard πŸ”΄ | Objects are entirely scrubbed from the label matrix. VLM must actively spot missing targets. | 30 |
35
+
36
+
37
+ ## βš™οΈ Environment Design & Rewards
38
+
39
+ The environment strictly enforces proper RL (Reinforcement Learning) paradigms required to actually train agents (e.g. PPO/GRPO setups):
40
+
41
+ - **Clean Boundaries:** The `reset()` function cleanly initializes a fresh scene ID mapping. Episodes logically finalize the moment `SUBMIT` is invoked or max steps are exhausted.
42
+ - **Dense Fractional Reward:** The reward function provides continuous trajectory signaling. Using `quality_delta = new_quality - old_quality`, the environment computes exact positive fractional improvement arrays (`+0.25`, `+0.34`, etc.) every time an agent makes a correct move, rather than sparse binary end-of-episode integers.
43
+ - **Built-in Guardrails:** The reward deducts `-0.01` passively for every executed step, heavily penalizing runaway loops, blind guessing, or destructive action behaviors.
44
+
45
+ ## πŸ“Š Deterministic Grading (0.0 to 1.0)
46
+
47
+ Calculated at every frame step, the Agent receives an un-gameable score out of `1.0` computed from a pure boolean hashmap (completely deterministic and perfectly reproducible):
48
+
49
+ - **Spurious Precision (35%)** β€” Did you remove fake boxes without destroying real ones?
50
+ - **Class Match Accuracy (35%)** β€” For existing valid boxes, did you change to the correct Gold label?
51
+ - **Missing Flag Recall (30%)** β€” Did you successfully use `FLAG_MISSING` for objects stripped from the image?
52
+
53
+ ## πŸ’» Spec Compliance & Quick Start
54
+
55
+ This repository is **100% OpenEnv Spec Compliant**. `openenv validate` passes natively, the `openenv.yaml` handles correct routing, and all interface states (Observation, Actions, Reward signals) use natively typed Pydantic structures in `models.py`.
56
+
57
+ ### 1. Zero-Storage Setup
58
+ Because we dynamically fetch `raw` annotations using explicit COCO API URLs inside `data/prepare_coco.py`, the massive dataset is compressed internally to ~2.5MB. This enables light-speed Docker Deployments & HF Space hosting.
59
+ ```bash
60
+ # Verify Environment
61
+ uv run openenv validate
62
+
63
+ # Containerize
64
+ docker build -t annotation-qa-env:latest .
65
+ docker run -d -p 8000:8000 annotation-qa-env:latest
66
+ ```
67
+
68
+ ### 2. VLM Baseline Inference
69
+ We test via native OpenAI client parity against standard Hugging Face router limits. Ensure you use an advanced vision model endpoint.
70
+
71
+ ```bash
72
+ # For HF Serverless Router
73
+ export OPENAI_API_KEY="your_api_token"
74
+ export API_BASE_URL="https://router.huggingface.co/v1"
75
+ export MODEL_NAME="Qwen/Qwen3-VL-8B-Instruct"
76
+
77
+ # Reproduce the baseline mathematically
78
+ python3 inference.py
79
+ ```
80
+
81
+ ## πŸ€– Pydantic Action Space
82
+
83
+ | Action | Required Fields | Description |
84
+ |--------|----------------|-------------|
85
+ | `change_class` | `annotation_id`, `new_class` | Correct a miscategorized label |
86
+ | `flag_missing` | `missing_class` | Flag a missing target by its class name |
87
+ | `remove_annotation` | `annotation_id` | Delete a completely spurious annotation |
88
+ | `submit` | (none) | Finalize audit corrections |
89
+
90
+ ## πŸ“œ License
91
+ BSD-3-Clause (matching OpenEnv)
client.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Annotation QA Environment Client.
3
+
4
+ Provides the client for connecting to an Annotation QA Environment server.
5
+ """
6
+
7
+ from openenv.core.env_client import EnvClient
8
+ from openenv.core.client_types import StepResult
9
+ from .models import (
10
+ Annotation,
11
+ AnnotationQAAction,
12
+ AnnotationQAObservation,
13
+ AnnotationQAState,
14
+ )
15
+
16
+
17
+ class AnnotationQAEnv(EnvClient[AnnotationQAAction, AnnotationQAObservation, AnnotationQAState]):
18
+ """
19
+ Client for the Annotation QA Environment.
20
+
21
+ Example:
22
+ >>> with AnnotationQAEnv(base_url="http://localhost:8000").sync() as env:
23
+ ... result = env.reset(task="fix_bboxes")
24
+ ... print(result.observation.annotations)
25
+ ... result = env.step(AnnotationQAAction(
26
+ ... action_type="adjust_bbox",
27
+ ... annotation_id=0,
28
+ ... new_bbox=[0.1, 0.2, 0.15, 0.1],
29
+ ... ))
30
+ ... print(result.reward)
31
+ """
32
+
33
+ def _step_payload(self, action: AnnotationQAAction) -> dict:
34
+ """Convert action to wire format."""
35
+ payload = {"action_type": action.action_type}
36
+ if action.annotation_id is not None:
37
+ payload["annotation_id"] = action.annotation_id
38
+ if action.new_bbox is not None:
39
+ payload["new_bbox"] = action.new_bbox
40
+ if action.new_class is not None:
41
+ payload["new_class"] = action.new_class
42
+ return payload
43
+
44
+ def _parse_result(self, payload: dict) -> StepResult:
45
+ """Parse server response into typed StepResult."""
46
+ obs_data = payload.get("observation", payload)
47
+
48
+ annotations = []
49
+ for ann_data in obs_data.get("annotations", []):
50
+ annotations.append(Annotation(
51
+ id=ann_data.get("id", 0),
52
+ bbox=ann_data.get("bbox", [0, 0, 0, 0]),
53
+ class_label=ann_data.get("class_label", ""),
54
+ ))
55
+
56
+ observation = AnnotationQAObservation(
57
+ done=payload.get("done", False),
58
+ reward=payload.get("reward"),
59
+ scene_description=obs_data.get("scene_description", ""),
60
+ scene_objects=obs_data.get("scene_objects", []),
61
+ annotations=annotations,
62
+ available_classes=obs_data.get("available_classes", []),
63
+ task_id=obs_data.get("task_id", ""),
64
+ task_description=obs_data.get("task_description", ""),
65
+ corrections_made=obs_data.get("corrections_made", 0),
66
+ step_count=obs_data.get("step_count", 0),
67
+ max_steps=obs_data.get("max_steps", 20),
68
+ message=obs_data.get("message", ""),
69
+ last_action_error=obs_data.get("last_action_error"),
70
+ )
71
+
72
+ return StepResult(
73
+ observation=observation,
74
+ reward=payload.get("reward"),
75
+ done=payload.get("done", False),
76
+ )
77
+
78
+ def _parse_state(self, payload: dict) -> AnnotationQAState:
79
+ """Parse state response."""
80
+ return AnnotationQAState(
81
+ episode_id=payload.get("episode_id"),
82
+ step_count=payload.get("step_count", 0),
83
+ task_id=payload.get("task_id", ""),
84
+ sample_id=payload.get("sample_id", ""),
85
+ initial_quality=payload.get("initial_quality", 0.0),
86
+ current_quality=payload.get("current_quality", 0.0),
87
+ corrections_made=payload.get("corrections_made", 0),
88
+ )
debug_overlay_test.jpg ADDED
inference.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script β€” Annotation QA Environment (72B One-Shot VQA + Set-of-Mark)
3
+ ==========================================================
4
+ MANDATORY
5
+ - Before submitting, ensure the following variables are defined:
6
+ API_BASE_URL The API endpoint for the VLM.
7
+ MODEL_NAME The model identifier to use for inference.
8
+ HF_TOKEN Your Hugging Face / API key.
9
+
10
+ - STDOUT MUST EXACTLY follow [START], [STEP], and [END] formats.
11
+
12
+ 72B ONE-SHOT VQA APPROACH
13
+ - Uses Qwen2.5-VL-72B-Instruct for incredibly high spatial accuracy.
14
+ - To bypass rigid API rate limits and token costs, the script makes EXACTLY
15
+ ONE API CALL per image.
16
+ - The VLM acts as a visual reviewer, grading every single box in text format.
17
+ - The Python loop then mechanically executes those parsed actions.
18
+ """
19
+
20
+ import base64
21
+ import io
22
+ import json
23
+ import os
24
+ import re
25
+ import sys
26
+ import textwrap
27
+ import urllib.request
28
+ from typing import Any, Dict, List, Optional
29
+
30
+ from openai import OpenAI
31
+
32
+ # Add parent to path for imports
33
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
+ try:
35
+ from annotation_qa_env.models import AnnotationQAAction, AnnotationQAObservation
36
+ from annotation_qa_env.server.environment import AnnotationQAEnvironment
37
+ except ImportError:
38
+ from models import AnnotationQAAction, AnnotationQAObservation
39
+ from server.environment import AnnotationQAEnvironment
40
+
41
+ # ──────────────────────────────────────────────
42
+ # Configuration
43
+ # ──────────────────────────────────────────────
44
+
45
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
46
+
47
+ # We test OPENAI_API_KEY natively per spec requirement, falling back to HF_TOKEN for Serverless Inference.
48
+ API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
49
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
50
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-VL-72B-Instruct")
51
+
52
+ BENCHMARK = "annotation_qa_env"
53
+ TASKS = ["remove_spurious", "fix_classes", "find_missing"]
54
+ MAX_STEPS_PER_TASK = {"remove_spurious": 15, "fix_classes": 20, "find_missing": 30}
55
+ TEMPERATURE = 0.2
56
+ MAX_TOKENS = 1500
57
+ SUCCESS_SCORE_THRESHOLD = 0.1
58
+
59
+ # Raw Image cache
60
+ _raw_image_cache = {}
61
+
62
+ SYSTEM_PROMPT = textwrap.dedent("""
63
+ You are a highly precise AI visual inspector reviewing annotated datasets.
64
+ You will be provided an image containing multiple drawn objects.
65
+ Every object has a thick colored bounding box and a distinct label showing `[ID: <number> | <class_label>]`.
66
+
67
+ Your task is to analyze EVERY SINGLE box drawn on the image systematically and check for errors, policy violations, incorrect attributes, or completely missing background objects.
68
+
69
+ IF the box tightly binds the object, the label is exactly correct, and it does not violate any safety policies, its status is KEEP.
70
+
71
+ You MUST respond strictly with a line-by-line list grading every single ID you see on the screen.
72
+ You may also append FLAG_MISSING commands at the very end of your list for objects that the annotator forgot to draw a box around.
73
+
74
+ Use EXACTLY this format and nothing else:
75
+
76
+ ID <number>: KEEP
77
+ ID <number>: CHANGE_CLASS <new_correct_class_name>
78
+ ID <number>: REMOVE
79
+ ID <number>: FLAG_SAFETY
80
+ ID <number>: CHANGE_ATTRIBUTE <new_attribute_name>
81
+ FLAG_MISSING: <missing_class_name>
82
+
83
+ Example Output:
84
+ ID 0: KEEP
85
+ ID 1: CHANGE_CLASS truck
86
+ ID 2: REMOVE
87
+ ID 3: FLAG_SAFETY
88
+ ID 14: KEEP
89
+ ID 15: CHANGE_ATTRIBUTE red skateboard
90
+ FLAG_MISSING: person
91
+ FLAG_MISSING: bicycle
92
+
93
+ Do NOT Output any other text, no intro, no json, no explanation. Just the list.
94
+ """).strip()
95
+
96
+ # ──────────────────────────────────────────────
97
+ # Logging helpers
98
+ # ──────────────────────────────────────────────
99
+
100
+ def log_start(task: str, env: str, model: str) -> None:
101
+ print(f"[START] task={task} env={env} model={model}", flush=True)
102
+
103
+
104
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
105
+ error_val = error if error else "null"
106
+ done_val = str(done).lower()
107
+ print(
108
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
109
+ flush=True,
110
+ )
111
+
112
+
113
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
114
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
115
+ print(
116
+ f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
117
+ flush=True,
118
+ )
119
+
120
+
121
+ # ──────────────────────────────────────────────
122
+ # Image Overlays
123
+ # ────────────────────────────────────────���─────
124
+
125
+ def get_base_image(image_url: str, max_dim: int = 768):
126
+ from PIL import Image
127
+
128
+ if image_url in _raw_image_cache:
129
+ return _raw_image_cache[image_url]
130
+
131
+ try:
132
+ req = urllib.request.Request(image_url, headers={"User-Agent": "AnnotationQA/1.0"})
133
+ with urllib.request.urlopen(req, timeout=30) as resp:
134
+ img_bytes = resp.read()
135
+
136
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
137
+ w, h = img.size
138
+ # For 72B VQA, higher resolution is better. Scale proportionally.
139
+ if max(w, h) > max_dim:
140
+ scale = max_dim / max(w, h)
141
+ new_w, new_h = int(w * scale), int(h * scale)
142
+ img = img.resize((new_w, new_h), Image.LANCZOS)
143
+
144
+ _raw_image_cache[image_url] = img
145
+ return img
146
+ except Exception as e:
147
+ print(f"[DEBUG] Failed to fetch image {image_url}: {e}", flush=True)
148
+ return None
149
+
150
+
151
+ def fetch_annotated_image_as_base64(obs: AnnotationQAObservation, debug_save: bool = False) -> str:
152
+ try:
153
+ from PIL import ImageDraw, ImageFont
154
+ except ImportError:
155
+ return ""
156
+
157
+ img = get_base_image(obs.image_url)
158
+ if img is None:
159
+ return ""
160
+
161
+ canvas = img.copy()
162
+ draw = ImageDraw.Draw(canvas, "RGBA")
163
+ w, h = canvas.size
164
+
165
+ try:
166
+ fontsize = max(14, int(h * 0.03))
167
+ try:
168
+ font = ImageFont.truetype("arial.ttf", fontsize)
169
+ except OSError:
170
+ try:
171
+ font = ImageFont.truetype("DejaVuSans.ttf", fontsize)
172
+ except OSError:
173
+ font = ImageFont.load_default()
174
+ except Exception:
175
+ font = ImageFont.load_default()
176
+
177
+ colors = [
178
+ (0, 255, 0, 255), # Green
179
+ (255, 165, 0, 255), # Orange
180
+ (0, 255, 255, 255), # Cyan
181
+ (255, 0, 255, 255), # Magenta
182
+ (255, 255, 0, 255), # Yellow
183
+ ]
184
+
185
+ for ann in obs.annotations:
186
+ color = colors[ann.id % len(colors)]
187
+ x_norm, y_norm, w_norm, h_norm = ann.bbox
188
+
189
+ x0 = int(x_norm * w)
190
+ y0 = int(y_norm * h)
191
+ x1 = int((x_norm + w_norm) * w)
192
+ y1 = int((y_norm + h_norm) * h)
193
+
194
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=4)
195
+
196
+ label_text = f" ID:{ann.id} | {ann.class_label} "
197
+ try:
198
+ bbox = font.getbbox(label_text)
199
+ text_w = bbox[2] - bbox[0]
200
+ text_h = bbox[3] - bbox[1]
201
+ except AttributeError:
202
+ text_w, text_h = 60, 15
203
+
204
+ bg_rect = [x0, max(0, y0 - text_h - 4), x0 + text_w, y0]
205
+ draw.rectangle(bg_rect, fill=color)
206
+ draw.text((x0, max(0, y0 - text_h - 4)), label_text, fill=(0,0,0,255), font=font)
207
+
208
+ if debug_save:
209
+ canvas.save("debug_overlay_test.jpg")
210
+
211
+ buf = io.BytesIO()
212
+ canvas.save(buf, format="JPEG", quality=85)
213
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
214
+
215
+
216
+ # ──────────────────────────────────────────────
217
+ # Prompt building
218
+ # ──────────────────────────────────────────────
219
+
220
+ def build_user_content(obs: AnnotationQAObservation) -> list:
221
+ content_blocks = []
222
+
223
+ if obs.image_url:
224
+ save_debug = (obs.step_count == 0)
225
+ b64 = fetch_annotated_image_as_base64(obs, debug_save=save_debug)
226
+ if b64:
227
+ content_blocks.append({
228
+ "type": "image_url",
229
+ "image_url": {
230
+ "url": f"data:image/jpeg;base64,{b64}",
231
+ },
232
+ })
233
+
234
+ # Prepare an inventory list of existing IDs so the VLM knows what needs checking
235
+ inventory = [f"ID {a.id}: {a.class_label}" for a in obs.annotations]
236
+
237
+ text = f"""Please analyze this image. The bounding boxes are clearly drawn with their current labels.
238
+ All valid standard COCO Classes are supported.
239
+
240
+ Here is the inventory of boxes on screen you MUST review:
241
+ { chr(10).join(inventory) }
242
+
243
+ Provide your final line-by-line grading of every ID now:
244
+ """
245
+ content_blocks.append({
246
+ "type": "text",
247
+ "text": text,
248
+ })
249
+
250
+ return content_blocks
251
+
252
+
253
+ def parse_vqa_actions(response_text: str) -> List[AnnotationQAAction]:
254
+ """Parse the line-by-line plain text output into distinct discrete actions."""
255
+ text = response_text.strip()
256
+ actions = []
257
+
258
+ # regex match for "ID X: CHANGE_CLASS dog" or "ID Y: REMOVE"
259
+ lines = text.split('\n')
260
+ for line in lines:
261
+ line = line.strip()
262
+
263
+ # 1. Check for FLAG_MISSING (which doesn't have an ID)
264
+ match_missing = re.search(r'FLAG_MISSING:\s*(.+)', line, re.IGNORECASE)
265
+ if match_missing:
266
+ m_class = match_missing.group(1).strip().lower()
267
+ actions.append(AnnotationQAAction(
268
+ action_type="flag_missing",
269
+ missing_class=m_class
270
+ ))
271
+ continue
272
+
273
+ # 2. Check for ID-based commands
274
+ match = re.search(r'ID\s*(\d+)[:\-\s]+(.+)', line, re.IGNORECASE)
275
+ if not match:
276
+ continue
277
+
278
+ ann_id = int(match.group(1))
279
+ instruction = match.group(2).strip().upper()
280
+
281
+ if instruction.startswith("REMOVE"):
282
+ actions.append(AnnotationQAAction(
283
+ action_type="remove_annotation",
284
+ annotation_id=ann_id
285
+ ))
286
+ elif instruction.startswith("CHANGE_CLASS") or instruction.startswith("CHANGE"):
287
+ parts = instruction.split()
288
+ if len(parts) > 1:
289
+ new_class = " ".join(parts[1:]).lower()
290
+ actions.append(AnnotationQAAction(
291
+ action_type="change_class",
292
+ annotation_id=ann_id,
293
+ new_class=new_class
294
+ ))
295
+ elif instruction.startswith("FLAG_SAFETY"):
296
+ actions.append(AnnotationQAAction(
297
+ action_type="flag_safety",
298
+ annotation_id=ann_id
299
+ ))
300
+ elif instruction.startswith("CHANGE_ATTRIBUTE"):
301
+ parts = instruction.split()
302
+ if len(parts) > 1:
303
+ new_attr = " ".join(parts[1:]).lower()
304
+ actions.append(AnnotationQAAction(
305
+ action_type="change_attribute",
306
+ annotation_id=ann_id,
307
+ new_attribute=new_attr
308
+ ))
309
+
310
+ return actions
311
+
312
+
313
+ # ──────────────────────────────────────────────
314
+ # Execution logic
315
+ # ──────────────────────────────────────────────
316
+
317
+ def get_vqa_actions(client: OpenAI, obs: AnnotationQAObservation) -> List[AnnotationQAAction]:
318
+ user_content = build_user_content(obs)
319
+ try:
320
+ completion = client.chat.completions.create(
321
+ model=MODEL_NAME,
322
+ messages=[
323
+ {"role": "system", "content": SYSTEM_PROMPT},
324
+ {"role": "user", "content": user_content},
325
+ ],
326
+ temperature=TEMPERATURE,
327
+ max_tokens=MAX_TOKENS,
328
+ stream=False,
329
+ )
330
+ response_text = completion.choices[0].message.content or ""
331
+ print(f"[DEBUG] VLM Output:\n{response_text}\n", flush=True)
332
+ return parse_vqa_actions(response_text)
333
+ except Exception as exc:
334
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
335
+ return []
336
+
337
+
338
+ def run_task(client: OpenAI, env: AnnotationQAEnvironment, task_name: str) -> float:
339
+ global _raw_image_cache
340
+ _raw_image_cache = {}
341
+
342
+ obs = env.reset(task=task_name, seed=42)
343
+ max_steps = MAX_STEPS_PER_TASK.get(task_name, 20)
344
+ rewards: List[float] = []
345
+ steps_taken = 0
346
+ score = 0.0
347
+ success = False
348
+
349
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
350
+
351
+ try:
352
+ # 1. ONE-SHOT VISUAL INSPECTION
353
+ # The script makes exactly ONE api call to grade the image
354
+ actions_to_take = get_vqa_actions(client, obs)
355
+
356
+ # 2. LOCAL SEQUENTIAL EXECUTION
357
+ # Loop through actions independently locally
358
+ for action in actions_to_take:
359
+ if obs.done or steps_taken >= max_steps:
360
+ break
361
+
362
+ steps_taken += 1
363
+ action_str = f"{action.action_type}("
364
+ if action.annotation_id is not None:
365
+ action_str += f"id={action.annotation_id}"
366
+ if action.new_class:
367
+ action_str += f" cls={action.new_class}"
368
+ if action.new_attribute:
369
+ action_str += f" attr={action.new_attribute}"
370
+ if action.missing_class:
371
+ action_str += f" missing={action.missing_class}"
372
+ action_str += ")"
373
+
374
+ obs = env.step(action)
375
+ reward = obs.reward if obs.reward is not None else 0.0
376
+ rewards.append(reward)
377
+
378
+ log_step(steps_taken, action_str, reward, obs.done, obs.last_action_error)
379
+
380
+ # 3. SUBMIT
381
+ if not obs.done and steps_taken < max_steps:
382
+ steps_taken += 1
383
+ obs = env.step(AnnotationQAAction(action_type="submit"))
384
+ reward = obs.reward if obs.reward is not None else 0.0
385
+ rewards.append(reward)
386
+ log_step(steps_taken, "submit", reward, obs.done, obs.last_action_error)
387
+
388
+ if rewards: score = rewards[-1]
389
+ score = max(0.0, min(1.0, score))
390
+ success = score >= SUCCESS_SCORE_THRESHOLD
391
+
392
+ except Exception as exc:
393
+ print(f"[DEBUG] Task {task_name} error: {exc}", flush=True)
394
+
395
+ log_end(success, steps_taken, score, rewards)
396
+ return score
397
+
398
+
399
+ def main() -> None:
400
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY, timeout=600.0)
401
+ env = AnnotationQAEnvironment()
402
+
403
+ total_score = 0.0
404
+ for task_name in TASKS:
405
+ print(f"\n{'='*60}", flush=True)
406
+ print(f"Running task: {task_name} (VLM: {MODEL_NAME})", flush=True)
407
+ print(f"{'='*60}", flush=True)
408
+ score = run_task(client, env, task_name)
409
+ total_score += score
410
+ print(f"Task {task_name} score: {score:.3f}\n", flush=True)
411
+
412
+ avg_score = total_score / len(TASKS)
413
+ print(f"\n{'='*60}", flush=True)
414
+ print(f"Average score across {len(TASKS)} tasks: {avg_score:.3f}", flush=True)
415
+ print(f"{'='*60}", flush=True)
416
+
417
+ if __name__ == "__main__":
418
+ main()
models.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Annotation QA Environment β€” Type-Safe Models.
3
+
4
+ Defines the API contract for the Annotation QA Environment:
5
+ - AnnotationQAAction: What corrections the agent can make
6
+ - AnnotationQAObservation: What the agent sees (image + annotations)
7
+ - AnnotationQAState: Episode metadata
8
+
9
+ The agent reviews intentionally-flawed annotations on real COCO val2017 images
10
+ and must fix bounding boxes, correct class labels, add missing annotations,
11
+ or remove spurious ones. A VLM (Vision-Language Model) is used to visually
12
+ inspect the images.
13
+ """
14
+
15
+ from typing import Any, Dict, List, Literal, Optional
16
+
17
+ from pydantic import BaseModel, Field
18
+
19
+
20
+ # ──────────────────────────────────────────────
21
+ # Annotation data structure
22
+ # ──────────────────────────────────────────────
23
+
24
+ class Annotation(BaseModel):
25
+ """A single annotation: bounding box + class label."""
26
+ id: int
27
+ bbox: List[float] = Field(
28
+ ...,
29
+ description="Bounding box as [x, y, w, h] normalized to 0.0–1.0",
30
+ min_length=4,
31
+ max_length=4,
32
+ )
33
+ class_label: str = Field(..., description="Object class label, e.g. 'car', 'person'")
34
+
35
+
36
+ # ──────────────────────────────────────────────
37
+ # Action
38
+ # ──────────────────────────────────────────────
39
+
40
+ class AnnotationQAAction(BaseModel):
41
+ """
42
+ An action the agent can take to correct annotations.
43
+
44
+ action_type determines which fields are required:
45
+ - "adjust_bbox": requires annotation_id, new_bbox
46
+ - "change_class": requires annotation_id, new_class
47
+ - "add_annotation": requires new_bbox, new_class
48
+ - "remove_annotation": requires annotation_id
49
+ - "submit": no extra fields needed (finalizes episode)
50
+ """
51
+ action_type: Literal[
52
+ "adjust_bbox",
53
+ "change_class",
54
+ "remove_annotation",
55
+ "add_annotation",
56
+ "submit",
57
+ "flag_safety",
58
+ "change_attribute",
59
+ "flag_missing",
60
+ ]
61
+ annotation_id: Optional[int] = Field(
62
+ None, description="ID of the annotation to modify"
63
+ )
64
+ new_bbox: Optional[List[float]] = Field(
65
+ None,
66
+ description="New bounding box [x, y, w, h] in 0.0–1.0",
67
+ min_length=4,
68
+ max_length=4,
69
+ )
70
+ new_class: Optional[str] = Field(
71
+ None, description="New class label"
72
+ )
73
+ new_attribute: Optional[str] = Field(
74
+ None, description="New attribute description for an object"
75
+ )
76
+ missing_class: Optional[str] = Field(
77
+ None, description="Class of an object that was missing bounding boxes"
78
+ )
79
+ metadata: Dict[str, Any] = Field(default_factory=dict)
80
+
81
+
82
+ # ──────────────────────────────────────────────
83
+ # Observation
84
+ # ──────────────────────────────────────────────
85
+
86
+ class AnnotationQAObservation(BaseModel):
87
+ """
88
+ What the agent sees after each step.
89
+
90
+ Includes the image URL, scene description, current annotations (some may
91
+ be wrong), available classes, and progress info. The VLM agent uses the
92
+ image_url to visually inspect the scene.
93
+ """
94
+ done: bool = False
95
+ reward: Optional[float] = None
96
+
97
+ # Image information (real COCO val2017)
98
+ image_url: Optional[str] = Field(
99
+ None, description="Public URL to the COCO val2017 image"
100
+ )
101
+ image_width: int = Field(0, description="Image width in pixels")
102
+ image_height: int = Field(0, description="Image height in pixels")
103
+
104
+ # Scene information
105
+ scene_description: str = Field(
106
+ "", description="Natural-language description of the scene and its objects"
107
+ )
108
+ scene_objects: List[Dict[str, Any]] = Field(
109
+ default_factory=list,
110
+ description="Ground-truth object list with positions (visible to agent as scene context)",
111
+ )
112
+
113
+ # Current annotations (may contain errors)
114
+ annotations: List[Annotation] = Field(
115
+ default_factory=list,
116
+ description="Current annotations the agent should review/fix",
117
+ )
118
+
119
+ # Task context
120
+ available_classes: List[str] = Field(
121
+ default_factory=list,
122
+ description="Valid class labels for this task (COCO 80 categories)",
123
+ )
124
+ task_id: str = ""
125
+ task_description: str = ""
126
+
127
+ # Progress
128
+ corrections_made: int = 0
129
+ step_count: int = 0
130
+ max_steps: int = 20
131
+
132
+ # Feedback
133
+ message: str = ""
134
+ last_action_error: Optional[str] = None
135
+
136
+
137
+ # ──────────────────────────────────────────────
138
+ # State
139
+ # ──────────────────────────────────────────────
140
+
141
+ class AnnotationQAState(BaseModel):
142
+ """Episode metadata β€” internal state tracked by the environment."""
143
+ episode_id: Optional[str] = None
144
+ step_count: int = 0
145
+ task_id: str = ""
146
+ sample_id: str = ""
147
+ initial_quality: float = 0.0
148
+ current_quality: float = 0.0
149
+ corrections_made: int = 0
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: annotation_qa_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-annotation-qa-env"
7
+ version = "0.2.0"
8
+ description = "Annotation QA Environment for OpenEnv β€” AI agent reviews and corrects flawed ML annotations on real COCO val2017 images using a VLM"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ # Core OpenEnv dependencies
12
+ "openenv-core[core]>=0.2.2",
13
+ "fastapi>=0.115.0",
14
+ "pydantic>=2.0.0",
15
+ "uvicorn>=0.24.0",
16
+ "requests>=2.31.0",
17
+ "openai>=1.0.0",
18
+ "Pillow>=10.0.0",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ dev = [
23
+ "pytest>=8.0.0",
24
+ "pytest-cov>=4.0.0",
25
+ ]
26
+
27
+ [project.scripts]
28
+ server = "annotation_qa_env.server.app:main"
29
+
30
+ [tool.setuptools]
31
+ include-package-data = true
32
+ packages = [
33
+ "annotation_qa_env",
34
+ "annotation_qa_env.server",
35
+ "annotation_qa_env.data",
36
+ ]
37
+ [tool.setuptools.package-dir]
38
+ "annotation_qa_env" = "."
39
+ "annotation_qa_env.server" = "server"
40
+ "annotation_qa_env.data" = "data"
41
+
42
+ [tool.setuptools.package-data]
43
+ "annotation_qa_env.data" = ["tasks/**/*.json"]
uv.lock ADDED
The diff for this file is too large to render. See raw diff