Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- client_env.py +3 -45
- inference.py +11 -17
- server/__init__.py +0 -0
- server/server.py +2 -2
client_env.py
CHANGED
|
@@ -1,22 +1,12 @@
|
|
| 1 |
-
"""
|
| 2 |
-
AegisGym WebSocket client β concrete subclass of openenv EnvClient.
|
| 3 |
-
|
| 4 |
-
Usage:
|
| 5 |
-
client = AegisGymWsClient()
|
| 6 |
-
sync_client = client.sync()
|
| 7 |
-
obs = sync_client.reset()
|
| 8 |
-
obs = sync_client.step({...})
|
| 9 |
-
"""
|
| 10 |
from typing import Any, Dict
|
| 11 |
from openenv.core.env_client import EnvClient
|
| 12 |
from openenv.core.sync_client import SyncEnvClient
|
| 13 |
from server.models import AuditAction, AuditObservation
|
| 14 |
|
| 15 |
-
HF_SPACE_WSS = "wss://armaan020-
|
| 16 |
-
|
| 17 |
|
| 18 |
class AegisGymWsClient(EnvClient):
|
| 19 |
-
"""Concrete EnvClient implementation for the
|
| 20 |
|
| 21 |
def _step_payload(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
| 22 |
"""Convert an action dict into the WS step payload."""
|
|
@@ -24,44 +14,12 @@ class AegisGymWsClient(EnvClient):
|
|
| 24 |
|
| 25 |
def _parse_result(self, payload: Dict[str, Any]) -> Any:
|
| 26 |
"""Parse reset/step response from the server into usable result."""
|
| 27 |
-
return payload
|
| 28 |
|
| 29 |
def _parse_state(self, payload: Dict[str, Any]) -> Any:
|
| 30 |
"""Parse the state endpoint response."""
|
| 31 |
return payload
|
| 32 |
|
| 33 |
-
|
| 34 |
def get_sync_client(ws_url: str = HF_SPACE_WSS) -> SyncEnvClient:
|
| 35 |
"""Return a synchronous wrapper over the WebSocket client."""
|
| 36 |
return AegisGymWsClient(base_url=ws_url).sync()
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def test_live_env():
|
| 40 |
-
print("=== Testing Live AegisGym Space (wss) ===")
|
| 41 |
-
client = get_sync_client()
|
| 42 |
-
|
| 43 |
-
print("reset() ...")
|
| 44 |
-
result = client.reset()
|
| 45 |
-
print(f" result type: {type(result)}")
|
| 46 |
-
print(f" keys: {list(result.keys()) if isinstance(result, dict) else dir(result)}")
|
| 47 |
-
|
| 48 |
-
action = AuditAction(
|
| 49 |
-
action_type="FLAG",
|
| 50 |
-
target_id="ACC-BL-001",
|
| 51 |
-
regulation_citation="EU-AI-Act-Art-57"
|
| 52 |
-
).model_dump()
|
| 53 |
-
|
| 54 |
-
print("step(FLAG ACC-BL-001) ...")
|
| 55 |
-
result = client.step(action)
|
| 56 |
-
print(f" reward={result.get('reward') if isinstance(result, dict) else getattr(result,'reward',None)}")
|
| 57 |
-
print(f" done={result.get('done') if isinstance(result, dict) else getattr(result,'done',None)}")
|
| 58 |
-
|
| 59 |
-
print("state() ...")
|
| 60 |
-
s = client.state()
|
| 61 |
-
print(f" State: {s}")
|
| 62 |
-
|
| 63 |
-
print("\n=== Live environment OK! ===")
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
if __name__ == "__main__":
|
| 67 |
-
test_live_env()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Any, Dict
|
| 2 |
from openenv.core.env_client import EnvClient
|
| 3 |
from openenv.core.sync_client import SyncEnvClient
|
| 4 |
from server.models import AuditAction, AuditObservation
|
| 5 |
|
| 6 |
+
HF_SPACE_WSS = "wss://armaan020-aegisopenenv.hf.space"
|
|
|
|
| 7 |
|
| 8 |
class AegisGymWsClient(EnvClient):
|
| 9 |
+
"""Concrete EnvClient implementation for the AegisOpenEnv HF Space."""
|
| 10 |
|
| 11 |
def _step_payload(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
| 12 |
"""Convert an action dict into the WS step payload."""
|
|
|
|
| 14 |
|
| 15 |
def _parse_result(self, payload: Dict[str, Any]) -> Any:
|
| 16 |
"""Parse reset/step response from the server into usable result."""
|
| 17 |
+
return payload
|
| 18 |
|
| 19 |
def _parse_state(self, payload: Dict[str, Any]) -> Any:
|
| 20 |
"""Parse the state endpoint response."""
|
| 21 |
return payload
|
| 22 |
|
|
|
|
| 23 |
def get_sync_client(ws_url: str = HF_SPACE_WSS) -> SyncEnvClient:
|
| 24 |
"""Return a synchronous wrapper over the WebSocket client."""
|
| 25 |
return AegisGymWsClient(base_url=ws_url).sync()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
"""
|
| 2 |
-
AegisGym Baseline Inference Script
|
| 3 |
-
Requirement: Must use OpenAI client and env vars for credentials.
|
| 4 |
-
"""
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
from openai import OpenAI
|
|
@@ -9,21 +5,14 @@ from client_env import get_sync_client
|
|
| 9 |
from server.models import AuditAction
|
| 10 |
|
| 11 |
# βββ Config (Required by Meta OpenEnv) ββββββββββββββββββββββββββββββββββββββ
|
| 12 |
-
API_BASE_URL = os.getenv("API_BASE_URL", "https://
|
| 13 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "
|
| 14 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
| 15 |
-
ENV_URL = os.getenv("ENV_URL", "https://armaan020-
|
| 16 |
-
|
| 17 |
-
# OpenRouter recommends specific headers for their API
|
| 18 |
-
extra_headers = {
|
| 19 |
-
"HTTP-Referer": "https://huggingface.co/spaces/armaan020/AegisGym",
|
| 20 |
-
"X-Title": "AegisGym OpenEnv Submission"
|
| 21 |
-
}
|
| 22 |
|
| 23 |
client = OpenAI(
|
| 24 |
api_key=OPENAI_API_KEY,
|
| 25 |
-
base_url=API_BASE_URL
|
| 26 |
-
default_headers=extra_headers if "openrouter" in API_BASE_URL.lower() else None
|
| 27 |
)
|
| 28 |
|
| 29 |
SYSTEM_PROMPT = """You are a high-performance financial auditor AI.
|
|
@@ -48,6 +37,7 @@ def run_baseline(num_episodes=10):
|
|
| 48 |
try:
|
| 49 |
obs_payload = env.reset()
|
| 50 |
obs = obs_payload.get("observation", {})
|
|
|
|
| 51 |
|
| 52 |
user_msg = (
|
| 53 |
f"Audit the transaction.\n"
|
|
@@ -56,7 +46,6 @@ def run_baseline(num_episodes=10):
|
|
| 56 |
f"Account: {obs.get('account_metadata', {})}"
|
| 57 |
)
|
| 58 |
|
| 59 |
-
# OpenAI API Call
|
| 60 |
response = client.chat.completions.create(
|
| 61 |
model=MODEL_NAME,
|
| 62 |
messages=[
|
|
@@ -68,6 +57,11 @@ def run_baseline(num_episodes=10):
|
|
| 68 |
|
| 69 |
content = response.choices[0].message.content
|
| 70 |
action_data = json.loads(content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
print(f" Target: {action_data.get('target_id')} | Action: {action_data.get('action_type')}")
|
| 72 |
|
| 73 |
result = env.step(action_data)
|
|
@@ -78,7 +72,7 @@ def run_baseline(num_episodes=10):
|
|
| 78 |
episodes_run += 1
|
| 79 |
|
| 80 |
except Exception as e:
|
| 81 |
-
print(f" Error in episode: {e}")
|
| 82 |
continue
|
| 83 |
|
| 84 |
print(f"\n--- AegisGym Reproducibility Report ---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
from openai import OpenAI
|
|
|
|
| 5 |
from server.models import AuditAction
|
| 6 |
|
| 7 |
# βββ Config (Required by Meta OpenEnv) ββββββββββββββββββββββββββββββββββββββ
|
| 8 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://openrouter.ai/api/v1")
|
| 9 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "stepfun/step-3.5-flash:free")
|
| 10 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
| 11 |
+
ENV_URL = os.getenv("ENV_URL", "https://armaan020-aegisopenenv.hf.space")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
client = OpenAI(
|
| 14 |
api_key=OPENAI_API_KEY,
|
| 15 |
+
base_url=API_BASE_URL
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
SYSTEM_PROMPT = """You are a high-performance financial auditor AI.
|
|
|
|
| 37 |
try:
|
| 38 |
obs_payload = env.reset()
|
| 39 |
obs = obs_payload.get("observation", {})
|
| 40 |
+
target_id = obs.get("account_metadata", {}).get("target_id", "N/A")
|
| 41 |
|
| 42 |
user_msg = (
|
| 43 |
f"Audit the transaction.\n"
|
|
|
|
| 46 |
f"Account: {obs.get('account_metadata', {})}"
|
| 47 |
)
|
| 48 |
|
|
|
|
| 49 |
response = client.chat.completions.create(
|
| 50 |
model=MODEL_NAME,
|
| 51 |
messages=[
|
|
|
|
| 57 |
|
| 58 |
content = response.choices[0].message.content
|
| 59 |
action_data = json.loads(content)
|
| 60 |
+
|
| 61 |
+
# Robust Fallbacks
|
| 62 |
+
if "target_id" not in action_data: action_data["target_id"] = target_id
|
| 63 |
+
if "action_type" not in action_data: action_data["action_type"] = "APPROVE"
|
| 64 |
+
|
| 65 |
print(f" Target: {action_data.get('target_id')} | Action: {action_data.get('action_type')}")
|
| 66 |
|
| 67 |
result = env.step(action_data)
|
|
|
|
| 72 |
episodes_run += 1
|
| 73 |
|
| 74 |
except Exception as e:
|
| 75 |
+
print(f" Error in episode {i+1}: {e}")
|
| 76 |
continue
|
| 77 |
|
| 78 |
print(f"\n--- AegisGym Reproducibility Report ---")
|
server/__init__.py
CHANGED
|
Binary files a/server/__init__.py and b/server/__init__.py differ
|
|
|
server/server.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import random
|
| 2 |
-
from typing import Dict, Any, Tuple
|
| 3 |
from openenv.core.env_server import Environment
|
| 4 |
from .models import AuditAction, AuditObservation, AuditState
|
| 5 |
from .grader import Grader
|
|
@@ -63,7 +63,7 @@ class AegisOpenEnv(Environment):
|
|
| 63 |
|
| 64 |
def _set_next_scenario(self):
|
| 65 |
tiers = ["easy", "medium", "hard"]
|
| 66 |
-
self.current_tier = tiers[self.step_count %
|
| 67 |
|
| 68 |
if self.current_tier == "easy":
|
| 69 |
self.current_target_id = random.choice(["ACC-BL-001", "ACC-CLEAN-01"])
|
|
|
|
| 1 |
import random
|
| 2 |
+
from typing import Dict, Any, Tuple, List
|
| 3 |
from openenv.core.env_server import Environment
|
| 4 |
from .models import AuditAction, AuditObservation, AuditState
|
| 5 |
from .grader import Grader
|
|
|
|
| 63 |
|
| 64 |
def _set_next_scenario(self):
|
| 65 |
tiers = ["easy", "medium", "hard"]
|
| 66 |
+
self.current_tier = tiers[self.step_count % len(tiers)]
|
| 67 |
|
| 68 |
if self.current_tier == "easy":
|
| 69 |
self.current_target_id = random.choice(["ACC-BL-001", "ACC-CLEAN-01"])
|