erukude commited on
Commit
26d4801
·
verified ·
1 Parent(s): 7f22fff

CornViT - A Multi-stage CVT Framework

Browse files
README.md CHANGED
@@ -1,3 +1,97 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - keras
6
+ - tensorflow
7
+ - computer-vision
8
+ - image-processing
9
+ - corn-kernel-classification
10
+ pipeline_tag: image-classification
11
+ library_name: keras
12
  ---
13
+
14
+ # CornViT
15
+
16
+ A Multi-Stage Convolutional Vision Transformer Framework for Corn Kernel Analysis
17
+
18
+ ## Overview
19
+
20
+ Three-stage hierarchical classification pipeline for automated corn kernel quality assessment:
21
+
22
+ - **Stage 1**: Purity detection (Pure vs Impure)
23
+ - **Stage 2**: Shape classification (Flat vs Round)
24
+ - **Stage 3**: Embryo orientation (Up vs Down)
25
+
26
+ ## Architecture
27
+
28
+ - **Model**: CvT-13 (384×384) with ImageNet-22k pretraining
29
+ - **Framework**: PyTorch + Microsoft CvT
30
+ - **Test Accuracy**: 93.8% (Stage 1), 94.1% (Stage 2), 91.1% (Stage 3)
31
+
32
+ ## Setup
33
+
34
+ ```bash
35
+ # Clone repository
36
+ git clone https://github.com/microsoft/CvT.git
37
+
38
+ # Install dependencies
39
+ pip install -r requirements.txt
40
+ ```
41
+
42
+ ## Training
43
+
44
+ Each stage has independent training scripts:
45
+
46
+ ```bash
47
+ python stage1/train_cvt13.py # Purity classification
48
+ python stage2/train_cvt13.py # Shape classification
49
+ python stage3/train_cvt13.py # Embryo orientation
50
+ ```
51
+
52
+ ## Inference
53
+
54
+ ```bash
55
+ python stage1/inference_cvt13.py
56
+ python stage2/inference_cvt13.py
57
+ python stage3/inference_cvt13.py
58
+ ```
59
+
60
+ ## Baselines
61
+
62
+ ResNet50 and DenseNet121 baselines available in `baselines/`.
63
+
64
+ ## Structure
65
+
66
+ ```
67
+ ├── stage1/ # Purity classification
68
+ ├── stage2/ # Shape classification
69
+ ├── stage3/ # Embryo orientation
70
+ └── preprocess/ # Data preprocessing scripts
71
+ ```
72
+
73
+ ## Requirements
74
+
75
+ - Python 3.13+
76
+ - PyTorch 2.9+
77
+ - CUDA (optional, for GPU training)
78
+
79
+ ---
80
+
81
+ ## Citation
82
+ If you use this code, models, or catalog in your research, please cite:
83
+
84
+ ```bibtex
85
+ @Article{computers15010002,
86
+ AUTHOR = {Erukude, Sai Teja and Mascarenhas, Jane and Shamir, Lior},
87
+ TITLE = {CornViT: A Multi-Stage Convolutional Vision Transformer Framework for Hierarchical Corn Kernel Analysis},
88
+ JOURNAL = {Computers},
89
+ VOLUME = {15},
90
+ YEAR = {2026},
91
+ NUMBER = {1},
92
+ ARTICLE-NUMBER = {2},
93
+ URL = {https://www.mdpi.com/2073-431X/15/1/2},
94
+ ISSN = {2073-431X},
95
+ DOI = {10.3390/computers15010002}
96
+ }
97
+ ```
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "cvt",
3
+ "architectures": [
4
+ "CvTForImageClassification"
5
+ ],
6
+ "paper": {
7
+ "title": "CornViT: A Multi-Stage Convolutional Vision Transformer Framework for Hierarchical Corn Kernel Analysis",
8
+ "year": 2025,
9
+ "doi": "https://doi.org/10.3390/computers15010002"
10
+ },
11
+ "pipeline_tag": "image-classification",
12
+ "library_name": "keras",
13
+ "framework": "pytorch",
14
+ "image_size": 384,
15
+ "num_channels": 3,
16
+ "pretrained": "imagenet-22k",
17
+ "backbone": "cvt-13",
18
+ "hierarchical_pipeline": true,
19
+ "stages": [
20
+ {
21
+ "stage": 1,
22
+ "name": "purity_detection",
23
+ "labels": {
24
+ "0": "pure",
25
+ "1": "impure"
26
+ },
27
+ "num_labels": 2,
28
+ "accuracy": 0.938
29
+ },
30
+ {
31
+ "stage": 2,
32
+ "name": "shape_classification",
33
+ "labels": {
34
+ "0": "flat",
35
+ "1": "round"
36
+ },
37
+ "num_labels": 2,
38
+ "accuracy": 0.941
39
+ },
40
+ {
41
+ "stage": 3,
42
+ "name": "embryo_orientation",
43
+ "labels": {
44
+ "0": "up",
45
+ "1": "down"
46
+ },
47
+ "num_labels": 2,
48
+ "accuracy": 0.911
49
+ }
50
+ ],
51
+ "training_framework": "pytorch"
52
+ }
graphical_abstract.png ADDED
preprocess/delete.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from tqdm import tqdm
4
+
5
+
6
+ ######################
7
+ #
8
+ ######################
9
+ def delete_files(dir1: str, dir2: str) -> bool:
10
+ """
11
+ Desc:
12
+ This method compares two directories and deletes files if present.
13
+ Args:
14
+ dir1 (str): Path to the directory 1.
15
+ dir2 (str): Path to the directory 2.
16
+ Returns:
17
+ True, if the deletion was complete, otherwise False.
18
+ """
19
+ try:
20
+ if not os.path.isdir(dir1) or not os.path.isdir(dir2):
21
+ return False
22
+
23
+ dir1_files = set(os.listdir(dir1))
24
+ dir2_files = set(os.listdir(dir2))
25
+
26
+ for idx, file in enumerate(dir1_files):
27
+ print(f"Processing file {idx}...")
28
+
29
+ file_path = os.path.join(dir1, file)
30
+ if os.path.isfile(file_path):
31
+
32
+ if file in dir2_files:
33
+ # Delete the file is it is present in dir2
34
+ os.remove(file_path)
35
+
36
+ return True
37
+
38
+ except Exception as delete_ex:
39
+ print(f"Deletion error: {delete_ex}.")
40
+ return False
41
+
42
+
43
+ ######################
44
+ #
45
+ ######################
46
+ def delete_n_random_files(dir: str, n: int) -> bool:
47
+ """
48
+ Desc:
49
+ This method deletes 'n' random files from the provided directory.
50
+ Args:
51
+ dir (str): Path to the directory.
52
+ n (int): The number of random files to be deleted.
53
+ Returns:
54
+ True, if the deletion was complete, otherwise False.
55
+ """
56
+ try:
57
+ if not os.path.isdir(dir):
58
+ print(f"Directory '{dir}' does not exist.")
59
+ return False
60
+
61
+ # Get all files (not directories) in the specified directory
62
+ all_files = [f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))]
63
+
64
+ if len(all_files) < n:
65
+ print(f"Cannot delete '{n}' files, directory only contains '{len(all_files)}' files.")
66
+ return False
67
+
68
+ files_to_delete = random.sample(all_files, n)
69
+
70
+ for file in tqdm(files_to_delete):
71
+ file_path = os.path.join(dir, file)
72
+ os.remove(file_path)
73
+
74
+ return True
75
+
76
+ except Exception as delete_ex:
77
+ print(f"Error occurred while deleting: {str(delete_ex)}.")
78
+ return False
79
+
80
+
81
+ ######################
82
+ #
83
+ ######################
84
+ def delete_files_name_contains(dir: str, word: str) -> bool:
85
+ """
86
+ Desc:
87
+ Deletes all files in a directory whose filenames contain a specific word (case-insensitive).
88
+ Parameters:
89
+ dir (str): The directory to search for files.
90
+ word (str): Substring to search for in filenames.
91
+ Returns:
92
+ bool: True if deletion completes (even if no files matched), False if an error occurred.
93
+ """
94
+ try:
95
+ if not os.path.isdir(dir):
96
+ print(f"Directory '{dir}' does not exist.")
97
+ return False
98
+
99
+ all_files = [f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))]
100
+
101
+ for file in tqdm(all_files, desc="Deleting files"):
102
+ if word.lower() in file.lower():
103
+ file_path = os.path.join(dir, file)
104
+ os.remove(file_path)
105
+
106
+ return True
107
+
108
+ except Exception as delete_ex:
109
+ print(f"Error occurred while deleting: {str(delete_ex)}.")
110
+ return False
111
+
112
+
113
+ ######################
114
+ #
115
+ ######################
116
+ if __name__ == "__main__":
117
+
118
+ dir1 = ""
119
+ dir2 = ""
120
+ delete_files(dir1, dir2)
preprocess/move.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ from tqdm import tqdm
5
+
6
+
7
+ ######################
8
+ #
9
+ ######################
10
+ def move_n_random_files(dir1: str, dir2: str, n: int) -> bool:
11
+ """
12
+ Desc:
13
+ This method moves 'n' random files from the source directory (dir1) to destination directory (dir2).
14
+ Args:
15
+ dir1 (str): Path to the source directory.
16
+ dir2 (str): Path to the destination directory.
17
+ n (int): The number of random files to be moved.
18
+ Returns:
19
+ True, if the operation was successful, otherwise False.
20
+ """
21
+ try:
22
+ # Check if the source and destination directory exists
23
+ if not os.path.isdir(dir1):
24
+ print(f"Directory '{dir1}' does not exist.")
25
+ return False
26
+
27
+ if not os.path.isdir(dir2):
28
+ print(f"Directory '{dir2}' does not exist. Creating it.")
29
+ os.makedirs(dir2)
30
+
31
+ # Get all files (not directories) in the specified directory
32
+ all_files = [f for f in os.listdir(dir1) if os.path.isfile(os.path.join(dir1, f))]
33
+
34
+ if len(all_files) < n:
35
+ print(f"Cannot move '{n}' files, directory only contains '{len(all_files)}' files.")
36
+ return False
37
+
38
+ files_to_move = random.sample(all_files, n)
39
+
40
+ for file in tqdm(files_to_move):
41
+ source_path = os.path.join(dir1, file)
42
+ destination_path = os.path.join(dir2, file)
43
+
44
+ # Move the file to the destination directory
45
+ shutil.move(source_path, destination_path)
46
+
47
+ return True
48
+
49
+ except Exception as move_ex:
50
+ print(f"Error occurred while moving files: {str(move_ex)}.")
51
+ return False
52
+
53
+
54
+ def copy_n_random_files(dir1: str, dir2: str, n: int) -> bool:
55
+ """
56
+ Desc:
57
+ Randomly select and copy 'n' files from one directory to another.
58
+ Args:
59
+ dir1 (str): Path to the source directory.
60
+ dir2 (str): Path to the destination directory. Will be created if it doesn't exist.
61
+ n (int): Number of files to randomly copy.
62
+ Returns:
63
+ bool: True if the operation was successful, False otherwise.
64
+ """
65
+ try:
66
+ # Check if the source directory exists
67
+ if not os.path.isdir(dir1):
68
+ print(f"Directory '{dir1}' does not exist.")
69
+ return False
70
+
71
+ # Ensure destination directory exists
72
+ if not os.path.isdir(dir2):
73
+ print(f"Directory '{dir2}' does not exist. Creating it.")
74
+ os.makedirs(dir2)
75
+
76
+ # Get list of all files (not directories) in source
77
+ all_files = [f for f in os.listdir(dir1) if os.path.isfile(os.path.join(dir1, f))]
78
+
79
+ if len(all_files) < n:
80
+ n = len(all_files)
81
+
82
+ print(f"Copying '{n}' files to '{dir2}'...")
83
+ files_to_copy = random.sample(all_files, n)
84
+
85
+ for file in tqdm(files_to_copy, desc="Copying files"):
86
+ source_path = os.path.join(dir1, file)
87
+ destination_path = os.path.join(dir2, file)
88
+ shutil.copy(source_path, destination_path)
89
+
90
+ return True
91
+
92
+ except Exception as copy_ex:
93
+ print(f"Error occurred while copying files: {str(copy_ex)}.")
94
+ return False
95
+
96
+
97
+ ######################
98
+ #
99
+ ######################
100
+ def copy_n_unique_files(dir1, dir2, output_dir, n):
101
+ """
102
+ Desc:
103
+ This method iterates files in dir1 and checks if they are not present in dir2. If not present, copies the file to output_dir. Moves 'n' files in total.
104
+ Args:
105
+ dir1 (str): Path to directory 1.
106
+ dir2 (str): Path to directory 2.
107
+ output_dir (str): Path to the destination directory.
108
+ n (int): The number of random files to be moved.
109
+ Returns:
110
+ True, if the operation was successful, otherwise False.
111
+ """
112
+ try:
113
+ # List all files in dir1 and dir2
114
+ dir1_files = [f for f in os.listdir(dir1) if os.path.isfile(os.path.join(dir1, f))]
115
+ dir2_files = [f for f in os.listdir(dir2) if os.path.isfile(os.path.join(dir2, f))]
116
+
117
+ # Filter out files that already exist in dir2
118
+ unique_files = [f for f in dir1_files if f not in dir2_files]
119
+
120
+ # If no unique files are found
121
+ if not unique_files:
122
+ print("No unique files to move.")
123
+ return False
124
+
125
+ # Randomly select 'n' files to copy (make sure we don't select more than we have)
126
+ files_to_copy = random.sample(unique_files, min(n, len(unique_files)))
127
+
128
+ # Copy selected files to output_dir
129
+ files_copied = 0
130
+ for file in files_to_copy:
131
+ src_path = os.path.join(dir1, file)
132
+ dest_path = os.path.join(output_dir, file)
133
+ shutil.copy(src_path, dest_path)
134
+ files_copied += 1
135
+ print(f"Copied: {file}")
136
+
137
+ print(f"Total files copied: {files_copied}")
138
+ return True
139
+
140
+ except Exception as copy_ex:
141
+ print(f"An error occurred while copying: {copy_ex}")
142
+ return False
143
+
144
+
145
+ ######################
146
+ #
147
+ ######################
148
+ if __name__ == "__main__":
149
+ dir1 = ""
150
+ dir2 = ""
151
+ n = 0
152
+
153
+ move_n_random_files(dir1, dir2, n)
preprocess/rename.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+
4
+
5
+ def rename_files(input_dir: str, word: str, new_word: str = '') -> bool:
6
+ """
7
+ Desc:
8
+ This method renames the files that has {word} in the filename
9
+ Args:
10
+ input_dir (str): Path to the input directory
11
+ word (str): A word to look for in the filename
12
+ new_word (str): A new word that is used to replace
13
+ Returns:
14
+ True, if renaming operation is success, else False.
15
+ """
16
+ try:
17
+ if not os.path.isdir(input_dir):
18
+ print(f"The directory {input_dir} does not exist.")
19
+ return False
20
+
21
+ for file in tqdm(os.listdir(input_dir)):
22
+ basename, ext = os.path.splitext(file)
23
+
24
+ if word in basename:
25
+ new_name = f"{basename.replace(word, new_word)}{ext}"
26
+
27
+ # Construct full paths for renaming
28
+ old_file_path = os.path.join(input_dir, file)
29
+ new_file_path = os.path.join(input_dir, new_name)
30
+
31
+ try:
32
+ # Rename the file
33
+ os.rename(old_file_path, new_file_path)
34
+ except:
35
+ continue
36
+
37
+ return True
38
+
39
+ except Exception as rename_ex:
40
+ print(f"An error occurred while renaming: {rename_ex}")
41
+ return False
42
+
43
+
44
+
45
+
46
+ if __name__ == "__main__":
47
+ inp_dir = ""
48
+ word = ""
49
+ new_word = ""
50
+ rename_files(inp_dir, word, new_word)
51
+
requirements.txt ADDED
Binary file (4.96 kB). View file
 
