5dimension's picture
Upload scripts/exp3_nn.py with huggingface_hub
dbaca87 verified
"""
EXPERIMENT 3: Neural Network Activation Comparison
Implement HSK activation (F(z) = sum z^n/n^n approx) and compare
with ReLU, Swish, GELU on MNIST classification.
Key insight: F(z) can be approximated by a truncated sum, but
this is expensive. We test both:
a) Truncated F(z) as activation (20 terms)
b) Simplified "HSK" from the document (grad-based approximation)
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import time
import json
# ============================================================
# Activation Functions
# ============================================================
class HSKActivation(nn.Module):
"""Full truncated F(z) = sum_{n=1}^{N} z^n / n^n as activation"""
def __init__(self, N=20):
super().__init__()
self.N = N
# Precompute n^n terms
self.register_buffer('nn_terms', torch.tensor([float(n**n) for n in range(1, N+1)]))
def forward(self, z):
result = torch.zeros_like(z)
z_power = torch.ones_like(z)
for n in range(1, self.N + 1):
z_power = z_power * z # z^n
result = result + z_power / self.nn_terms[n-1]
return result
class HSKApproxActivation(nn.Module):
"""Approximation from the document: linear growth shift"""
def __init__(self):
super().__init__()
self.inv_e = 0.3678794412
def forward(self, z):
# The document uses z * 0.367 as the "linear growth shift"
# This is essentially: output = z / e
# Which is just a scaled identity - terrible as activation (no nonlinearity!)
# The "gradient" they compute is for F'/F, not for the activation itself
return z * self.inv_e
class SwishActivation(nn.Module):
def forward(self, z):
return z * torch.sigmoid(z)
# ============================================================
# Network Architecture
# ============================================================
class DeepNet(nn.Module):
def __init__(self, activation_fn, hidden_size=128, num_layers=10):
super().__init__()
self.activation_name = activation_fn.__class__.__name__
layers = []
layers.append(nn.Linear(784, hidden_size))
layers.append(activation_fn)
for _ in range(num_layers - 1):
layers.append(nn.Linear(hidden_size, hidden_size))
# Need new activation instance for each layer (some have state)
if isinstance(activation_fn, HSKActivation):
layers.append(HSKActivation(N=20))
elif isinstance(activation_fn, HSKApproxActivation):
layers.append(HSKApproxActivation())
elif isinstance(activation_fn, SwishActivation):
layers.append(SwishActivation())
else:
# ReLU, GELU etc - stateless, can reuse
layers.append(activation_fn)
layers.append(nn.Linear(hidden_size, 10))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.view(-1, 784))
# ============================================================
# Training Loop
# ============================================================
def train_and_evaluate(activation_fn, name, hidden_size=128, num_layers=10,
epochs=5, lr=0.001):
device = torch.device('cpu')
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('/tmp/mnist', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('/tmp/mnist', train=False, download=True,
transform=transforms.ToTensor()),
batch_size=256, shuffle=False)
model = DeepNet(activation_fn, hidden_size, num_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
results = {
'name': name,
'epochs': [],
'final_test_acc': 0,
'total_time': 0,
'grad_norms': [],
}
start_time = time.time()
for epoch in range(epochs):
model.train()
total_loss = 0
batch_count = 0
grad_norms_epoch = []
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# Track gradient norms
total_grad_norm = 0
for p in model.parameters():
if p.grad is not None:
total_grad_norm += p.grad.data.norm(2).item() ** 2
total_grad_norm = total_grad_norm ** 0.5
grad_norms_epoch.append(total_grad_norm)
# Check for NaN/Inf
if torch.isnan(loss) or torch.isinf(loss):
print(f" WARNING: NaN/Inf loss at epoch {epoch+1}, batch {batch_idx}")
results['epochs'].append({'epoch': epoch+1, 'loss': float('nan'), 'acc': 0})
results['final_test_acc'] = 0
results['total_time'] = time.time() - start_time
return results
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
total_loss += loss.item()
batch_count += 1
# Test
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data.view(-1, 784))
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
test_acc = correct / total
avg_loss = total_loss / batch_count
avg_grad = sum(grad_norms_epoch) / len(grad_norms_epoch)
results['epochs'].append({
'epoch': epoch+1,
'loss': avg_loss,
'test_acc': test_acc,
'avg_grad_norm': avg_grad,
})
results['grad_norms'].extend(grad_norms_epoch)
print(f" Epoch {epoch+1}: loss={avg_loss:.4f}, test_acc={test_acc:.4f}, grad_norm={avg_grad:.2f}")
results['final_test_acc'] = test_acc
results['total_time'] = time.time() - start_time
return results
# ============================================================
# Run Experiments
# ============================================================
print("=" * 65)
print("EXPERIMENT 3: Neural Network Activation Comparison")
print("=" * 65)
activations = [
(nn.ReLU(), "ReLU"),
(nn.GELU(), "GELU"),
(SwishActivation(), "Swish"),
(HSKActivation(N=20), "HSK-Truncated(F_z)"),
(HSKApproxActivation(), "HSK-Approx(z/e)"),
]
all_results = {}
for act_fn, name in activations:
print(f"\n--- Training with {name} activation ---")
try:
results = train_and_evaluate(act_fn, name, hidden_size=128, num_layers=10, epochs=5)
all_results[name] = results
print(f" Final accuracy: {results['final_test_acc']:.4f}")
print(f" Total time: {results['total_time']:.1f}s")
except Exception as e:
print(f" FAILED: {e}")
all_results[name] = {'error': str(e)}
# Summary
print("\n" + "=" * 65)
print("SUMMARY")
print("=" * 65)
print(f"{'Activation':>20s} {'Test Acc':>10s} {'Time':>8s} {'Final Grad Norm':>15s}")
for name, res in all_results.items():
if 'error' in res:
print(f"{name:>20s} FAILED: {res['error']}")
else:
acc = res['final_test_acc']
t = res['total_time']
last_epoch = res['epochs'][-1]
gn = last_epoch.get('avg_grad_norm', 0)
print(f"{name:>20s} {acc:10.4f} {t:8.1f}s {gn:15.2f}")
# Save
with open('/app/exp3_results.json', 'w') as f:
# Convert any non-serializable types
def default_handler(obj):
if isinstance(obj, float) and (torch.isnan(torch.tensor(obj)) or torch.isinf(torch.tensor(obj))):
return str(obj)
return obj
json.dump(all_results, f, default=default_handler, indent=2)
print("\nSaved to /app/exp3_results.json")