M ShreeRaj commited on
Commit
3f809e4
Β·
unverified Β·
1 Parent(s): dc27fc0

Refactor environment variable handling and client initialization

Browse files
Files changed (1) hide show
  1. inference.py +52 -47
inference.py CHANGED
@@ -28,14 +28,19 @@ def post_json(url: str, payload: dict) -> dict:
28
  except urllib.error.HTTPError as e:
29
  raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
30
 
 
 
 
31
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
33
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
34
- HF_TOKEN = os.getenv("HF_TOKEN")
 
 
35
 
36
  TASK_NAME = "schedule-optimization"
37
  BENCHMARK = "cognitive-load-manager"
38
- SUCCESS_SCORE_THRESHOLD = 0.5 # Need 50% score basically
39
  MAX_STEPS = 50
40
 
41
  def log_start(task: str, env: str, model: str) -> None:
@@ -54,16 +59,15 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
54
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
55
 
56
  def main():
57
- # OpenAI client mapping to Hugging Face router, requiring HF_TOKEN
58
- client = None
59
- if HF_TOKEN:
60
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
61
 
62
- # Initialize Environment
63
  task_id = os.getenv("CLM_LEVEL", "hard")
64
-
65
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
66
-
67
  # 1. Reset Environment
68
  try:
69
  data = post_json(f"{ENV_BASE_URL}/reset", {"task_id": task_id})
@@ -71,20 +75,20 @@ def main():
71
  log_step(step=0, action="reset", reward=0.0, done=True, error=str(e)[:50])
72
  log_end(success=False, steps=0, score=0.0, rewards=[])
73
  return
74
-
75
  session_id = data["session_id"]
76
  observation = data["observation"]
77
-
78
  done = False
79
  step = 0
80
  rewards = []
81
  history = []
82
  info = {}
83
-
84
  while not done and step < MAX_STEPS:
85
  step += 1
86
-
87
- # 2. Extract action via OpenAI interface (pointing to HF)
88
  history_str = "\n".join(history[-5:]) if history else "No previous actions."
89
  system_prompt = """
90
  You are an AI task scheduler managing cognitive load.
@@ -106,52 +110,53 @@ What is your next action JSON?
106
  """
107
  action = None
108
  error_msg = None
109
-
110
- if client:
111
- try:
112
- completion = client.chat.completions.create(
113
- model=MODEL_NAME,
114
- messages=[
115
- {"role": "system", "content": system_prompt.strip()},
116
- {"role": "user", "content": user_prompt.strip()}
117
- ],
118
- temperature=0.1,
119
- max_tokens=150
120
- )
121
- action_text = (completion.choices[0].message.content or "").strip()
122
- # strip potential code blocks if model hallucinates them
123
- if action_text.startswith("```json"): action_text = action_text[7:]
124
- if action_text.endswith("```"): action_text = action_text[:-3]
125
-
126
- start_idx = action_text.find("{")
127
- end_idx = action_text.rfind("}")
128
- if start_idx != -1 and end_idx != -1:
129
- json_str = action_text[start_idx:end_idx+1]
130
- action = json.loads(json_str)
131
- except Exception as e:
132
- error_msg = str(e)[:50]
133
-
134
- # Fallback heuristic logic if action could not be parsed
 
 
 
135
  if not action:
136
  tasks = observation.get("tasks", [])
137
  incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
138
- if observation.get("visible_state", {}).get("fatigue_level") == "high":
139
  action = {"type": "break"}
140
  elif incomp:
141
  action = {"type": "work", "task_id": incomp[0]["id"]}
142
  else:
143
  action = {"type": "delay"}
144
 
145
- # Stringify action densely for stdout formatting
146
  action_str = json.dumps(action).replace(" ", "")
147
-
148
- # 3. Process action in Env
149
  try:
150
  step_data = post_json(f"{ENV_BASE_URL}/step", {
151
  "session_id": session_id,
152
  "action": action
153
  })
154
-
155
  observation = step_data["observation"]
156
  reward = step_data.get("reward", 0.0)
157
  done = step_data.get("done", False)
@@ -160,7 +165,7 @@ What is your next action JSON?
160
  reward = 0.0
161
  done = True
162
  error_msg = error_msg or str(e)[:50]
163
-
164
  rewards.append(reward)
