Spaces:
Runtime error
Runtime error
| """ | |
| SB3 callback that periodically improves specialist prompts using | |
| SpecialistFinetuner + SpecialistMemory. | |
| Wired into model.learn() alongside CheckpointCallback in train.py. | |
| Triggers every `improve_every_n_episodes` completed episodes. | |
| """ | |
| from __future__ import annotations | |
| from stable_baselines3.common.callbacks import BaseCallback | |
| class SpecialistImprovementCallback(BaseCallback): | |
| """ | |
| After every `improve_every_n_episodes` episodes, run the finetuner over | |
| all specialists that have enough memory entries and below-threshold reward. | |
| Also saves the memory file after each improvement pass. | |
| """ | |
| def __init__(self, improve_every_n_episodes: int = 100, verbose: int = 0): | |
| super().__init__(verbose) | |
| self._improve_every = improve_every_n_episodes | |
| self._episode_count = 0 | |
| def _on_step(self) -> bool: | |
| dones = self.locals.get("dones", []) | |
| self._episode_count += int(sum(dones)) | |
| if self._episode_count >= self._improve_every: | |
| self._episode_count = 0 | |
| self._run_improvement() | |
| return True | |
| def _run_improvement(self) -> None: | |
| from agents.specialist_finetuner import SpecialistFinetuner | |
| env = self._get_env() | |
| if env is None: | |
| return | |
| memory = getattr(env, "specialist_memory", None) | |
| registry = getattr(env, "registry", None) | |
| if memory is None or registry is None: | |
| return | |
| cfg = getattr(env, "config", {}) | |
| si_cfg = cfg.get("specialist_improvement", {}) | |
| min_entries = si_cfg.get("min_entries_to_improve", 10) | |
| threshold = si_cfg.get("improve_avg_reward_threshold", 0.70) | |
| finetuner = SpecialistFinetuner( | |
| min_entries=min_entries, | |
| improve_threshold=threshold, | |
| ) | |
| n = finetuner.improve_all(registry, memory) | |
| memory.save() | |
| if self.verbose and n > 0: | |
| print(f"[SpecialistImprovementCallback] Improved {n} specialist(s).") | |
| def _get_env(self): | |
| """Unwrap VecNormalize β DummyVecEnv β first env.""" | |
| try: | |
| venv = self.training_env | |
| # VecNormalize wraps venv; DummyVecEnv has .envs | |
| inner = getattr(venv, "venv", venv) | |
| envs = getattr(inner, "envs", None) | |
| if envs: | |
| return envs[0] | |
| except Exception: | |
| pass | |
| return None | |