Spaces:
No application file
No application file
Uploaded
Browse files- HAM_10000_CLASSIFICATION-master/.devcontainer/Dockerfile +17 -0
- HAM_10000_CLASSIFICATION-master/.devcontainer/devcontainer.json +47 -0
- HAM_10000_CLASSIFICATION-master/.dockerignore +7 -0
- HAM_10000_CLASSIFICATION-master/.gitignore +8 -0
- HAM_10000_CLASSIFICATION-master/Dockerfile +17 -0
- HAM_10000_CLASSIFICATION-master/README.md +128 -0
- HAM_10000_CLASSIFICATION-master/app.py +169 -0
- HAM_10000_CLASSIFICATION-master/models/convnext.py +25 -0
- HAM_10000_CLASSIFICATION-master/models/efficient_net.py +31 -0
- HAM_10000_CLASSIFICATION-master/models/mobile_net.py +39 -0
- HAM_10000_CLASSIFICATION-master/models/shufflenet.py +25 -0
- HAM_10000_CLASSIFICATION-master/models/vit.py +31 -0
- HAM_10000_CLASSIFICATION-master/requirements.txt +83 -0
- HAM_10000_CLASSIFICATION-master/utils/trainer.py +216 -0
HAM_10000_CLASSIFICATION-master/.devcontainer/Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-bookworm
|
| 2 |
+
|
| 3 |
+
# Install system dependencies (libGL, FFmpeg, X11 tools)
|
| 4 |
+
RUN apt-get update && \
|
| 5 |
+
apt-get install -y libgl1 ffmpeg x11-xserver-utils && \
|
| 6 |
+
apt-get clean && rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
# Install Python dependencies (optimized with --no-cache-dir)
|
| 9 |
+
RUN pip install --no-cache-dir \
|
| 10 |
+
numpy \
|
| 11 |
+
pandas \
|
| 12 |
+
matplotlib \
|
| 13 |
+
opencv-python-headless \
|
| 14 |
+
torch \
|
| 15 |
+
torchvision \
|
| 16 |
+
streamlit \
|
| 17 |
+
tqdm
|
HAM_10000_CLASSIFICATION-master/.devcontainer/devcontainer.json
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
| 2 |
+
// README at: https://github.com/devcontainers/templates/tree/main/src/python
|
| 3 |
+
{
|
| 4 |
+
"name": "Python 3",
|
| 5 |
+
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
| 6 |
+
|
| 7 |
+
"build" : {
|
| 8 |
+
"dockerfile": "Dockerfile"
|
| 9 |
+
},
|
| 10 |
+
|
| 11 |
+
// Features to add to the dev container. More info: https://containers.dev/features.
|
| 12 |
+
// "features": {},
|
| 13 |
+
"privileged" : true,
|
| 14 |
+
|
| 15 |
+
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
| 16 |
+
"forwardPorts": [
|
| 17 |
+
8501
|
| 18 |
+
],
|
| 19 |
+
|
| 20 |
+
"runArgs": [
|
| 21 |
+
"--env=DISPLAY",
|
| 22 |
+
"--env=XAUTHORITY",
|
| 23 |
+
"--env=WAYLAND_DISPLAY", // Added Wayland environment variable
|
| 24 |
+
"--env=QT_QPA_PLATFORM", // Ensures Qt apps use the right display server
|
| 25 |
+
"--volume=/tmp/.X11-unix:/tmp/.X11-unix",
|
| 26 |
+
"--volume=/run/user/1000/.mutter-Xwaylandauth.HSC612",
|
| 27 |
+
"--device=/dev/dri:/dev/dri"
|
| 28 |
+
],
|
| 29 |
+
|
| 30 |
+
"customizations": {
|
| 31 |
+
"vscode": {
|
| 32 |
+
"extensions": [
|
| 33 |
+
"ms-python.python",
|
| 34 |
+
"ms-python.debugpy"
|
| 35 |
+
]
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Use 'postCreateCommand' to run commands after the container is created.
|
| 40 |
+
// "postCreateCommand": "pip3 install --user -r requirements.txt",
|
| 41 |
+
|
| 42 |
+
// Configure tool-specific properties.
|
| 43 |
+
// "customizations": {},
|
| 44 |
+
|
| 45 |
+
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
| 46 |
+
// "remoteUser": "root"
|
| 47 |
+
}
|
HAM_10000_CLASSIFICATION-master/.dockerignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.devcontainer
|
| 2 |
+
dataset
|
| 3 |
+
__pycache__
|
| 4 |
+
utils
|
| 5 |
+
venv
|
| 6 |
+
test.py
|
| 7 |
+
README.md
|
HAM_10000_CLASSIFICATION-master/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset/
|
| 2 |
+
model_weights/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.csv
|
| 5 |
+
*.zip
|
| 6 |
+
test.py
|
| 7 |
+
venv/
|
| 8 |
+
*.jpg
|
HAM_10000_CLASSIFICATION-master/Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-bookworm
|
| 2 |
+
|
| 3 |
+
# Set the working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Copy the current directory contents into the container at /app
|
| 7 |
+
COPY . /app
|
| 8 |
+
|
| 9 |
+
# update ubuntu and pip
|
| 10 |
+
RUN python -m pip install --upgrade pip
|
| 11 |
+
|
| 12 |
+
RUN pip --default-timeout=100 --no-cache-dir install -r requirements.txt
|
| 13 |
+
|
| 14 |
+
EXPOSE 8501
|
| 15 |
+
|
| 16 |
+
CMD ["python", "app.py"]
|
| 17 |
+
|
HAM_10000_CLASSIFICATION-master/README.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ham10000-Backend
|
| 2 |
+
|
| 3 |
+
This repo contains the Backend.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
## Installation
|
| 8 |
+
|
| 9 |
+
- create a virtual environment.
|
| 10 |
+
```bash
|
| 11 |
+
$ python -m venv venv
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
- Install the required packages
|
| 15 |
+
```bash
|
| 16 |
+
$ pip install requirements.txt
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
- Run the backend.
|
| 20 |
+
```bash
|
| 21 |
+
$ python app.py
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
- if you see output like this. it means backend is running
|
| 25 |
+
```bash
|
| 26 |
+
INFO: Started server process [1624]
|
| 27 |
+
INFO: Waiting for application startup.
|
| 28 |
+
INFO: Application startup complete.
|
| 29 |
+
INFO: Uvicorn running on http://0.0.0.0:8501 (Press CTRL+C to quit)
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## API EndPoints
|
| 33 |
+
|
| 34 |
+
- ### **GET** - http://localhost:8501/api/model_performance
|
| 35 |
+
|
| 36 |
+
This endpoint is for details of each model.
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
{
|
| 40 |
+
"models": [
|
| 41 |
+
{
|
| 42 |
+
"name": "Efficient Net V7",
|
| 43 |
+
"description": "A high-efficiency architecture that leverages compound scaling for superior performance across various tasks.",
|
| 44 |
+
"performance_tags": [
|
| 45 |
+
{ "icon": "fa-tachometer-alt", "label": "High Efficiency" },
|
| 46 |
+
{ "icon": "fa-star", "label": "State-of-the-Art" }
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
- ### POST - http://localhost:8501/api/classify
|
| 54 |
+
This endpoint is for making posting images to models and making predictions.
|
| 55 |
+
|
| 56 |
+
**Request Body**
|
| 57 |
+
|
| 58 |
+
- Json:
|
| 59 |
+
```bash
|
| 60 |
+
{
|
| 61 |
+
'image': 'binary file (JPEG, PNG, etc.)',
|
| 62 |
+
'model': 'Efficient Net V7'
|
| 63 |
+
}
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
- CURL:
|
| 67 |
+
```bash
|
| 68 |
+
curl -X POST "http://your-api-domain.com/api/classify" \
|
| 69 |
+
-H "Content-Type: multipart/form-data" \
|
| 70 |
+
-F "image=@path/to/image.jpg" \
|
| 71 |
+
-F "model=Efficient Net V7"
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
**Response**
|
| 75 |
+
|
| 76 |
+
- If no cancer detected
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
{
|
| 80 |
+
"result" : "Cancer Not Detected"
|
| 81 |
+
}
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
- If cancer detected. The class with highest confidence is the prediction. In this case **akiec** has the highest confidence of **0.9999591112136841**.
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
{
|
| 88 |
+
"result": [
|
| 89 |
+
{
|
| 90 |
+
"class": "akiec",
|
| 91 |
+
"confidence": 0.9999591112136841
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"class": "mel",
|
| 95 |
+
"confidence": 0.00003423781890887767
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"class": "bkl",
|
| 99 |
+
"confidence": 0.0000065627805270196404
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"class": "nv",
|
| 103 |
+
"confidence": 1.223657761784125e-7
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"class": "vasc",
|
| 107 |
+
"confidence": 7.100419363581523e-9
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"class": "df",
|
| 111 |
+
"confidence": 6.403973351609693e-9
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"class": "bcc",
|
| 115 |
+
"confidence": 1.5403813780068276e-9
|
| 116 |
+
}
|
| 117 |
+
]
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
- ### GET - http://localhost:8501/api/health
|
| 122 |
+
This endpoint is for aws healthcheck only and doesn't have any partical significance.
|
| 123 |
+
|
| 124 |
+
## Note:
|
| 125 |
+
|
| 126 |
+
**'Efficient Net V7'** is the best model so use that for classification.
|
| 127 |
+
|
| 128 |
+
|
HAM_10000_CLASSIFICATION-master/app.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from models.efficient_net import EfficientNetB7
|
| 4 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from torchvision import models, transforms
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch
|
| 11 |
+
import gdown
|
| 12 |
+
import os
|
| 13 |
+
from colorama import Fore, Style
|
| 14 |
+
|
| 15 |
+
####################### DOWNLOAD MODEL WEIGHTS #######################
|
| 16 |
+
if not os.path.exists("model_weights"):
|
| 17 |
+
drive_link = "https://drive.google.com/drive/folders/1JOd3O1c3me5JWE3Elq0MhhShP_XhTYRV"
|
| 18 |
+
output_dir = "model_weights"
|
| 19 |
+
|
| 20 |
+
# Download the folder recursively
|
| 21 |
+
gdown.download_folder(drive_link, output=output_dir, quiet=False, use_cookies=False)
|
| 22 |
+
|
| 23 |
+
# Check if model weights were downloaded
|
| 24 |
+
model_weights_path = os.path.join(output_dir, "model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth") # Update with actual filename
|
| 25 |
+
|
| 26 |
+
if os.path.exists(model_weights_path):
|
| 27 |
+
print(Fore.GREEN + "✅ Model weights downloaded successfully!" + Style.RESET_ALL)
|
| 28 |
+
else:
|
| 29 |
+
print(Fore.RED + "❌ Model weights not found. Check the folder or link." + Style.RESET_ALL)
|
| 30 |
+
|
| 31 |
+
print("Download completed!")
|
| 32 |
+
else:
|
| 33 |
+
print(Fore.YELLOW + "⚠️ Model weights already exist. Skipping download." + Style.RESET_ALL)
|
| 34 |
+
print("Download skipped!")
|
| 35 |
+
######################################################################
|
| 36 |
+
|
| 37 |
+
ml_models = {}
|
| 38 |
+
cancer_nocancer_model_1 = None
|
| 39 |
+
cancer_nocancer_model_2 = None
|
| 40 |
+
|
| 41 |
+
################### for cancer no cancer model ################################
|
| 42 |
+
def load_model(model_path, num_classes):
|
| 43 |
+
"""
|
| 44 |
+
Load the EfficientNet-B0 model from the state dict
|
| 45 |
+
"""
|
| 46 |
+
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
|
| 47 |
+
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
|
| 48 |
+
|
| 49 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 50 |
+
model.load_state_dict(state_dict)
|
| 51 |
+
model.eval()
|
| 52 |
+
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
efficient_net_b0_transform = transforms.Compose([
|
| 56 |
+
transforms.Resize((224, 224)),
|
| 57 |
+
transforms.ToTensor(),
|
| 58 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 59 |
+
std=[0.229, 0.224, 0.225])
|
| 60 |
+
])
|
| 61 |
+
cancer_no_caner_class_mapping = {0: 'cancer', 1: 'nocancer'}
|
| 62 |
+
|
| 63 |
+
################################################################################
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# make a lifespan
|
| 67 |
+
@asynccontextmanager
|
| 68 |
+
async def lifespan(app: FastAPI):
|
| 69 |
+
global ml_models
|
| 70 |
+
global cancer_nocancer_model_1
|
| 71 |
+
global cancer_nocancer_model_2
|
| 72 |
+
|
| 73 |
+
cancer_nocancer_1_weight_path = "model_weights/cancer_nocancer.pth"
|
| 74 |
+
cancer_nocancer_2_weight_path = "model_weights/cancer_nocancer_100.pth"
|
| 75 |
+
|
| 76 |
+
efficient_net_model = EfficientNetB7(weights_path="model_weights/EfficientNetV7Large_v1/saved_models/best_test_model.pth")
|
| 77 |
+
|
| 78 |
+
cancer_nocancer_model_1 = load_model(cancer_nocancer_1_weight_path, len(cancer_no_caner_class_mapping))
|
| 79 |
+
cancer_nocancer_model_2 = load_model(cancer_nocancer_2_weight_path, len(cancer_no_caner_class_mapping))
|
| 80 |
+
|
| 81 |
+
ml_models = {
|
| 82 |
+
"EFFICIENT NET V7": efficient_net_model,
|
| 83 |
+
}
|
| 84 |
+
yield
|
| 85 |
+
|
| 86 |
+
ml_models.clear()
|
| 87 |
+
|
| 88 |
+
app = FastAPI(lifespan=lifespan)
|
| 89 |
+
app.add_middleware(
|
| 90 |
+
CORSMiddleware,
|
| 91 |
+
allow_origins=["*"], # Allow all origins (or specify your frontend URL)
|
| 92 |
+
allow_methods=["*"],
|
| 93 |
+
allow_headers=["*"],
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@app.post("/api/classify", response_class=JSONResponse)
|
| 97 |
+
async def classify(
|
| 98 |
+
image: UploadFile = File(...), # Accepts an image file
|
| 99 |
+
model: str = Form(...) # Accepts a model name as form data
|
| 100 |
+
):
|
| 101 |
+
global ml_models
|
| 102 |
+
global cancer_nocancer_model_1
|
| 103 |
+
global cancer_nocancer_model_2
|
| 104 |
+
|
| 105 |
+
model = model.upper()
|
| 106 |
+
if model not in ml_models.keys():
|
| 107 |
+
raise HTTPException(status_code=400, detail="Invalid model specified")
|
| 108 |
+
|
| 109 |
+
image_data = await image.read()
|
| 110 |
+
img = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 111 |
+
|
| 112 |
+
# check whether image if of cancer or not
|
| 113 |
+
img_for_check = efficient_net_b0_transform(img).unsqueeze(0)
|
| 114 |
+
check_cancer_1 = F.softmax(cancer_nocancer_model_1(img_for_check), dim=1)
|
| 115 |
+
check_cancer_2 = F.softmax(cancer_nocancer_model_2(img_for_check), dim=1)
|
| 116 |
+
|
| 117 |
+
check_cancer_1_index = check_cancer_1.argmax()
|
| 118 |
+
check_cancer_2_index = check_cancer_2.argmax()
|
| 119 |
+
|
| 120 |
+
print("-" * 25, "Cancer No Cancer", "-" * 25)
|
| 121 |
+
print(check_cancer_1_index)
|
| 122 |
+
print(check_cancer_2_index)
|
| 123 |
+
print("-" * 50)
|
| 124 |
+
|
| 125 |
+
if check_cancer_1_index == 1 and check_cancer_2_index == 1:
|
| 126 |
+
print("Cancer Not Detected")
|
| 127 |
+
return JSONResponse({"result": "Cancer Not Detected"})
|
| 128 |
+
else:
|
| 129 |
+
print("Cancer Detected")
|
| 130 |
+
prediction = predict(ml_models[model], model, img)
|
| 131 |
+
return JSONResponse(prediction)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def predict(model, model_name, image_data):
|
| 135 |
+
classes = ['akiec','bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
|
| 136 |
+
|
| 137 |
+
print("-" * 50)
|
| 138 |
+
print(f"Model Used: {model_name}")
|
| 139 |
+
pred = model.make_prediction(image_data)[0]
|
| 140 |
+
print(pred)
|
| 141 |
+
pred_json = {"result" : [{"class": cls, "confidence": float(pred[i])} for i, cls in enumerate(classes)]}
|
| 142 |
+
pred_json["result"] = sorted(pred_json["result"], key=lambda x: x["confidence"], reverse=True)
|
| 143 |
+
print(pred_json)
|
| 144 |
+
print("-" * 50)
|
| 145 |
+
return pred_json
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@app.get("/api/model_performance", response_class=JSONResponse)
|
| 149 |
+
async def home():
|
| 150 |
+
return {
|
| 151 |
+
"models": [
|
| 152 |
+
{
|
| 153 |
+
"name": "Efficient Net V7",
|
| 154 |
+
"description": "A high-efficiency architecture that leverages compound scaling for superior performance across various tasks.",
|
| 155 |
+
"performance_tags": [
|
| 156 |
+
{ "icon": "fa-tachometer-alt", "label": "High Efficiency" },
|
| 157 |
+
{ "icon": "fa-star", "label": "State-of-the-Art" }
|
| 158 |
+
]
|
| 159 |
+
},
|
| 160 |
+
]
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
@app.get("/api/health", response_class=JSONResponse)
|
| 164 |
+
def health():
|
| 165 |
+
return JSONResponse(status_code=200, content={"status": "Working fine!"})
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
import uvicorn
|
| 169 |
+
uvicorn.run(app, host="0.0.0.0", port=8501)
|
HAM_10000_CLASSIFICATION-master/models/convnext.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.models import convnext_base, ConvNeXt_Base_Weights
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class ConvNextBase:
|
| 7 |
+
def __init__(self, weights_path: str):
|
| 8 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
+
|
| 10 |
+
self.model = convnext_base()
|
| 11 |
+
self.model.classifier[2] = torch.nn.Linear(self.model.classifier[2].in_features, 7)
|
| 12 |
+
|
| 13 |
+
state_dict = torch.load(weights_path, map_location=self.device)
|
| 14 |
+
self.model.load_state_dict(state_dict)
|
| 15 |
+
self.model.eval()
|
| 16 |
+
|
| 17 |
+
self.transform = ConvNeXt_Base_Weights.IMAGENET1K_V1.transforms()
|
| 18 |
+
|
| 19 |
+
def make_prediction(self, image: Image):
|
| 20 |
+
image = self.transform(image).unsqueeze(0)
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
pred = F.softmax(self.model(image), dim=1)
|
| 23 |
+
return pred
|
| 24 |
+
|
| 25 |
+
|
HAM_10000_CLASSIFICATION-master/models/efficient_net.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.models import efficientnet_b7
|
| 3 |
+
from torchvision.transforms import InterpolationMode
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
class EfficientNetB7:
|
| 9 |
+
def __init__(self, weights_path: str):
|
| 10 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
|
| 12 |
+
self.transform = transforms.Compose([
|
| 13 |
+
transforms.Resize((600), interpolation=InterpolationMode.BICUBIC),
|
| 14 |
+
transforms.CenterCrop(600),
|
| 15 |
+
transforms.ToTensor(),
|
| 16 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 17 |
+
std=[0.229, 0.224, 0.225])
|
| 18 |
+
])
|
| 19 |
+
|
| 20 |
+
self.model = efficientnet_b7()
|
| 21 |
+
self.model.classifier[1] = torch.nn.Linear(2560, 7)
|
| 22 |
+
|
| 23 |
+
state_dict = torch.load(weights_path, map_location=torch.device(self.device))
|
| 24 |
+
self.model.load_state_dict(state_dict)
|
| 25 |
+
self.model.eval()
|
| 26 |
+
|
| 27 |
+
def make_prediction(self, image: Image):
|
| 28 |
+
image = self.transform(image).unsqueeze(0)
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
pred = F.softmax(self.model(image), dim=1)
|
| 31 |
+
return pred
|
HAM_10000_CLASSIFICATION-master/models/mobile_net.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
|
| 3 |
+
from torchvision.transforms import InterpolationMode
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
class MobileNetV3Large:
|
| 9 |
+
def __init__(self, weights_path: str):
|
| 10 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
|
| 12 |
+
self.transform = transforms.Compose([
|
| 13 |
+
transforms.Resize((232), interpolation=InterpolationMode.BILINEAR),
|
| 14 |
+
transforms.CenterCrop(224),
|
| 15 |
+
transforms.ToTensor(),
|
| 16 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 17 |
+
std=[0.229, 0.224, 0.225])
|
| 18 |
+
])
|
| 19 |
+
|
| 20 |
+
self.model = mobilenet_v3_large()
|
| 21 |
+
self.model.classifier = torch.nn.Sequential(
|
| 22 |
+
torch.nn.Linear(960, 1280),
|
| 23 |
+
torch.nn.ReLU(),
|
| 24 |
+
torch.nn.Dropout(0.2),
|
| 25 |
+
torch.nn.Linear(1280, 1000),
|
| 26 |
+
torch.nn.ReLU(),
|
| 27 |
+
torch.nn.Dropout(0.2),
|
| 28 |
+
torch.nn.Linear(1000, 7)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
state_dict = torch.load(weights_path, map_location=torch.device(self.device))
|
| 32 |
+
self.model.load_state_dict(state_dict)
|
| 33 |
+
self.model.eval()
|
| 34 |
+
|
| 35 |
+
def make_prediction(self, image: Image):
|
| 36 |
+
image = self.transform(image).unsqueeze(0)
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
pred = F.softmax(self.model(image), dim=1)
|
| 39 |
+
return pred
|
HAM_10000_CLASSIFICATION-master/models/shufflenet.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.models import shufflenet_v2_x1_5, ShuffleNet_V2_X1_5_Weights
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class ShuffleNet_V2_X1_5:
|
| 7 |
+
def __init__(self, weights_path: str):
|
| 8 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
+
|
| 10 |
+
self.model = shufflenet_v2_x1_5()
|
| 11 |
+
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 7)
|
| 12 |
+
|
| 13 |
+
state_dict = torch.load(weights_path, map_location=self.device)
|
| 14 |
+
self.model.load_state_dict(state_dict)
|
| 15 |
+
self.model.eval()
|
| 16 |
+
|
| 17 |
+
self.transform = ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1.transforms()
|
| 18 |
+
|
| 19 |
+
def make_prediction(self, image: Image):
|
| 20 |
+
image = self.transform(image).unsqueeze(0)
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
pred = F.softmax(self.model(image), dim=1)
|
| 23 |
+
return pred
|
| 24 |
+
|
| 25 |
+
|
HAM_10000_CLASSIFICATION-master/models/vit.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
class VIT_B_16:
|
| 8 |
+
def __init__(self, weights_path: str):
|
| 9 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
+
|
| 11 |
+
self.model = vit_b_16()
|
| 12 |
+
self.model.heads[0] = torch.nn.Linear(768, 7)
|
| 13 |
+
|
| 14 |
+
state_dict = torch.load(weights_path, map_location=self.device)
|
| 15 |
+
self.model.load_state_dict(state_dict)
|
| 16 |
+
self.model.eval()
|
| 17 |
+
|
| 18 |
+
self.transform = transforms.Compose([
|
| 19 |
+
transforms.Resize(256),
|
| 20 |
+
transforms.CenterCrop(224),
|
| 21 |
+
transforms.ToTensor(),
|
| 22 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 23 |
+
])
|
| 24 |
+
|
| 25 |
+
def make_prediction(self, image: Image):
|
| 26 |
+
image = self.transform(image).unsqueeze(0)
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
pred = F.softmax(self.model(image), dim=1)
|
| 29 |
+
return pred
|
| 30 |
+
|
| 31 |
+
|
HAM_10000_CLASSIFICATION-master/requirements.txt
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
altair==5.5.0
|
| 2 |
+
annotated-types==0.7.0
|
| 3 |
+
anyio==4.9.0
|
| 4 |
+
attrs==25.3.0
|
| 5 |
+
beautifulsoup4==4.13.3
|
| 6 |
+
blinker==1.9.0
|
| 7 |
+
cachetools==5.5.2
|
| 8 |
+
certifi==2025.1.31
|
| 9 |
+
charset-normalizer==3.4.1
|
| 10 |
+
click==8.1.8
|
| 11 |
+
colorama==0.4.6
|
| 12 |
+
contourpy==1.3.1
|
| 13 |
+
cycler==0.12.1
|
| 14 |
+
fastapi==0.115.12
|
| 15 |
+
filelock==3.18.0
|
| 16 |
+
fonttools==4.56.0
|
| 17 |
+
fsspec==2025.3.2
|
| 18 |
+
gdown==5.2.0
|
| 19 |
+
gitdb==4.0.12
|
| 20 |
+
GitPython==3.1.44
|
| 21 |
+
h11==0.14.0
|
| 22 |
+
idna==3.10
|
| 23 |
+
Jinja2==3.1.6
|
| 24 |
+
jsonschema==4.23.0
|
| 25 |
+
jsonschema-specifications==2024.10.1
|
| 26 |
+
kiwisolver==1.4.8
|
| 27 |
+
MarkupSafe==3.0.2
|
| 28 |
+
matplotlib==3.10.1
|
| 29 |
+
mpmath==1.3.0
|
| 30 |
+
narwhals==1.32.0
|
| 31 |
+
networkx==3.4.2
|
| 32 |
+
numpy==2.2.4
|
| 33 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 34 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 35 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 36 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 37 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 38 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 39 |
+
nvidia-curand-cu12==10.3.5.147
|
| 40 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 41 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 42 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 43 |
+
nvidia-nccl-cu12==2.21.5
|
| 44 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 45 |
+
nvidia-nvtx-cu12==12.4.127
|
| 46 |
+
opencv-python-headless==4.11.0.86
|
| 47 |
+
packaging==24.2
|
| 48 |
+
pandas==2.2.3
|
| 49 |
+
pillow==11.1.0
|
| 50 |
+
protobuf==5.29.4
|
| 51 |
+
pyarrow==19.0.1
|
| 52 |
+
pydantic==2.11.1
|
| 53 |
+
pydantic_core==2.33.0
|
| 54 |
+
pydeck==0.9.1
|
| 55 |
+
pyparsing==3.2.3
|
| 56 |
+
PySocks==1.7.1
|
| 57 |
+
python-dateutil==2.9.0.post0
|
| 58 |
+
python-multipart==0.0.20
|
| 59 |
+
pytz==2025.2
|
| 60 |
+
referencing==0.36.2
|
| 61 |
+
requests==2.32.3
|
| 62 |
+
rpds-py==0.24.0
|
| 63 |
+
setuptools==78.1.0
|
| 64 |
+
six==1.17.0
|
| 65 |
+
smmap==5.0.2
|
| 66 |
+
sniffio==1.3.1
|
| 67 |
+
soupsieve==2.6
|
| 68 |
+
starlette==0.46.1
|
| 69 |
+
streamlit==1.44.0
|
| 70 |
+
sympy==1.13.1
|
| 71 |
+
tenacity==9.0.0
|
| 72 |
+
toml==0.10.2
|
| 73 |
+
torch==2.6.0
|
| 74 |
+
torchvision==0.21.0
|
| 75 |
+
tornado==6.4.2
|
| 76 |
+
tqdm==4.67.1
|
| 77 |
+
triton==3.2.0
|
| 78 |
+
typing-inspection==0.4.0
|
| 79 |
+
typing_extensions==4.13.0
|
| 80 |
+
tzdata==2025.2
|
| 81 |
+
urllib3==2.3.0
|
| 82 |
+
uvicorn==0.34.0
|
| 83 |
+
watchdog==6.0.0
|
HAM_10000_CLASSIFICATION-master/utils/trainer.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs, output_folder, device="cpu"):
|
| 2 |
+
"""
|
| 3 |
+
Train a neural network model with specified training, validation, and testing datasets.
|
| 4 |
+
Additionally, plots accuracy and loss per epoch using matplotlib and saves them as images.
|
| 5 |
+
|
| 6 |
+
This function performs a complete training loop, including:
|
| 7 |
+
- Creating DataLoaders for training, validation, and testing datasets
|
| 8 |
+
- Moving the model to the specified device (CPU/GPU)
|
| 9 |
+
- Training the model for a specified number of epochs
|
| 10 |
+
- Tracking and logging training, validation, and testing metrics
|
| 11 |
+
- Saving the best (based on validation performance) and last model weights
|
| 12 |
+
- Plotting and saving accuracy and loss graphs per epoch
|
| 13 |
+
|
| 14 |
+
Parameters:
|
| 15 |
+
-----------
|
| 16 |
+
model : torch.nn.Module
|
| 17 |
+
The neural network model to be trained
|
| 18 |
+
train_loader : torch.utils.data.DataLoader
|
| 19 |
+
Dataset used for training the model
|
| 20 |
+
val_loader : torch.utils.data.DataLoader
|
| 21 |
+
Dataset used for validating the model during training
|
| 22 |
+
test_loader : torch.utils.data.DataLoader
|
| 23 |
+
Dataset used for evaluating the model's performance after training
|
| 24 |
+
optimizer : torch.optim.Optimizer
|
| 25 |
+
Optimization algorithm for updating model weights
|
| 26 |
+
criterion : torch.nn.Module
|
| 27 |
+
Loss function used to compute the model's performance
|
| 28 |
+
epochs : int
|
| 29 |
+
Number of complete passes through the entire training dataset
|
| 30 |
+
output_folder : str
|
| 31 |
+
Folder path where the model weights and plots will be saved
|
| 32 |
+
device : str, optional
|
| 33 |
+
Computing device to use for training (default is "cpu")
|
| 34 |
+
Can be "cpu" or "cuda" for GPU training
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
--------
|
| 38 |
+
None
|
| 39 |
+
|
| 40 |
+
Side Effects:
|
| 41 |
+
-------------
|
| 42 |
+
- Prints training, validation, and testing metrics for each epoch
|
| 43 |
+
- Saves the best performing model (based on validation accuracy) to "weights/best_model.pth"
|
| 44 |
+
- Saves the final model to "weights/last_model.pth"
|
| 45 |
+
- Saves the loss plot as "loss_plot.png" and accuracy plot as "accuracy_plot.png" in the output folder
|
| 46 |
+
|
| 47 |
+
Example:
|
| 48 |
+
--------
|
| 49 |
+
>>> model = MyModel()
|
| 50 |
+
>>> optimizer = torch.optim.Adam(model.parameters())
|
| 51 |
+
>>> criterion = nn.CrossEntropyLoss()
|
| 52 |
+
>>> train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs=10, batch_size=32, output_folder="weights")
|
| 53 |
+
"""
|
| 54 |
+
import os
|
| 55 |
+
import torch
|
| 56 |
+
from torch.utils.data import DataLoader
|
| 57 |
+
from tqdm import tqdm
|
| 58 |
+
import matplotlib.pyplot as plt
|
| 59 |
+
|
| 60 |
+
# Ensure weights folder exists
|
| 61 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
print(f"Device Found: {device}, Starting Training 🚀")
|
| 64 |
+
|
| 65 |
+
# Move model to the specified device
|
| 66 |
+
model = model.to(device)
|
| 67 |
+
|
| 68 |
+
best_val_accuracy = 0.0 # Initialize best validation accuracy tracker
|
| 69 |
+
|
| 70 |
+
# Lists to store metrics per epoch for plotting
|
| 71 |
+
train_losses, val_losses, test_losses = [], [], []
|
| 72 |
+
train_accuracies, val_accuracies, test_accuracies = [], [], []
|
| 73 |
+
|
| 74 |
+
for epoch in range(epochs):
|
| 75 |
+
# ----------------------
|
| 76 |
+
# Training Phase
|
| 77 |
+
# ----------------------
|
| 78 |
+
model.train() # Set model to training mode
|
| 79 |
+
running_loss = 0.0
|
| 80 |
+
correct = 0
|
| 81 |
+
total = 0
|
| 82 |
+
|
| 83 |
+
train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)", leave=False)
|
| 84 |
+
for images, labels in train_progress:
|
| 85 |
+
# Move tensors to the specified device
|
| 86 |
+
images, labels = images.to(device), labels.to(device)
|
| 87 |
+
|
| 88 |
+
optimizer.zero_grad() # Reset gradients
|
| 89 |
+
outputs = model(images) # Forward pass
|
| 90 |
+
loss = criterion(outputs, labels) # Compute loss
|
| 91 |
+
loss.backward() # Backpropagation
|
| 92 |
+
optimizer.step() # Update weights
|
| 93 |
+
|
| 94 |
+
running_loss += loss.item()
|
| 95 |
+
_, predicted = torch.max(outputs, 1)
|
| 96 |
+
total += labels.size(0)
|
| 97 |
+
correct += (predicted == labels).sum().item()
|
| 98 |
+
|
| 99 |
+
train_progress.set_postfix({
|
| 100 |
+
'Loss': f'{loss.item():.4f}',
|
| 101 |
+
'Accuracy': f'{100 * correct / total:.2f}%'
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
train_loss = running_loss / len(train_loader)
|
| 105 |
+
train_accuracy = 100 * correct / total
|
| 106 |
+
|
| 107 |
+
# ----------------------
|
| 108 |
+
# Validation Phase
|
| 109 |
+
# ----------------------
|
| 110 |
+
model.eval() # Set model to evaluation mode
|
| 111 |
+
val_loss = 0.0
|
| 112 |
+
correct_val = 0
|
| 113 |
+
total_val = 0
|
| 114 |
+
|
| 115 |
+
val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)", leave=False)
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
for images, labels in val_progress:
|
| 118 |
+
images, labels = images.to(device), labels.to(device)
|
| 119 |
+
outputs = model(images)
|
| 120 |
+
loss = criterion(outputs, labels)
|
| 121 |
+
val_loss += loss.item()
|
| 122 |
+
_, predicted = torch.max(outputs, 1)
|
| 123 |
+
total_val += labels.size(0)
|
| 124 |
+
correct_val += (predicted == labels).sum().item()
|
| 125 |
+
|
| 126 |
+
val_progress.set_postfix({
|
| 127 |
+
'Loss': f'{loss.item():.4f}',
|
| 128 |
+
'Accuracy': f'{100 * correct_val / total_val:.2f}%'
|
| 129 |
+
})
|
| 130 |
+
|
| 131 |
+
val_loss /= len(val_loader)
|
| 132 |
+
val_accuracy = 100 * correct_val / total_val
|
| 133 |
+
|
| 134 |
+
# ----------------------
|
| 135 |
+
# Testing Phase
|
| 136 |
+
# ----------------------
|
| 137 |
+
test_loss = 0.0
|
| 138 |
+
correct_test = 0
|
| 139 |
+
total_test = 0
|
| 140 |
+
|
| 141 |
+
test_progress = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} (Testing)", leave=False)
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
for images, labels in test_progress:
|
| 144 |
+
images, labels = images.to(device), labels.to(device)
|
| 145 |
+
outputs = model(images)
|
| 146 |
+
loss = criterion(outputs, labels)
|
| 147 |
+
test_loss += loss.item()
|
| 148 |
+
_, predicted = torch.max(outputs, 1)
|
| 149 |
+
total_test += labels.size(0)
|
| 150 |
+
correct_test += (predicted == labels).sum().item()
|
| 151 |
+
|
| 152 |
+
test_progress.set_postfix({
|
| 153 |
+
'Loss': f'{loss.item():.4f}',
|
| 154 |
+
'Accuracy': f'{100 * correct_test / total_test:.2f}%'
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
test_loss /= len(test_loader)
|
| 158 |
+
test_accuracy = 100 * correct_test / total_test
|
| 159 |
+
|
| 160 |
+
# Store metrics for plotting
|
| 161 |
+
train_losses.append(train_loss)
|
| 162 |
+
val_losses.append(val_loss)
|
| 163 |
+
test_losses.append(test_loss)
|
| 164 |
+
train_accuracies.append(train_accuracy)
|
| 165 |
+
val_accuracies.append(val_accuracy)
|
| 166 |
+
test_accuracies.append(test_accuracy)
|
| 167 |
+
|
| 168 |
+
# Log the metrics for this epoch
|
| 169 |
+
print(
|
| 170 |
+
f"Epoch [{epoch+1}/{epochs}]: "
|
| 171 |
+
f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% | "
|
| 172 |
+
f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% | "
|
| 173 |
+
f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Save the best model based on validation accuracy
|
| 177 |
+
if val_accuracy > best_val_accuracy:
|
| 178 |
+
best_val_accuracy = val_accuracy
|
| 179 |
+
torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))
|
| 180 |
+
|
| 181 |
+
# Save the last model
|
| 182 |
+
torch.save(model.state_dict(), os.path.join(output_folder, "last_model.pth"))
|
| 183 |
+
print("Training completed. Best validation accuracy: {:.2f}%".format(best_val_accuracy))
|
| 184 |
+
|
| 185 |
+
# ----------------------
|
| 186 |
+
# Plotting Metrics with Matplotlib
|
| 187 |
+
# ----------------------
|
| 188 |
+
epochs_range = range(1, epochs + 1)
|
| 189 |
+
|
| 190 |
+
# Plot Losses
|
| 191 |
+
plt.figure()
|
| 192 |
+
plt.plot(epochs_range, train_losses, label='Train Loss')
|
| 193 |
+
plt.plot(epochs_range, val_losses, label='Validation Loss')
|
| 194 |
+
plt.plot(epochs_range, test_losses, label='Test Loss')
|
| 195 |
+
plt.xlabel('Epoch')
|
| 196 |
+
plt.ylabel('Loss')
|
| 197 |
+
plt.title('Loss per Epoch')
|
| 198 |
+
plt.legend()
|
| 199 |
+
loss_plot_path = os.path.join(output_folder, 'loss_plot.png')
|
| 200 |
+
plt.savefig(loss_plot_path)
|
| 201 |
+
plt.close()
|
| 202 |
+
print(f"Loss plot saved to {loss_plot_path}")
|
| 203 |
+
|
| 204 |
+
# Plot Accuracies
|
| 205 |
+
plt.figure()
|
| 206 |
+
plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
|
| 207 |
+
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
|
| 208 |
+
plt.plot(epochs_range, test_accuracies, label='Test Accuracy')
|
| 209 |
+
plt.xlabel('Epoch')
|
| 210 |
+
plt.ylabel('Accuracy (%)')
|
| 211 |
+
plt.title('Accuracy per Epoch')
|
| 212 |
+
plt.legend()
|
| 213 |
+
acc_plot_path = os.path.join(output_folder, 'accuracy_plot.png')
|
| 214 |
+
plt.savefig(acc_plot_path)
|
| 215 |
+
plt.close()
|
| 216 |
+
print(f"Accuracy plot saved to {acc_plot_path}")
|