TheQuantEd's picture
deploy: AMD EA Strategy Optimizer — Neo4j + FastAPI + Streamlit
6252f54
"""
REINFORCE trainer for the EA Policy Network.
Runs on AMD ROCm (PyTorch CUDA interface) or CPU.
Trains in ~30 seconds on CPU, <5 seconds on ROCm.
"""
import os
import json
import time
import logging
import numpy as np
import torch
import torch.optim as optim
from backend.drl.policy_network import EAPolicyNetwork, get_device
from backend.drl.environment import EAEnvironment
log = logging.getLogger(__name__)
CHECKPOINT_DIR = os.path.join(os.path.dirname(__file__), "checkpoints")
DEFAULT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "ea_policy_v1.pt")
class REINFORCETrainer:
"""Vanilla policy gradient trainer for EAPolicyNetwork on EAEnvironment."""
def __init__(
self,
policy: EAPolicyNetwork | None = None,
env: EAEnvironment | None = None,
lr: float = 1e-3,
device: torch.device | None = None,
):
self.device = device or get_device()
self.policy = (policy or EAPolicyNetwork()).to(self.device)
self.env = env or EAEnvironment()
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
self.training_log: list[dict] = []
def compute_returns(self, rewards: list[float], gamma: float = 0.99) -> torch.Tensor:
"""Compute discounted returns (single-step env: just the reward)."""
returns = []
G = 0.0
for r in reversed(rewards):
G = r + gamma * G
returns.insert(0, G)
returns_t = torch.FloatTensor(returns).to(self.device)
# Normalise
if returns_t.std() > 1e-8:
returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)
return returns_t
def train(self, n_episodes: int = 100, gamma: float = 0.99) -> list[dict]:
"""
Run REINFORCE for n_episodes.
Returns training_log list of {episode, reward, loss}.
Saves checkpoint every 25 episodes and at end.
"""
log.info(f"Starting REINFORCE training: {n_episodes} episodes on {self.device}")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
t0 = time.time()
episode_rewards = []
for episode in range(n_episodes):
self.policy.train()
state = self.env.reset()
state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Single-step episode: sample one ordering
action_indices, log_prob_sum = self.policy.sample_action(state_t.squeeze(0))
action_np = action_indices.cpu().numpy()
_, reward, _ = self.env.step(action_np)
# REINFORCE update
returns = self.compute_returns([reward], gamma)
loss = -(log_prob_sum * returns[0])
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
self.optimizer.step()
episode_rewards.append(reward)
entry = {
"episode": episode,
"reward": round(reward, 4),
"loss": round(loss.item(), 6),
"avg_reward_last10": round(float(np.mean(episode_rewards[-10:])), 4),
}
self.training_log.append(entry)
if episode % 25 == 0 or episode == n_episodes - 1:
log.info(
f" Episode {episode+1}/{n_episodes} | "
f"reward={reward:.4f} | loss={loss.item():.6f} | "
f"avg10={entry['avg_reward_last10']:.4f}"
)
self.save_checkpoint(DEFAULT_CHECKPOINT)
elapsed = time.time() - t0
log.info(f"Training completed in {elapsed:.1f}s | final avg reward: {np.mean(episode_rewards[-20:]):.4f}")
# Save training log alongside checkpoint
log_path = DEFAULT_CHECKPOINT.replace(".pt", "_log.json")
with open(log_path, "w") as f:
json.dump(self.training_log, f, indent=2)
return self.training_log
def save_checkpoint(self, path: str):
torch.save({
"policy_state_dict": self.policy.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"training_log": self.training_log,
"device": str(self.device),
}, path)
def load_checkpoint(self, path: str = DEFAULT_CHECKPOINT) -> bool:
if not os.path.exists(path):
log.warning(f"Checkpoint not found: {path}")
return False
checkpoint = torch.load(path, map_location=self.device)
self.policy.load_state_dict(checkpoint["policy_state_dict"])
if "optimizer_state_dict" in checkpoint:
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "training_log" in checkpoint:
self.training_log = checkpoint["training_log"]
log.info(f"Loaded checkpoint from {path}")
return True
def load_trained_policy(checkpoint_path: str = DEFAULT_CHECKPOINT) -> EAPolicyNetwork:
"""Convenience function: load a trained policy for inference."""
device = get_device()
policy = EAPolicyNetwork().to(device)
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
policy.load_state_dict(checkpoint["policy_state_dict"])
log.info(f"Policy loaded from {checkpoint_path}")
else:
log.warning(f"No checkpoint at {checkpoint_path} — using random policy. Run: python -m pipeline.train_drl")
policy.eval()
return policy
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
trainer = REINFORCETrainer()
trainer.train(n_episodes=150)
print(f"Checkpoint saved to: {DEFAULT_CHECKPOINT}")