Arijit-07 commited on
Commit
39e59a5
Β·
1 Parent(s): e16d919

feat: Add fast mode with auto-fallback (INFERENCE_MODE=fast / --fast / 12s auto-switch)

Browse files
Files changed (1) hide show
  1. inference.py +94 -48
inference.py CHANGED
@@ -6,13 +6,22 @@ MANDATORY env vars:
6
  MODEL_NAME The model identifier
7
  HF_TOKEN Your Hugging Face / API key
8
 
 
 
 
 
 
 
9
  Run:
10
  API_BASE_URL=... MODEL_NAME=... HF_TOKEN=... python inference.py
 
11
  """
12
 
13
  import os
 
14
  import json
15
  import re
 
16
  import textwrap
17
  from typing import Optional
18
 
@@ -23,11 +32,16 @@ from models import Action, ActionType, Observation
23
  from graders.grader import grade_episode
24
 
25
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
26
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
27
- MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
 
 
 
 
 
28
 
29
- TEMPERATURE = 0.1
30
- MAX_TOKENS = 512
31
  FALLBACK_ACTION = Action(action_type=ActionType.NOOP, reason="parse_failure")
32
 
33
  SYSTEM_PROMPT = textwrap.dedent("""
@@ -183,13 +197,66 @@ def parse_action(response_text: str) -> Action:
183
  return FALLBACK_ACTION
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def run_task(client: OpenAI, task_id: str, seed: int = 42) -> dict:
187
  env = DevOpsIncidentEnv(task_id=task_id, seed=seed)
188
  obs = env.reset()
189
 
190
- print(f"[START] task={task_id} seed={seed} model={MODEL_NAME}", flush=True)
 
 
 
 
191
  print(f"\n{'━'*64}")
192
- print(f" Task: {task_id.upper()} | Seed: {seed} | Model: {MODEL_NAME}")
193
  print(f"{'━'*64}")
194
 
195
  done = False
@@ -200,48 +267,20 @@ def run_task(client: OpenAI, task_id: str, seed: int = 42) -> dict:
200
  prompt = observation_to_text(obs)
201
 
202
  try:
203
- reasoning_completion = client.chat.completions.create(
204
- model=MODEL_NAME,
205
- messages=[
206
- {"role": "system", "content": REASONING_PROMPT},
207
- {"role": "user", "content": prompt},
208
- ],
209
- temperature=0.3,
210
- max_tokens=256,
211
- )
212
- reasoning = reasoning_completion.choices[0].message.content or ""
213
-
214
- action_prompt = f"""
215
- Based on your analysis:
216
- {reasoning}
217
-
218
- Now output your action as a JSON object:
219
- {{
220
- "action_type": "...",
221
- "service": "...",
222
- "query": "...",
223
- "root_cause": "...",
224
- "runbook": "...",
225
- "version": "...",
226
- "reason": "one sentence summary"
227
- }}
228
- Output ONLY the JSON object.
229
- """.strip()
230
-
231
- action_completion = client.chat.completions.create(
232
- model=MODEL_NAME,
233
- messages=[
234
- {"role": "system", "content": SYSTEM_PROMPT},
235
- {"role": "user", "content": prompt},
236
- {"role": "assistant", "content": reasoning},
237
- {"role": "user", "content": action_prompt},
238
- ],
239
- temperature=0.1,
240
- max_tokens=200,
241
- )
242
- response_text = action_completion.choices[0].message.content or ""
243
  except Exception as exc:
244
- print(f" Step {step:02d}: API error β€” {exc}")
245
  reasoning = "(error)"
246
  response_text = ""
247
 
@@ -299,15 +338,21 @@ def run_task(client: OpenAI, task_id: str, seed: int = 42) -> dict:
299
 
300
 
301
  def main():
 
302
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
303
 
 
 
 
 
 
304
  results = []
305
  for task_id in ["easy", "medium", "hard", "bonus"]:
306
  r = run_task(client, task_id, seed=42)
307
  results.append(r)
308
 
309
  print(f"\n{'━'*64}")
310
- print(" BASELINE SCORES")
311
  print(f"{'━'*64}")
312
  total = 0.0
313
  for r in results:
@@ -325,3 +370,4 @@ def main():
325
 
326
  if __name__ == "__main__":
327
  main()
 
 
6
  MODEL_NAME The model identifier
7
  HF_TOKEN Your Hugging Face / API key
8
 
9
+ Optional:
10
+ INFERENCE_MODE Set to 'fast' to skip Chain-of-Thought (1 call/step).
11
+ Default is 'cot' (2 calls/step, better scores).
12
+ Auto-switches to fast if any step exceeds STEP_TIMEOUT_S.
13
+ STEP_TIMEOUT_S Max seconds per CoT step before auto-switching (default 12).
14
+
15
  Run:
16
  API_BASE_URL=... MODEL_NAME=... HF_TOKEN=... python inference.py
