mnoorchenar's picture
Update 2026-03-23 09:33:03
f19eb84
"""
Offline RL trainer for the ad recommendation DQN.
Converts logged impression data into (s, a, r, s′) transitions,
fills the replay buffer, then runs batched TD-learning updates.
"""
import time
import numpy as np
class AdTrainer:
def __init__(self, agent, dataset, max_seq_len: int = 10):
self.agent = agent
self.dataset = dataset
self.max_seq_len = max_seq_len
self.history = {'losses': [], 'epochs': 0}
# ------------------------------------------------------------------
def _build_transitions(self) -> list:
transitions = []
for uid, seq in self.dataset.user_sequences.items():
if len(seq) < 2:
continue
user_feat = self.dataset.get_user_features(uid)
ctx_feat = self.dataset.get_context_features() # default: current time
for t in range(len(seq) - 1):
history = seq[:t]
ad_id, reward, clicked, converted = seq[t]
next_history = seq[:t + 1]
done = t == len(seq) - 2
state = self.agent.get_state(history, user_feat, ctx_feat)
next_state = self.agent.get_state(next_history, user_feat, ctx_feat)
transitions.append((state, int(ad_id), float(reward), next_state, done))
return transitions
# ------------------------------------------------------------------
def train_offline(self, n_epochs: int = 5, callback=None) -> dict:
t0 = time.time()
n_users = len(self.dataset.user_sequences)
print(f"[Trainer] Building transitions from {n_users} users…")
transitions = self._build_transitions()
print(f"[Trainer] {len(transitions):,} transitions → filling replay buffer")
for s, a, r, ns, d in transitions:
self.agent.store(s, a, r, ns, d)
for epoch in range(n_epochs):
steps = min(600, len(self.agent.buffer) // max(1, self.agent.batch_size))
epoch_losses = []
for _ in range(steps):
loss = self.agent.train_step()
if loss is not None:
epoch_losses.append(loss)
avg = float(np.mean(epoch_losses)) if epoch_losses else 0.0
self.history['losses'].append(round(avg, 6))
self.history['epochs'] += 1
print(
f"[Trainer] Epoch {epoch + 1}/{n_epochs} | "
f"loss={avg:.5f} | ε={self.agent.epsilon:.4f} | "
f"buffer={len(self.agent.buffer):,}"
)
if callback:
callback({
'epoch': epoch + 1,
'total_epochs': n_epochs,
'loss': avg,
'epsilon': self.agent.epsilon,
'progress': (epoch + 1) / n_epochs,
})
elapsed = time.time() - t0
print(f"[Trainer] Done in {elapsed:.1f}s")
return self.history