100enigma's picture
SimQuantum — AMD Developer Hackathon
da98415
Raw
History Blame Contribute Delete
15.7 kB
"""
qdot/perception/classifier.py
==============================
TinyCNN — 3-class primary classifier for 2D stability diagrams.
EnsembleCNN — 5-model ensemble with max-disagreement uncertainty.
Key change from hackathon:
The CNN is now the PRIMARY CLASSIFIER, not an embedding extractor.
It has a 3-class softmax head and is trained end-to-end on CIM data.
Physics features (FFT, diagonal) are a *validator* layer, not the
primary signal. See blueprint §7.1 for why this matters.
Architecture (TinyCNN):
Input → (1, 64, 64) — log-preprocessed normalised conductance
Conv1 → (16, 32, 32) — 3×3, stride 2, BN, ReLU
Conv2 → (32, 16, 16) — 3×3, stride 2, BN, ReLU
Conv3 → (64, 8, 8) — 3×3, stride 2, BN, ReLU
Conv4 → (64, 4, 4) — 3×3, stride 2, BN, ReLU
GAP → (64,) — global average pooling
FC → (32,) — linear + ReLU ← OOD detector attaches here
Head → (3,) — linear (logits)
Ensemble:
5 independent TinyCNN instances trained from different random seeds.
uncertainty = max( max_j(p_j^(i)) - max_j(p_j^(k)) ) over all (i,k) pairs
where p^(i) is the softmax output of model i.
This is the max pairwise L∞ disagreement between models, feeds directly
into the Risk Score formula (§4.1): disagreement > 0.3 → r += 0.35.
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
# ---------------------------------------------------------------------------
# TinyCNN architecture
# ---------------------------------------------------------------------------
class TinyCNN(nn.Module):
"""
Compact CNN for 3-class stability diagram classification.
Designed for fast inference on CPU during real-device experiments
(target: < 5 ms per 64×64 patch on a modern laptop CPU).
"""
N_CLASSES = 3
def __init__(self, dropout_p: float = 0.2) -> None:
super().__init__()
self.encoder = nn.Sequential(
# Block 1: 1×64×64 → 16×32×32
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
# Block 2: 16×32×32 → 32×16×16
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# Block 3: 32×16×16 → 64×8×8
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# Block 4: 64×8×8 → 64×4×4
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
# Global average pooling: 64×4×4 → 64
self.gap = nn.AdaptiveAvgPool2d(1)
# Penultimate layer — OOD detector extracts features here
self.penultimate = nn.Sequential(
nn.Dropout(p=dropout_p),
nn.Linear(64, 32),
nn.ReLU(inplace=True),
)
# Classification head
self.head = nn.Linear(32, self.N_CLASSES)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass returning raw logits."""
x = self.encoder(x)
x = self.gap(x).squeeze(-1).squeeze(-1) # (B, 64)
x = self.penultimate(x) # (B, 32)
return self.head(x) # (B, 3)
def extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""
Extract penultimate layer features for OOD detection.
Returns tensor of shape (B, 32).
"""
with torch.no_grad():
x = self.encoder(x)
x = self.gap(x).squeeze(-1).squeeze(-1)
x = self.penultimate(x)
return x
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
"""Softmax probabilities. Shape: (B, 3)."""
with torch.no_grad():
return F.softmax(self.forward(x), dim=-1)
# ---------------------------------------------------------------------------
# Ensemble wrapper
# ---------------------------------------------------------------------------
class EnsembleCNN:
"""
5-model ensemble of TinyCNN instances.
Provides:
predict(x) — majority-vote label + mean confidence
predict_proba(x) — mean softmax probabilities across ensemble
uncertainty(x) — max-disagreement metric ∈ [0, 1]
The uncertainty metric feeds directly into the Risk Score:
state.ensemble_disagreement > 0.30 → r += 0.35
Usage:
ensemble = EnsembleCNN.from_trained(model_dir)
label, confidence, disagreement = ensemble.classify(array)
"""
N_MODELS = 5
def __init__(
self,
models: Optional[List[TinyCNN]] = None,
device: str = "cpu",
) -> None:
self.device = torch.device(device)
self.models: List[TinyCNN] = models or [TinyCNN() for _ in range(self.N_MODELS)]
for m in self.models:
m.to(self.device)
m.eval()
# -----------------------------------------------------------------------
# Primary interface
# -----------------------------------------------------------------------
def classify(
self, array: np.ndarray
) -> Tuple[int, float, float]:
"""
Classify a single 2D stability diagram.
Args:
array: float32 array of shape (H, W) or (1, H, W) or (1, 1, H, W).
Will be preprocessed automatically.
Returns:
(label_idx, confidence, disagreement)
label_idx: int ∈ {0, 1, 2} (DOUBLE_DOT, SINGLE_DOT, MISC)
confidence: float ∈ [0, 1] (mean max-prob across ensemble)
disagreement: float ∈ [0, 1] (max-disagreement metric)
"""
x = self._prepare(array)
all_probs = self._all_probabilities(x) # (N_MODELS, 3)
mean_probs = all_probs.mean(axis=0) # (3,)
label_idx = int(np.argmax(mean_probs))
confidence = float(mean_probs[label_idx])
disagreement = self._disagreement(all_probs)
return label_idx, confidence, disagreement
def predict_proba(self, array: np.ndarray) -> np.ndarray:
"""Mean softmax probabilities across ensemble. Shape: (3,)."""
x = self._prepare(array)
return self._all_probabilities(x).mean(axis=0)
def uncertainty(self, array: np.ndarray) -> float:
"""Max-disagreement metric ∈ [0, 1]."""
x = self._prepare(array)
return self._disagreement(self._all_probabilities(x))
def extract_features(self, array: np.ndarray) -> np.ndarray:
"""
Extract penultimate-layer features from model 0.
Used by MahalanobisOOD — we use a single reference model for OOD
to keep the feature space stable, then the ensemble provides UQ.
Returns: float32 array of shape (32,)
"""
x = self._prepare(array)
return self.models[0].extract_features(x).cpu().numpy().squeeze()
# -----------------------------------------------------------------------
# Training
# -----------------------------------------------------------------------
@classmethod
def train_from_data(
cls,
X_train: np.ndarray,
y_train: np.ndarray,
X_val: np.ndarray,
y_val: np.ndarray,
n_epochs: int = 30,
batch_size: int = 128,
lr: float = 3e-4,
device: str = "cpu",
model_dir: Optional[str] = None,
verbose: bool = True,
) -> "EnsembleCNN":
"""
Train all 5 ensemble members from scratch.
Args:
X_train: float32 (N, 1, 64, 64)
y_train: int64 (N,)
X_val: float32 (M, 1, 64, 64)
y_val: int64 (M,)
model_dir: if provided, saves each model checkpoint here.
Returns:
Trained EnsembleCNN ready for inference.
"""
ensemble = cls(device=device)
for i, model in enumerate(ensemble.models):
if verbose:
print(f"\n=== Training model {i+1}/{cls.N_MODELS} ===")
_train_single(
model=model,
X_train=X_train,
y_train=y_train,
X_val=X_val,
y_val=y_val,
n_epochs=n_epochs,
batch_size=batch_size,
lr=lr,
device=torch.device(device),
seed=i * 100 + 42,
verbose=verbose,
)
if model_dir:
Path(model_dir).mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), Path(model_dir) / f"model_{i}.pt")
return ensemble
# -----------------------------------------------------------------------
# Persistence
# -----------------------------------------------------------------------
def save(self, model_dir: str) -> None:
"""Save all model weights."""
Path(model_dir).mkdir(parents=True, exist_ok=True)
for i, model in enumerate(self.models):
torch.save(model.state_dict(), Path(model_dir) / f"model_{i}.pt")
@classmethod
def load(cls, model_dir: str, device: str = "cpu") -> "EnsembleCNN":
"""Load all model weights from a directory."""
models = []
for i in range(cls.N_MODELS):
path = Path(model_dir) / f"model_{i}.pt"
model = TinyCNN()
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
models.append(model)
return cls(models=models, device=device)
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def _prepare(self, array: np.ndarray) -> torch.Tensor:
"""
Prepare input array for inference.
Handles arbitrary input shape → (1, 1, 64, 64) tensor.
"""
from qdot.perception.features import log_preprocess
from scipy.ndimage import zoom
arr = np.asarray(array, dtype=np.float32)
# Strip batch/channel dims if present
while arr.ndim > 2 and arr.shape[0] == 1:
arr = arr.squeeze(0)
if arr.ndim != 2:
raise ValueError(f"Expected 2D array after squeezing, got shape {arr.shape}")
# Log-preprocess
arr = log_preprocess(arr)
# Resize to 64×64 if needed
if arr.shape != (64, 64):
scale = 64.0 / arr.shape[0]
arr = zoom(arr.astype(np.float64), scale, order=1).astype(np.float32)
arr = np.clip(arr, 0.0, 1.0)
# (1, 1, 64, 64) tensor
tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
return tensor.to(self.device)
def _all_probabilities(self, x: torch.Tensor) -> np.ndarray:
"""
Returns softmax probabilities from all models.
Shape: (N_MODELS, N_CLASSES).
"""
results = []
for model in self.models:
with torch.no_grad():
probs = model.predict_proba(x).cpu().numpy()
results.append(probs.squeeze())
return np.stack(results, axis=0)
@staticmethod
def _disagreement(all_probs: np.ndarray) -> float:
"""
Max-disagreement metric across ensemble.
For each pair of models (i, j), compute the L∞ distance between
their softmax vectors. Return the maximum over all pairs.
This is more interpretable than entropy because it directly
measures the worst-case disagreement between any two classifiers.
"""
n = all_probs.shape[0]
max_d = 0.0
for i in range(n):
for j in range(i + 1, n):
d = float(np.abs(all_probs[i] - all_probs[j]).max())
if d > max_d:
max_d = d
return max_d
# ---------------------------------------------------------------------------
# Single model training loop
# ---------------------------------------------------------------------------
def _train_single(
model: TinyCNN,
X_train: np.ndarray,
y_train: np.ndarray,
X_val: np.ndarray,
y_val: np.ndarray,
n_epochs: int,
batch_size: int,
lr: float,
device: torch.device,
seed: int,
verbose: bool,
) -> None:
"""Train a single TinyCNN with cosine-annealing LR and class-balanced sampling."""
torch.manual_seed(seed)
np.random.seed(seed)
model.to(device).train()
# Class-balanced sampler
counts = np.bincount(y_train)
class_weights = 1.0 / (counts + 1e-8)
sample_weights = class_weights[y_train]
sample_weights = sample_weights / sample_weights.sum()
sampler = torch.utils.data.WeightedRandomSampler(
weights=torch.from_numpy(sample_weights.astype(np.float64)),
num_samples=len(y_train),
replacement=True,
)
train_ds = TensorDataset(
torch.from_numpy(X_train).float(),
torch.from_numpy(y_train).long(),
)
val_ds = TensorDataset(
torch.from_numpy(X_val).float(),
torch.from_numpy(y_val).long(),
)
train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
optimiser = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimiser, T_max=n_epochs, eta_min=lr * 0.01
)
criterion = nn.CrossEntropyLoss()
best_val_acc = 0.0
best_state = None
for epoch in range(n_epochs):
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
optimiser.zero_grad()
logits = model(X_batch)
loss = criterion(logits, y_batch)
loss.backward()
optimiser.step()
train_loss += loss.item() * len(y_batch)
train_correct += (logits.argmax(dim=1) == y_batch).sum().item()
train_total += len(y_batch)
scheduler.step()
# Validation
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
for X_batch, y_batch in val_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
logits = model(X_batch)
val_correct += (logits.argmax(dim=1) == y_batch).sum().item()
val_total += len(y_batch)
val_acc = val_correct / val_total
if val_acc > best_val_acc:
best_val_acc = val_acc
best_state = {k: v.clone() for k, v in model.state_dict().items()}
if verbose and (epoch + 1) % 5 == 0:
train_acc = train_correct / train_total
print(
f" Epoch {epoch+1:3d}/{n_epochs} | "
f"train_loss={train_loss/train_total:.4f} | "
f"train_acc={train_acc:.3f} | "
f"val_acc={val_acc:.3f}"
)
# Restore best weights
if best_state is not None:
model.load_state_dict(best_state)
if verbose:
print(f" Best val_acc: {best_val_acc:.4f}")
model.eval()