Siteshcodes commited on
Commit
30f8f3a
·
1 Parent(s): 6ff231e

fix: add MAX_STEPS and rate limit delay to baseline

Browse files
Files changed (2) hide show
  1. baseline.py +5 -2
  2. client.py +15 -47
baseline.py CHANGED
@@ -7,6 +7,7 @@ import json
7
  from groq import Groq
8
  from client import BugTriageClient
9
  from model import TriageAction
 
10
 
11
  # ── config ─────────────────────────────────────────────────
12
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
@@ -90,8 +91,9 @@ def main():
90
 
91
  with BugTriageClient() as env:
92
  obs = env.reset()
93
-
94
- while not obs.done:
 
95
  task = obs.task_id
96
  print(f"\n── Task: {task.upper()} ──")
97
  print(f" Bug: {obs.bug_report.title}")
@@ -112,6 +114,7 @@ def main():
112
 
113
  scores[task] = result.reward
114
  step_count += 1
 
115
 
116
  print("\n" + "=" * 50)
117
  print(" BASELINE SCORES")
 
7
  from groq import Groq
8
  from client import BugTriageClient
9
  from model import TriageAction
10
+ import time
11
 
12
  # ── config ─────────────────────────────────────────────────
13
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
91
 
92
  with BugTriageClient() as env:
93
  obs = env.reset()
94
+ MAX_STEPS = 3
95
+ step_count = 0
96
+ while not obs.done and step_count < MAX_STEPS:
97
  task = obs.task_id
98
  print(f"\n── Task: {task.upper()} ──")
99
  print(f" Bug: {obs.bug_report.title}")
 
114
 
115
  scores[task] = result.reward
116
  step_count += 1
117
+ time.sleep(2)
118
 
119
  print("\n" + "=" * 50)
120
  print(" BASELINE SCORES")
client.py CHANGED
@@ -1,7 +1,6 @@
1
  # client.py
2
  import os
3
  import requests
4
- from dataclasses import asdict
5
  from typing import Optional
6
  from model import TriageAction, TriageObservation, BugReport
7
 
@@ -16,38 +15,18 @@ class StepResult:
16
 
17
  def _parse_observation(data: dict) -> TriageObservation:
18
  bug_data = data["bug_report"]
19
- bug = BugReport(
20
- id=bug_data["id"],
21
- title=bug_data["title"],
22
- body=bug_data["body"],
23
- author=bug_data["author"],
24
- labels_hint=bug_data.get("labels_hint", []),
25
- comments=bug_data.get("comments", []),
26
- )
27
  return TriageObservation(
28
  bug_report=bug,
29
- task_id=data["task_id"],
30
- score=data["score"],
31
- feedback=data["feedback"],
32
- done=data["done"],
33
- reward=data["reward"],
34
  )
35
 
36
 
37
  class BugTriageClient:
38
- """
39
- HTTP REST client for BugTriageEnvironment.
40
- Uses POST /reset and POST /step endpoints.
41
-
42
- Usage:
43
- with BugTriageClient() as env:
44
- obs = env.reset()
45
- while not obs.done:
46
- action = TriageAction(...)
47
- result = env.step(action)
48
- obs = result.observation
49
- """
50
-
51
  def __init__(self, base_url: Optional[str] = None):
52
  self.base_url = (
53
  base_url
@@ -57,27 +36,20 @@ class BugTriageClient:
57
  self.session.headers.update({"Content-Type": "application/json"})
58
 
59
  def reset(self) -> TriageObservation:
60
- """Call POST /reset to start a new episode."""
61
- response = self.session.post(
62
- f"{self.base_url}/reset",
63
- json={},
64
- timeout=30,
65
- )
66
  response.raise_for_status()
67
  data = response.json()
68
- return _parse_observation(data["observation"])
 
69
 
70
  def step(self, action: TriageAction) -> StepResult:
71
- """Call POST /step with the triage action."""
72
- payload = {"action": asdict(action)}
73
- response = self.session.post(
74
- f"{self.base_url}/step",
75
- json=payload,
76
- timeout=30,
77
- )
78
  response.raise_for_status()
79
  data = response.json()
80
- obs = _parse_observation(data["observation"])
 
81
  return StepResult(
82
  observation=obs,
83
  reward=data.get("reward", obs.reward) or 0.0,
@@ -86,11 +58,7 @@ class BugTriageClient:
86
  )
87
 
88
  def state(self) -> dict:
89
- """Call GET /state to get current environment state."""
90
- response = self.session.get(
91
- f"{self.base_url}/state",
92
- timeout=30,
93
- )
94
  response.raise_for_status()
95
  return response.json()
96
 
 
1
  # client.py
2
  import os
3
  import requests
 
4
  from typing import Optional
5
  from model import TriageAction, TriageObservation, BugReport
6
 
 
15
 
16
  def _parse_observation(data: dict) -> TriageObservation:
17
  bug_data = data["bug_report"]
18
+ bug = BugReport(**bug_data)
 
 
 
 
 
 
 
19
  return TriageObservation(
20
  bug_report=bug,
21
+ task_id=data.get("task_id", "easy"),
22
+ score=data.get("score", 0.0),
23
+ feedback=data.get("feedback", ""),
24
+ done=data.get("done", False),
25
+ reward=data.get("reward", 0.0),
26
  )
27
 
28
 
29
  class BugTriageClient:
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def __init__(self, base_url: Optional[str] = None):
31
  self.base_url = (
32
  base_url
 
36
  self.session.headers.update({"Content-Type": "application/json"})
37
 
38
  def reset(self) -> TriageObservation:
39
+ response = self.session.post(f"{self.base_url}/reset", json={}, timeout=30)
 
 
 
 
 
40
  response.raise_for_status()
41
  data = response.json()
42
+ obs_data = data.get("observation", data)
43
+ return _parse_observation(obs_data)
44
 
45
  def step(self, action: TriageAction) -> StepResult:
46
+ from dataclasses import asdict
47
+ payload = {"action": action.dict()}
48
+ response = self.session.post(f"{self.base_url}/step", json=payload, timeout=30)
 
 
 
 
49
  response.raise_for_status()
50
  data = response.json()
51
+ obs_data = data.get("observation", data)
52
+ obs = _parse_observation(obs_data)
53
  return StepResult(
54
  observation=obs,
55
  reward=data.get("reward", obs.reward) or 0.0,
 
58
  )
59
 
60
  def state(self) -> dict:
61
+ response = self.session.get(f"{self.base_url}/state", timeout=30)
 
 
 
 
62
  response.raise_for_status()
63
  return response.json()
64