Spaces:
Sleeping
Sleeping
Upload 27 files
Browse filescommit to add initial codes including testing codes
- backend/__init__.py +0 -0
- backend/__pycache__/main.cpython-310.pyc +0 -0
- backend/__pycache__/model_loader.cpython-310.pyc +0 -0
- backend/__pycache__/predict.cpython-310.pyc +0 -0
- backend/__pycache__/schemas.cpython-310.pyc +0 -0
- backend/main.py +22 -0
- backend/model_loader.py +42 -0
- backend/predict.py +19 -0
- backend/schemas.py +8 -0
- checkpoints/best_model.pth +3 -0
- requirements.txt +9 -0
- scripts/01_merge_datasets.py +56 -0
- scripts/02_resize_images.py +30 -0
- scripts/03_create_metadata.py +22 -0
- training/__pycache__/dataloader.cpython-310.pyc +0 -0
- training/__pycache__/dataset.cpython-310.pyc +0 -0
- training/__pycache__/model.cpython-310.pyc +0 -0
- training/__pycache__/utils.cpython-310.pyc +0 -0
- training/dataloader.py +47 -0
- training/dataset.py +48 -0
- training/evaluate.py +51 -0
- training/model.py +23 -0
- training/test_dataset.py +13 -0
- training/test_loader.py +12 -0
- training/test_model.py +26 -0
- training/train.py +160 -0
- training/utils.py +13 -0
backend/__init__.py
ADDED
|
File without changes
|
backend/__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (787 Bytes). View file
|
|
|
backend/__pycache__/model_loader.cpython-310.pyc
ADDED
|
Binary file (1.4 kB). View file
|
|
|
backend/__pycache__/predict.cpython-310.pyc
ADDED
|
Binary file (695 Bytes). View file
|
|
|
backend/__pycache__/schemas.cpython-310.pyc
ADDED
|
Binary file (449 Bytes). View file
|
|
|
backend/main.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import io
|
| 4 |
+
|
| 5 |
+
from backend.predict import predict_image
|
| 6 |
+
from backend.schemas import PredictionResponse
|
| 7 |
+
|
| 8 |
+
app = FastAPI(
|
| 9 |
+
title="X-Ray Classification API",
|
| 10 |
+
description="Deep learning based multi-class X-ray classifier",
|
| 11 |
+
version="1.0"
|
| 12 |
+
)
|
| 13 |
+
# Endpoint for predicting the class of an uploaded X-ray image.
|
| 14 |
+
# Accepts an image file and returns the predicted class label and confidence score.
|
| 15 |
+
|
| 16 |
+
@app.post("/predict", response_model=PredictionResponse)
|
| 17 |
+
async def predict(file: UploadFile = File(...)):
|
| 18 |
+
image_bytes = await file.read()
|
| 19 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 20 |
+
|
| 21 |
+
result = predict_image(image)
|
| 22 |
+
return result
|
backend/model_loader.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from training.model import build_model
|
| 4 |
+
from training.utils import get_device
|
| 5 |
+
|
| 6 |
+
CHECKPOINT_PATH = "checkpoints/best_model.pth"
|
| 7 |
+
LABEL_MAP_PATH = "data_processed/label_map.csv"
|
| 8 |
+
|
| 9 |
+
# Wrapper class for loading the model and making predictions on new data instances
|
| 10 |
+
class ModelWrapper:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.device = get_device()
|
| 13 |
+
|
| 14 |
+
label_df = pd.read_csv(LABEL_MAP_PATH)
|
| 15 |
+
self.id_to_label = dict(
|
| 16 |
+
zip(label_df["label_id"], label_df["label"])
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
num_classes = len(self.id_to_label)
|
| 20 |
+
|
| 21 |
+
self.model = build_model(num_classes, self.device)
|
| 22 |
+
self.model.load_state_dict(
|
| 23 |
+
torch.load(CHECKPOINT_PATH, map_location=self.device)
|
| 24 |
+
)
|
| 25 |
+
self.model.eval()
|
| 26 |
+
|
| 27 |
+
def predict(self, image_tensor):
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
image_tensor = image_tensor.to(self.device)
|
| 30 |
+
outputs = self.model(image_tensor)
|
| 31 |
+
probs = torch.softmax(outputs, dim=1)
|
| 32 |
+
|
| 33 |
+
confidence, pred_id = torch.max(probs, dim=1)
|
| 34 |
+
|
| 35 |
+
return {
|
| 36 |
+
"label_id": int(pred_id.item()),
|
| 37 |
+
"label_name": self.id_to_label[int(pred_id.item())],
|
| 38 |
+
"confidence": float(confidence.item())
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
model_wrapper = ModelWrapper()
|
backend/predict.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from backend.model_loader import model_wrapper
|
| 5 |
+
|
| 6 |
+
# Defining the image transformation pipeline for preprocessing
|
| 7 |
+
transform = transforms.Compose([
|
| 8 |
+
transforms.Resize((224, 224)),
|
| 9 |
+
transforms.ToTensor(),
|
| 10 |
+
transforms.Normalize(
|
| 11 |
+
mean=[0.485, 0.456, 0.406],
|
| 12 |
+
std=[0.229, 0.224, 0.225]
|
| 13 |
+
)
|
| 14 |
+
])
|
| 15 |
+
# Function to predict the class of an input image using the loaded best model
|
| 16 |
+
def predict_image(image: Image.Image):
|
| 17 |
+
image = image.convert("RGB")
|
| 18 |
+
tensor = transform(image).unsqueeze(0)
|
| 19 |
+
return model_wrapper.predict(tensor)
|
backend/schemas.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Module that defines the schema for prediction responses on the API.
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PredictionResponse(BaseModel):
|
| 6 |
+
label_id: int
|
| 7 |
+
label_name: str
|
| 8 |
+
confidence: float
|
checkpoints/best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d83d0507e2b50d687569e7ccab21287169f4a13ba3cd7441437a2001cefa82b6
|
| 3 |
+
size 94842430
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
fastapi
|
| 3 |
+
torch==2.1.0
|
| 4 |
+
torchvision==0.16.0
|
| 5 |
+
pillow
|
| 6 |
+
numpy
|
| 7 |
+
pandas
|
| 8 |
+
scikit-learn
|
| 9 |
+
python-multipart
|
scripts/01_merge_datasets.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import csv
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
RAW_ROOT = Path("raw_data")
|
| 7 |
+
OUT_IMG = Path("data_merged/images")
|
| 8 |
+
OUT_CSV = Path("data_merged/metadata_raw.csv")
|
| 9 |
+
|
| 10 |
+
OUT_IMG.mkdir(parents=True, exist_ok=True)
|
| 11 |
+
|
| 12 |
+
rows = []
|
| 13 |
+
img_id = 0
|
| 14 |
+
VALID_EXT = (".png", ".jpg", ".jpeg")
|
| 15 |
+
|
| 16 |
+
def merge_any_dataset(dataset_name, base_path):
|
| 17 |
+
global img_id
|
| 18 |
+
for root, _, files in os.walk(base_path):
|
| 19 |
+
for f in files:
|
| 20 |
+
if not f.lower().endswith(VALID_EXT):
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
src = Path(root) / f
|
| 24 |
+
class_name = Path(root).name
|
| 25 |
+
|
| 26 |
+
new_name = f"{dataset_name}__{class_name}__{img_id}{src.suffix}"
|
| 27 |
+
dst = OUT_IMG / new_name
|
| 28 |
+
|
| 29 |
+
shutil.copy(src, dst)
|
| 30 |
+
|
| 31 |
+
rows.append({
|
| 32 |
+
"image_id": img_id,
|
| 33 |
+
"filename": new_name,
|
| 34 |
+
"label": class_name,
|
| 35 |
+
"source": dataset_name
|
| 36 |
+
})
|
| 37 |
+
|
| 38 |
+
img_id += 1
|
| 39 |
+
|
| 40 |
+
# Merging all datasets found in RAW_ROOT
|
| 41 |
+
for item in RAW_ROOT.iterdir():
|
| 42 |
+
if item.is_dir():
|
| 43 |
+
merge_any_dataset(item.name, item)
|
| 44 |
+
|
| 45 |
+
# writing out the CSV
|
| 46 |
+
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
with open(OUT_CSV, "w", newline="", encoding="utf-8") as f:
|
| 48 |
+
writer = csv.DictWriter(
|
| 49 |
+
f,
|
| 50 |
+
fieldnames=["image_id", "filename", "label", "source"]
|
| 51 |
+
)
|
| 52 |
+
writer.writeheader()
|
| 53 |
+
writer.writerows(rows)
|
| 54 |
+
|
| 55 |
+
print("Merged dataset created")
|
| 56 |
+
print("Images:", len(rows))
|
scripts/02_resize_images.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
IN_IMG = Path("data_merged/images")
|
| 6 |
+
OUT_IMG = Path("data_processed/images")
|
| 7 |
+
IN_CSV = Path("data_merged/metadata_raw.csv")
|
| 8 |
+
OUT_CSV = Path("data_processed/metadata_resized.csv")
|
| 9 |
+
|
| 10 |
+
OUT_IMG.mkdir(parents=True, exist_ok=True)
|
| 11 |
+
|
| 12 |
+
df = pd.read_csv(IN_CSV)
|
| 13 |
+
kept_rows = []
|
| 14 |
+
|
| 15 |
+
for _, row in df.iterrows():
|
| 16 |
+
src = IN_IMG / row["filename"]
|
| 17 |
+
dst = OUT_IMG / row["filename"]
|
| 18 |
+
|
| 19 |
+
img = cv2.imread(str(src))
|
| 20 |
+
if img is None:
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
img = cv2.resize(img, (224, 224))
|
| 24 |
+
cv2.imwrite(str(dst), img)
|
| 25 |
+
|
| 26 |
+
kept_rows.append(row)
|
| 27 |
+
|
| 28 |
+
pd.DataFrame(kept_rows).to_csv(OUT_CSV, index=False)
|
| 29 |
+
|
| 30 |
+
print("Images kept:", len(kept_rows))
|
scripts/03_create_metadata.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
IN_CSV = Path("data_processed/metadata_resized.csv")
|
| 5 |
+
OUT_CSV = Path("data_processed/metadata_final.csv")
|
| 6 |
+
MAP_CSV = Path("data_processed/label_map.csv")
|
| 7 |
+
|
| 8 |
+
df = pd.read_csv(IN_CSV)
|
| 9 |
+
|
| 10 |
+
labels = sorted(df["label"].unique())
|
| 11 |
+
label_to_id = {label: i for i, label in enumerate(labels)}
|
| 12 |
+
|
| 13 |
+
df["label_id"] = df["label"].map(label_to_id)
|
| 14 |
+
|
| 15 |
+
df.to_csv(OUT_CSV, index=False)
|
| 16 |
+
|
| 17 |
+
pd.DataFrame({
|
| 18 |
+
"label": labels,
|
| 19 |
+
"label_id": [label_to_id[l] for l in labels]
|
| 20 |
+
}).to_csv(MAP_CSV, index=False)
|
| 21 |
+
|
| 22 |
+
print("Total classes:", len(labels))
|
training/__pycache__/dataloader.cpython-310.pyc
ADDED
|
Binary file (809 Bytes). View file
|
|
|
training/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (1.61 kB). View file
|
|
|
training/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (690 Bytes). View file
|
|
|
training/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (549 Bytes). View file
|
|
|
training/dataloader.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Module for creating data loaders for training and validation datasets.
|
| 2 |
+
from torch.utils.data import DataLoader, random_split
|
| 3 |
+
from dataset import XRayDataset
|
| 4 |
+
|
| 5 |
+
def get_dataloaders(
|
| 6 |
+
csv_path,
|
| 7 |
+
images_dir,
|
| 8 |
+
batch_size=32,
|
| 9 |
+
val_split=0.2
|
| 10 |
+
):
|
| 11 |
+
full_dataset = XRayDataset(
|
| 12 |
+
csv_path=csv_path,
|
| 13 |
+
images_dir=images_dir,
|
| 14 |
+
train=True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
val_size = int(len(full_dataset) * val_split)
|
| 18 |
+
train_size = len(full_dataset) - val_size
|
| 19 |
+
|
| 20 |
+
train_ds, val_ds = random_split(
|
| 21 |
+
full_dataset,
|
| 22 |
+
[train_size, val_size]
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Disable augmentation for validation dataset so that we only apply normalization
|
| 26 |
+
val_ds.dataset.transform = XRayDataset(
|
| 27 |
+
csv_path,
|
| 28 |
+
images_dir,
|
| 29 |
+
train=False
|
| 30 |
+
).transform
|
| 31 |
+
|
| 32 |
+
train_loader = DataLoader(
|
| 33 |
+
train_ds,
|
| 34 |
+
batch_size=batch_size,
|
| 35 |
+
shuffle=True,
|
| 36 |
+
num_workers=0
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
val_loader = DataLoader(
|
| 40 |
+
val_ds,
|
| 41 |
+
batch_size=batch_size,
|
| 42 |
+
shuffle=False,
|
| 43 |
+
num_workers=0
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
return train_loader, val_loader
|
training/dataset.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
# Custom Dataset for X-Ray Images with Augmentations for Training and Standard Transformations for Validation
|
| 8 |
+
|
| 9 |
+
class XRayDataset(Dataset):
|
| 10 |
+
def __init__(self, csv_path, images_dir, train=True):
|
| 11 |
+
self.df = pd.read_csv(csv_path)
|
| 12 |
+
self.images_dir = Path(images_dir)
|
| 13 |
+
|
| 14 |
+
if train:
|
| 15 |
+
self.transform = transforms.Compose([
|
| 16 |
+
transforms.Resize((224, 224)),
|
| 17 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 18 |
+
transforms.RandomRotation(15),
|
| 19 |
+
transforms.RandomResizedCrop(224, scale=(0.85, 1.0)),
|
| 20 |
+
transforms.ColorJitter(brightness=0.1, contrast=0.1),
|
| 21 |
+
transforms.ToTensor(),
|
| 22 |
+
transforms.Normalize(
|
| 23 |
+
mean=[0.485, 0.456, 0.406],
|
| 24 |
+
std=[0.229, 0.224, 0.225]
|
| 25 |
+
)
|
| 26 |
+
])
|
| 27 |
+
else:
|
| 28 |
+
self.transform = transforms.Compose([
|
| 29 |
+
transforms.Resize((224, 224)),
|
| 30 |
+
transforms.ToTensor(),
|
| 31 |
+
transforms.Normalize(
|
| 32 |
+
mean=[0.485, 0.456, 0.406],
|
| 33 |
+
std=[0.229, 0.224, 0.225]
|
| 34 |
+
)
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.df)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
row = self.df.iloc[idx]
|
| 42 |
+
img_path = self.images_dir / row["filename"]
|
| 43 |
+
label = row["label_id"]
|
| 44 |
+
|
| 45 |
+
image = Image.open(img_path).convert("RGB")
|
| 46 |
+
image = self.transform(image)
|
| 47 |
+
|
| 48 |
+
return image, label
|
training/evaluate.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import seaborn as sns
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from sklearn.metrics import confusion_matrix, classification_report
|
| 6 |
+
|
| 7 |
+
from model import build_model
|
| 8 |
+
from dataloader import get_dataloaders
|
| 9 |
+
from utils import get_device
|
| 10 |
+
|
| 11 |
+
CSV_PATH = "data_processed/metadata_final.csv"
|
| 12 |
+
IMG_DIR = "data_processed/images"
|
| 13 |
+
CHECKPOINT_PATH = "checkpoints/best_model.pth"
|
| 14 |
+
|
| 15 |
+
device = get_device()
|
| 16 |
+
|
| 17 |
+
df = pd.read_csv(CSV_PATH)
|
| 18 |
+
num_classes = df["label_id"].nunique()
|
| 19 |
+
|
| 20 |
+
model = build_model(num_classes, device)
|
| 21 |
+
model.load_state_dict(torch.load(CHECKPOINT_PATH))
|
| 22 |
+
model.eval()
|
| 23 |
+
|
| 24 |
+
_, val_loader = get_dataloaders(
|
| 25 |
+
csv_path=CSV_PATH,
|
| 26 |
+
images_dir=IMG_DIR,
|
| 27 |
+
batch_size=32
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
y_true, y_pred = [], []
|
| 31 |
+
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
for images, labels in val_loader:
|
| 34 |
+
images = images.to(device)
|
| 35 |
+
outputs = model(images)
|
| 36 |
+
preds = outputs.argmax(dim=1).cpu().numpy()
|
| 37 |
+
|
| 38 |
+
y_pred.extend(preds)
|
| 39 |
+
y_true.extend(labels.numpy())
|
| 40 |
+
|
| 41 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 42 |
+
|
| 43 |
+
plt.figure(figsize=(14, 12))
|
| 44 |
+
sns.heatmap(cm, cmap="Blues", xticklabels=False, yticklabels=False)
|
| 45 |
+
plt.title("Confusion Matrix")
|
| 46 |
+
plt.xlabel("Predicted")
|
| 47 |
+
plt.ylabel("True")
|
| 48 |
+
plt.show()
|
| 49 |
+
|
| 50 |
+
print("\nClassification Report:")
|
| 51 |
+
print(classification_report(y_true, y_pred))
|
training/model.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision import models
|
| 4 |
+
|
| 5 |
+
def build_model(num_classes: int, device: torch.device):
|
| 6 |
+
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
|
| 7 |
+
|
| 8 |
+
# Freeze early layers
|
| 9 |
+
for name, param in model.named_parameters():
|
| 10 |
+
if not (
|
| 11 |
+
name.startswith("layer3") or
|
| 12 |
+
name.startswith("layer4") or
|
| 13 |
+
name.startswith("fc")
|
| 14 |
+
):
|
| 15 |
+
param.requires_grad = False
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Replace classifier
|
| 20 |
+
in_features = model.fc.in_features
|
| 21 |
+
model.fc = nn.Linear(in_features, num_classes)
|
| 22 |
+
|
| 23 |
+
return model.to(device)
|
training/test_dataset.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Testing the XRayDataset class for correctness and functionality
|
| 2 |
+
from dataset import XRayDataset
|
| 3 |
+
|
| 4 |
+
ds = XRayDataset(
|
| 5 |
+
csv_path="data_processed/metadata_final.csv",
|
| 6 |
+
images_dir="data_processed/images",
|
| 7 |
+
train=True
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
print("Total samples:", len(ds))
|
| 11 |
+
img, label = ds[0]
|
| 12 |
+
print("Image shape:", img.shape)
|
| 13 |
+
print("Label:", label)
|
training/test_loader.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Testing the dataloader functionality
|
| 2 |
+
from dataloader import get_dataloaders
|
| 3 |
+
|
| 4 |
+
train_loader, val_loader = get_dataloaders(
|
| 5 |
+
csv_path="data_processed/metadata_final.csv",
|
| 6 |
+
images_dir="data_processed/images",
|
| 7 |
+
batch_size=32
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
images, labels = next(iter(train_loader))
|
| 11 |
+
print("Batch image shape:", images.shape)
|
| 12 |
+
print("Batch labels shape:", labels.shape)
|
training/test_model.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torchvision import models
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def build_model(num_classes, device):
|
| 6 |
+
model = models.resnet18(
|
| 7 |
+
weights=models.ResNet18_Weights.IMAGENET1K_V1
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
# Freezing everything
|
| 11 |
+
for param in model.parameters():
|
| 12 |
+
param.requires_grad = False
|
| 13 |
+
|
| 14 |
+
# Unfreezing deeper layers
|
| 15 |
+
for param in model.layer3.parameters():
|
| 16 |
+
param.requires_grad = True
|
| 17 |
+
|
| 18 |
+
for param in model.layer4.parameters():
|
| 19 |
+
param.requires_grad = True
|
| 20 |
+
|
| 21 |
+
# Replacing classifier for our number of classes
|
| 22 |
+
in_features = model.fc.in_features
|
| 23 |
+
model.fc = nn.Linear(in_features, num_classes)
|
| 24 |
+
|
| 25 |
+
model = model.to(device)
|
| 26 |
+
return model
|
training/train.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from dataloader import get_dataloaders
|
| 10 |
+
from model import build_model
|
| 11 |
+
from utils import get_device, accuracy
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compute_class_weights(csv_path):
|
| 15 |
+
df = pd.read_csv(csv_path)
|
| 16 |
+
|
| 17 |
+
class_counts = df["label_id"].value_counts().sort_index()
|
| 18 |
+
total_samples = class_counts.sum()
|
| 19 |
+
|
| 20 |
+
class_counts = torch.tensor(class_counts.values, dtype=torch.float32)
|
| 21 |
+
|
| 22 |
+
# Soft inverse-frequency weighting
|
| 23 |
+
weights = total_samples / class_counts
|
| 24 |
+
|
| 25 |
+
# Log-scale to reduce extremes
|
| 26 |
+
weights = torch.log1p(weights)
|
| 27 |
+
|
| 28 |
+
# Normalize
|
| 29 |
+
weights = weights / weights.mean()
|
| 30 |
+
|
| 31 |
+
# 🔒 Cap extreme weights (critical)
|
| 32 |
+
weights = torch.clamp(weights, max=3.0)
|
| 33 |
+
|
| 34 |
+
return weights
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Train and validation functions for one epoch each
|
| 39 |
+
|
| 40 |
+
def train_one_epoch(model, loader, criterion, optimizer, device):
|
| 41 |
+
model.train()
|
| 42 |
+
total_loss, total_acc = 0.0, 0.0
|
| 43 |
+
|
| 44 |
+
for images, labels in tqdm(loader, desc="Training", leave=False):
|
| 45 |
+
images, labels = images.to(device), labels.to(device)
|
| 46 |
+
|
| 47 |
+
optimizer.zero_grad()
|
| 48 |
+
outputs = model(images)
|
| 49 |
+
loss = criterion(outputs, labels)
|
| 50 |
+
|
| 51 |
+
loss.backward()
|
| 52 |
+
optimizer.step()
|
| 53 |
+
|
| 54 |
+
total_loss += loss.item()
|
| 55 |
+
total_acc += accuracy(outputs, labels)
|
| 56 |
+
|
| 57 |
+
return total_loss / len(loader), total_acc / len(loader)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def validate_one_epoch(model, loader, criterion, device):
|
| 61 |
+
model.eval()
|
| 62 |
+
total_loss, total_acc = 0.0, 0.0
|
| 63 |
+
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
for images, labels in tqdm(loader, desc="Validation", leave=False):
|
| 66 |
+
images, labels = images.to(device), labels.to(device)
|
| 67 |
+
outputs = model(images)
|
| 68 |
+
loss = criterion(outputs, labels)
|
| 69 |
+
|
| 70 |
+
total_loss += loss.item()
|
| 71 |
+
total_acc += accuracy(outputs, labels)
|
| 72 |
+
|
| 73 |
+
return total_loss / len(loader), total_acc / len(loader)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main():
|
| 77 |
+
#Hyperparameters and paths
|
| 78 |
+
BATCH_SIZE = 32
|
| 79 |
+
EPOCHS = 20
|
| 80 |
+
LR = 1e-4
|
| 81 |
+
PATIENCE = 4
|
| 82 |
+
|
| 83 |
+
CSV_PATH = "data_processed/metadata_final.csv"
|
| 84 |
+
IMG_DIR = "data_processed/images"
|
| 85 |
+
CHECKPOINT_DIR = "checkpoints"
|
| 86 |
+
CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/best_model.pth"
|
| 87 |
+
|
| 88 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
#Setup
|
| 91 |
+
device = get_device()
|
| 92 |
+
print("Using device:", device)
|
| 93 |
+
|
| 94 |
+
df = pd.read_csv(CSV_PATH)
|
| 95 |
+
num_classes = df["label_id"].nunique()
|
| 96 |
+
|
| 97 |
+
train_loader, val_loader = get_dataloaders(
|
| 98 |
+
csv_path=CSV_PATH,
|
| 99 |
+
images_dir=IMG_DIR,
|
| 100 |
+
batch_size=BATCH_SIZE
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
model = build_model(num_classes, device)
|
| 104 |
+
|
| 105 |
+
class_weights = compute_class_weights(CSV_PATH).to(device)
|
| 106 |
+
criterion = nn.CrossEntropyLoss(
|
| 107 |
+
weight=class_weights,
|
| 108 |
+
label_smoothing=0.02
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
optimizer = torch.optim.AdamW(
|
| 113 |
+
model.parameters(),
|
| 114 |
+
lr=LR,
|
| 115 |
+
weight_decay=1e-4
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Learning rate scheduler so that lr reduces if val loss plateaus
|
| 119 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 120 |
+
optimizer, mode="min", patience=2, factor=0.5
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
best_val_loss = float("inf")
|
| 124 |
+
epochs_without_improvement = 0
|
| 125 |
+
|
| 126 |
+
# Training loop with early stopping to prevent overfitting
|
| 127 |
+
for epoch in range(EPOCHS):
|
| 128 |
+
print(f"\nEpoch [{epoch + 1}/{EPOCHS}]")
|
| 129 |
+
|
| 130 |
+
train_loss, train_acc = train_one_epoch(
|
| 131 |
+
model, train_loader, criterion, optimizer, device
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
val_loss, val_acc = validate_one_epoch(
|
| 135 |
+
model, val_loader, criterion, device
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
scheduler.step(val_loss)
|
| 139 |
+
|
| 140 |
+
print(
|
| 141 |
+
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
|
| 142 |
+
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if val_loss < best_val_loss:
|
| 146 |
+
best_val_loss = val_loss
|
| 147 |
+
epochs_without_improvement = 0
|
| 148 |
+
torch.save(model.state_dict(), CHECKPOINT_PATH)
|
| 149 |
+
print("Best model saved")
|
| 150 |
+
else:
|
| 151 |
+
epochs_without_improvement += 1
|
| 152 |
+
if epochs_without_improvement >= PATIENCE:
|
| 153 |
+
print("Early stopping triggered")
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
print("\nTraining is complete.")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
training/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Utility functions for training machine learning models using PyTorch and calculating accuracy.
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def get_device():
|
| 5 |
+
if torch.cuda.is_available():
|
| 6 |
+
return torch.device("cuda")
|
| 7 |
+
return torch.device("cpu")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def accuracy(outputs, labels):
|
| 11 |
+
preds = outputs.argmax(dim=1)
|
| 12 |
+
correct = (preds == labels).sum().item()
|
| 13 |
+
return correct / labels.size(0)
|