meta-learning-priority-panic / env /priority_panic_env.py
vanshjagetia's picture
Update env/priority_panic_env.py
d0142c7 verified
import copy
import math
from typing import List, Dict, Any
from models.models import Task, Observation, Action
from tasks.tasks import EASY_TASK, MEDIUM_TASK, HARD_TASK, TaskConfig
class MetaLearningPriorityPanicEnv:
def __init__(self, task_id: str = "easy", **kwargs):
self.task_id = task_id
self.tasks: List[Task] = []
self.energy: int = 10
self.step_count: int = 0
self.social_debt: float = 0.0
self.streak: int = 0
self.last_action_result: str = "Environment initialized"
self.last_action_str: str = ""
self.max_steps = 10
self.config: TaskConfig = None
# ================================
# RESET
# ================================
def reset(self, task_id: str = None, **kwargs) -> Dict[str, Any]:
if task_id:
self.task_id = task_id
if self.task_id == "hard":
self.config = HARD_TASK
elif self.task_id == "medium":
self.config = MEDIUM_TASK
else:
self.config = EASY_TASK
self.energy = self.config["initial_energy"]
self.max_steps = self.config["max_steps"]
self.step_count = 0
self.social_debt = 0.0
self.streak = 0
self.last_action_result = "Environment initialized"
self.last_action_str = ""
self.tasks = copy.deepcopy(self.config["initial_tasks"])
return {
"observation": self._get_observation().model_dump(),
"reward": 0.01,
"done": False,
"info": {}
}
# ================================
# OBSERVATION
# ================================
def _get_observation(self) -> Observation:
return Observation(
tasks=copy.deepcopy(self.tasks),
energy=self.energy,
step_count=self.step_count,
social_debt=self.social_debt,
streak=self.streak,
last_action_result=self.last_action_result
)
# ================================
# STEP FUNCTION
# ================================
def step(self, action: Action) -> Dict[str, Any]:
raw_score = 0.0
worked = False
self.last_action_result = "Action processed."
# 🔹 Anti-repeat penalty
action_str = f"{action.action_type}:{','.join(map(str, sorted(action.task_ids)))}"
if action_str == self.last_action_str and self.step_count > 0:
raw_score -= 0.2
self.last_action_result = "Penalty: Repeated action."
self.last_action_str = action_str
# 🔹 Validate action
if action.action_type not in ["complete_task", "skip", "noop"]:
raw_score -= 0.2
self.last_action_result = "Penalty: Invalid action."
# ================================
# COMPLETE TASK
# ================================
elif action.action_type == "complete_task":
for t_id in action.task_ids[:2]: # limit to 2 tasks
task = next((t for t in self.tasks if t.id == t_id), None)
if not task:
raw_score -= 0.2
continue
if task.completed:
raw_score -= 0.2
continue
if self.energy >= task.energy_cost:
self.energy -= task.energy_cost
task.completed = True
worked = True
base = 0.3
bonus = (
0.4 if task.priority == "high"
else 0.2 if task.priority == "medium"
else 0.1
)
# 🔥 Stable reward (scaled to avoid explosion)
raw_score += (base + bonus) * 0.8
self.last_action_result = f"Task {t_id} completed."
else:
raw_score -= 0.05
# 🔹 Multi-task bonus
if worked and len(action.task_ids) > 1:
raw_score += 0.05
elif action.action_type == "skip":
self.last_action_result = "Skipped."
elif action.action_type == "noop":
raw_score -= 0.1
self.last_action_result = "No-op."
# ================================
# DEADLINE PENALTY (CAPPED)
# ================================
missed = sum(
1 for t in self.tasks
if not t.completed and self.step_count >= t.deadline
)
raw_score -= min(0.2, 0.1 * missed)
# ================================
# SOCIAL DEBT SYSTEM
# ================================
high_tasks = [t for t in self.tasks if t.priority == "high" and not t.completed]
high_ids = {t.id for t in high_tasks}
completed_ids = set(action.task_ids) if action.action_type == "complete_task" else set()
ignored_high = len(high_tasks) > 0 and len(high_ids & completed_ids) == 0
social_debt_active = self.config.get("social_debt_active", False)
if ignored_high and social_debt_active:
self.social_debt += 1.0
raw_score -= 0.2
if self.social_debt > 0 and social_debt_active:
raw_score -= 0.1
# ================================
# BONUS: ALL HIGH TASKS DONE
# ================================
high_all = [t for t in self.tasks if t.priority == "high"]
if high_all and all(t.completed for t in high_all):
raw_score += 0.1
# ================================
# TASK INJECTION
# ================================
if self.step_count in self.config.get("task_injection_steps", []):
if self.step_count == 3:
self.tasks.append(
Task(
id=98,
description="Urgent CEO request",
priority="high",
deadline=self.step_count + 4,
energy_cost=3,
completed=False
)
)
elif self.step_count == 6:
self.tasks.append(
Task(
id=99,
description="Resolve production bug",
priority="medium",
deadline=self.step_count + 4,
energy_cost=2,
completed=False
)
)
# ================================
# STREAK SYSTEM
# ================================
if action.action_type == "complete_task" and worked:
self.streak += 1
raw_score += 0.1 * self.streak
else:
self.streak = 0
# ================================
# STEP UPDATE
# ================================
self.step_count += 1
# ================================
# FINAL NORMALIZATION (STRICT SAFE)
# ================================
try:
raw_score = float(raw_score)
if math.isnan(raw_score) or math.isinf(raw_score):
normalized = 0.01
else:
# HARD CLAMP
raw_score = max(0.0, min(raw_score, 1.0 - 1e-6))
# STRICT OPEN INTERVAL (0.01, 0.99)
if raw_score <= 1e-6:
normalized = 0.01
elif raw_score >= 0.999:
normalized = 0.99
else:
# normalized = raw_score
normalized = max(0.01, min(raw_score, 0.99))
normalized = max(0.01, min(normalized, 0.99))
except:
normalized = 0.01
done = self.step_count >= self.max_steps
return {
"observation": self._get_observation().model_dump(),
"reward": float(normalized),
"done": done,
"info": {}
}