pragunk commited on
Commit
ec6e336
·
verified ·
1 Parent(s): 06fb3eb

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +42 -25
inference.py CHANGED
@@ -1,34 +1,33 @@
1
  import os
2
  import json
3
- from dotenv import load_dotenv # <-- ADDED
4
  from openai import OpenAI
5
  from adaptive_cache.env import AdaptiveCacheEnv, Action
6
 
7
- # Load variables from .env file into the environment
8
- load_dotenv() # <-- ADDED
 
 
 
 
 
 
 
9
 
10
  def run_baseline(task_level: str):
11
- print(f"\n--- Running Baseline for Task: {task_level.upper()} ---")
12
-
13
- # 1. Initialize the official OpenAI client dynamically
14
- # <-- CHANGED: Now pulls from generic LLM_ variables
15
- api_key = os.environ.get("LLM_API_KEY")
16
- base_url = os.environ.get("LLM_BASE_URL")
17
- model_name = os.environ.get("LLM_MODEL_NAME", "gpt-4o-mini")
18
-
19
- if not api_key:
20
- print("ERROR: LLM_API_KEY environment variable not set. Check your .env file.")
21
  return
22
 
 
23
  client = OpenAI(
24
- base_url=base_url,
25
- api_key=api_key
26
  )
27
 
28
  env = AdaptiveCacheEnv(task_level=task_level)
29
  obs = env.reset()
30
  done = False
31
- total_reward = 0.0
32
 
33
  system_prompt = """
34
  You are an intelligent Cache Manager.
@@ -36,12 +35,21 @@ def run_baseline(task_level: str):
36
  Respond ONLY with a JSON object matching this schema: {"evict_index": integer}
37
  """
38
 
 
 
 
 
 
 
 
39
  while not done:
 
 
 
 
40
  try:
41
- # 2. Call the dynamically configured model
42
- # <-- CHANGED: Now uses the model_name variable
43
  response = client.chat.completions.create(
44
- model=model_name,
45
  response_format={ "type": "json_object" },
46
  messages=[
47
  {"role": "system", "content": system_prompt},
@@ -53,18 +61,27 @@ def run_baseline(task_level: str):
53
  content = response.choices[0].message.content
54
  action_dict = json.loads(content)
55
  action = Action(**action_dict)
 
56
 
57
  except Exception as e:
58
- # Failsafe so the script doesn't crash on a bad JSON format
59
- print(f"LLM Parsing failed ({e}). Defaulting to slot 0.")
 
60
  action = Action(evict_index=0)
61
 
62
  obs, reward, done, info = env.step(action)
63
- total_reward += reward
64
 
65
- print(f"Episode Finished.")
66
- print(f"Total Reward: {total_reward}")
67
- print(f"Final Grader Score (Hit Rate): {info.get('score', 0.0):.2f} / 1.00")
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  run_baseline("easy")
 
1
  import os
2
  import json
3
+ from dotenv import load_dotenv
4
  from openai import OpenAI
5
  from adaptive_cache.env import AdaptiveCacheEnv, Action
6
 
7
+ # Load variables from local .env file (for local testing)
8
+ load_dotenv()
9
+
10
+ # 1. STRICT COMPLIANCE: Match the pre-submission checklist exactly
11
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
12
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+
15
+ BENCHMARK = "adaptive-cache"
16
 
17
  def run_baseline(task_level: str):
18
+ if not HF_TOKEN:
19
+ print("ERROR: HF_TOKEN environment variable not set.", flush=True)
 
 
 
 
 
 
 
 
20
  return
21
 
22
+ # Pass the HF_TOKEN to the api_key parameter of the OpenAI client
23
  client = OpenAI(
24
+ base_url=API_BASE_URL,
25
+ api_key=HF_TOKEN
26
  )
27
 
28
  env = AdaptiveCacheEnv(task_level=task_level)
29
  obs = env.reset()
30
  done = False
 
31
 
32
  system_prompt = """
33
  You are an intelligent Cache Manager.
 
35
  Respond ONLY with a JSON object matching this schema: {"evict_index": integer}
36
  """
37
 
38
+ # Trackers required for the grader's END log
39
+ rewards_history = []
40
+ step_count = 0
41
+
42
+ # 2. REQUIRED LOG FORMAT: START
43
+ print(f"[START] task={task_level} env={BENCHMARK} model={MODEL_NAME}", flush=True)
44
+
45
  while not done:
46
+ step_count += 1
47
+ error_msg = "null"
48
+ action_str = ""
49
+
50
  try:
 
 
51
  response = client.chat.completions.create(
52
+ model=MODEL_NAME,
53
  response_format={ "type": "json_object" },
54
  messages=[
55
  {"role": "system", "content": system_prompt},
 
61
  content = response.choices[0].message.content
62
  action_dict = json.loads(content)
63
  action = Action(**action_dict)
64
+ action_str = str(action.evict_index)
65
 
66
  except Exception as e:
67
+ # Format error to be strictly on a single line for the grader
68
+ error_msg = str(e).replace('\n', ' ')
69
+ action_str = "0"
70
  action = Action(evict_index=0)
71
 
72
  obs, reward, done, info = env.step(action)
73
+ rewards_history.append(reward)
74
 
75
+ # 3. REQUIRED LOG FORMAT: STEP
76
+ done_str = str(done).lower()
77
+ print(f"[STEP] step={step_count} action={action_str} reward={reward:.2f} done={done_str} error={error_msg}", flush=True)
78
+
79
+ # 4. REQUIRED LOG FORMAT: END
80
+ score = info.get('score', 0.0)
81
+ success_str = str(score > 0.0).lower()
82
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
83
+
84
+ print(f"[END] success={success_str} steps={step_count} score={score:.3f} rewards={rewards_str}", flush=True)
85
 
86
  if __name__ == "__main__":
87
  run_baseline("easy")