phase1b-resnet / script.py
annaferrari02's picture
Upload 3 files
fc1291b verified
"""
Inference script for CVGGNet-ResNet50
Compatible with ResNet-50 + CBAM architecture
"""
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import pandas as pd
import numpy as np
import cv2
from tqdm import tqdm
# ==================== CBAM MODULES (must match training) ====================
class ChannelAttention(nn.Module):
def __init__(self, channels, reduction=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(channels, channels // reduction, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channels // reduction, channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, channels, reduction=16, kernel_size=7):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(channels, reduction)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
x = x * self.channel_attention(x)
x = x * self.spatial_attention(x)
return x
# ==================== MODEL ARCHITECTURE ====================
class CVGGNetResNet50(nn.Module):
"""CVGGNet with ResNet-50 backbone + CBAM attention"""
def __init__(self, num_classes=3, pretrained=False):
super(CVGGNetResNet50, self).__init__()
# Load ResNet-50 backbone
resnet = models.resnet50(pretrained=pretrained)
# Extract feature layers (remove avgpool and fc)
self.features = nn.Sequential(*list(resnet.children())[:-2])
# CBAM attention on ResNet-50's output (2048 channels)
self.cbam = CBAM(channels=2048, reduction=16)
# Pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Lightweight Classifier (matches training architecture)
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.6),
nn.Linear(512, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.cbam(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# ==================== BILATERAL FILTER ====================
def rapid_bilateral_filter(image, radius=5, sigma_color=150, sigma_space=8):
"""Rapid Bilateral Filter preprocessing (matches training params)"""
if isinstance(image, Image.Image):
image = np.array(image)
filtered = cv2.bilateralFilter(image, radius, sigma_color, sigma_space)
return filtered
# ==================== INFERENCE FUNCTION ====================
def run_inference(test_images_path, model, image_size, submission_csv_path,
use_bilateral_filter=True, device='cpu'):
"""
Run inference on test images
Args:
test_images_path: Path to test images directory
model: Trained model
image_size: Input image size (single int for square images)
submission_csv_path: Path to save predictions CSV
use_bilateral_filter: Whether to apply bilateral filter preprocessing
device: Device to run inference on ('cpu' or 'cuda')
"""
model.eval()
model = model.to(device)
# Get test images
test_images = sorted(os.listdir(test_images_path))
# Preprocessing transform (matches training)
test_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
predictions = []
print(f"Running inference on {len(test_images)} images...")
for image_name in tqdm(test_images):
img_path = os.path.join(test_images_path, image_name)
image = Image.open(img_path).convert('RGB')
# Apply bilateral filter if enabled
if use_bilateral_filter:
image = rapid_bilateral_filter(image)
image = Image.fromarray(image)
# Preprocess
img_tensor = test_transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
output = model(img_tensor)
pred = torch.argmax(output, dim=1).cpu().item()
predictions.append(pred)
# Create submission DataFrame
df_predictions = pd.DataFrame({
'file_name': test_images,
'category_id': predictions
})
# Save to CSV
df_predictions.to_csv(submission_csv_path, index=False)
print(f"\n✓ Predictions saved to: {submission_csv_path}")
# Display prediction distribution
print("\nPrediction Distribution:")
for class_id in range(3):
count = (df_predictions['category_id'] == class_id).sum()
percentage = 100 * count / len(df_predictions)
print(f" Class {class_id}: {count} images ({percentage:.1f}%)")
return df_predictions
# ==================== MAIN SCRIPT ====================
if __name__ == "__main__":
# Paths
current_directory = os.path.dirname(os.path.abspath(__file__))
TEST_IMAGE_PATH = "/tmp/data/test_images" # HuggingFace standard path
MODEL_WEIGHTS_PATH = os.path.join(current_directory, "cvggnet_optimized_small.pth")
SUBMISSION_CSV_SAVE_PATH = os.path.join(current_directory, "submission.csv")
# Configuration (MUST MATCH TRAINING)
NUM_CLASSES = 3
IMAGE_SIZE = 224 # ResNet standard input size
USE_BILATERAL_FILTER = True # Match your training setting
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("="*60)
print("CVGGNet-ResNet50 Inference")
print("="*60)
print(f"Device: {DEVICE}")
print(f"Model weights: {MODEL_WEIGHTS_PATH}")
print(f"Test images: {TEST_IMAGE_PATH}")
print(f"Output: {SUBMISSION_CSV_SAVE_PATH}")
print(f"Bilateral filter: {USE_BILATERAL_FILTER}")
print("="*60 + "\n")
# Load model
print("Loading ResNet-50 model...")
model = CVGGNetResNet50(num_classes=NUM_CLASSES, pretrained=False)
# Load weights
checkpoint = torch.load(MODEL_WEIGHTS_PATH, map_location=torch.device(DEVICE))
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✓ Model loaded from epoch {checkpoint.get('epoch', 'unknown')}")
if 'val_acc' in checkpoint:
print(f" Validation accuracy: {checkpoint.get('val_acc', 0):.2f}%")
else:
model.load_state_dict(checkpoint)
print("✓ Model weights loaded")
# Check model size
model_size_bytes = os.path.getsize(MODEL_WEIGHTS_PATH)
model_size_mb = model_size_bytes / (1024**2)
print(f" Model size: {model_size_mb:.1f} MB\n")
# Run inference
predictions_df = run_inference(
test_images_path=TEST_IMAGE_PATH,
model=model,
image_size=IMAGE_SIZE,
submission_csv_path=SUBMISSION_CSV_SAVE_PATH,
use_bilateral_filter=USE_BILATERAL_FILTER,
device=DEVICE
)
print("\n✓ Inference complete!")