Upload 9 files
Browse files- multiclass_model.pth +3 -0
- phase_1b_sample_solution_multiclass.py +143 -0
- script.py +80 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-314.pyc +0 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/utils.cpython-314.pyc +0 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/utils.py +56 -0
multiclass_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e9df2fa575a0d20999bffa87c6ed668352a8e7a3f141ebeeb66092280b19f6f
|
| 3 |
+
size 43349799
|
phase_1b_sample_solution_multiclass.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
import timm
|
| 9 |
+
from sklearn.metrics import classification_report
|
| 10 |
+
from sklearn.model_selection import StratifiedGroupKFold
|
| 11 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 12 |
+
from submission.utils.utils import ImageData
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
# --- CONFIGURATION ---
|
| 18 |
+
BASE_PATH = "/Users/yusufbardolia/Documents/Intelligent System In Medicine/phase_1a"
|
| 19 |
+
PATH_TO_IMAGES = os.path.join(BASE_PATH, "images")
|
| 20 |
+
PATH_TO_GT = os.path.join(BASE_PATH, "gt_for_classification_multiclass_from_filenames_0_index.csv")
|
| 21 |
+
|
| 22 |
+
PATH_TO_SPLIT_GT = os.path.join(os.getcwd(), "honest_split_gt.csv")
|
| 23 |
+
MODEL_SAVE_PATH = os.path.join("submission", "multiclass_model.pth")
|
| 24 |
+
|
| 25 |
+
# --- UPGRADES ---
|
| 26 |
+
MODEL_NAME = 'efficientnet_b3' # Larger, more powerful model
|
| 27 |
+
IMAGE_SIZE = (300, 300) # EfficientNet-B3 native resolution
|
| 28 |
+
MAX_EPOCHS = 15
|
| 29 |
+
BATCH_SIZE = 16 # Smaller batch for larger model
|
| 30 |
+
NUM_CLASSES = 3
|
| 31 |
+
LEARNING_RATE = 0.0003
|
| 32 |
+
|
| 33 |
+
if torch.backends.mps.is_available():
|
| 34 |
+
DEVICE = "mps"
|
| 35 |
+
print(f"✅ Using Apple M-Series GPU (MPS)")
|
| 36 |
+
else:
|
| 37 |
+
DEVICE = "cpu"
|
| 38 |
+
|
| 39 |
+
def create_honest_split():
|
| 40 |
+
print("Creating honest, stratified data split...")
|
| 41 |
+
df = pd.read_csv(PATH_TO_GT)
|
| 42 |
+
|
| 43 |
+
surgery_dates = []
|
| 44 |
+
for fname in df["file_name"]:
|
| 45 |
+
match = re.search(r'(202\d{5})', fname)
|
| 46 |
+
surgery_dates.append(match.group(1) if match else "unknown")
|
| 47 |
+
|
| 48 |
+
groups = np.array(surgery_dates)
|
| 49 |
+
y = df["category_id"].values
|
| 50 |
+
|
| 51 |
+
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
|
| 52 |
+
train_idx, val_idx = next(sgkf.split(df, y, groups=groups))
|
| 53 |
+
|
| 54 |
+
df["validation_set"] = 0
|
| 55 |
+
df.loc[val_idx, "validation_set"] = 1
|
| 56 |
+
df.to_csv(PATH_TO_SPLIT_GT, index=False)
|
| 57 |
+
|
| 58 |
+
classes = np.unique(y)
|
| 59 |
+
weights = compute_class_weight(class_weight='balanced', classes=classes, y=y[train_idx])
|
| 60 |
+
return PATH_TO_SPLIT_GT, torch.tensor(weights, dtype=torch.float32).to(DEVICE)
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
split_csv_path, class_weights = create_honest_split()
|
| 64 |
+
|
| 65 |
+
# 2. Transforms (Heavy Augmentation)
|
| 66 |
+
train_transforms = transforms.Compose([
|
| 67 |
+
transforms.Resize((320, 320)), # Resize larger first
|
| 68 |
+
transforms.RandomCrop(IMAGE_SIZE), # Then random crop (better data aug)
|
| 69 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 70 |
+
transforms.RandomVerticalFlip(p=0.5),
|
| 71 |
+
transforms.RandomRotation(degrees=45),
|
| 72 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
| 73 |
+
transforms.ToTensor(),
|
| 74 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
val_transforms = transforms.Compose([
|
| 78 |
+
transforms.Resize(IMAGE_SIZE),
|
| 79 |
+
transforms.ToTensor(),
|
| 80 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
train_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=0, transform=train_transforms)
|
| 84 |
+
val_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=1, transform=val_transforms)
|
| 85 |
+
|
| 86 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
|
| 87 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
|
| 88 |
+
|
| 89 |
+
print(f"Loading {MODEL_NAME}...")
|
| 90 |
+
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES)
|
| 91 |
+
model = model.to(DEVICE)
|
| 92 |
+
|
| 93 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
|
| 94 |
+
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
|
| 95 |
+
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=MAX_EPOCHS)
|
| 96 |
+
|
| 97 |
+
print(f"Starting training...")
|
| 98 |
+
best_f1 = 0.0
|
| 99 |
+
|
| 100 |
+
for epoch in range(MAX_EPOCHS):
|
| 101 |
+
model.train()
|
| 102 |
+
|
| 103 |
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
|
| 104 |
+
for img, label in pbar:
|
| 105 |
+
img, label = img.to(DEVICE), label.to(DEVICE)
|
| 106 |
+
|
| 107 |
+
optimizer.zero_grad()
|
| 108 |
+
output = model(img)
|
| 109 |
+
loss = criterion(output, label)
|
| 110 |
+
loss.backward()
|
| 111 |
+
optimizer.step()
|
| 112 |
+
scheduler.step()
|
| 113 |
+
|
| 114 |
+
pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
|
| 115 |
+
|
| 116 |
+
# Validation
|
| 117 |
+
model.eval()
|
| 118 |
+
all_preds = []
|
| 119 |
+
all_labels = []
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
for img, label in val_loader:
|
| 123 |
+
img, label = img.to(DEVICE), label.to(DEVICE)
|
| 124 |
+
output = model(img)
|
| 125 |
+
preds = torch.argmax(output, dim=1).cpu().numpy()
|
| 126 |
+
all_preds.extend(preds)
|
| 127 |
+
all_labels.extend(label.cpu().numpy())
|
| 128 |
+
|
| 129 |
+
report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
|
| 130 |
+
curr_f1 = report['macro avg']['f1-score']
|
| 131 |
+
|
| 132 |
+
print(f"Val F1: {curr_f1:.4f}")
|
| 133 |
+
|
| 134 |
+
if curr_f1 > best_f1:
|
| 135 |
+
best_f1 = curr_f1
|
| 136 |
+
torch.save(model.state_dict(), MODEL_SAVE_PATH)
|
| 137 |
+
print(f"🚀 Saved {MODEL_SAVE_PATH}")
|
| 138 |
+
|
| 139 |
+
print(f"Done. Best F1: {best_f1:.4f}")
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
if not os.path.exists("submission"): os.makedirs("submission")
|
| 143 |
+
main()
|
script.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import timm
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from scipy.stats import mode
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
# CONFIG MUST MATCH TRAINING
|
| 11 |
+
MODEL_NAME = 'efficientnet_b3'
|
| 12 |
+
IMAGE_SIZE = (300, 300)
|
| 13 |
+
NUM_CLASSES = 3
|
| 14 |
+
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 15 |
+
|
| 16 |
+
def apply_temporal_smoothing(predictions, window_size=5):
|
| 17 |
+
smoothed_preds = predictions.copy()
|
| 18 |
+
for i in range(len(predictions)):
|
| 19 |
+
start = max(0, i - window_size // 2)
|
| 20 |
+
end = min(len(predictions), i + window_size // 2 + 1)
|
| 21 |
+
window = predictions[start:end]
|
| 22 |
+
most_common = mode(window, keepdims=False)[0]
|
| 23 |
+
smoothed_preds[i] = most_common
|
| 24 |
+
return smoothed_preds
|
| 25 |
+
|
| 26 |
+
def run_inference(TEST_IMAGE_PATH, model, SUBMISSION_CSV_SAVE_PATH):
|
| 27 |
+
model.eval()
|
| 28 |
+
test_images = os.listdir(TEST_IMAGE_PATH)
|
| 29 |
+
test_images.sort()
|
| 30 |
+
|
| 31 |
+
transform = transforms.Compose([
|
| 32 |
+
transforms.Resize(IMAGE_SIZE),
|
| 33 |
+
transforms.ToTensor(),
|
| 34 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
raw_predictions = []
|
| 38 |
+
print(f"Inference with TTA on {len(test_images)} images...")
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
for img_name in test_images:
|
| 42 |
+
img_path = os.path.join(TEST_IMAGE_PATH, img_name)
|
| 43 |
+
try:
|
| 44 |
+
# Load Original
|
| 45 |
+
img_pil = Image.open(img_path).convert("RGB")
|
| 46 |
+
img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
|
| 47 |
+
|
| 48 |
+
# Load Flipped (TTA)
|
| 49 |
+
img_flip = transform(img_pil.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(DEVICE)
|
| 50 |
+
|
| 51 |
+
# Predict both
|
| 52 |
+
out1 = model(img_tensor)
|
| 53 |
+
out2 = model(img_flip)
|
| 54 |
+
|
| 55 |
+
# Average probabilities
|
| 56 |
+
avg_out = (F.softmax(out1, dim=1) + F.softmax(out2, dim=1)) / 2
|
| 57 |
+
|
| 58 |
+
pred = torch.argmax(avg_out, dim=1).item()
|
| 59 |
+
raw_predictions.append(pred)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"Error {img_name}: {e}")
|
| 62 |
+
raw_predictions.append(0)
|
| 63 |
+
|
| 64 |
+
final_predictions = apply_temporal_smoothing(raw_predictions, window_size=5)
|
| 65 |
+
|
| 66 |
+
df = pd.DataFrame({"file_name": test_images, "category_id": final_predictions})
|
| 67 |
+
df.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False)
|
| 68 |
+
print(f"Saved to {SUBMISSION_CSV_SAVE_PATH}")
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 72 |
+
TEST_PATH = "/tmp/data/test_images"
|
| 73 |
+
MODEL_PATH = os.path.join(current_dir, "multiclass_model.pth")
|
| 74 |
+
SUB_PATH = os.path.join(current_dir, "submission.csv")
|
| 75 |
+
|
| 76 |
+
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
|
| 77 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| 78 |
+
model = model.to(DEVICE)
|
| 79 |
+
|
| 80 |
+
run_inference(TEST_PATH, model, SUB_PATH)
|
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
utils/__pycache__/utils.cpython-314.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (1.27 kB). View file
|
|
|
utils/utils.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
class ImageData(Dataset):
|
| 8 |
+
def __init__(self, img_dir, annotation_file, validation_set, transform=None):
|
| 9 |
+
"""
|
| 10 |
+
Custom Dataset that respects the 'validation_set' column in the CSV.
|
| 11 |
+
0 = Training Set
|
| 12 |
+
1 = Validation Set
|
| 13 |
+
"""
|
| 14 |
+
# Read the CSV file
|
| 15 |
+
try:
|
| 16 |
+
gt = pd.read_csv(annotation_file)
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Error reading CSV {annotation_file}: {e}")
|
| 19 |
+
# Return empty if failed, to prevent crash during init
|
| 20 |
+
self.img_labels = pd.DataFrame()
|
| 21 |
+
self.img_dir = img_dir
|
| 22 |
+
self.transform = transform
|
| 23 |
+
self.images = []
|
| 24 |
+
self.labels = []
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
# Filter: 0 = Train, 1 = Validation
|
| 28 |
+
if validation_set:
|
| 29 |
+
self.img_labels = gt[gt["validation_set"] == 1].reset_index(drop=True)
|
| 30 |
+
else:
|
| 31 |
+
self.img_labels = gt[gt["validation_set"] == 0].reset_index(drop=True)
|
| 32 |
+
|
| 33 |
+
self.img_dir = img_dir
|
| 34 |
+
self.transform = transform
|
| 35 |
+
|
| 36 |
+
# Store filenames and labels
|
| 37 |
+
self.images = self.img_labels["file_name"].values
|
| 38 |
+
self.labels = self.img_labels["category_id"].values
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return len(self.img_labels)
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, idx):
|
| 44 |
+
img_name = self.images[idx]
|
| 45 |
+
img_path = os.path.join(self.img_dir, img_name)
|
| 46 |
+
|
| 47 |
+
# CRITICAL: Open in RGB mode. OpenCV loads BGR by default, but PIL is safer here.
|
| 48 |
+
image = Image.open(img_path).convert("RGB")
|
| 49 |
+
|
| 50 |
+
label = self.labels[idx]
|
| 51 |
+
|
| 52 |
+
if self.transform:
|
| 53 |
+
image = self.transform(image)
|
| 54 |
+
|
| 55 |
+
# Return image and label (as long/int for CrossEntropyLoss)
|
| 56 |
+
return image, int(label)
|