armaan020 commited on
Commit
7f2ea18
Β·
verified Β·
1 Parent(s): a4d54e7

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. client_env.py +3 -45
  2. inference.py +11 -17
  3. server/__init__.py +0 -0
  4. server/server.py +2 -2
client_env.py CHANGED
@@ -1,22 +1,12 @@
1
- """
2
- AegisGym WebSocket client β€” concrete subclass of openenv EnvClient.
3
-
4
- Usage:
5
- client = AegisGymWsClient()
6
- sync_client = client.sync()
7
- obs = sync_client.reset()
8
- obs = sync_client.step({...})
9
- """
10
  from typing import Any, Dict
11
  from openenv.core.env_client import EnvClient
12
  from openenv.core.sync_client import SyncEnvClient
13
  from server.models import AuditAction, AuditObservation
14
 
15
- HF_SPACE_WSS = "wss://armaan020-aegisgym.hf.space"
16
-
17
 
18
  class AegisGymWsClient(EnvClient):
19
- """Concrete EnvClient implementation for the AegisGym HF Space."""
20
 
21
  def _step_payload(self, action: Dict[str, Any]) -> Dict[str, Any]:
22
  """Convert an action dict into the WS step payload."""
@@ -24,44 +14,12 @@ class AegisGymWsClient(EnvClient):
24
 
25
  def _parse_result(self, payload: Dict[str, Any]) -> Any:
26
  """Parse reset/step response from the server into usable result."""
27
- return payload # keep as dict; training code accesses .observation, .reward, .done
28
 
29
  def _parse_state(self, payload: Dict[str, Any]) -> Any:
30
  """Parse the state endpoint response."""
31
  return payload
32
 
33
-
34
  def get_sync_client(ws_url: str = HF_SPACE_WSS) -> SyncEnvClient:
35
  """Return a synchronous wrapper over the WebSocket client."""
36
  return AegisGymWsClient(base_url=ws_url).sync()
37
-
38
-
39
- def test_live_env():
40
- print("=== Testing Live AegisGym Space (wss) ===")
41
- client = get_sync_client()
42
-
43
- print("reset() ...")
44
- result = client.reset()
45
- print(f" result type: {type(result)}")
46
- print(f" keys: {list(result.keys()) if isinstance(result, dict) else dir(result)}")
47
-
48
- action = AuditAction(
49
- action_type="FLAG",
50
- target_id="ACC-BL-001",
51
- regulation_citation="EU-AI-Act-Art-57"
52
- ).model_dump()
53
-
54
- print("step(FLAG ACC-BL-001) ...")
55
- result = client.step(action)
56
- print(f" reward={result.get('reward') if isinstance(result, dict) else getattr(result,'reward',None)}")
57
- print(f" done={result.get('done') if isinstance(result, dict) else getattr(result,'done',None)}")
58
-
59
- print("state() ...")
60
- s = client.state()
61
- print(f" State: {s}")
62
-
63
- print("\n=== Live environment OK! ===")
64
-
65
-
66
- if __name__ == "__main__":
67
- test_live_env()
 
 
 
 
 
 
 
 
 
 
1
  from typing import Any, Dict
2
  from openenv.core.env_client import EnvClient
3
  from openenv.core.sync_client import SyncEnvClient
4
  from server.models import AuditAction, AuditObservation
5
 
6
+ HF_SPACE_WSS = "wss://armaan020-aegisopenenv.hf.space"
 
7
 
8
  class AegisGymWsClient(EnvClient):
9
+ """Concrete EnvClient implementation for the AegisOpenEnv HF Space."""
10
 
11
  def _step_payload(self, action: Dict[str, Any]) -> Dict[str, Any]:
12
  """Convert an action dict into the WS step payload."""
 
14
 
15
  def _parse_result(self, payload: Dict[str, Any]) -> Any:
16
  """Parse reset/step response from the server into usable result."""
17
+ return payload
18
 
19
  def _parse_state(self, payload: Dict[str, Any]) -> Any:
20
  """Parse the state endpoint response."""
21
  return payload
22
 
 
23
  def get_sync_client(ws_url: str = HF_SPACE_WSS) -> SyncEnvClient:
24
  """Return a synchronous wrapper over the WebSocket client."""
25
  return AegisGymWsClient(base_url=ws_url).sync()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py CHANGED
@@ -1,7 +1,3 @@
1
- """
2
- AegisGym Baseline Inference Script
3
- Requirement: Must use OpenAI client and env vars for credentials.
4
- """
5
  import os
6
  import json
7
  from openai import OpenAI
@@ -9,21 +5,14 @@ from client_env import get_sync_client
9
  from server.models import AuditAction
10
 
11
  # ─── Config (Required by Meta OpenEnv) ──────────────────────────────────────
12
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
13
- MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
14
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
15
- ENV_URL = os.getenv("ENV_URL", "https://armaan020-aegisgym.hf.space")
16
-
17
- # OpenRouter recommends specific headers for their API
18
- extra_headers = {
19
- "HTTP-Referer": "https://huggingface.co/spaces/armaan020/AegisGym",
20
- "X-Title": "AegisGym OpenEnv Submission"
21
- }
22
 
