anshumanatrey commited on
Commit
3e7a0ef
·
verified ·
1 Parent(s): db7b6af

Fix HF_TOKEN handling, [END] always emitted, add openenv tag

Browse files
Files changed (2) hide show
  1. README.md +2 -0
  2. inference.py +105 -99
README.md CHANGED
@@ -5,6 +5,8 @@ colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
  app_port: 8000
 
 
8
  short_description: "Can your AI reason from raw evidence or just parse labels?"
9
  ---
10
 
 
5
  colorTo: purple
6
  sdk: docker
7
  app_port: 8000
8
+ tags:
9
+ - openenv
10
  short_description: "Can your AI reason from raw evidence or just parse labels?"
11
  ---
12
 
inference.py CHANGED
@@ -23,8 +23,11 @@ from openai import OpenAI
23
 
24
  # --- ENV VARS ---
25
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
26
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY", "")
27
  MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
 
 
 
 
28
 
29
  # --- CONFIG ---
30
  SCENARIO_MAX_STEPS = {"easy": 25, "medium": 35, "hard": 45}
@@ -160,108 +163,111 @@ def run_scenario(client: OpenAI, scenario_id: str, env_url: str) -> float:
160
  success = False
161
  last_error = None
162
 
163
- with SecurityAuditEnv(base_url=env_url).sync() as env:
164
- result = env.reset(scenario_id=scenario_id)
165
- observation = result.observation
166
- history: List[str] = []
167
-
168
- for step in range(1, max_steps + 1):
169
- if result.done:
170
- break
171
-
172
- prompt = build_prompt(step, observation, history, max_steps=max_steps)
173
- messages = [
174
- {"role": "system", "content": SYSTEM_PROMPT},
175
- {"role": "user", "content": prompt},
176
- ]
177
-
178
- last_error = None
179
- try:
180
- completion = client.chat.completions.create(
181
- model=MODEL_NAME,
182
- messages=messages,
183
- temperature=TEMPERATURE,
184
- max_tokens=MAX_TOKENS,
185
- stream=False,
186
- )
187
- response_text = completion.choices[0].message.content or ""
188
- except Exception as exc:
189
- last_error = str(exc)
190
- response_text = '{"action_type": "list_tools"}'
191
-
192
- action_dict = parse_action(response_text)
193
- if not action_dict:
194
- last_error = "Could not parse LLM response as JSON"
195
- action_dict = {"action_type": "list_tools"}
196
-
197
- action_type = action_dict.get("action_type", "list_tools")
198
- tool_name = action_dict.get("tool_name")
199
- arguments = action_dict.get("arguments", {})
200
-
201
- action_str = action_type
202
- if tool_name:
203
- action_str += f"({tool_name})"
204
-
205
- try:
206
- action = SecurityAuditAction(
207
- action_type=action_type,
208
- tool_name=tool_name,
209
- arguments=arguments,
210
- )
211
- result = env.step(action)
212
- observation = result.observation
213
  last_error = None
