| """ |
| Offline RL trainer for the ad recommendation DQN. |
| |
| Converts logged impression data into (s, a, r, s′) transitions, |
| fills the replay buffer, then runs batched TD-learning updates. |
| """ |
| import time |
| import numpy as np |
|
|
|
|
| class AdTrainer: |
| def __init__(self, agent, dataset, max_seq_len: int = 10): |
| self.agent = agent |
| self.dataset = dataset |
| self.max_seq_len = max_seq_len |
| self.history = {'losses': [], 'epochs': 0} |
|
|
| |
|
|
| def _build_transitions(self) -> list: |
| transitions = [] |
| for uid, seq in self.dataset.user_sequences.items(): |
| if len(seq) < 2: |
| continue |
| user_feat = self.dataset.get_user_features(uid) |
| ctx_feat = self.dataset.get_context_features() |
|
|
| for t in range(len(seq) - 1): |
| history = seq[:t] |
| ad_id, reward, clicked, converted = seq[t] |
| next_history = seq[:t + 1] |
| done = t == len(seq) - 2 |
|
|
| state = self.agent.get_state(history, user_feat, ctx_feat) |
| next_state = self.agent.get_state(next_history, user_feat, ctx_feat) |
|
|
| transitions.append((state, int(ad_id), float(reward), next_state, done)) |
|
|
| return transitions |
|
|
| |
|
|
| def train_offline(self, n_epochs: int = 5, callback=None) -> dict: |
| t0 = time.time() |
| n_users = len(self.dataset.user_sequences) |
| print(f"[Trainer] Building transitions from {n_users} users…") |
|
|
| transitions = self._build_transitions() |
| print(f"[Trainer] {len(transitions):,} transitions → filling replay buffer") |
|
|
| for s, a, r, ns, d in transitions: |
| self.agent.store(s, a, r, ns, d) |
|
|
| for epoch in range(n_epochs): |
| steps = min(600, len(self.agent.buffer) // max(1, self.agent.batch_size)) |
| epoch_losses = [] |
|
|
| for _ in range(steps): |
| loss = self.agent.train_step() |
| if loss is not None: |
| epoch_losses.append(loss) |
|
|
| avg = float(np.mean(epoch_losses)) if epoch_losses else 0.0 |
| self.history['losses'].append(round(avg, 6)) |
| self.history['epochs'] += 1 |
|
|
| print( |
| f"[Trainer] Epoch {epoch + 1}/{n_epochs} | " |
| f"loss={avg:.5f} | ε={self.agent.epsilon:.4f} | " |
| f"buffer={len(self.agent.buffer):,}" |
| ) |
|
|
| if callback: |
| callback({ |
| 'epoch': epoch + 1, |
| 'total_epochs': n_epochs, |
| 'loss': avg, |
| 'epsilon': self.agent.epsilon, |
| 'progress': (epoch + 1) / n_epochs, |
| }) |
|
|
| elapsed = time.time() - t0 |
| print(f"[Trainer] Done in {elapsed:.1f}s") |
| return self.history |
|
|