anthonym21's picture
Initial Commit with GRPO notebook
935a6ef
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Data models for the Slipstream Governance Environment."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import Field
try:
# When running with openenv-core installed
from openenv.core.env_server.types import Action, Observation, State
except ImportError: # pragma: no cover
from openenv.core.env_server.types import Action, Observation, State
class SlipstreamAction(Action):
"""Action for SlipstreamGov: the model's message to send through the governor."""
message: str = Field(..., min_length=1, description="Model output containing a SLIP message")
class SlipstreamObservation(Observation):
"""Observation returned by the governor after validation + scoring."""
# On reset
task_prompt: Optional[str] = Field(default=None, description="Prompt for the model (natural-language intent + constraints)")
# On step (evaluation)
parsed_slip: Optional[str] = Field(default=None, description="Extracted SLIP line (normalized)")
expected_anchor: Optional[str] = Field(default=None, description="Scenario's expected anchor")
predicted_anchor: Optional[str] = Field(default=None, description="Anchor parsed from model output")
arg_overlap: float = Field(default=0.0, ge=0.0, le=1.0, description="Fraction of expected args present in output")
violations: List[str] = Field(default_factory=list, description="Rule violations detected by the governor")
metrics: Dict[str, Any] = Field(default_factory=dict, description="Extra metrics for debugging / dashboards")
class SlipstreamState(State):
"""Environment state."""
scenario_id: Optional[int] = Field(default=None, description="Current scenario id")
attack: bool = Field(default=False, description="Whether this episode included a secret-injection attack")