hari15prasad
Initial clean deployment to Hugging Face
6f44ddb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from typing import Any
from collections import deque
# ---------------------------------------------------------------------------
# DQN Architecture
# Standard dense architecture for the EduForge pedagogical framework.
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Multi-Head DQN Architecture (Task 6)
# ---------------------------------------------------------------------------
class DQN(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(DQN, self).__init__()
# Shared Trunk
self.trunk = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU()
)
# Domain-Specific Heads: 0=procedural, 1=conceptual, 2=factual, 3=transfer
self.heads = nn.ModuleList([
nn.Linear(128, output_dim) for _ in range(4)
])
def forward(self, x: torch.Tensor, domain_idx: Any = None) -> torch.Tensor:
trunk_out = self.trunk(x)
if domain_idx is None:
# Default to first head if not specified (for safety/eval)
return self.heads[0](trunk_out)
if isinstance(domain_idx, (int, np.integer)):
return self.heads[domain_idx](trunk_out)
# Batch processing
if torch.is_tensor(domain_idx):
# Compute all heads
all_q = torch.stack([head(trunk_out) for head in self.heads], dim=1) # (batch, 4, output_dim)
# domain_idx is (batch, 1)
# We need to select the head for each sample in the batch
# Reshape domain_idx for gathering: (batch, 1, output_dim)
idx = domain_idx.view(-1, 1, 1).expand(-1, 1, all_q.size(-1))
selected_q = all_q.gather(1, idx).squeeze(1)
return selected_q
return self.heads[0](trunk_out)
# ---------------------------------------------------------------------------
# Double DQN Agent
# Decouples selection from evaluation to stabilize Q-value estimation.
# ---------------------------------------------------------------------------
class DQNAgent:
def __init__(self, input_dim, output_dim, lr=1e-3, gamma=0.99, buffer_size=20000):
self.input_dim = input_dim
self.output_dim = output_dim
self.gamma = gamma
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Networks: Policy (Live) and Target (Lagging)
self.policy_net = DQN(input_dim, output_dim).to(self.device)
self.target_net = DQN(input_dim, output_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=lr)
self.memory = ReplayBuffer(buffer_size)
def select_action(self, state, epsilon, domain_idx, mask=None):
if random.random() < epsilon:
# Exploration
available_actions = [i for i, m in enumerate(mask) if m == 0] if mask is not None else range(self.output_dim)
return random.choice(available_actions)
# Exploitation
state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
q_values = self.policy_net(state_t, domain_idx)
if mask is not None:
# Apply mask by heavily penalizing forbidden actions
# mask in dqn_pipeline is 0.0=allowed, -inf=blocked
mask_t = torch.FloatTensor(mask).to(self.device)
q_values = q_values + mask_t
return q_values.max(1)[1].item()
def update(self, batch_size):
if len(self.memory) < batch_size:
return None
# Sample transitions
states, actions, rewards, next_states, dones, next_masks, indices, weights, domain_idxs = self.memory.sample(batch_size)
# To Tensors
states_t = torch.FloatTensor(states).to(self.device)
actions_t = torch.LongTensor(actions).unsqueeze(1).to(self.device)
rewards_t = torch.FloatTensor(rewards).to(self.device)
next_states_t = torch.FloatTensor(next_states).to(self.device)
dones_t = torch.FloatTensor(dones).to(self.device)
next_masks_t = torch.FloatTensor(next_masks).to(self.device)
weights_t = torch.FloatTensor(weights).to(self.device)
domain_idxs_t = torch.LongTensor(domain_idxs).unsqueeze(1).to(self.device)
# 1. Get current Q-values from Policy Net (Domain-specific head)
current_q = self.policy_net(states_t, domain_idxs_t).gather(1, actions_t)
# 2. DOUBLE DQN LOGIC:
# Step A: Policy Net selects the best action for the next state
with torch.no_grad():
# Apply masks to next state predictions so we don't bootstrap from invalid actions
next_q_policy = self.policy_net(next_states_t, domain_idxs_t)
next_q_policy = next_q_policy + next_masks_t
best_actions = next_q_policy.max(1)[1].unsqueeze(1)
# Step B: Target Net evaluates the value of that specific action
next_q_target = self.target_net(next_states_t, domain_idxs_t).gather(1, best_actions).squeeze(1)
# Step C: Reward Clipping (+10/-10)
clipped_rewards = torch.clamp(rewards_t, -10.0, 10.0)
target_q = clipped_rewards + (self.gamma * next_q_target * (1 - dones_t))
# 3. Compute Loss
# Use Importance Sampling weights if available
td_errors = torch.abs(current_q.squeeze() - target_q).detach().cpu().numpy()
self.memory.update_priorities(indices, td_errors + 1e-6)
loss = (weights_t * F.smooth_l1_loss(current_q.squeeze(), target_q, reduction='none')).mean()
# 4. Backprop with Gradient Clipping
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
return loss.item()
def sync_target(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
# ---------------------------------------------------------------------------
# Prioritized Replay Buffer (Task 4)
# ---------------------------------------------------------------------------
class ReplayBuffer:
def __init__(self, capacity: int, alpha: float = 0.6):
self.capacity = capacity
self.alpha = alpha
self.memory = []
self.priorities = np.zeros((capacity,), dtype=np.float32)
self.pos = 0
def push(self, state, action, reward, next_state, done, next_mask, domain_idx):
max_prio = self.priorities.max() if self.memory else 1.0
if len(self.memory) < self.capacity:
self.memory.append((state, action, reward, next_state, done, next_mask, domain_idx))
else:
self.memory[self.pos] = (state, action, reward, next_state, done, next_mask, domain_idx)
self.priorities[self.pos] = max_prio
self.pos = (self.pos + 1) % self.capacity
def sample(self, batch_size: int, beta: float = 0.4):
if len(self.memory) == len(self.priorities):
prios = self.priorities
else:
prios = self.priorities[:len(self.memory)]
probs = prios ** self.alpha
probs /= probs.sum()
indices = np.random.choice(len(self.memory), batch_size, p=probs)
samples = [self.memory[idx] for idx in indices]
total = len(self.memory)
weights = (total * probs[indices]) ** (-beta)
weights /= weights.max()
weights = np.array(weights, dtype=np.float32)
batch = zip(*samples)
states, actions, rewards, next_states, dones, next_masks, domain_idxs = batch
return (
np.stack(states),
np.array(actions),
np.array(rewards),
np.stack(next_states),
np.array(dones),
np.stack(next_masks),
indices,
weights,
np.array(domain_idxs)
)
def update_priorities(self, batch_indices, batch_priorities):
for idx, prio in zip(batch_indices, batch_priorities):
self.priorities[idx] = prio
def __len__(self):
return len(self.memory)