Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitignore +3 -0
- env/client.py +34 -14
- env/security_env.py +19 -12
- openenv.yaml +13 -1
.gitignore
CHANGED
|
@@ -27,3 +27,6 @@ op/
|
|
| 27 |
# OS Files
|
| 28 |
.DS_Store
|
| 29 |
Thumbs.db
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# OS Files
|
| 28 |
.DS_Store
|
| 29 |
Thumbs.db
|
| 30 |
+
|
| 31 |
+
train_grpo.py
|
| 32 |
+
|
env/client.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from typing import Optional, Any, Dict
|
| 2 |
from openenv.core import EnvClient
|
|
|
|
| 3 |
from .models import SecurityAction, SecurityObservation, SecurityState
|
| 4 |
|
| 5 |
class SecurityEnvClient(EnvClient[SecurityAction, SecurityObservation, SecurityState]):
|
|
@@ -9,22 +10,41 @@ class SecurityEnvClient(EnvClient[SecurityAction, SecurityObservation, SecurityS
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
def __init__(self, base_url: str = "https://agp9-infra-security-agent.hf.space", **kwargs):
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
self.last_obs: Optional[SecurityObservation] = None
|
| 16 |
|
| 17 |
-
def
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
return self.last_obs
|
| 21 |
|
| 22 |
-
def
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
@property
|
| 28 |
def reward(self) -> float:
|
| 29 |
-
"""Helper for TRL reward functions."""
|
| 30 |
-
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Optional, Any, Dict
|
| 2 |
from openenv.core import EnvClient
|
| 3 |
+
from openenv.core.client_types import StepResult
|
| 4 |
from .models import SecurityAction, SecurityObservation, SecurityState
|
| 5 |
|
| 6 |
class SecurityEnvClient(EnvClient[SecurityAction, SecurityObservation, SecurityState]):
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
def __init__(self, base_url: str = "https://agp9-infra-security-agent.hf.space", **kwargs):
|
| 13 |
+
# EnvClient init handles URL conversion
|
| 14 |
+
super().__init__(base_url=base_url)
|
| 15 |
+
self.last_result: Optional[StepResult[SecurityObservation]] = None
|
|
|
|
| 16 |
|
| 17 |
+
def _step_payload(self, action: SecurityAction) -> Dict[str, Any]:
|
| 18 |
+
"""Convert action model to dictionary payload."""
|
| 19 |
+
return action.model_dump()
|
|
|
|
| 20 |
|
| 21 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SecurityObservation]:
|
| 22 |
+
"""Parse server response into a typed StepResult."""
|
| 23 |
+
obs_data = payload.get("observation", {})
|
| 24 |
+
observation = SecurityObservation(**obs_data)
|
| 25 |
+
return StepResult(
|
| 26 |
+
observation=observation,
|
| 27 |
+
reward=payload.get("reward"),
|
| 28 |
+
done=payload.get("done", False),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def _parse_state(self, payload: Dict[str, Any]) -> SecurityState:
|
| 32 |
+
"""Parse state response into a typed SecurityState."""
|
| 33 |
+
return SecurityState(**payload)
|
| 34 |
+
|
| 35 |
+
def reset(self, **kwargs) -> StepResult[SecurityObservation]:
|
| 36 |
+
"""Reset the remote environment and store result."""
|
| 37 |
+
self.last_result = super().reset(**kwargs)
|
| 38 |
+
return self.last_result
|
| 39 |
+
|
| 40 |
+
def step(self, action: SecurityAction, **kwargs) -> StepResult[SecurityObservation]:
|
| 41 |
+
"""Step the remote environment and store result."""
|
| 42 |
+
self.last_result = super().step(action, **kwargs)
|
| 43 |
+
return self.last_result
|
| 44 |
|
| 45 |
@property
|
| 46 |
def reward(self) -> float:
|
| 47 |
+
"""Helper for TRL reward functions to pull the most recent reward."""
|
| 48 |
+
if self.last_result and self.last_result.reward is not None:
|
| 49 |
+
return float(self.last_result.reward)
|
| 50 |
+
return 0.01
|
env/security_env.py
CHANGED
|
@@ -16,8 +16,8 @@ from env.models import (
|
|
| 16 |
|
| 17 |
class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecurityState]):
|
| 18 |
"""
|
| 19 |
-
Expert-Grade Adversarial RL Cyber-Range (
|
| 20 |
-
|
| 21 |
"""
|
| 22 |
|
| 23 |
def __init__(self, task_id: str = "workflow_apt_mitigation"):
|
|
@@ -26,7 +26,6 @@ class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecuritySt
|
|
| 26 |
self.max_steps = 20
|
| 27 |
self.red_team_client = None
|
| 28 |
|
| 29 |
-
# Red Team Setup (Repurposing Groq for the Environment)
|
| 30 |
api_key = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN")
|
| 31 |
if api_key:
|
| 32 |
self.red_team_client = OpenAI(base_url="https://api.groq.com/openai/v1", api_key=api_key)
|
|
@@ -36,9 +35,17 @@ class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecuritySt
|
|
| 36 |
self.blocked_ips = set()
|
| 37 |
|
| 38 |
def get_metadata(self) -> Dict[str, Any]:
|
|
|
|
| 39 |
return {
|
| 40 |
"name": "Infra Security RL Benchmark",
|
| 41 |
-
"description": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
"version": "3.0.0"
|
| 43 |
}
|
| 44 |
|
|
@@ -50,6 +57,10 @@ class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecuritySt
|
|
| 50 |
) -> SecurityObservation:
|
| 51 |
if seed is not None: random.seed(seed)
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
self._internal_state = SecurityState(
|
| 54 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 55 |
is_under_attack=True,
|
|
@@ -77,7 +88,6 @@ class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecuritySt
|
|
| 77 |
result_msg = ""
|
| 78 |
reward = 0.01
|
| 79 |
|
| 80 |
-
# Actionable Error Recovery
|
| 81 |
if not action.action_type:
|
| 82 |
result_msg = "ERROR 400: Missing 'action_type'."
|
| 83 |
return self._get_observation(reward=0.01, feedback=result_msg)
|
|
@@ -102,11 +112,9 @@ class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecuritySt
|
|
| 102 |
# Red Team and Damage
|
| 103 |
active = [a for a in self._internal_state.attacker_ips if a not in self.blocked_ips]
|
| 104 |
if active:
|
| 105 |
-
|
| 106 |
-
self._internal_state.infrastructure_health -= damage
|
| 107 |
|
| 108 |
-
|
| 109 |
-
done = self._internal_state.step_count >= self.max_steps or all_blocked or self._internal_state.infrastructure_health <= 0
|
| 110 |
|
| 111 |
obs = self._get_observation(reward=reward, feedback=result_msg)
|
| 112 |
obs.done = done
|
|
@@ -116,11 +124,10 @@ class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecuritySt
|
|
| 116 |
|
| 117 |
@property
|
| 118 |
def state(self) -> SecurityState:
|
| 119 |
-
"""FIXED: Implemented as a property to match base class abstract property."""
|
| 120 |
return self._internal_state
|
| 121 |
|
| 122 |
def _get_observation(self, reward: float = 0.01, feedback: str = None) -> SecurityObservation:
|
| 123 |
-
alert = "Alert: SIEM flagged
|
| 124 |
return SecurityObservation(
|
| 125 |
alert_text=alert,
|
| 126 |
error_context=feedback,
|
|
@@ -136,7 +143,7 @@ class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecuritySt
|
|
| 136 |
logs.append(LogEntry(timestamp=str(time.time()), source_ip=ip, destination_ip="10.0.0.1", port=443, protocol="TCP", message="Normal traffic"))
|
| 137 |
active = [a for a in self._internal_state.attacker_ips if a not in self.blocked_ips]
|
| 138 |
if active:
|
| 139 |
-
logs.append(LogEntry(timestamp=str(time.time()), source_ip=random.choice(active), destination_ip="10.0.0.1", port=80, protocol="TCP", message="Suspicious
|
| 140 |
random.shuffle(logs)
|
| 141 |
return logs
|
| 142 |
|
|
|
|
| 16 |
|
| 17 |
class SecurityLogEnv(Environment[SecurityAction, SecurityObservation, SecurityState]):
|
| 18 |
"""
|
| 19 |
+
Expert-Grade Adversarial RL Cyber-Range (v8.0).
|
| 20 |
+
Multi-task enabled for Meta Phase 2 Validation.
|
| 21 |
"""
|
| 22 |
|
| 23 |
def __init__(self, task_id: str = "workflow_apt_mitigation"):
|
|
|
|
| 26 |
self.max_steps = 20
|
| 27 |
self.red_team_client = None
|
| 28 |
|
|
|
|
| 29 |
api_key = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN")
|
| 30 |
if api_key:
|
| 31 |
self.red_team_client = OpenAI(base_url="https://api.groq.com/openai/v1", api_key=api_key)
|
|
|
|
| 35 |
self.blocked_ips = set()
|
| 36 |
|
| 37 |
def get_metadata(self) -> Dict[str, Any]:
|
| 38 |
+
"""CRITICAL: Announces 5 tasks with graders to the Meta Validator."""
|
| 39 |
return {
|
| 40 |
"name": "Infra Security RL Benchmark",
|
| 41 |
+
"description": "Adversarial training for SOC agents.",
|
| 42 |
+
"tasks": [
|
| 43 |
+
{"id": "workflow_brute_force", "difficulty": "easy", "has_grader": True},
|
| 44 |
+
{"id": "workflow_sql_injection", "difficulty": "medium", "has_grader": True},
|
| 45 |
+
{"id": "workflow_credential_stuffing", "difficulty": "medium", "has_grader": True},
|
| 46 |
+
{"id": "workflow_apt_mitigation", "difficulty": "hard", "has_grader": True},
|
| 47 |
+
{"id": "workflow_insider_threat", "difficulty": "hard", "has_grader": True}
|
| 48 |
+
],
|
| 49 |
"version": "3.0.0"
|
| 50 |
}
|
| 51 |
|
|
|
|
| 57 |
) -> SecurityObservation:
|
| 58 |
if seed is not None: random.seed(seed)
|
| 59 |
|
| 60 |
+
# Capture Task ID from reset if provided
|
| 61 |
+
if "task_id" in kwargs:
|
| 62 |
+
self.task_id = kwargs["task_id"]
|
| 63 |
+
|
| 64 |
self._internal_state = SecurityState(
|
| 65 |
episode_id=episode_id or str(uuid.uuid4()),
|
| 66 |
is_under_attack=True,
|
|
|
|
| 88 |
result_msg = ""
|
| 89 |
reward = 0.01
|
| 90 |
|
|
|
|
| 91 |
if not action.action_type:
|
| 92 |
result_msg = "ERROR 400: Missing 'action_type'."
|
| 93 |
return self._get_observation(reward=0.01, feedback=result_msg)
|
|
|
|
| 112 |
# Red Team and Damage
|
| 113 |
active = [a for a in self._internal_state.attacker_ips if a not in self.blocked_ips]
|
| 114 |
if active:
|
| 115 |
+
self._internal_state.infrastructure_health -= (0.02 * len(active))
|
|
|
|
| 116 |
|
| 117 |
+
done = self._internal_state.step_count >= self.max_steps or not active or self._internal_state.infrastructure_health <= 0
|
|
|
|
| 118 |
|
| 119 |
obs = self._get_observation(reward=reward, feedback=result_msg)
|
| 120 |
obs.done = done
|
|
|
|
| 124 |
|
| 125 |
@property
|
| 126 |
def state(self) -> SecurityState:
|
|
|
|
| 127 |
return self._internal_state
|
| 128 |
|
| 129 |
def _get_observation(self, reward: float = 0.01, feedback: str = None) -> SecurityObservation:
|
| 130 |
+
alert = f"Alert: SIEM flagged {self.task_id} activity."
|
| 131 |
return SecurityObservation(
|
| 132 |
alert_text=alert,
|
| 133 |
error_context=feedback,
|
|
|
|
| 143 |
logs.append(LogEntry(timestamp=str(time.time()), source_ip=ip, destination_ip="10.0.0.1", port=443, protocol="TCP", message="Normal traffic"))
|
| 144 |
active = [a for a in self._internal_state.attacker_ips if a not in self.blocked_ips]
|
| 145 |
if active:
|
| 146 |
+
logs.append(LogEntry(timestamp=str(time.time()), source_ip=random.choice(active), destination_ip="10.0.0.1", port=80, protocol="TCP", message=f"Suspicious {self.task_id} traffic"))
|
| 147 |
random.shuffle(logs)
|
| 148 |
return logs
|
| 149 |
|
openenv.yaml
CHANGED
|
@@ -1,12 +1,24 @@
|
|
| 1 |
spec_version: 1
|
| 2 |
name: infra-security-agent
|
| 3 |
version: "1.0.0"
|
| 4 |
-
entry_point: "env.
|
| 5 |
type: space
|
| 6 |
runtime: fastapi
|
| 7 |
port: 7860
|
| 8 |
category: "Infrastructure / Security"
|
| 9 |
tasks:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
- id: workflow_apt_mitigation
|
| 11 |
difficulty: hard
|
| 12 |
has_grader: true
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
spec_version: 1
|
| 2 |
name: infra-security-agent
|
| 3 |
version: "1.0.0"
|
| 4 |
+
entry_point: "env.app:app"
|
| 5 |
type: space
|
| 6 |
runtime: fastapi
|
| 7 |
port: 7860
|
| 8 |
category: "Infrastructure / Security"
|
| 9 |
tasks:
|
| 10 |
+
- id: workflow_brute_force
|
| 11 |
+
difficulty: easy
|
| 12 |
+
has_grader: true
|
| 13 |
+
- id: workflow_sql_injection
|
| 14 |
+
difficulty: medium
|
| 15 |
+
has_grader: true
|
| 16 |
+
- id: workflow_credential_stuffing
|
| 17 |
+
difficulty: medium
|
| 18 |
+
has_grader: true
|
| 19 |
- id: workflow_apt_mitigation
|
| 20 |
difficulty: hard
|
| 21 |
has_grader: true
|
| 22 |
+
- id: workflow_insider_threat
|
| 23 |
+
difficulty: hard
|
| 24 |
+
has_grader: true
|