17
+ API_BASE_URL=... MODEL_NAME=... HF_TOKEN=... python inference.py --fast
18
  """
19
 
20
  import os
21
+ import sys
22
  import json
23
  import re
24
+ import time
25
  import textwrap
26
  from typing import Optional
27
 
 
32
  from graders.grader import grade_episode
33
 
34
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
35
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
36
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
37
+
38
+ # Inference mode: 'cot' (default) or 'fast'
39
+ _mode_env = os.getenv("INFERENCE_MODE", "cot").lower()
40
+ FAST_MODE = _mode_env == "fast" or "--fast" in sys.argv
41
+ STEP_TIMEOUT = float(os.getenv("STEP_TIMEOUT_S", "12")) # seconds; auto-switch threshold
42
 
43
+ TEMPERATURE = 0.1
44
+ MAX_TOKENS = 512
45
  FALLBACK_ACTION = Action(action_type=ActionType.NOOP, reason="parse_failure")
46
 
47
  SYSTEM_PROMPT = textwrap.dedent("""
 
197
  return FALLBACK_ACTION
198
 
199
 
200
+ def _call_fast(client: OpenAI, prompt: str) -> tuple[str, str]:
201
+ """Single-step: one LLM call returns JSON action directly."""
202
+ completion = client.chat.completions.create(
203
+ model=MODEL_NAME,
204
+ messages=[
205
+ {"role": "system", "content": SYSTEM_PROMPT},
206
+ {"role": "user", "content": prompt},
207
+ ],
208
+ temperature=TEMPERATURE,
209
+ max_tokens=MAX_TOKENS,
210
+ )
211
+ response_text = completion.choices[0].message.content or ""
212
+ return response_text, "(fast-mode)"
213
+
214
+
215
+ def _call_cot(client: OpenAI, prompt: str) -> tuple[str, str]:
216
+ """Two-step Chain-of-Thought: reason first, then emit JSON action."""
217
+ reasoning_completion = client.chat.completions.create(
218
+ model=MODEL_NAME,
219
+ messages=[
220
+ {"role": "system", "content": REASONING_PROMPT},
221
+ {"role": "user", "content": prompt},
222
+ ],
223
+ temperature=0.3,
224
+ max_tokens=256,
225
+ )
226
+ reasoning = reasoning_completion.choices[0].message.content or ""
227
+
228
+ action_prompt = (
229
+ f"Based on your analysis:\n{reasoning}\n\n"
230
+ "Now output your action as a JSON object with fields: "
231
+ "action_type, service, query, root_cause, runbook, version, reason.\n"
232
+ "Output ONLY the JSON object."
233
+ )
234
+ action_completion = client.chat.completions.create(
235
+ model=MODEL_NAME,
236
+ messages=[
237
+ {"role": "system", "content": SYSTEM_PROMPT},
238
+ {"role": "user", "content": prompt},
239
+ {"role": "assistant", "content": reasoning},
240
+ {"role": "user", "content": action_prompt},
241
+ ],
242
+ temperature=0.1,
243
+ max_tokens=200,
244
+ )
245
+ response_text = action_completion.choices[0].message.content or ""
246
+ return response_text, reasoning
247
+
248
+
249
  def run_task(client: OpenAI, task_id: str, seed: int = 42) -> dict:
250
  env = DevOpsIncidentEnv(task_id=task_id, seed=seed)
251
  obs = env.reset()
252
 
253
+ # Respect global mode but allow per-task auto-downgrade if API is slow
254
+ use_fast = FAST_MODE
255
+ mode_label = "fast" if use_fast else "cot"
256
+
257
+ print(f"[START] task={task_id} seed={seed} model={MODEL_NAME} mode={mode_label}", flush=True)
258
  print(f"\n{'━'*64}")
259
+ print(f" Task: {task_id.upper()} | Seed: {seed} | Mode: {mode_label.upper()} | Model: {MODEL_NAME}")
260
  print(f"{'━'*64}")
261
 
262
  done = False
 
267
  prompt = observation_to_text(obs)
268
 
269
  try:
270
+ t0 = time.monotonic()
271
+ if use_fast:
272
+ response_text, reasoning = _call_fast(client, prompt)
273
+ else:
274
+ response_text, reasoning = _call_cot(client, prompt)
275
+ elapsed = time.monotonic() - t0
276
+
277
+ # Auto-switch: if CoT step exceeds threshold, go fast for remainder
278
+ if not use_fast and elapsed > STEP_TIMEOUT:
279
+ use_fast = True
280
+ print(f" ⚑ CoT took {elapsed:.1f}s > {STEP_TIMEOUT}s limit β€” switching to fast mode", flush=True)
281
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  except Exception as exc:
283
+ print(f" Step {step:02d}: API error β€” {exc}", flush=True)
284
  reasoning = "(error)"
285
  response_text = ""
286
 
 
338
 
339
 
340
  def main():
341
+ mode_label = "FAST" if FAST_MODE else "COT"
342
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
343
 
344
+ print(f"\n{'━'*64}", flush=True)
345
+ print(f" DevOps Incident Response β€” OpenEnv Baseline", flush=True)
346
+ print(f" Mode: {mode_label} | Timeout: {STEP_TIMEOUT}s | Model: {MODEL_NAME}", flush=True)
347
+ print(f"{'━'*64}", flush=True)
348
+
349
  results = []
350
  for task_id in ["easy", "medium", "hard", "bonus"]:
351
  r = run_task(client, task_id, seed=42)
352
  results.append(r)
353
 
354
  print(f"\n{'━'*64}")
355
+ print(f" BASELINE SCORES [{mode_label} mode]")
356
  print(f"{'━'*64}")
357
  total = 0.0
358
  for r in results:
 
370
 
371
  if __name__ == "__main__":
372
  main()
373
+