214
- except Exception as exc:
215
- last_error = str(exc)
216
- reward = 0.0
217
- all_rewards.append(reward)
218
- total_steps = step
219
- # --- MANDATORY STDOUT: [STEP] ---
220
- error_str = last_error.replace("\n", " ") if last_error else "null"
221
- print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done=false error={error_str}", flush=True)
222
- break
223
-
224
- reward = result.reward or 0.0
225
- all_rewards.append(reward)
226
- total_steps = step
227
-
228
- history.append(f"Step {step}: {action_str} → reward {reward:+.2f}")
229
-
230
- # --- MANDATORY STDOUT: [STEP] ---
231
- done_str = "true" if result.done else "false"
232
- error_str = last_error.replace("\n", " ") if last_error else "null"
233
- print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={done_str} error={error_str}", flush=True)
234
-
235
- if result.done:
236
- grades = getattr(observation, "metadata", {}) or {}
237
- grades = grades.get("grades", {})
238
- final_score = grades.get("final_score", reward)
239
- success = final_score > 0
240
- break
241
- else:
242
- # Didn't finish — force report generation
243
- try:
244
- action = SecurityAuditAction(action_type="generate_report")
245
- result = env.step(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  reward = result.reward or 0.0
247
  all_rewards.append(reward)
248
- total_steps += 1
249
-
250
- done_str = "true" if result.done else "false"
251
- print(f"[STEP] step={total_steps} action=generate_report reward={reward:.2f} done={done_str} error=null", flush=True)
252
 
253
- grades = getattr(result.observation, "metadata", {}) or {}
254
- grades = grades.get("grades", {})
255
- final_score = grades.get("final_score", 0.0)
256
- success = final_score > 0
257
- except Exception as exc:
258
- final_score = 0.0
259
- last_error = str(exc)
260
 
261
- # --- MANDATORY STDOUT: [END] ---
262
- rewards_str = ",".join(f"{r:.2f}" for r in all_rewards)
263
- success_str = "true" if success else "false"
264
- print(f"[END] success={success_str} steps={total_steps} rewards={rewards_str}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  return final_score
267
 
@@ -272,7 +278,7 @@ def main():
272
  print(f"API: {API_BASE_URL}")
273
  print(f"Model: {MODEL_NAME}")
274
 
275
- llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
276
  env_url = os.getenv("ENV_URL", "http://localhost:8000")
277
 
278
  scores = {}
 
23
 
24
  # --- ENV VARS ---
25
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
 
26
  MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
27
+ HF_TOKEN = os.getenv("HF_TOKEN")
28
+
29
+ if HF_TOKEN is None:
30
+ raise ValueError("HF_TOKEN environment variable is required")
31
 
32
  # --- CONFIG ---
33
  SCENARIO_MAX_STEPS = {"easy": 25, "medium": 35, "hard": 45}
 
163
  success = False
164
  last_error = None
165
 
166
+ try:
167
+ with SecurityAuditEnv(base_url=env_url).sync() as env:
168
+ result = env.reset(scenario_id=scenario_id)
169
+ observation = result.observation
170
+ history: List[str] = []
171
+
172
+ for step in range(1, max_steps + 1):
173
+ if result.done:
174
+ break
175
+
176
+ prompt = build_prompt(step, observation, history, max_steps=max_steps)
177
+ messages = [
178
+ {"role": "system", "content": SYSTEM_PROMPT},
179
+ {"role": "user", "content": prompt},
180
+ ]
181
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  last_error = None
183
+ try:
184
+ completion = client.chat.completions.create(
185
+ model=MODEL_NAME,
186
+ messages=messages,
187
+ temperature=TEMPERATURE,
188
+ max_tokens=MAX_TOKENS,
189
+ stream=False,
190
+ )
191
+ response_text = completion.choices[0].message.content or ""
192
+ except Exception as exc:
193
+ last_error = str(exc)
194
+ response_text = '{"action_type": "list_tools"}'
195
+
196
+ action_dict = parse_action(response_text)
197
+ if not action_dict:
198
+ last_error = "Could not parse LLM response as JSON"
199
+ action_dict = {"action_type": "list_tools"}
200
+
201
+ action_type = action_dict.get("action_type", "list_tools")
202
+ tool_name = action_dict.get("tool_name")
203
+ arguments = action_dict.get("arguments", {})
204
+
205
+ action_str = action_type
206
+ if tool_name:
207
+ action_str += f"({tool_name})"
208
+
209
+ try:
210
+ action = SecurityAuditAction(
211
+ action_type=action_type,
212
+ tool_name=tool_name,
213
+ arguments=arguments,
214
+ )
215
+ result = env.step(action)
216
+ observation = result.observation
217
+ last_error = None
218
+ except Exception as exc:
219
+ last_error = str(exc)
220
+ reward = 0.0
221
+ all_rewards.append(reward)
222
+ total_steps = step
223
+ # --- MANDATORY STDOUT: [STEP] ---
224
+ error_str = last_error.replace("\n", " ") if last_error else "null"
225
+ print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done=false error={error_str}", flush=True)
226
+ break
227
+
228
  reward = result.reward or 0.0
229
  all_rewards.append(reward)
230
+ total_steps = step
 
 
 
231
 
232
+ history.append(f"Step {step}: {action_str} → reward {reward:+.2f}")
 
 
 
 
 
 
233
 
234
+ # --- MANDATORY STDOUT: [STEP] ---
235
+ done_str = "true" if result.done else "false"
236
+ error_str = last_error.replace("\n", " ") if last_error else "null"
237
+ print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={done_str} error={error_str}", flush=True)
238
+
239
+ if result.done:
240
+ grades = getattr(observation, "metadata", {}) or {}
241
+ grades = grades.get("grades", {})
242
+ final_score = grades.get("final_score", reward)
243
+ success = final_score > 0
244
+ break
245
+ else:
246
+ # Didn't finish — force report generation
247
+ try:
248
+ action = SecurityAuditAction(action_type="generate_report")
249
+ result = env.step(action)
250
+ reward = result.reward or 0.0
251
+ all_rewards.append(reward)
252
+ total_steps += 1
253
+
254
+ done_str = "true" if result.done else "false"
255
+ print(f"[STEP] step={total_steps} action=generate_report reward={reward:.2f} done={done_str} error=null", flush=True)
256
+
257
+ grades = getattr(result.observation, "metadata", {}) or {}
258
+ grades = grades.get("grades", {})
259
+ final_score = grades.get("final_score", 0.0)
260
+ success = final_score > 0
261
+ except Exception as exc:
262
+ final_score = 0.0
263
+ last_error = str(exc)
264
+ except Exception as exc:
265
+ last_error = str(exc)
266
+ finally:
267
+ # --- MANDATORY STDOUT: [END] (always emitted, even on exception) ---
268
+ rewards_str = ",".join(f"{r:.2f}" for r in all_rewards)
269
+ success_str = "true" if success else "false"
270
+ print(f"[END] success={success_str} steps={total_steps} rewards={rewards_str}", flush=True)
271
 
272
  return final_score
273
 
 
278
  print(f"API: {API_BASE_URL}")
279
  print(f"Model: {MODEL_NAME}")
280
 
281
+ llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
282
  env_url = os.getenv("ENV_URL", "http://localhost:8000")
283
 
284
  scores = {}