yusufbardolia commited on
Commit
9946dd2
·
verified ·
1 Parent(s): 55db5f1

Upload 9 files

Browse files
multiclass_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e9df2fa575a0d20999bffa87c6ed668352a8e7a3f141ebeeb66092280b19f6f
3
+ size 43349799
phase_1b_sample_solution_multiclass.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import re
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader
8
+ import timm
9
+ from sklearn.metrics import classification_report
10
+ from sklearn.model_selection import StratifiedGroupKFold
11
+ from sklearn.utils.class_weight import compute_class_weight
12
+ from submission.utils.utils import ImageData
13
+ import torchvision.transforms as transforms
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ # --- CONFIGURATION ---
18
+ BASE_PATH = "/Users/yusufbardolia/Documents/Intelligent System In Medicine/phase_1a"
19
+ PATH_TO_IMAGES = os.path.join(BASE_PATH, "images")
20
+ PATH_TO_GT = os.path.join(BASE_PATH, "gt_for_classification_multiclass_from_filenames_0_index.csv")
21
+
22
+ PATH_TO_SPLIT_GT = os.path.join(os.getcwd(), "honest_split_gt.csv")
23
+ MODEL_SAVE_PATH = os.path.join("submission", "multiclass_model.pth")
24
+
25
+ # --- UPGRADES ---
26
+ MODEL_NAME = 'efficientnet_b3' # Larger, more powerful model
27
+ IMAGE_SIZE = (300, 300) # EfficientNet-B3 native resolution
28
+ MAX_EPOCHS = 15
29
+ BATCH_SIZE = 16 # Smaller batch for larger model
30
+ NUM_CLASSES = 3
31
+ LEARNING_RATE = 0.0003
32
+
33
+ if torch.backends.mps.is_available():
34
+ DEVICE = "mps"
35
+ print(f"✅ Using Apple M-Series GPU (MPS)")
36
+ else:
37
+ DEVICE = "cpu"
38
+
39
+ def create_honest_split():
40
+ print("Creating honest, stratified data split...")
41
+ df = pd.read_csv(PATH_TO_GT)
42
+
43
+ surgery_dates = []
44
+ for fname in df["file_name"]:
45
+ match = re.search(r'(202\d{5})', fname)
46
+ surgery_dates.append(match.group(1) if match else "unknown")
47
+
48
+ groups = np.array(surgery_dates)
49
+ y = df["category_id"].values
50
+
51
+ sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
52
+ train_idx, val_idx = next(sgkf.split(df, y, groups=groups))
53
+
54
+ df["validation_set"] = 0
55
+ df.loc[val_idx, "validation_set"] = 1
56
+ df.to_csv(PATH_TO_SPLIT_GT, index=False)
57
+
58
+ classes = np.unique(y)
59
+ weights = compute_class_weight(class_weight='balanced', classes=classes, y=y[train_idx])
60
+ return PATH_TO_SPLIT_GT, torch.tensor(weights, dtype=torch.float32).to(DEVICE)
61
+
62
+ def main():
63
+ split_csv_path, class_weights = create_honest_split()
64
+
65
+ # 2. Transforms (Heavy Augmentation)
66
+ train_transforms = transforms.Compose([
67
+ transforms.Resize((320, 320)), # Resize larger first
68
+ transforms.RandomCrop(IMAGE_SIZE), # Then random crop (better data aug)
69
+ transforms.RandomHorizontalFlip(p=0.5),
70
+ transforms.RandomVerticalFlip(p=0.5),
71
+ transforms.RandomRotation(degrees=45),
72
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
75
+ ])
76
+
77
+ val_transforms = transforms.Compose([
78
+ transforms.Resize(IMAGE_SIZE),
79
+ transforms.ToTensor(),
80
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
81
+ ])
82
+
83
+ train_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=0, transform=train_transforms)
84
+ val_dataset = ImageData(PATH_TO_IMAGES, split_csv_path, validation_set=1, transform=val_transforms)
85
+
86
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
87
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
88
+
89
+ print(f"Loading {MODEL_NAME}...")
90
+ model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES)
91
+ model = model.to(DEVICE)
92
+
93
+ criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
94
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
95
+ scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=MAX_EPOCHS)
96
+
97
+ print(f"Starting training...")
98
+ best_f1 = 0.0
99
+
100
+ for epoch in range(MAX_EPOCHS):
101
+ model.train()
102
+
103
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
104
+ for img, label in pbar:
105
+ img, label = img.to(DEVICE), label.to(DEVICE)
106
+
107
+ optimizer.zero_grad()
108
+ output = model(img)
109
+ loss = criterion(output, label)
110
+ loss.backward()
111
+ optimizer.step()
112
+ scheduler.step()
113
+
114
+ pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
115
+
116
+ # Validation
117
+ model.eval()
118
+ all_preds = []
119
+ all_labels = []
120
+
121
+ with torch.no_grad():
122
+ for img, label in val_loader:
123
+ img, label = img.to(DEVICE), label.to(DEVICE)
124
+ output = model(img)
125
+ preds = torch.argmax(output, dim=1).cpu().numpy()
126
+ all_preds.extend(preds)
127
+ all_labels.extend(label.cpu().numpy())
128
+
129
+ report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
130
+ curr_f1 = report['macro avg']['f1-score']
131
+
132
+ print(f"Val F1: {curr_f1:.4f}")
133
+
134
+ if curr_f1 > best_f1:
135
+ best_f1 = curr_f1
136
+ torch.save(model.state_dict(), MODEL_SAVE_PATH)
137
+ print(f"🚀 Saved {MODEL_SAVE_PATH}")
138
+
139
+ print(f"Done. Best F1: {best_f1:.4f}")
140
+
141
+ if __name__ == "__main__":
142
+ if not os.path.exists("submission"): os.makedirs("submission")
143
+ main()
script.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ import timm
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from scipy.stats import mode
8
+ import torch.nn.functional as F
9
+
10
+ # CONFIG MUST MATCH TRAINING
11
+ MODEL_NAME = 'efficientnet_b3'
12
+ IMAGE_SIZE = (300, 300)
13
+ NUM_CLASSES = 3
14
+ DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
15
+
16
+ def apply_temporal_smoothing(predictions, window_size=5):
17
+ smoothed_preds = predictions.copy()
18
+ for i in range(len(predictions)):
19
+ start = max(0, i - window_size // 2)
20
+ end = min(len(predictions), i + window_size // 2 + 1)
21
+ window = predictions[start:end]
22
+ most_common = mode(window, keepdims=False)[0]
23
+ smoothed_preds[i] = most_common
24
+ return smoothed_preds
25
+
26
+ def run_inference(TEST_IMAGE_PATH, model, SUBMISSION_CSV_SAVE_PATH):
27
+ model.eval()
28
+ test_images = os.listdir(TEST_IMAGE_PATH)
29
+ test_images.sort()
30
+
31
+ transform = transforms.Compose([
32
+ transforms.Resize(IMAGE_SIZE),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
35
+ ])
36
+
37
+ raw_predictions = []
38
+ print(f"Inference with TTA on {len(test_images)} images...")
39
+
40
+ with torch.no_grad():
41
+ for img_name in test_images:
42
+ img_path = os.path.join(TEST_IMAGE_PATH, img_name)
43
+ try:
44
+ # Load Original
45
+ img_pil = Image.open(img_path).convert("RGB")
46
+ img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
47
+
48
+ # Load Flipped (TTA)
49
+ img_flip = transform(img_pil.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(DEVICE)
50
+
51
+ # Predict both
52
+ out1 = model(img_tensor)
53
+ out2 = model(img_flip)
54
+
55
+ # Average probabilities
56
+ avg_out = (F.softmax(out1, dim=1) + F.softmax(out2, dim=1)) / 2
57
+
58
+ pred = torch.argmax(avg_out, dim=1).item()
59
+ raw_predictions.append(pred)
60
+ except Exception as e:
61
+ print(f"Error {img_name}: {e}")
62
+ raw_predictions.append(0)
63
+
64
+ final_predictions = apply_temporal_smoothing(raw_predictions, window_size=5)
65
+
66
+ df = pd.DataFrame({"file_name": test_images, "category_id": final_predictions})
67
+ df.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False)
68
+ print(f"Saved to {SUBMISSION_CSV_SAVE_PATH}")
69
+
70
+ if __name__ == "__main__":
71
+ current_dir = os.path.dirname(os.path.abspath(__file__))
72
+ TEST_PATH = "/tmp/data/test_images"
73
+ MODEL_PATH = os.path.join(current_dir, "multiclass_model.pth")
74
+ SUB_PATH = os.path.join(current_dir, "submission.csv")
75
+
76
+ model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
77
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
78
+ model = model.to(DEVICE)
79
+
80
+ run_inference(TEST_PATH, model, SUB_PATH)
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (198 Bytes). View file
 
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (171 Bytes). View file
 
utils/__pycache__/utils.cpython-314.pyc ADDED
Binary file (2.79 kB). View file
 
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.27 kB). View file
 
utils/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+
7
+ class ImageData(Dataset):
8
+ def __init__(self, img_dir, annotation_file, validation_set, transform=None):
9
+ """
10
+ Custom Dataset that respects the 'validation_set' column in the CSV.
11
+ 0 = Training Set
12
+ 1 = Validation Set
13
+ """
14
+ # Read the CSV file
15
+ try:
16
+ gt = pd.read_csv(annotation_file)
17
+ except Exception as e:
18
+ print(f"Error reading CSV {annotation_file}: {e}")
19
+ # Return empty if failed, to prevent crash during init
20
+ self.img_labels = pd.DataFrame()
21
+ self.img_dir = img_dir
22
+ self.transform = transform
23
+ self.images = []
24
+ self.labels = []
25
+ return
26
+
27
+ # Filter: 0 = Train, 1 = Validation
28
+ if validation_set:
29
+ self.img_labels = gt[gt["validation_set"] == 1].reset_index(drop=True)
30
+ else:
31
+ self.img_labels = gt[gt["validation_set"] == 0].reset_index(drop=True)
32
+
33
+ self.img_dir = img_dir
34
+ self.transform = transform
35
+
36
+ # Store filenames and labels
37
+ self.images = self.img_labels["file_name"].values
38
+ self.labels = self.img_labels["category_id"].values
39
+
40
+ def __len__(self):
41
+ return len(self.img_labels)
42
+
43
+ def __getitem__(self, idx):
44
+ img_name = self.images[idx]
45
+ img_path = os.path.join(self.img_dir, img_name)
46
+
47
+ # CRITICAL: Open in RGB mode. OpenCV loads BGR by default, but PIL is safer here.
48
+ image = Image.open(img_path).convert("RGB")
49
+
50
+ label = self.labels[idx]
51
+
52
+ if self.transform:
53
+ image = self.transform(image)
54
+
55
+ # Return image and label (as long/int for CrossEntropyLoss)
56
+ return image, int(label)