| import pandas as pd |
| import os |
| import re |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| import timm |
| from sklearn.metrics import classification_report |
| from sklearn.model_selection import StratifiedGroupKFold |
| from sklearn.utils.class_weight import compute_class_weight |
| from submission.utils.utils import ImageData |
| import torchvision.transforms as transforms |
| import numpy as np |
| from tqdm import tqdm |
|
|
| |
| BASE_PATH = "/Users/yusufbardolia/Documents/Intelligent System In Medicine/phase_1a" |
| PATH_TO_IMAGES = os.path.join(BASE_PATH, "images") |
| PATH_TO_GT = os.path.join(BASE_PATH, "gt_for_classification_multiclass_from_filenames_0_index.csv") |
|
|
| PATH_TO_SPLIT_GT = os.path.join(os.getcwd(), "honest_split_gt.csv") |
| MODEL_SAVE_PATH = os.path.join("submission", "multiclass_model.pth") |
|
|
| |
| MODEL_NAME = 'efficientnet_b3' |
| IMAGE_SIZE = (300, 300) |
| MAX_EPOCHS = 15 |
| BATCH_SIZE = 16 |
| NUM_CLASSES = 3 |
| LEARNING_RATE = 0.0003 |
|
|
| if torch.backends.mps.is_available(): |
| DEVICE = "mps" |
| print(f"✅ Using Apple M-Series GPU (MPS)") |
| else: |
| DEVICE = "cpu" |
|
|
| def create_honest_split(): |
| print("Creating honest, stratified data split...") |
| df = pd.read_csv(PATH_TO_GT) |
| |
| surgery_dates = [] |
| for fname in df["file_name"]: |
| match = re.search(r'(202\d{5})', fname) |
| surgery_dates.append(match.group(1) if match else "unknown") |
| |
| groups = np.array(surgery_dates) |
| y = df["category_id"].values |
| |
| sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42) |
| train_idx, val_idx = next(sgkf.split(df, y, groups=groups)) |
| |
| df["validation_set"] = 0 |
| df.loc[val_idx, "validation_set"] = 1 |
| df.to_csv(PATH_TO_SPLIT_GT, index=False) |
| |
| classes = np.unique(y) |
| weights = compute_class_weight(class_weight='balanced', classes=classes, y=y[train_idx]) |
| return PATH_TO_SPLIT_GT, torch.tensor(weights, dtype=torch.float32).to(DEVICE) |
|
|
| def main(): |
| split_csv_path, class_weights = create_honest_split() |
|
|
| |
| train_transforms = transforms.Compose([ |
| transforms.Resize((320, 320)), |
| transforms.RandomCrop(IMAGE_SIZE), |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.RandomVerticalFlip(p=0.5), |
| transforms.RandomRotation(degrees=45), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| val_transforms = transforms.Compose([ |
| transforms.Resize(IMAGE_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| train_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=0, transform=train_transforms) |
| val_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=1, transform=val_transforms) |
| |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) |
|
|
| print(f"Loading {MODEL_NAME}...") |
| model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES) |
| model = model.to(DEVICE) |
|
|
| criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1) |
| optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01) |
| scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=MAX_EPOCHS) |
|
|
| print(f"Starting training...") |
| best_f1 = 0.0 |
| |
| for epoch in range(MAX_EPOCHS): |
| model.train() |
| |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}") |
| for img, label in pbar: |
| img, label = img.to(DEVICE), label.to(DEVICE) |
| |
| optimizer.zero_grad() |
| output = model(img) |
| loss = criterion(output, label) |
| loss.backward() |
| optimizer.step() |
| scheduler.step() |
| |
| pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) |
| |
| |
| model.eval() |
| all_preds = [] |
| all_labels = [] |
| |
| with torch.no_grad(): |
| for img, label in val_loader: |
| img, label = img.to(DEVICE), label.to(DEVICE) |
| output = model(img) |
| preds = torch.argmax(output, dim=1).cpu().numpy() |
| all_preds.extend(preds) |
| all_labels.extend(label.cpu().numpy()) |
| |
| report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0) |
| curr_f1 = report['macro avg']['f1-score'] |
| |
| print(f"Val F1: {curr_f1:.4f}") |
| |
| if curr_f1 > best_f1: |
| best_f1 = curr_f1 |
| torch.save(model.state_dict(), MODEL_SAVE_PATH) |
| print(f"🚀 Saved {MODEL_SAVE_PATH}") |
|
|
| print(f"Done. Best F1: {best_f1:.4f}") |
|
|
| if __name__ == "__main__": |
| if not os.path.exists("submission"): os.makedirs("submission") |
| main() |