SpindleFlow-RL / training /specialist_improvement_callback.py
garvitsachdeva's picture
SpindleFlow RL β€” periodic push + log persistence
02ff91f
"""
SB3 callback that periodically improves specialist prompts using
SpecialistFinetuner + SpecialistMemory.
Wired into model.learn() alongside CheckpointCallback in train.py.
Triggers every `improve_every_n_episodes` completed episodes.
"""
from __future__ import annotations
from stable_baselines3.common.callbacks import BaseCallback
class SpecialistImprovementCallback(BaseCallback):
"""
After every `improve_every_n_episodes` episodes, run the finetuner over
all specialists that have enough memory entries and below-threshold reward.
Also saves the memory file after each improvement pass.
"""
def __init__(self, improve_every_n_episodes: int = 100, verbose: int = 0):
super().__init__(verbose)
self._improve_every = improve_every_n_episodes
self._episode_count = 0
def _on_step(self) -> bool:
dones = self.locals.get("dones", [])
self._episode_count += int(sum(dones))
if self._episode_count >= self._improve_every:
self._episode_count = 0
self._run_improvement()
return True
def _run_improvement(self) -> None:
from agents.specialist_finetuner import SpecialistFinetuner
env = self._get_env()
if env is None:
return
memory = getattr(env, "specialist_memory", None)
registry = getattr(env, "registry", None)
if memory is None or registry is None:
return
cfg = getattr(env, "config", {})
si_cfg = cfg.get("specialist_improvement", {})
min_entries = si_cfg.get("min_entries_to_improve", 10)
threshold = si_cfg.get("improve_avg_reward_threshold", 0.70)
finetuner = SpecialistFinetuner(
min_entries=min_entries,
improve_threshold=threshold,
)
n = finetuner.improve_all(registry, memory)
memory.save()
if self.verbose and n > 0:
print(f"[SpecialistImprovementCallback] Improved {n} specialist(s).")
def _get_env(self):
"""Unwrap VecNormalize β†’ DummyVecEnv β†’ first env."""
try:
venv = self.training_env
# VecNormalize wraps venv; DummyVecEnv has .envs
inner = getattr(venv, "venv", venv)
envs = getattr(inner, "envs", None)
if envs:
return envs[0]
except Exception:
pass
return None