Spaces:
Sleeping
Sleeping
| """ | |
| gate_model.py | |
| ============= | |
| Gating Network that takes probability matrix inputs and predicts | |
| whether a patch contains an archaeological site. | |
| Input: (4 channels, 64, 64) | |
| - Channel 0: Autoencoder probability | |
| - Channel 1: Isolation Forest probability | |
| - Channel 2: K-Means probability | |
| - Channel 3: Archaeological Site Similarity | |
| Output: Binary classification [0, 1] | |
| - 1 = Archaeological site | |
| - 0 = Not archaeological site | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler | |
| from pathlib import Path | |
| from typing import Tuple, List, Dict, Optional | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import ( | |
| accuracy_score, precision_score, recall_score, f1_score, | |
| confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve | |
| ) | |
| import json | |
| # ============================================================================== | |
| # GATE ARCHITECTURE | |
| # ============================================================================== | |
| class GateNetwork(nn.Module): | |
| """ | |
| Gating network for archaeological site detection | |
| Architecture: | |
| - Input: (batch, 4, 64, 64) probability matrices | |
| - CNN encoder to extract spatial features | |
| - Global pooling + MLP classifier | |
| - Output: (batch, 1) sigmoid probability | |
| """ | |
| def __init__(self, in_channels: int = 4, dropout: float = 0.3): | |
| super().__init__() | |
| # Convolutional encoder | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2) # 64 -> 32 | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2) # 32 -> 16 | |
| ) | |
| self.conv3 = nn.Sequential( | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2) # 16 -> 8 | |
| ) | |
| self.conv4 = nn.Sequential( | |
| nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True), | |
| nn.AdaptiveAvgPool2d(1) # Global average pooling | |
| ) | |
| # Classifier head | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, 64), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 1) | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (batch, 4, 64, 64) probability matrices | |
| Returns: | |
| logits: (batch, 1) raw logits | |
| probs: (batch, 1) sigmoid probabilities | |
| """ | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = self.conv3(x) | |
| x = self.conv4(x) | |
| logits = self.classifier(x) | |
| probs = torch.sigmoid(logits) | |
| return logits, probs | |
| def predict(self, x): | |
| """Predict class labels (0 or 1)""" | |
| _, probs = self.forward(x) | |
| return (probs > 0.5).long() | |
| def count_parameters(self): | |
| """Count trainable parameters""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| # ============================================================================== | |
| # DATASET | |
| # ============================================================================== | |
| class ArchaeologyGateDataset(Dataset): | |
| """ | |
| Dataset for training the Gate Network | |
| Loads unified probability matrices and corresponding labels | |
| """ | |
| def __init__( | |
| self, | |
| aoi_list: List[str], | |
| labels: List[int], | |
| prob_matrix_dir: Path, | |
| augment: bool = False | |
| ): | |
| """ | |
| Args: | |
| aoi_list: List of AOI names | |
| labels: List of labels (1 = archaeology, 0 = not) | |
| prob_matrix_dir: Directory containing unified probability matrices | |
| augment: Whether to apply data augmentation | |
| """ | |
| self.aoi_list = aoi_list | |
| self.labels = labels | |
| self.prob_matrix_dir = prob_matrix_dir | |
| self.augment = augment | |
| assert len(aoi_list) == len(labels), "AOI list and labels must match" | |
| # Load all data into memory | |
| self.data = [] | |
| self.metadata_list = [] | |
| print(f"Loading {len(aoi_list)} AOIs...") | |
| for aoi_name, label in tqdm(zip(aoi_list, labels), total=len(aoi_list)): | |
| matrix_path = prob_matrix_dir / f"{aoi_name}_unified_prob_matrix.npz" | |
| if not matrix_path.exists(): | |
| print(f"⚠️ Warning: Missing {matrix_path}") | |
| continue | |
| with np.load(matrix_path, allow_pickle=True) as data: | |
| unified_matrix = data['unified_matrix'] # (num_patches, 64, 64, 4) | |
| metadata = data['metadata'] | |
| # Add all patches from this AOI | |
| for patch_matrix in unified_matrix: | |
| # Transpose to (4, 64, 64) for PyTorch | |
| patch_matrix = np.transpose(patch_matrix, (2, 0, 1)) | |
| self.data.append((patch_matrix, label)) | |
| print(f"✓ Loaded {len(self.data)} patches") | |
| # Class distribution | |
| labels_array = np.array([label for _, label in self.data]) | |
| n_positive = np.sum(labels_array == 1) | |
| n_negative = np.sum(labels_array == 0) | |
| print(f" Positive (archaeology): {n_positive} ({n_positive/len(self.data)*100:.1f}%)") | |
| print(f" Negative (not): {n_negative} ({n_negative/len(self.data)*100:.1f}%)") | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| patch_matrix, label = self.data[idx] | |
| # Convert to torch tensors | |
| patch_matrix = torch.from_numpy(patch_matrix).float() | |
| label = torch.tensor(label, dtype=torch.float32) | |
| # Data augmentation (optional) | |
| if self.augment: | |
| # Random flip | |
| if torch.rand(1) > 0.5: | |
| patch_matrix = torch.flip(patch_matrix, dims=[1]) | |
| if torch.rand(1) > 0.5: | |
| patch_matrix = torch.flip(patch_matrix, dims=[2]) | |
| # Random rotation (90, 180, 270 degrees) | |
| k = torch.randint(0, 4, (1,)).item() | |
| if k > 0: | |
| patch_matrix = torch.rot90(patch_matrix, k=k, dims=[1, 2]) | |
| return patch_matrix, label | |
| # ============================================================================== | |
| # TRAINING | |
| # ============================================================================== | |
| class GateTrainer: | |
| """Trainer for Gate Network""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| device: torch.device, | |
| lr: float = 1e-3, | |
| weight_decay: float = 1e-4, | |
| pos_weight: Optional[float] = None | |
| ): | |
| self.model = model.to(device) | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.device = device | |
| self.optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=lr, | |
| weight_decay=weight_decay | |
| ) | |
| self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| self.optimizer, | |
| mode='max', | |
| factor=0.5, | |
| patience=5, | |
| verbose=True | |
| ) | |
| # Loss function with optional class weighting | |
| if pos_weight is not None: | |
| pos_weight_tensor = torch.tensor([pos_weight]).to(device) | |
| self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) | |
| else: | |
| self.criterion = nn.BCEWithLogitsLoss() | |
| # Tracking | |
| self.history = { | |
| 'train_loss': [], 'val_loss': [], | |
| 'train_acc': [], 'val_acc': [], | |
| 'train_f1': [], 'val_f1': [], | |
| 'lr': [] | |
| } | |
| self.best_val_f1 = 0.0 | |
| self.best_model_state = None | |
| def train_epoch(self): | |
| """Train for one epoch""" | |
| self.model.train() | |
| running_loss = 0.0 | |
| all_preds = [] | |
| all_labels = [] | |
| pbar = tqdm(self.train_loader, desc="Training", leave=False) | |
| for batch_matrices, batch_labels in pbar: | |
| batch_matrices = batch_matrices.to(self.device) | |
| batch_labels = batch_labels.to(self.device).unsqueeze(1) | |
| # Forward | |
| logits, probs = self.model(batch_matrices) | |
| loss = self.criterion(logits, batch_labels) | |
| # Backward | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| # Track | |
| running_loss += loss.item() * batch_matrices.size(0) | |
| all_preds.extend((probs > 0.5).cpu().numpy().flatten()) | |
| all_labels.extend(batch_labels.cpu().numpy().flatten()) | |
| pbar.set_postfix({'loss': loss.item()}) | |
| epoch_loss = running_loss / len(self.train_loader.dataset) | |
| epoch_acc = accuracy_score(all_labels, all_preds) | |
| epoch_f1 = f1_score(all_labels, all_preds, zero_division=0) | |
| return epoch_loss, epoch_acc, epoch_f1 | |
| def validate(self): | |
| """Validate on validation set""" | |
| self.model.eval() | |
| running_loss = 0.0 | |
| all_preds = [] | |
| all_probs = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for batch_matrices, batch_labels in tqdm(self.val_loader, desc="Validating", leave=False): | |
| batch_matrices = batch_matrices.to(self.device) | |
| batch_labels = batch_labels.to(self.device).unsqueeze(1) | |
| logits, probs = self.model(batch_matrices) | |
| loss = self.criterion(logits, batch_labels) | |
| running_loss += loss.item() * batch_matrices.size(0) | |
| all_preds.extend((probs > 0.5).cpu().numpy().flatten()) | |
| all_probs.extend(probs.cpu().numpy().flatten()) | |
| all_labels.extend(batch_labels.cpu().numpy().flatten()) | |
| epoch_loss = running_loss / len(self.val_loader.dataset) | |
| epoch_acc = accuracy_score(all_labels, all_preds) | |
| epoch_f1 = f1_score(all_labels, all_preds, zero_division=0) | |
| epoch_prec = precision_score(all_labels, all_preds, zero_division=0) | |
| epoch_rec = recall_score(all_labels, all_preds, zero_division=0) | |
| try: | |
| epoch_auc = roc_auc_score(all_labels, all_probs) | |
| except: | |
| epoch_auc = 0.0 | |
| return { | |
| 'loss': epoch_loss, | |
| 'acc': epoch_acc, | |
| 'f1': epoch_f1, | |
| 'precision': epoch_prec, | |
| 'recall': epoch_rec, | |
| 'auc': epoch_auc, | |
| 'preds': all_preds, | |
| 'probs': all_probs, | |
| 'labels': all_labels | |
| } | |
| def train(self, num_epochs: int, save_dir: Path): | |
| """Train for multiple epochs""" | |
| save_dir.mkdir(exist_ok=True, parents=True) | |
| print(f"\n{'='*80}") | |
| print(f"TRAINING GATE NETWORK") | |
| print(f"{'='*80}") | |
| print(f"Model parameters: {self.model.count_parameters():,}") | |
| print(f"Training samples: {len(self.train_loader.dataset)}") | |
| print(f"Validation samples: {len(self.val_loader.dataset)}") | |
| print(f"Epochs: {num_epochs}") | |
| print(f"{'='*80}\n") | |
| for epoch in range(1, num_epochs + 1): | |
| print(f"\nEpoch {epoch}/{num_epochs}") | |
| print("-" * 40) | |
| # Train | |
| train_loss, train_acc, train_f1 = self.train_epoch() | |
| # Validate | |
| val_metrics = self.validate() | |
| # Update scheduler | |
| self.scheduler.step(val_metrics['f1']) | |
| # Track history | |
| self.history['train_loss'].append(train_loss) | |
| self.history['val_loss'].append(val_metrics['loss']) | |
| self.history['train_acc'].append(train_acc) | |
| self.history['val_acc'].append(val_metrics['acc']) | |
| self.history['train_f1'].append(train_f1) | |
| self.history['val_f1'].append(val_metrics['f1']) | |
| self.history['lr'].append(self.optimizer.param_groups[0]['lr']) | |
| # Print metrics | |
| print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}") | |
| print(f"Val Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['acc']:.4f} | F1: {val_metrics['f1']:.4f}") | |
| print(f" Prec: {val_metrics['precision']:.4f} | Rec: {val_metrics['recall']:.4f} | AUC: {val_metrics['auc']:.4f}") | |
| # Save best model | |
| if val_metrics['f1'] > self.best_val_f1: | |
| self.best_val_f1 = val_metrics['f1'] | |
| self.best_model_state = self.model.state_dict().copy() | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': self.best_model_state, | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'val_f1': self.best_val_f1, | |
| 'history': self.history | |
| }, save_dir / 'best_gate_model.pth') | |
| print(f"✓ Saved best model (F1: {self.best_val_f1:.4f})") | |
| # Load best model | |
| self.model.load_state_dict(self.best_model_state) | |
| print(f"\n{'='*80}") | |
| print(f"TRAINING COMPLETE") | |
| print(f"Best Validation F1: {self.best_val_f1:.4f}") | |
| print(f"{'='*80}\n") | |
| return self.history | |
| # ============================================================================== | |
| # EVALUATION & VISUALIZATION | |
| # ============================================================================== | |
| def plot_training_history(history: Dict, save_path: Optional[Path] = None): | |
| """Plot training curves""" | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) | |
| # Loss | |
| axes[0, 0].plot(history['train_loss'], label='Train') | |
| axes[0, 0].plot(history['val_loss'], label='Val') | |
| axes[0, 0].set_xlabel('Epoch') | |
| axes[0, 0].set_ylabel('Loss') | |
| axes[0, 0].set_title('Training & Validation Loss') | |
| axes[0, 0].legend() | |
| axes[0, 0].grid(True) | |
| # Accuracy | |
| axes[0, 1].plot(history['train_acc'], label='Train') | |
| axes[0, 1].plot(history['val_acc'], label='Val') | |
| axes[0, 1].set_xlabel('Epoch') | |
| axes[0, 1].set_ylabel('Accuracy') | |
| axes[0, 1].set_title('Training & Validation Accuracy') | |
| axes[0, 1].legend() | |
| axes[0, 1].grid(True) | |
| # F1 | |
| axes[1, 0].plot(history['train_f1'], label='Train') | |
| axes[1, 0].plot(history['val_f1'], label='Val') | |
| axes[1, 0].set_xlabel('Epoch') | |
| axes[1, 0].set_ylabel('F1 Score') | |
| axes[1, 0].set_title('Training & Validation F1') | |
| axes[1, 0].legend() | |
| axes[1, 0].grid(True) | |
| # Learning rate | |
| axes[1, 1].plot(history['lr']) | |
| axes[1, 1].set_xlabel('Epoch') | |
| axes[1, 1].set_ylabel('Learning Rate') | |
| axes[1, 1].set_title('Learning Rate Schedule') | |
| axes[1, 1].set_yscale('log') | |
| axes[1, 1].grid(True) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"💾 Saved training history: {save_path}") | |
| plt.show() | |
| def plot_confusion_matrix(labels, preds, save_path: Optional[Path] = None): | |
| """Plot confusion matrix""" | |
| cm = confusion_matrix(labels, preds) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| im = ax.imshow(cm, cmap='Blues') | |
| ax.set_xticks([0, 1]) | |
| ax.set_yticks([0, 1]) | |
| ax.set_xticklabels(['Not Arch', 'Arch']) | |
| ax.set_yticklabels(['Not Arch', 'Arch']) | |
| ax.set_xlabel('Predicted') | |
| ax.set_ylabel('True') | |
| ax.set_title('Confusion Matrix') | |
| # Add text annotations | |
| for i in range(2): | |
| for j in range(2): | |
| text = ax.text(j, i, cm[i, j], | |
| ha="center", va="center", color="white" if cm[i, j] > cm.max()/2 else "black", | |
| fontsize=20, fontweight='bold') | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"💾 Saved confusion matrix: {save_path}") | |
| plt.show() | |
| def plot_roc_curve(labels, probs, save_path: Optional[Path] = None): | |
| """Plot ROC curve""" | |
| fpr, tpr, thresholds = roc_curve(labels, probs) | |
| auc = roc_auc_score(labels, probs) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.plot(fpr, tpr, linewidth=2, label=f'ROC (AUC = {auc:.3f})') | |
| ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random') | |
| ax.set_xlabel('False Positive Rate') | |
| ax.set_ylabel('True Positive Rate') | |
| ax.set_title('ROC Curve') | |
| ax.legend() | |
| ax.grid(True) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"💾 Saved ROC curve: {save_path}") | |
| plt.show() | |
| def evaluate_model( | |
| model: nn.Module, | |
| test_loader: DataLoader, | |
| device: torch.device, | |
| save_dir: Optional[Path] = None | |
| ) -> Dict: | |
| """Comprehensive model evaluation""" | |
| model.eval() | |
| all_preds = [] | |
| all_probs = [] | |
| all_labels = [] | |
| print("\n🔬 Evaluating model on test set...") | |
| with torch.no_grad(): | |
| for batch_matrices, batch_labels in tqdm(test_loader, desc="Testing"): | |
| batch_matrices = batch_matrices.to(device) | |
| _, probs = model(batch_matrices) | |
| all_preds.extend((probs > 0.5).cpu().numpy().flatten()) | |
| all_probs.extend(probs.cpu().numpy().flatten()) | |
| all_labels.extend(batch_labels.numpy().flatten()) | |
| # Compute metrics | |
| acc = accuracy_score(all_labels, all_preds) | |
| prec = precision_score(all_labels, all_preds, zero_division=0) | |
| rec = recall_score(all_labels, all_preds, zero_division=0) | |
| f1 = f1_score(all_labels, all_preds, zero_division=0) | |
| auc = roc_auc_score(all_labels, all_probs) | |
| cm = confusion_matrix(all_labels, all_preds) | |
| results = { | |
| 'accuracy': acc, | |
| 'precision': prec, | |
| 'recall': rec, | |
| 'f1': f1, | |
| 'auc': auc, | |
| 'confusion_matrix': cm.tolist(), | |
| 'predictions': all_preds, | |
| 'probabilities': all_probs, | |
| 'labels': all_labels | |
| } | |
| # Print results | |
| print(f"\n{'='*60}") | |
| print(f"TEST SET EVALUATION") | |
| print(f"{'='*60}") | |
| print(f"Accuracy: {acc:.4f}") | |
| print(f"Precision: {prec:.4f}") | |
| print(f"Recall: {rec:.4f}") | |
| print(f"F1 Score: {f1:.4f}") | |
| print(f"AUC: {auc:.4f}") | |
| print(f"\nConfusion Matrix:") | |
| print(f" Predicted") | |
| print(f" Not Arch | Arch") | |
| print(f"True Not Arch {cm[0,0]:4d} | {cm[0,1]:4d}") | |
| print(f"True Arch {cm[1,0]:4d} | {cm[1,1]:4d}") | |
| print(f"{'='*60}\n") | |
| # Visualizations | |
| if save_dir: | |
| save_dir.mkdir(exist_ok=True, parents=True) | |
| plot_confusion_matrix(all_labels, all_preds, save_dir / 'confusion_matrix.png') | |
| plot_roc_curve(all_labels, all_probs, save_dir / 'roc_curve.png') | |
| # Save results as JSON | |
| with open(save_dir / 'test_results.json', 'w') as f: | |
| json_results = {k: v for k, v in results.items() | |
| if k not in ['predictions', 'probabilities', 'labels']} | |
| json.dump(json_results, f, indent=2) | |
| print(f"💾 Saved evaluation results to {save_dir}") | |
| return results |