stage1/inference_cvt13.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import datasets, transforms
4
+ from torch.utils.data import DataLoader
5
+ import sys
6
+ import os
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from pathlib import Path
10
+ import json
11
+ from datetime import datetime
12
+ from sklearn.metrics import (
13
+ accuracy_score, precision_score, recall_score, f1_score,
14
+ confusion_matrix, classification_report, roc_curve, auc,
15
+ precision_recall_curve, average_precision_score, roc_auc_score
16
+ )
17
+ import seaborn as sns
18
+ import pandas as pd
19
+
20
+
21
+ # ============================================================
22
+ # CONFIGURATION
23
+ # ============================================================
24
+
25
+ BASE_DIR = "path_to_CornViT"
26
+
27
+ # Path to the Microsoft CvT repository
28
+ CVT_REPO_PATH = f"{BASE_DIR}/CvT"
29
+
30
+ # Model configuration
31
+ IMG_SIZE = 384
32
+ NUM_CLASSES = 2
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ RUN = "cvt13_run_2025xxxx_xxxxxx"
36
+
37
+ # Path to trained model
38
+ MODEL_PATH = f"metrics/{RUN}/train/best_model.pth"
39
+
40
+ # Test data folder (should have subfolders for each class like train/val structure)
41
+ TEST_DATA_DIR = f"{BASE_DIR}/stage1/data/test"
42
+
43
+ # Class names (update these to match your dataset)
44
+ CLASS_NAMES = ["Pure", "Impure"]
45
+
46
+ # Output directory for evaluation results (within the same metrics folder)
47
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
48
+ EVAL_OUTPUT_DIR = f"metrics/{RUN}/evals/eval_{timestamp}"
49
+ os.makedirs(EVAL_OUTPUT_DIR, exist_ok=True)
50
+
51
+
52
+ # ============================================================
53
+ # SETUP: Import CvT model
54
+ # ============================================================
55
+
56
+ # Fix torch._six compatibility
57
+ cls_cvt_path = os.path.join(CVT_REPO_PATH, "lib", "models", "cls_cvt.py")
58
+ if os.path.exists(cls_cvt_path):
59
+ with open(cls_cvt_path, 'r', encoding='utf-8') as f:
60
+ content = f.read()
61
+
62
+ if "from torch._six import container_abcs" in content:
63
+ content = content.replace(
64
+ "from torch._six import container_abcs",
65
+ "import collections.abc as container_abcs"
66
+ )
67
+ content = content.replace(
68
+ "or pretrained_layers[0] is '*'",
69
+ "or pretrained_layers[0] == '*'"
70
+ )
71
+ with open(cls_cvt_path, 'w', encoding='utf-8') as f:
72
+ f.write(content)
73
+
74
+ sys.path.insert(0, CVT_REPO_PATH)
75
+
76
+ import warnings
77
+ warnings.filterwarnings('ignore', category=SyntaxWarning)
78
+
79
+ from lib.models import cls_cvt
80
+ from lib.config import config, update_config
81
+
82
+
83
+ # ============================================================
84
+ # MODEL LOADING
85
+ # ============================================================
86
+
87
+ def load_model(model_path, config_path=None):
88
+ """Load the trained CvT model"""
89
+
90
+ # Load config
91
+ if config_path is None:
92
+ config_path = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt", "cvt-13-384x384.yaml")
93
+
94
+ config.defrost()
95
+ config.merge_from_file(config_path)
96
+ config.MODEL.NUM_CLASSES = NUM_CLASSES
97
+ config.MODEL.PRETRAINED = ''
98
+ config.freeze()
99
+
100
+ # Create model
101
+ model = cls_cvt.get_cls_model(config)
102
+
103
+ # Load trained weights
104
+ checkpoint = torch.load(model_path, map_location=DEVICE)
105
+ if 'model_state_dict' in checkpoint:
106
+ model.load_state_dict(checkpoint['model_state_dict'])
107
+ else:
108
+ model.load_state_dict(checkpoint)
109
+
110
+ model = model.to(DEVICE)
111
+ model.eval()
112
+
113
+ print(f"✅ Model loaded from: {model_path}")
114
+ return model
115
+
116
+
117
+ # ============================================================
118
+ # DATA LOADING
119
+ # ============================================================
120
+
121
+ def get_test_dataloader(test_dir, batch_size=32):
122
+ """Create test dataloader"""
123
+ test_transforms = transforms.Compose([
124
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
125
+ transforms.ToTensor(),
126
+ transforms.Normalize([0.485, 0.456, 0.406],
127
+ [0.229, 0.224, 0.225])
128
+ ])
129
+
130
+ test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
131
+ test_loader = DataLoader(test_dataset, batch_size=batch_size,
132
+ shuffle=False, num_workers=0, pin_memory=True)
133
+
134
+ print(f"✅ Test dataset loaded: {len(test_dataset)} images")
135
+ print(f" Classes: {test_dataset.classes}")
136
+ return test_loader, test_dataset
137
+
138
+
139
+ # ============================================================
140
+ # EVALUATION FUNCTIONS
141
+ # ============================================================
142
+
143
+ def evaluate_model(model, test_loader, test_dataset):
144
+ """
145
+ Evaluate model with single image predictions
146
+
147
+ Returns:
148
+ all_preds: Predicted class labels
149
+ all_labels: Ground truth labels
150
+ all_probs: Predicted probabilities for all classes
151
+ all_confidences: Confidence scores
152
+ image_paths: List of image paths
153
+ """
154
+ model.eval()
155
+
156
+ all_preds = []
157
+ all_labels = []
158
+ all_probs = []
159
+ all_confidences = []
160
+ image_paths = []
161
+
162
+ print("\n🔍 Running single-image inference on test set...")
163
+
164
+ # Process each image individually
165
+ total_images = len(test_dataset)
166
+
167
+ for idx in range(total_images):
168
+ # Get single image and label
169
+ image, label = test_dataset[idx]
170
+ img_path, _ = test_dataset.samples[idx]
171
+
172
+ # Add batch dimension and move to device
173
+ image = image.unsqueeze(0).to(DEVICE)
174
+
175
+ with torch.no_grad():
176
+ # Forward pass
177
+ output = model(image)
178
+
179
+ # Ensure output has correct shape
180
+ if output.dim() == 1:
181
+ output = output.unsqueeze(0)
182
+
183
+ probabilities = torch.softmax(output, dim=1)
184
+ confidence, predicted = torch.max(probabilities, 1)
185
+
186
+ # Collect results
187
+ all_preds.append(predicted.item())
188
+ all_labels.append(label)
189
+ all_probs.append(probabilities.cpu().numpy()[0])
190
+ all_confidences.append(confidence.item())
191
+ image_paths.append(img_path)
192
+
193
+ # Progress update
194
+ if (idx + 1) % 50 == 0 or (idx + 1) == total_images:
195
+ print(f" Processed {idx + 1}/{total_images} images...")
196
+
197
+ print(f"✅ Inference complete: {len(all_preds)} predictions")
198
+
199
+ return (np.array(all_preds), np.array(all_labels), np.array(all_probs),
200
+ np.array(all_confidences), image_paths)
201
+
202
+
203
+ # ============================================================
204
+ # METRICS CALCULATION
205
+ # ============================================================
206
+
207
+ def calculate_metrics(y_true, y_pred, y_probs):
208
+ """Calculate all classification metrics"""
209
+
210
+ metrics = {}
211
+
212
+ # Basic metrics
213
+ metrics['accuracy'] = accuracy_score(y_true, y_pred)
214
+ metrics['precision_macro'] = precision_score(y_true, y_pred, average='macro', zero_division=0)
215
+ metrics['precision_weighted'] = precision_score(y_true, y_pred, average='weighted', zero_division=0)
216
+ metrics['recall_macro'] = recall_score(y_true, y_pred, average='macro', zero_division=0)
217
+ metrics['recall_weighted'] = recall_score(y_true, y_pred, average='weighted', zero_division=0)
218
+ metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro', zero_division=0)
219
+ metrics['f1_weighted'] = f1_score(y_true, y_pred, average='weighted', zero_division=0)
220
+
221
+ # Per-class metrics
222
+ precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
223
+ recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
224
+ f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
225
+
226
+ metrics['per_class'] = {}
227
+ for i, class_name in enumerate(CLASS_NAMES):
228
+ metrics['per_class'][class_name] = {
229
+ 'precision': float(precision_per_class[i]),
230
+ 'recall': float(recall_per_class[i]),
231
+ 'f1_score': float(f1_per_class[i])
232
+ }
233
+
234
+ # ROC-AUC (for binary and multi-class)
235
+ if NUM_CLASSES == 2:
236
+ metrics['roc_auc'] = roc_auc_score(y_true, y_probs[:, 1])
237
+ metrics['average_precision'] = average_precision_score(y_true, y_probs[:, 1])
238
+ else:
239
+ metrics['roc_auc_ovr'] = roc_auc_score(y_true, y_probs, multi_class='ovr', average='macro')
240
+ metrics['roc_auc_ovo'] = roc_auc_score(y_true, y_probs, multi_class='ovo', average='macro')
241
+
242
+ return metrics
243
+
244
+
245
+ # ============================================================
246
+ # VISUALIZATION FUNCTIONS
247
+ # ============================================================
248
+
249
+ def plot_confusion_matrix(y_true, y_pred, save_path):
250
+ """Plot and save confusion matrix"""
251
+ cm = confusion_matrix(y_true, y_pred)
252
+
253
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6))
254
+
255
+ # Raw counts
256
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
257
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
258
+ ax=axes[0], cbar_kws={'label': 'Count'})
259
+ axes[0].set_xlabel('Predicted Label', fontsize=12)
260
+ axes[0].set_ylabel('True Label', fontsize=12)
261
+ axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
262
+
263
+ # Normalized
264
+ cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
265
+ sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
266
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
267
+ ax=axes[1], cbar_kws={'label': 'Percentage'})
268
+ axes[1].set_xlabel('Predicted Label', fontsize=12)
269
+ axes[1].set_ylabel('True Label', fontsize=12)
270
+ axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
271
+
272
+ plt.tight_layout()
273
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
274
+ print(f"📊 Confusion matrix saved to: {save_path}")
275
+ plt.close()
276
+
277
+ return cm
278
+
279
+
280
+ def plot_roc_curve(y_true, y_probs, save_path):
281
+ """Plot ROC curve"""
282
+ fig, ax = plt.subplots(figsize=(10, 8))
283
+
284
+ if NUM_CLASSES == 2:
285
+ # Binary classification
286
+ fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1])
287
+ roc_auc = auc(fpr, tpr)
288
+
289
+ ax.plot(fpr, tpr, color='darkorange', lw=2,
290
+ label=f'ROC curve (AUC = {roc_auc:.3f})')
291
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
292
+
293
+ else:
294
+ # Multi-class (one-vs-rest)
295
+ for i, class_name in enumerate(CLASS_NAMES):
296
+ y_true_binary = (y_true == i).astype(int)
297
+ fpr, tpr, _ = roc_curve(y_true_binary, y_probs[:, i])
298
+ roc_auc = auc(fpr, tpr)
299
+ ax.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
300
+
301
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
302
+
303
+ ax.set_xlim([0.0, 1.0])
304
+ ax.set_ylim([0.0, 1.05])
305
+ ax.set_xlabel('False Positive Rate', fontsize=12)
306
+ ax.set_ylabel('True Positive Rate', fontsize=12)
307
+ ax.set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=14, fontweight='bold')
308
+ ax.legend(loc="lower right", fontsize=10)
309
+ ax.grid(alpha=0.3)
310
+
311
+ plt.tight_layout()
312
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
313
+ print(f"📊 ROC curve saved to: {save_path}")
314
+ plt.close()
315
+
316
+
317
+ def plot_precision_recall_curve(y_true, y_probs, save_path):
318
+ """Plot Precision-Recall curve"""
319
+ fig, ax = plt.subplots(figsize=(10, 8))
320
+
321
+ if NUM_CLASSES == 2:
322
+ # Binary classification
323
+ precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1])
324
+ avg_precision = average_precision_score(y_true, y_probs[:, 1])
325
+
326
+ ax.plot(recall, precision, color='darkorange', lw=2,
327
+ label=f'PR curve (AP = {avg_precision:.3f})')
328
+
329
+ else:
330
+ # Multi-class
331
+ for i, class_name in enumerate(CLASS_NAMES):
332
+ y_true_binary = (y_true == i).astype(int)
333
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_probs[:, i])
334
+ avg_precision = average_precision_score(y_true_binary, y_probs[:, i])
335
+ ax.plot(recall, precision, lw=2,
336
+ label=f'{class_name} (AP = {avg_precision:.3f})')
337
+
338
+ ax.set_xlim([0.0, 1.0])
339
+ ax.set_ylim([0.0, 1.05])
340
+ ax.set_xlabel('Recall', fontsize=12)
341
+ ax.set_ylabel('Precision', fontsize=12)
342
+ ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
343
+ ax.legend(loc="lower left", fontsize=10)
344
+ ax.grid(alpha=0.3)
345
+
346
+ plt.tight_layout()
347
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
348
+ print(f"📊 Precision-Recall curve saved to: {save_path}")
349
+ plt.close()
350
+
351
+
352
+ def plot_class_distribution(y_true, y_pred, save_path):
353
+ """Plot class distribution comparison"""
354
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
355
+
356
+ # True distribution
357
+ true_counts = [np.sum(y_true == i) for i in range(NUM_CLASSES)]
358
+ axes[0].bar(CLASS_NAMES, true_counts, color='steelblue', alpha=0.7)
359
+ axes[0].set_ylabel('Count', fontsize=12)
360
+ axes[0].set_title('True Label Distribution', fontsize=14, fontweight='bold')
361
+ axes[0].grid(axis='y', alpha=0.3)
362
+ for i, count in enumerate(true_counts):
363
+ axes[0].text(i, count + max(true_counts)*0.01, str(count),
364
+ ha='center', va='bottom', fontweight='bold')
365
+
366
+ # Predicted distribution
367
+ pred_counts = [np.sum(y_pred == i) for i in range(NUM_CLASSES)]
368
+ axes[1].bar(CLASS_NAMES, pred_counts, color='coral', alpha=0.7)
369
+ axes[1].set_ylabel('Count', fontsize=12)
370
+ axes[1].set_title('Predicted Label Distribution', fontsize=14, fontweight='bold')
371
+ axes[1].grid(axis='y', alpha=0.3)
372
+ for i, count in enumerate(pred_counts):
373
+ axes[1].text(i, count + max(pred_counts)*0.01, str(count),
374
+ ha='center', va='bottom', fontweight='bold')
375
+
376
+ plt.tight_layout()
377
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
378
+ print(f"📊 Class distribution saved to: {save_path}")
379
+ plt.close()
380
+
381
+
382
+ def plot_per_class_metrics(metrics, save_path):
383
+ """Plot per-class performance metrics"""
384
+ classes = list(metrics['per_class'].keys())
385
+ precision_vals = [metrics['per_class'][c]['precision'] for c in classes]
386
+ recall_vals = [metrics['per_class'][c]['recall'] for c in classes]
387
+ f1_vals = [metrics['per_class'][c]['f1_score'] for c in classes]
388
+
389
+ x = np.arange(len(classes))
390
+ width = 0.25
391
+
392
+ fig, ax = plt.subplots(figsize=(12, 7))
393
+
394
+ bars1 = ax.bar(x - width, precision_vals, width, label='Precision', color='steelblue', alpha=0.8)
395
+ bars2 = ax.bar(x, recall_vals, width, label='Recall', color='coral', alpha=0.8)
396
+ bars3 = ax.bar(x + width, f1_vals, width, label='F1-Score', color='lightgreen', alpha=0.8)
397
+
398
+ ax.set_ylabel('Score', fontsize=12)
399
+ ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
400
+ ax.set_xticks(x)
401
+ ax.set_xticklabels(classes)
402
+ ax.legend(fontsize=11)
403
+ ax.set_ylim([0, 1.1])
404
+ ax.grid(axis='y', alpha=0.3)
405
+
406
+ # Add value labels on bars
407
+ def autolabel(bars):
408
+ for bar in bars:
409
+ height = bar.get_height()
410
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
411
+ f'{height:.3f}', ha='center', va='bottom', fontsize=9)
412
+
413
+ autolabel(bars1)
414
+ autolabel(bars2)
415
+ autolabel(bars3)
416
+
417
+ plt.tight_layout()
418
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
419
+ print(f"📊 Per-class metrics saved to: {save_path}")
420
+ plt.close()
421
+
422
+
423
+ def plot_confidence_distribution(y_true, y_pred, confidences, save_path):
424
+ """Plot confidence score distribution for correct vs incorrect predictions"""
425
+ # Confidence scores are already extracted
426
+ correct = (y_true == y_pred)
427
+
428
+ fig, axes = plt.subplots(2, 1, figsize=(12, 10))
429
+
430
+ # Histogram
431
+ axes[0].hist(confidences[correct], bins=50, alpha=0.7, label='Correct',
432
+ color='green', edgecolor='black')
433
+ axes[0].hist(confidences[~correct], bins=50, alpha=0.7, label='Incorrect',
434
+ color='red', edgecolor='black')
435
+ axes[0].set_xlabel('Confidence Score', fontsize=12)
436
+ axes[0].set_ylabel('Frequency', fontsize=12)
437
+ axes[0].set_title('Confidence Distribution: Correct vs Incorrect Predictions',
438
+ fontsize=14, fontweight='bold')
439
+ axes[0].legend(fontsize=11)
440
+ axes[0].grid(alpha=0.3)
441
+
442
+ # Box plot
443
+ data_to_plot = [confidences[correct], confidences[~correct]]
444
+ box = axes[1].boxplot(data_to_plot, labels=['Correct', 'Incorrect'],
445
+ patch_artist=True, showmeans=True)
446
+ box['boxes'][0].set_facecolor('lightgreen')
447
+ box['boxes'][1].set_facecolor('lightcoral')
448
+ axes[1].set_ylabel('Confidence Score', fontsize=12)
449
+ axes[1].set_title('Confidence Score Box Plot', fontsize=14, fontweight='bold')
450
+ axes[1].grid(axis='y', alpha=0.3)
451
+
452
+ # Add statistics
453
+ correct_mean = np.mean(confidences[correct])
454
+ incorrect_mean = np.mean(confidences[~correct]) if (~correct).sum() > 0 else 0
455
+ axes[1].text(1, correct_mean, f'μ={correct_mean:.3f}',
456
+ ha='right', va='center', fontweight='bold', fontsize=10)
457
+ if (~correct).sum() > 0:
458
+ axes[1].text(2, incorrect_mean, f'μ={incorrect_mean:.3f}',
459
+ ha='left', va='center', fontweight='bold', fontsize=10)
460
+
461
+ plt.tight_layout()
462
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
463
+ print(f"📊 Confidence distribution saved to: {save_path}")
464
+ plt.close()
465
+
466
+
467
+ # ============================================================
468
+ # RESULTS SAVING
469
+ # ============================================================
470
+
471
+ def save_predictions_to_csv(image_paths, y_true, y_pred, y_probs, confidences, save_path):
472
+ """Save detailed predictions to CSV"""
473
+ results = []
474
+
475
+ for img_path, true_label, pred, probs, conf in zip(image_paths, y_true, y_pred, y_probs, confidences):
476
+ result = {
477
+ 'image_path': img_path,
478
+ 'image_name': os.path.basename(img_path),
479
+ 'true_label': CLASS_NAMES[true_label],
480
+ 'true_label_idx': true_label,
481
+ 'predicted_label': CLASS_NAMES[pred],
482
+ 'predicted_label_idx': pred,
483
+ 'confidence': conf,
484
+ 'correct': pred == true_label
485
+ }
486
+
487
+ # Add probabilities for each class
488
+ for i, class_name in enumerate(CLASS_NAMES):
489
+ result[f'prob_{class_name}'] = probs[i]
490
+
491
+ results.append(result)
492
+
493
+ df = pd.DataFrame(results)
494
+ df.to_csv(save_path, index=False)
495
+ print(f"💾 Predictions saved to: {save_path}")
496
+
497
+ # Print some statistics
498
+ print(f"\n📊 Prediction Statistics:")
499
+ print(f" Total images: {len(df)}")
500
+ print(f" Correct predictions: {df['correct'].sum()} ({df['correct'].sum()/len(df)*100:.2f}%)")
501
+ print(f" Incorrect predictions: {(~df['correct']).sum()} ({(~df['correct']).sum()/len(df)*100:.2f}%)")
502
+ print(f" Average confidence: {df['confidence'].mean():.4f}")
503
+ print(f" Confidence on correct: {df[df['correct']]['confidence'].mean():.4f}")
504
+ print(f" Confidence on incorrect: {df[~df['correct']]['confidence'].mean():.4f}" if (~df['correct']).sum() > 0 else "")
505
+
506
+ return df
507
+
508
+
509
+ def save_metrics_json(metrics, save_path):
510
+ """Save metrics to JSON file"""
511
+ with open(save_path, 'w') as f:
512
+ json.dump(metrics, f, indent=4)
513
+ print(f"💾 Metrics saved to: {save_path}")
514
+
515
+
516
+ def generate_classification_report_file(y_true, y_pred, save_path):
517
+ """Generate and save sklearn classification report"""
518
+ report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, digits=4)
519
+
520
+ with open(save_path, 'w') as f:
521
+ f.write("="*60 + "\n")
522
+ f.write("CLASSIFICATION REPORT\n")
523
+ f.write("="*60 + "\n\n")
524
+ f.write(report)
525
+
526
+ print(f"📄 Classification report saved to: {save_path}")
527
+
528
+
529
+ # ============================================================
530
+ # MAIN EVALUATION PIPELINE
531
+ # ============================================================
532
+
533
+ def main():
534
+ """Main evaluation pipeline"""
535
+
536
+ print("\n" + "="*60)
537
+ print("CvT-13 MODEL EVALUATION PIPELINE")
538
+ print("Single Image Prediction Mode")
539
+ print("="*60 + "\n")
540
+
541
+ # Load model
542
+ print("📦 Loading model...")
543
+ model = load_model(MODEL_PATH)
544
+
545
+ # Load test data
546
+ print("\n📂 Loading test data...")
547
+ test_loader, test_dataset = get_test_dataloader(TEST_DATA_DIR, batch_size=1)
548
+
549
+ # Run evaluation with single image predictions
550
+ print("\n🔍 Evaluating model (single image predictions)...")
551
+ y_pred, y_true, y_probs, confidences, image_paths = evaluate_model(model, test_loader, test_dataset)
552
+
553
+ # Calculate metrics
554
+ print("\n📊 Calculating metrics...")
555
+ metrics = calculate_metrics(y_true, y_pred, y_probs)
556
+
557
+ # Print key metrics
558
+ print("\n" + "="*60)
559
+ print("EVALUATION RESULTS")
560
+ print("="*60)
561
+ print(f"Total Images Evaluated: {len(y_pred)}")
562
+ print(f"Accuracy: {metrics['accuracy']*100:.2f}%")
563
+ print(f"Precision (Macro): {metrics['precision_macro']*100:.2f}%")
564
+ print(f"Recall (Macro): {metrics['recall_macro']*100:.2f}%")
565
+ print(f"F1-Score (Macro): {metrics['f1_macro']*100:.2f}%")
566
+ if 'roc_auc' in metrics:
567
+ print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
568
+ print("\nPer-Class Metrics:")
569
+ for class_name, class_metrics in metrics['per_class'].items():
570
+ print(f" {class_name}:")
571
+ print(f" Precision: {class_metrics['precision']*100:.2f}%")
572
+ print(f" Recall: {class_metrics['recall']*100:.2f}%")
573
+ print(f" F1-Score: {class_metrics['f1_score']*100:.2f}%")
574
+ print("="*60)
575
+
576
+ # Generate all visualizations
577
+ print("\n📊 Generating visualizations...")
578
+ plot_confusion_matrix(y_true, y_pred,
579
+ os.path.join(EVAL_OUTPUT_DIR, "confusion_matrix.png"))
580
+ plot_roc_curve(y_true, y_probs,
581
+ os.path.join(EVAL_OUTPUT_DIR, "roc_curve.png"))
582
+ plot_precision_recall_curve(y_true, y_probs,
583
+ os.path.join(EVAL_OUTPUT_DIR, "precision_recall_curve.png"))
584
+ plot_class_distribution(y_true, y_pred,
585
+ os.path.join(EVAL_OUTPUT_DIR, "class_distribution.png"))
586
+ plot_per_class_metrics(metrics,
587
+ os.path.join(EVAL_OUTPUT_DIR, "per_class_metrics.png"))
588
+ plot_confidence_distribution(y_true, y_pred, confidences,
589
+ os.path.join(EVAL_OUTPUT_DIR, "confidence_distribution.png"))
590
+
591
+ # Save results
592
+ print("\n💾 Saving results...")
593
+ df = save_predictions_to_csv(image_paths, y_true, y_pred, y_probs, confidences,
594
+ os.path.join(EVAL_OUTPUT_DIR, "predictions.csv"))
595
+ save_metrics_json(metrics,
596
+ os.path.join(EVAL_OUTPUT_DIR, "metrics.json"))
597
+ generate_classification_report_file(y_true, y_pred,
598
+ os.path.join(EVAL_OUTPUT_DIR, "classification_report.txt"))
599
+
600
+ # Save misclassified images list
601
+ misclassified = df[~df['correct']]
602
+ if len(misclassified) > 0:
603
+ misclassified_path = os.path.join(EVAL_OUTPUT_DIR, "misclassified_images.csv")
604
+ misclassified.to_csv(misclassified_path, index=False)
605
+ print(f"⚠️ Misclassified images saved to: {misclassified_path}")
606
+ print(f" Total misclassified: {len(misclassified)}")
607
+
608
+ # Save low confidence predictions
609
+ low_conf_threshold = 0.7
610
+ low_confidence = df[df['confidence'] < low_conf_threshold]
611
+ if len(low_confidence) > 0:
612
+ low_conf_path = os.path.join(EVAL_OUTPUT_DIR, "low_confidence_predictions.csv")
613
+ low_confidence.to_csv(low_conf_path, index=False)
614
+ print(f"⚠️ Low confidence predictions saved to: {low_conf_path}")
615
+ print(f" Total with confidence < {low_conf_threshold}: {len(low_confidence)}")
616
+
617
+ print("\n" + "="*60)
618
+ print(f"✅ Evaluation complete!")
619
+ print(f"📁 All results saved to: {EVAL_OUTPUT_DIR}")
620
+ print("="*60 + "\n")
621
+
622
+ print("Generated files:")
623
+ print(" 📊 confusion_matrix.png - Confusion matrix visualization")
624
+ print(" 📊 roc_curve.png - ROC curve")
625
+ print(" 📊 precision_recall_curve.png - Precision-Recall curve")
626
+ print(" 📊 class_distribution.png - Class distribution comparison")
627
+ print(" 📊 per_class_metrics.png - Per-class performance")
628
+ print(" 📊 confidence_distribution.png - Confidence analysis")
629
+ print(" 💾 predictions.csv - Detailed predictions for each image")
630
+ print(" 💾 misclassified_images.csv - List of incorrectly classified images")
631
+ print(" 💾 low_confidence_predictions.csv - Predictions with low confidence")
632
+ print(" 💾 metrics.json - All metrics in JSON format")
633
+ print(" 📄 classification_report.txt - Sklearn classification report")
634
+
635
+
636
+ if __name__ == '__main__':
637
+ main()
stage1/stage1_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:445c5a5b94b86649cab12ef3c2fe4df9461f9879864c43d52a7cc9560204fcc3
3
+ size 78733538
stage1/train_cvt13.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ import sys
7
+ import os
8
+ from timm.loss import SoftTargetCrossEntropy
9
+ from timm.scheduler import CosineLRScheduler
10
+ from timm.utils import accuracy
11
+ import matplotlib.pyplot as plt
12
+ import json
13
+ from datetime import datetime
14
+
15
+
16
+ # ============================================================
17
+ # SETUP: Clone and import from Microsoft CvT repository
18
+ # ============================================================
19
+ """
20
+ First, clone the Microsoft CvT repository:
21
+ git clone https://github.com/microsoft/CvT.git
22
+ cd CvT
23
+ pip install -r requirements.txt
24
+ """
25
+
26
+ BASE_DIR = "path_to_CornViT"
27
+
28
+ # Add the CvT repo to Python path
29
+ CVT_REPO_PATH = f"{BASE_DIR}/CvT"
30
+
31
+ if not os.path.exists(CVT_REPO_PATH):
32
+ print(f"❌ CvT repository not found at {CVT_REPO_PATH}")
33
+ print("Please clone it: git clone https://github.com/microsoft/CvT.git")
34
+ sys.exit(1)
35
+
36
+ # Fix torch._six compatibility BEFORE importing
37
+ print("Applying compatibility fixes for newer PyTorch versions...")
38
+ cls_cvt_path = os.path.join(CVT_REPO_PATH, "lib", "models", "cls_cvt.py")
39
+
40
+ if os.path.exists(cls_cvt_path):
41
+ with open(cls_cvt_path, 'r', encoding='utf-8') as f:
42
+ content = f.read()
43
+
44
+ # Fix 1: Replace torch._six import
45
+ if "from torch._six import container_abcs" in content:
46
+ content = content.replace(
47
+ "from torch._six import container_abcs",
48
+ "import collections.abc as container_abcs"
49
+ )
50
+
51
+ # Fix 2: Replace 'is' with '==' for string comparison
52
+ content = content.replace(
53
+ "or pretrained_layers[0] is '*'",
54
+ "or pretrained_layers[0] == '*'"
55
+ )
56
+
57
+ with open(cls_cvt_path, 'w', encoding='utf-8') as f:
58
+ f.write(content)
59
+ print("✅ Applied compatibility patches to cls_cvt.py")
60
+ else:
61
+ print("✅ Compatibility patches already applied")
62
+ else:
63
+ print(f"❌ Could not find cls_cvt.py at {cls_cvt_path}")
64
+ sys.exit(1)
65
+
66
+ # Now import
67
+ sys.path.insert(0, CVT_REPO_PATH)
68
+
69
+ # Suppress the SyntaxWarning
70
+ import warnings
71
+ warnings.filterwarnings('ignore', category=SyntaxWarning)
72
+
73
+ from lib.models import cls_cvt
74
+ from lib.config import config, update_config
75
+ print("✅ Successfully imported Microsoft CvT models")
76
+
77
+
78
+ # ============================================================
79
+ # CONFIGURATION
80
+ # ============================================================
81
+
82
+ DATA_DIR = f"{BASE_DIR}/stage1/data"
83
+ BATCH_SIZE = 32
84
+ IMG_SIZE = 384
85
+ NUM_CLASSES = 2
86
+ NUM_EPOCHS = 100
87
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
88
+ PRETRAINED_PATH = f"{BASE_DIR}/CvT-13-384x384-IN-22k.pth"
89
+
90
+ # Create output directory for saving results
91
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
92
+ OUTPUT_DIR = f"metrics/cvt13_run_{timestamp}"
93
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
94
+ print(f"Metrics will be saved to: {OUTPUT_DIR}")
95
+
96
+
97
+ # ============================================================
98
+ # DATASET & AUGMENTATION
99
+ # ============================================================
100
+
101
+ train_transforms = transforms.Compose([
102
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
103
+ transforms.RandomHorizontalFlip(),
104
+ transforms.RandomVerticalFlip(),
105
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
106
+ transforms.RandomRotation(15),
107
+ transforms.ToTensor(),
108
+ transforms.Normalize([0.485, 0.456, 0.406],
109
+ [0.229, 0.224, 0.225])
110
+ ])
111
+
112
+ val_transforms = transforms.Compose([
113
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize([0.485, 0.456, 0.406],
116
+ [0.229, 0.224, 0.225])
117
+ ])
118
+
119
+ train_dataset = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_transforms)
120
+ val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=val_transforms)
121
+
122
+ train_loader = DataLoader(
123
+ train_dataset,
124
+ batch_size=BATCH_SIZE,
125
+ shuffle=True,
126
+ num_workers=0,
127
+ pin_memory=True,
128
+ drop_last=True
129
+ )
130
+ val_loader = DataLoader(
131
+ val_dataset,
132
+ batch_size=BATCH_SIZE,
133
+ shuffle=False,
134
+ num_workers=0,
135
+ pin_memory=True,
136
+ drop_last=True
137
+ )
138
+
139
+
140
+ # ============================================================
141
+ # MODEL SETUP - Using Microsoft CvT Implementation
142
+ # ============================================================
143
+
144
+ # Load the CvT-13 config from the repository
145
+ cvt_config_path = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt", "cvt-13-384x384.yaml")
146
+
147
+ if not os.path.exists(cvt_config_path):
148
+ print(f"⚠️ Config file not found at {cvt_config_path}")
149
+ print("Available configs:")
150
+ config_dir = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt")
151
+ if os.path.exists(config_dir):
152
+ for f in os.listdir(config_dir):
153
+ if f.endswith('.yaml'):
154
+ print(f" - {f}")
155
+ sys.exit(1)
156
+
157
+ print(f"Loading config from: {cvt_config_path}")
158
+
159
+ # Load config directly using merge_from_file
160
+ config.defrost()
161
+ config.merge_from_file(cvt_config_path)
162
+
163
+ # Update the number of classes for our task
164
+ config.MODEL.NUM_CLASSES = NUM_CLASSES
165
+ config.MODEL.PRETRAINED = '' # We'll load weights manually
166
+ config.freeze()
167
+
168
+ print("Creating CvT-13 model...")
169
+ # Create model using the official CvT architecture
170
+ model = cls_cvt.get_cls_model(config)
171
+ model = model.to(DEVICE)
172
+
173
+ # Load pretrained weights
174
+ if os.path.exists(PRETRAINED_PATH):
175
+ print(f"Loading pretrained weights from {PRETRAINED_PATH}")
176
+ try:
177
+ checkpoint = torch.load(PRETRAINED_PATH, map_location=DEVICE)
178
+
179
+ # Handle different checkpoint formats
180
+ if 'model' in checkpoint:
181
+ state_dict = checkpoint['model']
182
+ elif 'state_dict' in checkpoint:
183
+ state_dict = checkpoint['state_dict']
184
+ else:
185
+ state_dict = checkpoint
186
+
187
+ # Remove 'module.' prefix if present
188
+ new_state_dict = {}
189
+ for k, v in state_dict.items():
190
+ name = k.replace("module.", "")
191
+ new_state_dict[name] = v
192
+
193
+ # Remove head layers from pretrained weights (they have different dimensions)
194
+ filtered_state_dict = {k: v for k, v in new_state_dict.items() if 'head' not in k}
195
+
196
+ # Load weights - strict=False will only load matching layers
197
+ missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
198
+
199
+ # Count how many weights were actually loaded
200
+ loaded_keys = [k for k in filtered_state_dict.keys() if k in model.state_dict()]
201
+ print(f"✅ Loaded pretrained weights: {len(loaded_keys)} layers from backbone")
202
+ print(f" Head layer initialized randomly for {NUM_CLASSES} classes")
203
+
204
+ # Show what's missing (should only be head-related)
205
+ head_missing = [k for k in missing_keys if 'head' in k]
206
+ other_missing = [k for k in missing_keys if 'head' not in k]
207
+
208
+ if other_missing:
209
+ print(f"⚠️ Warning - Missing non-head keys: {other_missing}")
210
+ if unexpected_keys:
211
+ print(f"⚠️ Unexpected keys: {unexpected_keys}")
212
+
213
+ except Exception as e:
214
+ print(f"⚠️ Error loading pretrained weights: {e}")
215
+ import traceback
216
+ traceback.print_exc()
217
+ print("Continuing with random initialization...")
218
+ else:
219
+ print(f"⚠️ Pretrained weights not found at {PRETRAINED_PATH}")
220
+ print("Training from scratch...")
221
+
222
+ # Freeze backbone - only train the head for faster training and less overfitting
223
+ print("Freezing backbone layers (keeping only head trainable)...")
224
+ for name, param in model.named_parameters():
225
+ if "head" not in name:
226
+ param.requires_grad = False
227
+
228
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
229
+ total_params = sum(p.numel() for p in model.parameters())
230
+ print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")
231
+ print(f"Frozen parameters: {total_params - trainable_params:,}")
232
+
233
+
234
+ # ============================================================
235
+ # OPTIMIZER AND LOSS
236
+ # ============================================================
237
+
238
+ optimizer = optim.AdamW(
239
+ filter(lambda p: p.requires_grad, model.parameters()),
240
+ lr=1e-4,
241
+ weight_decay=0.05
242
+ )
243
+
244
+ criterion = SoftTargetCrossEntropy()
245
+
246
+ lr_scheduler = CosineLRScheduler(
247
+ optimizer,
248
+ t_initial=NUM_EPOCHS,
249
+ lr_min=1e-6,
250
+ warmup_t=5,
251
+ warmup_lr_init=1e-5,
252
+ )
253
+
254
+
255
+ # ============================================================
256
+ # TRAINING & VALIDATION LOOP
257
+ # ============================================================
258
+
259
+ def train_one_epoch(epoch, history):
260
+ model.train()
261
+ total_loss, total_acc = 0, 0
262
+
263
+ for images, targets in train_loader:
264
+ images, targets = images.to(DEVICE), targets.to(DEVICE)
265
+
266
+ optimizer.zero_grad()
267
+ outputs = model(images)
268
+ loss = criterion(outputs, targets)
269
+ loss.backward()
270
+ optimizer.step()
271
+
272
+ acc1, _ = accuracy(outputs, targets.argmax(dim=1), topk=(1, 5))
273
+ total_loss += loss.item()
274
+ total_acc += acc1.item()
275
+
276
+ avg_loss = total_loss / len(train_loader)
277
+ avg_acc = total_acc / len(train_loader)
278
+
279
+ history['train_loss'].append(avg_loss)
280
+ history['train_acc'].append(avg_acc)
281
+ history['learning_rate'].append(optimizer.param_groups[0]['lr'])
282
+
283
+ print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Train Loss: {avg_loss:.4f} | Train Acc: {avg_acc:.2f}% | LR: {optimizer.param_groups[0]['lr']:.6f}")
284
+ return avg_loss, avg_acc
285
+
286
+
287
+ def validate(epoch, history):
288
+ model.eval()
289
+ total_loss, total_acc = 0, 0
290
+
291
+ with torch.no_grad():
292
+ for images, targets in val_loader:
293
+ images, targets = images.to(DEVICE), targets.to(DEVICE)
294
+ outputs = model(images)
295
+ loss = nn.CrossEntropyLoss()(outputs, targets)
296
+ acc1, _ = accuracy(outputs, targets, topk=(1, 5))
297
+
298
+ total_loss += loss.item()
299
+ total_acc += acc1.item()
300
+
301
+ avg_loss = total_loss / len(val_loader)
302
+ avg_acc = total_acc / len(val_loader)
303
+
304
+ history['val_loss'].append(avg_loss)
305
+ history['val_acc'].append(avg_acc)
306
+
307
+ print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Val Loss: {avg_loss:.4f} | Val Acc: {avg_acc:.2f}%")
308
+ return avg_acc
309
+
310
+
311
+ def plot_training_history(history, save_path):
312
+ """Plot and save training metrics"""
313
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
314
+
315
+ epochs = range(1, len(history['train_loss']) + 1)
316
+
317
+ # Plot 1: Loss
318
+ axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
319
+ axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
320
+ axes[0, 0].set_xlabel('Epoch', fontsize=12)
321
+ axes[0, 0].set_ylabel('Loss', fontsize=12)
322
+ axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
323
+ axes[0, 0].legend()
324
+ axes[0, 0].grid(True, alpha=0.3)
325
+
326
+ # Plot 2: Accuracy
327
+ axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
328
+ axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
329
+ axes[0, 1].set_xlabel('Epoch', fontsize=12)
330
+ axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
331
+ axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
332
+ axes[0, 1].legend()
333
+ axes[0, 1].grid(True, alpha=0.3)
334
+
335
+ # Plot 3: Learning Rate
336
+ axes[1, 0].plot(epochs, history['learning_rate'], 'g-', linewidth=2)
337
+ axes[1, 0].set_xlabel('Epoch', fontsize=12)
338
+ axes[1, 0].set_ylabel('Learning Rate', fontsize=12)
339
+ axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
340
+ axes[1, 0].set_yscale('log')
341
+ axes[1, 0].grid(True, alpha=0.3)
342
+
343
+ # Plot 4: Val Acc vs Train Acc (Overfitting check)
344
+ axes[1, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
345
+ axes[1, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
346
+ gap = [t - v for t, v in zip(history['train_acc'], history['val_acc'])]
347
+ axes[1, 1].fill_between(epochs, history['val_acc'], history['train_acc'],
348
+ alpha=0.3, color='orange', label='Overfitting Gap')
349
+ axes[1, 1].set_xlabel('Epoch', fontsize=12)
350
+ axes[1, 1].set_ylabel('Accuracy (%)', fontsize=12)
351
+ axes[1, 1].set_title('Overfitting Analysis', fontsize=14, fontweight='bold')
352
+ axes[1, 1].legend()
353
+ axes[1, 1].grid(True, alpha=0.3)
354
+
355
+ plt.tight_layout()
356
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
357
+ print(f"📊 Training plots saved to: {save_path}")
358
+ plt.close()
359
+
360
+
361
+ def save_training_summary(history, best_acc, save_path):
362
+ """Save training summary as JSON"""
363
+ summary = {
364
+ 'config': {
365
+ 'model': 'CvT-13',
366
+ 'batch_size': BATCH_SIZE,
367
+ 'img_size': IMG_SIZE,
368
+ 'num_classes': NUM_CLASSES,
369
+ 'num_epochs': NUM_EPOCHS,
370
+ 'device': DEVICE,
371
+ 'pretrained': PRETRAINED_PATH,
372
+ },
373
+ 'final_metrics': {
374
+ 'best_val_accuracy': best_acc,
375
+ 'final_train_loss': history['train_loss'][-1],
376
+ 'final_train_acc': history['train_acc'][-1],
377
+ 'final_val_loss': history['val_loss'][-1],
378
+ 'final_val_acc': history['val_acc'][-1],
379
+ },
380
+ 'history': history
381
+ }
382
+
383
+ with open(save_path, 'w') as f:
384
+ json.dump(summary, f, indent=4)
385
+
386
+ print(f"💾 Training summary saved to: {save_path}")
387
+
388
+
389
+ # ============================================================
390
+ # MAIN TRAINING LOOP
391
+ # ============================================================
392
+
393
+ if __name__ == '__main__':
394
+ print("\n" + "="*60)
395
+ print("STARTING TRAINING")
396
+ print("="*60 + "\n")
397
+
398
+ # Initialize history tracking
399
+ history = {
400
+ 'train_loss': [],
401
+ 'train_acc': [],
402
+ 'val_loss': [],
403
+ 'val_acc': [],
404
+ 'learning_rate': []
405
+ }
406
+
407
+ best_acc = 0.0
408
+ best_epoch = 0
409
+
410
+ for epoch in range(NUM_EPOCHS):
411
+ train_loss, train_acc = train_one_epoch(epoch, history)
412
+ val_acc = validate(epoch, history)
413
+ lr_scheduler.step(epoch + 1)
414
+
415
+ # Save best model
416
+ if val_acc > best_acc:
417
+ best_acc = val_acc
418
+ best_epoch = epoch + 1
419
+ torch.save({
420
+ 'epoch': epoch,
421
+ 'model_state_dict': model.state_dict(),
422
+ 'optimizer_state_dict': optimizer.state_dict(),
423
+ 'best_acc': best_acc,
424
+ 'history': history,
425
+ }, os.path.join(OUTPUT_DIR, "best_model.pth"))
426
+ print(f"✅ Saved best model at epoch {epoch+1} with val acc {best_acc:.2f}%\n")
427
+
428
+ # Save checkpoint every 10 epochs
429
+ if (epoch + 1) % 10 == 0:
430
+ torch.save({
431
+ 'epoch': epoch,
432
+ 'model_state_dict': model.state_dict(),
433
+ 'optimizer_state_dict': optimizer.state_dict(),
434
+ 'val_acc': val_acc,
435
+ 'history': history,
436
+ }, os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
437
+ print(f"💾 Checkpoint saved at epoch {epoch+1}\n")
438
+
439
+ # Plot and save metrics every 5 epochs
440
+ if (epoch + 1) % 5 == 0 or epoch == NUM_EPOCHS - 1:
441
+ plot_training_history(history, os.path.join(OUTPUT_DIR, "training_metrics.png"))
442
+
443
+ # Final summary
444
+ print("="*60)
445
+ print(f"🎉 Training complete!")
446
+ print(f"Best validation accuracy: {best_acc:.2f}% at epoch {best_epoch}")
447
+ print(f"Final train accuracy: {history['train_acc'][-1]:.2f}%")
448
+ print(f"Final val accuracy: {history['val_acc'][-1]:.2f}%")
449
+ print("="*60)
450
+
451
+ # Save final training summary
452
+ save_training_summary(history, best_acc, os.path.join(OUTPUT_DIR, "training_summary.json"))
453
+
454
+ # Save final plot
455
+ plot_training_history(history, os.path.join(OUTPUT_DIR, "final_training_metrics.png"))
456
+
457
+ print(f"\n📁 All outputs saved to: {OUTPUT_DIR}")
stage2/inference_cvt13.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import datasets, transforms
4
+ from torch.utils.data import DataLoader
5
+ import sys
6
+ import os
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from pathlib import Path
10
+ import json
11
+ from datetime import datetime
12
+ from sklearn.metrics import (
13
+ accuracy_score, precision_score, recall_score, f1_score,
14
+ confusion_matrix, classification_report, roc_curve, auc,
15
+ precision_recall_curve, average_precision_score, roc_auc_score
16
+ )
17
+ import seaborn as sns
18
+ import pandas as pd
19
+
20
+
21
+ # ============================================================
22
+ # CONFIGURATION
23
+ # ============================================================
24
+
25
+ BASE_DIR = "path_to_CornViT"
26
+
27
+ # Path to the Microsoft CvT repository
28
+ CVT_REPO_PATH = f"{BASE_DIR}/CvT"
29
+
30
+ # Model configuration
31
+ IMG_SIZE = 384
32
+ NUM_CLASSES = 2
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ RUN = "cvt13_run_2025xxxx_xxxxxx"
36
+
37
+ # Path to trained model
38
+ MODEL_PATH = f"metrics/{RUN}/train/best_model.pth"
39
+
40
+ # Test data folder (should have subfolders for each class like train/val structure)
41
+ TEST_DATA_DIR = f"{BASE_DIR}/stage2/data/test"
42
+
43
+ # Class names (update these to match your dataset)
44
+ CLASS_NAMES = ["Flat", "Round"]
45
+
46
+ # Output directory for evaluation results (within the same metrics folder)
47
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
48
+ EVAL_OUTPUT_DIR = f"metrics/{RUN}/evals/eval_{timestamp}"
49
+ os.makedirs(EVAL_OUTPUT_DIR, exist_ok=True)
50
+
51
+
52
+ # ============================================================
53
+ # SETUP: Import CvT model
54
+ # ============================================================
55
+
56
+ # Fix torch._six compatibility
57
+ cls_cvt_path = os.path.join(CVT_REPO_PATH, "lib", "models", "cls_cvt.py")
58
+ if os.path.exists(cls_cvt_path):
59
+ with open(cls_cvt_path, 'r', encoding='utf-8') as f:
60
+ content = f.read()
61
+
62
+ if "from torch._six import container_abcs" in content:
63
+ content = content.replace(
64
+ "from torch._six import container_abcs",
65
+ "import collections.abc as container_abcs"
66
+ )
67
+ content = content.replace(
68
+ "or pretrained_layers[0] is '*'",
69
+ "or pretrained_layers[0] == '*'"
70
+ )
71
+ with open(cls_cvt_path, 'w', encoding='utf-8') as f:
72
+ f.write(content)
73
+
74
+ sys.path.insert(0, CVT_REPO_PATH)
75
+
76
+ import warnings
77
+ warnings.filterwarnings('ignore', category=SyntaxWarning)
78
+
79
+ from lib.models import cls_cvt
80
+ from lib.config import config, update_config
81
+
82
+
83
+ # ============================================================
84
+ # MODEL LOADING
85
+ # ============================================================
86
+
87
+ def load_model(model_path, config_path=None):
88
+ """Load the trained CvT model"""
89
+
90
+ # Load config
91
+ if config_path is None:
92
+ config_path = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt", "cvt-13-384x384.yaml")
93
+
94
+ config.defrost()
95
+ config.merge_from_file(config_path)
96
+ config.MODEL.NUM_CLASSES = NUM_CLASSES
97
+ config.MODEL.PRETRAINED = ''
98
+ config.freeze()
99
+
100
+ # Create model
101
+ model = cls_cvt.get_cls_model(config)
102
+
103
+ # Load trained weights
104
+ checkpoint = torch.load(model_path, map_location=DEVICE)
105
+ if 'model_state_dict' in checkpoint:
106
+ model.load_state_dict(checkpoint['model_state_dict'])
107
+ else:
108
+ model.load_state_dict(checkpoint)
109
+
110
+ model = model.to(DEVICE)
111
+ model.eval()
112
+
113
+ print(f"✅ Model loaded from: {model_path}")
114
+ return model
115
+
116
+
117
+ # ============================================================
118
+ # DATA LOADING
119
+ # ============================================================
120
+
121
+ def get_test_dataloader(test_dir, batch_size=32):
122
+ """Create test dataloader"""
123
+ test_transforms = transforms.Compose([
124
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
125
+ transforms.ToTensor(),
126
+ transforms.Normalize([0.485, 0.456, 0.406],
127
+ [0.229, 0.224, 0.225])
128
+ ])
129
+
130
+ test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
131
+ test_loader = DataLoader(test_dataset, batch_size=batch_size,
132
+ shuffle=False, num_workers=0, pin_memory=True)
133
+
134
+ print(f"✅ Test dataset loaded: {len(test_dataset)} images")
135
+ print(f" Classes: {test_dataset.classes}")
136
+ return test_loader, test_dataset
137
+
138
+
139
+ # ============================================================
140
+ # EVALUATION FUNCTIONS
141
+ # ============================================================
142
+
143
+ def evaluate_model(model, test_loader, test_dataset):
144
+ """
145
+ Evaluate model with single image predictions
146
+
147
+ Returns:
148
+ all_preds: Predicted class labels
149
+ all_labels: Ground truth labels
150
+ all_probs: Predicted probabilities for all classes
151
+ all_confidences: Confidence scores
152
+ image_paths: List of image paths
153
+ """
154
+ model.eval()
155
+
156
+ all_preds = []
157
+ all_labels = []
158
+ all_probs = []
159
+ all_confidences = []
160
+ image_paths = []
161
+
162
+ print("\n🔍 Running single-image inference on test set...")
163
+
164
+ # Process each image individually
165
+ total_images = len(test_dataset)
166
+
167
+ for idx in range(total_images):
168
+ # Get single image and label
169
+ image, label = test_dataset[idx]
170
+ img_path, _ = test_dataset.samples[idx]
171
+
172
+ # Add batch dimension and move to device
173
+ image = image.unsqueeze(0).to(DEVICE)
174
+
175
+ with torch.no_grad():
176
+ # Forward pass
177
+ output = model(image)
178
+
179
+ # Ensure output has correct shape
180
+ if output.dim() == 1:
181
+ output = output.unsqueeze(0)
182
+
183
+ probabilities = torch.softmax(output, dim=1)
184
+ confidence, predicted = torch.max(probabilities, 1)
185
+
186
+ # Collect results
187
+ all_preds.append(predicted.item())
188
+ all_labels.append(label)
189
+ all_probs.append(probabilities.cpu().numpy()[0])
190
+ all_confidences.append(confidence.item())
191
+ image_paths.append(img_path)
192
+
193
+ # Progress update
194
+ if (idx + 1) % 50 == 0 or (idx + 1) == total_images:
195
+ print(f" Processed {idx + 1}/{total_images} images...")
196
+
197
+ print(f"✅ Inference complete: {len(all_preds)} predictions")
198
+
199
+ return (np.array(all_preds), np.array(all_labels), np.array(all_probs),
200
+ np.array(all_confidences), image_paths)
201
+
202
+
203
+ # ============================================================
204
+ # METRICS CALCULATION
205
+ # ============================================================
206
+
207
+ def calculate_metrics(y_true, y_pred, y_probs):
208
+ """Calculate all classification metrics"""
209
+
210
+ metrics = {}
211
+
212
+ # Basic metrics
213
+ metrics['accuracy'] = accuracy_score(y_true, y_pred)
214
+ metrics['precision_macro'] = precision_score(y_true, y_pred, average='macro', zero_division=0)
215
+ metrics['precision_weighted'] = precision_score(y_true, y_pred, average='weighted', zero_division=0)
216
+ metrics['recall_macro'] = recall_score(y_true, y_pred, average='macro', zero_division=0)
217
+ metrics['recall_weighted'] = recall_score(y_true, y_pred, average='weighted', zero_division=0)
218
+ metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro', zero_division=0)
219
+ metrics['f1_weighted'] = f1_score(y_true, y_pred, average='weighted', zero_division=0)
220
+
221
+ # Per-class metrics
222
+ precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
223
+ recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
224
+ f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
225
+
226
+ metrics['per_class'] = {}
227
+ for i, class_name in enumerate(CLASS_NAMES):
228
+ metrics['per_class'][class_name] = {
229
+ 'precision': float(precision_per_class[i]),
230
+ 'recall': float(recall_per_class[i]),
231
+ 'f1_score': float(f1_per_class[i])
232
+ }
233
+
234
+ # ROC-AUC (for binary and multi-class)
235
+ if NUM_CLASSES == 2:
236
+ metrics['roc_auc'] = roc_auc_score(y_true, y_probs[:, 1])
237
+ metrics['average_precision'] = average_precision_score(y_true, y_probs[:, 1])
238
+ else:
239
+ metrics['roc_auc_ovr'] = roc_auc_score(y_true, y_probs, multi_class='ovr', average='macro')
240
+ metrics['roc_auc_ovo'] = roc_auc_score(y_true, y_probs, multi_class='ovo', average='macro')
241
+
242
+ return metrics
243
+
244
+
245
+ # ============================================================
246
+ # VISUALIZATION FUNCTIONS
247
+ # ============================================================
248
+
249
+ def plot_confusion_matrix(y_true, y_pred, save_path):
250
+ """Plot and save confusion matrix"""
251
+ cm = confusion_matrix(y_true, y_pred)
252
+
253
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6))
254
+
255
+ # Raw counts
256
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
257
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
258
+ ax=axes[0], cbar_kws={'label': 'Count'})
259
+ axes[0].set_xlabel('Predicted Label', fontsize=12)
260
+ axes[0].set_ylabel('True Label', fontsize=12)
261
+ axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
262
+
263
+ # Normalized
264
+ cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
265
+ sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
266
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
267
+ ax=axes[1], cbar_kws={'label': 'Percentage'})
268
+ axes[1].set_xlabel('Predicted Label', fontsize=12)
269
+ axes[1].set_ylabel('True Label', fontsize=12)
270
+ axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
271
+
272
+ plt.tight_layout()
273
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
274
+ print(f"📊 Confusion matrix saved to: {save_path}")
275
+ plt.close()
276
+
277
+ return cm
278
+
279
+
280
+ def plot_roc_curve(y_true, y_probs, save_path):
281
+ """Plot ROC curve"""
282
+ fig, ax = plt.subplots(figsize=(10, 8))
283
+
284
+ if NUM_CLASSES == 2:
285
+ # Binary classification
286
+ fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1])
287
+ roc_auc = auc(fpr, tpr)
288
+
289
+ ax.plot(fpr, tpr, color='darkorange', lw=2,
290
+ label=f'ROC curve (AUC = {roc_auc:.3f})')
291
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
292
+
293
+ else:
294
+ # Multi-class (one-vs-rest)
295
+ for i, class_name in enumerate(CLASS_NAMES):
296
+ y_true_binary = (y_true == i).astype(int)
297
+ fpr, tpr, _ = roc_curve(y_true_binary, y_probs[:, i])
298
+ roc_auc = auc(fpr, tpr)
299
+ ax.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
300
+
301
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
302
+
303
+ ax.set_xlim([0.0, 1.0])
304
+ ax.set_ylim([0.0, 1.05])
305
+ ax.set_xlabel('False Positive Rate', fontsize=12)
306
+ ax.set_ylabel('True Positive Rate', fontsize=12)
307
+ ax.set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=14, fontweight='bold')
308
+ ax.legend(loc="lower right", fontsize=10)
309
+ ax.grid(alpha=0.3)
310
+
311
+ plt.tight_layout()
312
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
313
+ print(f"📊 ROC curve saved to: {save_path}")
314
+ plt.close()
315
+
316
+
317
+ def plot_precision_recall_curve(y_true, y_probs, save_path):
318
+ """Plot Precision-Recall curve"""
319
+ fig, ax = plt.subplots(figsize=(10, 8))
320
+
321
+ if NUM_CLASSES == 2:
322
+ # Binary classification
323
+ precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1])
324
+ avg_precision = average_precision_score(y_true, y_probs[:, 1])
325
+
326
+ ax.plot(recall, precision, color='darkorange', lw=2,
327
+ label=f'PR curve (AP = {avg_precision:.3f})')
328
+
329
+ else:
330
+ # Multi-class
331
+ for i, class_name in enumerate(CLASS_NAMES):
332
+ y_true_binary = (y_true == i).astype(int)
333
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_probs[:, i])
334
+ avg_precision = average_precision_score(y_true_binary, y_probs[:, i])
335
+ ax.plot(recall, precision, lw=2,
336
+ label=f'{class_name} (AP = {avg_precision:.3f})')
337
+
338
+ ax.set_xlim([0.0, 1.0])
339
+ ax.set_ylim([0.0, 1.05])
340
+ ax.set_xlabel('Recall', fontsize=12)
341
+ ax.set_ylabel('Precision', fontsize=12)
342
+ ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
343
+ ax.legend(loc="lower left", fontsize=10)
344
+ ax.grid(alpha=0.3)
345
+
346
+ plt.tight_layout()
347
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
348
+ print(f"📊 Precision-Recall curve saved to: {save_path}")
349
+ plt.close()
350
+
351
+
352
+ def plot_class_distribution(y_true, y_pred, save_path):
353
+ """Plot class distribution comparison"""
354
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
355
+
356
+ # True distribution
357
+ true_counts = [np.sum(y_true == i) for i in range(NUM_CLASSES)]
358
+ axes[0].bar(CLASS_NAMES, true_counts, color='steelblue', alpha=0.7)
359
+ axes[0].set_ylabel('Count', fontsize=12)
360
+ axes[0].set_title('True Label Distribution', fontsize=14, fontweight='bold')
361
+ axes[0].grid(axis='y', alpha=0.3)
362
+ for i, count in enumerate(true_counts):
363
+ axes[0].text(i, count + max(true_counts)*0.01, str(count),
364
+ ha='center', va='bottom', fontweight='bold')
365
+
366
+ # Predicted distribution
367
+ pred_counts = [np.sum(y_pred == i) for i in range(NUM_CLASSES)]
368
+ axes[1].bar(CLASS_NAMES, pred_counts, color='coral', alpha=0.7)
369
+ axes[1].set_ylabel('Count', fontsize=12)
370
+ axes[1].set_title('Predicted Label Distribution', fontsize=14, fontweight='bold')
371
+ axes[1].grid(axis='y', alpha=0.3)
372
+ for i, count in enumerate(pred_counts):
373
+ axes[1].text(i, count + max(pred_counts)*0.01, str(count),
374
+ ha='center', va='bottom', fontweight='bold')
375
+
376
+ plt.tight_layout()
377
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
378
+ print(f"📊 Class distribution saved to: {save_path}")
379
+ plt.close()
380
+
381
+
382
+ def plot_per_class_metrics(metrics, save_path):
383
+ """Plot per-class performance metrics"""
384
+ classes = list(metrics['per_class'].keys())
385
+ precision_vals = [metrics['per_class'][c]['precision'] for c in classes]
386
+ recall_vals = [metrics['per_class'][c]['recall'] for c in classes]
387
+ f1_vals = [metrics['per_class'][c]['f1_score'] for c in classes]
388
+
389
+ x = np.arange(len(classes))
390
+ width = 0.25
391
+
392
+ fig, ax = plt.subplots(figsize=(12, 7))
393
+
394
+ bars1 = ax.bar(x - width, precision_vals, width, label='Precision', color='steelblue', alpha=0.8)
395
+ bars2 = ax.bar(x, recall_vals, width, label='Recall', color='coral', alpha=0.8)
396
+ bars3 = ax.bar(x + width, f1_vals, width, label='F1-Score', color='lightgreen', alpha=0.8)
397
+
398
+ ax.set_ylabel('Score', fontsize=12)
399
+ ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
400
+ ax.set_xticks(x)
401
+ ax.set_xticklabels(classes)
402
+ ax.legend(fontsize=11)
403
+ ax.set_ylim([0, 1.1])
404
+ ax.grid(axis='y', alpha=0.3)
405
+
406
+ # Add value labels on bars
407
+ def autolabel(bars):
408
+ for bar in bars:
409
+ height = bar.get_height()
410
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
411
+ f'{height:.3f}', ha='center', va='bottom', fontsize=9)
412
+
413
+ autolabel(bars1)
414
+ autolabel(bars2)
415
+ autolabel(bars3)
416
+
417
+ plt.tight_layout()
418
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
419
+ print(f"📊 Per-class metrics saved to: {save_path}")
420
+ plt.close()
421
+
422
+
423
+ def plot_confidence_distribution(y_true, y_pred, confidences, save_path):
424
+ """Plot confidence score distribution for correct vs incorrect predictions"""
425
+ # Confidence scores are already extracted
426
+ correct = (y_true == y_pred)
427
+
428
+ fig, axes = plt.subplots(2, 1, figsize=(12, 10))
429
+
430
+ # Histogram
431
+ axes[0].hist(confidences[correct], bins=50, alpha=0.7, label='Correct',
432
+ color='green', edgecolor='black')
433
+ axes[0].hist(confidences[~correct], bins=50, alpha=0.7, label='Incorrect',
434
+ color='red', edgecolor='black')
435
+ axes[0].set_xlabel('Confidence Score', fontsize=12)
436
+ axes[0].set_ylabel('Frequency', fontsize=12)
437
+ axes[0].set_title('Confidence Distribution: Correct vs Incorrect Predictions',
438
+ fontsize=14, fontweight='bold')
439
+ axes[0].legend(fontsize=11)
440
+ axes[0].grid(alpha=0.3)
441
+
442
+ # Box plot
443
+ data_to_plot = [confidences[correct], confidences[~correct]]
444
+ box = axes[1].boxplot(data_to_plot, labels=['Correct', 'Incorrect'],
445
+ patch_artist=True, showmeans=True)
446
+ box['boxes'][0].set_facecolor('lightgreen')
447
+ box['boxes'][1].set_facecolor('lightcoral')
448
+ axes[1].set_ylabel('Confidence Score', fontsize=12)
449
+ axes[1].set_title('Confidence Score Box Plot', fontsize=14, fontweight='bold')
450
+ axes[1].grid(axis='y', alpha=0.3)
451
+
452
+ # Add statistics
453
+ correct_mean = np.mean(confidences[correct])
454
+ incorrect_mean = np.mean(confidences[~correct]) if (~correct).sum() > 0 else 0
455
+ axes[1].text(1, correct_mean, f'μ={correct_mean:.3f}',
456
+ ha='right', va='center', fontweight='bold', fontsize=10)
457
+ if (~correct).sum() > 0:
458
+ axes[1].text(2, incorrect_mean, f'μ={incorrect_mean:.3f}',
459
+ ha='left', va='center', fontweight='bold', fontsize=10)
460
+
461
+ plt.tight_layout()
462
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
463
+ print(f"📊 Confidence distribution saved to: {save_path}")
464
+ plt.close()
465
+
466
+
467
+ # ============================================================
468
+ # RESULTS SAVING
469
+ # ============================================================
470
+
471
+ def save_predictions_to_csv(image_paths, y_true, y_pred, y_probs, confidences, save_path):
472
+ """Save detailed predictions to CSV"""
473
+ results = []
474
+
475
+ for img_path, true_label, pred, probs, conf in zip(image_paths, y_true, y_pred, y_probs, confidences):
476
+ result = {
477
+ 'image_path': img_path,
478
+ 'image_name': os.path.basename(img_path),
479
+ 'true_label': CLASS_NAMES[true_label],
480
+ 'true_label_idx': true_label,
481
+ 'predicted_label': CLASS_NAMES[pred],
482
+ 'predicted_label_idx': pred,
483
+ 'confidence': conf,
484
+ 'correct': pred == true_label
485
+ }
486
+
487
+ # Add probabilities for each class
488
+ for i, class_name in enumerate(CLASS_NAMES):
489
+ result[f'prob_{class_name}'] = probs[i]
490
+
491
+ results.append(result)
492
+
493
+ df = pd.DataFrame(results)
494
+ df.to_csv(save_path, index=False)
495
+ print(f"💾 Predictions saved to: {save_path}")
496
+
497
+ # Print some statistics
498
+ print(f"\n📊 Prediction Statistics:")
499
+ print(f" Total images: {len(df)}")
500
+ print(f" Correct predictions: {df['correct'].sum()} ({df['correct'].sum()/len(df)*100:.2f}%)")
501
+ print(f" Incorrect predictions: {(~df['correct']).sum()} ({(~df['correct']).sum()/len(df)*100:.2f}%)")
502
+ print(f" Average confidence: {df['confidence'].mean():.4f}")
503
+ print(f" Confidence on correct: {df[df['correct']]['confidence'].mean():.4f}")
504
+ print(f" Confidence on incorrect: {df[~df['correct']]['confidence'].mean():.4f}" if (~df['correct']).sum() > 0 else "")
505
+
506
+ return df
507
+
508
+
509
+ def save_metrics_json(metrics, save_path):
510
+ """Save metrics to JSON file"""
511
+ with open(save_path, 'w') as f:
512
+ json.dump(metrics, f, indent=4)
513
+ print(f"💾 Metrics saved to: {save_path}")
514
+
515
+
516
+ def generate_classification_report_file(y_true, y_pred, save_path):
517
+ """Generate and save sklearn classification report"""
518
+ report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, digits=4)
519
+
520
+ with open(save_path, 'w') as f:
521
+ f.write("="*60 + "\n")
522
+ f.write("CLASSIFICATION REPORT\n")
523
+ f.write("="*60 + "\n\n")
524
+ f.write(report)
525
+
526
+ print(f"📄 Classification report saved to: {save_path}")
527
+
528
+
529
+ # ============================================================
530
+ # MAIN EVALUATION PIPELINE
531
+ # ============================================================
532
+
533
+ def main():
534
+ """Main evaluation pipeline"""
535
+
536
+ print("\n" + "="*60)
537
+ print("CvT-13 MODEL EVALUATION PIPELINE")
538
+ print("Single Image Prediction Mode")
539
+ print("="*60 + "\n")
540
+
541
+ # Load model
542
+ print("📦 Loading model...")
543
+ model = load_model(MODEL_PATH)
544
+
545
+ # Load test data
546
+ print("\n📂 Loading test data...")
547
+ test_loader, test_dataset = get_test_dataloader(TEST_DATA_DIR, batch_size=1)
548
+
549
+ # Run evaluation with single image predictions
550
+ print("\n🔍 Evaluating model (single image predictions)...")
551
+ y_pred, y_true, y_probs, confidences, image_paths = evaluate_model(model, test_loader, test_dataset)
552
+
553
+ # Calculate metrics
554
+ print("\n📊 Calculating metrics...")
555
+ metrics = calculate_metrics(y_true, y_pred, y_probs)
556
+
557
+ # Print key metrics
558
+ print("\n" + "="*60)
559
+ print("EVALUATION RESULTS")
560
+ print("="*60)
561
+ print(f"Total Images Evaluated: {len(y_pred)}")
562
+ print(f"Accuracy: {metrics['accuracy']*100:.2f}%")
563
+ print(f"Precision (Macro): {metrics['precision_macro']*100:.2f}%")
564
+ print(f"Recall (Macro): {metrics['recall_macro']*100:.2f}%")
565
+ print(f"F1-Score (Macro): {metrics['f1_macro']*100:.2f}%")
566
+ if 'roc_auc' in metrics:
567
+ print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
568
+ print("\nPer-Class Metrics:")
569
+ for class_name, class_metrics in metrics['per_class'].items():
570
+ print(f" {class_name}:")
571
+ print(f" Precision: {class_metrics['precision']*100:.2f}%")
572
+ print(f" Recall: {class_metrics['recall']*100:.2f}%")
573
+ print(f" F1-Score: {class_metrics['f1_score']*100:.2f}%")
574
+ print("="*60)
575
+
576
+ # Generate all visualizations
577
+ print("\n📊 Generating visualizations...")
578
+ plot_confusion_matrix(y_true, y_pred,
579
+ os.path.join(EVAL_OUTPUT_DIR, "confusion_matrix.png"))
580
+ plot_roc_curve(y_true, y_probs,
581
+ os.path.join(EVAL_OUTPUT_DIR, "roc_curve.png"))
582
+ plot_precision_recall_curve(y_true, y_probs,
583
+ os.path.join(EVAL_OUTPUT_DIR, "precision_recall_curve.png"))
584
+ plot_class_distribution(y_true, y_pred,
585
+ os.path.join(EVAL_OUTPUT_DIR, "class_distribution.png"))
586
+ plot_per_class_metrics(metrics,
587
+ os.path.join(EVAL_OUTPUT_DIR, "per_class_metrics.png"))
588
+ plot_confidence_distribution(y_true, y_pred, confidences,
589
+ os.path.join(EVAL_OUTPUT_DIR, "confidence_distribution.png"))
590
+
591
+ # Save results
592
+ print("\n💾 Saving results...")
593
+ df = save_predictions_to_csv(image_paths, y_true, y_pred, y_probs, confidences,
594
+ os.path.join(EVAL_OUTPUT_DIR, "predictions.csv"))
595
+ save_metrics_json(metrics,
596
+ os.path.join(EVAL_OUTPUT_DIR, "metrics.json"))
597
+ generate_classification_report_file(y_true, y_pred,
598
+ os.path.join(EVAL_OUTPUT_DIR, "classification_report.txt"))
599
+
600
+ # Save misclassified images list
601
+ misclassified = df[~df['correct']]
602
+ if len(misclassified) > 0:
603
+ misclassified_path = os.path.join(EVAL_OUTPUT_DIR, "misclassified_images.csv")
604
+ misclassified.to_csv(misclassified_path, index=False)
605
+ print(f"⚠️ Misclassified images saved to: {misclassified_path}")
606
+ print(f" Total misclassified: {len(misclassified)}")
607
+
608
+ # Save low confidence predictions
609
+ low_conf_threshold = 0.7
610
+ low_confidence = df[df['confidence'] < low_conf_threshold]
611
+ if len(low_confidence) > 0:
612
+ low_conf_path = os.path.join(EVAL_OUTPUT_DIR, "low_confidence_predictions.csv")
613
+ low_confidence.to_csv(low_conf_path, index=False)
614
+ print(f"⚠️ Low confidence predictions saved to: {low_conf_path}")
615
+ print(f" Total with confidence < {low_conf_threshold}: {len(low_confidence)}")
616
+
617
+ print("\n" + "="*60)
618
+ print(f"✅ Evaluation complete!")
619
+ print(f"📁 All results saved to: {EVAL_OUTPUT_DIR}")
620
+ print("="*60 + "\n")
621
+
622
+ print("Generated files:")
623
+ print(" 📊 confusion_matrix.png - Confusion matrix visualization")
624
+ print(" 📊 roc_curve.png - ROC curve")
625
+ print(" 📊 precision_recall_curve.png - Precision-Recall curve")
626
+ print(" 📊 class_distribution.png - Class distribution comparison")
627
+ print(" 📊 per_class_metrics.png - Per-class performance")
628
+ print(" 📊 confidence_distribution.png - Confidence analysis")
629
+ print(" 💾 predictions.csv - Detailed predictions for each image")
630
+ print(" 💾 misclassified_images.csv - List of incorrectly classified images")
631
+ print(" 💾 low_confidence_predictions.csv - Predictions with low confidence")
632
+ print(" 💾 metrics.json - All metrics in JSON format")
633
+ print(" 📄 classification_report.txt - Sklearn classification report")
634
+
635
+
636
+ if __name__ == '__main__':
637
+ main()
stage2/stage2_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:445c5a5b94b86649cab12ef3c2fe4df9461f9879864c43d52a7cc9560204fcc3
3
+ size 78733538
stage2/train_cvt13.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ import sys
7
+ import os
8
+ from timm.loss import SoftTargetCrossEntropy
9
+ from timm.scheduler import CosineLRScheduler
10
+ from timm.utils import accuracy
11
+ import matplotlib.pyplot as plt
12
+ import json
13
+ from datetime import datetime
14
+
15
+
16
+ # ============================================================
17
+ # SETUP: Clone and import from Microsoft CvT repository
18
+ # ============================================================
19
+ """
20
+ First, clone the Microsoft CvT repository:
21
+ git clone https://github.com/microsoft/CvT.git
22
+ cd CvT
23
+ pip install -r requirements.txt
24
+ """
25
+
26
+ BASE_DIR = "path_to_CornViT"
27
+
28
+ # Add the CvT repo to Python path
29
+ CVT_REPO_PATH = f"{BASE_DIR}/CvT"
30
+
31
+ if not os.path.exists(CVT_REPO_PATH):
32
+ print(f"❌ CvT repository not found at {CVT_REPO_PATH}")
33
+ print("Please clone it: git clone https://github.com/microsoft/CvT.git")
34
+ sys.exit(1)
35
+
36
+ # Fix torch._six compatibility BEFORE importing
37
+ print("Applying compatibility fixes for newer PyTorch versions...")
38
+ cls_cvt_path = os.path.join(CVT_REPO_PATH, "lib", "models", "cls_cvt.py")
39
+
40
+ if os.path.exists(cls_cvt_path):
41
+ with open(cls_cvt_path, 'r', encoding='utf-8') as f:
42
+ content = f.read()
43
+
44
+ # Fix 1: Replace torch._six import
45
+ if "from torch._six import container_abcs" in content:
46
+ content = content.replace(
47
+ "from torch._six import container_abcs",
48
+ "import collections.abc as container_abcs"
49
+ )
50
+
51
+ # Fix 2: Replace 'is' with '==' for string comparison
52
+ content = content.replace(
53
+ "or pretrained_layers[0] is '*'",
54
+ "or pretrained_layers[0] == '*'"
55
+ )
56
+
57
+ with open(cls_cvt_path, 'w', encoding='utf-8') as f:
58
+ f.write(content)
59
+ print("✅ Applied compatibility patches to cls_cvt.py")
60
+ else:
61
+ print("✅ Compatibility patches already applied")
62
+ else:
63
+ print(f"❌ Could not find cls_cvt.py at {cls_cvt_path}")
64
+ sys.exit(1)
65
+
66
+ # Now import
67
+ sys.path.insert(0, CVT_REPO_PATH)
68
+
69
+ # Suppress the SyntaxWarning
70
+ import warnings
71
+ warnings.filterwarnings('ignore', category=SyntaxWarning)
72
+
73
+ from lib.models import cls_cvt
74
+ from lib.config import config, update_config
75
+ print("✅ Successfully imported Microsoft CvT models")
76
+
77
+
78
+ # ============================================================
79
+ # CONFIGURATION
80
+ # ============================================================
81
+
82
+ DATA_DIR = f"{BASE_DIR}/stage2/data"
83
+ BATCH_SIZE = 32
84
+ IMG_SIZE = 384
85
+ NUM_CLASSES = 2
86
+ NUM_EPOCHS = 100
87
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
88
+ PRETRAINED_PATH = f"{BASE_DIR}/CvT-13-384x384-IN-22k.pth"
89
+
90
+ # Create output directory for saving results
91
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
92
+ OUTPUT_DIR = f"metrics/cvt13_run_{timestamp}"
93
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
94
+ print(f"Metrics will be saved to: {OUTPUT_DIR}")
95
+
96
+
97
+ # ============================================================
98
+ # DATASET & AUGMENTATION
99
+ # ============================================================
100
+
101
+ train_transforms = transforms.Compose([
102
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
103
+ transforms.RandomHorizontalFlip(),
104
+ transforms.RandomVerticalFlip(),
105
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
106
+ transforms.RandomRotation(15),
107
+ transforms.ToTensor(),
108
+ transforms.Normalize([0.485, 0.456, 0.406],
109
+ [0.229, 0.224, 0.225])
110
+ ])
111
+
112
+ val_transforms = transforms.Compose([
113
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize([0.485, 0.456, 0.406],
116
+ [0.229, 0.224, 0.225])
117
+ ])
118
+
119
+ train_dataset = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_transforms)
120
+ val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=val_transforms)
121
+
122
+ train_loader = DataLoader(
123
+ train_dataset,
124
+ batch_size=BATCH_SIZE,
125
+ shuffle=True,
126
+ num_workers=0,
127
+ pin_memory=True,
128
+ drop_last=True
129
+ )
130
+ val_loader = DataLoader(
131
+ val_dataset,
132
+ batch_size=BATCH_SIZE,
133
+ shuffle=False,
134
+ num_workers=0,
135
+ pin_memory=True,
136
+ drop_last=True
137
+ )
138
+
139
+
140
+ # ============================================================
141
+ # MODEL SETUP - Using Microsoft CvT Implementation
142
+ # ============================================================
143
+
144
+ # Load the CvT-13 config from the repository
145
+ cvt_config_path = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt", "cvt-13-384x384.yaml")
146
+
147
+ if not os.path.exists(cvt_config_path):
148
+ print(f"⚠️ Config file not found at {cvt_config_path}")
149
+ print("Available configs:")
150
+ config_dir = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt")
151
+ if os.path.exists(config_dir):
152
+ for f in os.listdir(config_dir):
153
+ if f.endswith('.yaml'):
154
+ print(f" - {f}")
155
+ sys.exit(1)
156
+
157
+ print(f"Loading config from: {cvt_config_path}")
158
+
159
+ # Load config directly using merge_from_file
160
+ config.defrost()
161
+ config.merge_from_file(cvt_config_path)
162
+
163
+ # Update the number of classes for our task
164
+ config.MODEL.NUM_CLASSES = NUM_CLASSES
165
+ config.MODEL.PRETRAINED = '' # We'll load weights manually
166
+ config.freeze()
167
+
168
+ print("Creating CvT-13 model...")
169
+ # Create model using the official CvT architecture
170
+ model = cls_cvt.get_cls_model(config)
171
+ model = model.to(DEVICE)
172
+
173
+ # Load pretrained weights
174
+ if os.path.exists(PRETRAINED_PATH):
175
+ print(f"Loading pretrained weights from {PRETRAINED_PATH}")
176
+ try:
177
+ checkpoint = torch.load(PRETRAINED_PATH, map_location=DEVICE)
178
+
179
+ # Handle different checkpoint formats
180
+ if 'model' in checkpoint:
181
+ state_dict = checkpoint['model']
182
+ elif 'state_dict' in checkpoint:
183
+ state_dict = checkpoint['state_dict']
184
+ else:
185
+ state_dict = checkpoint
186
+
187
+ # Remove 'module.' prefix if present
188
+ new_state_dict = {}
189
+ for k, v in state_dict.items():
190
+ name = k.replace("module.", "")
191
+ new_state_dict[name] = v
192
+
193
+ # Remove head layers from pretrained weights (they have different dimensions)
194
+ filtered_state_dict = {k: v for k, v in new_state_dict.items() if 'head' not in k}
195
+
196
+ # Load weights - strict=False will only load matching layers
197
+ missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
198
+
199
+ # Count how many weights were actually loaded
200
+ loaded_keys = [k for k in filtered_state_dict.keys() if k in model.state_dict()]
201
+ print(f"✅ Loaded pretrained weights: {len(loaded_keys)} layers from backbone")
202
+ print(f" Head layer initialized randomly for {NUM_CLASSES} classes")
203
+
204
+ # Show what's missing (should only be head-related)
205
+ head_missing = [k for k in missing_keys if 'head' in k]
206
+ other_missing = [k for k in missing_keys if 'head' not in k]
207
+
208
+ if other_missing:
209
+ print(f"⚠️ Warning - Missing non-head keys: {other_missing}")
210
+ if unexpected_keys:
211
+ print(f"⚠️ Unexpected keys: {unexpected_keys}")
212
+
213
+ except Exception as e:
214
+ print(f"⚠️ Error loading pretrained weights: {e}")
215
+ import traceback
216
+ traceback.print_exc()
217
+ print("Continuing with random initialization...")
218
+ else:
219
+ print(f"⚠️ Pretrained weights not found at {PRETRAINED_PATH}")
220
+ print("Training from scratch...")
221
+
222
+ # Freeze backbone - only train the head for faster training and less overfitting
223
+ print("Freezing backbone layers (keeping only head trainable)...")
224
+ for name, param in model.named_parameters():
225
+ if "head" not in name:
226
+ param.requires_grad = False
227
+
228
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
229
+ total_params = sum(p.numel() for p in model.parameters())
230
+ print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")
231
+ print(f"Frozen parameters: {total_params - trainable_params:,}")
232
+
233
+
234
+ # ============================================================
235
+ # OPTIMIZER AND LOSS
236
+ # ============================================================
237
+
238
+ optimizer = optim.AdamW(
239
+ filter(lambda p: p.requires_grad, model.parameters()),
240
+ lr=1e-4,
241
+ weight_decay=0.05
242
+ )
243
+
244
+ criterion = SoftTargetCrossEntropy()
245
+
246
+ lr_scheduler = CosineLRScheduler(
247
+ optimizer,
248
+ t_initial=NUM_EPOCHS,
249
+ lr_min=1e-6,
250
+ warmup_t=5,
251
+ warmup_lr_init=1e-5,
252
+ )
253
+
254
+
255
+ # ============================================================
256
+ # TRAINING & VALIDATION LOOP
257
+ # ============================================================
258
+
259
+ def train_one_epoch(epoch, history):
260
+ model.train()
261
+ total_loss, total_acc = 0, 0
262
+
263
+ for images, targets in train_loader:
264
+ images, targets = images.to(DEVICE), targets.to(DEVICE)
265
+ images, targets = mixup_fn(images, targets)
266
+
267
+ optimizer.zero_grad()
268
+ outputs = model(images)
269
+ loss = criterion(outputs, targets)
270
+ loss.backward()
271
+ optimizer.step()
272
+
273
+ acc1, _ = accuracy(outputs, targets.argmax(dim=1), topk=(1, 5))
274
+ total_loss += loss.item()
275
+ total_acc += acc1.item()
276
+
277
+ avg_loss = total_loss / len(train_loader)
278
+ avg_acc = total_acc / len(train_loader)
279
+
280
+ history['train_loss'].append(avg_loss)
281
+ history['train_acc'].append(avg_acc)
282
+ history['learning_rate'].append(optimizer.param_groups[0]['lr'])
283
+
284
+ print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Train Loss: {avg_loss:.4f} | Train Acc: {avg_acc:.2f}% | LR: {optimizer.param_groups[0]['lr']:.6f}")
285
+ return avg_loss, avg_acc
286
+
287
+
288
+ def validate(epoch, history):
289
+ model.eval()
290
+ total_loss, total_acc = 0, 0
291
+
292
+ with torch.no_grad():
293
+ for images, targets in val_loader:
294
+ images, targets = images.to(DEVICE), targets.to(DEVICE)
295
+ outputs = model(images)
296
+ loss = nn.CrossEntropyLoss()(outputs, targets)
297
+ acc1, _ = accuracy(outputs, targets, topk=(1, 5))
298
+
299
+ total_loss += loss.item()
300
+ total_acc += acc1.item()
301
+
302
+ avg_loss = total_loss / len(val_loader)
303
+ avg_acc = total_acc / len(val_loader)
304
+
305
+ history['val_loss'].append(avg_loss)
306
+ history['val_acc'].append(avg_acc)
307
+
308
+ print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Val Loss: {avg_loss:.4f} | Val Acc: {avg_acc:.2f}%")
309
+ return avg_acc
310
+
311
+
312
+ def plot_training_history(history, save_path):
313
+ """Plot and save training metrics"""
314
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
315
+
316
+ epochs = range(1, len(history['train_loss']) + 1)
317
+
318
+ # Plot 1: Loss
319
+ axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
320
+ axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
321
+ axes[0, 0].set_xlabel('Epoch', fontsize=12)
322
+ axes[0, 0].set_ylabel('Loss', fontsize=12)
323
+ axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
324
+ axes[0, 0].legend()
325
+ axes[0, 0].grid(True, alpha=0.3)
326
+
327
+ # Plot 2: Accuracy
328
+ axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
329
+ axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
330
+ axes[0, 1].set_xlabel('Epoch', fontsize=12)
331
+ axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
332
+ axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
333
+ axes[0, 1].legend()
334
+ axes[0, 1].grid(True, alpha=0.3)
335
+
336
+ # Plot 3: Learning Rate
337
+ axes[1, 0].plot(epochs, history['learning_rate'], 'g-', linewidth=2)
338
+ axes[1, 0].set_xlabel('Epoch', fontsize=12)
339
+ axes[1, 0].set_ylabel('Learning Rate', fontsize=12)
340
+ axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
341
+ axes[1, 0].set_yscale('log')
342
+ axes[1, 0].grid(True, alpha=0.3)
343
+
344
+ # Plot 4: Val Acc vs Train Acc (Overfitting check)
345
+ axes[1, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
346
+ axes[1, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
347
+ gap = [t - v for t, v in zip(history['train_acc'], history['val_acc'])]
348
+ axes[1, 1].fill_between(epochs, history['val_acc'], history['train_acc'],
349
+ alpha=0.3, color='orange', label='Overfitting Gap')
350
+ axes[1, 1].set_xlabel('Epoch', fontsize=12)
351
+ axes[1, 1].set_ylabel('Accuracy (%)', fontsize=12)
352
+ axes[1, 1].set_title('Overfitting Analysis', fontsize=14, fontweight='bold')
353
+ axes[1, 1].legend()
354
+ axes[1, 1].grid(True, alpha=0.3)
355
+
356
+ plt.tight_layout()
357
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
358
+ print(f"📊 Training plots saved to: {save_path}")
359
+ plt.close()
360
+
361
+
362
+ def save_training_summary(history, best_acc, save_path):
363
+ """Save training summary as JSON"""
364
+ summary = {
365
+ 'config': {
366
+ 'model': 'CvT-13',
367
+ 'batch_size': BATCH_SIZE,
368
+ 'img_size': IMG_SIZE,
369
+ 'num_classes': NUM_CLASSES,
370
+ 'num_epochs': NUM_EPOCHS,
371
+ 'device': DEVICE,
372
+ 'pretrained': PRETRAINED_PATH,
373
+ },
374
+ 'final_metrics': {
375
+ 'best_val_accuracy': best_acc,
376
+ 'final_train_loss': history['train_loss'][-1],
377
+ 'final_train_acc': history['train_acc'][-1],
378
+ 'final_val_loss': history['val_loss'][-1],
379
+ 'final_val_acc': history['val_acc'][-1],
380
+ },
381
+ 'history': history
382
+ }
383
+
384
+ with open(save_path, 'w') as f:
385
+ json.dump(summary, f, indent=4)
386
+
387
+ print(f"💾 Training summary saved to: {save_path}")
388
+
389
+
390
+ # ============================================================
391
+ # MAIN TRAINING LOOP
392
+ # ============================================================
393
+
394
+ if __name__ == '__main__':
395
+ print("\n" + "="*60)
396
+ print("STARTING TRAINING")
397
+ print("="*60 + "\n")
398
+
399
+ # Initialize history tracking
400
+ history = {
401
+ 'train_loss': [],
402
+ 'train_acc': [],
403
+ 'val_loss': [],
404
+ 'val_acc': [],
405
+ 'learning_rate': []
406
+ }
407
+
408
+ best_acc = 0.0
409
+ best_epoch = 0
410
+
411
+ for epoch in range(NUM_EPOCHS):
412
+ train_loss, train_acc = train_one_epoch(epoch, history)
413
+ val_acc = validate(epoch, history)
414
+ lr_scheduler.step(epoch + 1)
415
+
416
+ # Save best model
417
+ if val_acc > best_acc:
418
+ best_acc = val_acc
419
+ best_epoch = epoch + 1
420
+ torch.save({
421
+ 'epoch': epoch,
422
+ 'model_state_dict': model.state_dict(),
423
+ 'optimizer_state_dict': optimizer.state_dict(),
424
+ 'best_acc': best_acc,
425
+ 'history': history,
426
+ }, os.path.join(OUTPUT_DIR, "best_model.pth"))
427
+ print(f"✅ Saved best model at epoch {epoch+1} with val acc {best_acc:.2f}%\n")
428
+
429
+ # Save checkpoint every 10 epochs
430
+ if (epoch + 1) % 10 == 0:
431
+ torch.save({
432
+ 'epoch': epoch,
433
+ 'model_state_dict': model.state_dict(),
434
+ 'optimizer_state_dict': optimizer.state_dict(),
435
+ 'val_acc': val_acc,
436
+ 'history': history,
437
+ }, os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
438
+ print(f"💾 Checkpoint saved at epoch {epoch+1}\n")
439
+
440
+ # Plot and save metrics every 5 epochs
441
+ if (epoch + 1) % 5 == 0 or epoch == NUM_EPOCHS - 1:
442
+ plot_training_history(history, os.path.join(OUTPUT_DIR, "training_metrics.png"))
443
+
444
+ # Final summary
445
+ print("="*60)
446
+ print(f"🎉 Training complete!")
447
+ print(f"Best validation accuracy: {best_acc:.2f}% at epoch {best_epoch}")
448
+ print(f"Final train accuracy: {history['train_acc'][-1]:.2f}%")
449
+ print(f"Final val accuracy: {history['val_acc'][-1]:.2f}%")
450
+ print("="*60)
451
+
452
+ # Save final training summary
453
+ save_training_summary(history, best_acc, os.path.join(OUTPUT_DIR, "training_summary.json"))
454
+
455
+ # Save final plot
456
+ plot_training_history(history, os.path.join(OUTPUT_DIR, "final_training_metrics.png"))
457
+
458
+ print(f"\n📁 All outputs saved to: {OUTPUT_DIR}")
stage3/inference_cvt13.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import datasets, transforms
4
+ from torch.utils.data import DataLoader
5
+ import sys
6
+ import os
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from pathlib import Path
10
+ import json
11
+ from datetime import datetime
12
+ from sklearn.metrics import (
13
+ accuracy_score, precision_score, recall_score, f1_score,
14
+ confusion_matrix, classification_report, roc_curve, auc,
15
+ precision_recall_curve, average_precision_score, roc_auc_score
16
+ )
17
+ import seaborn as sns
18
+ import pandas as pd
19
+
20
+
21
+ # ============================================================
22
+ # CONFIGURATION
23
+ # ============================================================
24
+
25
+ BASE_DIR = "path_to_CornViT"
26
+
27
+ # Path to the Microsoft CvT repository
28
+ CVT_REPO_PATH = f"{BASE_DIR}/CvT"
29
+
30
+ # Model configuration
31
+ IMG_SIZE = 384
32
+ NUM_CLASSES = 2
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ RUN = "cvt13_run_2025xxxx_xxxxxx"
36
+
37
+ # Path to trained model
38
+ MODEL_PATH = f"metrics/{RUN}/train/best_model.pth"
39
+
40
+ # Test data folder (should have subfolders for each class like train/val structure)
41
+ TEST_DATA_DIR = f"{BASE_DIR}/stage3/data/test"
42
+
43
+ # Class names (update these to match your dataset)
44
+ CLASS_NAMES = ["Up", "Down"]
45
+
46
+ # Output directory for evaluation results (within the same metrics folder)
47
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
48
+ EVAL_OUTPUT_DIR = f"metrics/{RUN}/evals/eval_{timestamp}"
49
+ os.makedirs(EVAL_OUTPUT_DIR, exist_ok=True)
50
+
51
+
52
+ # ============================================================
53
+ # SETUP: Import CvT model
54
+ # ============================================================
55
+
56
+ # Fix torch._six compatibility
57
+ cls_cvt_path = os.path.join(CVT_REPO_PATH, "lib", "models", "cls_cvt.py")
58
+ if os.path.exists(cls_cvt_path):
59
+ with open(cls_cvt_path, 'r', encoding='utf-8') as f:
60
+ content = f.read()
61
+
62
+ if "from torch._six import container_abcs" in content:
63
+ content = content.replace(
64
+ "from torch._six import container_abcs",
65
+ "import collections.abc as container_abcs"
66
+ )
67
+ content = content.replace(
68
+ "or pretrained_layers[0] is '*'",
69
+ "or pretrained_layers[0] == '*'"
70
+ )
71
+ with open(cls_cvt_path, 'w', encoding='utf-8') as f:
72
+ f.write(content)
73
+
74
+ sys.path.insert(0, CVT_REPO_PATH)
75
+
76
+ import warnings
77
+ warnings.filterwarnings('ignore', category=SyntaxWarning)
78
+
79
+ from lib.models import cls_cvt
80
+ from lib.config import config, update_config
81
+
82
+
83
+ # ============================================================
84
+ # MODEL LOADING
85
+ # ============================================================
86
+
87
+ def load_model(model_path, config_path=None):
88
+ """Load the trained CvT model"""
89
+
90
+ # Load config
91
+ if config_path is None:
92
+ config_path = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt", "cvt-13-384x384.yaml")
93
+
94
+ config.defrost()
95
+ config.merge_from_file(config_path)
96
+ config.MODEL.NUM_CLASSES = NUM_CLASSES
97
+ config.MODEL.PRETRAINED = ''
98
+ config.freeze()
99
+
100
+ # Create model
101
+ model = cls_cvt.get_cls_model(config)
102
+
103
+ # Load trained weights
104
+ checkpoint = torch.load(model_path, map_location=DEVICE)
105
+ if 'model_state_dict' in checkpoint:
106
+ model.load_state_dict(checkpoint['model_state_dict'])
107
+ else:
108
+ model.load_state_dict(checkpoint)
109
+
110
+ model = model.to(DEVICE)
111
+ model.eval()
112
+
113
+ print(f"✅ Model loaded from: {model_path}")
114
+ return model
115
+
116
+
117
+ # ============================================================
118
+ # DATA LOADING
119
+ # ============================================================
120
+
121
+ def get_test_dataloader(test_dir, batch_size=32):
122
+ """Create test dataloader"""
123
+ test_transforms = transforms.Compose([
124
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
125
+ transforms.ToTensor(),
126
+ transforms.Normalize([0.485, 0.456, 0.406],
127
+ [0.229, 0.224, 0.225])
128
+ ])
129
+
130
+ test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
131
+ test_loader = DataLoader(test_dataset, batch_size=batch_size,
132
+ shuffle=False, num_workers=0, pin_memory=True)
133
+
134
+ print(f"✅ Test dataset loaded: {len(test_dataset)} images")
135
+ print(f" Classes: {test_dataset.classes}")
136
+ return test_loader, test_dataset
137
+
138
+
139
+ # ============================================================
140
+ # EVALUATION FUNCTIONS
141
+ # ============================================================
142
+
143
+ def evaluate_model(model, test_loader, test_dataset):
144
+ """
145
+ Evaluate model with single image predictions
146
+
147
+ Returns:
148
+ all_preds: Predicted class labels
149
+ all_labels: Ground truth labels
150
+ all_probs: Predicted probabilities for all classes
151
+ all_confidences: Confidence scores
152
+ image_paths: List of image paths
153
+ """
154
+ model.eval()
155
+
156
+ all_preds = []
157
+ all_labels = []
158
+ all_probs = []
159
+ all_confidences = []
160
+ image_paths = []
161
+
162
+ print("\n🔍 Running single-image inference on test set...")
163
+
164
+ # Process each image individually
165
+ total_images = len(test_dataset)
166
+
167
+ for idx in range(total_images):
168
+ # Get single image and label
169
+ image, label = test_dataset[idx]
170
+ img_path, _ = test_dataset.samples[idx]
171
+
172
+ # Add batch dimension and move to device
173
+ image = image.unsqueeze(0).to(DEVICE)
174
+
175
+ with torch.no_grad():
176
+ # Forward pass
177
+ output = model(image)
178
+
179
+ # Ensure output has correct shape
180
+ if output.dim() == 1:
181
+ output = output.unsqueeze(0)
182
+
183
+ probabilities = torch.softmax(output, dim=1)
184
+ confidence, predicted = torch.max(probabilities, 1)
185
+
186
+ # Collect results
187
+ all_preds.append(predicted.item())
188
+ all_labels.append(label)
189
+ all_probs.append(probabilities.cpu().numpy()[0])
190
+ all_confidences.append(confidence.item())
191
+ image_paths.append(img_path)
192
+
193
+ # Progress update
194
+ if (idx + 1) % 50 == 0 or (idx + 1) == total_images:
195
+ print(f" Processed {idx + 1}/{total_images} images...")
196
+
197
+ print(f"✅ Inference complete: {len(all_preds)} predictions")
198
+
199
+ return (np.array(all_preds), np.array(all_labels), np.array(all_probs),
200
+ np.array(all_confidences), image_paths)
201
+
202
+
203
+ # ============================================================
204
+ # METRICS CALCULATION
205
+ # ============================================================
206
+
207
+ def calculate_metrics(y_true, y_pred, y_probs):
208
+ """Calculate all classification metrics"""
209
+
210
+ metrics = {}
211
+
212
+ # Basic metrics
213
+ metrics['accuracy'] = accuracy_score(y_true, y_pred)
214
+ metrics['precision_macro'] = precision_score(y_true, y_pred, average='macro', zero_division=0)
215
+ metrics['precision_weighted'] = precision_score(y_true, y_pred, average='weighted', zero_division=0)
216
+ metrics['recall_macro'] = recall_score(y_true, y_pred, average='macro', zero_division=0)
217
+ metrics['recall_weighted'] = recall_score(y_true, y_pred, average='weighted', zero_division=0)
218
+ metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro', zero_division=0)
219
+ metrics['f1_weighted'] = f1_score(y_true, y_pred, average='weighted', zero_division=0)
220
+
221
+ # Per-class metrics
222
+ precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
223
+ recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
224
+ f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
225
+
226
+ metrics['per_class'] = {}
227
+ for i, class_name in enumerate(CLASS_NAMES):
228
+ metrics['per_class'][class_name] = {
229
+ 'precision': float(precision_per_class[i]),
230
+ 'recall': float(recall_per_class[i]),
231
+ 'f1_score': float(f1_per_class[i])
232
+ }
233
+
234
+ # ROC-AUC (for binary and multi-class)
235
+ if NUM_CLASSES == 2:
236
+ metrics['roc_auc'] = roc_auc_score(y_true, y_probs[:, 1])
237
+ metrics['average_precision'] = average_precision_score(y_true, y_probs[:, 1])
238
+ else:
239
+ metrics['roc_auc_ovr'] = roc_auc_score(y_true, y_probs, multi_class='ovr', average='macro')
240
+ metrics['roc_auc_ovo'] = roc_auc_score(y_true, y_probs, multi_class='ovo', average='macro')
241
+
242
+ return metrics
243
+
244
+
245
+ # ============================================================
246
+ # VISUALIZATION FUNCTIONS
247
+ # ============================================================
248
+
249
+ def plot_confusion_matrix(y_true, y_pred, save_path):
250
+ """Plot and save confusion matrix"""
251
+ cm = confusion_matrix(y_true, y_pred)
252
+
253
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6))
254
+
255
+ # Raw counts
256
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
257
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
258
+ ax=axes[0], cbar_kws={'label': 'Count'})
259
+ axes[0].set_xlabel('Predicted Label', fontsize=12)
260
+ axes[0].set_ylabel('True Label', fontsize=12)
261
+ axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
262
+
263
+ # Normalized
264
+ cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
265
+ sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
266
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
267
+ ax=axes[1], cbar_kws={'label': 'Percentage'})
268
+ axes[1].set_xlabel('Predicted Label', fontsize=12)
269
+ axes[1].set_ylabel('True Label', fontsize=12)
270
+ axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
271
+
272
+ plt.tight_layout()
273
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
274
+ print(f"📊 Confusion matrix saved to: {save_path}")
275
+ plt.close()
276
+
277
+ return cm
278
+
279
+
280
+ def plot_roc_curve(y_true, y_probs, save_path):
281
+ """Plot ROC curve"""
282
+ fig, ax = plt.subplots(figsize=(10, 8))
283
+
284
+ if NUM_CLASSES == 2:
285
+ # Binary classification
286
+ fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1])
287
+ roc_auc = auc(fpr, tpr)
288
+
289
+ ax.plot(fpr, tpr, color='darkorange', lw=2,
290
+ label=f'ROC curve (AUC = {roc_auc:.3f})')
291
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
292
+
293
+ else:
294
+ # Multi-class (one-vs-rest)
295
+ for i, class_name in enumerate(CLASS_NAMES):
296
+ y_true_binary = (y_true == i).astype(int)
297
+ fpr, tpr, _ = roc_curve(y_true_binary, y_probs[:, i])
298
+ roc_auc = auc(fpr, tpr)
299
+ ax.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.3f})')
300
+
301
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
302
+
303
+ ax.set_xlim([0.0, 1.0])
304
+ ax.set_ylim([0.0, 1.05])
305
+ ax.set_xlabel('False Positive Rate', fontsize=12)
306
+ ax.set_ylabel('True Positive Rate', fontsize=12)
307
+ ax.set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=14, fontweight='bold')
308
+ ax.legend(loc="lower right", fontsize=10)
309
+ ax.grid(alpha=0.3)
310
+
311
+ plt.tight_layout()
312
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
313
+ print(f"📊 ROC curve saved to: {save_path}")
314
+ plt.close()
315
+
316
+
317
+ def plot_precision_recall_curve(y_true, y_probs, save_path):
318
+ """Plot Precision-Recall curve"""
319
+ fig, ax = plt.subplots(figsize=(10, 8))
320
+
321
+ if NUM_CLASSES == 2:
322
+ # Binary classification
323
+ precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1])
324
+ avg_precision = average_precision_score(y_true, y_probs[:, 1])
325
+
326
+ ax.plot(recall, precision, color='darkorange', lw=2,
327
+ label=f'PR curve (AP = {avg_precision:.3f})')
328
+
329
+ else:
330
+ # Multi-class
331
+ for i, class_name in enumerate(CLASS_NAMES):
332
+ y_true_binary = (y_true == i).astype(int)
333
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_probs[:, i])
334
+ avg_precision = average_precision_score(y_true_binary, y_probs[:, i])
335
+ ax.plot(recall, precision, lw=2,
336
+ label=f'{class_name} (AP = {avg_precision:.3f})')
337
+
338
+ ax.set_xlim([0.0, 1.0])
339
+ ax.set_ylim([0.0, 1.05])
340
+ ax.set_xlabel('Recall', fontsize=12)
341
+ ax.set_ylabel('Precision', fontsize=12)
342
+ ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
343
+ ax.legend(loc="lower left", fontsize=10)
344
+ ax.grid(alpha=0.3)
345
+
346
+ plt.tight_layout()
347
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
348
+ print(f"📊 Precision-Recall curve saved to: {save_path}")
349
+ plt.close()
350
+
351
+
352
+ def plot_class_distribution(y_true, y_pred, save_path):
353
+ """Plot class distribution comparison"""
354
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
355
+
356
+ # True distribution
357
+ true_counts = [np.sum(y_true == i) for i in range(NUM_CLASSES)]
358
+ axes[0].bar(CLASS_NAMES, true_counts, color='steelblue', alpha=0.7)
359
+ axes[0].set_ylabel('Count', fontsize=12)
360
+ axes[0].set_title('True Label Distribution', fontsize=14, fontweight='bold')
361
+ axes[0].grid(axis='y', alpha=0.3)
362
+ for i, count in enumerate(true_counts):
363
+ axes[0].text(i, count + max(true_counts)*0.01, str(count),
364
+ ha='center', va='bottom', fontweight='bold')
365
+
366
+ # Predicted distribution
367
+ pred_counts = [np.sum(y_pred == i) for i in range(NUM_CLASSES)]
368
+ axes[1].bar(CLASS_NAMES, pred_counts, color='coral', alpha=0.7)
369
+ axes[1].set_ylabel('Count', fontsize=12)
370
+ axes[1].set_title('Predicted Label Distribution', fontsize=14, fontweight='bold')
371
+ axes[1].grid(axis='y', alpha=0.3)
372
+ for i, count in enumerate(pred_counts):
373
+ axes[1].text(i, count + max(pred_counts)*0.01, str(count),
374
+ ha='center', va='bottom', fontweight='bold')
375
+
376
+ plt.tight_layout()
377
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
378
+ print(f"📊 Class distribution saved to: {save_path}")
379
+ plt.close()
380
+
381
+
382
+ def plot_per_class_metrics(metrics, save_path):
383
+ """Plot per-class performance metrics"""
384
+ classes = list(metrics['per_class'].keys())
385
+ precision_vals = [metrics['per_class'][c]['precision'] for c in classes]
386
+ recall_vals = [metrics['per_class'][c]['recall'] for c in classes]
387
+ f1_vals = [metrics['per_class'][c]['f1_score'] for c in classes]
388
+
389
+ x = np.arange(len(classes))
390
+ width = 0.25
391
+
392
+ fig, ax = plt.subplots(figsize=(12, 7))
393
+
394
+ bars1 = ax.bar(x - width, precision_vals, width, label='Precision', color='steelblue', alpha=0.8)
395
+ bars2 = ax.bar(x, recall_vals, width, label='Recall', color='coral', alpha=0.8)
396
+ bars3 = ax.bar(x + width, f1_vals, width, label='F1-Score', color='lightgreen', alpha=0.8)
397
+
398
+ ax.set_ylabel('Score', fontsize=12)
399
+ ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
400
+ ax.set_xticks(x)
401
+ ax.set_xticklabels(classes)
402
+ ax.legend(fontsize=11)
403
+ ax.set_ylim([0, 1.1])
404
+ ax.grid(axis='y', alpha=0.3)
405
+
406
+ # Add value labels on bars
407
+ def autolabel(bars):
408
+ for bar in bars:
409
+ height = bar.get_height()
410
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
411
+ f'{height:.3f}', ha='center', va='bottom', fontsize=9)
412
+
413
+ autolabel(bars1)
414
+ autolabel(bars2)
415
+ autolabel(bars3)
416
+
417
+ plt.tight_layout()
418
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
419
+ print(f"📊 Per-class metrics saved to: {save_path}")
420
+ plt.close()
421
+
422
+
423
+ def plot_confidence_distribution(y_true, y_pred, confidences, save_path):
424
+ """Plot confidence score distribution for correct vs incorrect predictions"""
425
+ # Confidence scores are already extracted
426
+ correct = (y_true == y_pred)
427
+
428
+ fig, axes = plt.subplots(2, 1, figsize=(12, 10))
429
+
430
+ # Histogram
431
+ axes[0].hist(confidences[correct], bins=50, alpha=0.7, label='Correct',
432
+ color='green', edgecolor='black')
433
+ axes[0].hist(confidences[~correct], bins=50, alpha=0.7, label='Incorrect',
434
+ color='red', edgecolor='black')
435
+ axes[0].set_xlabel('Confidence Score', fontsize=12)
436
+ axes[0].set_ylabel('Frequency', fontsize=12)
437
+ axes[0].set_title('Confidence Distribution: Correct vs Incorrect Predictions',
438
+ fontsize=14, fontweight='bold')
439
+ axes[0].legend(fontsize=11)
440
+ axes[0].grid(alpha=0.3)
441
+
442
+ # Box plot
443
+ data_to_plot = [confidences[correct], confidences[~correct]]
444
+ box = axes[1].boxplot(data_to_plot, labels=['Correct', 'Incorrect'],
445
+ patch_artist=True, showmeans=True)
446
+ box['boxes'][0].set_facecolor('lightgreen')
447
+ box['boxes'][1].set_facecolor('lightcoral')
448
+ axes[1].set_ylabel('Confidence Score', fontsize=12)
449
+ axes[1].set_title('Confidence Score Box Plot', fontsize=14, fontweight='bold')
450
+ axes[1].grid(axis='y', alpha=0.3)
451
+
452
+ # Add statistics
453
+ correct_mean = np.mean(confidences[correct])
454
+ incorrect_mean = np.mean(confidences[~correct]) if (~correct).sum() > 0 else 0
455
+ axes[1].text(1, correct_mean, f'μ={correct_mean:.3f}',
456
+ ha='right', va='center', fontweight='bold', fontsize=10)
457
+ if (~correct).sum() > 0:
458
+ axes[1].text(2, incorrect_mean, f'μ={incorrect_mean:.3f}',
459
+ ha='left', va='center', fontweight='bold', fontsize=10)
460
+
461
+ plt.tight_layout()
462
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
463
+ print(f"📊 Confidence distribution saved to: {save_path}")
464
+ plt.close()
465
+
466
+
467
+ # ============================================================
468
+ # RESULTS SAVING
469
+ # ============================================================
470
+
471
+ def save_predictions_to_csv(image_paths, y_true, y_pred, y_probs, confidences, save_path):
472
+ """Save detailed predictions to CSV"""
473
+ results = []
474
+
475
+ for img_path, true_label, pred, probs, conf in zip(image_paths, y_true, y_pred, y_probs, confidences):
476
+ result = {
477
+ 'image_path': img_path,
478
+ 'image_name': os.path.basename(img_path),
479
+ 'true_label': CLASS_NAMES[true_label],
480
+ 'true_label_idx': true_label,
481
+ 'predicted_label': CLASS_NAMES[pred],
482
+ 'predicted_label_idx': pred,
483
+ 'confidence': conf,
484
+ 'correct': pred == true_label
485
+ }
486
+
487
+ # Add probabilities for each class
488
+ for i, class_name in enumerate(CLASS_NAMES):
489
+ result[f'prob_{class_name}'] = probs[i]
490
+
491
+ results.append(result)
492
+
493
+ df = pd.DataFrame(results)
494
+ df.to_csv(save_path, index=False)
495
+ print(f"💾 Predictions saved to: {save_path}")
496
+
497
+ # Print some statistics
498
+ print(f"\n📊 Prediction Statistics:")
499
+ print(f" Total images: {len(df)}")
500
+ print(f" Correct predictions: {df['correct'].sum()} ({df['correct'].sum()/len(df)*100:.2f}%)")
501
+ print(f" Incorrect predictions: {(~df['correct']).sum()} ({(~df['correct']).sum()/len(df)*100:.2f}%)")
502
+ print(f" Average confidence: {df['confidence'].mean():.4f}")
503
+ print(f" Confidence on correct: {df[df['correct']]['confidence'].mean():.4f}")
504
+ print(f" Confidence on incorrect: {df[~df['correct']]['confidence'].mean():.4f}" if (~df['correct']).sum() > 0 else "")
505
+
506
+ return df
507
+
508
+
509
+ def save_metrics_json(metrics, save_path):
510
+ """Save metrics to JSON file"""
511
+ with open(save_path, 'w') as f:
512
+ json.dump(metrics, f, indent=4)
513
+ print(f"💾 Metrics saved to: {save_path}")
514
+
515
+
516
+ def generate_classification_report_file(y_true, y_pred, save_path):
517
+ """Generate and save sklearn classification report"""
518
+ report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, digits=4)
519
+
520
+ with open(save_path, 'w') as f:
521
+ f.write("="*60 + "\n")
522
+ f.write("CLASSIFICATION REPORT\n")
523
+ f.write("="*60 + "\n\n")
524
+ f.write(report)
525
+
526
+ print(f"📄 Classification report saved to: {save_path}")
527
+
528
+
529
+ # ============================================================
530
+ # MAIN EVALUATION PIPELINE
531
+ # ============================================================
532
+
533
+ def main():
534
+ """Main evaluation pipeline"""
535
+
536
+ print("\n" + "="*60)
537
+ print("CvT-13 MODEL EVALUATION PIPELINE")
538
+ print("Single Image Prediction Mode")
539
+ print("="*60 + "\n")
540
+
541
+ # Load model
542
+ print("📦 Loading model...")
543
+ model = load_model(MODEL_PATH)
544
+
545
+ # Load test data
546
+ print("\n📂 Loading test data...")
547
+ test_loader, test_dataset = get_test_dataloader(TEST_DATA_DIR, batch_size=1)
548
+
549
+ # Run evaluation with single image predictions
550
+ print("\n🔍 Evaluating model (single image predictions)...")
551
+ y_pred, y_true, y_probs, confidences, image_paths = evaluate_model(model, test_loader, test_dataset)
552
+
553
+ # Calculate metrics
554
+ print("\n📊 Calculating metrics...")
555
+ metrics = calculate_metrics(y_true, y_pred, y_probs)
556
+
557
+ # Print key metrics
558
+ print("\n" + "="*60)
559
+ print("EVALUATION RESULTS")
560
+ print("="*60)
561
+ print(f"Total Images Evaluated: {len(y_pred)}")
562
+ print(f"Accuracy: {metrics['accuracy']*100:.2f}%")
563
+ print(f"Precision (Macro): {metrics['precision_macro']*100:.2f}%")
564
+ print(f"Recall (Macro): {metrics['recall_macro']*100:.2f}%")
565
+ print(f"F1-Score (Macro): {metrics['f1_macro']*100:.2f}%")
566
+ if 'roc_auc' in metrics:
567
+ print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
568
+ print("\nPer-Class Metrics:")
569
+ for class_name, class_metrics in metrics['per_class'].items():
570
+ print(f" {class_name}:")
571
+ print(f" Precision: {class_metrics['precision']*100:.2f}%")
572
+ print(f" Recall: {class_metrics['recall']*100:.2f}%")
573
+ print(f" F1-Score: {class_metrics['f1_score']*100:.2f}%")
574
+ print("="*60)
575
+
576
+ # Generate all visualizations
577
+ print("\n📊 Generating visualizations...")
578
+ plot_confusion_matrix(y_true, y_pred,
579
+ os.path.join(EVAL_OUTPUT_DIR, "confusion_matrix.png"))
580
+ plot_roc_curve(y_true, y_probs,
581
+ os.path.join(EVAL_OUTPUT_DIR, "roc_curve.png"))
582
+ plot_precision_recall_curve(y_true, y_probs,
583
+ os.path.join(EVAL_OUTPUT_DIR, "precision_recall_curve.png"))
584
+ plot_class_distribution(y_true, y_pred,
585
+ os.path.join(EVAL_OUTPUT_DIR, "class_distribution.png"))
586
+ plot_per_class_metrics(metrics,
587
+ os.path.join(EVAL_OUTPUT_DIR, "per_class_metrics.png"))
588
+ plot_confidence_distribution(y_true, y_pred, confidences,
589
+ os.path.join(EVAL_OUTPUT_DIR, "confidence_distribution.png"))
590
+
591
+ # Save results
592
+ print("\n💾 Saving results...")
593
+ df = save_predictions_to_csv(image_paths, y_true, y_pred, y_probs, confidences,
594
+ os.path.join(EVAL_OUTPUT_DIR, "predictions.csv"))
595
+ save_metrics_json(metrics,
596
+ os.path.join(EVAL_OUTPUT_DIR, "metrics.json"))
597
+ generate_classification_report_file(y_true, y_pred,
598
+ os.path.join(EVAL_OUTPUT_DIR, "classification_report.txt"))
599
+
600
+ # Save misclassified images list
601
+ misclassified = df[~df['correct']]
602
+ if len(misclassified) > 0:
603
+ misclassified_path = os.path.join(EVAL_OUTPUT_DIR, "misclassified_images.csv")
604
+ misclassified.to_csv(misclassified_path, index=False)
605
+ print(f"⚠️ Misclassified images saved to: {misclassified_path}")
606
+ print(f" Total misclassified: {len(misclassified)}")
607
+
608
+ # Save low confidence predictions
609
+ low_conf_threshold = 0.7
610
+ low_confidence = df[df['confidence'] < low_conf_threshold]
611
+ if len(low_confidence) > 0:
612
+ low_conf_path = os.path.join(EVAL_OUTPUT_DIR, "low_confidence_predictions.csv")
613
+ low_confidence.to_csv(low_conf_path, index=False)
614
+ print(f"⚠️ Low confidence predictions saved to: {low_conf_path}")
615
+ print(f" Total with confidence < {low_conf_threshold}: {len(low_confidence)}")
616
+
617
+ print("\n" + "="*60)
618
+ print(f"✅ Evaluation complete!")
619
+ print(f"📁 All results saved to: {EVAL_OUTPUT_DIR}")
620
+ print("="*60 + "\n")
621
+
622
+ print("Generated files:")
623
+ print(" 📊 confusion_matrix.png - Confusion matrix visualization")
624
+ print(" 📊 roc_curve.png - ROC curve")
625
+ print(" 📊 precision_recall_curve.png - Precision-Recall curve")
626
+ print(" 📊 class_distribution.png - Class distribution comparison")
627
+ print(" 📊 per_class_metrics.png - Per-class performance")
628
+ print(" 📊 confidence_distribution.png - Confidence analysis")
629
+ print(" 💾 predictions.csv - Detailed predictions for each image")
630
+ print(" 💾 misclassified_images.csv - List of incorrectly classified images")
631
+ print(" 💾 low_confidence_predictions.csv - Predictions with low confidence")
632
+ print(" 💾 metrics.json - All metrics in JSON format")
633
+ print(" 📄 classification_report.txt - Sklearn classification report")
634
+
635
+
636
+ if __name__ == '__main__':
637
+ main()
stage3/stage3_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:445c5a5b94b86649cab12ef3c2fe4df9461f9879864c43d52a7cc9560204fcc3
3
+ size 78733538
stage3/train_cvt13.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ import sys
7
+ import os
8
+ from timm.loss import SoftTargetCrossEntropy
9
+ from timm.scheduler import CosineLRScheduler
10
+ from timm.utils import accuracy
11
+ import matplotlib.pyplot as plt
12
+ import json
13
+ from datetime import datetime
14
+
15
+
16
+ # ============================================================
17
+ # SETUP: Clone and import from Microsoft CvT repository
18
+ # ============================================================
19
+ """
20
+ First, clone the Microsoft CvT repository:
21
+ git clone https://github.com/microsoft/CvT.git
22
+ cd CvT
23
+ pip install -r requirements.txt
24
+ """
25
+
26
+ BASE_DIR = "path_to_CornViT"
27
+
28
+ # Add the CvT repo to Python path
29
+ CVT_REPO_PATH = f"{BASE_DIR}/CvT"
30
+
31
+ if not os.path.exists(CVT_REPO_PATH):
32
+ print(f"❌ CvT repository not found at {CVT_REPO_PATH}")
33
+ print("Please clone it: git clone https://github.com/microsoft/CvT.git")
34
+ sys.exit(1)
35
+
36
+ # Fix torch._six compatibility BEFORE importing
37
+ print("Applying compatibility fixes for newer PyTorch versions...")
38
+ cls_cvt_path = os.path.join(CVT_REPO_PATH, "lib", "models", "cls_cvt.py")
39
+
40
+ if os.path.exists(cls_cvt_path):
41
+ with open(cls_cvt_path, 'r', encoding='utf-8') as f:
42
+ content = f.read()
43
+
44
+ # Fix 1: Replace torch._six import
45
+ if "from torch._six import container_abcs" in content:
46
+ content = content.replace(
47
+ "from torch._six import container_abcs",
48
+ "import collections.abc as container_abcs"
49
+ )
50
+
51
+ # Fix 2: Replace 'is' with '==' for string comparison
52
+ content = content.replace(
53
+ "or pretrained_layers[0] is '*'",
54
+ "or pretrained_layers[0] == '*'"
55
+ )
56
+
57
+ with open(cls_cvt_path, 'w', encoding='utf-8') as f:
58
+ f.write(content)
59
+ print("✅ Applied compatibility patches to cls_cvt.py")
60
+ else:
61
+ print("✅ Compatibility patches already applied")
62
+ else:
63
+ print(f"❌ Could not find cls_cvt.py at {cls_cvt_path}")
64
+ sys.exit(1)
65
+
66
+ # Now import
67
+ sys.path.insert(0, CVT_REPO_PATH)
68
+
69
+ # Suppress the SyntaxWarning
70
+ import warnings
71
+ warnings.filterwarnings('ignore', category=SyntaxWarning)
72
+
73
+ from lib.models import cls_cvt
74
+ from lib.config import config, update_config
75
+ print("✅ Successfully imported Microsoft CvT models")
76
+
77
+
78
+ # ============================================================
79
+ # CONFIGURATION
80
+ # ============================================================
81
+
82
+ DATA_DIR = f"{BASE_DIR}/stage3/data"
83
+ BATCH_SIZE = 32
84
+ IMG_SIZE = 384
85
+ NUM_CLASSES = 2
86
+ NUM_EPOCHS = 100
87
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
88
+ PRETRAINED_PATH = f"{BASE_DIR}/CvT-13-384x384-IN-22k.pth"
89
+
90
+ # Create output directory for saving results
91
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
92
+ OUTPUT_DIR = f"metrics/cvt13_run_{timestamp}"
93
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
94
+ print(f"Metrics will be saved to: {OUTPUT_DIR}")
95
+
96
+
97
+ # ============================================================
98
+ # DATASET & AUGMENTATION
99
+ # ============================================================
100
+
101
+ train_transforms = transforms.Compose([
102
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
103
+ transforms.RandomHorizontalFlip(),
104
+ transforms.RandomVerticalFlip(),
105
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
106
+ transforms.RandomRotation(15),
107
+ transforms.ToTensor(),
108
+ transforms.Normalize([0.485, 0.456, 0.406],
109
+ [0.229, 0.224, 0.225])
110
+ ])
111
+
112
+ val_transforms = transforms.Compose([
113
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize([0.485, 0.456, 0.406],
116
+ [0.229, 0.224, 0.225])
117
+ ])
118
+
119
+ train_dataset = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_transforms)
120
+ val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=val_transforms)
121
+
122
+ train_loader = DataLoader(
123
+ train_dataset,
124
+ batch_size=BATCH_SIZE,
125
+ shuffle=True,
126
+ num_workers=0,
127
+ pin_memory=True,
128
+ drop_last=True
129
+ )
130
+ val_loader = DataLoader(
131
+ val_dataset,
132
+ batch_size=BATCH_SIZE,
133
+ shuffle=False,
134
+ num_workers=0,
135
+ pin_memory=True,
136
+ drop_last=True
137
+ )
138
+
139
+
140
+ # ============================================================
141
+ # MODEL SETUP - Using Microsoft CvT Implementation
142
+ # ============================================================
143
+
144
+ # Load the CvT-13 config from the repository
145
+ cvt_config_path = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt", "cvt-13-384x384.yaml")
146
+
147
+ if not os.path.exists(cvt_config_path):
148
+ print(f"⚠️ Config file not found at {cvt_config_path}")
149
+ print("Available configs:")
150
+ config_dir = os.path.join(CVT_REPO_PATH, "experiments", "imagenet", "cvt")
151
+ if os.path.exists(config_dir):
152
+ for f in os.listdir(config_dir):
153
+ if f.endswith('.yaml'):
154
+ print(f" - {f}")
155
+ sys.exit(1)
156
+
157
+ print(f"Loading config from: {cvt_config_path}")
158
+
159
+ # Load config directly using merge_from_file
160
+ config.defrost()
161
+ config.merge_from_file(cvt_config_path)
162
+
163
+ # Update the number of classes for our task
164
+ config.MODEL.NUM_CLASSES = NUM_CLASSES
165
+ config.MODEL.PRETRAINED = '' # We'll load weights manually
166
+ config.freeze()
167
+
168
+ print("Creating CvT-13 model...")
169
+ # Create model using the official CvT architecture
170
+ model = cls_cvt.get_cls_model(config)
171
+ model = model.to(DEVICE)
172
+
173
+ # Load pretrained weights
174
+ if os.path.exists(PRETRAINED_PATH):
175
+ print(f"Loading pretrained weights from {PRETRAINED_PATH}")
176
+ try:
177
+ checkpoint = torch.load(PRETRAINED_PATH, map_location=DEVICE)
178
+
179
+ # Handle different checkpoint formats
180
+ if 'model' in checkpoint:
181
+ state_dict = checkpoint['model']
182
+ elif 'state_dict' in checkpoint:
183
+ state_dict = checkpoint['state_dict']
184
+ else:
185
+ state_dict = checkpoint
186
+
187
+ # Remove 'module.' prefix if present
188
+ new_state_dict = {}
189
+ for k, v in state_dict.items():
190
+ name = k.replace("module.", "")
191
+ new_state_dict[name] = v
192
+
193
+ # Remove head layers from pretrained weights (they have different dimensions)
194
+ filtered_state_dict = {k: v for k, v in new_state_dict.items() if 'head' not in k}
195
+
196
+ # Load weights - strict=False will only load matching layers
197
+ missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
198
+
199
+ # Count how many weights were actually loaded
200
+ loaded_keys = [k for k in filtered_state_dict.keys() if k in model.state_dict()]
201
+ print(f"✅ Loaded pretrained weights: {len(loaded_keys)} layers from backbone")
202
+ print(f" Head layer initialized randomly for {NUM_CLASSES} classes")
203
+
204
+ # Show what's missing (should only be head-related)
205
+ head_missing = [k for k in missing_keys if 'head' in k]
206
+ other_missing = [k for k in missing_keys if 'head' not in k]
207
+
208
+ if other_missing:
209
+ print(f"⚠️ Warning - Missing non-head keys: {other_missing}")
210
+ if unexpected_keys:
211
+ print(f"⚠️ Unexpected keys: {unexpected_keys}")
212
+
213
+ except Exception as e:
214
+ print(f"⚠️ Error loading pretrained weights: {e}")
215
+ import traceback
216
+ traceback.print_exc()
217
+ print("Continuing with random initialization...")
218
+ else:
219
+ print(f"⚠️ Pretrained weights not found at {PRETRAINED_PATH}")
220
+ print("Training from scratch...")
221
+
222
+ # Freeze backbone - only train the head for faster training and less overfitting
223
+ print("Freezing backbone layers (keeping only head trainable)...")
224
+ for name, param in model.named_parameters():
225
+ if "head" not in name:
226
+ param.requires_grad = False
227
+
228
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
229
+ total_params = sum(p.numel() for p in model.parameters())
230
+ print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")
231
+ print(f"Frozen parameters: {total_params - trainable_params:,}")
232
+
233
+
234
+ # ============================================================
235
+ # OPTIMIZER AND LOSS
236
+ # ============================================================
237
+
238
+ optimizer = optim.AdamW(
239
+ filter(lambda p: p.requires_grad, model.parameters()),
240
+ lr=1e-4,
241
+ weight_decay=0.05
242
+ )
243
+
244
+ criterion = SoftTargetCrossEntropy()
245
+
246
+ lr_scheduler = CosineLRScheduler(
247
+ optimizer,
248
+ t_initial=NUM_EPOCHS,
249
+ lr_min=1e-6,
250
+ warmup_t=5,
251
+ warmup_lr_init=1e-5,
252
+ )
253
+
254
+
255
+ # ============================================================
256
+ # TRAINING & VALIDATION LOOP
257
+ # ============================================================
258
+
259
+ def train_one_epoch(epoch, history):
260
+ model.train()
261
+ total_loss, total_acc = 0, 0
262
+
263
+ for images, targets in train_loader:
264
+ images, targets = images.to(DEVICE), targets.to(DEVICE)
265
+
266
+ optimizer.zero_grad()
267
+ outputs = model(images)
268
+ loss = criterion(outputs, targets)
269
+ loss.backward()
270
+ optimizer.step()
271
+
272
+ acc1, _ = accuracy(outputs, targets.argmax(dim=1), topk=(1, 5))
273
+ total_loss += loss.item()
274
+ total_acc += acc1.item()
275
+
276
+ avg_loss = total_loss / len(train_loader)
277
+ avg_acc = total_acc / len(train_loader)
278
+
279
+ history['train_loss'].append(avg_loss)
280
+ history['train_acc'].append(avg_acc)
281
+ history['learning_rate'].append(optimizer.param_groups[0]['lr'])
282
+
283
+ print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Train Loss: {avg_loss:.4f} | Train Acc: {avg_acc:.2f}% | LR: {optimizer.param_groups[0]['lr']:.6f}")
284
+ return avg_loss, avg_acc
285
+
286
+
287
+ def validate(epoch, history):
288
+ model.eval()
289
+ total_loss, total_acc = 0, 0
290
+
291
+ with torch.no_grad():
292
+ for images, targets in val_loader:
293
+ images, targets = images.to(DEVICE), targets.to(DEVICE)
294
+ outputs = model(images)
295
+ loss = nn.CrossEntropyLoss()(outputs, targets)
296
+ acc1, _ = accuracy(outputs, targets, topk=(1, 5))
297
+
298
+ total_loss += loss.item()
299
+ total_acc += acc1.item()
300
+
301
+ avg_loss = total_loss / len(val_loader)
302
+ avg_acc = total_acc / len(val_loader)
303
+
304
+ history['val_loss'].append(avg_loss)
305
+ history['val_acc'].append(avg_acc)
306
+
307
+ print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Val Loss: {avg_loss:.4f} | Val Acc: {avg_acc:.2f}%")
308
+ return avg_acc
309
+
310
+
311
+ def plot_training_history(history, save_path):
312
+ """Plot and save training metrics"""
313
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
314
+
315
+ epochs = range(1, len(history['train_loss']) + 1)
316
+
317
+ # Plot 1: Loss
318
+ axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
319
+ axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
320
+ axes[0, 0].set_xlabel('Epoch', fontsize=12)
321
+ axes[0, 0].set_ylabel('Loss', fontsize=12)
322
+ axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
323
+ axes[0, 0].legend()
324
+ axes[0, 0].grid(True, alpha=0.3)
325
+
326
+ # Plot 2: Accuracy
327
+ axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
328
+ axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
329
+ axes[0, 1].set_xlabel('Epoch', fontsize=12)
330
+ axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
331
+ axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
332
+ axes[0, 1].legend()
333
+ axes[0, 1].grid(True, alpha=0.3)
334
+
335
+ # Plot 3: Learning Rate
336
+ axes[1, 0].plot(epochs, history['learning_rate'], 'g-', linewidth=2)
337
+ axes[1, 0].set_xlabel('Epoch', fontsize=12)
338
+ axes[1, 0].set_ylabel('Learning Rate', fontsize=12)
339
+ axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
340
+ axes[1, 0].set_yscale('log')
341
+ axes[1, 0].grid(True, alpha=0.3)
342
+
343
+ # Plot 4: Val Acc vs Train Acc (Overfitting check)
344
+ axes[1, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
345
+ axes[1, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
346
+ gap = [t - v for t, v in zip(history['train_acc'], history['val_acc'])]
347
+ axes[1, 1].fill_between(epochs, history['val_acc'], history['train_acc'],
348
+ alpha=0.3, color='orange', label='Overfitting Gap')
349
+ axes[1, 1].set_xlabel('Epoch', fontsize=12)
350
+ axes[1, 1].set_ylabel('Accuracy (%)', fontsize=12)
351
+ axes[1, 1].set_title('Overfitting Analysis', fontsize=14, fontweight='bold')
352
+ axes[1, 1].legend()
353
+ axes[1, 1].grid(True, alpha=0.3)
354
+
355
+ plt.tight_layout()
356
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
357
+ print(f"📊 Training plots saved to: {save_path}")
358
+ plt.close()
359
+
360
+
361
+ def save_training_summary(history, best_acc, save_path):
362
+ """Save training summary as JSON"""
363
+ summary = {
364
+ 'config': {
365
+ 'model': 'CvT-13',
366
+ 'batch_size': BATCH_SIZE,
367
+ 'img_size': IMG_SIZE,
368
+ 'num_classes': NUM_CLASSES,
369
+ 'num_epochs': NUM_EPOCHS,
370
+ 'device': DEVICE,
371
+ 'pretrained': PRETRAINED_PATH,
372
+ },
373
+ 'final_metrics': {
374
+ 'best_val_accuracy': best_acc,
375
+ 'final_train_loss': history['train_loss'][-1],
376
+ 'final_train_acc': history['train_acc'][-1],
377
+ 'final_val_loss': history['val_loss'][-1],
378
+ 'final_val_acc': history['val_acc'][-1],
379
+ },
380
+ 'history': history
381
+ }
382
+
383
+ with open(save_path, 'w') as f:
384
+ json.dump(summary, f, indent=4)
385
+
386
+ print(f"💾 Training summary saved to: {save_path}")
387
+
388
+
389
+ # ============================================================
390
+ # MAIN TRAINING LOOP
391
+ # ============================================================
392
+
393
+ if __name__ == '__main__':
394
+ print("\n" + "="*60)
395
+ print("STARTING TRAINING")
396
+ print("="*60 + "\n")
397
+
398
+ # Initialize history tracking
399
+ history = {
400
+ 'train_loss': [],
401
+ 'train_acc': [],
402
+ 'val_loss': [],
403
+ 'val_acc': [],
404
+ 'learning_rate': []
405
+ }
406
+
407
+ best_acc = 0.0
408
+ best_epoch = 0
409
+
410
+ for epoch in range(NUM_EPOCHS):
411
+ train_loss, train_acc = train_one_epoch(epoch, history)
412
+ val_acc = validate(epoch, history)
413
+ lr_scheduler.step(epoch + 1)
414
+
415
+ # Save best model
416
+ if val_acc > best_acc:
417
+ best_acc = val_acc
418
+ best_epoch = epoch + 1
419
+ torch.save({
420
+ 'epoch': epoch,
421
+ 'model_state_dict': model.state_dict(),
422
+ 'optimizer_state_dict': optimizer.state_dict(),
423
+ 'best_acc': best_acc,
424
+ 'history': history,
425
+ }, os.path.join(OUTPUT_DIR, "best_model.pth"))
426
+ print(f"✅ Saved best model at epoch {epoch+1} with val acc {best_acc:.2f}%\n")
427
+
428
+ # Save checkpoint every 10 epochs
429
+ if (epoch + 1) % 10 == 0:
430
+ torch.save({
431
+ 'epoch': epoch,
432
+ 'model_state_dict': model.state_dict(),
433
+ 'optimizer_state_dict': optimizer.state_dict(),
434
+ 'val_acc': val_acc,
435
+ 'history': history,
436
+ }, os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
437
+ print(f"💾 Checkpoint saved at epoch {epoch+1}\n")
438
+
439
+ # Plot and save metrics every 5 epochs
440
+ if (epoch + 1) % 5 == 0 or epoch == NUM_EPOCHS - 1:
441
+ plot_training_history(history, os.path.join(OUTPUT_DIR, "training_metrics.png"))
442
+
443
+ # Final summary
444
+ print("="*60)
445
+ print(f"🎉 Training complete!")
446
+ print(f"Best validation accuracy: {best_acc:.2f}% at epoch {best_epoch}")
447
+ print(f"Final train accuracy: {history['train_acc'][-1]:.2f}%")
448
+ print(f"Final val accuracy: {history['val_acc'][-1]:.2f}%")
449
+ print("="*60)
450
+
451
+ # Save final training summary
452
+ save_training_summary(history, best_acc, os.path.join(OUTPUT_DIR, "training_summary.json"))
453
+
454
+ # Save final plot
455
+ plot_training_history(history, os.path.join(OUTPUT_DIR, "final_training_metrics.png"))
456
+
457
+ print(f"\n📁 All outputs saved to: {OUTPUT_DIR}")