File size: 9,816 Bytes
9b47159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d192bd6
9b47159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd482eb
 
 
 
9b47159
 
 
cd482eb
d192bd6
9b47159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e60e2c
9b47159
7e60e2c
9b47159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
Core RL environment implementing the OpenEnv interface.

Provides reset(), step(), and state for the Bug Triage environment.
Thread-safe episode store for concurrent HTTP sessions.
"""

from __future__ import annotations

import json
import logging
import random
import threading
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional

try:
    from core.env_server import Environment as _OpenEnvBase
except ImportError:

    class _OpenEnvBase:  # type: ignore[no-redef]
        pass


from ..graders import GRADERS
from ..models import (
    DEVELOPERS,
    BugReport,
    BugTriageAction,
    BugTriageObservation,
    BugTriageState,
)

logger = logging.getLogger(__name__)

DATA_PATH = Path(__file__).parent.parent / "data" / "bugs.json"


class BugTriageEnvironment(_OpenEnvBase):
    """
    Bug Triage RL Environment.

    Each episode:
      1. reset(task_id) - selects a random bug report
      2. step(action)   - agent triages the bug; grader scores; done=True
      3. state          - returns episode metadata

    Episode store is thread-safe for concurrent HTTP sessions.
    """

    def __init__(self, data_path: Path = DATA_PATH) -> None:
        self._bugs: List[Dict[str, Any]] = self._load_data(data_path)
        self._episodes: Dict[str, Dict[str, Any]] = {}
        self._lock = threading.Lock()
        self._current_episode_id: Optional[str] = None
        logger.info("Loaded %d bug reports.", len(self._bugs))

    # -- OpenEnv interface -----------------------------------------------------

    def reset(self, task_id: str = "task_1") -> BugTriageObservation:
        """Start a new episode. Returns initial observation."""
        bug = random.choice(self._bugs)
        episode_id = uuid.uuid4().hex[:12]

        episode: Dict[str, Any] = {
            "episode_id": episode_id,
            "task_id": task_id,
            "bug": bug,
            "ground_truth": bug["ground_truth"],
            "actions": [],
            "done": False,
        }

        with self._lock:
            self._episodes[episode_id] = episode
            self._current_episode_id = episode_id

        bug_report = self._make_bug_report(bug)
        logger.info(
            "Episode %s | task=%s | bug=%s", episode_id, task_id, bug["bug_id"]
        )

        return BugTriageObservation(
            done=False,
            reward=0.01,
            task_id=task_id,
            bug_report=bug_report,
            available_developers=list(DEVELOPERS),
            step_number=0,
            feedback="New bug report received. Please triage.",
            episode_id=episode_id,
        )

    def step(
        self,
        action: BugTriageAction,
        episode_id: Optional[str] = None,
    ) -> BugTriageObservation:
        """Execute one triage action. Episode terminates immediately."""
        ep_id = episode_id or self._current_episode_id
        if ep_id is None:
            raise ValueError("No active episode. Call reset() first.")

        with self._lock:
            ep = self._episodes.get(ep_id)

        if ep is None:
            raise ValueError(f"Unknown episode_id: {ep_id}")
        if ep["done"]:
            raise ValueError(f"Episode {ep_id} is already done.")

        ep["actions"].append(action.model_dump())
        ep["done"] = True

        # Score via task-specific grader
        grader_fn = GRADERS.get(ep["task_id"], GRADERS["task_1"])
        grader_score = grader_fn(ep["actions"], ep["ground_truth"])

        # Shaped reward: map [0, 1] to [-0.5, 1.0] for GRPO training
        reward = (grader_score * 1.5) - 0.5

        # Confidence calibration bonus/penalty
        calibration_bonus = self._compute_calibration_bonus(
            action.confidence, grader_score
        )
        reward += calibration_bonus
        reward = max(0.01, min(0.99, reward))

        # Build feedback string
        feedback = self._build_feedback(
            action, ep["ground_truth"], ep["task_id"], grader_score
        )
        if action.confidence is not None:
            feedback += (
                f" | confidence={action.confidence:.2f}"
                f" (calibration={calibration_bonus:+.2f})"
            )

        bug_report = self._make_bug_report(ep["bug"])

        log_msg = "Episode %s | score=%.3f | reward=%.3f"
        log_args: list[Any] = [ep_id, grader_score, reward]
        if action.confidence is not None:
            log_msg += " | cal=%+.2f"
            log_args.append(calibration_bonus)
        logger.info(log_msg, *log_args)

        return BugTriageObservation(
            done=True,
            reward=round(reward, 4),
            task_id=ep["task_id"],
            bug_report=bug_report,
            available_developers=list(DEVELOPERS),
            step_number=1,
            feedback=feedback,
            grader_score=round(grader_score, 4),
            episode_id=ep_id,
        )

    @property
    def state(self) -> BugTriageState:
        """Return current episode metadata."""
        ep_id = self._current_episode_id
        if ep_id is None:
            return BugTriageState()

        with self._lock:
            ep = self._episodes.get(ep_id, {})

        return BugTriageState(
            episode_id=ep_id,
            step_count=len(ep.get("actions", [])),
            task_id=ep.get("task_id", ""),
            bug_id=ep.get("bug", {}).get("bug_id"),
        )

    # -- Grading ---------------------------------------------------------------

    def grade_episode(self, episode_id: str, task_id: str) -> Dict[str, Any]:
        """Grade a completed episode. Used by the /grader endpoint."""
        with self._lock:
            ep = self._episodes.get(episode_id)

        if ep is None:
            return {"score": 0.01, "breakdown": {}, "error": "episode_not_found"}
        if not ep["actions"]:
            return {"score": 0.01, "breakdown": {}, "error": "no_actions"}

        grader_fn = GRADERS.get(task_id, GRADERS["task_1"])
        score = grader_fn(ep["actions"], ep["ground_truth"])

        last_action = ep["actions"][-1]
        gt = ep["ground_truth"]
        breakdown: Dict[str, float] = {
            "bug_type_match": float(
                last_action.get("bug_type") == gt.get("bug_type")
            ),
            "priority_match": float(
                last_action.get("priority") == gt.get("priority")
            ),
        }
        if task_id == "task_3":
            breakdown["developer_match"] = float(
                last_action.get("assigned_developer")
                == gt.get("assigned_developer")
            )
            breakdown["action_match"] = float(
                last_action.get("suggested_action")
                == gt.get("suggested_action")
            )

        return {"score": round(score, 4), "breakdown": breakdown}

    # -- Private helpers -------------------------------------------------------

    @staticmethod
    def _load_data(path: Path) -> List[Dict[str, Any]]:
        with open(path, "r", encoding="utf-8") as fh:
            return json.load(fh)

    @staticmethod
    def _make_bug_report(bug: Dict[str, Any]) -> BugReport:
        return BugReport(
            bug_id=bug["bug_id"],
            title=bug["title"],
            description=bug["description"],
            logs=bug.get("logs"),
            environment=bug.get("environment"),
            reporter=bug.get("reporter"),
            created_at=bug.get("created_at"),
            metadata=bug.get("metadata", {}),
        )

    @staticmethod
    def _compute_calibration_bonus(
        confidence: Optional[float],
        grader_score: float,
    ) -> float:
        """
        Compute reward adjustment based on confidence calibration.

        Returns a bonus in [-0.15, +0.10] or 0.0 if confidence is not provided.
        """
        if confidence is None:
            return 0.0

        calibration_error = abs(confidence - grader_score)

        if grader_score >= 0.8 and confidence >= 0.8:
            return 0.10  # correct and confident
        if grader_score < 0.5 and confidence >= 0.8:
            return -0.15  # wrong and overconfident
        if calibration_error < 0.2:
            return 0.05  # well-calibrated
        return -0.05  # poorly calibrated

    @staticmethod
    def _build_feedback(
        action: BugTriageAction,
        ground_truth: Dict[str, Any],
        task_id: str,
        score: float,
    ) -> str:
        parts: list[str] = [f"score={score:.2f}"]

        if task_id in ("task_1", "task_3"):
            match = action.bug_type == ground_truth.get("bug_type")
            parts.append(
                f"bug_type={'ok' if match else 'wrong'}"
                f" (pred={action.bug_type}, expected={ground_truth.get('bug_type')})"
            )
        if task_id in ("task_2", "task_3"):
            match = action.priority == ground_truth.get("priority")
            parts.append(
                f"priority={'ok' if match else 'wrong'}"
                f" (pred={action.priority}, expected={ground_truth.get('priority')})"
            )
        if task_id == "task_3":
            match = action.assigned_developer == ground_truth.get(
                "assigned_developer"
            )
            parts.append(
                f"developer={'ok' if match else 'wrong'}"
                f" (pred={action.assigned_developer},"
                f" expected={ground_truth.get('assigned_developer')})"
            )
            match = action.suggested_action == ground_truth.get(
                "suggested_action"
            )
            parts.append(
                f"action={'ok' if match else 'wrong'}"
                f" (pred={action.suggested_action},"
                f" expected={ground_truth.get('suggested_action')})"
            )

        return " | ".join(parts)