sci / run_sci_signal_v2.py
vishal-1344's picture
Initial SCI framework upload (v1)
6ba1ba5 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
# --- CONFIGURATION v3 ---
BATCH_SIZE = 64
TRAIN_SIZE = 4000 # Increased for stability
TEST_SIZE = 1000
EPOCHS = 5 # Increased for better convergence
SP_TARGET = 0.85 # Realistically calibrated target (was 0.95)
MAX_STEPS = 15 # Give controller room to work
TEMPERATURE = 0.5 # Temperature scaling (sharpening)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- 1. MODEL DEFINITION ---
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1) # Larger filters
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x # Returns logits
# --- 2. UTILS: SP CALCULATION ---
def compute_sp(probs):
"""SP = 1 - (Entropy / MaxEntropy)"""
probs = torch.clamp(probs, min=1e-9)
entropy = -torch.sum(probs * torch.log(probs), dim=1)
max_entropy = np.log(10)
sp = 1.0 - (entropy / max_entropy)
return sp
# --- 3. TRAINING ---
def train_model():
print(f"Loading MNIST (Train: {TRAIN_SIZE}, Test: {TEST_SIZE})...")
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
full_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(Subset(full_train, range(TRAIN_SIZE)), batch_size=BATCH_SIZE, shuffle=True)
model = SimpleCNN().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
print(f"Training for {EPOCHS} epochs...")
for epoch in range(EPOCHS):
for data, target in train_loader:
data, target = data.to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
return model
# --- 4. EVALUATION ---
def evaluate(model):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
test_loader = DataLoader(Subset(datasets.MNIST('./data', train=False, transform=transform), range(TEST_SIZE)), batch_size=1, shuffle=False)
base_acc, sci_acc = 0, 0
base_sp_list, sci_sp_list = [], []
sci_steps_list = []
model.train() # Stochastic mode ON
print(f"Running Inference (Target SP={SP_TARGET}, Temp={TEMPERATURE})...")
with torch.no_grad():
for i, (data, target) in enumerate(test_loader):
data, target = data.to(DEVICE), target.to(DEVICE)
# --- BASELINE (1 Pass) ---
logits = model(data)
# Apply Temperature Scaling to Baseline too for fair comparison
probs = F.softmax(logits / TEMPERATURE, dim=1)
sp = compute_sp(probs)
pred = probs.argmax(dim=1)
base_acc += pred.eq(target).sum().item()
base_sp_list.append(sp.item())
# --- SCI (Logit Averaging Controller) ---
accum_logits = logits.clone() # Start with first pass logits
steps = 1
current_sp = sp.item()
# Loop: while quality is low, compute more
while current_sp < SP_TARGET and steps < MAX_STEPS:
new_logits = model(data)
accum_logits += new_logits
steps += 1
# KEY CHANGE: Average Logits -> Softmax (Not Average Probs)
mean_logits = accum_logits / steps
current_probs = F.softmax(mean_logits / TEMPERATURE, dim=1)
current_sp = compute_sp(current_probs).item()
# Final Decision
final_mean_logits = accum_logits / steps
sci_probs = F.softmax(final_mean_logits / TEMPERATURE, dim=1)
sci_pred = sci_probs.argmax(dim=1)
sci_acc += sci_pred.eq(target).sum().item()
sci_sp_list.append(current_sp)
sci_steps_list.append(steps)
# --- 5. STATS ---
base_acc_pct = 100.0 * base_acc / TEST_SIZE
sci_acc_pct = 100.0 * sci_acc / TEST_SIZE
mean_base_sp = np.mean(base_sp_list)
mean_sci_sp = np.mean(sci_sp_list)
base_errors = [abs(SP_TARGET - sp) for sp in base_sp_list]
sci_errors = [abs(SP_TARGET - sp) for sp in sci_sp_list]
mean_base_error = np.mean(base_errors)
mean_sci_error = np.mean(sci_errors)
reduction = (mean_base_error - mean_sci_error) / mean_base_error * 100.0
avg_steps = np.mean(sci_steps_list)
print("\n" + "="*65)
print(f"RESULTS v3: SCI (Logit Avg + Temp Scaling) vs Baseline")
print("="*65)
print(f"{'Metric':<25} | {'Baseline':<10} | {'SCI (Adaptive)':<15}")
print("-" * 65)
print(f"{'Accuracy':<25} | {base_acc_pct:.2f}% | {sci_acc_pct:.2f}%")
print(f"{'Mean Surgical Precision':<25} | {mean_base_sp:.4f} | {mean_sci_sp:.4f}")
print(f"{'Mean Steps':<25} | {1.0:.2f} | {avg_steps:.2f}")
print("-" * 65)
print(f"{'Interpretive Error (dSP)':<25} | {mean_base_error:.4f} | {mean_sci_error:.4f}")
print(f"{'Error Reduction':<25} | - | {reduction:.2f}%")
print("="*65)
if __name__ == "__main__":
trained_model = train_model()
evaluate(trained_model)