Flamekizer11 commited on
Commit
64d0ccc
·
verified ·
1 Parent(s): 8260a91

Upload 27 files

Browse files

commit to add initial codes including testing codes

backend/__init__.py ADDED
File without changes
backend/__pycache__/main.cpython-310.pyc ADDED
Binary file (787 Bytes). View file
 
backend/__pycache__/model_loader.cpython-310.pyc ADDED
Binary file (1.4 kB). View file
 
backend/__pycache__/predict.cpython-310.pyc ADDED
Binary file (695 Bytes). View file
 
backend/__pycache__/schemas.cpython-310.pyc ADDED
Binary file (449 Bytes). View file
 
backend/main.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from PIL import Image
3
+ import io
4
+
5
+ from backend.predict import predict_image
6
+ from backend.schemas import PredictionResponse
7
+
8
+ app = FastAPI(
9
+ title="X-Ray Classification API",
10
+ description="Deep learning based multi-class X-ray classifier",
11
+ version="1.0"
12
+ )
13
+ # Endpoint for predicting the class of an uploaded X-ray image.
14
+ # Accepts an image file and returns the predicted class label and confidence score.
15
+
16
+ @app.post("/predict", response_model=PredictionResponse)
17
+ async def predict(file: UploadFile = File(...)):
18
+ image_bytes = await file.read()
19
+ image = Image.open(io.BytesIO(image_bytes))
20
+
21
+ result = predict_image(image)
22
+ return result
backend/model_loader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ from training.model import build_model
4
+ from training.utils import get_device
5
+
6
+ CHECKPOINT_PATH = "checkpoints/best_model.pth"
7
+ LABEL_MAP_PATH = "data_processed/label_map.csv"
8
+
9
+ # Wrapper class for loading the model and making predictions on new data instances
10
+ class ModelWrapper:
11
+ def __init__(self):
12
+ self.device = get_device()
13
+
14
+ label_df = pd.read_csv(LABEL_MAP_PATH)
15
+ self.id_to_label = dict(
16
+ zip(label_df["label_id"], label_df["label"])
17
+ )
18
+
19
+ num_classes = len(self.id_to_label)
20
+
21
+ self.model = build_model(num_classes, self.device)
22
+ self.model.load_state_dict(
23
+ torch.load(CHECKPOINT_PATH, map_location=self.device)
24
+ )
25
+ self.model.eval()
26
+
27
+ def predict(self, image_tensor):
28
+ with torch.no_grad():
29
+ image_tensor = image_tensor.to(self.device)
30
+ outputs = self.model(image_tensor)
31
+ probs = torch.softmax(outputs, dim=1)
32
+
33
+ confidence, pred_id = torch.max(probs, dim=1)
34
+
35
+ return {
36
+ "label_id": int(pred_id.item()),
37
+ "label_name": self.id_to_label[int(pred_id.item())],
38
+ "confidence": float(confidence.item())
39
+ }
40
+
41
+
42
+ model_wrapper = ModelWrapper()
backend/predict.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from torchvision import transforms
4
+ from backend.model_loader import model_wrapper
5
+
6
+ # Defining the image transformation pipeline for preprocessing
7
+ transform = transforms.Compose([
8
+ transforms.Resize((224, 224)),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize(
11
+ mean=[0.485, 0.456, 0.406],
12
+ std=[0.229, 0.224, 0.225]
13
+ )
14
+ ])
15
+ # Function to predict the class of an input image using the loaded best model
16
+ def predict_image(image: Image.Image):
17
+ image = image.convert("RGB")
18
+ tensor = transform(image).unsqueeze(0)
19
+ return model_wrapper.predict(tensor)
backend/schemas.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #Module that defines the schema for prediction responses on the API.
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class PredictionResponse(BaseModel):
6
+ label_id: int
7
+ label_name: str
8
+ confidence: float
checkpoints/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d83d0507e2b50d687569e7ccab21287169f4a13ba3cd7441437a2001cefa82b6
3
+ size 94842430
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ fastapi
3
+ torch==2.1.0
4
+ torchvision==0.16.0
5
+ pillow
6
+ numpy
7
+ pandas
8
+ scikit-learn
9
+ python-multipart
scripts/01_merge_datasets.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import csv
4
+ from pathlib import Path
5
+
6
+ RAW_ROOT = Path("raw_data")
7
+ OUT_IMG = Path("data_merged/images")
8
+ OUT_CSV = Path("data_merged/metadata_raw.csv")
9
+
10
+ OUT_IMG.mkdir(parents=True, exist_ok=True)
11
+
12
+ rows = []
13
+ img_id = 0
14
+ VALID_EXT = (".png", ".jpg", ".jpeg")
15
+
16
+ def merge_any_dataset(dataset_name, base_path):
17
+ global img_id
18
+ for root, _, files in os.walk(base_path):
19
+ for f in files:
20
+ if not f.lower().endswith(VALID_EXT):
21
+ continue
22
+
23
+ src = Path(root) / f
24
+ class_name = Path(root).name
25
+
26
+ new_name = f"{dataset_name}__{class_name}__{img_id}{src.suffix}"
27
+ dst = OUT_IMG / new_name
28
+
29
+ shutil.copy(src, dst)
30
+
31
+ rows.append({
32
+ "image_id": img_id,
33
+ "filename": new_name,
34
+ "label": class_name,
35
+ "source": dataset_name
36
+ })
37
+
38
+ img_id += 1
39
+
40
+ # Merging all datasets found in RAW_ROOT
41
+ for item in RAW_ROOT.iterdir():
42
+ if item.is_dir():
43
+ merge_any_dataset(item.name, item)
44
+
45
+ # writing out the CSV
46
+ OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
47
+ with open(OUT_CSV, "w", newline="", encoding="utf-8") as f:
48
+ writer = csv.DictWriter(
49
+ f,
50
+ fieldnames=["image_id", "filename", "label", "source"]
51
+ )
52
+ writer.writeheader()
53
+ writer.writerows(rows)
54
+
55
+ print("Merged dataset created")
56
+ print("Images:", len(rows))
scripts/02_resize_images.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import pandas as pd
3
+ from pathlib import Path
4
+
5
+ IN_IMG = Path("data_merged/images")
6
+ OUT_IMG = Path("data_processed/images")
7
+ IN_CSV = Path("data_merged/metadata_raw.csv")
8
+ OUT_CSV = Path("data_processed/metadata_resized.csv")
9
+
10
+ OUT_IMG.mkdir(parents=True, exist_ok=True)
11
+
12
+ df = pd.read_csv(IN_CSV)
13
+ kept_rows = []
14
+
15
+ for _, row in df.iterrows():
16
+ src = IN_IMG / row["filename"]
17
+ dst = OUT_IMG / row["filename"]
18
+
19
+ img = cv2.imread(str(src))
20
+ if img is None:
21
+ continue
22
+
23
+ img = cv2.resize(img, (224, 224))
24
+ cv2.imwrite(str(dst), img)
25
+
26
+ kept_rows.append(row)
27
+
28
+ pd.DataFrame(kept_rows).to_csv(OUT_CSV, index=False)
29
+
30
+ print("Images kept:", len(kept_rows))
scripts/03_create_metadata.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from pathlib import Path
3
+
4
+ IN_CSV = Path("data_processed/metadata_resized.csv")
5
+ OUT_CSV = Path("data_processed/metadata_final.csv")
6
+ MAP_CSV = Path("data_processed/label_map.csv")
7
+
8
+ df = pd.read_csv(IN_CSV)
9
+
10
+ labels = sorted(df["label"].unique())
11
+ label_to_id = {label: i for i, label in enumerate(labels)}
12
+
13
+ df["label_id"] = df["label"].map(label_to_id)
14
+
15
+ df.to_csv(OUT_CSV, index=False)
16
+
17
+ pd.DataFrame({
18
+ "label": labels,
19
+ "label_id": [label_to_id[l] for l in labels]
20
+ }).to_csv(MAP_CSV, index=False)
21
+
22
+ print("Total classes:", len(labels))
training/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (809 Bytes). View file
 
