""" gate_model.py ============= Gating Network that takes probability matrix inputs and predicts whether a patch contains an archaeological site. Input: (4 channels, 64, 64) - Channel 0: Autoencoder probability - Channel 1: Isolation Forest probability - Channel 2: K-Means probability - Channel 3: Archaeological Site Similarity Output: Binary classification [0, 1] - 1 = Archaeological site - 0 = Not archaeological site """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from pathlib import Path from typing import Tuple, List, Dict, Optional from tqdm import tqdm import matplotlib.pyplot as plt from sklearn.metrics import ( accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve ) import json # ============================================================================== # GATE ARCHITECTURE # ============================================================================== class GateNetwork(nn.Module): """ Gating network for archaeological site detection Architecture: - Input: (batch, 4, 64, 64) probability matrices - CNN encoder to extract spatial features - Global pooling + MLP classifier - Output: (batch, 1) sigmoid probability """ def __init__(self, in_channels: int = 4, dropout: float = 0.3): super().__init__() # Convolutional encoder self.conv1 = nn.Sequential( nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2) # 64 -> 32 ) self.conv2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2) # 32 -> 16 ) self.conv3 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2) # 16 -> 8 ) self.conv4 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1) # Global average pooling ) # Classifier head self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(dropout), nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(128, 64), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(64, 1) ) def forward(self, x): """ Args: x: (batch, 4, 64, 64) probability matrices Returns: logits: (batch, 1) raw logits probs: (batch, 1) sigmoid probabilities """ x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) logits = self.classifier(x) probs = torch.sigmoid(logits) return logits, probs def predict(self, x): """Predict class labels (0 or 1)""" _, probs = self.forward(x) return (probs > 0.5).long() def count_parameters(self): """Count trainable parameters""" return sum(p.numel() for p in self.parameters() if p.requires_grad) # ============================================================================== # DATASET # ============================================================================== class ArchaeologyGateDataset(Dataset): """ Dataset for training the Gate Network Loads unified probability matrices and corresponding labels """ def __init__( self, aoi_list: List[str], labels: List[int], prob_matrix_dir: Path, augment: bool = False ): """ Args: aoi_list: List of AOI names labels: List of labels (1 = archaeology, 0 = not) prob_matrix_dir: Directory containing unified probability matrices augment: Whether to apply data augmentation """ self.aoi_list = aoi_list self.labels = labels self.prob_matrix_dir = prob_matrix_dir self.augment = augment assert len(aoi_list) == len(labels), "AOI list and labels must match" # Load all data into memory self.data = [] self.metadata_list = [] print(f"Loading {len(aoi_list)} AOIs...") for aoi_name, label in tqdm(zip(aoi_list, labels), total=len(aoi_list)): matrix_path = prob_matrix_dir / f"{aoi_name}_unified_prob_matrix.npz" if not matrix_path.exists(): print(f"āš ļø Warning: Missing {matrix_path}") continue with np.load(matrix_path, allow_pickle=True) as data: unified_matrix = data['unified_matrix'] # (num_patches, 64, 64, 4) metadata = data['metadata'] # Add all patches from this AOI for patch_matrix in unified_matrix: # Transpose to (4, 64, 64) for PyTorch patch_matrix = np.transpose(patch_matrix, (2, 0, 1)) self.data.append((patch_matrix, label)) print(f"āœ“ Loaded {len(self.data)} patches") # Class distribution labels_array = np.array([label for _, label in self.data]) n_positive = np.sum(labels_array == 1) n_negative = np.sum(labels_array == 0) print(f" Positive (archaeology): {n_positive} ({n_positive/len(self.data)*100:.1f}%)") print(f" Negative (not): {n_negative} ({n_negative/len(self.data)*100:.1f}%)") def __len__(self): return len(self.data) def __getitem__(self, idx): patch_matrix, label = self.data[idx] # Convert to torch tensors patch_matrix = torch.from_numpy(patch_matrix).float() label = torch.tensor(label, dtype=torch.float32) # Data augmentation (optional) if self.augment: # Random flip if torch.rand(1) > 0.5: patch_matrix = torch.flip(patch_matrix, dims=[1]) if torch.rand(1) > 0.5: patch_matrix = torch.flip(patch_matrix, dims=[2]) # Random rotation (90, 180, 270 degrees) k = torch.randint(0, 4, (1,)).item() if k > 0: patch_matrix = torch.rot90(patch_matrix, k=k, dims=[1, 2]) return patch_matrix, label # ============================================================================== # TRAINING # ============================================================================== class GateTrainer: """Trainer for Gate Network""" def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, lr: float = 1e-3, weight_decay: float = 1e-4, pos_weight: Optional[float] = None ): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.optimizer = torch.optim.AdamW( model.parameters(), lr=lr, weight_decay=weight_decay ) self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='max', factor=0.5, patience=5, verbose=True ) # Loss function with optional class weighting if pos_weight is not None: pos_weight_tensor = torch.tensor([pos_weight]).to(device) self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) else: self.criterion = nn.BCEWithLogitsLoss() # Tracking self.history = { 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'train_f1': [], 'val_f1': [], 'lr': [] } self.best_val_f1 = 0.0 self.best_model_state = None def train_epoch(self): """Train for one epoch""" self.model.train() running_loss = 0.0 all_preds = [] all_labels = [] pbar = tqdm(self.train_loader, desc="Training", leave=False) for batch_matrices, batch_labels in pbar: batch_matrices = batch_matrices.to(self.device) batch_labels = batch_labels.to(self.device).unsqueeze(1) # Forward logits, probs = self.model(batch_matrices) loss = self.criterion(logits, batch_labels) # Backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Track running_loss += loss.item() * batch_matrices.size(0) all_preds.extend((probs > 0.5).cpu().numpy().flatten()) all_labels.extend(batch_labels.cpu().numpy().flatten()) pbar.set_postfix({'loss': loss.item()}) epoch_loss = running_loss / len(self.train_loader.dataset) epoch_acc = accuracy_score(all_labels, all_preds) epoch_f1 = f1_score(all_labels, all_preds, zero_division=0) return epoch_loss, epoch_acc, epoch_f1 def validate(self): """Validate on validation set""" self.model.eval() running_loss = 0.0 all_preds = [] all_probs = [] all_labels = [] with torch.no_grad(): for batch_matrices, batch_labels in tqdm(self.val_loader, desc="Validating", leave=False): batch_matrices = batch_matrices.to(self.device) batch_labels = batch_labels.to(self.device).unsqueeze(1) logits, probs = self.model(batch_matrices) loss = self.criterion(logits, batch_labels) running_loss += loss.item() * batch_matrices.size(0) all_preds.extend((probs > 0.5).cpu().numpy().flatten()) all_probs.extend(probs.cpu().numpy().flatten()) all_labels.extend(batch_labels.cpu().numpy().flatten()) epoch_loss = running_loss / len(self.val_loader.dataset) epoch_acc = accuracy_score(all_labels, all_preds) epoch_f1 = f1_score(all_labels, all_preds, zero_division=0) epoch_prec = precision_score(all_labels, all_preds, zero_division=0) epoch_rec = recall_score(all_labels, all_preds, zero_division=0) try: epoch_auc = roc_auc_score(all_labels, all_probs) except: epoch_auc = 0.0 return { 'loss': epoch_loss, 'acc': epoch_acc, 'f1': epoch_f1, 'precision': epoch_prec, 'recall': epoch_rec, 'auc': epoch_auc, 'preds': all_preds, 'probs': all_probs, 'labels': all_labels } def train(self, num_epochs: int, save_dir: Path): """Train for multiple epochs""" save_dir.mkdir(exist_ok=True, parents=True) print(f"\n{'='*80}") print(f"TRAINING GATE NETWORK") print(f"{'='*80}") print(f"Model parameters: {self.model.count_parameters():,}") print(f"Training samples: {len(self.train_loader.dataset)}") print(f"Validation samples: {len(self.val_loader.dataset)}") print(f"Epochs: {num_epochs}") print(f"{'='*80}\n") for epoch in range(1, num_epochs + 1): print(f"\nEpoch {epoch}/{num_epochs}") print("-" * 40) # Train train_loss, train_acc, train_f1 = self.train_epoch() # Validate val_metrics = self.validate() # Update scheduler self.scheduler.step(val_metrics['f1']) # Track history self.history['train_loss'].append(train_loss) self.history['val_loss'].append(val_metrics['loss']) self.history['train_acc'].append(train_acc) self.history['val_acc'].append(val_metrics['acc']) self.history['train_f1'].append(train_f1) self.history['val_f1'].append(val_metrics['f1']) self.history['lr'].append(self.optimizer.param_groups[0]['lr']) # Print metrics print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}") print(f"Val Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['acc']:.4f} | F1: {val_metrics['f1']:.4f}") print(f" Prec: {val_metrics['precision']:.4f} | Rec: {val_metrics['recall']:.4f} | AUC: {val_metrics['auc']:.4f}") # Save best model if val_metrics['f1'] > self.best_val_f1: self.best_val_f1 = val_metrics['f1'] self.best_model_state = self.model.state_dict().copy() torch.save({ 'epoch': epoch, 'model_state_dict': self.best_model_state, 'optimizer_state_dict': self.optimizer.state_dict(), 'val_f1': self.best_val_f1, 'history': self.history }, save_dir / 'best_gate_model.pth') print(f"āœ“ Saved best model (F1: {self.best_val_f1:.4f})") # Load best model self.model.load_state_dict(self.best_model_state) print(f"\n{'='*80}") print(f"TRAINING COMPLETE") print(f"Best Validation F1: {self.best_val_f1:.4f}") print(f"{'='*80}\n") return self.history # ============================================================================== # EVALUATION & VISUALIZATION # ============================================================================== def plot_training_history(history: Dict, save_path: Optional[Path] = None): """Plot training curves""" fig, axes = plt.subplots(2, 2, figsize=(14, 10)) # Loss axes[0, 0].plot(history['train_loss'], label='Train') axes[0, 0].plot(history['val_loss'], label='Val') axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].set_title('Training & Validation Loss') axes[0, 0].legend() axes[0, 0].grid(True) # Accuracy axes[0, 1].plot(history['train_acc'], label='Train') axes[0, 1].plot(history['val_acc'], label='Val') axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy') axes[0, 1].set_title('Training & Validation Accuracy') axes[0, 1].legend() axes[0, 1].grid(True) # F1 axes[1, 0].plot(history['train_f1'], label='Train') axes[1, 0].plot(history['val_f1'], label='Val') axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('F1 Score') axes[1, 0].set_title('Training & Validation F1') axes[1, 0].legend() axes[1, 0].grid(True) # Learning rate axes[1, 1].plot(history['lr']) axes[1, 1].set_xlabel('Epoch') axes[1, 1].set_ylabel('Learning Rate') axes[1, 1].set_title('Learning Rate Schedule') axes[1, 1].set_yscale('log') axes[1, 1].grid(True) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"šŸ’¾ Saved training history: {save_path}") plt.show() def plot_confusion_matrix(labels, preds, save_path: Optional[Path] = None): """Plot confusion matrix""" cm = confusion_matrix(labels, preds) fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(cm, cmap='Blues') ax.set_xticks([0, 1]) ax.set_yticks([0, 1]) ax.set_xticklabels(['Not Arch', 'Arch']) ax.set_yticklabels(['Not Arch', 'Arch']) ax.set_xlabel('Predicted') ax.set_ylabel('True') ax.set_title('Confusion Matrix') # Add text annotations for i in range(2): for j in range(2): text = ax.text(j, i, cm[i, j], ha="center", va="center", color="white" if cm[i, j] > cm.max()/2 else "black", fontsize=20, fontweight='bold') plt.colorbar(im, ax=ax) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"šŸ’¾ Saved confusion matrix: {save_path}") plt.show() def plot_roc_curve(labels, probs, save_path: Optional[Path] = None): """Plot ROC curve""" fpr, tpr, thresholds = roc_curve(labels, probs) auc = roc_auc_score(labels, probs) fig, ax = plt.subplots(figsize=(8, 6)) ax.plot(fpr, tpr, linewidth=2, label=f'ROC (AUC = {auc:.3f})') ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random') ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title('ROC Curve') ax.legend() ax.grid(True) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"šŸ’¾ Saved ROC curve: {save_path}") plt.show() def evaluate_model( model: nn.Module, test_loader: DataLoader, device: torch.device, save_dir: Optional[Path] = None ) -> Dict: """Comprehensive model evaluation""" model.eval() all_preds = [] all_probs = [] all_labels = [] print("\nšŸ”¬ Evaluating model on test set...") with torch.no_grad(): for batch_matrices, batch_labels in tqdm(test_loader, desc="Testing"): batch_matrices = batch_matrices.to(device) _, probs = model(batch_matrices) all_preds.extend((probs > 0.5).cpu().numpy().flatten()) all_probs.extend(probs.cpu().numpy().flatten()) all_labels.extend(batch_labels.numpy().flatten()) # Compute metrics acc = accuracy_score(all_labels, all_preds) prec = precision_score(all_labels, all_preds, zero_division=0) rec = recall_score(all_labels, all_preds, zero_division=0) f1 = f1_score(all_labels, all_preds, zero_division=0) auc = roc_auc_score(all_labels, all_probs) cm = confusion_matrix(all_labels, all_preds) results = { 'accuracy': acc, 'precision': prec, 'recall': rec, 'f1': f1, 'auc': auc, 'confusion_matrix': cm.tolist(), 'predictions': all_preds, 'probabilities': all_probs, 'labels': all_labels } # Print results print(f"\n{'='*60}") print(f"TEST SET EVALUATION") print(f"{'='*60}") print(f"Accuracy: {acc:.4f}") print(f"Precision: {prec:.4f}") print(f"Recall: {rec:.4f}") print(f"F1 Score: {f1:.4f}") print(f"AUC: {auc:.4f}") print(f"\nConfusion Matrix:") print(f" Predicted") print(f" Not Arch | Arch") print(f"True Not Arch {cm[0,0]:4d} | {cm[0,1]:4d}") print(f"True Arch {cm[1,0]:4d} | {cm[1,1]:4d}") print(f"{'='*60}\n") # Visualizations if save_dir: save_dir.mkdir(exist_ok=True, parents=True) plot_confusion_matrix(all_labels, all_preds, save_dir / 'confusion_matrix.png') plot_roc_curve(all_labels, all_probs, save_dir / 'roc_curve.png') # Save results as JSON with open(save_dir / 'test_results.json', 'w') as f: json_results = {k: v for k, v in results.items() if k not in ['predictions', 'probabilities', 'labels']} json.dump(json_results, f, indent=2) print(f"šŸ’¾ Saved evaluation results to {save_dir}") return results