Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- LICENCE +21 -0
- README.md +134 -19
- __pycache__/utils.cpython-311.pyc +0 -0
- app.py +31 -0
- demo/ReSegNet-demo.mp4 +3 -0
- demo/after.jpg +0 -0
- demo/befire.png +3 -0
- models/model_weights.pth +3 -0
- requirements.txt +7 -3
- retina-blood-vessel-segmentation-f1-score-of-80.ipynb +1 -0
- run.py +2 -0
- utils.py +32 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo/befire.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
demo/ReSegNet-demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENCE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Eslam Tarek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,19 +1,134 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Retina Blood Vessel Segmentation
|
| 2 |
+
|
| 3 |
+
## About the Project
|
| 4 |
+
|
| 5 |
+
This project focuses on the automatic segmentation of blood vessels in retinal fundus images using deep learning. Accurate vessel segmentation is crucial for diagnosing and monitoring various ophthalmic diseases, such as diabetic retinopathy, glaucoma, and hypertensive retinopathy. The project leverages state-of-the-art convolutional neural networks to perform pixel-wise classification, distinguishing vessel structures from the background. The solution is designed for both research and practical clinical applications, providing robust and efficient segmentation results.
|
| 6 |
+
|
| 7 |
+
The repository contains:
|
| 8 |
+
- A Jupyter notebook for end-to-end training, evaluation, and visualization.
|
| 9 |
+
- A Streamlit web application for interactive inference on new images.
|
| 10 |
+
- Pretrained model weights and demo media for quick experimentation.
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## About the Dataset
|
| 15 |
+
|
| 16 |
+
The model is trained and evaluated on the [Retina Blood Vessel dataset](https://www.kaggle.com/datasets/abdallahwagih/retina-blood-vessel/data) from Kaggle. This dataset consists of high-resolution color fundus images and their corresponding binary masks, where vessel pixels are annotated by experts.
|
| 17 |
+
|
| 18 |
+
**Dataset Structure:**
|
| 19 |
+
- `image/`: Contains original RGB fundus images.
|
| 20 |
+
- `mask/`: Contains ground truth binary masks for vessel segmentation.
|
| 21 |
+
|
| 22 |
+
**Key Characteristics:**
|
| 23 |
+
- Images vary in illumination, contrast, and vessel visibility.
|
| 24 |
+
- Vessel pixels are a small fraction of the total image area, leading to class imbalance.
|
| 25 |
+
- The dataset is split into training and testing sets for model development and evaluation.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Notebook Summary
|
| 30 |
+
|
| 31 |
+
The provided notebook (`retina-blood-vessel-segmentation-f1-score-of-80.ipynb`) guides users through the entire workflow:
|
| 32 |
+
1. **Problem Definition:** Outlines the clinical motivation and technical challenges.
|
| 33 |
+
2. **Data Preparation:** Loads images and masks, applies preprocessing (resizing, normalization), and splits data into training and validation sets.
|
| 34 |
+
3. **Model Selection:** Utilizes a U-Net architecture with a ResNet34 encoder pretrained on ImageNet for effective feature extraction.
|
| 35 |
+
4. **Loss Function & Optimizer:** Combines Binary Cross Entropy and Dice Loss to address class imbalance and improve segmentation accuracy.
|
| 36 |
+
5. **Training:** Implements training and validation loops with progress monitoring and checkpointing.
|
| 37 |
+
6. **Evaluation:** Computes metrics (F1, IoU, Precision, Recall, Accuracy) and visualizes predictions alongside ground truth.
|
| 38 |
+
7. **Saving:** Exports the trained model for deployment.
|
| 39 |
+
|
| 40 |
+
The notebook is modular, well-commented, and suitable for both educational and research purposes.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Model Results
|
| 45 |
+
|
| 46 |
+
### Preprocessing
|
| 47 |
+
|
| 48 |
+
- **Image Normalization:** All images are scaled to [0, 1] and resized to 512x512 pixels to standardize input dimensions.
|
| 49 |
+
- **Mask Processing:** Masks are binarized and reshaped to match the model's output.
|
| 50 |
+
- **Augmentation:** (Optional) Techniques such as flipping, rotation, and brightness adjustment can be applied to improve generalization.
|
| 51 |
+
|
| 52 |
+
### Training
|
| 53 |
+
|
| 54 |
+
- **Architecture:** U-Net with a ResNet34 encoder, leveraging pretrained weights for faster convergence and better feature extraction.
|
| 55 |
+
- **Loss Function:** A combination of Binary Cross Entropy and Dice Loss is used to handle class imbalance and encourage overlap between predicted and true vessel regions.
|
| 56 |
+
- **Optimizer:** Adam optimizer with a learning rate scheduler (ReduceLROnPlateau) to adaptively reduce learning rate on validation loss plateaus.
|
| 57 |
+
- **Epochs:** Trained for 50 epochs with early stopping based on validation loss.
|
| 58 |
+
|
| 59 |
+
### Evaluation
|
| 60 |
+
|
| 61 |
+
- **Metrics:** The model is evaluated using F1 Score, Jaccard Index (IoU), Precision, Recall, and Accuracy.
|
| 62 |
+
- **Results:** Achieved an F1 score of **80%** on the test set, indicating strong performance in segmenting fine vessel structures.
|
| 63 |
+
- **Visualization:** The notebook provides side-by-side comparisons of original images, ground truth masks, and model predictions for qualitative assessment.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## How to Install
|
| 68 |
+
|
| 69 |
+
Follow these steps to set up the environment using Python's `venv`:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
# Clone the repository
|
| 73 |
+
git clone https://github.com/DeepActionPotential/ReSegNet
|
| 74 |
+
cd ReSegNet
|
| 75 |
+
|
| 76 |
+
# Create a virtual environment
|
| 77 |
+
python -m venv venv
|
| 78 |
+
|
| 79 |
+
# Activate the virtual environment
|
| 80 |
+
# On Windows:
|
| 81 |
+
venv\Scripts\activate
|
| 82 |
+
# On macOS/Linux:
|
| 83 |
+
source venv/bin/activate
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Install required packages
|
| 87 |
+
pip install -r requirements.txt
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
---
|
| 91 |
+
|
| 92 |
+
## How to Use the Software
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
### Web Demo
|
| 96 |
+
|
| 97 |
+
1. Ensure the trained model weights are available in the `models/` directory.
|
| 98 |
+
2. Run the Streamlit app:
|
| 99 |
+
```bash
|
| 100 |
+
streamlit run app.py
|
| 101 |
+
```
|
| 102 |
+
3. Upload a retinal image through the web interface and click "Run Segmentation" to see the predicted vessel mask.
|
| 103 |
+
|
| 104 |
+
### Demo Media
|
| 105 |
+
|
| 106 |
+
## [demo-video](demo/ReSegNet-demo.mp4)
|
| 107 |
+
|
| 108 |
+

|
| 109 |
+

|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
|
| 113 |
+
## Technologies Used
|
| 114 |
+
|
| 115 |
+
### Model Training
|
| 116 |
+
|
| 117 |
+
- **PyTorch:** Core deep learning framework for model definition, training, and evaluation.
|
| 118 |
+
- **segmentation-models-pytorch:** Provides high-level implementations of popular segmentation architectures (e.g., U-Net, FPN) with pretrained encoders.
|
| 119 |
+
- **OpenCV & NumPy:** For image processing, augmentation, and efficient data handling.
|
| 120 |
+
- **Matplotlib:** Visualization of images, masks, and results.
|
| 121 |
+
- **scikit-learn:** Calculation of evaluation metrics (F1, IoU, Precision, Recall, Accuracy).
|
| 122 |
+
|
| 123 |
+
### Deployment
|
| 124 |
+
|
| 125 |
+
- **Streamlit:** Rapid development of interactive web applications for model inference and visualization.
|
| 126 |
+
- **Pillow:** Image loading and preprocessing in the web app.
|
| 127 |
+
|
| 128 |
+
These technologies ensure a robust, reproducible, and user-friendly workflow from model development to deployment.
|
| 129 |
+
|
| 130 |
+
---
|
| 131 |
+
|
| 132 |
+
## License
|
| 133 |
+
|
| 134 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.77 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
from utils import load_model, preprocess_image, predict_mask, postprocess_mask
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Load the model
|
| 10 |
+
MODEL_PATH = "./models/model_weights.pth"
|
| 11 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
model = load_model(MODEL_PATH, DEVICE)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
st.set_page_config(page_title="Let Me Detect - Retina", layout="centered")
|
| 16 |
+
st.title("π§ Let Me Segmen - Retinal Segmentation")
|
| 17 |
+
|
| 18 |
+
uploaded_file = st.file_uploader("Upload a retinal image", type=["jpg", "png", "jpeg"])
|
| 19 |
+
|
| 20 |
+
if uploaded_file:
|
| 21 |
+
image = Image.open(uploaded_file).convert("RGB")
|
| 22 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
| 23 |
+
|
| 24 |
+
if st.button("Run Segmentation"):
|
| 25 |
+
with st.spinner("Segmenting..."):
|
| 26 |
+
input_tensor = preprocess_image(image).to(DEVICE)
|
| 27 |
+
pred_mask = predict_mask(model, input_tensor)
|
| 28 |
+
final_mask = postprocess_mask(pred_mask)
|
| 29 |
+
|
| 30 |
+
st.image(final_mask, caption="Predicted Mask", use_column_width=True)
|
| 31 |
+
st.success("Segmentation complete!")
|
demo/ReSegNet-demo.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59a61c48849911f97cee6949fca55048125b1b1734848dd0ec9decba3db9c7c0
|
| 3 |
+
size 2674969
|
demo/after.jpg
ADDED
|
demo/befire.png
ADDED
|
Git LFS Details
|
models/model_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cdd4e30f9cd9aeacc0e7bdd00c67160d36004f6771f62bf92f86448f93666bba
|
| 3 |
+
size 97920229
|
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.34.0
|
| 2 |
+
torch==2.3.0
|
| 3 |
+
torchvision==0.18.0
|
| 4 |
+
numpy==1.26.4
|
| 5 |
+
opencv-python==4.9.0.80
|
| 6 |
+
Pillow==10.3.0
|
| 7 |
+
segmentation_models_pytorch==0.3.3
|
retina-blood-vessel-segmentation-f1-score-of-80.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"none","dataSources":[{"sourceId":6318833,"sourceType":"datasetVersion","datasetId":3636171}],"dockerImageVersionId":30528,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":false}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"\n### **1. Problem Definition**\n\n* **Identify the Task**: Define the vision taskβe.g., image classification, object detection, or segmentation.\n* **Determine the Goal**: Establish what success looks likeβe.g., target accuracy, latency, or generalization.\n* **Specify Constraints**: Consider hardware limitations, deployment environment, and dataset availability.\n* **Use Case Examples**: Medical imaging, autonomous vehicles, retail analytics, etc.\n\n---\n\n### **2. Data Preparation**\n\n* Dataset loading, preprocessing, augmentation, and train/validation/test splitting.\n\n---\n\n### **3. Choose or Define Model**\n\n* Select pretrained architectures or design a custom model for your task.\n\n---\n\n### **4. Define Loss Function and Optimizer**\n\n* Match loss functions and optimizers to the problem type (classification, detection, etc.).\n\n---\n\n### **5. Train the Model**\n\n* Setup of training loops, optimization steps, and progress monitoring.\n\n---\n\n### **6. Evaluate the Model**\n\n* Test set performance metrics and visualizations.\n\n---\n\n### **7. Save**\n\n* Model serialization, export to deployment-friendly formats, and integration into applications.\n\n\n\n","metadata":{}},{"cell_type":"markdown","source":"\n## **1. Problem Definition: Retina Blood Vessel Segmentation**\n\nThe goal is to develop a deep learning model that segments blood vessels from retinal fundus images using the [Retina Blood Vessel dataset](https://www.kaggle.com/datasets/abdallahwagih/retina-blood-vessel/data). Accurate segmentation of retinal vessels is a critical step in diagnosing and monitoring eye diseases such as diabetic retinopathy, glaucoma, and hypertensive retinopathy.\n\nThe dataset provides color fundus images along with corresponding ground truth masks highlighting the vascular structure. This is a pixel-wise binary classification task, where the model must distinguish vessel pixels from the background.\n\n### Key Challenges:\n\n* **Class Imbalance**: Blood vessels cover a small fraction of each image, making it easy for the model to be biased toward predicting background.\n* **Fine Structural Detail**: Vessels are thin, branching, and vary in intensity, requiring high-resolution feature extraction and spatial precision.\n* **Image Variability**: Differences in illumination, contrast, and noise between samples increase the complexity of generalization.\n\n### Success Criteria:\n\n* High segmentation quality measured by **Dice coefficient**, **IoU**, **Precision**, and **Recall**.\n* Robust generalization to unseen data, especially across varying image qualities.\n* Efficient inference for potential integration in screening tools or clinical workflows.\n\n","metadata":{}},{"cell_type":"markdown","source":"#### Tools\n","metadata":{}},{"cell_type":"code","source":"!pip install segmentation-models-pytorch --quiet\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:12.940169Z","iopub.execute_input":"2025-05-03T15:25:12.940550Z","iopub.status.idle":"2025-05-03T15:25:22.959425Z","shell.execute_reply.started":"2025-05-03T15:25:12.940523Z","shell.execute_reply":"2025-05-03T15:25:22.958333Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Standard Library\nimport os\nimport time\nimport random\nfrom glob import glob\nfrom operator import add\nfrom pathlib import Path\n\n\n# Third-Party Libraries\nimport cv2\nimport numpy as np\nfrom tqdm import tqdm\nimport matplotlib.pyplot as plt\nfrom sklearn.metrics import (\n accuracy_score,\n f1_score,\n jaccard_score,\n precision_score,\n recall_score\n)\n\n# PyTorch and Related Modules\nimport torch\nimport torch.optim as optim\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset, DataLoader\nimport segmentation_models_pytorch as smp\n\n\n","metadata":{"execution":{"iopub.status.busy":"2025-05-03T15:25:22.961196Z","iopub.execute_input":"2025-05-03T15:25:22.961513Z","iopub.status.idle":"2025-05-03T15:25:28.468576Z","shell.execute_reply.started":"2025-05-03T15:25:22.961488Z","shell.execute_reply":"2025-05-03T15:25:28.467881Z"},"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### **2. Data Preparation**\n","metadata":{}},{"cell_type":"code","source":"class ImageMaskDataset(Dataset):\n def __init__(self, image_paths, mask_paths):\n self.image_paths = [Path(p) for p in image_paths]\n self.mask_paths = [Path(p) for p in mask_paths]\n\n def __len__(self):\n return len(self.image_paths)\n\n def __getitem__(self, idx):\n img = self._load_image(self.image_paths[idx])\n msk = self._load_mask(self.mask_paths[idx])\n return img, msk\n\n def _load_image(self, path: Path):\n arr = cv2.imread(str(path), cv2.IMREAD_COLOR)\n arr = arr.astype(np.float32) / 255.0\n arr = np.transpose(arr, (2, 0, 1))\n return torch.from_numpy(arr)\n\n def _load_mask(self, path: Path):\n arr = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)\n arr = arr.astype(np.float32) / 255.0\n arr = np.expand_dims(arr, 0)\n return torch.from_numpy(arr)\n","metadata":{"execution":{"iopub.status.busy":"2025-05-03T15:25:28.469533Z","iopub.execute_input":"2025-05-03T15:25:28.469770Z","iopub.status.idle":"2025-05-03T15:25:28.476479Z","shell.execute_reply.started":"2025-05-03T15:25:28.469749Z","shell.execute_reply":"2025-05-03T15:25:28.475526Z"},"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# βββ Configuration βββ\nconfig = {\n \"seed\": 42,\n \"data_root\": Path(\"/kaggle/input/retina-blood-vessel/Data\"),\n \"img_size\": (512, 512),\n \"batch_size\": 2,\n \"lr\": 1e-4,\n \"checkpoint_dir\": Path(\"files\") / \"checkpoint.pth\",\n \"epochs\": 50,\n}","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.478204Z","iopub.execute_input":"2025-05-03T15:25:28.478476Z","iopub.status.idle":"2025-05-03T15:25:28.486957Z","shell.execute_reply.started":"2025-05-03T15:25:28.478455Z","shell.execute_reply":"2025-05-03T15:25:28.486282Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def create_dir(path):\n \"\"\"\n Ensure that a directory exists (creates it if necessary).\n Accepts either a string or Path.\n \"\"\"\n Path(path).mkdir(parents=True, exist_ok=True)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.487957Z","iopub.execute_input":"2025-05-03T15:25:28.488503Z","iopub.status.idle":"2025-05-03T15:25:28.496632Z","shell.execute_reply.started":"2025-05-03T15:25:28.488472Z","shell.execute_reply":"2025-05-03T15:25:28.495891Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# βββ Setup βββ\ncreate_dir(config[\"checkpoint_dir\"].parent)\ncheckpoint_path = \"files/checkpoint.pth\"","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.497617Z","iopub.execute_input":"2025-05-03T15:25:28.497918Z","iopub.status.idle":"2025-05-03T15:25:28.505710Z","shell.execute_reply.started":"2025-05-03T15:25:28.497889Z","shell.execute_reply":"2025-05-03T15:25:28.505024Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# βββ Helpers βββ\ndef get_paths(root: Path, split: str, kind: str):\n return sorted((root / split / kind).glob(\"*\"))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.506793Z","iopub.execute_input":"2025-05-03T15:25:28.507092Z","iopub.status.idle":"2025-05-03T15:25:28.518396Z","shell.execute_reply.started":"2025-05-03T15:25:28.507063Z","shell.execute_reply":"2025-05-03T15:25:28.517677Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# βββ Data Paths βββ\ntrain_x = get_paths(config[\"data_root\"], \"train\", \"image\")\ntrain_y = get_paths(config[\"data_root\"], \"train\", \"mask\")\nvalid_x = get_paths(config[\"data_root\"], \"test\", \"image\")\nvalid_y = get_paths(config[\"data_root\"], \"test\", \"mask\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.519453Z","iopub.execute_input":"2025-05-03T15:25:28.519766Z","iopub.status.idle":"2025-05-03T15:25:28.597651Z","shell.execute_reply.started":"2025-05-03T15:25:28.519729Z","shell.execute_reply":"2025-05-03T15:25:28.596991Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"print(\n f\"Dataset Size:\\n\"\n f\" Train: {len(train_x)} samples\\n\"\n f\" Valid: {len(valid_x)} samples\\n\"\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.598710Z","iopub.execute_input":"2025-05-03T15:25:28.599030Z","iopub.status.idle":"2025-05-03T15:25:28.603703Z","shell.execute_reply.started":"2025-05-03T15:25:28.599001Z","shell.execute_reply":"2025-05-03T15:25:28.602945Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# βββ Datasets & Loaders βββ\ntrain_dataset = ImageMaskDataset(train_x, train_y)\nvalid_dataset = ImageMaskDataset(valid_x, valid_y)\n\ntrain_loader = DataLoader(\n train_dataset,\n batch_size=config[\"batch_size\"],\n shuffle=True,\n num_workers=2,\n)\nvalid_loader = DataLoader(\n valid_dataset,\n batch_size=config[\"batch_size\"],\n shuffle=False,\n num_workers=2,\n)\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.606313Z","iopub.execute_input":"2025-05-03T15:25:28.606547Z","iopub.status.idle":"2025-05-03T15:25:28.618622Z","shell.execute_reply.started":"2025-05-03T15:25:28.606527Z","shell.execute_reply":"2025-05-03T15:25:28.617872Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"train_dataset[0][0].shape, valid_dataset[0][0].shape","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.619564Z","iopub.execute_input":"2025-05-03T15:25:28.619870Z","iopub.status.idle":"2025-05-03T15:25:28.710432Z","shell.execute_reply.started":"2025-05-03T15:25:28.619839Z","shell.execute_reply":"2025-05-03T15:25:28.709479Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### **3. Choose or Define Model**\n","metadata":{}},{"cell_type":"code","source":"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.711562Z","iopub.execute_input":"2025-05-03T15:25:28.712236Z","iopub.status.idle":"2025-05-03T15:25:28.739724Z","shell.execute_reply.started":"2025-05-03T15:25:28.712183Z","shell.execute_reply":"2025-05-03T15:25:28.738665Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"model = smp.Unet(\n encoder_name=\"resnet34\",\n encoder_weights=\"imagenet\",\n in_channels=3,\n classes=1, \n activation=None\n).to(device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:28.740722Z","iopub.execute_input":"2025-05-03T15:25:28.740999Z","iopub.status.idle":"2025-05-03T15:25:31.026920Z","shell.execute_reply.started":"2025-05-03T15:25:28.740965Z","shell.execute_reply":"2025-05-03T15:25:31.025934Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### **4. Define Loss Function and Optimizer**\n","metadata":{}},{"cell_type":"code","source":"bce_loss = nn.BCEWithLogitsLoss()\ndice_loss = smp.losses.DiceLoss(mode=\"binary\")\n\n\ndef loss_fn(preds, targets):\n return bce_loss(preds, targets) + dice_loss(preds, targets)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:31.028050Z","iopub.execute_input":"2025-05-03T15:25:31.028342Z","iopub.status.idle":"2025-05-03T15:25:31.033001Z","shell.execute_reply.started":"2025-05-03T15:25:31.028318Z","shell.execute_reply":"2025-05-03T15:25:31.032112Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"optimizer = optim.Adam(model.parameters(), lr=config['lr'])\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(\n optimizer, mode=\"min\", patience=3, factor=0.5\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:31.034053Z","iopub.execute_input":"2025-05-03T15:25:31.034333Z","iopub.status.idle":"2025-05-03T15:25:31.045882Z","shell.execute_reply.started":"2025-05-03T15:25:31.034311Z","shell.execute_reply":"2025-05-03T15:25:31.045001Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### **5. Train the Model**\n","metadata":{}},{"cell_type":"code","source":"def train_one_epoch(loader):\n model.train()\n running_loss = 0.0\n for images, masks in tqdm(loader, desc=\"Train\"):\n images, masks = images.to(device), masks.to(device)\n preds = model(images)\n loss = loss_fn(preds, masks)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n running_loss += loss.item()\n return running_loss / len(loader)\n\n\n\ndef validate(loader):\n model.eval()\n val_loss = 0.0\n with torch.no_grad():\n for images, masks in tqdm(loader, desc=\"Validate\"):\n images, masks = images.to(device), masks.to(device)\n preds = model(images)\n val_loss += loss_fn(preds, masks).item()\n return val_loss / len(loader)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:31.047252Z","iopub.execute_input":"2025-05-03T15:25:31.047549Z","iopub.status.idle":"2025-05-03T15:25:31.057779Z","shell.execute_reply.started":"2025-05-03T15:25:31.047521Z","shell.execute_reply":"2025-05-03T15:25:31.057114Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"best_val = float(\"inf\")\nfor epoch in range(1, config['epochs']):\n train_loss = train_one_epoch(train_loader)\n val_loss = validate(valid_loader)\n scheduler.step(val_loss)\n\n print(f\"Epoch {epoch:02d} β train: {train_loss:.4f}, val: {val_loss:.4f}\")\n if val_loss < best_val:\n best_val = val_loss\n torch.save(model.state_dict(), checkpoint_path)\n print(\" β checkpoint saved\")\n\n# Inference Example\nmodel.load_state_dict(torch.load(checkpoint_path))\nmodel.eval()\nwith torch.no_grad():\n img, _ = valid_dataset[0]\n pred = model(img.unsqueeze(0).to(device))\n mask = torch.sigmoid(pred).cpu().squeeze().numpy() > 0.5\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:25:31.058886Z","iopub.execute_input":"2025-05-03T15:25:31.059121Z","iopub.status.idle":"2025-05-03T15:28:30.163032Z","shell.execute_reply.started":"2025-05-03T15:25:31.059098Z","shell.execute_reply":"2025-05-03T15:28:30.162023Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"valid_dataset[0]","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:28:30.164362Z","iopub.execute_input":"2025-05-03T15:28:30.165307Z","iopub.status.idle":"2025-05-03T15:28:30.202923Z","shell.execute_reply.started":"2025-05-03T15:28:30.165274Z","shell.execute_reply":"2025-05-03T15:28:30.202054Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### **6. Evaluate the Model**\n","metadata":{}},{"cell_type":"code","source":"# Helpers -------------------------------------------------------------------\n\ndef ensure_dir_exists(path: str):\n os.makedirs(path, exist_ok=True)\n\ndef tensor_to_numpy_image(tensor: torch.Tensor) -> np.ndarray:\n arr = tensor.cpu().numpy().transpose(1, 2, 0)\n return (arr * 255).astype(np.uint8)\n\ndef tensor_to_binary_mask(tensor: torch.Tensor, threshold: float = 0.5) -> np.ndarray:\n arr = tensor.cpu().numpy().squeeze()\n return (arr > threshold).astype(np.uint8)\n\ndef expand_mask_to_rgb(mask: np.ndarray) -> np.ndarray:\n return np.stack([mask]*3, axis=-1)\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:28:30.203946Z","iopub.execute_input":"2025-05-03T15:28:30.204185Z","iopub.status.idle":"2025-05-03T15:28:30.210106Z","shell.execute_reply.started":"2025-05-03T15:28:30.204164Z","shell.execute_reply":"2025-05-03T15:28:30.209257Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Metrics -------------------------------------------------------------------\n\nfrom sklearn.metrics import (\n jaccard_score, f1_score,\n recall_score, precision_score,\n accuracy_score\n)\n\ndef compute_metrics_for_sample(y_true: torch.Tensor, y_pred: torch.Tensor):\n \"\"\"\n Returns [jaccard, f1, recall, precision, accuracy] for a single sample.\n \"\"\"\n y_t = (y_true.cpu().numpy().ravel() > 0.5).astype(np.uint8)\n y_p = (y_pred.cpu().numpy().ravel() > 0.5).astype(np.uint8)\n\n return [\n jaccard_score(y_t, y_p),\n f1_score(y_t, y_p),\n recall_score(y_t, y_p),\n precision_score(y_t, y_p),\n accuracy_score(y_t, y_p),\n ]","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:28:30.211193Z","iopub.execute_input":"2025-05-03T15:28:30.211498Z","iopub.status.idle":"2025-05-03T15:28:30.221165Z","shell.execute_reply.started":"2025-05-03T15:28:30.211475Z","shell.execute_reply":"2025-05-03T15:28:30.220265Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# I/O -----------------------------------------------------------------------\n\ndef save_comparison_image(\n orig_img: np.ndarray,\n gt_mask: np.ndarray,\n pred_mask: np.ndarray,\n save_dir: str,\n filename: str,\n img_size: tuple\n):\n height, width = img_size\n separator = np.ones((height, 10, 3), dtype=np.uint8) * 128\n\n gt_rgb = expand_mask_to_rgb(gt_mask)\n pred_rgb = expand_mask_to_rgb(pred_mask)\n\n composite = np.concatenate(\n [orig_img, separator, gt_rgb, separator, pred_rgb],\n axis=1\n )\n cv2.imwrite(os.path.join(save_dir, filename), composite)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:28:30.222133Z","iopub.execute_input":"2025-05-03T15:28:30.222388Z","iopub.status.idle":"2025-05-03T15:28:30.236284Z","shell.execute_reply.started":"2025-05-03T15:28:30.222366Z","shell.execute_reply":"2025-05-03T15:28:30.235490Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Main Evaluation -----------------------------------------------------------\n\ndef evaluate_model(\n model: torch.nn.Module,\n data_loader: torch.utils.data.DataLoader,\n device: torch.device,\n results_dir: str,\n img_size: tuple\n):\n ensure_dir_exists(results_dir)\n model.to(device).eval()\n\n total_metrics = np.zeros(5, dtype=float)\n n_samples = len(data_loader.dataset)\n sample_idx = 0\n\n with torch.no_grad():\n for imgs, masks in tqdm(data_loader, desc=\"Evaluating\", total=len(data_loader)):\n imgs = imgs.to(device)\n masks = masks.to(device)\n\n preds = torch.sigmoid(model(imgs))\n\n for img_t, mask_t, pred_t in zip(imgs, masks, preds):\n metrics = compute_metrics_for_sample(mask_t, pred_t)\n total_metrics += np.array(metrics)\n\n orig = tensor_to_numpy_image(img_t)\n gt = tensor_to_binary_mask(mask_t) * 255\n pr = tensor_to_binary_mask(pred_t) * 255\n filename = f\"sample_{sample_idx:04d}.png\"\n save_comparison_image(orig, gt, pr, results_dir, filename, img_size)\n sample_idx += 1\n\n avg_metrics = total_metrics / n_samples\n jaccard, f1, recall, precision, accuracy = avg_metrics\n \n print(f\"Accuracy: {accuracy:.4f}\")\n print(f\"F1 Score: {f1:.4f}\")\n print(f\"Recall: {recall:.4f}\")\n print(f\"Precision:{precision:.4f}\")\n print(f\"Jaccard: {jaccard:.4f}\")\n\n# Usage ---------------------------------------------------------------------\nmodel.load_state_dict(torch.load(checkpoint_path))\nevaluate_model(model, valid_loader, device, \"results\", config[\"img_size\"])\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:28:30.237372Z","iopub.execute_input":"2025-05-03T15:28:30.237623Z","iopub.status.idle":"2025-05-03T15:28:36.198983Z","shell.execute_reply.started":"2025-05-03T15:28:30.237602Z","shell.execute_reply":"2025-05-03T15:28:36.197783Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"model.eval()\n\n# Number of examples to display\nnum_examples = 10\n\nfor idx in range(num_examples):\n # Get image & mask from your dataset\n img_t, mask_t = valid_dataset[idx] # img_t: Tensor [3,H,W], mask_t: Tensor [1,H,W]\n \n # Run the model\n with torch.no_grad():\n pred_t = torch.sigmoid(model(img_t.unsqueeze(0).to(device)))\n pred_t = pred_t.cpu().squeeze(0) # [1,H,W]\n \n # Convert to numpy uint8 for plotting\n img_np = img_t.cpu().numpy().transpose(1,2,0) # [H,W,3], floats in [0,1]\n img_np = (img_np * 255).astype(np.uint8)\n img_np = img_np[..., ::-1]\n \n gt_mask = (mask_t.cpu().numpy().squeeze() > 0.5).astype(np.uint8) * 255\n pr_mask = (pred_t.cpu().numpy().squeeze() > 0.5).astype(np.uint8) * 255\n \n # Make 3-channel versions of the masks\n gt_rgb = np.stack([gt_mask]*3, axis=-1)\n pr_rgb = np.stack([pr_mask]*3, axis=-1)\n \n # Build a separator and composite image\n h, w, _ = img_np.shape\n sep = np.ones((h, 10, 3), dtype=np.uint8) * 128\n composite = np.concatenate([img_np, sep, gt_rgb, sep, pr_rgb], axis=1)\n \n # Plot\n plt.figure(figsize=(12, 6))\n plt.axis('off')\n plt.imshow(composite)\n plt.title(f\"Sample {idx}: Original | Ground Truth | Prediction\")\n plt.show()\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:28:36.200790Z","iopub.execute_input":"2025-05-03T15:28:36.201658Z","iopub.status.idle":"2025-05-03T15:28:39.509896Z","shell.execute_reply.started":"2025-05-03T15:28:36.201618Z","shell.execute_reply":"2025-05-03T15:28:39.508884Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### **7. Save the model**","metadata":{}},{"cell_type":"code","source":"torch.save(model, \"/kaggle/working/model_weights.pth\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-03T15:32:47.344924Z","iopub.execute_input":"2025-05-03T15:32:47.345285Z","iopub.status.idle":"2025-05-03T15:32:47.569946Z","shell.execute_reply.started":"2025-05-03T15:32:47.345253Z","shell.execute_reply":"2025-05-03T15:32:47.568922Z"}},"outputs":[],"execution_count":null}]}
|
run.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
subprocess.run(["streamlit", "run", "app.py"])
|
utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision.transforms as T
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import cv2
|
| 6 |
+
from segmentation_models_pytorch import Unet
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def preprocess_image(image, size=(512, 512)):
|
| 10 |
+
image = image.resize(size)
|
| 11 |
+
img_array = np.array(image).astype(np.float32) / 255.0
|
| 12 |
+
img_array = np.transpose(img_array, (2, 0, 1))
|
| 13 |
+
tensor = torch.tensor(img_array).unsqueeze(0)
|
| 14 |
+
return tensor
|
| 15 |
+
|
| 16 |
+
def predict_mask(model, tensor):
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
output = torch.sigmoid(model(tensor))
|
| 19 |
+
return output.squeeze().cpu().numpy()
|
| 20 |
+
|
| 21 |
+
def postprocess_mask(mask_array, threshold=0.5):
|
| 22 |
+
mask = (mask_array > threshold).astype(np.uint8) * 255
|
| 23 |
+
mask_rgb = np.stack([mask]*3, axis=-1)
|
| 24 |
+
return Image.fromarray(mask_rgb)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_model(path: str, device: torch.device):
|
| 28 |
+
model = Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1) # your architecture
|
| 29 |
+
model.load_state_dict(torch.load(path, map_location=device))
|
| 30 |
+
model.to(device)
|
| 31 |
+
model.eval()
|
| 32 |
+
return model
|