23
  client = OpenAI(
24
  api_key=OPENAI_API_KEY,
25
- base_url=API_BASE_URL,
26
- default_headers=extra_headers if "openrouter" in API_BASE_URL.lower() else None
27
  )
28
 
29
  SYSTEM_PROMPT = """You are a high-performance financial auditor AI.
@@ -48,6 +37,7 @@ def run_baseline(num_episodes=10):
48
  try:
49
  obs_payload = env.reset()
50
  obs = obs_payload.get("observation", {})
 
51
 
52
  user_msg = (
53
  f"Audit the transaction.\n"
@@ -56,7 +46,6 @@ def run_baseline(num_episodes=10):
56
  f"Account: {obs.get('account_metadata', {})}"
57
  )
58
 
59
- # OpenAI API Call
60
  response = client.chat.completions.create(
61
  model=MODEL_NAME,
62
  messages=[
@@ -68,6 +57,11 @@ def run_baseline(num_episodes=10):
68
 
69
  content = response.choices[0].message.content
70
  action_data = json.loads(content)
 
 
 
 
 
71
  print(f" Target: {action_data.get('target_id')} | Action: {action_data.get('action_type')}")
72
 
73
  result = env.step(action_data)
@@ -78,7 +72,7 @@ def run_baseline(num_episodes=10):
78
  episodes_run += 1
79
 
80
  except Exception as e:
81
- print(f" Error in episode: {e}")
82
  continue
83
 
84
  print(f"\n--- AegisGym Reproducibility Report ---")
 
 
 
 
 
1
  import os
2
  import json
3
  from openai import OpenAI
 
5
  from server.models import AuditAction
6
 
7
  # ─── Config (Required by Meta OpenEnv) ──────────────────────────────────────
8
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://openrouter.ai/api/v1")
9
+ MODEL_NAME = os.getenv("MODEL_NAME", "stepfun/step-3.5-flash:free")
10
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
11
+ ENV_URL = os.getenv("ENV_URL", "https://armaan020-aegisopenenv.hf.space")
 
 
 
 
 
 
12
 
13
  client = OpenAI(
14
  api_key=OPENAI_API_KEY,
15
+ base_url=API_BASE_URL
 
16
  )
17
 
18
  SYSTEM_PROMPT = """You are a high-performance financial auditor AI.
 
37
  try:
38
  obs_payload = env.reset()
39
  obs = obs_payload.get("observation", {})
40
+ target_id = obs.get("account_metadata", {}).get("target_id", "N/A")
41
 
42
  user_msg = (
43
  f"Audit the transaction.\n"
 
46
  f"Account: {obs.get('account_metadata', {})}"
47
  )
48
 
 
49
  response = client.chat.completions.create(
50
  model=MODEL_NAME,
51
  messages=[
 
57
 
58
  content = response.choices[0].message.content
59
  action_data = json.loads(content)
60
+
61
+ # Robust Fallbacks
62
+ if "target_id" not in action_data: action_data["target_id"] = target_id
63
+ if "action_type" not in action_data: action_data["action_type"] = "APPROVE"
64
+
65
  print(f" Target: {action_data.get('target_id')} | Action: {action_data.get('action_type')}")
66
 
67
  result = env.step(action_data)
 
72
  episodes_run += 1
73
 
74
  except Exception as e:
75
+ print(f" Error in episode {i+1}: {e}")
76
  continue
77
 
78
  print(f"\n--- AegisGym Reproducibility Report ---")
server/__init__.py CHANGED
Binary files a/server/__init__.py and b/server/__init__.py differ
 
server/server.py CHANGED
@@ -1,5 +1,5 @@
1
  import random
2
- from typing import Dict, Any, Tuple
3
  from openenv.core.env_server import Environment
4
  from .models import AuditAction, AuditObservation, AuditState
5
  from .grader import Grader
@@ -63,7 +63,7 @@ class AegisOpenEnv(Environment):
63
 
64
  def _set_next_scenario(self):
65
  tiers = ["easy", "medium", "hard"]
66
- self.current_tier = tiers[self.step_count % 3]
67
 
68
  if self.current_tier == "easy":
69
  self.current_target_id = random.choice(["ACC-BL-001", "ACC-CLEAN-01"])
 
1
  import random
2
+ from typing import Dict, Any, Tuple, List
3
  from openenv.core.env_server import Environment
4
  from .models import AuditAction, AuditObservation, AuditState
5
  from .grader import Grader
 
63
 
64
  def _set_next_scenario(self):
65
  tiers = ["easy", "medium", "hard"]
66
+ self.current_tier = tiers[self.step_count % len(tiers)]
67
 
68
  if self.current_tier == "easy":
69
  self.current_target_id = random.choice(["ACC-BL-001", "ACC-CLEAN-01"])