pragunk commited on
Commit
3dd5687
·
verified ·
1 Parent(s): 91d9e5e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +60 -13
inference.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
2
  import json
 
3
  from dotenv import load_dotenv
4
  from openai import OpenAI
5
  from adaptive_cache.env import AdaptiveCacheEnv, Action
6
 
7
- # Load variables from local .env file (for local testing)
8
  load_dotenv()
9
 
10
- # 1. STRICT COMPLIANCE: Match the pre-submission checklist exactly
11
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
12
  MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
13
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -19,7 +20,6 @@ def run_baseline(task_level: str):
19
  print("ERROR: HF_TOKEN environment variable not set.", flush=True)
20
  return
21
 
22
- # Pass the HF_TOKEN to the api_key parameter of the OpenAI client
23
  client = OpenAI(
24
  base_url=API_BASE_URL,
25
  api_key=HF_TOKEN
@@ -29,17 +29,34 @@ def run_baseline(task_level: str):
29
  obs = env.reset()
30
  done = False
31
 
 
 
 
 
 
 
 
 
32
  system_prompt = """
33
- You are an intelligent Cache Manager.
34
  You must decide which cache slot index (0 to 9) to evict.
35
- Respond ONLY with a JSON object matching this schema: {"evict_index": integer}
 
 
 
 
 
 
 
 
 
 
36
  """
37
 
38
- # Trackers required for the grader's END log
39
  rewards_history = []
40
  step_count = 0
41
 
42
- # 2. REQUIRED LOG FORMAT: START
43
  print(f"[START] task={task_level} env={BENCHMARK} model={MODEL_NAME}", flush=True)
44
 
45
  while not done:
@@ -47,36 +64,66 @@ def run_baseline(task_level: str):
47
  error_msg = "null"
48
  action_str = ""
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
  response = client.chat.completions.create(
52
  model=MODEL_NAME,
53
  response_format={ "type": "json_object" },
54
  messages=[
55
  {"role": "system", "content": system_prompt},
56
- {"role": "user", "content": f"Current State: {obs.model_dump_json()}"}
57
  ],
58
  temperature=0.0
59
  )
60
 
61
  content = response.choices[0].message.content
62
  action_dict = json.loads(content)
63
- action = Action(**action_dict)
 
 
 
 
 
 
64
  action_str = str(action.evict_index)
65
 
66
  except Exception as e:
67
- # Format error to be strictly on a single line for the grader
68
  error_msg = str(e).replace('\n', ' ')
69
  action_str = "0"
70
  action = Action(evict_index=0)
71
 
72
- obs, reward, done, info = env.step(action)
 
 
 
 
 
 
 
 
 
 
 
 
73
  rewards_history.append(reward)
74
 
75
- # 3. REQUIRED LOG FORMAT: STEP
76
  done_str = str(done).lower()
77
  print(f"[STEP] step={step_count} action={action_str} reward={reward:.2f} done={done_str} error={error_msg}", flush=True)
78
 
79
- # 4. REQUIRED LOG FORMAT: END
80
  score = info.get('score', 0.0)
81
  success_str = str(score > 0.0).lower()
82
  rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
 
1
  import os
2
  import json
3
+ from collections import deque
4
  from dotenv import load_dotenv
5
  from openai import OpenAI
6
  from adaptive_cache.env import AdaptiveCacheEnv, Action
7
 
8
+ # Load variables from local .env file
9
  load_dotenv()
10
 
11
+ # STRICT COMPLIANCE: Match the pre-submission checklist exactly
12
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
13
  MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
14
  HF_TOKEN = os.getenv("HF_TOKEN")
 
20
  print("ERROR: HF_TOKEN environment variable not set.", flush=True)
21
  return
22
 
 
23
  client = OpenAI(
24
  base_url=API_BASE_URL,
25
  api_key=HF_TOKEN
 
29
  obs = env.reset()
30
  done = False
31
 
32
+ # ---------------------------------------------------------
33
+ # PHASE 2 UPGRADE: Agentic Memory Trackers
34
+ # ---------------------------------------------------------
35
+ # We keep the last 15 steps of history.
36
+ # If the sequence loop is 12 items long, 15 gives the LLM
37
+ # enough vision to realize the pattern is repeating.
38
+ history_window = deque(maxlen=15)
39
+
40
  system_prompt = """
41
+ You are an advanced OS Cache Manager with memory and pattern recognition.
42
  You must decide which cache slot index (0 to 9) to evict.
43
+
44
+ STRATEGY GUIDE:
45
+ 1. Analyze the "Recent History". Are requests looping? If yes, pin some items by refusing to evict them.
46
+ 2. Has the working set shifted entirely? If yes, aggressively evict the oldest items.
47
+ 3. Learn from your past actions: if evicting a slot led to a MISS later, protect that slot!
48
+
49
+ You MUST respond with a JSON object matching this exact schema:
50
+ {
51
+ "reasoning": "A 1-sentence analysis of the history and your strategy",
52
+ "evict_index": integer
53
+ }
54
  """
55
 
 
56
  rewards_history = []
57
  step_count = 0
58
 
59
+ # REQUIRED LOG FORMAT: START
60
  print(f"[START] task={task_level} env={BENCHMARK} model={MODEL_NAME}", flush=True)
61
 
62
  while not done:
 
64
  error_msg = "null"
65
  action_str = ""
66
 
67
+ # Format the memory for the LLM
68
+ history_str = "\n".join(history_window) if history_window else "No history yet. This is the first step."
69
+
70
+ user_prompt = f"""
71
+ --- RECENT HISTORY (Oldest to Newest) ---
72
+ {history_str}
73
+
74
+ --- CURRENT STATE ---
75
+ Current Cache State: {obs.cache_state}
76
+ Idle Times: {obs.idle_times}
77
+ Incoming Request (Needs to be cached): {obs.incoming_request}
78
+ """
79
+
80
  try:
81
  response = client.chat.completions.create(
82
  model=MODEL_NAME,
83
  response_format={ "type": "json_object" },
84
  messages=[
85
  {"role": "system", "content": system_prompt},
86
+ {"role": "user", "content": user_prompt}
87
  ],
88
  temperature=0.0
89
  )
90
 
91
  content = response.choices[0].message.content
92
  action_dict = json.loads(content)
93
+
94
+ # CRITICAL: We extract ONLY the integer and drop the reasoning
95
+ # so Pydantic doesn't throw a validation error.
96
+ # We also DO NOT print the reasoning, keeping the grader happy.
97
+ evict_idx = int(action_dict.get("evict_index", 0))
98
+
99
+ action = Action(evict_index=evict_idx)
100
  action_str = str(action.evict_index)
101
 
102
  except Exception as e:
 
103
  error_msg = str(e).replace('\n', ' ')
104
  action_str = "0"
105
  action = Action(evict_index=0)
106
 
107
+ # Step the environment
108
+ next_obs, reward, done, info = env.step(action)
109
+
110
+ # ---------------------------------------------------------
111
+ # PHASE 2 UPGRADE: Log the outcome into memory
112
+ # ---------------------------------------------------------
113
+ # We record what was requested, what the agent did, and if it worked.
114
+ result_str = "HIT (+1.0)" if reward > 0 else "MISS (-1.0)"
115
+ memory_entry = f"Step {step_count} | Req: {obs.incoming_request} | Agent Evicted Slot: {action_str} | Result: {result_str}"
116
+ history_window.append(memory_entry)
117
+
118
+ # Update observation for the next loop
119
+ obs = next_obs
120
  rewards_history.append(reward)
121
 
122
+ # REQUIRED LOG FORMAT: STEP
123
  done_str = str(done).lower()
124
  print(f"[STEP] step={step_count} action={action_str} reward={reward:.2f} done={done_str} error={error_msg}", flush=True)
125
 
126
+ # REQUIRED LOG FORMAT: END
127
  score = info.get('score', 0.0)
128
  success_str = str(score > 0.0).lower()
129
  rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)