""" BASED ON: "Deepnet-based surgical tools detection in laparoscopic videos" AUTHORS: Praveen SR Konduri, G Siva Nageswara Rao DOI: https://doi.org/10.1016/j.knosys.2025.113517 """ import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision import models, transforms from PIL import Image import pandas as pd import numpy as np import cv2 from sklearn.metrics import classification_report, confusion_matrix from tqdm import tqdm import matplotlib.pyplot as plt import seaborn as sns # CONFIGURATION BASE_PATH = r"C:\Users\anna2\ISM" # Adjust to your path PATH_TO_IMAGES = os.path.join(BASE_PATH, "images") PATH_TO_TRAIN_GT = os.path.join(BASE_PATH, "Baselines", "phase_1b", "gt_for_classification_multiclass_from_filenames_0_index.csv") MODEL_SAVE_PATH = os.path.join(BASE_PATH, "ANNA", "phase1b-6", "cvggnet_optimized_small.pth") # Hyperparameters VAL_FRACTION = 0.1 IMAGE_SIZE = 224 # Standard VGG input MAX_EPOCHS = 15 # they were3 before BATCH_SIZE = 48 NUM_CLASSES = 3 LEARNING_RATE = 0.0012 # Slightly reduced for stability # da tentare dopo: scheduler = optim.lr_scheduler.CosineAnnealingLR( # optimizer, T_max=MAX_EPOCHS, eta_min=1e-6) WEIGHT_DECAY = 5e-4 # INCREASED for regularization DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Features USE_BILATERAL_FILTER = True USE_CLASS_WEIGHTS = False USE_EARLY_STOPPING = True EARLY_STOP_PATIENCE = 3 #CBAM ATTENTION MODULE (section 3.3) class ChannelAttention(nn.Module): """Channel Attention Module from CBAM""" def __init__(self, channels, reduction=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(channels, channels // reduction, 1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(channels // reduction, channels, 1, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) class SpatialAttention(nn.Module): """Spatial Attention Module from CBAM""" def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv(x) return self.sigmoid(x) class CBAM(nn.Module): """Convolutional Block Attention Module""" def __init__(self, channels, reduction=16, kernel_size=7): super(CBAM, self).__init__() self.channel_attention = ChannelAttention(channels, reduction) self.spatial_attention = SpatialAttention(kernel_size) def forward(self, x): x = x * self.channel_attention(x) x = x * self.spatial_attention(x) return x # ULTRA-OPTIMIZED CVGGNet-16 MODEL ''' class CVGGNet16UltraOptimized(nn.Module): """ CVGGNet-16 with Ultra-Aggressive Optimization VGG-16 Structure (5 conv blocks): Block 1: conv1_1, conv1_2 (64 channels) ← FROZEN Block 2: conv2_1, conv2_2 (128 channels) ← FROZEN Block 3: conv3_1, conv3_2, conv3_3 (256) ← FROZEN Block 4: conv4_1, conv4_2, conv4_3 (512) ← FROZEN (NEW) Block 5: conv5_1, conv5_2, conv5_3 (512) ← TRAINABLE (only this!) Classifier: Lightweight 512→128→3 (vs original 4096→4096→3) Key Changes: - Freeze blocks 1-4 (only train block 5) - Tiny classifier (99% parameter reduction) - Model size: ~200MB (down from 1.6GB) - Trainable params: ~15% (down from 43%) """ def __init__(self, num_classes=3, pretrained=True): super(CVGGNet16UltraOptimized, self).__init__() # Load pre-trained VGG-16 vgg16 = models.vgg16(pretrained=pretrained) # Extract features self.features = vgg16.features # CBAM attention self.cbam = CBAM(channels=512, reduction=16) # Pooling self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) # LIGHTWEIGHT Classifier (CRITICAL FIX for model size) self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 512), # 25K params (vs 100M in original) nn.ReLU(inplace=True), nn.Dropout(0.6), # INCREASED dropout for overfitting nn.Linear(512, 128), nn.ReLU(inplace=True), nn.Dropout(0.5), # INCREASED dropout nn.Linear(128, num_classes) ) # Apply aggressive freezing self._freeze_early_layers() def _freeze_early_layers(self): """ ULTRA-AGGRESSIVE FREEZING: Freeze blocks 1-4, train ONLY block 5 VGG-16 features structure: - Indices 0-4: Block 1 ← FROZEN - Indices 5-9: Block 2 ← FROZEN - Indices 10-16: Block 3 ← FROZEN - Indices 17-23: Block 4 ← FROZEN (NEW) - Indices 24-30: Block 5 ← TRAINABLE (only this!) """ print("\n" + "="*70) print("Applying ULTRA-AGGRESSIVE Layer Freezing") print("="*70) # Freeze blocks 1-4 (indices 0-23) freeze_until_idx = 10 # Start of block 5 - MOST AGGRESSIVE for idx, layer in enumerate(self.features): if idx < freeze_until_idx: for param in layer.parameters(): param.requires_grad = False # Count parameters total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) frozen_params = total_params - trainable_params print(f"\nParameter Summary:") print(f" Total parameters: {total_params:,}") print(f" Frozen parameters: {frozen_params:,} ({100*frozen_params/total_params:.1f}%)") print(f" Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)") print(f"\nLayer Status:") print(f" ✗ FROZEN: VGG-16 Blocks 1-4 (conv1-conv4)") print(f" ✓ TRAINABLE: VGG-16 Block 5 ONLY (conv5)") print(f" ✓ TRAINABLE: CBAM Attention") print(f" ✓ TRAINABLE: Lightweight Classifier (512→128→3)") # Calculate model size model_size_mb = (total_params * 4) / (1024**2) # 4 bytes per float32 print(f"\nEstimated Model Size:") print(f" Full precision (FP32): ~{model_size_mb:.1f} MB") print(f" Half precision (FP16): ~{model_size_mb/2:.1f} MB") print("="*70 + "\n") def forward(self, x): x = self.features(x) x = self.cbam(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x ''' class CVGGNetResNet50(nn.Module): def __init__(self, num_classes=3, pretrained=True): super(CVGGNetResNet50, self).__init__() # Load ResNet-50 resnet = models.resnet50(pretrained=pretrained) # Extract feature layers # Index mapping: # 0: conv1, 1: bn1, 2: relu, 3: maxpool # 4: layer1, 5: layer2, 6: layer3, 7: layer4 self.features = nn.Sequential(*list(resnet.children())[:-2]) # CBAM attention on final feature maps (2048 channels) self.cbam = CBAM(channels=2048, reduction=16) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Lightweight classifier self.classifier = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(inplace=True), nn.Dropout(0.6), nn.Linear(512, 128), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(128, num_classes) ) # Apply freezing self._freeze_early_layers() def _print_freeze_summary(self): """Print detailed freezing summary - DEFINE THIS FIRST""" total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) frozen_params = total_params - trainable_params print(f"\nParameter Summary:") print(f" Total parameters: {total_params:,}") print(f" Frozen parameters: {frozen_params:,} ({100*frozen_params/total_params:.1f}%)") print(f" Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)") print(f"\nLayer Status:") print(f" ❌ FROZEN: conv1 + bn1 (initial conv)") print(f" ❌ FROZEN: layer1 (3 blocks, 256 channels)") print(f" ❌ FROZEN: layer2 (4 blocks, 512 channels)") print(f" ✓ TRAINABLE: layer3 (6 blocks, 1024 channels)") print(f" ✓ TRAINABLE: layer4 (3 blocks, 2048 channels)") print(f" ✓ TRAINABLE: CBAM Attention") print(f" ✓ TRAINABLE: Classifier (2048→512→128→3)") model_size_mb = (total_params * 4) / (1024**2) print(f"\nEstimated Model Size: ~{model_size_mb:.1f} MB") print("="*70 + "\n") def _freeze_early_layers(self): """ RECOMMENDED: Freeze layers 1-2, train layers 3-4 """ print("\n" + "="*70) print("ResNet-50 Layer Freezing Strategy") print("="*70) # Freeze initial conv block for param in self.features[0].parameters(): # conv1 param.requires_grad = False for param in self.features[1].parameters(): # bn1 param.requires_grad = False # Freeze layer1 (early low-level features) for param in self.features[4].parameters(): param.requires_grad = False # Freeze layer2 (mid-level features) for param in self.features[5].parameters(): param.requires_grad = False # layer3 and layer4 remain trainable self._print_freeze_summary() def forward(self, x): x = self.features(x) x = self.cbam(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x # RAPID BILATERAL FILTER (section 3.2 of paper) # ref: "Bilateral Filtering: Theory and Applications" # By Sylvain Paris, Pierre Kornprobst, Jack Tumblin and Frédo Durand # DOI: 10.1561/0600000020 def rapid_bilateral_filter(image, radius=5, sigma_color=150, sigma_space=8): """Rapid Bilateral Filter for image's contrast enhancement. Returns smoothened images where important image features are enhanced and non relevant features are eliminated""" if isinstance(image, Image.Image): image = np.array(image) filtered = cv2.bilateralFilter(image, radius, sigma_color, sigma_space) return filtered # DATASET class SurgicalToolDataset(Dataset): """Dataset with optional Rapid Bilateral Filter preprocessing""" def __init__(self, img_dir, annotation_file, transform=None, validation_set=False, use_bilateral_filter=True): gt = pd.read_csv(annotation_file) if validation_set: self.img_labels = gt[gt["validation_set"] == 1] else: self.img_labels = gt[gt["validation_set"] == 0] self.img_dir = img_dir self.transform = transform self.use_bilateral_filter = use_bilateral_filter self.images = self.img_labels["file_name"].values self.labels = self.img_labels["category_id"].values def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.images[idx]) image = Image.open(img_path).convert('RGB') if self.use_bilateral_filter: image = rapid_bilateral_filter(image) image = Image.fromarray(image) label = self.labels[idx] if self.transform: image = self.transform(image) return image, label # EARLY STOPPING class EarlyStopping: """Early stopping to prevent overfitting""" def __init__(self, patience=3, min_delta=0.001): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = None def __call__(self, val_loss): if self.best_loss is None: self.best_loss = val_loss elif val_loss > self.best_loss - self.min_delta: self.counter += 1 if self.counter >= self.patience: return True else: self.best_loss = val_loss self.counter = 0 return False #TRAINING FUNCTIONS def compute_class_weights(labels, num_classes): """Compute class weights for imbalanced datasets""" class_counts = np.bincount(labels, minlength=num_classes) total_samples = len(labels) weights = total_samples / (num_classes * class_counts) weights = torch.FloatTensor(weights) print(f"\nClass weights computed: {weights.numpy()}") return weights def train_epoch(model, train_loader, criterion, optimizer, device, class_weights=None): """Train for one epoch""" model.train() running_loss = 0.0 correct = 0 total = 0 if class_weights is not None: criterion = nn.CrossEntropyLoss(weight=class_weights.to(device)) pbar = tqdm(train_loader, desc="Training", leave=False) for images, labels in pbar: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'}) epoch_loss = running_loss / len(train_loader) epoch_acc = 100. * correct / total return epoch_loss, epoch_acc def validate(model, val_loader, criterion, device): """Validate the model""" model.eval() running_loss = 0.0 all_predictions = [] all_labels = [] with torch.no_grad(): for images, labels in tqdm(val_loader, desc="Validating", leave=False): images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) all_predictions.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) val_loss = running_loss / len(val_loader) return val_loss, all_predictions, all_labels def plot_confusion_matrix(labels, predictions, save_path): """Plot confusion matrix""" cm = confusion_matrix(labels, predictions) plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=[f'Class {i}' for i in range(len(cm))], yticklabels=[f'Class {i}' for i in range(len(cm))]) plt.title('Confusion Matrix') plt.ylabel('True Label') plt.xlabel('Predicted Label') plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() print(f"✓ Confusion matrix saved to {save_path}") def plot_training_history(train_losses, val_losses, train_accs, val_accs, save_path): """Plot training history""" epochs = range(1, len(train_losses) + 1) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) # Loss plot ax1.plot(epochs, train_losses, 'b-o', label='Train Loss', linewidth=2) ax1.plot(epochs, val_losses, 'r-s', label='Val Loss', linewidth=2) ax1.set_xlabel('Epoch', fontsize=12) ax1.set_ylabel('Loss', fontsize=12) ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold') ax1.legend(fontsize=11) ax1.grid(True, alpha=0.3) # Accuracy plot ax2.plot(epochs, train_accs, 'b-o', label='Train Acc', linewidth=2) ax2.plot(epochs, val_accs, 'r-s', label='Val Acc', linewidth=2) ax2.set_xlabel('Epoch', fontsize=12) ax2.set_ylabel('Accuracy (%)', fontsize=12) ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold') ax2.legend(fontsize=11) ax2.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() print(f"✓ Training history saved to {save_path}") # MAIN TRAINING FUNCTION def main(): """Main training pipeline""" # Set seeds for reproducibility torch.manual_seed(543) np.random.seed(543) print("="*70) print("CVGGNet-16 ULTRA-OPTIMIZED Training") print("Strategy: Ultra-Aggressive Freezing + Tiny Classifier") print("="*70) print(f"Device: {DEVICE}") print(f"Batch size: {BATCH_SIZE}") print(f"Max epochs: {MAX_EPOCHS} (REDUCED to prevent overfitting)") print(f"Learning rate: {LEARNING_RATE}") print(f"Weight decay: {WEIGHT_DECAY} (INCREASED for regularization)") print(f"Bilateral filter: {USE_BILATERAL_FILTER}") print(f"Early stopping: {USE_EARLY_STOPPING} (patience={EARLY_STOP_PATIENCE})") print("="*70 + "\n") # DATA PREPARATION # Create validation split df = pd.read_csv(PATH_TO_TRAIN_GT) if "validation_set" not in df.columns: df["validation_set"] = 0 val_indices = df.sample(frac=VAL_FRACTION, random_state=42).index df.loc[val_indices, "validation_set"] = 1 df.to_csv(PATH_TO_TRAIN_GT, index=False) print(f"✓ Created validation split ({VAL_FRACTION*100:.0f}%)\n") # REDUCED Data Augmentation (was too aggressive) train_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.RandomHorizontalFlip(p=0.5), # REDUCED from 0.5 transforms.RandomRotation(degrees=15), #transforms.AugMix(severity=2), # REDUCED from 15 # REMOVED ColorJitter - too aggressive for surgical images transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Create datasets train_dataset = SurgicalToolDataset( img_dir=PATH_TO_IMAGES, annotation_file=PATH_TO_TRAIN_GT, transform=train_transform, validation_set=False, use_bilateral_filter=USE_BILATERAL_FILTER ) val_dataset = SurgicalToolDataset( img_dir=PATH_TO_IMAGES, annotation_file=PATH_TO_TRAIN_GT, transform=val_transform, validation_set=True, use_bilateral_filter=USE_BILATERAL_FILTER ) # Create dataloaders train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=6, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=6, pin_memory=True) print(f"Dataset sizes:") print(f" Training: {len(train_dataset)} images") print(f" Validation: {len(val_dataset)} images") print(f" Batches per epoch: {len(train_loader)} (train), {len(val_loader)} (val)") # Compute class weights class_weights = None if USE_CLASS_WEIGHTS: class_weights = compute_class_weights(train_dataset.labels, NUM_CLASSES) # MODEL SETUP print(f"\nCreating CVGGNet-Resnet Ultra-Optimized model...") model = CVGGNetResNet50(num_classes=NUM_CLASSES, pretrained=True).to(DEVICE) # Loss and optimizer criterion = nn.CrossEntropyLoss() # Optimizer - only for trainable parameters optimizer = optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY ) # Learning rate scheduler scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=2, verbose=True ) #DA CAPIRE # Early stopping early_stopping = None if USE_EARLY_STOPPING: early_stopping = EarlyStopping(patience=EARLY_STOP_PATIENCE, min_delta=0.001) # TRAINING LOOP best_val_loss = float('inf') best_val_acc = 0.0 train_losses, val_losses = [], [] train_accs, val_accs = [], [] print("\n" + "="*70) print("Starting Training") print("="*70 + "\n") import time training_start_time = time.time() for epoch in range(MAX_EPOCHS): epoch_start_time = time.time() print(f"\nEpoch [{epoch+1}/{MAX_EPOCHS}]") print("-" * 70) # Train train_loss, train_acc = train_epoch( model, train_loader, criterion, optimizer, DEVICE, class_weights ) print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%") # Validate val_loss, val_predictions, val_labels = validate( model, val_loader, criterion, DEVICE ) val_acc = 100. * np.sum(np.array(val_predictions) == np.array(val_labels)) / len(val_labels) print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%") # Classification report print("\nValidation Metrics:") report = classification_report(val_labels, val_predictions, target_names=[f'Class {i}' for i in range(NUM_CLASSES)], digits=4) print(report) # Save history train_losses.append(train_loss) val_losses.append(val_loss) train_accs.append(train_acc) val_accs.append(val_acc) # Learning rate scheduling scheduler.step(val_loss) # Save best model if val_acc > best_val_acc: best_val_acc = val_acc best_val_loss = val_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_acc': val_acc, 'val_loss': val_loss, 'train_acc': train_acc, 'train_loss': train_loss, }, MODEL_SAVE_PATH) print(f"\n✓ Best model saved! (Val Acc: {val_acc:.2f}%)") # Early stopping check if early_stopping is not None: if early_stopping(val_loss): print(f"\n⚠️ Early stopping at epoch {epoch+1}") break epoch_time = time.time() - epoch_start_time print(f"\nEpoch time: {epoch_time/60:.2f} minutes") print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}") training_time = time.time() - training_start_time # FINAL EVALUATION print("\n" + "="*70) print("Training Complete!") print("="*70) print(f"Total training time: {training_time/60:.2f} minutes") print(f"Best Validation Accuracy: {best_val_acc:.2f}%") print(f"Best Validation Loss: {best_val_loss:.4f}") print(f"Model saved to: {MODEL_SAVE_PATH}") # Check model size model_size_bytes = os.path.getsize(MODEL_SAVE_PATH) model_size_mb = model_size_bytes / (1024**2) print(f"Model file size: {model_size_mb:.1f} MB") if model_size_mb > 500: print("⚠️ WARNING: Model still large (>500MB). Check classifier architecture.") else: print("✓ Model size is good for HuggingFace upload!") # Load best model for final evaluation checkpoint = torch.load(MODEL_SAVE_PATH) model.load_state_dict(checkpoint['model_state_dict']) # Final validation _, final_predictions, final_labels = validate(model, val_loader, criterion, DEVICE) # Plot confusion matrix cm_path = os.path.join(BASE_PATH, 'confusion_matrix_ultra_optimized.png') plot_confusion_matrix(final_labels, final_predictions, cm_path) # Plot training history history_path = os.path.join(BASE_PATH, 'training_history_ultra_optimized.png') plot_training_history(train_losses, val_losses, train_accs, val_accs, history_path) # Final metrics print("\n" + "="*70) print("Final Validation Metrics:") print("="*70) final_report = classification_report(final_labels, final_predictions, target_names=[f'Class {i}' for i in range(NUM_CLASSES)], digits=4) print(final_report) print(f"\n✓ All done! Results saved in {BASE_PATH}") print("="*70) return model if __name__ == "__main__": model = main()