KChad's picture
major changes: overhauling the entire thing to have a full backend setup with schemas and grader stuff. however things still need to be changed these are temporary
37a52ed
raw
history blame
2.43 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Echo Environment.
The Echo environment is a simple test environment that echoes back messages.
"""
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field, validator
class Action(BaseModel):
type: Literal["ReadDoc", "WriteDB", "SendEmail", "FlagInjection", "NoOp"]
to: Optional[str] = None
body: Optional[str] = None
content: Optional[str] = None
reason: Optional[str] = None
@validator("to", always=True)
def validate_send_email(cls, v, values):
if values.get("type") == "SendEmail" and not v:
raise ValueError("SendEmail requires a 'to' address")
return v
class Observation(BaseModel):
subject: str
body: str
last_action: Optional[str] = None
last_action_result: str
step_count: int
task_difficulty: Literal["easy", "medium", "hard"]
summary_required: bool = False
current_summary: Optional[str] = None
class Reward(BaseModel):
value: float
reason: str
breakdown: Dict[str, float]
@validator("value")
def clamp_reward(cls, v):
return max(-1.0, min(1.0, v))
class State(BaseModel):
current_email: Dict[str, Any] # Any because objectives is a nested dict
actions_taken: List[Dict[str, Any]] # Any because values can be None
cumulative_reward: float
step_count: int
done: bool
done_reason: Optional[Literal["exfil_detected", "injection_flagged", "max_steps"]] = None
task_difficulty: Literal["easy", "medium", "hard"]
current_summary: Optional[str] = None
class EpisodeResult(BaseModel):
final_score: float
breakdown: Dict[str, float]
exfil_occurred: bool
injection_flagged: bool
task_completed: bool
summary_score: Optional[float] = None
total_steps: int
trajectory: List[Dict[str, Any]] # Any because values can be None
if __name__ == "__main__":
a = Action(type="SendEmail", to="test@test.com", body="hello")
print(a)
a2 = Action(type="FlagInjection", reason="injection detected")
print(a2)
try:
bad = Action(type="SendEmail")
except Exception as e:
print(f"Correctly caught: {e}")