Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import random | |
| from typing import Any | |
| from collections import deque | |
| # --------------------------------------------------------------------------- | |
| # DQN Architecture | |
| # Standard dense architecture for the EduForge pedagogical framework. | |
| # --------------------------------------------------------------------------- | |
| # --------------------------------------------------------------------------- | |
| # Multi-Head DQN Architecture (Task 6) | |
| # --------------------------------------------------------------------------- | |
| class DQN(nn.Module): | |
| def __init__(self, input_dim: int, output_dim: int): | |
| super(DQN, self).__init__() | |
| # Shared Trunk | |
| self.trunk = nn.Sequential( | |
| nn.Linear(input_dim, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 128), | |
| nn.ReLU() | |
| ) | |
| # Domain-Specific Heads: 0=procedural, 1=conceptual, 2=factual, 3=transfer | |
| self.heads = nn.ModuleList([ | |
| nn.Linear(128, output_dim) for _ in range(4) | |
| ]) | |
| def forward(self, x: torch.Tensor, domain_idx: Any = None) -> torch.Tensor: | |
| trunk_out = self.trunk(x) | |
| if domain_idx is None: | |
| # Default to first head if not specified (for safety/eval) | |
| return self.heads[0](trunk_out) | |
| if isinstance(domain_idx, (int, np.integer)): | |
| return self.heads[domain_idx](trunk_out) | |
| # Batch processing | |
| if torch.is_tensor(domain_idx): | |
| # Compute all heads | |
| all_q = torch.stack([head(trunk_out) for head in self.heads], dim=1) # (batch, 4, output_dim) | |
| # domain_idx is (batch, 1) | |
| # We need to select the head for each sample in the batch | |
| # Reshape domain_idx for gathering: (batch, 1, output_dim) | |
| idx = domain_idx.view(-1, 1, 1).expand(-1, 1, all_q.size(-1)) | |
| selected_q = all_q.gather(1, idx).squeeze(1) | |
| return selected_q | |
| return self.heads[0](trunk_out) | |
| # --------------------------------------------------------------------------- | |
| # Double DQN Agent | |
| # Decouples selection from evaluation to stabilize Q-value estimation. | |
| # --------------------------------------------------------------------------- | |
| class DQNAgent: | |
| def __init__(self, input_dim, output_dim, lr=1e-3, gamma=0.99, buffer_size=20000): | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.gamma = gamma | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Networks: Policy (Live) and Target (Lagging) | |
| self.policy_net = DQN(input_dim, output_dim).to(self.device) | |
| self.target_net = DQN(input_dim, output_dim).to(self.device) | |
| self.target_net.load_state_dict(self.policy_net.state_dict()) | |
| self.target_net.eval() | |
| self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=lr) | |
| self.memory = ReplayBuffer(buffer_size) | |
| def select_action(self, state, epsilon, domain_idx, mask=None): | |
| if random.random() < epsilon: | |
| # Exploration | |
| available_actions = [i for i, m in enumerate(mask) if m == 0] if mask is not None else range(self.output_dim) | |
| return random.choice(available_actions) | |
| # Exploitation | |
| state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| q_values = self.policy_net(state_t, domain_idx) | |
| if mask is not None: | |
| # Apply mask by heavily penalizing forbidden actions | |
| # mask in dqn_pipeline is 0.0=allowed, -inf=blocked | |
| mask_t = torch.FloatTensor(mask).to(self.device) | |
| q_values = q_values + mask_t | |
| return q_values.max(1)[1].item() | |
| def update(self, batch_size): | |
| if len(self.memory) < batch_size: | |
| return None | |
| # Sample transitions | |
| states, actions, rewards, next_states, dones, next_masks, indices, weights, domain_idxs = self.memory.sample(batch_size) | |
| # To Tensors | |
| states_t = torch.FloatTensor(states).to(self.device) | |
| actions_t = torch.LongTensor(actions).unsqueeze(1).to(self.device) | |
| rewards_t = torch.FloatTensor(rewards).to(self.device) | |
| next_states_t = torch.FloatTensor(next_states).to(self.device) | |
| dones_t = torch.FloatTensor(dones).to(self.device) | |
| next_masks_t = torch.FloatTensor(next_masks).to(self.device) | |
| weights_t = torch.FloatTensor(weights).to(self.device) | |
| domain_idxs_t = torch.LongTensor(domain_idxs).unsqueeze(1).to(self.device) | |
| # 1. Get current Q-values from Policy Net (Domain-specific head) | |
| current_q = self.policy_net(states_t, domain_idxs_t).gather(1, actions_t) | |
| # 2. DOUBLE DQN LOGIC: | |
| # Step A: Policy Net selects the best action for the next state | |
| with torch.no_grad(): | |
| # Apply masks to next state predictions so we don't bootstrap from invalid actions | |
| next_q_policy = self.policy_net(next_states_t, domain_idxs_t) | |
| next_q_policy = next_q_policy + next_masks_t | |
| best_actions = next_q_policy.max(1)[1].unsqueeze(1) | |
| # Step B: Target Net evaluates the value of that specific action | |
| next_q_target = self.target_net(next_states_t, domain_idxs_t).gather(1, best_actions).squeeze(1) | |
| # Step C: Reward Clipping (+10/-10) | |
| clipped_rewards = torch.clamp(rewards_t, -10.0, 10.0) | |
| target_q = clipped_rewards + (self.gamma * next_q_target * (1 - dones_t)) | |
| # 3. Compute Loss | |
| # Use Importance Sampling weights if available | |
| td_errors = torch.abs(current_q.squeeze() - target_q).detach().cpu().numpy() | |
| self.memory.update_priorities(indices, td_errors + 1e-6) | |
| loss = (weights_t * F.smooth_l1_loss(current_q.squeeze(), target_q, reduction='none')).mean() | |
| # 4. Backprop with Gradient Clipping | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) | |
| self.optimizer.step() | |
| return loss.item() | |
| def sync_target(self): | |
| self.target_net.load_state_dict(self.policy_net.state_dict()) | |
| # --------------------------------------------------------------------------- | |
| # Prioritized Replay Buffer (Task 4) | |
| # --------------------------------------------------------------------------- | |
| class ReplayBuffer: | |
| def __init__(self, capacity: int, alpha: float = 0.6): | |
| self.capacity = capacity | |
| self.alpha = alpha | |
| self.memory = [] | |
| self.priorities = np.zeros((capacity,), dtype=np.float32) | |
| self.pos = 0 | |
| def push(self, state, action, reward, next_state, done, next_mask, domain_idx): | |
| max_prio = self.priorities.max() if self.memory else 1.0 | |
| if len(self.memory) < self.capacity: | |
| self.memory.append((state, action, reward, next_state, done, next_mask, domain_idx)) | |
| else: | |
| self.memory[self.pos] = (state, action, reward, next_state, done, next_mask, domain_idx) | |
| self.priorities[self.pos] = max_prio | |
| self.pos = (self.pos + 1) % self.capacity | |
| def sample(self, batch_size: int, beta: float = 0.4): | |
| if len(self.memory) == len(self.priorities): | |
| prios = self.priorities | |
| else: | |
| prios = self.priorities[:len(self.memory)] | |
| probs = prios ** self.alpha | |
| probs /= probs.sum() | |
| indices = np.random.choice(len(self.memory), batch_size, p=probs) | |
| samples = [self.memory[idx] for idx in indices] | |
| total = len(self.memory) | |
| weights = (total * probs[indices]) ** (-beta) | |
| weights /= weights.max() | |
| weights = np.array(weights, dtype=np.float32) | |
| batch = zip(*samples) | |
| states, actions, rewards, next_states, dones, next_masks, domain_idxs = batch | |
| return ( | |
| np.stack(states), | |
| np.array(actions), | |
| np.array(rewards), | |
| np.stack(next_states), | |
| np.array(dones), | |
| np.stack(next_masks), | |
| indices, | |
| weights, | |
| np.array(domain_idxs) | |
| ) | |
| def update_priorities(self, batch_indices, batch_priorities): | |
| for idx, prio in zip(batch_indices, batch_priorities): | |
| self.priorities[idx] = prio | |
| def __len__(self): | |
| return len(self.memory) |