| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torchvision |
| import torchvision.transforms as transforms |
| import copy |
| import torch.fft |
| import torch.nn.functional as F |
|
|
| |
| transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) |
| trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) |
| testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) |
|
|
| def get_split_dataloaders(dataset, classes, batch_size=64): |
| indices = [i for i, target in enumerate(dataset.targets) if target in classes] |
| subset = torch.utils.data.Subset(dataset, indices) |
| return torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) |
|
|
| |
| train_loader_A = get_split_dataloaders(trainset, [0, 1, 2, 3, 4]) |
| test_loader_A = get_split_dataloaders(testset, [0, 1, 2, 3, 4]) |
|
|
| |
| train_loader_B = get_split_dataloaders(trainset, [5, 6, 7, 8, 9]) |
| test_loader_B = get_split_dataloaders(testset, [5, 6, 7, 8, 9]) |
|
|
| |
| class SimpleCNN(nn.Module): |
| def __init__(self): |
| super(SimpleCNN, self).__init__() |
| self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) |
| self.relu = nn.ReLU() |
| self.pool = nn.MaxPool2d(2, 2) |
| self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) |
| self.fc1 = nn.Linear(32 * 7 * 7, 128) |
| self.fc2 = nn.Linear(128, 10) |
|
|
| def forward(self, x): |
| x = self.pool(self.relu(self.conv1(x))) |
| x = self.pool(self.relu(self.conv2(x))) |
| x = x.view(-1, 32 * 7 * 7) |
| x = self.relu(self.fc1(x)) |
| x = self.fc2(x) |
| return x |
|
|
| def evaluate_accuracy(model, dataloader): |
| model.eval() |
| correct = 0 |
| total = 0 |
| with torch.no_grad(): |
| for images, labels in dataloader: |
| outputs = model(images) |
| _, predicted = torch.max(outputs.data, 1) |
| total += labels.size(0) |
| correct += (predicted == labels).sum().item() |
| return 100 * correct / total |
|
|
| |
| class AnastrophicRegularizer(nn.Module): |
| def __init__(self, lambda_reg=1.0, eta_reg=3.0): |
| super().__init__() |
| self.lambda_reg = lambda_reg |
| self.eta_reg = eta_reg |
|
|
| def compute_phi(self, w): |
| """Calculates Spectral Coherence (Phi) via 1D FFT.""" |
| fft_w = torch.fft.fft(w.view(-1)) |
| amplitudes = torch.abs(fft_w) |
| phases = torch.angle(fft_w) |
| |
| p_j = (amplitudes ** 2) / (torch.sum(amplitudes ** 2) + 1e-8) |
| complex_sum = torch.sum(p_j * torch.exp(1j * phases)) |
| |
| return torch.abs(complex_sum) |
|
|
| def compute_beta_proxy(self, w, w_prev): |
| """Continuous proxy for Anastrophic Beta (BB) measuring harmonic tension.""" |
| fft_w = torch.fft.fft(w.view(-1)) |
| fft_prev = torch.fft.fft(w_prev.view(-1)) |
| |
| complex_w = torch.view_as_real(fft_w) |
| complex_prev = torch.view_as_real(fft_prev) |
| |
| return F.mse_loss(complex_w, complex_prev) |
|
|
| def forward(self, model, model_prev): |
| loss_ana = 0.0 |
| for (name, param), (name_prev, param_prev) in zip(model.named_parameters(), model_prev.named_parameters()): |
| if 'weight' in name: |
| phi = self.compute_phi(param) |
| |
| beta = self.compute_beta_proxy(param, param_prev.detach()) |
| |
| loss_ana += self.lambda_reg * (1 - phi) + self.eta_reg * beta |
| |
| return loss_ana |
|
|
| |
| model = SimpleCNN() |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
| print("--- Starting Task A (Digits 0-4) Training ---") |
| model.train() |
| for epoch in range(3): |
| for images, labels in train_loader_A: |
| optimizer.zero_grad() |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
|
|
| acc_A = evaluate_accuracy(model, test_loader_A) |
| print(f"Accuracy on Task A after training Task A: {acc_A:.2f}%\n") |
|
|
| |
| model_A_frozen = copy.deepcopy(model) |
| model_A_frozen.eval() |
|
|
| |
| regularizer = AnastrophicRegularizer(lambda_reg=1.0, eta_reg=3.0) |
|
|
| print("--- Starting Task B (Digits 5-9) Training with R_ana ---") |
| optimizer_B = optim.Adam(model.parameters(), lr=0.001) |
|
|
| for epoch in range(3): |
| model.train() |
| running_loss_class = 0.0 |
| running_loss_ana = 0.0 |
| |
| for images, labels in train_loader_B: |
| optimizer_B.zero_grad() |
| outputs = model(images) |
| |
| loss_class = criterion(outputs, labels) |
| |
| |
| loss_ana = regularizer(model, model_A_frozen) |
| |
| loss = loss_class + loss_ana |
| loss.backward() |
| optimizer_B.step() |
| |
| running_loss_class += loss_class.item() |
| running_loss_ana += loss_ana.item() |
| |
| print(f"Epoch {epoch+1} | Classification Loss: {running_loss_class/len(train_loader_B):.4f} | R_ana Loss: {running_loss_ana/len(train_loader_B):.4f}") |
|
|
| |
| print("\n--- Final Results ---") |
| acc_B_final = evaluate_accuracy(model, test_loader_B) |
| acc_A_final = evaluate_accuracy(model, test_loader_A) |
|
|
| print(f"Accuracy on NEW Task B (5-9): {acc_B_final:.2f}%") |
| print(f"RETAINED Accuracy on Task A (0-4): {acc_A_final:.2f}%") |
|
|