CartayaGon's picture
Upload main.py
f91a968 verified
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
import torch.fft
import torch.nn.functional as F
# --- 1. DATA PREPARATION (Split-MNIST) ---
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
def get_split_dataloaders(dataset, classes, batch_size=64):
indices = [i for i, target in enumerate(dataset.targets) if target in classes]
subset = torch.utils.data.Subset(dataset, indices)
return torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True)
# Task A: Digits 0-4
train_loader_A = get_split_dataloaders(trainset, [0, 1, 2, 3, 4])
test_loader_A = get_split_dataloaders(testset, [0, 1, 2, 3, 4])
# Task B: Digits 5-9
train_loader_B = get_split_dataloaders(trainset, [5, 6, 7, 8, 9])
test_loader_B = get_split_dataloaders(testset, [5, 6, 7, 8, 9])
# --- 2. NEURAL NETWORK ARCHITECTURE (CNN) ---
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
def evaluate_accuracy(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# --- 3. ANASTROPHIC REGULARIZATION ---
class AnastrophicRegularizer(nn.Module):
def __init__(self, lambda_reg=1.0, eta_reg=3.0):
super().__init__()
self.lambda_reg = lambda_reg
self.eta_reg = eta_reg
def compute_phi(self, w):
"""Calculates Spectral Coherence (Phi) via 1D FFT."""
fft_w = torch.fft.fft(w.view(-1))
amplitudes = torch.abs(fft_w)
phases = torch.angle(fft_w)
p_j = (amplitudes ** 2) / (torch.sum(amplitudes ** 2) + 1e-8)
complex_sum = torch.sum(p_j * torch.exp(1j * phases))
return torch.abs(complex_sum)
def compute_beta_proxy(self, w, w_prev):
"""Continuous proxy for Anastrophic Beta (BB) measuring harmonic tension."""
fft_w = torch.fft.fft(w.view(-1))
fft_prev = torch.fft.fft(w_prev.view(-1))
complex_w = torch.view_as_real(fft_w)
complex_prev = torch.view_as_real(fft_prev)
return F.mse_loss(complex_w, complex_prev)
def forward(self, model, model_prev):
loss_ana = 0.0
for (name, param), (name_prev, param_prev) in zip(model.named_parameters(), model_prev.named_parameters()):
if 'weight' in name:
phi = self.compute_phi(param)
# .detach() is critical to anchor the previous structural state
beta = self.compute_beta_proxy(param, param_prev.detach())
loss_ana += self.lambda_reg * (1 - phi) + self.eta_reg * beta
return loss_ana
# --- 4. PHASE 1: TRAINING TASK A ---
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("--- Starting Task A (Digits 0-4) Training ---")
model.train()
for epoch in range(3):
for images, labels in train_loader_A:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
acc_A = evaluate_accuracy(model, test_loader_A)
print(f"Accuracy on Task A after training Task A: {acc_A:.2f}%\n")
# Freeze the base model to preserve its structural return invariants
model_A_frozen = copy.deepcopy(model)
model_A_frozen.eval()
# --- 5. PHASE 2: TRAINING TASK B WITH ANASTROPHIC REGULARIZATION ---
regularizer = AnastrophicRegularizer(lambda_reg=1.0, eta_reg=3.0)
print("--- Starting Task B (Digits 5-9) Training with R_ana ---")
optimizer_B = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(3):
model.train()
running_loss_class = 0.0
running_loss_ana = 0.0
for images, labels in train_loader_B:
optimizer_B.zero_grad()
outputs = model(images)
loss_class = criterion(outputs, labels)
# Apply Anastrophic Theory to preserve structural relationships
loss_ana = regularizer(model, model_A_frozen)
loss = loss_class + loss_ana
loss.backward()
optimizer_B.step()
running_loss_class += loss_class.item()
running_loss_ana += loss_ana.item()
print(f"Epoch {epoch+1} | Classification Loss: {running_loss_class/len(train_loader_B):.4f} | R_ana Loss: {running_loss_ana/len(train_loader_B):.4f}")
# --- 6. FINAL EVALUATION ---
print("\n--- Final Results ---")
acc_B_final = evaluate_accuracy(model, test_loader_B)
acc_A_final = evaluate_accuracy(model, test_loader_A)
print(f"Accuracy on NEW Task B (5-9): {acc_B_final:.2f}%")
print(f"RETAINED Accuracy on Task A (0-4): {acc_A_final:.2f}%")