training/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
training/__pycache__/model.cpython-310.pyc ADDED
Binary file (690 Bytes). View file
 
training/__pycache__/utils.cpython-310.pyc ADDED
Binary file (549 Bytes). View file
 
training/dataloader.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Module for creating data loaders for training and validation datasets.
2
+ from torch.utils.data import DataLoader, random_split
3
+ from dataset import XRayDataset
4
+
5
+ def get_dataloaders(
6
+ csv_path,
7
+ images_dir,
8
+ batch_size=32,
9
+ val_split=0.2
10
+ ):
11
+ full_dataset = XRayDataset(
12
+ csv_path=csv_path,
13
+ images_dir=images_dir,
14
+ train=True
15
+ )
16
+
17
+ val_size = int(len(full_dataset) * val_split)
18
+ train_size = len(full_dataset) - val_size
19
+
20
+ train_ds, val_ds = random_split(
21
+ full_dataset,
22
+ [train_size, val_size]
23
+ )
24
+
25
+ # Disable augmentation for validation dataset so that we only apply normalization
26
+ val_ds.dataset.transform = XRayDataset(
27
+ csv_path,
28
+ images_dir,
29
+ train=False
30
+ ).transform
31
+
32
+ train_loader = DataLoader(
33
+ train_ds,
34
+ batch_size=batch_size,
35
+ shuffle=True,
36
+ num_workers=0
37
+ )
38
+
39
+ val_loader = DataLoader(
40
+ val_ds,
41
+ batch_size=batch_size,
42
+ shuffle=False,
43
+ num_workers=0
44
+ )
45
+
46
+
47
+ return train_loader, val_loader
training/dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import pandas as pd
5
+ from pathlib import Path
6
+
7
+ # Custom Dataset for X-Ray Images with Augmentations for Training and Standard Transformations for Validation
8
+
9
+ class XRayDataset(Dataset):
10
+ def __init__(self, csv_path, images_dir, train=True):
11
+ self.df = pd.read_csv(csv_path)
12
+ self.images_dir = Path(images_dir)
13
+
14
+ if train:
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.RandomHorizontalFlip(p=0.5),
18
+ transforms.RandomRotation(15),
19
+ transforms.RandomResizedCrop(224, scale=(0.85, 1.0)),
20
+ transforms.ColorJitter(brightness=0.1, contrast=0.1),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(
23
+ mean=[0.485, 0.456, 0.406],
24
+ std=[0.229, 0.224, 0.225]
25
+ )
26
+ ])
27
+ else:
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(
32
+ mean=[0.485, 0.456, 0.406],
33
+ std=[0.229, 0.224, 0.225]
34
+ )
35
+ ])
36
+
37
+ def __len__(self):
38
+ return len(self.df)
39
+
40
+ def __getitem__(self, idx):
41
+ row = self.df.iloc[idx]
42
+ img_path = self.images_dir / row["filename"]
43
+ label = row["label_id"]
44
+
45
+ image = Image.open(img_path).convert("RGB")
46
+ image = self.transform(image)
47
+
48
+ return image, label
training/evaluate.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ from sklearn.metrics import confusion_matrix, classification_report
6
+
7
+ from model import build_model
8
+ from dataloader import get_dataloaders
9
+ from utils import get_device
10
+
11
+ CSV_PATH = "data_processed/metadata_final.csv"
12
+ IMG_DIR = "data_processed/images"
13
+ CHECKPOINT_PATH = "checkpoints/best_model.pth"
14
+
15
+ device = get_device()
16
+
17
+ df = pd.read_csv(CSV_PATH)
18
+ num_classes = df["label_id"].nunique()
19
+
20
+ model = build_model(num_classes, device)
21
+ model.load_state_dict(torch.load(CHECKPOINT_PATH))
22
+ model.eval()
23
+
24
+ _, val_loader = get_dataloaders(
25
+ csv_path=CSV_PATH,
26
+ images_dir=IMG_DIR,
27
+ batch_size=32
28
+ )
29
+
30
+ y_true, y_pred = [], []
31
+
32
+ with torch.no_grad():
33
+ for images, labels in val_loader:
34
+ images = images.to(device)
35
+ outputs = model(images)
36
+ preds = outputs.argmax(dim=1).cpu().numpy()
37
+
38
+ y_pred.extend(preds)
39
+ y_true.extend(labels.numpy())
40
+
41
+ cm = confusion_matrix(y_true, y_pred)
42
+
43
+ plt.figure(figsize=(14, 12))
44
+ sns.heatmap(cm, cmap="Blues", xticklabels=False, yticklabels=False)
45
+ plt.title("Confusion Matrix")
46
+ plt.xlabel("Predicted")
47
+ plt.ylabel("True")
48
+ plt.show()
49
+
50
+ print("\nClassification Report:")
51
+ print(classification_report(y_true, y_pred))
training/model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+ def build_model(num_classes: int, device: torch.device):
6
+ model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
7
+
8
+ # Freeze early layers
9
+ for name, param in model.named_parameters():
10
+ if not (
11
+ name.startswith("layer3") or
12
+ name.startswith("layer4") or
13
+ name.startswith("fc")
14
+ ):
15
+ param.requires_grad = False
16
+
17
+
18
+
19
+ # Replace classifier
20
+ in_features = model.fc.in_features
21
+ model.fc = nn.Linear(in_features, num_classes)
22
+
23
+ return model.to(device)
training/test_dataset.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Testing the XRayDataset class for correctness and functionality
2
+ from dataset import XRayDataset
3
+
4
+ ds = XRayDataset(
5
+ csv_path="data_processed/metadata_final.csv",
6
+ images_dir="data_processed/images",
7
+ train=True
8
+ )
9
+
10
+ print("Total samples:", len(ds))
11
+ img, label = ds[0]
12
+ print("Image shape:", img.shape)
13
+ print("Label:", label)
training/test_loader.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Testing the dataloader functionality
2
+ from dataloader import get_dataloaders
3
+
4
+ train_loader, val_loader = get_dataloaders(
5
+ csv_path="data_processed/metadata_final.csv",
6
+ images_dir="data_processed/images",
7
+ batch_size=32
8
+ )
9
+
10
+ images, labels = next(iter(train_loader))
11
+ print("Batch image shape:", images.shape)
12
+ print("Batch labels shape:", labels.shape)
training/test_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision import models
3
+
4
+
5
+ def build_model(num_classes, device):
6
+ model = models.resnet18(
7
+ weights=models.ResNet18_Weights.IMAGENET1K_V1
8
+ )
9
+
10
+ # Freezing everything
11
+ for param in model.parameters():
12
+ param.requires_grad = False
13
+
14
+ # Unfreezing deeper layers
15
+ for param in model.layer3.parameters():
16
+ param.requires_grad = True
17
+
18
+ for param in model.layer4.parameters():
19
+ param.requires_grad = True
20
+
21
+ # Replacing classifier for our number of classes
22
+ in_features = model.fc.in_features
23
+ model.fc = nn.Linear(in_features, num_classes)
24
+
25
+ model = model.to(device)
26
+ return model
training/train.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import pandas as pd
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from dataloader import get_dataloaders
10
+ from model import build_model
11
+ from utils import get_device, accuracy
12
+
13
+
14
+ def compute_class_weights(csv_path):
15
+ df = pd.read_csv(csv_path)
16
+
17
+ class_counts = df["label_id"].value_counts().sort_index()
18
+ total_samples = class_counts.sum()
19
+
20
+ class_counts = torch.tensor(class_counts.values, dtype=torch.float32)
21
+
22
+ # Soft inverse-frequency weighting
23
+ weights = total_samples / class_counts
24
+
25
+ # Log-scale to reduce extremes
26
+ weights = torch.log1p(weights)
27
+
28
+ # Normalize
29
+ weights = weights / weights.mean()
30
+
31
+ # 🔒 Cap extreme weights (critical)
32
+ weights = torch.clamp(weights, max=3.0)
33
+
34
+ return weights
35
+
36
+
37
+
38
+ # Train and validation functions for one epoch each
39
+
40
+ def train_one_epoch(model, loader, criterion, optimizer, device):
41
+ model.train()
42
+ total_loss, total_acc = 0.0, 0.0
43
+
44
+ for images, labels in tqdm(loader, desc="Training", leave=False):
45
+ images, labels = images.to(device), labels.to(device)
46
+
47
+ optimizer.zero_grad()
48
+ outputs = model(images)
49
+ loss = criterion(outputs, labels)
50
+
51
+ loss.backward()
52
+ optimizer.step()
53
+
54
+ total_loss += loss.item()
55
+ total_acc += accuracy(outputs, labels)
56
+
57
+ return total_loss / len(loader), total_acc / len(loader)
58
+
59
+
60
+ def validate_one_epoch(model, loader, criterion, device):
61
+ model.eval()
62
+ total_loss, total_acc = 0.0, 0.0
63
+
64
+ with torch.no_grad():
65
+ for images, labels in tqdm(loader, desc="Validation", leave=False):
66
+ images, labels = images.to(device), labels.to(device)
67
+ outputs = model(images)
68
+ loss = criterion(outputs, labels)
69
+
70
+ total_loss += loss.item()
71
+ total_acc += accuracy(outputs, labels)
72
+
73
+ return total_loss / len(loader), total_acc / len(loader)
74
+
75
+
76
+ def main():
77
+ #Hyperparameters and paths
78
+ BATCH_SIZE = 32
79
+ EPOCHS = 20
80
+ LR = 1e-4
81
+ PATIENCE = 4
82
+
83
+ CSV_PATH = "data_processed/metadata_final.csv"
84
+ IMG_DIR = "data_processed/images"
85
+ CHECKPOINT_DIR = "checkpoints"
86
+ CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/best_model.pth"
87
+
88
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
89
+
90
+ #Setup
91
+ device = get_device()
92
+ print("Using device:", device)
93
+
94
+ df = pd.read_csv(CSV_PATH)
95
+ num_classes = df["label_id"].nunique()
96
+
97
+ train_loader, val_loader = get_dataloaders(
98
+ csv_path=CSV_PATH,
99
+ images_dir=IMG_DIR,
100
+ batch_size=BATCH_SIZE
101
+ )
102
+
103
+ model = build_model(num_classes, device)
104
+
105
+ class_weights = compute_class_weights(CSV_PATH).to(device)
106
+ criterion = nn.CrossEntropyLoss(
107
+ weight=class_weights,
108
+ label_smoothing=0.02
109
+ )
110
+
111
+
112
+ optimizer = torch.optim.AdamW(
113
+ model.parameters(),
114
+ lr=LR,
115
+ weight_decay=1e-4
116
+ )
117
+
118
+ # Learning rate scheduler so that lr reduces if val loss plateaus
119
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
120
+ optimizer, mode="min", patience=2, factor=0.5
121
+ )
122
+
123
+ best_val_loss = float("inf")
124
+ epochs_without_improvement = 0
125
+
126
+ # Training loop with early stopping to prevent overfitting
127
+ for epoch in range(EPOCHS):
128
+ print(f"\nEpoch [{epoch + 1}/{EPOCHS}]")
129
+
130
+ train_loss, train_acc = train_one_epoch(
131
+ model, train_loader, criterion, optimizer, device
132
+ )
133
+
134
+ val_loss, val_acc = validate_one_epoch(
135
+ model, val_loader, criterion, device
136
+ )
137
+
138
+ scheduler.step(val_loss)
139
+
140
+ print(
141
+ f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
142
+ f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
143
+ )
144
+
145
+ if val_loss < best_val_loss:
146
+ best_val_loss = val_loss
147
+ epochs_without_improvement = 0
148
+ torch.save(model.state_dict(), CHECKPOINT_PATH)
149
+ print("Best model saved")
150
+ else:
151
+ epochs_without_improvement += 1
152
+ if epochs_without_improvement >= PATIENCE:
153
+ print("Early stopping triggered")
154
+ break
155
+
156
+ print("\nTraining is complete.")
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
training/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utility functions for training machine learning models using PyTorch and calculating accuracy.
2
+ import torch
3
+
4
+ def get_device():
5
+ if torch.cuda.is_available():
6
+ return torch.device("cuda")
7
+ return torch.device("cpu")
8
+
9
+
10
+ def accuracy(outputs, labels):
11
+ preds = outputs.argmax(dim=1)
12
+ correct = (preds == labels).sum().item()
13
+ return correct / labels.size(0)