meta_ai_hackathon / models.py
GOOD CAT
Final submission prep
ec8c511
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# Standard OpenEnv types (if openenv-core is installed)
try:
from openenv.core.env_server.types import Action, Observation
except ImportError:
# Fallback if not installed
class Action(BaseModel):
pass
class Observation(BaseModel):
pass
# --- Custom Action/Observation classes as seen in video ---
class FirewallAction(Action):
"""Action for the AI Firewall environment."""
action: int = Field(..., description="Action index: 0=ALLOW, 1=BLOCK, 2=INSPECT, 3=SANDBOX, 4=RATE_LIMIT, 5=QUARANTINE")
session_id: Optional[str] = Field(None, description="Specific session to act upon")
class FirewallObservation(Observation):
"""Observation for the AI Firewall environment."""
features: List[float] = Field(..., description="22-dimensional normalized feature vector")
focus_session_id: Optional[str] = Field(None, description="ID of the session currently in focus")
# --- Original models from env/models.py ---
class ActionRecord(BaseModel):
tick: int
session_id: str
action: int
action_name: str
malicious: bool
reward: float
components: Dict[str, float]
class ResetRequest(BaseModel):
task: str = Field(default="easy", description="Task difficulty: easy, medium, hard")
seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
class StepRequest(BaseModel):
actions: Dict[str, int] = Field(default_factory=dict, description="Map of session_id to action index")
class StepSingleRequest(BaseModel):
action: int = Field(..., description="Action index (0-5) for the current focus session")
class ToolRequest(BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the tool call")
class StateResponse(BaseModel):
episode_id: int
task: str
step_count: int
current_tick: int
observation_dim: int
num_actions: int
budget_remaining: float
total_reward: float
pending_session_count: int
inspected_session_count: int
pending_session_ids: List[str]
inspected_session_ids: List[str]
queue_length: int
focus_session_id: Optional[str]
focus_observation: List[float]
class StepResponse(BaseModel):
reward: float
done: bool
state: StateResponse
info: Dict[str, Any]
class StepSingleResponse(BaseModel):
observation: List[float]
reward: float
done: bool
state: StateResponse
info: Dict[str, Any]
class EvaluateSessionResponse(BaseModel):
session_id: str
features: Dict[str, Any]
observation: List[float]
is_inspected: bool
revealed_malicious: Optional[bool]
expires_tick: int
class NetworkStatsResponse(BaseModel):
episode_id: int
task: str
tick: int
step_count: int
total_reward: float
budget_remaining: float
budget_used_pct: float
total_malicious: int
total_benign: int
detection_rate: float
false_positive_rate: float
efficiency: float
early_detection_bonus: float
cascade_prevention: float
correct_allows: int
inspections: int
expired_malicious: int
expired_benign: int
class HealthResponse(BaseModel):
status: str
version: str
class ToolsListResponse(BaseModel):
tools: List[str]
class TakeActionResponse(BaseModel):
reward: float
record: ActionRecord
class LLMChatRequest(BaseModel):
prompt: str
api_key: Optional[str] = None
base_url: Optional[str] = None
model: Optional[str] = None
class LLMChatResponse(BaseModel):
content: str
model: str
class LLMConfigResponse(BaseModel):
base_url: str
model: str
has_api_key: bool
class LLMTestResponse(BaseModel):
ok: bool
model: str
content: str