165
  history.append(f"Step {step} Action: {action_str} -> Reward: {reward}")
166
  log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)
 
28
  except urllib.error.HTTPError as e:
29
  raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
30
 
31
+ # ── Environment variables ────────────────────────────────────────────────────
32
+ # API_BASE_URL and API_KEY are injected by the hackathon LiteLLM proxy.
33
+ # HF_TOKEN is kept as a fallback for local testing only.
34
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
35
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
36
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
37
+
38
+ # Prefer the hackathon-injected API_KEY; fall back to HF_TOKEN for local runs
39
+ API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "")
40
 
41
  TASK_NAME = "schedule-optimization"
42
  BENCHMARK = "cognitive-load-manager"
43
+ SUCCESS_SCORE_THRESHOLD = 0.5
44
  MAX_STEPS = 50
45
 
46
  def log_start(task: str, env: str, model: str) -> None:
 
59
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
60
 
61
  def main():
62
+ # Always initialise the OpenAI client using the proxy URL and API key.
63
+ # The hackathon validator requires ALL LLM calls to go through API_BASE_URL
64
+ # with the provided API_KEY β€” never bypass this with hardcoded credentials.
65
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
66
 
 
67
  task_id = os.getenv("CLM_LEVEL", "hard")
68
+
69
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
70
+
71
  # 1. Reset Environment
72
  try:
73
  data = post_json(f"{ENV_BASE_URL}/reset", {"task_id": task_id})
 
75
  log_step(step=0, action="reset", reward=0.0, done=True, error=str(e)[:50])
76
  log_end(success=False, steps=0, score=0.0, rewards=[])
77
  return
78
+
79
  session_id = data["session_id"]
80
  observation = data["observation"]
81
+
82
  done = False
83
  step = 0
84
  rewards = []
85
  history = []
86
  info = {}
87
+
88
  while not done and step < MAX_STEPS:
89
  step += 1
90
+
91
+ # 2. Get next action from LLM via the hackathon proxy
92
  history_str = "\n".join(history[-5:]) if history else "No previous actions."
93
  system_prompt = """
94
  You are an AI task scheduler managing cognitive load.
 
110
  """
111
  action = None
112
  error_msg = None
113
+
114
+ try:
115
+ completion = client.chat.completions.create(
116
+ model=MODEL_NAME,
117
+ messages=[
118
+ {"role": "system", "content": system_prompt.strip()},
119
+ {"role": "user", "content": user_prompt.strip()}
120
+ ],
121
+ temperature=0.1,
122
+ max_tokens=150
123
+ )
124
+ action_text = (completion.choices[0].message.content or "").strip()
125
+
126
+ # Strip accidental markdown code fences
127
+ if action_text.startswith("```json"):
128
+ action_text = action_text[7:]
129
+ if action_text.startswith("```"):
130
+ action_text = action_text[3:]
131
+ if action_text.endswith("```"):
132
+ action_text = action_text[:-3]
133
+
134
+ start_idx = action_text.find("{")
135
+ end_idx = action_text.rfind("}")
136
+ if start_idx != -1 and end_idx != -1:
137
+ action = json.loads(action_text[start_idx:end_idx + 1])
138
+ except Exception as e:
139
+ error_msg = str(e)[:50]
140
+
141
+ # Fallback heuristic only if LLM call failed / returned unparseable output
142
  if not action:
143
  tasks = observation.get("tasks", [])
144
  incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
145
+ if observation.get("visible_state", {}).get("fatigue_level") in ("high", "medium"):
146
  action = {"type": "break"}
147
  elif incomp:
148
  action = {"type": "work", "task_id": incomp[0]["id"]}
149
  else:
150
  action = {"type": "delay"}
151
 
 
152
  action_str = json.dumps(action).replace(" ", "")
153
+
154
+ # 3. Step the environment
155
  try:
156
  step_data = post_json(f"{ENV_BASE_URL}/step", {
157
  "session_id": session_id,
158
  "action": action
159
  })
 
160
  observation = step_data["observation"]
161
  reward = step_data.get("reward", 0.0)
162
  done = step_data.get("done", False)
 
165
  reward = 0.0
166
  done = True
167
  error_msg = error_msg or str(e)[:50]
168
+
169
  rewards.append(reward)
170
  history.append(f"Step {step} Action: {action_str} -> Reward: {reward}")
171
  log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)