Siteshcodes commited on
Commit
9e050fb
Β·
1 Parent(s): 30f8f3a

add inference.py for hackathon submission

Browse files
Files changed (1) hide show
  1. inference.py +227 -0
inference.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py β€” Bug Triage Env
3
+ OpenEnv Hackathon submission inference script.
4
+
5
+ Required env vars:
6
+ API_BASE_URL LLM endpoint (default: HuggingFace router)
7
+ MODEL_NAME Model identifier
8
+ HF_TOKEN HuggingFace / API key
9
+ ENV_BASE_URL Bug Triage env URL (default: HF Space)
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import time
15
+ import textwrap
16
+ from typing import List, Optional
17
+
18
+ from openai import OpenAI
19
+ from client import BugTriageClient
20
+ from model import TriageAction
21
+
22
+ # ── config ──────────────────────────────────────────────────────────────
23
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
24
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
25
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
26
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://siteshcodes-bug-triage-env.hf.space")
27
+
28
+ TASK_NAME = "bug-triage"
29
+ BENCHMARK = "bug-triage-env"
30
+ MAX_STEPS = 3
31
+ TEMPERATURE = 0.0
32
+ MAX_TOKENS = 400
33
+ SUCCESS_SCORE_THRESHOLD = 0.4
34
+
35
+ SYSTEM_PROMPT = textwrap.dedent("""
36
+ You are a senior software engineering manager.
37
+ You will receive a bug report and must triage it. Respond ONLY with
38
+ valid JSON β€” no markdown, no explanation, no backticks.
39
+
40
+ Return exactly this structure:
41
+ {
42
+ "priority": "P0",
43
+ "labels": ["bug"],
44
+ "assigned_team": "backend",
45
+ "milestone": "hotfix",
46
+ "reasoning": "one sentence explaining your decision"
47
+ }
48
+
49
+ Priority guide:
50
+ P0 β€” production down, data loss, security vulnerability, 100% user impact
51
+ P1 β€” major feature broken, significant user impact, no workaround
52
+ P2 β€” degraded experience, workaround exists
53
+ P3 β€” minor, cosmetic, docs, low impact
54
+
55
+ Teams: backend | frontend | infra | security | devx
56
+ Milestones: hotfix | v2.1 | backlog
57
+ """).strip()
58
+
59
+
60
+ # ── logging helpers ──────────────────────────────────────────────────────
61
+
62
+ def log_start(task: str, env: str, model: str) -> None:
63
+ print(f"[START] task={task} env={env} model={model}", flush=True)
64
+
65
+
66
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
67
+ error_val = error if error else "null"
68
+ done_val = str(done).lower()
69
+ print(
70
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
71
+ flush=True,
72
+ )
73
+
74
+
75
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
76
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
77
+ print(
78
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
79
+ flush=True,
80
+ )
81
+
82
+
83
+ # ── model call ───────────────────────────────────────────────────────────
84
+
85
+ def format_bug(obs) -> str:
86
+ bug = obs.bug_report
87
+ comments = "\n".join(f" - {c}" for c in bug.comments) or " None"
88
+ return (
89
+ f"Title: {bug.title}\n\n"
90
+ f"Description:\n{bug.body}\n\n"
91
+ f"Existing labels: {', '.join(bug.labels_hint) or 'none'}\n"
92
+ f"Comments:\n{comments}"
93
+ )
94
+
95
+
96
+ def call_model(client: OpenAI, bug_text: str) -> TriageAction:
97
+ try:
98
+ completion = client.chat.completions.create(
99
+ model=MODEL_NAME,
100
+ messages=[
101
+ {"role": "system", "content": SYSTEM_PROMPT},
102
+ {"role": "user", "content": bug_text},
103
+ ],
104
+ temperature=TEMPERATURE,
105
+ max_tokens=MAX_TOKENS,
106
+ stream=False,
107
+ )
108
+ raw = (completion.choices[0].message.content or "").strip()
109
+
110
+ # strip accidental markdown fences
111
+ if raw.startswith("```"):
112
+ raw = raw.split("```")[1]
113
+ if raw.startswith("json"):
114
+ raw = raw[4:]
115
+
116
+ data = json.loads(raw)
117
+ return TriageAction(
118
+ priority=data.get("priority", "P2"),
119
+ labels=data.get("labels", ["bug"]),
120
+ assigned_team=data.get("assigned_team", "backend"),
121
+ milestone=data.get("milestone", "backlog"),
122
+ reasoning=data.get("reasoning", ""),
123
+ )
124
+ except Exception as exc:
125
+ print(f"[DEBUG] Model call failed: {exc}", flush=True)
126
+ # fallback action
127
+ return TriageAction(
128
+ priority="P2",
129
+ labels=["bug"],
130
+ assigned_team="backend",
131
+ milestone="backlog",
132
+ reasoning="fallback due to model error",
133
+ )
134
+
135
+
136
+ # ── main ────────────────────────────────────────────────────────────────��
137
+
138
+ def main() -> None:
139
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
140
+
141
+ rewards: List[float] = []
142
+ steps_taken = 0
143
+ score = 0.0
144
+ success = False
145
+
146
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
147
+
148
+ try:
149
+ with BugTriageClient(base_url=ENV_BASE_URL) as env:
150
+ obs = env.reset()
151
+ step_count = 0
152
+
153
+ while not obs.done and step_count < MAX_STEPS:
154
+ step_count += 1
155
+ task = obs.task_id
156
+
157
+ print(f"\n── Task: {task.upper()} ──", flush=True)
158
+ print(f" Bug: {obs.bug_report.title}", flush=True)
159
+
160
+ bug_text = format_bug(obs)
161
+ action = call_model(client, bug_text)
162
+
163
+ print(f" β†’ Priority: {action.priority}", flush=True)
164
+ print(f" β†’ Labels: {action.labels}", flush=True)
165
+ print(f" β†’ Team: {action.assigned_team}", flush=True)
166
+ print(f" β†’ Milestone: {action.milestone}", flush=True)
167
+
168
+ result = env.step(action)
169
+ obs = result.observation
170
+
171
+ reward = result.reward or 0.0
172
+ done = result.done
173
+ rewards.append(reward)
174
+ steps_taken = step_count
175
+
176
+ # action summary for [STEP] log
177
+ action_str = (
178
+ f"priority={action.priority},"
179
+ f"team={action.assigned_team},"
180
+ f"milestone={action.milestone}"
181
+ )
182
+
183
+ log_step(
184
+ step=step_count,
185
+ action=action_str,
186
+ reward=reward,
187
+ done=done,
188
+ error=None,
189
+ )
190
+
191
+ print(f" βœ“ Reward: {reward:.3f}", flush=True)
192
+ print(f" βœ“ Feedback: {obs.feedback}", flush=True)
193
+
194
+ time.sleep(1) # avoid rate limiting
195
+
196
+ # compute final score
197
+ score = sum(rewards) / MAX_STEPS if MAX_STEPS > 0 else 0.0
198
+ score = min(max(score, 0.0), 1.0)
199
+ success = score >= SUCCESS_SCORE_THRESHOLD
200
+
201
+ # print score table
202
+ task_order = ["easy", "medium", "hard"]
203
+ print("\n" + "=" * 50, flush=True)
204
+ print(" BASELINE SCORES", flush=True)
205
+ print("=" * 50, flush=True)
206
+ for i, task in enumerate(task_order):
207
+ r = rewards[i] if i < len(rewards) else 0.0
208
+ bar = "β–ˆ" * int(r * 20) + "β–‘" * (20 - int(r * 20))
209
+ print(f" {task:<8} {bar} {r:.3f}", flush=True)
210
+ print(f"\n Average score: {score:.3f}", flush=True)
211
+ print("=" * 50, flush=True)
212
+
213
+ except Exception as exc:
214
+ print(f"[DEBUG] Episode error: {exc}", flush=True)
215
+ success = False
216
+
217
+ finally:
218
+ log_end(
219
+ success=success,
220
+ steps=steps_taken,
221
+ score=score,
222
+ rewards=rewards,
223
+ )
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()