""" EXPERIMENT 3: Neural Network Activation Comparison Implement HSK activation (F(z) = sum z^n/n^n approx) and compare with ReLU, Swish, GELU on MNIST classification. Key insight: F(z) can be approximated by a truncated sum, but this is expensive. We test both: a) Truncated F(z) as activation (20 terms) b) Simplified "HSK" from the document (grad-based approximation) """ import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms import time import json # ============================================================ # Activation Functions # ============================================================ class HSKActivation(nn.Module): """Full truncated F(z) = sum_{n=1}^{N} z^n / n^n as activation""" def __init__(self, N=20): super().__init__() self.N = N # Precompute n^n terms self.register_buffer('nn_terms', torch.tensor([float(n**n) for n in range(1, N+1)])) def forward(self, z): result = torch.zeros_like(z) z_power = torch.ones_like(z) for n in range(1, self.N + 1): z_power = z_power * z # z^n result = result + z_power / self.nn_terms[n-1] return result class HSKApproxActivation(nn.Module): """Approximation from the document: linear growth shift""" def __init__(self): super().__init__() self.inv_e = 0.3678794412 def forward(self, z): # The document uses z * 0.367 as the "linear growth shift" # This is essentially: output = z / e # Which is just a scaled identity - terrible as activation (no nonlinearity!) # The "gradient" they compute is for F'/F, not for the activation itself return z * self.inv_e class SwishActivation(nn.Module): def forward(self, z): return z * torch.sigmoid(z) # ============================================================ # Network Architecture # ============================================================ class DeepNet(nn.Module): def __init__(self, activation_fn, hidden_size=128, num_layers=10): super().__init__() self.activation_name = activation_fn.__class__.__name__ layers = [] layers.append(nn.Linear(784, hidden_size)) layers.append(activation_fn) for _ in range(num_layers - 1): layers.append(nn.Linear(hidden_size, hidden_size)) # Need new activation instance for each layer (some have state) if isinstance(activation_fn, HSKActivation): layers.append(HSKActivation(N=20)) elif isinstance(activation_fn, HSKApproxActivation): layers.append(HSKApproxActivation()) elif isinstance(activation_fn, SwishActivation): layers.append(SwishActivation()) else: # ReLU, GELU etc - stateless, can reuse layers.append(activation_fn) layers.append(nn.Linear(hidden_size, 10)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x.view(-1, 784)) # ============================================================ # Training Loop # ============================================================ def train_and_evaluate(activation_fn, name, hidden_size=128, num_layers=10, epochs=5, lr=0.001): device = torch.device('cpu') train_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transforms.ToTensor()), batch_size=128, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transforms.ToTensor()), batch_size=256, shuffle=False) model = DeepNet(activation_fn, hidden_size, num_layers).to(device) optimizer = optim.Adam(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() results = { 'name': name, 'epochs': [], 'final_test_acc': 0, 'total_time': 0, 'grad_norms': [], } start_time = time.time() for epoch in range(epochs): model.train() total_loss = 0 batch_count = 0 grad_norms_epoch = [] for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() # Track gradient norms total_grad_norm = 0 for p in model.parameters(): if p.grad is not None: total_grad_norm += p.grad.data.norm(2).item() ** 2 total_grad_norm = total_grad_norm ** 0.5 grad_norms_epoch.append(total_grad_norm) # Check for NaN/Inf if torch.isnan(loss) or torch.isinf(loss): print(f" WARNING: NaN/Inf loss at epoch {epoch+1}, batch {batch_idx}") results['epochs'].append({'epoch': epoch+1, 'loss': float('nan'), 'acc': 0}) results['final_test_acc'] = 0 results['total_time'] = time.time() - start_time return results torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() total_loss += loss.item() batch_count += 1 # Test model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: output = model(data.view(-1, 784)) pred = output.argmax(dim=1) correct += (pred == target).sum().item() total += target.size(0) test_acc = correct / total avg_loss = total_loss / batch_count avg_grad = sum(grad_norms_epoch) / len(grad_norms_epoch) results['epochs'].append({ 'epoch': epoch+1, 'loss': avg_loss, 'test_acc': test_acc, 'avg_grad_norm': avg_grad, }) results['grad_norms'].extend(grad_norms_epoch) print(f" Epoch {epoch+1}: loss={avg_loss:.4f}, test_acc={test_acc:.4f}, grad_norm={avg_grad:.2f}") results['final_test_acc'] = test_acc results['total_time'] = time.time() - start_time return results # ============================================================ # Run Experiments # ============================================================ print("=" * 65) print("EXPERIMENT 3: Neural Network Activation Comparison") print("=" * 65) activations = [ (nn.ReLU(), "ReLU"), (nn.GELU(), "GELU"), (SwishActivation(), "Swish"), (HSKActivation(N=20), "HSK-Truncated(F_z)"), (HSKApproxActivation(), "HSK-Approx(z/e)"), ] all_results = {} for act_fn, name in activations: print(f"\n--- Training with {name} activation ---") try: results = train_and_evaluate(act_fn, name, hidden_size=128, num_layers=10, epochs=5) all_results[name] = results print(f" Final accuracy: {results['final_test_acc']:.4f}") print(f" Total time: {results['total_time']:.1f}s") except Exception as e: print(f" FAILED: {e}") all_results[name] = {'error': str(e)} # Summary print("\n" + "=" * 65) print("SUMMARY") print("=" * 65) print(f"{'Activation':>20s} {'Test Acc':>10s} {'Time':>8s} {'Final Grad Norm':>15s}") for name, res in all_results.items(): if 'error' in res: print(f"{name:>20s} FAILED: {res['error']}") else: acc = res['final_test_acc'] t = res['total_time'] last_epoch = res['epochs'][-1] gn = last_epoch.get('avg_grad_norm', 0) print(f"{name:>20s} {acc:10.4f} {t:8.1f}s {gn:15.2f}") # Save with open('/app/exp3_results.json', 'w') as f: # Convert any non-serializable types def default_handler(obj): if isinstance(obj, float) and (torch.isnan(torch.tensor(obj)) or torch.isinf(torch.tensor(obj))): return str(obj) return obj json.dump(all_results, f, default=default_handler, indent=2) print("\nSaved to /app/exp3_results.json")