SONAR / gate_model.py
arnavmishra4's picture
Upload 9 files
0b37019 verified
"""
gate_model.py
=============
Gating Network that takes probability matrix inputs and predicts
whether a patch contains an archaeological site.
Input: (4 channels, 64, 64)
- Channel 0: Autoencoder probability
- Channel 1: Isolation Forest probability
- Channel 2: K-Means probability
- Channel 3: Archaeological Site Similarity
Output: Binary classification [0, 1]
- 1 = Archaeological site
- 0 = Not archaeological site
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from pathlib import Path
from typing import Tuple, List, Dict, Optional
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve
)
import json
# ==============================================================================
# GATE ARCHITECTURE
# ==============================================================================
class GateNetwork(nn.Module):
"""
Gating network for archaeological site detection
Architecture:
- Input: (batch, 4, 64, 64) probability matrices
- CNN encoder to extract spatial features
- Global pooling + MLP classifier
- Output: (batch, 1) sigmoid probability
"""
def __init__(self, in_channels: int = 4, dropout: float = 0.3):
super().__init__()
# Convolutional encoder
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2) # 64 -> 32
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2) # 32 -> 16
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2) # 16 -> 8
)
self.conv4 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1) # Global average pooling
)
# Classifier head
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(64, 1)
)
def forward(self, x):
"""
Args:
x: (batch, 4, 64, 64) probability matrices
Returns:
logits: (batch, 1) raw logits
probs: (batch, 1) sigmoid probabilities
"""
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
logits = self.classifier(x)
probs = torch.sigmoid(logits)
return logits, probs
def predict(self, x):
"""Predict class labels (0 or 1)"""
_, probs = self.forward(x)
return (probs > 0.5).long()
def count_parameters(self):
"""Count trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ==============================================================================
# DATASET
# ==============================================================================
class ArchaeologyGateDataset(Dataset):
"""
Dataset for training the Gate Network
Loads unified probability matrices and corresponding labels
"""
def __init__(
self,
aoi_list: List[str],
labels: List[int],
prob_matrix_dir: Path,
augment: bool = False
):
"""
Args:
aoi_list: List of AOI names
labels: List of labels (1 = archaeology, 0 = not)
prob_matrix_dir: Directory containing unified probability matrices
augment: Whether to apply data augmentation
"""
self.aoi_list = aoi_list
self.labels = labels
self.prob_matrix_dir = prob_matrix_dir
self.augment = augment
assert len(aoi_list) == len(labels), "AOI list and labels must match"
# Load all data into memory
self.data = []
self.metadata_list = []
print(f"Loading {len(aoi_list)} AOIs...")
for aoi_name, label in tqdm(zip(aoi_list, labels), total=len(aoi_list)):
matrix_path = prob_matrix_dir / f"{aoi_name}_unified_prob_matrix.npz"
if not matrix_path.exists():
print(f"⚠️ Warning: Missing {matrix_path}")
continue
with np.load(matrix_path, allow_pickle=True) as data:
unified_matrix = data['unified_matrix'] # (num_patches, 64, 64, 4)
metadata = data['metadata']
# Add all patches from this AOI
for patch_matrix in unified_matrix:
# Transpose to (4, 64, 64) for PyTorch
patch_matrix = np.transpose(patch_matrix, (2, 0, 1))
self.data.append((patch_matrix, label))
print(f"✓ Loaded {len(self.data)} patches")
# Class distribution
labels_array = np.array([label for _, label in self.data])
n_positive = np.sum(labels_array == 1)
n_negative = np.sum(labels_array == 0)
print(f" Positive (archaeology): {n_positive} ({n_positive/len(self.data)*100:.1f}%)")
print(f" Negative (not): {n_negative} ({n_negative/len(self.data)*100:.1f}%)")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
patch_matrix, label = self.data[idx]
# Convert to torch tensors
patch_matrix = torch.from_numpy(patch_matrix).float()
label = torch.tensor(label, dtype=torch.float32)
# Data augmentation (optional)
if self.augment:
# Random flip
if torch.rand(1) > 0.5:
patch_matrix = torch.flip(patch_matrix, dims=[1])
if torch.rand(1) > 0.5:
patch_matrix = torch.flip(patch_matrix, dims=[2])
# Random rotation (90, 180, 270 degrees)
k = torch.randint(0, 4, (1,)).item()
if k > 0:
patch_matrix = torch.rot90(patch_matrix, k=k, dims=[1, 2])
return patch_matrix, label
# ==============================================================================
# TRAINING
# ==============================================================================
class GateTrainer:
"""Trainer for Gate Network"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
lr: float = 1e-3,
weight_decay: float = 1e-4,
pos_weight: Optional[float] = None
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay
)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='max',
factor=0.5,
patience=5,
verbose=True
)
# Loss function with optional class weighting
if pos_weight is not None:
pos_weight_tensor = torch.tensor([pos_weight]).to(device)
self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
else:
self.criterion = nn.BCEWithLogitsLoss()
# Tracking
self.history = {
'train_loss': [], 'val_loss': [],
'train_acc': [], 'val_acc': [],
'train_f1': [], 'val_f1': [],
'lr': []
}
self.best_val_f1 = 0.0
self.best_model_state = None
def train_epoch(self):
"""Train for one epoch"""
self.model.train()
running_loss = 0.0
all_preds = []
all_labels = []
pbar = tqdm(self.train_loader, desc="Training", leave=False)
for batch_matrices, batch_labels in pbar:
batch_matrices = batch_matrices.to(self.device)
batch_labels = batch_labels.to(self.device).unsqueeze(1)
# Forward
logits, probs = self.model(batch_matrices)
loss = self.criterion(logits, batch_labels)
# Backward
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Track
running_loss += loss.item() * batch_matrices.size(0)
all_preds.extend((probs > 0.5).cpu().numpy().flatten())
all_labels.extend(batch_labels.cpu().numpy().flatten())
pbar.set_postfix({'loss': loss.item()})
epoch_loss = running_loss / len(self.train_loader.dataset)
epoch_acc = accuracy_score(all_labels, all_preds)
epoch_f1 = f1_score(all_labels, all_preds, zero_division=0)
return epoch_loss, epoch_acc, epoch_f1
def validate(self):
"""Validate on validation set"""
self.model.eval()
running_loss = 0.0
all_preds = []
all_probs = []
all_labels = []
with torch.no_grad():
for batch_matrices, batch_labels in tqdm(self.val_loader, desc="Validating", leave=False):
batch_matrices = batch_matrices.to(self.device)
batch_labels = batch_labels.to(self.device).unsqueeze(1)
logits, probs = self.model(batch_matrices)
loss = self.criterion(logits, batch_labels)
running_loss += loss.item() * batch_matrices.size(0)
all_preds.extend((probs > 0.5).cpu().numpy().flatten())
all_probs.extend(probs.cpu().numpy().flatten())
all_labels.extend(batch_labels.cpu().numpy().flatten())
epoch_loss = running_loss / len(self.val_loader.dataset)
epoch_acc = accuracy_score(all_labels, all_preds)
epoch_f1 = f1_score(all_labels, all_preds, zero_division=0)
epoch_prec = precision_score(all_labels, all_preds, zero_division=0)
epoch_rec = recall_score(all_labels, all_preds, zero_division=0)
try:
epoch_auc = roc_auc_score(all_labels, all_probs)
except:
epoch_auc = 0.0
return {
'loss': epoch_loss,
'acc': epoch_acc,
'f1': epoch_f1,
'precision': epoch_prec,
'recall': epoch_rec,
'auc': epoch_auc,
'preds': all_preds,
'probs': all_probs,
'labels': all_labels
}
def train(self, num_epochs: int, save_dir: Path):
"""Train for multiple epochs"""
save_dir.mkdir(exist_ok=True, parents=True)
print(f"\n{'='*80}")
print(f"TRAINING GATE NETWORK")
print(f"{'='*80}")
print(f"Model parameters: {self.model.count_parameters():,}")
print(f"Training samples: {len(self.train_loader.dataset)}")
print(f"Validation samples: {len(self.val_loader.dataset)}")
print(f"Epochs: {num_epochs}")
print(f"{'='*80}\n")
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
print("-" * 40)
# Train
train_loss, train_acc, train_f1 = self.train_epoch()
# Validate
val_metrics = self.validate()
# Update scheduler
self.scheduler.step(val_metrics['f1'])
# Track history
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_metrics['loss'])
self.history['train_acc'].append(train_acc)
self.history['val_acc'].append(val_metrics['acc'])
self.history['train_f1'].append(train_f1)
self.history['val_f1'].append(val_metrics['f1'])
self.history['lr'].append(self.optimizer.param_groups[0]['lr'])
# Print metrics
print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['acc']:.4f} | F1: {val_metrics['f1']:.4f}")
print(f" Prec: {val_metrics['precision']:.4f} | Rec: {val_metrics['recall']:.4f} | AUC: {val_metrics['auc']:.4f}")
# Save best model
if val_metrics['f1'] > self.best_val_f1:
self.best_val_f1 = val_metrics['f1']
self.best_model_state = self.model.state_dict().copy()
torch.save({
'epoch': epoch,
'model_state_dict': self.best_model_state,
'optimizer_state_dict': self.optimizer.state_dict(),
'val_f1': self.best_val_f1,
'history': self.history
}, save_dir / 'best_gate_model.pth')
print(f"✓ Saved best model (F1: {self.best_val_f1:.4f})")
# Load best model
self.model.load_state_dict(self.best_model_state)
print(f"\n{'='*80}")
print(f"TRAINING COMPLETE")
print(f"Best Validation F1: {self.best_val_f1:.4f}")
print(f"{'='*80}\n")
return self.history
# ==============================================================================
# EVALUATION & VISUALIZATION
# ==============================================================================
def plot_training_history(history: Dict, save_path: Optional[Path] = None):
"""Plot training curves"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Loss
axes[0, 0].plot(history['train_loss'], label='Train')
axes[0, 0].plot(history['val_loss'], label='Val')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)
# Accuracy
axes[0, 1].plot(history['train_acc'], label='Train')
axes[0, 1].plot(history['val_acc'], label='Val')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Training & Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)
# F1
axes[1, 0].plot(history['train_f1'], label='Train')
axes[1, 0].plot(history['val_f1'], label='Val')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].set_title('Training & Validation F1')
axes[1, 0].legend()
axes[1, 0].grid(True)
# Learning rate
axes[1, 1].plot(history['lr'])
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"💾 Saved training history: {save_path}")
plt.show()
def plot_confusion_matrix(labels, preds, save_path: Optional[Path] = None):
"""Plot confusion matrix"""
cm = confusion_matrix(labels, preds)
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm, cmap='Blues')
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(['Not Arch', 'Arch'])
ax.set_yticklabels(['Not Arch', 'Arch'])
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title('Confusion Matrix')
# Add text annotations
for i in range(2):
for j in range(2):
text = ax.text(j, i, cm[i, j],
ha="center", va="center", color="white" if cm[i, j] > cm.max()/2 else "black",
fontsize=20, fontweight='bold')
plt.colorbar(im, ax=ax)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"💾 Saved confusion matrix: {save_path}")
plt.show()
def plot_roc_curve(labels, probs, save_path: Optional[Path] = None):
"""Plot ROC curve"""
fpr, tpr, thresholds = roc_curve(labels, probs)
auc = roc_auc_score(labels, probs)
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, linewidth=2, label=f'ROC (AUC = {auc:.3f})')
ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve')
ax.legend()
ax.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"💾 Saved ROC curve: {save_path}")
plt.show()
def evaluate_model(
model: nn.Module,
test_loader: DataLoader,
device: torch.device,
save_dir: Optional[Path] = None
) -> Dict:
"""Comprehensive model evaluation"""
model.eval()
all_preds = []
all_probs = []
all_labels = []
print("\n🔬 Evaluating model on test set...")
with torch.no_grad():
for batch_matrices, batch_labels in tqdm(test_loader, desc="Testing"):
batch_matrices = batch_matrices.to(device)
_, probs = model(batch_matrices)
all_preds.extend((probs > 0.5).cpu().numpy().flatten())
all_probs.extend(probs.cpu().numpy().flatten())
all_labels.extend(batch_labels.numpy().flatten())
# Compute metrics
acc = accuracy_score(all_labels, all_preds)
prec = precision_score(all_labels, all_preds, zero_division=0)
rec = recall_score(all_labels, all_preds, zero_division=0)
f1 = f1_score(all_labels, all_preds, zero_division=0)
auc = roc_auc_score(all_labels, all_probs)
cm = confusion_matrix(all_labels, all_preds)
results = {
'accuracy': acc,
'precision': prec,
'recall': rec,
'f1': f1,
'auc': auc,
'confusion_matrix': cm.tolist(),
'predictions': all_preds,
'probabilities': all_probs,
'labels': all_labels
}
# Print results
print(f"\n{'='*60}")
print(f"TEST SET EVALUATION")
print(f"{'='*60}")
print(f"Accuracy: {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall: {rec:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUC: {auc:.4f}")
print(f"\nConfusion Matrix:")
print(f" Predicted")
print(f" Not Arch | Arch")
print(f"True Not Arch {cm[0,0]:4d} | {cm[0,1]:4d}")
print(f"True Arch {cm[1,0]:4d} | {cm[1,1]:4d}")
print(f"{'='*60}\n")
# Visualizations
if save_dir:
save_dir.mkdir(exist_ok=True, parents=True)
plot_confusion_matrix(all_labels, all_preds, save_dir / 'confusion_matrix.png')
plot_roc_curve(all_labels, all_probs, save_dir / 'roc_curve.png')
# Save results as JSON
with open(save_dir / 'test_results.json', 'w') as f:
json_results = {k: v for k, v in results.items()
if k not in ['predictions', 'probabilities', 'labels']}
json.dump(json_results, f, indent=2)
print(f"💾 Saved evaluation results to {save_dir}")
return results