import copy import math from typing import List, Dict, Any from models.models import Task, Observation, Action from tasks.tasks import EASY_TASK, MEDIUM_TASK, HARD_TASK, TaskConfig class MetaLearningPriorityPanicEnv: def __init__(self, task_id: str = "easy", **kwargs): self.task_id = task_id self.tasks: List[Task] = [] self.energy: int = 10 self.step_count: int = 0 self.social_debt: float = 0.0 self.streak: int = 0 self.last_action_result: str = "Environment initialized" self.last_action_str: str = "" self.max_steps = 10 self.config: TaskConfig = None # ================================ # RESET # ================================ def reset(self, task_id: str = None, **kwargs) -> Dict[str, Any]: if task_id: self.task_id = task_id if self.task_id == "hard": self.config = HARD_TASK elif self.task_id == "medium": self.config = MEDIUM_TASK else: self.config = EASY_TASK self.energy = self.config["initial_energy"] self.max_steps = self.config["max_steps"] self.step_count = 0 self.social_debt = 0.0 self.streak = 0 self.last_action_result = "Environment initialized" self.last_action_str = "" self.tasks = copy.deepcopy(self.config["initial_tasks"]) return { "observation": self._get_observation().model_dump(), "reward": 0.01, "done": False, "info": {} } # ================================ # OBSERVATION # ================================ def _get_observation(self) -> Observation: return Observation( tasks=copy.deepcopy(self.tasks), energy=self.energy, step_count=self.step_count, social_debt=self.social_debt, streak=self.streak, last_action_result=self.last_action_result ) # ================================ # STEP FUNCTION # ================================ def step(self, action: Action) -> Dict[str, Any]: raw_score = 0.0 worked = False self.last_action_result = "Action processed." # 🔹 Anti-repeat penalty action_str = f"{action.action_type}:{','.join(map(str, sorted(action.task_ids)))}" if action_str == self.last_action_str and self.step_count > 0: raw_score -= 0.2 self.last_action_result = "Penalty: Repeated action." self.last_action_str = action_str # 🔹 Validate action if action.action_type not in ["complete_task", "skip", "noop"]: raw_score -= 0.2 self.last_action_result = "Penalty: Invalid action." # ================================ # COMPLETE TASK # ================================ elif action.action_type == "complete_task": for t_id in action.task_ids[:2]: # limit to 2 tasks task = next((t for t in self.tasks if t.id == t_id), None) if not task: raw_score -= 0.2 continue if task.completed: raw_score -= 0.2 continue if self.energy >= task.energy_cost: self.energy -= task.energy_cost task.completed = True worked = True base = 0.3 bonus = ( 0.4 if task.priority == "high" else 0.2 if task.priority == "medium" else 0.1 ) # 🔥 Stable reward (scaled to avoid explosion) raw_score += (base + bonus) * 0.8 self.last_action_result = f"Task {t_id} completed." else: raw_score -= 0.05 # 🔹 Multi-task bonus if worked and len(action.task_ids) > 1: raw_score += 0.05 elif action.action_type == "skip": self.last_action_result = "Skipped." elif action.action_type == "noop": raw_score -= 0.1 self.last_action_result = "No-op." # ================================ # DEADLINE PENALTY (CAPPED) # ================================ missed = sum( 1 for t in self.tasks if not t.completed and self.step_count >= t.deadline ) raw_score -= min(0.2, 0.1 * missed) # ================================ # SOCIAL DEBT SYSTEM # ================================ high_tasks = [t for t in self.tasks if t.priority == "high" and not t.completed] high_ids = {t.id for t in high_tasks} completed_ids = set(action.task_ids) if action.action_type == "complete_task" else set() ignored_high = len(high_tasks) > 0 and len(high_ids & completed_ids) == 0 social_debt_active = self.config.get("social_debt_active", False) if ignored_high and social_debt_active: self.social_debt += 1.0 raw_score -= 0.2 if self.social_debt > 0 and social_debt_active: raw_score -= 0.1 # ================================ # BONUS: ALL HIGH TASKS DONE # ================================ high_all = [t for t in self.tasks if t.priority == "high"] if high_all and all(t.completed for t in high_all): raw_score += 0.1 # ================================ # TASK INJECTION # ================================ if self.step_count in self.config.get("task_injection_steps", []): if self.step_count == 3: self.tasks.append( Task( id=98, description="Urgent CEO request", priority="high", deadline=self.step_count + 4, energy_cost=3, completed=False ) ) elif self.step_count == 6: self.tasks.append( Task( id=99, description="Resolve production bug", priority="medium", deadline=self.step_count + 4, energy_cost=2, completed=False ) ) # ================================ # STREAK SYSTEM # ================================ if action.action_type == "complete_task" and worked: self.streak += 1 raw_score += 0.1 * self.streak else: self.streak = 0 # ================================ # STEP UPDATE # ================================ self.step_count += 1 # ================================ # FINAL NORMALIZATION (STRICT SAFE) # ================================ try: raw_score = float(raw_score) if math.isnan(raw_score) or math.isinf(raw_score): normalized = 0.01 else: # HARD CLAMP raw_score = max(0.0, min(raw_score, 1.0 - 1e-6)) # STRICT OPEN INTERVAL (0.01, 0.99) if raw_score <= 1e-6: normalized = 0.01 elif raw_score >= 0.999: normalized = 0.99 else: # normalized = raw_score normalized = max(0.01, min(raw_score, 0.99)) normalized = max(0.01, min(normalized, 0.99)) except: normalized = 0.01 done = self.step_count >= self.max_steps return { "observation": self._get_observation().model_dump(), "reward": float(normalized), "done": done, "info": {} }