DataBoySu commited on
Commit ·
97fbf33
1
Parent(s): 81e1efb
deployment without ui
Browse files- .dockerignore +5 -1
- Dockerfile +62 -32
- README.md +85 -49
- client.py +45 -42
- graders/__init__.py +2 -0
- inference.py +5 -5
.dockerignore
CHANGED
|
@@ -45,4 +45,8 @@ pre-val.sh
|
|
| 45 |
# ── Misc ──────────────────────────────────────────────────────────────────────
|
| 46 |
test_redirect.py
|
| 47 |
openenv_AML_env.egg-info/
|
| 48 |
-
openenv_tracefix_rl.egg-info/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# ── Misc ──────────────────────────────────────────────────────────────────────
|
| 46 |
test_redirect.py
|
| 47 |
openenv_AML_env.egg-info/
|
| 48 |
+
openenv_tracefix_rl.egg-info/
|
| 49 |
+
.venv/
|
| 50 |
+
__pycache__/
|
| 51 |
+
*.pyc
|
| 52 |
+
.git/
|
Dockerfile
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# ============================================================
|
| 2 |
# AML Investigator — OpenEnv Environment
|
| 3 |
# Hugging Face Spaces compliant Docker image
|
|
@@ -10,55 +16,79 @@
|
|
| 10 |
# docker run -p 7860:7860 aml-env
|
| 11 |
# ============================================================
|
| 12 |
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
# git
|
| 18 |
RUN apt-get update && \
|
| 19 |
-
apt-get install -y --no-install-recommends
|
| 20 |
rm -rf /var/lib/apt/lists/*
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 26 |
-
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 27 |
-
mv /root/.local/bin/uvx /usr/local/bin/uvx
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
RUN if [ -f uv.lock ]; then \
|
| 41 |
uv sync --frozen --no-editable; \
|
| 42 |
else \
|
| 43 |
uv sync --no-editable; \
|
| 44 |
fi
|
| 45 |
|
| 46 |
-
#
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
ENV PATH="/app/.venv/bin:$PATH"
|
| 49 |
|
| 50 |
-
# PYTHONPATH → repo root
|
| 51 |
-
#
|
| 52 |
-
# from
|
| 53 |
-
|
|
|
|
| 54 |
|
| 55 |
-
# Hugging Face Spaces mandates port 7860
|
| 56 |
EXPOSE 7860
|
| 57 |
|
| 58 |
-
#
|
| 59 |
HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
|
| 60 |
CMD curl -f http://localhost:7860/health || exit 1
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
# ============================================================
|
| 8 |
# AML Investigator — OpenEnv Environment
|
| 9 |
# Hugging Face Spaces compliant Docker image
|
|
|
|
| 16 |
# docker run -p 7860:7860 aml-env
|
| 17 |
# ============================================================
|
| 18 |
|
| 19 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 20 |
+
FROM ${BASE_IMAGE} AS builder
|
| 21 |
|
| 22 |
+
WORKDIR /app
|
| 23 |
+
|
| 24 |
+
# git is needed for uv to resolve any VCS dependencies
|
| 25 |
RUN apt-get update && \
|
| 26 |
+
apt-get install -y --no-install-recommends git && \
|
| 27 |
rm -rf /var/lib/apt/lists/*
|
| 28 |
|
| 29 |
+
# Copy full build context (unwanted files pruned by .dockerignore)
|
| 30 |
+
COPY . /app/env
|
| 31 |
+
WORKDIR /app/env
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
# Ensure uv is available (the openenv-base image usually has it; install as fallback)
|
| 34 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 35 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 36 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 37 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 38 |
+
fi
|
| 39 |
|
| 40 |
+
# Install deps only (no project install yet) — uses --frozen so uv.lock is honoured
|
| 41 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 42 |
+
if [ -f uv.lock ]; then \
|
| 43 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 44 |
+
else \
|
| 45 |
+
uv sync --no-install-project --no-editable; \
|
| 46 |
+
fi
|
| 47 |
|
| 48 |
+
# Install the project itself into the venv
|
| 49 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 50 |
+
if [ -f uv.lock ]; then \
|
|
|
|
| 51 |
uv sync --frozen --no-editable; \
|
| 52 |
else \
|
| 53 |
uv sync --no-editable; \
|
| 54 |
fi
|
| 55 |
|
| 56 |
+
# ── Runtime stage ─────────────────────────────────────────────────────────────
|
| 57 |
+
FROM ${BASE_IMAGE}
|
| 58 |
+
|
| 59 |
+
WORKDIR /app
|
| 60 |
+
|
| 61 |
+
# curl is required for the HEALTHCHECK; install it in the RUNTIME stage
|
| 62 |
+
RUN apt-get update && \
|
| 63 |
+
apt-get install -y --no-install-recommends curl && \
|
| 64 |
+
rm -rf /var/lib/apt/lists/*
|
| 65 |
+
|
| 66 |
+
# Copy venv and source from builder
|
| 67 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 68 |
+
COPY --from=builder /app/env /app/env
|
| 69 |
+
|
| 70 |
+
# Create unprivileged user (good practice for HF Spaces)
|
| 71 |
+
RUN useradd -m -u 1000 appuser && \
|
| 72 |
+
chown -R appuser:appuser /app
|
| 73 |
+
|
| 74 |
+
# The venv bin directory must be first on PATH
|
| 75 |
ENV PATH="/app/.venv/bin:$PATH"
|
| 76 |
|
| 77 |
+
# PYTHONPATH → /app/env (repo root inside container)
|
| 78 |
+
# This makes both import styles work:
|
| 79 |
+
# from models import AmlAction (bare)
|
| 80 |
+
# from server.AML_env_environment import … (prefixed)
|
| 81 |
+
ENV PYTHONPATH="/app/env"
|
| 82 |
|
| 83 |
+
# Hugging Face Spaces mandates port 7860
|
| 84 |
EXPOSE 7860
|
| 85 |
|
| 86 |
+
# Health check — verifiable with `docker inspect`
|
| 87 |
HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
|
| 88 |
CMD curl -f http://localhost:7860/health || exit 1
|
| 89 |
|
| 90 |
+
WORKDIR /app/env
|
| 91 |
+
USER appuser
|
| 92 |
+
|
| 93 |
+
# Start the OpenEnv FastAPI server
|
| 94 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -11,9 +11,11 @@ tags:
|
|
| 11 |
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
-
#
|
| 15 |
|
| 16 |
-
A
|
|
|
|
|
|
|
| 17 |
|
| 18 |
## Quick Start
|
| 19 |
|
|
@@ -23,26 +25,33 @@ The simplest way to use the Aml Env environment is through the `AmlEnv` class:
|
|
| 23 |
from AML_env import AmlAction, AmlEnv
|
| 24 |
|
| 25 |
try:
|
| 26 |
-
# Create environment from Docker image
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# Reset
|
| 30 |
-
|
| 31 |
-
print(f"
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
finally:
|
| 44 |
-
|
| 45 |
-
AML_envenv.close()
|
| 46 |
```
|
| 47 |
|
| 48 |
That's it! The `AmlEnv.from_docker_image()` method handles:
|
|
@@ -57,7 +66,7 @@ Before using the environment, you need to build the Docker image:
|
|
| 57 |
|
| 58 |
```bash
|
| 59 |
# From project root
|
| 60 |
-
docker build -t
|
| 61 |
```
|
| 62 |
|
| 63 |
## Deploying to Hugging Face Spaces
|
|
@@ -118,23 +127,34 @@ The deployed space includes:
|
|
| 118 |
|
| 119 |
## Environment Details
|
| 120 |
|
| 121 |
-
### Action
|
| 122 |
-
**AmlAction**
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
### Reward
|
| 134 |
-
|
| 135 |
-
-
|
| 136 |
-
-
|
| 137 |
-
-
|
| 138 |
|
| 139 |
## Advanced Usage
|
| 140 |
|
|
@@ -239,17 +259,33 @@ uvicorn server.app:app --reload
|
|
| 239 |
|
| 240 |
```
|
| 241 |
AML_env/
|
| 242 |
-
├──
|
| 243 |
-
├──
|
| 244 |
-
├──
|
| 245 |
-
├──
|
| 246 |
-
├──
|
| 247 |
-
├──
|
| 248 |
-
├──
|
| 249 |
-
├──
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
```
|
|
|
|
| 11 |
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# AML Investigator Environment
|
| 15 |
|
| 16 |
+
A financial crime investigation environment for Reinforcement Learning agents.
|
| 17 |
+
The agent must query a mock banking system (transactions, KYC records) under a strict API budget
|
| 18 |
+
to investigate flagged accounts and submit a final fraud/clear decision.
|
| 19 |
|
| 20 |
## Quick Start
|
| 21 |
|
|
|
|
| 25 |
from AML_env import AmlAction, AmlEnv
|
| 26 |
|
| 27 |
try:
|
| 28 |
+
# Create environment from Docker image (built from root Dockerfile)
|
| 29 |
+
env = AmlEnv.from_docker_image("aml-env:latest")
|
| 30 |
+
|
| 31 |
+
# Reset to a specific task
|
| 32 |
+
obs = env.reset(task="aml_easy")
|
| 33 |
+
print(f"Alert: {obs.observation.alert_details}")
|
| 34 |
+
print(f"Budget: {obs.observation.budget_remaining}")
|
| 35 |
+
|
| 36 |
+
# Query transactions
|
| 37 |
+
result = env.step(AmlAction(action={
|
| 38 |
+
"action_type": "query_transactions",
|
| 39 |
+
"account_id": "ACC-9001",
|
| 40 |
+
"limit": 10,
|
| 41 |
+
"offset": 0,
|
| 42 |
+
}))
|
| 43 |
+
print(f"Transactions: {result.observation.last_action_result}")
|
| 44 |
+
|
| 45 |
+
# Submit final decision
|
| 46 |
+
result = env.step(AmlAction(action={
|
| 47 |
+
"action_type": "submit_decision",
|
| 48 |
+
"decision": "CLEAR",
|
| 49 |
+
"evidence_links": [],
|
| 50 |
+
}))
|
| 51 |
+
print(f"Done: {result.done}, Reward: {result.reward}")
|
| 52 |
|
| 53 |
finally:
|
| 54 |
+
env.close()
|
|
|
|
| 55 |
```
|
| 56 |
|
| 57 |
That's it! The `AmlEnv.from_docker_image()` method handles:
|
|
|
|
| 66 |
|
| 67 |
```bash
|
| 68 |
# From project root
|
| 69 |
+
docker build -t aml-env:latest .
|
| 70 |
```
|
| 71 |
|
| 72 |
## Deploying to Hugging Face Spaces
|
|
|
|
| 127 |
|
| 128 |
## Environment Details
|
| 129 |
|
| 130 |
+
### Action Space
|
| 131 |
+
**AmlAction** wraps one of four tool calls (discriminated by `action_type`):
|
| 132 |
+
|
| 133 |
+
| Tool | Fields | Description |
|
| 134 |
+
|---|---|---|
|
| 135 |
+
| `query_transactions` | `account_id`, `limit`, `offset` | Paginated transaction history for an account |
|
| 136 |
+
| `search_transactions` | `account_id`, `keyword` | Search memo_text of transactions |
|
| 137 |
+
| `get_kyc_record` | `entity_id` | Retrieve KYC data for an entity |
|
| 138 |
+
| `submit_decision` | `decision` (`FRAUD`\|`CLEAR`), `evidence_links` | Final verdict — ends the episode |
|
| 139 |
+
|
| 140 |
+
### Observation Space
|
| 141 |
+
**AmlObservation** is returned after every `reset()` and `step()`:
|
| 142 |
+
|
| 143 |
+
| Field | Type | Description |
|
| 144 |
+
|---|---|---|
|
| 145 |
+
| `alert_details` | `str` | The investigation mission (constant per episode) |
|
| 146 |
+
| `budget_remaining` | `int` | API calls left before forced termination |
|
| 147 |
+
| `last_action` | `str \| None` | Name of the last tool called |
|
| 148 |
+
| `last_action_result` | `Any` | Payload returned by the last tool |
|
| 149 |
+
| `error_message` | `str \| None` | Error string if the last action failed |
|
| 150 |
+
| `done` | `bool` | Whether the episode has ended |
|
| 151 |
+
| `reward` | `float` | Per-step reward signal |
|
| 152 |
|
| 153 |
### Reward
|
| 154 |
+
- **Per step:** `-0.02` (efficiency penalty discourages random looping)
|
| 155 |
+
- **Submit FRAUD (correct):** grader returns `0.4`–`1.0` depending on evidence quality
|
| 156 |
+
- **Submit CLEAR (correct false positive):** grader returns `1.0`
|
| 157 |
+
- **Budget exhausted without submission:** episode ends with accumulated negative rewards
|
| 158 |
|
| 159 |
## Advanced Usage
|
| 160 |
|
|
|
|
| 259 |
|
| 260 |
```
|
| 261 |
AML_env/
|
| 262 |
+
├── Dockerfile # Container image (root, HF Spaces compliant)
|
| 263 |
+
├── .dockerignore # Docker build exclusions
|
| 264 |
+
├── .hfignore # HF Space upload exclusions
|
| 265 |
+
├── .gitignore # Git exclusions
|
| 266 |
+
├── __init__.py # Package exports (AmlEnv, AmlAction, AmlObservation)
|
| 267 |
+
├── client.py # AmlEnv WebSocket client
|
| 268 |
+
├── models.py # Pydantic action/observation schemas
|
| 269 |
+
├── inference.py # Baseline RL agent (OpenAI client, [START]/[STEP]/[END] logs)
|
| 270 |
+
├── openenv.yaml # OpenEnv manifest (tasks, graders, port)
|
| 271 |
+
├── pyproject.toml # Project metadata and uv dependencies
|
| 272 |
+
├── uv.lock # Locked dependency graph
|
| 273 |
+
├── README.md # This file (also HF Space card)
|
| 274 |
+
├── data/
|
| 275 |
+
│ ├── entities.json # 312 KYC entity records
|
| 276 |
+
│ ├── accounts.json # 410 bank accounts
|
| 277 |
+
│ └── transactions.json # 5,079 transactions (haystack + fraud scenarios)
|
| 278 |
+
├── graders/
|
| 279 |
+
│ ├── __init__.py
|
| 280 |
+
│ ├── aml_easy.py # "The False Positive" grader
|
| 281 |
+
│ ├── aml_medium.py # "The Smurf Network" grader
|
| 282 |
+
│ └── aml_hard.py # "The Corporate Mirage" grader
|
| 283 |
+
├── server/
|
| 284 |
+
│ ├── __init__.py
|
| 285 |
+
│ ├── AML_env_environment.py # Core OpenEnv environment (reset/step/state)
|
| 286 |
+
│ ├── app.py # FastAPI server (CORS, create_app wrapper)
|
| 287 |
+
│ └── requirements.txt # Pip fallback requirements
|
| 288 |
+
└── tools/
|
| 289 |
+
├── haystack.py # Financial graph generator
|
| 290 |
+
└── tasks.json # Manual fraud scenario definitions
|
| 291 |
```
|
client.py
CHANGED
|
@@ -4,7 +4,11 @@
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from typing import Dict
|
| 10 |
|
|
@@ -15,83 +19,82 @@ from openenv.core.env_server.types import State
|
|
| 15 |
from .models import AmlAction, AmlObservation
|
| 16 |
|
| 17 |
|
| 18 |
-
class AmlEnv(
|
| 19 |
-
EnvClient[AmlAction, AmlObservation, State]
|
| 20 |
-
):
|
| 21 |
"""
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
enabling efficient multi-step
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
>>> # Connect to a running server
|
| 30 |
-
>>> with AmlEnv(base_url="http://localhost:8000") as client:
|
| 31 |
-
... result = client.reset()
|
| 32 |
-
... print(result.observation.echoed_message)
|
| 33 |
-
...
|
| 34 |
-
... result = client.step(AmlAction(message="Hello!"))
|
| 35 |
-
... print(result.observation.echoed_message)
|
| 36 |
-
|
| 37 |
-
Example with Docker:
|
| 38 |
-
>>> # Automatically start container and connect
|
| 39 |
-
>>> client = AmlEnv.from_docker_image("AML_env-env:latest")
|
| 40 |
>>> try:
|
| 41 |
-
...
|
| 42 |
-
... result = client.step(AmlAction(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
... finally:
|
| 44 |
... client.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
|
| 47 |
def _step_payload(self, action: AmlAction) -> Dict:
|
| 48 |
"""
|
| 49 |
-
|
| 50 |
|
| 51 |
Args:
|
| 52 |
-
action: AmlAction
|
| 53 |
|
| 54 |
Returns:
|
| 55 |
-
|
| 56 |
"""
|
| 57 |
-
return
|
| 58 |
-
"message": action.message,
|
| 59 |
-
}
|
| 60 |
|
| 61 |
def _parse_result(self, payload: Dict) -> StepResult[AmlObservation]:
|
| 62 |
"""
|
| 63 |
-
|
| 64 |
|
| 65 |
Args:
|
| 66 |
-
payload: JSON response
|
| 67 |
|
| 68 |
Returns:
|
| 69 |
-
StepResult
|
| 70 |
"""
|
| 71 |
obs_data = payload.get("observation", {})
|
| 72 |
observation = AmlObservation(
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
done=payload.get("done", False),
|
| 76 |
-
reward=payload.get("reward"),
|
| 77 |
-
metadata=obs_data.get("metadata", {}),
|
| 78 |
)
|
| 79 |
-
|
| 80 |
return StepResult(
|
| 81 |
observation=observation,
|
| 82 |
-
reward=payload.get("reward"),
|
| 83 |
done=payload.get("done", False),
|
| 84 |
)
|
| 85 |
|
| 86 |
def _parse_state(self, payload: Dict) -> State:
|
| 87 |
"""
|
| 88 |
-
|
| 89 |
|
| 90 |
Args:
|
| 91 |
-
payload: JSON response from
|
| 92 |
|
| 93 |
Returns:
|
| 94 |
-
State
|
| 95 |
"""
|
| 96 |
return State(
|
| 97 |
episode_id=payload.get("episode_id"),
|
|
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
"""AML Investigator Environment Client.
|
| 8 |
+
|
| 9 |
+
High-level WebSocket client that wraps the OpenEnv EnvClient base class
|
| 10 |
+
with AML-specific action/observation types.
|
| 11 |
+
"""
|
| 12 |
|
| 13 |
from typing import Dict
|
| 14 |
|
|
|
|
| 19 |
from .models import AmlAction, AmlObservation
|
| 20 |
|
| 21 |
|
| 22 |
+
class AmlEnv(EnvClient[AmlAction, AmlObservation, State]):
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
+
WebSocket client for the AML Investigator environment.
|
| 25 |
+
|
| 26 |
+
Maintains a persistent WebSocket connection to the environment server,
|
| 27 |
+
enabling efficient multi-step investigations with lower per-step latency.
|
| 28 |
+
|
| 29 |
+
Example (Docker):
|
| 30 |
+
>>> client = AmlEnv.from_docker_image("aml-env:latest")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
>>> try:
|
| 32 |
+
... obs = client.reset(task="aml_easy")
|
| 33 |
+
... result = client.step(AmlAction(action={
|
| 34 |
+
... "action_type": "query_transactions",
|
| 35 |
+
... "account_id": "ACC-9001"
|
| 36 |
+
... }))
|
| 37 |
+
... print(result.observation.last_action_result)
|
| 38 |
... finally:
|
| 39 |
... client.close()
|
| 40 |
+
|
| 41 |
+
Example (existing server):
|
| 42 |
+
>>> with AmlEnv(base_url="http://localhost:7860") as env:
|
| 43 |
+
... obs = env.reset(task="aml_easy")
|
| 44 |
+
... result = env.step(AmlAction(action={
|
| 45 |
+
... "action_type": "submit_decision",
|
| 46 |
+
... "decision": "CLEAR",
|
| 47 |
+
... "evidence_links": []
|
| 48 |
+
... }))
|
| 49 |
"""
|
| 50 |
|
| 51 |
def _step_payload(self, action: AmlAction) -> Dict:
|
| 52 |
"""
|
| 53 |
+
Serialize AmlAction to the JSON dict sent over the WebSocket.
|
| 54 |
|
| 55 |
Args:
|
| 56 |
+
action: Typed AmlAction wrapper containing the specific tool call.
|
| 57 |
|
| 58 |
Returns:
|
| 59 |
+
Dict with the nested ``action`` key the server expects.
|
| 60 |
"""
|
| 61 |
+
return action.model_dump()
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def _parse_result(self, payload: Dict) -> StepResult[AmlObservation]:
|
| 64 |
"""
|
| 65 |
+
Deserialize the server's JSON response into a typed StepResult.
|
| 66 |
|
| 67 |
Args:
|
| 68 |
+
payload: Raw JSON response dict from the server.
|
| 69 |
|
| 70 |
Returns:
|
| 71 |
+
StepResult containing an AmlObservation.
|
| 72 |
"""
|
| 73 |
obs_data = payload.get("observation", {})
|
| 74 |
observation = AmlObservation(
|
| 75 |
+
alert_details=obs_data.get("alert_details", ""),
|
| 76 |
+
budget_remaining=obs_data.get("budget_remaining", 0),
|
| 77 |
+
last_action=obs_data.get("last_action"),
|
| 78 |
+
last_action_result=obs_data.get("last_action_result"),
|
| 79 |
+
error_message=obs_data.get("error_message"),
|
| 80 |
done=payload.get("done", False),
|
| 81 |
+
reward=payload.get("reward", 0.0),
|
|
|
|
| 82 |
)
|
|
|
|
| 83 |
return StepResult(
|
| 84 |
observation=observation,
|
| 85 |
+
reward=payload.get("reward", 0.0),
|
| 86 |
done=payload.get("done", False),
|
| 87 |
)
|
| 88 |
|
| 89 |
def _parse_state(self, payload: Dict) -> State:
|
| 90 |
"""
|
| 91 |
+
Deserialize the server's /state response into a State object.
|
| 92 |
|
| 93 |
Args:
|
| 94 |
+
payload: Raw JSON response dict from the server.
|
| 95 |
|
| 96 |
Returns:
|
| 97 |
+
State with episode_id and step_count.
|
| 98 |
"""
|
| 99 |
return State(
|
| 100 |
episode_id=payload.get("episode_id"),
|
graders/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Graders package — makes graders/ a proper Python package so OpenEnv can
|
| 2 |
+
# resolve grader paths like "graders.aml_easy:grade" as module imports.
|
inference.py
CHANGED
|
@@ -15,14 +15,14 @@ from openenv.core.env_server.interfaces import Environment
|
|
| 15 |
from server.AML_env_environment import AmlEnvironment
|
| 16 |
from models import AmlAction
|
| 17 |
|
| 18 |
-
API_KEY = os.getenv("HF_TOKEN") or
|
| 19 |
-
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 20 |
-
MODEL_NAME = os.getenv("MODEL_NAME") or "
|
| 21 |
|
| 22 |
# Must match openenv.yaml EXACTLY
|
| 23 |
TASKS = ["aml_easy", "aml_medium", "aml_hard"]
|
| 24 |
BENCHMARK = "aml_investigator"
|
| 25 |
-
MAX_STEPS = 25
|
| 26 |
|
| 27 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 28 |
"""
|
|
@@ -125,7 +125,7 @@ async def main() -> None:
|
|
| 125 |
break
|
| 126 |
|
| 127 |
# Calculate a baseline score for the stdout logs (Graders handle real scoring)
|
| 128 |
-
score = sum(rewards) + 1.0 if "submit_decision" in obs.last_action else 0.0
|
| 129 |
score = min(max(score, 0.0), 1.0)
|
| 130 |
success = score > 0.5
|
| 131 |
|
|
|
|
| 15 |
from server.AML_env_environment import AmlEnvironment
|
| 16 |
from models import AmlAction
|
| 17 |
|
| 18 |
+
API_KEY = os.getenv("HF_TOKEN") or "lm-studio"
|
| 19 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" or "http://localhost:1234/v1"
|
| 20 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "openai/gpt-oss-20b"
|
| 21 |
|
| 22 |
# Must match openenv.yaml EXACTLY
|
| 23 |
TASKS = ["aml_easy", "aml_medium", "aml_hard"]
|
| 24 |
BENCHMARK = "aml_investigator"
|
| 25 |
+
MAX_STEPS = 25
|
| 26 |
|
| 27 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 28 |
"""
|
|
|
|
| 125 |
break
|
| 126 |
|
| 127 |
# Calculate a baseline score for the stdout logs (Graders handle real scoring)
|
| 128 |
+
score = sum(rewards) + 1.0 if "submit_decision" in (obs.last_action or "") else 0.0
|
| 129 |
score = min(max(score, 0.0), 1.0)
|
| 130 |
success = score > 0.5
|
| 131 |
|