|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import torch.optim as optim
|
|
|
from torchvision import datasets, transforms
|
|
|
from torch.utils.data import DataLoader, Subset
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 64
|
|
|
TRAIN_SIZE = 4000
|
|
|
TEST_SIZE = 1000
|
|
|
EPOCHS = 5
|
|
|
SP_TARGET = 0.85
|
|
|
MAX_STEPS = 15
|
|
|
TEMPERATURE = 0.5
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
class SimpleCNN(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(SimpleCNN, self).__init__()
|
|
|
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
|
|
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
|
|
self.dropout1 = nn.Dropout(0.25)
|
|
|
self.dropout2 = nn.Dropout(0.5)
|
|
|
self.fc1 = nn.Linear(9216, 128)
|
|
|
self.fc2 = nn.Linear(128, 10)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.conv1(x)
|
|
|
x = F.relu(x)
|
|
|
x = self.conv2(x)
|
|
|
x = F.relu(x)
|
|
|
x = F.max_pool2d(x, 2)
|
|
|
x = self.dropout1(x)
|
|
|
x = torch.flatten(x, 1)
|
|
|
x = self.fc1(x)
|
|
|
x = F.relu(x)
|
|
|
x = self.dropout2(x)
|
|
|
x = self.fc2(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def compute_sp(probs):
|
|
|
"""SP = 1 - (Entropy / MaxEntropy)"""
|
|
|
probs = torch.clamp(probs, min=1e-9)
|
|
|
entropy = -torch.sum(probs * torch.log(probs), dim=1)
|
|
|
max_entropy = np.log(10)
|
|
|
sp = 1.0 - (entropy / max_entropy)
|
|
|
return sp
|
|
|
|
|
|
|
|
|
def train_model():
|
|
|
print(f"Loading MNIST (Train: {TRAIN_SIZE}, Test: {TEST_SIZE})...")
|
|
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
|
|
full_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
|
|
|
train_loader = DataLoader(Subset(full_train, range(TRAIN_SIZE)), batch_size=BATCH_SIZE, shuffle=True)
|
|
|
|
|
|
model = SimpleCNN().to(DEVICE)
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
|
|
model.train()
|
|
|
print(f"Training for {EPOCHS} epochs...")
|
|
|
for epoch in range(EPOCHS):
|
|
|
for data, target in train_loader:
|
|
|
data, target = data.to(DEVICE), target.to(DEVICE)
|
|
|
optimizer.zero_grad()
|
|
|
output = model(data)
|
|
|
loss = F.cross_entropy(output, target)
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
return model
|
|
|
|
|
|
|
|
|
def evaluate(model):
|
|
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
test_loader = DataLoader(Subset(datasets.MNIST('./data', train=False, transform=transform), range(TEST_SIZE)), batch_size=1, shuffle=False)
|
|
|
|
|
|
base_acc, sci_acc = 0, 0
|
|
|
base_sp_list, sci_sp_list = [], []
|
|
|
sci_steps_list = []
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
print(f"Running Inference (Target SP={SP_TARGET}, Temp={TEMPERATURE})...")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for i, (data, target) in enumerate(test_loader):
|
|
|
data, target = data.to(DEVICE), target.to(DEVICE)
|
|
|
|
|
|
|
|
|
logits = model(data)
|
|
|
|
|
|
probs = F.softmax(logits / TEMPERATURE, dim=1)
|
|
|
sp = compute_sp(probs)
|
|
|
pred = probs.argmax(dim=1)
|
|
|
|
|
|
base_acc += pred.eq(target).sum().item()
|
|
|
base_sp_list.append(sp.item())
|
|
|
|
|
|
|
|
|
accum_logits = logits.clone()
|
|
|
steps = 1
|
|
|
current_sp = sp.item()
|
|
|
|
|
|
|
|
|
while current_sp < SP_TARGET and steps < MAX_STEPS:
|
|
|
new_logits = model(data)
|
|
|
accum_logits += new_logits
|
|
|
steps += 1
|
|
|
|
|
|
|
|
|
mean_logits = accum_logits / steps
|
|
|
current_probs = F.softmax(mean_logits / TEMPERATURE, dim=1)
|
|
|
current_sp = compute_sp(current_probs).item()
|
|
|
|
|
|
|
|
|
final_mean_logits = accum_logits / steps
|
|
|
sci_probs = F.softmax(final_mean_logits / TEMPERATURE, dim=1)
|
|
|
sci_pred = sci_probs.argmax(dim=1)
|
|
|
|
|
|
sci_acc += sci_pred.eq(target).sum().item()
|
|
|
sci_sp_list.append(current_sp)
|
|
|
sci_steps_list.append(steps)
|
|
|
|
|
|
|
|
|
base_acc_pct = 100.0 * base_acc / TEST_SIZE
|
|
|
sci_acc_pct = 100.0 * sci_acc / TEST_SIZE
|
|
|
mean_base_sp = np.mean(base_sp_list)
|
|
|
mean_sci_sp = np.mean(sci_sp_list)
|
|
|
|
|
|
base_errors = [abs(SP_TARGET - sp) for sp in base_sp_list]
|
|
|
sci_errors = [abs(SP_TARGET - sp) for sp in sci_sp_list]
|
|
|
|
|
|
mean_base_error = np.mean(base_errors)
|
|
|
mean_sci_error = np.mean(sci_errors)
|
|
|
reduction = (mean_base_error - mean_sci_error) / mean_base_error * 100.0
|
|
|
avg_steps = np.mean(sci_steps_list)
|
|
|
|
|
|
print("\n" + "="*65)
|
|
|
print(f"RESULTS v3: SCI (Logit Avg + Temp Scaling) vs Baseline")
|
|
|
print("="*65)
|
|
|
print(f"{'Metric':<25} | {'Baseline':<10} | {'SCI (Adaptive)':<15}")
|
|
|
print("-" * 65)
|
|
|
print(f"{'Accuracy':<25} | {base_acc_pct:.2f}% | {sci_acc_pct:.2f}%")
|
|
|
print(f"{'Mean Surgical Precision':<25} | {mean_base_sp:.4f} | {mean_sci_sp:.4f}")
|
|
|
print(f"{'Mean Steps':<25} | {1.0:.2f} | {avg_steps:.2f}")
|
|
|
print("-" * 65)
|
|
|
print(f"{'Interpretive Error (dSP)':<25} | {mean_base_error:.4f} | {mean_sci_error:.4f}")
|
|
|
print(f"{'Error Reduction':<25} | - | {reduction:.2f}%")
|
|
|
print("="*65)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
trained_model = train_model()
|
|
|
evaluate(trained_model) |