Spaces:
Sleeping
Sleeping
Fix HF_TOKEN handling, [END] always emitted, add openenv tag
Browse files- README.md +2 -0
- inference.py +105 -99
README.md
CHANGED
|
@@ -5,6 +5,8 @@ colorFrom: blue
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
app_port: 8000
|
|
|
|
|
|
|
| 8 |
short_description: "Can your AI reason from raw evidence or just parse labels?"
|
| 9 |
---
|
| 10 |
|
|
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
app_port: 8000
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
short_description: "Can your AI reason from raw evidence or just parse labels?"
|
| 11 |
---
|
| 12 |
|
inference.py
CHANGED
|
@@ -23,8 +23,11 @@ from openai import OpenAI
|
|
| 23 |
|
| 24 |
# --- ENV VARS ---
|
| 25 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 26 |
-
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY", "")
|
| 27 |
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# --- CONFIG ---
|
| 30 |
SCENARIO_MAX_STEPS = {"easy": 25, "medium": 35, "hard": 45}
|
|
@@ -160,108 +163,111 @@ def run_scenario(client: OpenAI, scenario_id: str, env_url: str) -> float:
|
|
| 160 |
success = False
|
| 161 |
last_error = None
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
try:
|
| 180 |
-
completion = client.chat.completions.create(
|
| 181 |
-
model=MODEL_NAME,
|
| 182 |
-
messages=messages,
|
| 183 |
-
temperature=TEMPERATURE,
|
| 184 |
-
max_tokens=MAX_TOKENS,
|
| 185 |
-
stream=False,
|
| 186 |
-
)
|
| 187 |
-
response_text = completion.choices[0].message.content or ""
|
| 188 |
-
except Exception as exc:
|
| 189 |
-
last_error = str(exc)
|
| 190 |
-
response_text = '{"action_type": "list_tools"}'
|
| 191 |
-
|
| 192 |
-
action_dict = parse_action(response_text)
|
| 193 |
-
if not action_dict:
|
| 194 |
-
last_error = "Could not parse LLM response as JSON"
|
| 195 |
-
action_dict = {"action_type": "list_tools"}
|
| 196 |
-
|
| 197 |
-
action_type = action_dict.get("action_type", "list_tools")
|
| 198 |
-
tool_name = action_dict.get("tool_name")
|
| 199 |
-
arguments = action_dict.get("arguments", {})
|
| 200 |
-
|
| 201 |
-
action_str = action_type
|
| 202 |
-
if tool_name:
|
| 203 |
-
action_str += f"({tool_name})"
|
| 204 |
-
|
| 205 |
-
try:
|
| 206 |
-
action = SecurityAuditAction(
|
| 207 |
-
action_type=action_type,
|
| 208 |
-
tool_name=tool_name,
|
| 209 |
-
arguments=arguments,
|
| 210 |
-
)
|
| 211 |
-
result = env.step(action)
|
| 212 |
-
observation = result.observation
|
| 213 |
last_error = None
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
reward = result.reward or 0.0
|
| 247 |
all_rewards.append(reward)
|
| 248 |
-
total_steps
|
| 249 |
-
|
| 250 |
-
done_str = "true" if result.done else "false"
|
| 251 |
-
print(f"[STEP] step={total_steps} action=generate_report reward={reward:.2f} done={done_str} error=null", flush=True)
|
| 252 |
|
| 253 |
-
|
| 254 |
-
grades = grades.get("grades", {})
|
| 255 |
-
final_score = grades.get("final_score", 0.0)
|
| 256 |
-
success = final_score > 0
|
| 257 |
-
except Exception as exc:
|
| 258 |
-
final_score = 0.0
|
| 259 |
-
last_error = str(exc)
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
return final_score
|
| 267 |
|
|
@@ -272,7 +278,7 @@ def main():
|
|
| 272 |
print(f"API: {API_BASE_URL}")
|
| 273 |
print(f"Model: {MODEL_NAME}")
|
| 274 |
|
| 275 |
-
llm_client = OpenAI(base_url=API_BASE_URL, api_key=
|
| 276 |
env_url = os.getenv("ENV_URL", "http://localhost:8000")
|
| 277 |
|
| 278 |
scores = {}
|
|
|
|
| 23 |
|
| 24 |
# --- ENV VARS ---
|
| 25 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
|
|
|
| 26 |
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
|
| 27 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 28 |
+
|
| 29 |
+
if HF_TOKEN is None:
|
| 30 |
+
raise ValueError("HF_TOKEN environment variable is required")
|
| 31 |
|
| 32 |
# --- CONFIG ---
|
| 33 |
SCENARIO_MAX_STEPS = {"easy": 25, "medium": 35, "hard": 45}
|
|
|
|
| 163 |
success = False
|
| 164 |
last_error = None
|
| 165 |
|
| 166 |
+
try:
|
| 167 |
+
with SecurityAuditEnv(base_url=env_url).sync() as env:
|
| 168 |
+
result = env.reset(scenario_id=scenario_id)
|
| 169 |
+
observation = result.observation
|
| 170 |
+
history: List[str] = []
|
| 171 |
+
|
| 172 |
+
for step in range(1, max_steps + 1):
|
| 173 |
+
if result.done:
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
prompt = build_prompt(step, observation, history, max_steps=max_steps)
|
| 177 |
+
messages = [
|
| 178 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 179 |
+
{"role": "user", "content": prompt},
|
| 180 |
+
]
|
| 181 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
last_error = None
|
| 183 |
+
try:
|
| 184 |
+
completion = client.chat.completions.create(
|
| 185 |
+
model=MODEL_NAME,
|
| 186 |
+
messages=messages,
|
| 187 |
+
temperature=TEMPERATURE,
|
| 188 |
+
max_tokens=MAX_TOKENS,
|
| 189 |
+
stream=False,
|
| 190 |
+
)
|
| 191 |
+
response_text = completion.choices[0].message.content or ""
|
| 192 |
+
except Exception as exc:
|
| 193 |
+
last_error = str(exc)
|
| 194 |
+
response_text = '{"action_type": "list_tools"}'
|
| 195 |
+
|
| 196 |
+
action_dict = parse_action(response_text)
|
| 197 |
+
if not action_dict:
|
| 198 |
+
last_error = "Could not parse LLM response as JSON"
|
| 199 |
+
action_dict = {"action_type": "list_tools"}
|
| 200 |
+
|
| 201 |
+
action_type = action_dict.get("action_type", "list_tools")
|
| 202 |
+
tool_name = action_dict.get("tool_name")
|
| 203 |
+
arguments = action_dict.get("arguments", {})
|
| 204 |
+
|
| 205 |
+
action_str = action_type
|
| 206 |
+
if tool_name:
|
| 207 |
+
action_str += f"({tool_name})"
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
action = SecurityAuditAction(
|
| 211 |
+
action_type=action_type,
|
| 212 |
+
tool_name=tool_name,
|
| 213 |
+
arguments=arguments,
|
| 214 |
+
)
|
| 215 |
+
result = env.step(action)
|
| 216 |
+
observation = result.observation
|
| 217 |
+
last_error = None
|
| 218 |
+
except Exception as exc:
|
| 219 |
+
last_error = str(exc)
|
| 220 |
+
reward = 0.0
|
| 221 |
+
all_rewards.append(reward)
|
| 222 |
+
total_steps = step
|
| 223 |
+
# --- MANDATORY STDOUT: [STEP] ---
|
| 224 |
+
error_str = last_error.replace("\n", " ") if last_error else "null"
|
| 225 |
+
print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done=false error={error_str}", flush=True)
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
reward = result.reward or 0.0
|
| 229 |
all_rewards.append(reward)
|
| 230 |
+
total_steps = step
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
history.append(f"Step {step}: {action_str} → reward {reward:+.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
# --- MANDATORY STDOUT: [STEP] ---
|
| 235 |
+
done_str = "true" if result.done else "false"
|
| 236 |
+
error_str = last_error.replace("\n", " ") if last_error else "null"
|
| 237 |
+
print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={done_str} error={error_str}", flush=True)
|
| 238 |
+
|
| 239 |
+
if result.done:
|
| 240 |
+
grades = getattr(observation, "metadata", {}) or {}
|
| 241 |
+
grades = grades.get("grades", {})
|
| 242 |
+
final_score = grades.get("final_score", reward)
|
| 243 |
+
success = final_score > 0
|
| 244 |
+
break
|
| 245 |
+
else:
|
| 246 |
+
# Didn't finish — force report generation
|
| 247 |
+
try:
|
| 248 |
+
action = SecurityAuditAction(action_type="generate_report")
|
| 249 |
+
result = env.step(action)
|
| 250 |
+
reward = result.reward or 0.0
|
| 251 |
+
all_rewards.append(reward)
|
| 252 |
+
total_steps += 1
|
| 253 |
+
|
| 254 |
+
done_str = "true" if result.done else "false"
|
| 255 |
+
print(f"[STEP] step={total_steps} action=generate_report reward={reward:.2f} done={done_str} error=null", flush=True)
|
| 256 |
+
|
| 257 |
+
grades = getattr(result.observation, "metadata", {}) or {}
|
| 258 |
+
grades = grades.get("grades", {})
|
| 259 |
+
final_score = grades.get("final_score", 0.0)
|
| 260 |
+
success = final_score > 0
|
| 261 |
+
except Exception as exc:
|
| 262 |
+
final_score = 0.0
|
| 263 |
+
last_error = str(exc)
|
| 264 |
+
except Exception as exc:
|
| 265 |
+
last_error = str(exc)
|
| 266 |
+
finally:
|
| 267 |
+
# --- MANDATORY STDOUT: [END] (always emitted, even on exception) ---
|
| 268 |
+
rewards_str = ",".join(f"{r:.2f}" for r in all_rewards)
|
| 269 |
+
success_str = "true" if success else "false"
|
| 270 |
+
print(f"[END] success={success_str} steps={total_steps} rewards={rewards_str}", flush=True)
|
| 271 |
|
| 272 |
return final_score
|
| 273 |
|
|
|
|
| 278 |
print(f"API: {API_BASE_URL}")
|
| 279 |
print(f"Model: {MODEL_NAME}")
|
| 280 |
|
| 281 |
+
llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 282 |
env_url = os.getenv("ENV_URL", "http://localhost:8000")
|
| 283 |
|
| 284 |
scores = {}
|