File size: 5,052 Bytes
65b799e
 
 
 
2f5db5e
daba1b9
65b799e
 
2fccde8
 
 
 
 
 
fe3a41d
 
 
 
 
 
65b799e
 
fe3a41d
daba1b9
 
fe3a41d
daba1b9
 
 
fe3a41d
65b799e
 
6deaccc
 
 
 
 
 
 
 
 
5e0e606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc237b
 
2fccde8
 
 
5e0e606
cdc237b
5e0e606
cdc237b
 
5e0e606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc237b
5e0e606
 
 
 
 
 
 
 
2f5db5e
65b799e
daba1b9
65b799e
 
 
 
 
2f5db5e
 
daba1b9
2f5db5e
daba1b9
 
2fccde8
 
 
 
daba1b9
 
 
fe3a41d
 
 
2f5db5e
 
cdc237b
6deaccc
 
2f5db5e
 
5e0e606
 
 
 
2f5db5e
 
 
6deaccc
 
 
 
 
 
65b799e
 
daba1b9
65b799e
5e0e606
cdc237b
 
65b799e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from __future__ import annotations

from typing import Literal

from openenv.core import Action, Observation, State
from pydantic import BaseModel, Field

ActionIntent = Literal["run", "submit", "restore_best"]
ConstraintName = Literal[
    "none",
    "aspect_ratio",
    "average_triangularity",
    "edge_iota_over_nfp",
]
ParameterName = Literal[
    "aspect_ratio",
    "elongation",
    "rotational_transform",
    "triangularity_scale",
]
DirectionName = Literal["increase", "decrease"]
MagnitudeName = Literal["small", "medium", "large"]
EvaluationFidelityName = Literal["low", "high"]


class LowDimBoundaryParams(BaseModel):
    aspect_ratio: float
    elongation: float
    rotational_transform: float
    triangularity_scale: float


def default_low_dim_boundary_params() -> LowDimBoundaryParams:
    return LowDimBoundaryParams(
        aspect_ratio=3.6,
        elongation=1.4,
        rotational_transform=1.5,
        triangularity_scale=0.55,
    )


class RewardBreakdown(BaseModel):
    intent: ActionIntent = "run"
    total: float = 0.0
    evaluation_failed: bool = False
    recovered_from_failure: bool = False
    reference_constraints_satisfied: bool = False
    reference_score: float | None = None
    reference_feasibility: float | None = None
    reference_max_elongation: float | None = None
    initial_reference_score: float | None = None
    terminal_score_ratio: float | None = None
    invalid_action_penalty: float = 0.0
    failure_penalty: float = 0.0
    failure_submit_penalty: float = 0.0
    failure_budget_penalty: float = 0.0
    feasibility_crossing_bonus: float = 0.0
    feasibility_regression_penalty: float = 0.0
    feasibility_delta_reward: float = 0.0
    best_feasibility_bonus: float = 0.0
    near_feasible_bonus: float = 0.0
    aspect_ratio_repair_reward: float = 0.0
    triangularity_repair_reward: float = 0.0
    iota_repair_reward: float = 0.0
    objective_delta_reward: float = 0.0
    best_score_bonus: float = 0.0
    step_cost: float = 0.0
    no_progress_penalty: float = 0.0
    repeat_state_penalty: float = 0.0
    recovery_bonus: float = 0.0
    terminal_improvement_bonus: float = 0.0
    terminal_budget_bonus: float = 0.0
    terminal_no_improvement_penalty: float = 0.0


def default_reward_breakdown() -> RewardBreakdown:
    return RewardBreakdown()


class ActionMonitor(BaseModel):
    intent: ActionIntent = "run"
    parameter: ParameterName | None = None
    direction: DirectionName | None = None
    magnitude: MagnitudeName | None = None
    params_before: LowDimBoundaryParams = Field(default_factory=default_low_dim_boundary_params)
    params_after: LowDimBoundaryParams = Field(default_factory=default_low_dim_boundary_params)
    clamped: bool = False
    no_op: bool = False
    repeat_state: bool = False
    used_best_params: bool = False


def default_action_monitor() -> ActionMonitor:
    params = default_low_dim_boundary_params()
    return ActionMonitor(params_before=params, params_after=params)


class StellaratorAction(Action):
    intent: ActionIntent
    parameter: ParameterName | None = None
    direction: DirectionName | None = None
    magnitude: MagnitudeName | None = None
    reasoning: str = ""


class StellaratorObservation(Observation):
    diagnostics_text: str = ""
    max_elongation: float = 0.0
    aspect_ratio: float = 0.0
    average_triangularity: float = 0.0
    edge_iota_over_nfp: float = 0.0
    aspect_ratio_violation: float = 0.0
    triangularity_violation: float = 0.0
    iota_violation: float = 0.0
    dominant_constraint: ConstraintName = "none"
    p1_score: float = 0.0
    p1_feasibility: float = 0.0
    vacuum_well: float = 0.0
    evaluation_fidelity: EvaluationFidelityName = "low"
    evaluation_failed: bool = False
    failure_reason: str = ""
    step_number: int = 0
    budget_remaining: int = 6
    no_progress_steps: int = 0
    best_low_fidelity_score: float = 0.0
    best_low_fidelity_feasibility: float = float("inf")
    constraints_satisfied: bool = True
    target_spec: str = ""
    reward_breakdown: RewardBreakdown = Field(default_factory=default_reward_breakdown)
    action_monitor: ActionMonitor = Field(default_factory=default_action_monitor)
    episode_total_reward: float = 0.0
    trajectory_summary: str = ""


class StellaratorState(State):
    initial_params: LowDimBoundaryParams = Field(default_factory=default_low_dim_boundary_params)
    current_params: LowDimBoundaryParams = Field(default_factory=default_low_dim_boundary_params)
    best_params: LowDimBoundaryParams = Field(default_factory=default_low_dim_boundary_params)
    initial_low_fidelity_score: float = 0.0
    best_low_fidelity_score: float = 0.0
    best_low_fidelity_feasibility: float = float("inf")
    budget_total: int = 6
    budget_remaining: int = 6
    episode_done: bool = False
    constraints_satisfied: bool = True
    total_reward: float = 0.0
    no_progress_steps: int = 0
    visited_state_keys: list[str] = Field(default_factory=list)
    history: list[str] = Field(default_factory=list)