pragunk commited on
Commit
7790e85
·
verified ·
1 Parent(s): 0e6ed14

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +71 -62
inference.py CHANGED
@@ -1,63 +1,72 @@
1
- import os
2
- import json
3
- from openai import OpenAI
4
- from adaptive_cache.env import AdaptiveCacheEnv, Action
5
-
6
- def run_baseline(task_level: str):
7
- print(f"\n--- Running Baseline for Task: {task_level.upper()} ---")
8
-
9
- # 1. Initialize the official OpenAI client, but point it to Groq's free endpoint
10
- api_key = os.environ.get("GROQ_API_KEY")
11
- if not api_key:
12
- print("ERROR: GROQ_API_KEY environment variable not set.")
13
- return
14
-
15
- client = OpenAI(
16
- base_url="https://api.groq.com/openai/v1",
17
- api_key=api_key
18
- )
19
-
20
- env = AdaptiveCacheEnv(task_level=task_level)
21
- obs = env.reset()
22
- done = False
23
- total_reward = 0.0
24
-
25
- system_prompt = """
26
- You are an intelligent Cache Manager.
27
- You must decide which cache slot index (0 to 9) to evict.
28
- Respond ONLY with a JSON object matching this schema: {"evict_index": integer}
29
- """
30
-
31
- while not done:
32
- try:
33
- # 2. Call Groq's high-speed open source model
34
- response = client.chat.completions.create(
35
- model="llama-3.1-8b-instant", # Groq's powerful model
36
- response_format={ "type": "json_object" },
37
- messages=[
38
- {"role": "system", "content": system_prompt},
39
- {"role": "user", "content": f"Current State: {obs.model_dump_json()}"}
40
- ],
41
- temperature=0.0
42
- )
43
-
44
- content = response.choices[0].message.content
45
- action_dict = json.loads(content)
46
- action = Action(**action_dict)
47
-
48
- except Exception as e:
49
- # Failsafe so the script doesn't crash on a bad JSON format
50
- print(f"LLM Parsing failed ({e}). Defaulting to slot 0.")
51
- action = Action(evict_index=0)
52
-
53
- obs, reward, done, info = env.step(action)
54
- total_reward += reward
55
-
56
- print(f"Episode Finished.")
57
- print(f"Total Reward: {total_reward}")
58
- print(f"Final Grader Score (Hit Rate): {info.get('score', 0.0):.2f} / 1.00")
59
-
60
- if __name__ == "__main__":
61
- run_baseline("easy")
62
- run_baseline("medium")
 
 
 
 
 
 
 
 
 
63
  run_baseline("hard")
 
1
+ import os
2
+ import json
3
+ from dotenv import load_dotenv # <-- ADDED
4
+ from openai import OpenAI
5
+ from adaptive_cache.env import AdaptiveCacheEnv, Action
6
+
7
+ # Load variables from .env file into the environment
8
+ load_dotenv() # <-- ADDED
9
+
10
+ def run_baseline(task_level: str):
11
+ print(f"\n--- Running Baseline for Task: {task_level.upper()} ---")
12
+
13
+ # 1. Initialize the official OpenAI client dynamically
14
+ # <-- CHANGED: Now pulls from generic LLM_ variables
15
+ api_key = os.environ.get("LLM_API_KEY")
16
+ base_url = os.environ.get("LLM_BASE_URL")
17
+ model_name = os.environ.get("LLM_MODEL_NAME", "gpt-4o-mini")
18
+
19
+ if not api_key:
20
+ print("ERROR: LLM_API_KEY environment variable not set. Check your .env file.")
21
+ return
22
+
23
+ client = OpenAI(
24
+ base_url=base_url,
25
+ api_key=api_key
26
+ )
27
+
28
+ env = AdaptiveCacheEnv(task_level=task_level)
29
+ obs = env.reset()
30
+ done = False
31
+ total_reward = 0.0
32
+
33
+ system_prompt = """
34
+ You are an intelligent Cache Manager.
35
+ You must decide which cache slot index (0 to 9) to evict.
36
+ Respond ONLY with a JSON object matching this schema: {"evict_index": integer}
37
+ """
38
+
39
+ while not done:
40
+ try:
41
+ # 2. Call the dynamically configured model
42
+ # <-- CHANGED: Now uses the model_name variable
43
+ response = client.chat.completions.create(
44
+ model=model_name,
45
+ response_format={ "type": "json_object" },
46
+ messages=[
47
+ {"role": "system", "content": system_prompt},
48
+ {"role": "user", "content": f"Current State: {obs.model_dump_json()}"}
49
+ ],
50
+ temperature=0.0
51
+ )
52
+
53
+ content = response.choices[0].message.content
54
+ action_dict = json.loads(content)
55
+ action = Action(**action_dict)
56
+
57
+ except Exception as e:
58
+ # Failsafe so the script doesn't crash on a bad JSON format
59
+ print(f"LLM Parsing failed ({e}). Defaulting to slot 0.")
60
+ action = Action(evict_index=0)
61
+
62
+ obs, reward, done, info = env.step(action)
63
+ total_reward += reward
64
+
65
+ print(f"Episode Finished.")
66
+ print(f"Total Reward: {total_reward}")
67
+ print(f"Final Grader Score (Hit Rate): {info.get('score', 0.0):.2f} / 1.00")
68
+
69
+ if __name__ == "__main__":
70
+ run_baseline("easy")
71
+ run_baseline("medium")
72
  run_baseline("hard")