annaferrari02 commited on
Commit
fc1291b
·
verified ·
1 Parent(s): 745c622

Upload 3 files

Browse files
Files changed (3) hide show
  1. cvggnet_optimized_small.pth +3 -0
  2. script.py +245 -0
  3. train.py +735 -0
cvggnet_optimized_small.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5127cc2f34223c3f37b9f5ded78464d77221ff19a2f9698ac5958455e9ad8b63
3
+ size 290618196
script.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for CVGGNet-ResNet50
3
+ Compatible with ResNet-50 + CBAM architecture
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import models, transforms
10
+ from PIL import Image
11
+ import pandas as pd
12
+ import numpy as np
13
+ import cv2
14
+ from tqdm import tqdm
15
+
16
+
17
+ # ==================== CBAM MODULES (must match training) ====================
18
+
19
+ class ChannelAttention(nn.Module):
20
+ def __init__(self, channels, reduction=16):
21
+ super(ChannelAttention, self).__init__()
22
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
23
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
24
+
25
+ self.fc = nn.Sequential(
26
+ nn.Conv2d(channels, channels // reduction, 1, bias=False),
27
+ nn.ReLU(inplace=True),
28
+ nn.Conv2d(channels // reduction, channels, 1, bias=False)
29
+ )
30
+ self.sigmoid = nn.Sigmoid()
31
+
32
+ def forward(self, x):
33
+ avg_out = self.fc(self.avg_pool(x))
34
+ max_out = self.fc(self.max_pool(x))
35
+ out = avg_out + max_out
36
+ return self.sigmoid(out)
37
+
38
+
39
+ class SpatialAttention(nn.Module):
40
+ def __init__(self, kernel_size=7):
41
+ super(SpatialAttention, self).__init__()
42
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
43
+ self.sigmoid = nn.Sigmoid()
44
+
45
+ def forward(self, x):
46
+ avg_out = torch.mean(x, dim=1, keepdim=True)
47
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
48
+ x = torch.cat([avg_out, max_out], dim=1)
49
+ x = self.conv(x)
50
+ return self.sigmoid(x)
51
+
52
+
53
+ class CBAM(nn.Module):
54
+ def __init__(self, channels, reduction=16, kernel_size=7):
55
+ super(CBAM, self).__init__()
56
+ self.channel_attention = ChannelAttention(channels, reduction)
57
+ self.spatial_attention = SpatialAttention(kernel_size)
58
+
59
+ def forward(self, x):
60
+ x = x * self.channel_attention(x)
61
+ x = x * self.spatial_attention(x)
62
+ return x
63
+
64
+
65
+ # ==================== MODEL ARCHITECTURE ====================
66
+
67
+ class CVGGNetResNet50(nn.Module):
68
+ """CVGGNet with ResNet-50 backbone + CBAM attention"""
69
+
70
+ def __init__(self, num_classes=3, pretrained=False):
71
+ super(CVGGNetResNet50, self).__init__()
72
+
73
+ # Load ResNet-50 backbone
74
+ resnet = models.resnet50(pretrained=pretrained)
75
+
76
+ # Extract feature layers (remove avgpool and fc)
77
+ self.features = nn.Sequential(*list(resnet.children())[:-2])
78
+
79
+ # CBAM attention on ResNet-50's output (2048 channels)
80
+ self.cbam = CBAM(channels=2048, reduction=16)
81
+
82
+ # Pooling
83
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
84
+
85
+ # Lightweight Classifier (matches training architecture)
86
+ self.classifier = nn.Sequential(
87
+ nn.Linear(2048, 512),
88
+ nn.ReLU(inplace=True),
89
+ nn.Dropout(0.6),
90
+ nn.Linear(512, 128),
91
+ nn.ReLU(inplace=True),
92
+ nn.Dropout(0.5),
93
+ nn.Linear(128, num_classes)
94
+ )
95
+
96
+ def forward(self, x):
97
+ x = self.features(x)
98
+ x = self.cbam(x)
99
+ x = self.avgpool(x)
100
+ x = torch.flatten(x, 1)
101
+ x = self.classifier(x)
102
+ return x
103
+
104
+
105
+ # ==================== BILATERAL FILTER ====================
106
+
107
+ def rapid_bilateral_filter(image, radius=5, sigma_color=150, sigma_space=8):
108
+ """Rapid Bilateral Filter preprocessing (matches training params)"""
109
+ if isinstance(image, Image.Image):
110
+ image = np.array(image)
111
+
112
+ filtered = cv2.bilateralFilter(image, radius, sigma_color, sigma_space)
113
+ return filtered
114
+
115
+
116
+ # ==================== INFERENCE FUNCTION ====================
117
+
118
+ def run_inference(test_images_path, model, image_size, submission_csv_path,
119
+ use_bilateral_filter=True, device='cpu'):
120
+ """
121
+ Run inference on test images
122
+
123
+ Args:
124
+ test_images_path: Path to test images directory
125
+ model: Trained model
126
+ image_size: Input image size (single int for square images)
127
+ submission_csv_path: Path to save predictions CSV
128
+ use_bilateral_filter: Whether to apply bilateral filter preprocessing
129
+ device: Device to run inference on ('cpu' or 'cuda')
130
+ """
131
+
132
+ model.eval()
133
+ model = model.to(device)
134
+
135
+ # Get test images
136
+ test_images = sorted(os.listdir(test_images_path))
137
+
138
+ # Preprocessing transform (matches training)
139
+ test_transform = transforms.Compose([
140
+ transforms.Resize((image_size, image_size)),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
143
+ ])
144
+
145
+ predictions = []
146
+
147
+ print(f"Running inference on {len(test_images)} images...")
148
+
149
+ for image_name in tqdm(test_images):
150
+ img_path = os.path.join(test_images_path, image_name)
151
+ image = Image.open(img_path).convert('RGB')
152
+
153
+ # Apply bilateral filter if enabled
154
+ if use_bilateral_filter:
155
+ image = rapid_bilateral_filter(image)
156
+ image = Image.fromarray(image)
157
+
158
+ # Preprocess
159
+ img_tensor = test_transform(image).unsqueeze(0).to(device)
160
+
161
+ # Predict
162
+ with torch.no_grad():
163
+ output = model(img_tensor)
164
+ pred = torch.argmax(output, dim=1).cpu().item()
165
+ predictions.append(pred)
166
+
167
+ # Create submission DataFrame
168
+ df_predictions = pd.DataFrame({
169
+ 'file_name': test_images,
170
+ 'category_id': predictions
171
+ })
172
+
173
+ # Save to CSV
174
+ df_predictions.to_csv(submission_csv_path, index=False)
175
+ print(f"\n✓ Predictions saved to: {submission_csv_path}")
176
+
177
+ # Display prediction distribution
178
+ print("\nPrediction Distribution:")
179
+ for class_id in range(3):
180
+ count = (df_predictions['category_id'] == class_id).sum()
181
+ percentage = 100 * count / len(df_predictions)
182
+ print(f" Class {class_id}: {count} images ({percentage:.1f}%)")
183
+
184
+ return df_predictions
185
+
186
+
187
+ # ==================== MAIN SCRIPT ====================
188
+
189
+ if __name__ == "__main__":
190
+
191
+ # Paths
192
+ current_directory = os.path.dirname(os.path.abspath(__file__))
193
+ TEST_IMAGE_PATH = "/tmp/data/test_images" # HuggingFace standard path
194
+ MODEL_WEIGHTS_PATH = os.path.join(current_directory, "cvggnet_optimized_small.pth")
195
+ SUBMISSION_CSV_SAVE_PATH = os.path.join(current_directory, "submission.csv")
196
+
197
+ # Configuration (MUST MATCH TRAINING)
198
+ NUM_CLASSES = 3
199
+ IMAGE_SIZE = 224 # ResNet standard input size
200
+ USE_BILATERAL_FILTER = True # Match your training setting
201
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
202
+
203
+ print("="*60)
204
+ print("CVGGNet-ResNet50 Inference")
205
+ print("="*60)
206
+ print(f"Device: {DEVICE}")
207
+ print(f"Model weights: {MODEL_WEIGHTS_PATH}")
208
+ print(f"Test images: {TEST_IMAGE_PATH}")
209
+ print(f"Output: {SUBMISSION_CSV_SAVE_PATH}")
210
+ print(f"Bilateral filter: {USE_BILATERAL_FILTER}")
211
+ print("="*60 + "\n")
212
+
213
+ # Load model
214
+ print("Loading ResNet-50 model...")
215
+ model = CVGGNetResNet50(num_classes=NUM_CLASSES, pretrained=False)
216
+
217
+ # Load weights
218
+ checkpoint = torch.load(MODEL_WEIGHTS_PATH, map_location=torch.device(DEVICE))
219
+
220
+ # Handle different checkpoint formats
221
+ if 'model_state_dict' in checkpoint:
222
+ model.load_state_dict(checkpoint['model_state_dict'])
223
+ print(f"✓ Model loaded from epoch {checkpoint.get('epoch', 'unknown')}")
224
+ if 'val_acc' in checkpoint:
225
+ print(f" Validation accuracy: {checkpoint.get('val_acc', 0):.2f}%")
226
+ else:
227
+ model.load_state_dict(checkpoint)
228
+ print("✓ Model weights loaded")
229
+
230
+ # Check model size
231
+ model_size_bytes = os.path.getsize(MODEL_WEIGHTS_PATH)
232
+ model_size_mb = model_size_bytes / (1024**2)
233
+ print(f" Model size: {model_size_mb:.1f} MB\n")
234
+
235
+ # Run inference
236
+ predictions_df = run_inference(
237
+ test_images_path=TEST_IMAGE_PATH,
238
+ model=model,
239
+ image_size=IMAGE_SIZE,
240
+ submission_csv_path=SUBMISSION_CSV_SAVE_PATH,
241
+ use_bilateral_filter=USE_BILATERAL_FILTER,
242
+ device=DEVICE
243
+ )
244
+
245
+ print("\n✓ Inference complete!")
train.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BASED ON: "Deepnet-based surgical tools detection in laparoscopic videos"
3
+ AUTHORS: Praveen SR Konduri, G Siva Nageswara Rao
4
+ DOI: https://doi.org/10.1016/j.knosys.2025.113517
5
+
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from torchvision import models, transforms
14
+ from PIL import Image
15
+ import pandas as pd
16
+ import numpy as np
17
+ import cv2
18
+ from sklearn.metrics import classification_report, confusion_matrix
19
+ from tqdm import tqdm
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+
23
+
24
+ # CONFIGURATION
25
+
26
+ BASE_PATH = r"C:\Users\anna2\ISM" # Adjust to your path
27
+ PATH_TO_IMAGES = os.path.join(BASE_PATH, "images")
28
+ PATH_TO_TRAIN_GT = os.path.join(BASE_PATH, "Baselines", "phase_1b", "gt_for_classification_multiclass_from_filenames_0_index.csv")
29
+
30
+ MODEL_SAVE_PATH = os.path.join(BASE_PATH, "ANNA", "phase1b-6", "cvggnet_optimized_small.pth")
31
+
32
+ # Hyperparameters
33
+ VAL_FRACTION = 0.1
34
+ IMAGE_SIZE = 224 # Standard VGG input
35
+ MAX_EPOCHS = 15 # they were3 before
36
+ BATCH_SIZE = 48
37
+ NUM_CLASSES = 3
38
+ LEARNING_RATE = 0.0012 # Slightly reduced for stability
39
+ # da tentare dopo: scheduler = optim.lr_scheduler.CosineAnnealingLR(
40
+ # optimizer, T_max=MAX_EPOCHS, eta_min=1e-6)
41
+ WEIGHT_DECAY = 5e-4 # INCREASED for regularization
42
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ # Features
45
+ USE_BILATERAL_FILTER = True
46
+ USE_CLASS_WEIGHTS = False
47
+ USE_EARLY_STOPPING = True
48
+ EARLY_STOP_PATIENCE = 3
49
+
50
+
51
+ #CBAM ATTENTION MODULE (section 3.3)
52
+
53
+ class ChannelAttention(nn.Module):
54
+ """Channel Attention Module from CBAM"""
55
+ def __init__(self, channels, reduction=16):
56
+ super(ChannelAttention, self).__init__()
57
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
58
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
59
+
60
+ self.fc = nn.Sequential(
61
+ nn.Conv2d(channels, channels // reduction, 1, bias=False),
62
+ nn.ReLU(inplace=True),
63
+ nn.Conv2d(channels // reduction, channels, 1, bias=False)
64
+ )
65
+ self.sigmoid = nn.Sigmoid()
66
+
67
+ def forward(self, x):
68
+ avg_out = self.fc(self.avg_pool(x))
69
+ max_out = self.fc(self.max_pool(x))
70
+ out = avg_out + max_out
71
+ return self.sigmoid(out)
72
+
73
+
74
+ class SpatialAttention(nn.Module):
75
+ """Spatial Attention Module from CBAM"""
76
+ def __init__(self, kernel_size=7):
77
+ super(SpatialAttention, self).__init__()
78
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
79
+ self.sigmoid = nn.Sigmoid()
80
+
81
+ def forward(self, x):
82
+ avg_out = torch.mean(x, dim=1, keepdim=True)
83
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
84
+ x = torch.cat([avg_out, max_out], dim=1)
85
+ x = self.conv(x)
86
+ return self.sigmoid(x)
87
+
88
+
89
+ class CBAM(nn.Module):
90
+ """Convolutional Block Attention Module"""
91
+ def __init__(self, channels, reduction=16, kernel_size=7):
92
+ super(CBAM, self).__init__()
93
+ self.channel_attention = ChannelAttention(channels, reduction)
94
+ self.spatial_attention = SpatialAttention(kernel_size)
95
+
96
+ def forward(self, x):
97
+ x = x * self.channel_attention(x)
98
+ x = x * self.spatial_attention(x)
99
+ return x
100
+
101
+
102
+ # ULTRA-OPTIMIZED CVGGNet-16 MODEL
103
+ '''
104
+ class CVGGNet16UltraOptimized(nn.Module):
105
+ """
106
+ CVGGNet-16 with Ultra-Aggressive Optimization
107
+
108
+ VGG-16 Structure (5 conv blocks):
109
+ Block 1: conv1_1, conv1_2 (64 channels) ← FROZEN
110
+ Block 2: conv2_1, conv2_2 (128 channels) ← FROZEN
111
+ Block 3: conv3_1, conv3_2, conv3_3 (256) ← FROZEN
112
+ Block 4: conv4_1, conv4_2, conv4_3 (512) ← FROZEN (NEW)
113
+ Block 5: conv5_1, conv5_2, conv5_3 (512) ← TRAINABLE (only this!)
114
+
115
+ Classifier: Lightweight 512→128→3 (vs original 4096→4096→3)
116
+
117
+ Key Changes:
118
+ - Freeze blocks 1-4 (only train block 5)
119
+ - Tiny classifier (99% parameter reduction)
120
+ - Model size: ~200MB (down from 1.6GB)
121
+ - Trainable params: ~15% (down from 43%)
122
+ """
123
+ def __init__(self, num_classes=3, pretrained=True):
124
+ super(CVGGNet16UltraOptimized, self).__init__()
125
+
126
+ # Load pre-trained VGG-16
127
+ vgg16 = models.vgg16(pretrained=pretrained)
128
+
129
+ # Extract features
130
+ self.features = vgg16.features
131
+
132
+ # CBAM attention
133
+ self.cbam = CBAM(channels=512, reduction=16)
134
+
135
+ # Pooling
136
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
137
+
138
+ # LIGHTWEIGHT Classifier (CRITICAL FIX for model size)
139
+ self.classifier = nn.Sequential(
140
+ nn.Linear(512 * 7 * 7, 512), # 25K params (vs 100M in original)
141
+ nn.ReLU(inplace=True),
142
+ nn.Dropout(0.6), # INCREASED dropout for overfitting
143
+ nn.Linear(512, 128),
144
+ nn.ReLU(inplace=True),
145
+ nn.Dropout(0.5), # INCREASED dropout
146
+ nn.Linear(128, num_classes)
147
+ )
148
+
149
+ # Apply aggressive freezing
150
+ self._freeze_early_layers()
151
+
152
+ def _freeze_early_layers(self):
153
+ """
154
+ ULTRA-AGGRESSIVE FREEZING: Freeze blocks 1-4, train ONLY block 5
155
+
156
+ VGG-16 features structure:
157
+ - Indices 0-4: Block 1 ← FROZEN
158
+ - Indices 5-9: Block 2 ← FROZEN
159
+ - Indices 10-16: Block 3 ← FROZEN
160
+ - Indices 17-23: Block 4 ← FROZEN (NEW)
161
+ - Indices 24-30: Block 5 ← TRAINABLE (only this!)
162
+ """
163
+ print("\n" + "="*70)
164
+ print("Applying ULTRA-AGGRESSIVE Layer Freezing")
165
+ print("="*70)
166
+
167
+ # Freeze blocks 1-4 (indices 0-23)
168
+ freeze_until_idx = 10 # Start of block 5 - MOST AGGRESSIVE
169
+
170
+ for idx, layer in enumerate(self.features):
171
+ if idx < freeze_until_idx:
172
+ for param in layer.parameters():
173
+ param.requires_grad = False
174
+
175
+ # Count parameters
176
+ total_params = sum(p.numel() for p in self.parameters())
177
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
178
+ frozen_params = total_params - trainable_params
179
+
180
+ print(f"\nParameter Summary:")
181
+ print(f" Total parameters: {total_params:,}")
182
+ print(f" Frozen parameters: {frozen_params:,} ({100*frozen_params/total_params:.1f}%)")
183
+ print(f" Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")
184
+
185
+ print(f"\nLayer Status:")
186
+ print(f" ✗ FROZEN: VGG-16 Blocks 1-4 (conv1-conv4)")
187
+ print(f" ✓ TRAINABLE: VGG-16 Block 5 ONLY (conv5)")
188
+ print(f" ✓ TRAINABLE: CBAM Attention")
189
+ print(f" ✓ TRAINABLE: Lightweight Classifier (512→128→3)")
190
+
191
+ # Calculate model size
192
+ model_size_mb = (total_params * 4) / (1024**2) # 4 bytes per float32
193
+ print(f"\nEstimated Model Size:")
194
+ print(f" Full precision (FP32): ~{model_size_mb:.1f} MB")
195
+ print(f" Half precision (FP16): ~{model_size_mb/2:.1f} MB")
196
+ print("="*70 + "\n")
197
+
198
+ def forward(self, x):
199
+ x = self.features(x)
200
+ x = self.cbam(x)
201
+ x = self.avgpool(x)
202
+ x = torch.flatten(x, 1)
203
+ x = self.classifier(x)
204
+ return x
205
+ '''
206
+
207
+ class CVGGNetResNet50(nn.Module):
208
+ def __init__(self, num_classes=3, pretrained=True):
209
+ super(CVGGNetResNet50, self).__init__()
210
+
211
+ # Load ResNet-50
212
+ resnet = models.resnet50(pretrained=pretrained)
213
+
214
+ # Extract feature layers
215
+ # Index mapping:
216
+ # 0: conv1, 1: bn1, 2: relu, 3: maxpool
217
+ # 4: layer1, 5: layer2, 6: layer3, 7: layer4
218
+ self.features = nn.Sequential(*list(resnet.children())[:-2])
219
+
220
+ # CBAM attention on final feature maps (2048 channels)
221
+ self.cbam = CBAM(channels=2048, reduction=16)
222
+
223
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
224
+
225
+ # Lightweight classifier
226
+ self.classifier = nn.Sequential(
227
+ nn.Linear(2048, 512),
228
+ nn.ReLU(inplace=True),
229
+ nn.Dropout(0.6),
230
+ nn.Linear(512, 128),
231
+ nn.ReLU(inplace=True),
232
+ nn.Dropout(0.5),
233
+ nn.Linear(128, num_classes)
234
+ )
235
+
236
+ # Apply freezing
237
+ self._freeze_early_layers()
238
+
239
+ def _print_freeze_summary(self):
240
+ """Print detailed freezing summary - DEFINE THIS FIRST"""
241
+ total_params = sum(p.numel() for p in self.parameters())
242
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
243
+ frozen_params = total_params - trainable_params
244
+
245
+ print(f"\nParameter Summary:")
246
+ print(f" Total parameters: {total_params:,}")
247
+ print(f" Frozen parameters: {frozen_params:,} ({100*frozen_params/total_params:.1f}%)")
248
+ print(f" Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")
249
+
250
+ print(f"\nLayer Status:")
251
+ print(f" ❌ FROZEN: conv1 + bn1 (initial conv)")
252
+ print(f" ❌ FROZEN: layer1 (3 blocks, 256 channels)")
253
+ print(f" ❌ FROZEN: layer2 (4 blocks, 512 channels)")
254
+ print(f" ✓ TRAINABLE: layer3 (6 blocks, 1024 channels)")
255
+ print(f" ✓ TRAINABLE: layer4 (3 blocks, 2048 channels)")
256
+ print(f" ✓ TRAINABLE: CBAM Attention")
257
+ print(f" ✓ TRAINABLE: Classifier (2048→512→128→3)")
258
+
259
+ model_size_mb = (total_params * 4) / (1024**2)
260
+ print(f"\nEstimated Model Size: ~{model_size_mb:.1f} MB")
261
+ print("="*70 + "\n")
262
+
263
+ def _freeze_early_layers(self):
264
+ """
265
+ RECOMMENDED: Freeze layers 1-2, train layers 3-4
266
+ """
267
+ print("\n" + "="*70)
268
+ print("ResNet-50 Layer Freezing Strategy")
269
+ print("="*70)
270
+
271
+ # Freeze initial conv block
272
+ for param in self.features[0].parameters(): # conv1
273
+ param.requires_grad = False
274
+ for param in self.features[1].parameters(): # bn1
275
+ param.requires_grad = False
276
+
277
+ # Freeze layer1 (early low-level features)
278
+ for param in self.features[4].parameters():
279
+ param.requires_grad = False
280
+
281
+ # Freeze layer2 (mid-level features)
282
+ for param in self.features[5].parameters():
283
+ param.requires_grad = False
284
+
285
+ # layer3 and layer4 remain trainable
286
+
287
+ self._print_freeze_summary()
288
+
289
+ def forward(self, x):
290
+ x = self.features(x)
291
+ x = self.cbam(x)
292
+ x = self.avgpool(x)
293
+ x = torch.flatten(x, 1)
294
+ x = self.classifier(x)
295
+ return x
296
+
297
+ # RAPID BILATERAL FILTER (section 3.2 of paper)
298
+ # ref: "Bilateral Filtering: Theory and Applications"
299
+ # By Sylvain Paris, Pierre Kornprobst, Jack Tumblin and Frédo Durand
300
+ # DOI: 10.1561/0600000020
301
+
302
+ def rapid_bilateral_filter(image, radius=5, sigma_color=150, sigma_space=8):
303
+ """Rapid Bilateral Filter for image's contrast
304
+ enhancement. Returns smoothened images where
305
+ important image features are enhanced and non
306
+ relevant features are eliminated"""
307
+ if isinstance(image, Image.Image):
308
+ image = np.array(image)
309
+
310
+ filtered = cv2.bilateralFilter(image, radius, sigma_color, sigma_space)
311
+ return filtered
312
+
313
+
314
+ # DATASET
315
+
316
+ class SurgicalToolDataset(Dataset):
317
+ """Dataset with optional Rapid Bilateral Filter preprocessing"""
318
+
319
+ def __init__(self, img_dir, annotation_file, transform=None,
320
+ validation_set=False, use_bilateral_filter=True):
321
+ gt = pd.read_csv(annotation_file)
322
+
323
+ if validation_set:
324
+ self.img_labels = gt[gt["validation_set"] == 1]
325
+ else:
326
+ self.img_labels = gt[gt["validation_set"] == 0]
327
+
328
+ self.img_dir = img_dir
329
+ self.transform = transform
330
+ self.use_bilateral_filter = use_bilateral_filter
331
+
332
+ self.images = self.img_labels["file_name"].values
333
+ self.labels = self.img_labels["category_id"].values
334
+
335
+ def __len__(self):
336
+ return len(self.img_labels)
337
+
338
+ def __getitem__(self, idx):
339
+ img_path = os.path.join(self.img_dir, self.images[idx])
340
+ image = Image.open(img_path).convert('RGB')
341
+
342
+ if self.use_bilateral_filter:
343
+ image = rapid_bilateral_filter(image)
344
+ image = Image.fromarray(image)
345
+
346
+ label = self.labels[idx]
347
+
348
+ if self.transform:
349
+ image = self.transform(image)
350
+
351
+ return image, label
352
+
353
+
354
+ # EARLY STOPPING
355
+
356
+ class EarlyStopping:
357
+ """Early stopping to prevent overfitting"""
358
+ def __init__(self, patience=3, min_delta=0.001):
359
+ self.patience = patience
360
+ self.min_delta = min_delta
361
+ self.counter = 0
362
+ self.best_loss = None
363
+
364
+ def __call__(self, val_loss):
365
+ if self.best_loss is None:
366
+ self.best_loss = val_loss
367
+ elif val_loss > self.best_loss - self.min_delta:
368
+ self.counter += 1
369
+ if self.counter >= self.patience:
370
+ return True
371
+ else:
372
+ self.best_loss = val_loss
373
+ self.counter = 0
374
+ return False
375
+
376
+
377
+ #TRAINING FUNCTIONS
378
+
379
+ def compute_class_weights(labels, num_classes):
380
+ """Compute class weights for imbalanced datasets"""
381
+ class_counts = np.bincount(labels, minlength=num_classes)
382
+ total_samples = len(labels)
383
+ weights = total_samples / (num_classes * class_counts)
384
+ weights = torch.FloatTensor(weights)
385
+ print(f"\nClass weights computed: {weights.numpy()}")
386
+ return weights
387
+
388
+
389
+ def train_epoch(model, train_loader, criterion, optimizer, device, class_weights=None):
390
+ """Train for one epoch"""
391
+ model.train()
392
+ running_loss = 0.0
393
+ correct = 0
394
+ total = 0
395
+
396
+ if class_weights is not None:
397
+ criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
398
+
399
+ pbar = tqdm(train_loader, desc="Training", leave=False)
400
+ for images, labels in pbar:
401
+ images, labels = images.to(device), labels.to(device)
402
+
403
+ optimizer.zero_grad()
404
+ outputs = model(images)
405
+ loss = criterion(outputs, labels)
406
+ loss.backward()
407
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
408
+ optimizer.step()
409
+
410
+ running_loss += loss.item()
411
+ _, predicted = torch.max(outputs.data, 1)
412
+ total += labels.size(0)
413
+ correct += (predicted == labels).sum().item()
414
+
415
+ pbar.set_postfix({'loss': f'{loss.item():.4f}',
416
+ 'acc': f'{100.*correct/total:.2f}%'})
417
+
418
+ epoch_loss = running_loss / len(train_loader)
419
+ epoch_acc = 100. * correct / total
420
+
421
+ return epoch_loss, epoch_acc
422
+
423
+
424
+ def validate(model, val_loader, criterion, device):
425
+ """Validate the model"""
426
+ model.eval()
427
+ running_loss = 0.0
428
+ all_predictions = []
429
+ all_labels = []
430
+
431
+ with torch.no_grad():
432
+ for images, labels in tqdm(val_loader, desc="Validating", leave=False):
433
+ images, labels = images.to(device), labels.to(device)
434
+
435
+ outputs = model(images)
436
+ loss = criterion(outputs, labels)
437
+
438
+ running_loss += loss.item()
439
+
440
+ _, predicted = torch.max(outputs.data, 1)
441
+ all_predictions.extend(predicted.cpu().numpy())
442
+ all_labels.extend(labels.cpu().numpy())
443
+
444
+ val_loss = running_loss / len(val_loader)
445
+
446
+ return val_loss, all_predictions, all_labels
447
+
448
+
449
+ def plot_confusion_matrix(labels, predictions, save_path):
450
+ """Plot confusion matrix"""
451
+ cm = confusion_matrix(labels, predictions)
452
+
453
+ plt.figure(figsize=(8, 6))
454
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
455
+ xticklabels=[f'Class {i}' for i in range(len(cm))],
456
+ yticklabels=[f'Class {i}' for i in range(len(cm))])
457
+ plt.title('Confusion Matrix')
458
+ plt.ylabel('True Label')
459
+ plt.xlabel('Predicted Label')
460
+ plt.tight_layout()
461
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
462
+ plt.close()
463
+ print(f"✓ Confusion matrix saved to {save_path}")
464
+
465
+
466
+ def plot_training_history(train_losses, val_losses, train_accs, val_accs, save_path):
467
+ """Plot training history"""
468
+ epochs = range(1, len(train_losses) + 1)
469
+
470
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
471
+
472
+ # Loss plot
473
+ ax1.plot(epochs, train_losses, 'b-o', label='Train Loss', linewidth=2)
474
+ ax1.plot(epochs, val_losses, 'r-s', label='Val Loss', linewidth=2)
475
+ ax1.set_xlabel('Epoch', fontsize=12)
476
+ ax1.set_ylabel('Loss', fontsize=12)
477
+ ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
478
+ ax1.legend(fontsize=11)
479
+ ax1.grid(True, alpha=0.3)
480
+
481
+ # Accuracy plot
482
+ ax2.plot(epochs, train_accs, 'b-o', label='Train Acc', linewidth=2)
483
+ ax2.plot(epochs, val_accs, 'r-s', label='Val Acc', linewidth=2)
484
+ ax2.set_xlabel('Epoch', fontsize=12)
485
+ ax2.set_ylabel('Accuracy (%)', fontsize=12)
486
+ ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
487
+ ax2.legend(fontsize=11)
488
+ ax2.grid(True, alpha=0.3)
489
+
490
+ plt.tight_layout()
491
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
492
+ plt.close()
493
+ print(f"✓ Training history saved to {save_path}")
494
+
495
+
496
+ # MAIN TRAINING FUNCTION
497
+
498
+ def main():
499
+ """Main training pipeline"""
500
+
501
+ # Set seeds for reproducibility
502
+ torch.manual_seed(543)
503
+ np.random.seed(543)
504
+
505
+ print("="*70)
506
+ print("CVGGNet-16 ULTRA-OPTIMIZED Training")
507
+ print("Strategy: Ultra-Aggressive Freezing + Tiny Classifier")
508
+ print("="*70)
509
+ print(f"Device: {DEVICE}")
510
+ print(f"Batch size: {BATCH_SIZE}")
511
+ print(f"Max epochs: {MAX_EPOCHS} (REDUCED to prevent overfitting)")
512
+ print(f"Learning rate: {LEARNING_RATE}")
513
+ print(f"Weight decay: {WEIGHT_DECAY} (INCREASED for regularization)")
514
+ print(f"Bilateral filter: {USE_BILATERAL_FILTER}")
515
+ print(f"Early stopping: {USE_EARLY_STOPPING} (patience={EARLY_STOP_PATIENCE})")
516
+ print("="*70 + "\n")
517
+
518
+ # DATA PREPARATION
519
+
520
+ # Create validation split
521
+ df = pd.read_csv(PATH_TO_TRAIN_GT)
522
+ if "validation_set" not in df.columns:
523
+ df["validation_set"] = 0
524
+ val_indices = df.sample(frac=VAL_FRACTION, random_state=42).index
525
+ df.loc[val_indices, "validation_set"] = 1
526
+ df.to_csv(PATH_TO_TRAIN_GT, index=False)
527
+ print(f"✓ Created validation split ({VAL_FRACTION*100:.0f}%)\n")
528
+
529
+ # REDUCED Data Augmentation (was too aggressive)
530
+ train_transform = transforms.Compose([
531
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
532
+ transforms.RandomHorizontalFlip(p=0.5), # REDUCED from 0.5
533
+ transforms.RandomRotation(degrees=15),
534
+ #transforms.AugMix(severity=2), # REDUCED from 15
535
+ # REMOVED ColorJitter - too aggressive for surgical images
536
+ transforms.ToTensor(),
537
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
538
+ ])
539
+
540
+ val_transform = transforms.Compose([
541
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
542
+ transforms.ToTensor(),
543
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
544
+ ])
545
+
546
+ # Create datasets
547
+ train_dataset = SurgicalToolDataset(
548
+ img_dir=PATH_TO_IMAGES,
549
+ annotation_file=PATH_TO_TRAIN_GT,
550
+ transform=train_transform,
551
+ validation_set=False,
552
+ use_bilateral_filter=USE_BILATERAL_FILTER
553
+ )
554
+
555
+ val_dataset = SurgicalToolDataset(
556
+ img_dir=PATH_TO_IMAGES,
557
+ annotation_file=PATH_TO_TRAIN_GT,
558
+ transform=val_transform,
559
+ validation_set=True,
560
+ use_bilateral_filter=USE_BILATERAL_FILTER
561
+ )
562
+
563
+ # Create dataloaders
564
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
565
+ shuffle=True, num_workers=6, pin_memory=True)
566
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
567
+ shuffle=False, num_workers=6, pin_memory=True)
568
+
569
+ print(f"Dataset sizes:")
570
+ print(f" Training: {len(train_dataset)} images")
571
+ print(f" Validation: {len(val_dataset)} images")
572
+ print(f" Batches per epoch: {len(train_loader)} (train), {len(val_loader)} (val)")
573
+
574
+ # Compute class weights
575
+ class_weights = None
576
+ if USE_CLASS_WEIGHTS:
577
+ class_weights = compute_class_weights(train_dataset.labels, NUM_CLASSES)
578
+
579
+ # MODEL SETUP
580
+
581
+ print(f"\nCreating CVGGNet-Resnet Ultra-Optimized model...")
582
+ model = CVGGNetResNet50(num_classes=NUM_CLASSES, pretrained=True).to(DEVICE)
583
+
584
+ # Loss and optimizer
585
+ criterion = nn.CrossEntropyLoss()
586
+
587
+ # Optimizer - only for trainable parameters
588
+ optimizer = optim.AdamW(
589
+ filter(lambda p: p.requires_grad, model.parameters()),
590
+ lr=LEARNING_RATE,
591
+ weight_decay=WEIGHT_DECAY
592
+ )
593
+
594
+ # Learning rate scheduler
595
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
596
+ optimizer, mode='min', factor=0.5, patience=2, verbose=True
597
+ ) #DA CAPIRE
598
+
599
+ # Early stopping
600
+ early_stopping = None
601
+ if USE_EARLY_STOPPING:
602
+ early_stopping = EarlyStopping(patience=EARLY_STOP_PATIENCE, min_delta=0.001)
603
+
604
+ # TRAINING LOOP
605
+
606
+ best_val_loss = float('inf')
607
+ best_val_acc = 0.0
608
+ train_losses, val_losses = [], []
609
+ train_accs, val_accs = [], []
610
+
611
+ print("\n" + "="*70)
612
+ print("Starting Training")
613
+ print("="*70 + "\n")
614
+
615
+ import time
616
+ training_start_time = time.time()
617
+
618
+ for epoch in range(MAX_EPOCHS):
619
+ epoch_start_time = time.time()
620
+
621
+ print(f"\nEpoch [{epoch+1}/{MAX_EPOCHS}]")
622
+ print("-" * 70)
623
+
624
+ # Train
625
+ train_loss, train_acc = train_epoch(
626
+ model, train_loader, criterion, optimizer, DEVICE, class_weights
627
+ )
628
+
629
+ print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
630
+
631
+ # Validate
632
+ val_loss, val_predictions, val_labels = validate(
633
+ model, val_loader, criterion, DEVICE
634
+ )
635
+
636
+ val_acc = 100. * np.sum(np.array(val_predictions) == np.array(val_labels)) / len(val_labels)
637
+
638
+ print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")
639
+
640
+ # Classification report
641
+ print("\nValidation Metrics:")
642
+ report = classification_report(val_labels, val_predictions,
643
+ target_names=[f'Class {i}' for i in range(NUM_CLASSES)],
644
+ digits=4)
645
+ print(report)
646
+
647
+ # Save history
648
+ train_losses.append(train_loss)
649
+ val_losses.append(val_loss)
650
+ train_accs.append(train_acc)
651
+ val_accs.append(val_acc)
652
+
653
+ # Learning rate scheduling
654
+ scheduler.step(val_loss)
655
+
656
+ # Save best model
657
+ if val_acc > best_val_acc:
658
+ best_val_acc = val_acc
659
+ best_val_loss = val_loss
660
+ torch.save({
661
+ 'epoch': epoch,
662
+ 'model_state_dict': model.state_dict(),
663
+ 'optimizer_state_dict': optimizer.state_dict(),
664
+ 'val_acc': val_acc,
665
+ 'val_loss': val_loss,
666
+ 'train_acc': train_acc,
667
+ 'train_loss': train_loss,
668
+ }, MODEL_SAVE_PATH)
669
+ print(f"\n✓ Best model saved! (Val Acc: {val_acc:.2f}%)")
670
+
671
+ # Early stopping check
672
+ if early_stopping is not None:
673
+ if early_stopping(val_loss):
674
+ print(f"\n⚠️ Early stopping at epoch {epoch+1}")
675
+ break
676
+
677
+ epoch_time = time.time() - epoch_start_time
678
+ print(f"\nEpoch time: {epoch_time/60:.2f} minutes")
679
+ print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
680
+
681
+ training_time = time.time() - training_start_time
682
+
683
+ # FINAL EVALUATION
684
+
685
+ print("\n" + "="*70)
686
+ print("Training Complete!")
687
+ print("="*70)
688
+ print(f"Total training time: {training_time/60:.2f} minutes")
689
+ print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
690
+ print(f"Best Validation Loss: {best_val_loss:.4f}")
691
+ print(f"Model saved to: {MODEL_SAVE_PATH}")
692
+
693
+ # Check model size
694
+ model_size_bytes = os.path.getsize(MODEL_SAVE_PATH)
695
+ model_size_mb = model_size_bytes / (1024**2)
696
+ print(f"Model file size: {model_size_mb:.1f} MB")
697
+
698
+ if model_size_mb > 500:
699
+ print("⚠️ WARNING: Model still large (>500MB). Check classifier architecture.")
700
+ else:
701
+ print("✓ Model size is good for HuggingFace upload!")
702
+
703
+ # Load best model for final evaluation
704
+ checkpoint = torch.load(MODEL_SAVE_PATH)
705
+ model.load_state_dict(checkpoint['model_state_dict'])
706
+
707
+ # Final validation
708
+ _, final_predictions, final_labels = validate(model, val_loader, criterion, DEVICE)
709
+
710
+ # Plot confusion matrix
711
+ cm_path = os.path.join(BASE_PATH, 'confusion_matrix_ultra_optimized.png')
712
+ plot_confusion_matrix(final_labels, final_predictions, cm_path)
713
+
714
+ # Plot training history
715
+ history_path = os.path.join(BASE_PATH, 'training_history_ultra_optimized.png')
716
+ plot_training_history(train_losses, val_losses, train_accs, val_accs, history_path)
717
+
718
+ # Final metrics
719
+ print("\n" + "="*70)
720
+ print("Final Validation Metrics:")
721
+ print("="*70)
722
+ final_report = classification_report(final_labels, final_predictions,
723
+ target_names=[f'Class {i}' for i in range(NUM_CLASSES)],
724
+ digits=4)
725
+ print(final_report)
726
+
727
+ print(f"\n✓ All done! Results saved in {BASE_PATH}")
728
+ print("="*70)
729
+
730
+ return model
731
+
732
+
733
+ if __name__ == "__main__":
734
+ model = main()
735
+