import pandas as pd import os import re import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import timm from sklearn.metrics import classification_report from sklearn.model_selection import StratifiedGroupKFold from sklearn.utils.class_weight import compute_class_weight from submission.utils.utils import ImageData import torchvision.transforms as transforms import numpy as np from tqdm import tqdm # --- CONFIGURATION --- BASE_PATH = "/Users/yusufbardolia/Documents/Intelligent System In Medicine/phase_1a" PATH_TO_IMAGES = os.path.join(BASE_PATH, "images") PATH_TO_GT = os.path.join(BASE_PATH, "gt_for_classification_multiclass_from_filenames_0_index.csv") PATH_TO_SPLIT_GT = os.path.join(os.getcwd(), "honest_split_gt.csv") MODEL_SAVE_PATH = os.path.join("submission", "multiclass_model.pth") # --- UPGRADES --- MODEL_NAME = 'efficientnet_b3' # Larger, more powerful model IMAGE_SIZE = (300, 300) # EfficientNet-B3 native resolution MAX_EPOCHS = 15 BATCH_SIZE = 16 # Smaller batch for larger model NUM_CLASSES = 3 LEARNING_RATE = 0.0003 if torch.backends.mps.is_available(): DEVICE = "mps" print(f"✅ Using Apple M-Series GPU (MPS)") else: DEVICE = "cpu" def create_honest_split(): print("Creating honest, stratified data split...") df = pd.read_csv(PATH_TO_GT) surgery_dates = [] for fname in df["file_name"]: match = re.search(r'(202\d{5})', fname) surgery_dates.append(match.group(1) if match else "unknown") groups = np.array(surgery_dates) y = df["category_id"].values sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42) train_idx, val_idx = next(sgkf.split(df, y, groups=groups)) df["validation_set"] = 0 df.loc[val_idx, "validation_set"] = 1 df.to_csv(PATH_TO_SPLIT_GT, index=False) classes = np.unique(y) weights = compute_class_weight(class_weight='balanced', classes=classes, y=y[train_idx]) return PATH_TO_SPLIT_GT, torch.tensor(weights, dtype=torch.float32).to(DEVICE) def main(): split_csv_path, class_weights = create_honest_split() # 2. Transforms (Heavy Augmentation) train_transforms = transforms.Compose([ transforms.Resize((320, 320)), # Resize larger first transforms.RandomCrop(IMAGE_SIZE), # Then random crop (better data aug) transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomRotation(degrees=45), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transforms = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=0, transform=train_transforms) val_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=1, transform=val_transforms) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) print(f"Loading {MODEL_NAME}...") model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES) model = model.to(DEVICE) criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1) optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01) scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=MAX_EPOCHS) print(f"Starting training...") best_f1 = 0.0 for epoch in range(MAX_EPOCHS): model.train() pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}") for img, label in pbar: img, label = img.to(DEVICE), label.to(DEVICE) optimizer.zero_grad() output = model(img) loss = criterion(output, label) loss.backward() optimizer.step() scheduler.step() pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) # Validation model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for img, label in val_loader: img, label = img.to(DEVICE), label.to(DEVICE) output = model(img) preds = torch.argmax(output, dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(label.cpu().numpy()) report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0) curr_f1 = report['macro avg']['f1-score'] print(f"Val F1: {curr_f1:.4f}") if curr_f1 > best_f1: best_f1 = curr_f1 torch.save(model.state_dict(), MODEL_SAVE_PATH) print(f"🚀 Saved {MODEL_SAVE_PATH}") print(f"Done. Best F1: {best_f1:.4f}") if __name__ == "__main__": if not os.path.exists("submission"): os.makedirs("submission") main()