Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- .github/workflows/update_space.yml +28 -0
- .gitignore +176 -0
- LICENSE +21 -0
- README.md +133 -8
- app.py +300 -0
- bird.jpeg +0 -0
- enhaned_kmeans_segmented.png +0 -0
- experiments/SegNet/efficient_b0_backbone/architecture.py +178 -0
- experiments/SegNet/efficient_b0_backbone/train.py +81 -0
- experiments/SegNet/vgg_backbone/SegNet_with_VGG16_backbone.ipynb +0 -0
- experiments/SegNet/vgg_backbone/model.py +48 -0
- experiments/enhanced_kmeans_segmenter.py +100 -0
- experiments/ensemble_method.py +148 -0
- experiments/felzenszwalb_segmentation/__init__.py +1 -0
- experiments/felzenszwalb_segmentation/disjoint_set.py +39 -0
- experiments/felzenszwalb_segmentation/segmentation.py +83 -0
- experiments/felzenszwalb_segmentation/utils/__init__.py +2 -0
- experiments/felzenszwalb_segmentation/utils/filter_utils.py +38 -0
- experiments/felzenszwalb_segmentation/utils/utils.py +25 -0
- experiments/kmeans_segmenter.py +95 -0
- experiments/otsu_segmenter.py +95 -0
- experiments/watershed_segmenter.py +208 -0
- kmeans_comparison.png +3 -0
- kmeans_segmented.png +0 -0
- requirements.txt +11 -0
- saved_models/segnet_efficientnet_camvid.pth +3 -0
- saved_models/segnet_vgg.pth +3 -0
- segnet_efficientnet_voc.pth +3 -0
- watershed_output.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
kmeans_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
watershed_output.png filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/update_space.yml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Run Python script
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
build:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
|
| 12 |
+
steps:
|
| 13 |
+
- name: Checkout
|
| 14 |
+
uses: actions/checkout@v2
|
| 15 |
+
|
| 16 |
+
- name: Set up Python
|
| 17 |
+
uses: actions/setup-python@v2
|
| 18 |
+
with:
|
| 19 |
+
python-version: '3.9'
|
| 20 |
+
|
| 21 |
+
- name: Install Gradio
|
| 22 |
+
run: python -m pip install gradio
|
| 23 |
+
|
| 24 |
+
- name: Log in to Hugging Face
|
| 25 |
+
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
| 26 |
+
|
| 27 |
+
- name: Deploy to Spaces
|
| 28 |
+
run: gradio deploy
|
.gitignore
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# Ruff stuff:
|
| 171 |
+
.ruff_cache/
|
| 172 |
+
|
| 173 |
+
# PyPI configuration file
|
| 174 |
+
.pypirc
|
| 175 |
+
|
| 176 |
+
./saved_models/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Akshat Jain
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,137 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji: 📈
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: pink
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.25.0
|
| 8 |
app_file: app.py
|
| 9 |
-
|
|
|
|
| 10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Image_Segmentation_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
app_file: app.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.23.1
|
| 6 |
---
|
| 7 |
+
# Image Segmentation Toolkit
|
| 8 |
+
|
| 9 |
+
## Overview
|
| 10 |
+
This project implements a comprehensive image segmentation toolkit that combines classical computer vision techniques with deep learning-based approaches. The application provides an interactive interface to compare different segmentation algorithms on user-provided images.
|
| 11 |
+
|
| 12 |
+
## Features
|
| 13 |
+
- **Classical Segmentation Methods**:
|
| 14 |
+
- Otsu's Thresholding: Optimal global thresholding for binary segmentation
|
| 15 |
+
- K-means Clustering: Color-based segmentation with adjustable clusters
|
| 16 |
+
- SLIC (Simple Linear Iterative Clustering): Superpixel segmentation
|
| 17 |
+
- Watershed Algorithm: Gradient-based segmentation for separating touching objects
|
| 18 |
+
- Felzenszwalb Algorithm: Graph-based segmentation with adaptive thresholding
|
| 19 |
+
|
| 20 |
+
- **Deep Learning Models**:
|
| 21 |
+
- SegNet with EfficientNet B0 backbone: Pretrained semantic segmentation model
|
| 22 |
+
- SegNet with VGG backbone: Alternative architecture for comparison
|
| 23 |
+
|
| 24 |
+
- **Ensemble Methods**:
|
| 25 |
+
- Otsu + SegNet: Combining boundary information from Otsu with semantic labels from SegNet
|
| 26 |
+
- Custom ensemble segmentation with adjustable parameters
|
| 27 |
+
|
| 28 |
+
## Installation
|
| 29 |
+
|
| 30 |
+
### Prerequisites
|
| 31 |
+
- Python 3.8+
|
| 32 |
+
- PyTorch 1.10+
|
| 33 |
+
- CUDA-compatible GPU (recommended)
|
| 34 |
+
|
| 35 |
+
### Setup
|
| 36 |
+
1. Clone the repository:
|
| 37 |
+
```bash
|
| 38 |
+
git clone https://github.com/yourusername/CSL7360_Project.git
|
| 39 |
+
cd CSL7360_Project
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
2. Create and activate a virtual environment (optional but recommended):
|
| 43 |
+
```bash
|
| 44 |
+
python -m venv venv
|
| 45 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
3. Install required packages:
|
| 49 |
+
```bash
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
4. Download pretrained models:
|
| 54 |
+
```bash
|
| 55 |
+
python download_models.py
|
| 56 |
+
```
|
| 57 |
+
The application will also automatically download models when first launched.
|
| 58 |
+
|
| 59 |
+
## Usage
|
| 60 |
+
|
| 61 |
+
### Running the Application
|
| 62 |
+
Start the Gradio web interface:
|
| 63 |
+
```bash
|
| 64 |
+
python app.py
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
The interface will be available at http://127.0.0.1:7860 in your web browser.
|
| 68 |
+
|
| 69 |
+
### Using the Interface
|
| 70 |
+
1. Select a segmentation method from the tabs at the top
|
| 71 |
+
2. Upload an image using the file picker
|
| 72 |
+
3. Adjust algorithm parameters if available
|
| 73 |
+
4. Click the "Segment this image" button
|
| 74 |
+
5. View the results in the display area
|
| 75 |
+
|
| 76 |
+
### Algorithm Parameters
|
| 77 |
+
|
| 78 |
+
#### Otsu's Method
|
| 79 |
+
- No parameters, fully automatic threshold selection
|
| 80 |
+
|
| 81 |
+
#### K-means Segmentation
|
| 82 |
+
- **Number of Clusters (K)**: Controls how many color groups to segment into
|
| 83 |
+
|
| 84 |
+
#### SLIC Segmentation
|
| 85 |
+
- **Number of superpixels**: Controls the granularity of segmentation
|
| 86 |
+
- **Compactness factor**: Controls how much superpixels adhere to boundaries
|
| 87 |
+
- **Number of iterations**: Controls refinement of superpixel boundaries
|
| 88 |
+
|
| 89 |
+
#### Felzenszwalb Algorithm
|
| 90 |
+
- **Sigma**: Gaussian pre-processing smoothing parameter
|
| 91 |
+
- **K value**: Controls segment size preference
|
| 92 |
+
- **Min Size Factor**: Minimum component size
|
| 93 |
+
|
| 94 |
+
#### Ensemble Segmentation
|
| 95 |
+
- **Boundary Refinement Weight**: Controls influence of classical methods on deep learning boundaries
|
| 96 |
+
|
| 97 |
+
## Project Structure
|
| 98 |
+
```
|
| 99 |
+
CSL7360_Project/
|
| 100 |
+
├── app.py # Main application with pretrained models
|
| 101 |
+
├── experiments/ # Implementation of segmentation algorithms
|
| 102 |
+
│ ├── ensemble_method.py # Ensemble segmentation implementation
|
| 103 |
+
│ ├── felzenszwalb_segmentation/ # Felzenszwalb algorithm implementation
|
| 104 |
+
│ ├── kmeans_segmenter.py # K-means segmentation implementation
|
| 105 |
+
│ ├── enhanced_kmeans_segmenter.py # SLIC implementation
|
| 106 |
+
│ ├── otsu_segmenter.py # Otsu thresholding implementation
|
| 107 |
+
│ ├── watershed_segmenter.py # Watershed algorithm implementation
|
| 108 |
+
│ └── SegNet/ # Deep learning models
|
| 109 |
+
│ ├── efficient_b0_backbone/ # EfficientNet backbone for SegNet
|
| 110 |
+
│ └── vgg_backbone/ # VGG backbone for SegNet
|
| 111 |
+
├── saved_models/ # Directory for pretrained weights
|
| 112 |
+
└── requirements.txt # Package dependencies
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Examples
|
| 116 |
+
The application works well on a variety of images:
|
| 117 |
+
- Natural scenes
|
| 118 |
+
- Urban environments
|
| 119 |
+
- Medical images
|
| 120 |
+
- Aerial/satellite imagery
|
| 121 |
+
- Objects with clear boundaries
|
| 122 |
+
|
| 123 |
+
## Technologies Used
|
| 124 |
+
- **PyTorch**: Deep learning framework
|
| 125 |
+
- **OpenCV**: Classical computer vision algorithms
|
| 126 |
+
- **NumPy**: Numerical computations
|
| 127 |
+
- **PIL/Pillow**: Image loading and manipulation
|
| 128 |
+
- **Gradio**: Interactive web interface
|
| 129 |
+
- **Matplotlib**: Visualization of results
|
| 130 |
+
|
| 131 |
+
## Credits
|
| 132 |
+
- Built as part of CSL7360 course project
|
| 133 |
+
- Uses pretrained models based on Pascal VOC and CamVid datasets
|
| 134 |
+
- Implements algorithms from classical computer vision literature
|
| 135 |
|
| 136 |
+
## License
|
| 137 |
+
This project is available under the MIT License.
|
app.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from experiments.otsu_segmenter import generate_segmented_image
|
| 5 |
+
from experiments.kmeans_segmenter import generate_kmeans_segmented_image
|
| 6 |
+
from experiments.enhanced_kmeans_segmenter import slic_kmeans
|
| 7 |
+
from experiments.watershed_segmenter import generate_watershed
|
| 8 |
+
from experiments.felzenszwalb_segmentation import segment
|
| 9 |
+
from experiments.SegNet.efficient_b0_backbone.architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE
|
| 10 |
+
from experiments.SegNet.vgg_backbone.model import SegNet
|
| 11 |
+
# from experiments.ensemble_method import generate_ensemble_segmentation
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from matplotlib import cm
|
| 15 |
+
import gdown
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
# Check if the saved_models directory exists, if not create it
|
| 19 |
+
if not os.path.exists("saved_models"):
|
| 20 |
+
os.makedirs("saved_models")
|
| 21 |
+
|
| 22 |
+
# Check if the model file already exists before downloading
|
| 23 |
+
if not os.path.exists("saved_models/segnet_vgg.pth"):
|
| 24 |
+
print("Downloading SegNet VGG weights...")
|
| 25 |
+
segnet_vgg_weights = "https://drive.google.com/file/d/1EFXKQ_3bDW9FbZCqOLdrE0DOI0V4W82o/view?usp=sharing"
|
| 26 |
+
gdown.download(segnet_vgg_weights, "saved_models/segnet_vgg.pth", fuzzy=True)
|
| 27 |
+
print("Download complete!")
|
| 28 |
+
else:
|
| 29 |
+
print("SegNet VGG weights already exist, skipping download.")
|
| 30 |
+
|
| 31 |
+
def generate_segnet_vgg(image_path):
|
| 32 |
+
model = SegNet(32).to(DEVICE)
|
| 33 |
+
model.load_state_dict(torch.load("saved_models/segnet_vgg.pth", map_location=DEVICE))
|
| 34 |
+
# Set model to evaluation mode
|
| 35 |
+
model.eval()
|
| 36 |
+
|
| 37 |
+
# Load and preprocess the image
|
| 38 |
+
image = Image.open(image_path).convert('RGB')
|
| 39 |
+
original_image = image.copy()
|
| 40 |
+
|
| 41 |
+
# Apply same preprocessing as during training
|
| 42 |
+
transform = transforms.Compose([
|
| 43 |
+
transforms.Resize((224, 224)), # Adjust size to match your model's expected input
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
| 49 |
+
|
| 50 |
+
# Get prediction
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
output = model(input_tensor)
|
| 53 |
+
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
|
| 54 |
+
|
| 55 |
+
# Convert prediction to visualization
|
| 56 |
+
# Option 1: Use a colormap for visualization
|
| 57 |
+
colormap = cm.get_cmap('nipy_spectral')
|
| 58 |
+
colored_mask = colormap(pred_mask / (pred_mask.max() or 1)) # Normalize, handle case where max is 0
|
| 59 |
+
colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8) # Drop alpha and convert to uint8
|
| 60 |
+
segmented_image = Image.fromarray(colored_mask)
|
| 61 |
+
|
| 62 |
+
# Resize segmented image to match original image size
|
| 63 |
+
segmented_image = segmented_image.resize(original_image.size, Image.NEAREST)
|
| 64 |
+
|
| 65 |
+
return original_image, segmented_image
|
| 66 |
+
|
| 67 |
+
def generate_kmeans(image_path,k):
|
| 68 |
+
kmeans_image_output, kmeans_segmented_image_output,_,kmeans_threshold_text=generate_kmeans_segmented_image(image_path, k)
|
| 69 |
+
return kmeans_image_output, kmeans_segmented_image_output, kmeans_threshold_text
|
| 70 |
+
|
| 71 |
+
def generate_slic(image_path,k,m,max_iter):
|
| 72 |
+
image,seg_img, labels, centers = slic_kmeans(image_path, K=k, m=m, max_iter=max_iter)
|
| 73 |
+
return image,seg_img
|
| 74 |
+
|
| 75 |
+
def generate_felzenszwalb(image_path, sigma, k, min_size_factor):
|
| 76 |
+
image = Image.open(image_path).convert("RGB")
|
| 77 |
+
image_np = np.array(image)
|
| 78 |
+
segments_fz = segment(image_np, sigma=sigma, k=k, min_size=min_size_factor)
|
| 79 |
+
segments_fz = segments_fz.astype(np.uint8)
|
| 80 |
+
|
| 81 |
+
return image, segments_fz
|
| 82 |
+
|
| 83 |
+
def SegNet_efficient_b0(image_path):
|
| 84 |
+
model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
|
| 85 |
+
model.load_state_dict(torch.load("saved_models/segnet_efficientnet_camvid.pth", map_location=DEVICE))
|
| 86 |
+
model.eval()
|
| 87 |
+
transform = transforms.Compose([
|
| 88 |
+
transforms.Resize((360, 480)), # Or larger if needed
|
| 89 |
+
transforms.ToTensor(),
|
| 90 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 91 |
+
std=[0.229, 0.224, 0.225])
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
image = Image.open(image_path).convert("RGB")
|
| 95 |
+
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
output = model(input_tensor)
|
| 99 |
+
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
|
| 100 |
+
|
| 101 |
+
# Convert original image for Gradio display
|
| 102 |
+
original_image_resized = image
|
| 103 |
+
|
| 104 |
+
# Convert predicted mask to a color image using a colormap
|
| 105 |
+
colormap = cm.get_cmap('nipy_spectral')
|
| 106 |
+
colored_mask = colormap(pred_mask / pred_mask.max()) # Normalize
|
| 107 |
+
colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8) # Drop alpha and convert to uint8
|
| 108 |
+
mask_pil = Image.fromarray(colored_mask)
|
| 109 |
+
|
| 110 |
+
return original_image_resized, mask_pil
|
| 111 |
+
|
| 112 |
+
def ensemble_segmentation(image_path):
|
| 113 |
+
"""
|
| 114 |
+
Ensemble segmentation combining SegNet and Otsu,
|
| 115 |
+
assuming Otsu produces a mask with the foreground as black (value 0)
|
| 116 |
+
and background as white (value 255).
|
| 117 |
+
|
| 118 |
+
In this ensemble, we force the SegNet prediction to background (class 0)
|
| 119 |
+
where Otsu indicates background (after inversion, i.e., where otsu_bin==0).
|
| 120 |
+
|
| 121 |
+
Parameters:
|
| 122 |
+
image_path (str): Path to the input image.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
original_image: The original resized image used for segmentation.
|
| 126 |
+
segnet_mask_pil: SegNet multi-class segmentation output (PIL image).
|
| 127 |
+
otsu_mask_pil: The original Otsu binary segmentation mask (PIL image).
|
| 128 |
+
ensemble_mask_pil: Final ensemble segmentation mask (PIL image).
|
| 129 |
+
"""
|
| 130 |
+
# Run SegNet segmentation (model outputs a multi-class mask).
|
| 131 |
+
segnet_orig, segnet_mask_pil = SegNet_efficient_b0(image_path)
|
| 132 |
+
# Convert SegNet output to a NumPy array (assumed grayscale labeling, e.g., background=0).
|
| 133 |
+
segnet_mask_np = np.array(segnet_mask_pil.convert("L"))
|
| 134 |
+
|
| 135 |
+
# Run Otsu segmentation. (generate_segmented_image returns several outputs.)
|
| 136 |
+
_, otsu_segmented_pil, _, _, _ = generate_segmented_image(image_path)
|
| 137 |
+
|
| 138 |
+
# Resize Otsu mask to match SegNet output shape, e.g., (480, 360) if SegNet works in that resolution.
|
| 139 |
+
resized_shape = (segnet_mask_np.shape[1], segnet_mask_np.shape[0])
|
| 140 |
+
otsu_mask_resized = otsu_segmented_pil.resize(resized_shape, Image.NEAREST)
|
| 141 |
+
otsu_mask_np = np.array(otsu_mask_resized)
|
| 142 |
+
|
| 143 |
+
# Invert Otsu's binary mask:
|
| 144 |
+
# Assuming that in otsu_mask_np, foreground is black (0) and background is white (255),
|
| 145 |
+
# we build a binary mask where "1" represents the object's area.
|
| 146 |
+
otsu_bin = (otsu_mask_np == 0).astype(np.uint8) # Now, foreground is 1 and background is 0.
|
| 147 |
+
|
| 148 |
+
# Create the ensemble segmentation:
|
| 149 |
+
# Where Otsu indicates foreground (otsu_bin==1), keep SegNet's prediction;
|
| 150 |
+
# where Otsu is background (otsu_bin==0), force it to background class (0).
|
| 151 |
+
ensemble_seg = np.where(otsu_bin == 1, segnet_mask_np, 0)
|
| 152 |
+
|
| 153 |
+
# Convert back to a PIL image.
|
| 154 |
+
ensemble_mask_pil = Image.fromarray(ensemble_seg.astype(np.uint8))
|
| 155 |
+
|
| 156 |
+
return segnet_orig, segnet_mask_pil, otsu_segmented_pil, ensemble_mask_pil
|
| 157 |
+
|
| 158 |
+
with gr.Blocks() as demo:
|
| 159 |
+
gr.Markdown("# Image Segmentation using Classical CV")
|
| 160 |
+
|
| 161 |
+
with gr.Tabs() as tabs:
|
| 162 |
+
with gr.TabItem("Otsu's Method"):
|
| 163 |
+
with gr.Row():
|
| 164 |
+
with gr.Column(scale=1):
|
| 165 |
+
file_input = gr.File(label="Upload Image File")
|
| 166 |
+
display_btn = gr.Button("Segment this image")
|
| 167 |
+
threshold_text = gr.Textbox(label="Threshold Comparison", value="", interactive=False)
|
| 168 |
+
|
| 169 |
+
with gr.Column(scale=2):
|
| 170 |
+
image_output = gr.Image(label="Original Image")
|
| 171 |
+
histogram_output = gr.Image(label="Histogram")
|
| 172 |
+
segmented_image_output = gr.Image(label="Our Segmented Image")
|
| 173 |
+
opencv_segmented_image_output = gr.Image(label="OpenCV Segmented Image")
|
| 174 |
+
display_btn.click(
|
| 175 |
+
fn=generate_segmented_image,
|
| 176 |
+
inputs=file_input,
|
| 177 |
+
outputs=[image_output, segmented_image_output, opencv_segmented_image_output, histogram_output, threshold_text]
|
| 178 |
+
)
|
| 179 |
+
with gr.TabItem("K-means Segmentation"):
|
| 180 |
+
with gr.Row():
|
| 181 |
+
with gr.Column(scale=1):
|
| 182 |
+
kmeans_file_input = gr.File(label="Upload Image File")
|
| 183 |
+
kmeans_k_value = gr.Slider(minimum=2, maximum=10, value=3, step=1, label="Number of Clusters (K)")
|
| 184 |
+
kmeans_display_btn = gr.Button("Segment this image")
|
| 185 |
+
kmeans_threshold_text = gr.Textbox(label="K-means Info", value="", interactive=False)
|
| 186 |
+
|
| 187 |
+
with gr.Column(scale=2):
|
| 188 |
+
kmeans_image_output = gr.Image(label="Original Image")
|
| 189 |
+
kmeans_segmented_image_output = gr.Image(label="K-means Segmented Image")
|
| 190 |
+
|
| 191 |
+
kmeans_display_btn.click(
|
| 192 |
+
fn=generate_kmeans,
|
| 193 |
+
inputs=[kmeans_file_input, kmeans_k_value],
|
| 194 |
+
outputs=[kmeans_image_output, kmeans_segmented_image_output, kmeans_threshold_text]
|
| 195 |
+
)
|
| 196 |
+
with gr.TabItem("SLIC Segmentation"):
|
| 197 |
+
with gr.Row():
|
| 198 |
+
with gr.Column(scale=1):
|
| 199 |
+
slic_file_input = gr.File(label="Upload Image File")
|
| 200 |
+
slic_k_value = gr.Slider(minimum=2, maximum=200, value=3, step=1, label="Number of superpixels")
|
| 201 |
+
slic_m_value = gr.Slider(minimum=1, maximum=40, value=3, step=1, label="Compactness factor")
|
| 202 |
+
slic_max_iter_value = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of iterations")
|
| 203 |
+
slic_display_btn = gr.Button("Segment this image")
|
| 204 |
+
|
| 205 |
+
with gr.Column(scale=2):
|
| 206 |
+
slic_image_output = gr.Image(label="Original Image",container=True)
|
| 207 |
+
slic_segmented_image_output = gr.Image(label="SLIC Segmented Image",container=True)
|
| 208 |
+
|
| 209 |
+
slic_display_btn.click(
|
| 210 |
+
fn=generate_slic,
|
| 211 |
+
inputs=[slic_file_input, slic_k_value,slic_m_value,slic_max_iter_value],
|
| 212 |
+
outputs=[slic_image_output,slic_segmented_image_output]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
with gr.TabItem("Watershed"):
|
| 216 |
+
with gr.Row():
|
| 217 |
+
with gr.Column(scale=1):
|
| 218 |
+
watershed_file = gr.File(label="Upload Image")
|
| 219 |
+
watershed_btn = gr.Button("Run Watershed")
|
| 220 |
+
|
| 221 |
+
with gr.Column(scale=2):
|
| 222 |
+
original_img = gr.Image(label="1. Original")
|
| 223 |
+
blurred_img = gr.Image(label="2. Blurred")
|
| 224 |
+
threshold_img = gr.Image(label="3. Threshold")
|
| 225 |
+
|
| 226 |
+
watershed_btn.click(
|
| 227 |
+
fn=generate_watershed,
|
| 228 |
+
inputs=[watershed_file],
|
| 229 |
+
outputs=[original_img, blurred_img, threshold_img]
|
| 230 |
+
)
|
| 231 |
+
with gr.TabItem("Felzenszwalb Algorithm Segmentation"):
|
| 232 |
+
with gr.Row():
|
| 233 |
+
with gr.Column(scale=1):
|
| 234 |
+
felzenszwalb_file_input = gr.File(label="Upload Image File")
|
| 235 |
+
sigma_value = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label="Sigma")
|
| 236 |
+
K_value = gr.Slider(minimum=2, maximum=1000, value=2, step=1, label="K value")
|
| 237 |
+
min_size_value = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Min Size Factor")
|
| 238 |
+
felzenszwalb_display_btn = gr.Button("Segment this image")
|
| 239 |
+
|
| 240 |
+
with gr.Column(scale=2):
|
| 241 |
+
felzenszwalb_image_output = gr.Image(label="Original Image",container=True)
|
| 242 |
+
felzenszwalb_segmented_image_output = gr.Image(label="felzenszwalb Segmented Image",container=True)
|
| 243 |
+
|
| 244 |
+
felzenszwalb_display_btn.click(
|
| 245 |
+
fn=generate_felzenszwalb,
|
| 246 |
+
inputs=[felzenszwalb_file_input,sigma_value,K_value,min_size_value],
|
| 247 |
+
outputs=[felzenszwalb_image_output,felzenszwalb_segmented_image_output]
|
| 248 |
+
)
|
| 249 |
+
with gr.TabItem("SegNet EfficientNet B0 Segmentation"):
|
| 250 |
+
with gr.Row():
|
| 251 |
+
with gr.Column(scale=1):
|
| 252 |
+
segnet_file_input = gr.File(label="Upload Image File")
|
| 253 |
+
segnet_display_btn = gr.Button("Segment this image")
|
| 254 |
+
|
| 255 |
+
with gr.Column(scale=2):
|
| 256 |
+
segnet_image_output = gr.Image(label="Original Image")
|
| 257 |
+
segnet_segmented_image_output = gr.Image(label="SegNet Segmented Image")
|
| 258 |
+
|
| 259 |
+
segnet_display_btn.click(
|
| 260 |
+
fn=SegNet_efficient_b0,
|
| 261 |
+
inputs=[segnet_file_input],
|
| 262 |
+
outputs=[segnet_image_output,segnet_segmented_image_output]
|
| 263 |
+
)
|
| 264 |
+
with gr.TabItem("SegNet VGG Segmentation"):
|
| 265 |
+
with gr.Row():
|
| 266 |
+
with gr.Column(scale=1):
|
| 267 |
+
segnet_file_input = gr.File(label="Upload Image File")
|
| 268 |
+
segnet_display_btn = gr.Button("Segment this image")
|
| 269 |
+
|
| 270 |
+
with gr.Column(scale=2):
|
| 271 |
+
segnet_image_output = gr.Image(label="Original Image")
|
| 272 |
+
segnet_segmented_image_output = gr.Image(label="SegNet VGG Segmented Image")
|
| 273 |
+
|
| 274 |
+
segnet_display_btn.click(
|
| 275 |
+
fn=generate_segnet_vgg,
|
| 276 |
+
inputs=[segnet_file_input],
|
| 277 |
+
outputs=[segnet_image_output,segnet_segmented_image_output]
|
| 278 |
+
)
|
| 279 |
+
# In app.py
|
| 280 |
+
with gr.TabItem("Ensemble Segmentation"):
|
| 281 |
+
with gr.Row():
|
| 282 |
+
with gr.Column(scale=1):
|
| 283 |
+
ensemble_file_input = gr.File(label="Upload Image File")
|
| 284 |
+
ensemble_display_btn = gr.Button("Segment with Ensemble Method")
|
| 285 |
+
|
| 286 |
+
with gr.Column(scale=2):
|
| 287 |
+
ensemble_image_output = gr.Image(label="Original Image")
|
| 288 |
+
ensemble_mask = gr.Image(label="Ensemble Segmented Image")
|
| 289 |
+
ensemble_segnet_segmented_output = gr.Image(label="SegNet Efficient B0 Segmented Image")
|
| 290 |
+
ensemble_otsu_segmented_output = gr.Image(label="Otsu Segmented Image")
|
| 291 |
+
|
| 292 |
+
ensemble_display_btn.click(
|
| 293 |
+
fn=ensemble_segmentation,
|
| 294 |
+
inputs=[ensemble_file_input],
|
| 295 |
+
outputs=[ensemble_image_output, ensemble_segnet_segmented_output, ensemble_otsu_segmented_output, ensemble_mask]
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
demo.launch()
|
| 300 |
+
|
bird.jpeg
ADDED
|
enhaned_kmeans_segmented.png
ADDED
|
experiments/SegNet/efficient_b0_backbone/architecture.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision import models, transforms
|
| 6 |
+
from torchvision.datasets import VOCSegmentation
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
import glob
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
import wandb
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import os
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import opendatasets as opd
|
| 17 |
+
import zipfile
|
| 18 |
+
|
| 19 |
+
torch.manual_seed(42)
|
| 20 |
+
np.random.seed(42)
|
| 21 |
+
|
| 22 |
+
# wandb.login(key="your_wandb_api_key_here")
|
| 23 |
+
|
| 24 |
+
EPOCHS = 25
|
| 25 |
+
BATCH_SIZE = 8
|
| 26 |
+
LR = 1e-3
|
| 27 |
+
NUM_CLASSES = 32
|
| 28 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
# wandb.init(project="segnet-efficientnet-camvid", config={
|
| 31 |
+
# "epochs": EPOCHS,
|
| 32 |
+
# "batch_size": BATCH_SIZE,
|
| 33 |
+
# "learning_rate": LR,
|
| 34 |
+
# "architecture": "SegNet-EfficientNet",
|
| 35 |
+
# "dataset": "CamVid"
|
| 36 |
+
# })
|
| 37 |
+
|
| 38 |
+
class SegNetEfficientNet(nn.Module):
|
| 39 |
+
def __init__(self, num_classes=32):
|
| 40 |
+
super(SegNetEfficientNet, self).__init__()
|
| 41 |
+
base_model = models.efficientnet_b0(pretrained=True)
|
| 42 |
+
features = list(base_model.features.children())
|
| 43 |
+
|
| 44 |
+
# EfficientNet-B0 backbone (output channels gradually increase to 1280)
|
| 45 |
+
self.encoder = nn.Sequential(*features) # Output: [B, 1280, H/32, W/32]
|
| 46 |
+
|
| 47 |
+
# Decoder blocks (mirroring encoder with ConvTranspose2d)
|
| 48 |
+
self.decoder = nn.Sequential(
|
| 49 |
+
nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
|
| 50 |
+
nn.BatchNorm2d(512),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
|
| 53 |
+
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
|
| 54 |
+
nn.BatchNorm2d(256),
|
| 55 |
+
nn.ReLU(inplace=True),
|
| 56 |
+
|
| 57 |
+
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
|
| 58 |
+
nn.BatchNorm2d(128),
|
| 59 |
+
nn.ReLU(inplace=True),
|
| 60 |
+
|
| 61 |
+
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
| 62 |
+
nn.BatchNorm2d(64),
|
| 63 |
+
nn.ReLU(inplace=True),
|
| 64 |
+
|
| 65 |
+
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
|
| 66 |
+
nn.BatchNorm2d(32),
|
| 67 |
+
nn.ReLU(inplace=True),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
x = self.encoder(x) # Downsampled features from EfficientNet
|
| 74 |
+
x = self.decoder(x) # Upsampled
|
| 75 |
+
x = self.classifier(x)
|
| 76 |
+
x = F.interpolate(x, size=(360, 480), mode='bilinear', align_corners=False)
|
| 77 |
+
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
class CamVidDataset(Dataset):
|
| 81 |
+
"""
|
| 82 |
+
CamVid dataset loader with RGB mask to class index conversion.
|
| 83 |
+
Expects directory structure:
|
| 84 |
+
camvid/
|
| 85 |
+
train/
|
| 86 |
+
train_labels/
|
| 87 |
+
val/
|
| 88 |
+
val_labels/
|
| 89 |
+
test/
|
| 90 |
+
test_labels/
|
| 91 |
+
"""
|
| 92 |
+
def __init__(self, root, split='train', transform=None, image_size=(360, 480), target_transform=None, class_dict_path='camvid/CamVid/class_dict.csv'):
|
| 93 |
+
self.root = root
|
| 94 |
+
self.split = split
|
| 95 |
+
self.transform = transform
|
| 96 |
+
self.target_transform = target_transform
|
| 97 |
+
|
| 98 |
+
self.image_dir = os.path.join(root, split)
|
| 99 |
+
self.label_dir = os.path.join(root, f"{split}_labels")
|
| 100 |
+
|
| 101 |
+
self.image_paths = sorted(glob.glob(os.path.join(self.image_dir, '*.png')))
|
| 102 |
+
self.label_paths = sorted(glob.glob(os.path.join(self.label_dir, '*.png')))
|
| 103 |
+
self.label_resize = transforms.Resize(image_size, interpolation=Image.NEAREST)
|
| 104 |
+
self.image_resize = transforms.Resize(image_size, interpolation=Image.BILINEAR)
|
| 105 |
+
assert len(self.image_paths) == len(self.label_paths), "Mismatch between images and labels."
|
| 106 |
+
|
| 107 |
+
# Load class_dict.csv and build color-to-class mapping
|
| 108 |
+
df = pd.read_csv(class_dict_path)
|
| 109 |
+
self.color_to_class = {
|
| 110 |
+
(row['r'], row['g'], row['b']): idx for idx, row in df.iterrows()
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def __len__(self):
|
| 114 |
+
return len(self.image_paths)
|
| 115 |
+
|
| 116 |
+
def rgb_to_class(self, mask):
|
| 117 |
+
"""Convert an RGB mask (PIL.Image) to a 2D class index mask."""
|
| 118 |
+
mask_np = np.array(mask)
|
| 119 |
+
h, w, _ = mask_np.shape
|
| 120 |
+
class_mask = np.zeros((h, w), dtype=np.uint8)
|
| 121 |
+
|
| 122 |
+
for rgb, class_idx in self.color_to_class.items():
|
| 123 |
+
matches = (mask_np == rgb).all(axis=2)
|
| 124 |
+
class_mask[matches] = class_idx
|
| 125 |
+
|
| 126 |
+
return class_mask
|
| 127 |
+
|
| 128 |
+
def __getitem__(self, idx):
|
| 129 |
+
image = Image.open(self.image_paths[idx]).convert('RGB')
|
| 130 |
+
label = Image.open(self.label_paths[idx]).convert('RGB')
|
| 131 |
+
|
| 132 |
+
# Resize both to 360x480
|
| 133 |
+
image = self.image_resize(image)
|
| 134 |
+
label = self.label_resize(label)
|
| 135 |
+
|
| 136 |
+
if self.transform:
|
| 137 |
+
image = self.transform(image)
|
| 138 |
+
|
| 139 |
+
label = self.rgb_to_class(label)
|
| 140 |
+
label = torch.from_numpy(label).long()
|
| 141 |
+
|
| 142 |
+
return image, label
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
dataset_url = "https://www.kaggle.com/datasets/carlolepelaars/camvid"
|
| 146 |
+
opd.download(dataset_url)
|
| 147 |
+
|
| 148 |
+
# Set dataset folder (adjust path if needed)
|
| 149 |
+
dataset_folder = "camvid"
|
| 150 |
+
print("Dataset directory contents:")
|
| 151 |
+
print(os.listdir(dataset_folder))
|
| 152 |
+
input_transform = transforms.Compose([
|
| 153 |
+
transforms.Resize((360, 480)), # Or larger if needed
|
| 154 |
+
transforms.ToTensor(),
|
| 155 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 156 |
+
std=[0.229, 0.224, 0.225])
|
| 157 |
+
])
|
| 158 |
+
|
| 159 |
+
def label_transform(label):
|
| 160 |
+
# Resize using nearest neighbor so that labels are not interpolated
|
| 161 |
+
label = label.resize((480, 360), Image.NEAREST)
|
| 162 |
+
label = np.array(label, dtype=np.int64)
|
| 163 |
+
return torch.from_numpy(label)
|
| 164 |
+
|
| 165 |
+
num_classes = 32
|
| 166 |
+
data_root = 'camvid/CamVid/' # make sure this matches your structure
|
| 167 |
+
|
| 168 |
+
# Load datasets and dataloaders (assuming CamVidDataset is already defined)
|
| 169 |
+
train_dataset = CamVidDataset(root=data_root, split='train',
|
| 170 |
+
transform=input_transform, target_transform=label_transform)
|
| 171 |
+
val_dataset = CamVidDataset(root=data_root, split='val',
|
| 172 |
+
transform=input_transform, target_transform=label_transform)
|
| 173 |
+
test_dataset = CamVidDataset(root=data_root, split='test',
|
| 174 |
+
transform=input_transform, target_transform=label_transform)
|
| 175 |
+
|
| 176 |
+
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
|
| 177 |
+
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)
|
| 178 |
+
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4)
|
experiments/SegNet/efficient_b0_backbone/train.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import models, transforms
|
| 5 |
+
from torchvision.datasets import VOCSegmentation
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
import wandb
|
| 10 |
+
import os
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from .architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE, LR, EPOCHS, train_loader, val_loader, IMAGE_SIZE
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
|
| 16 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
| 17 |
+
criterion = nn.CrossEntropyLoss(ignore_index=255)
|
| 18 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
def pixel_accuracy(preds, labels):
|
| 21 |
+
_, preds = torch.max(preds, 1)
|
| 22 |
+
correct = (preds == labels).float()
|
| 23 |
+
acc = correct.sum() / correct.numel()
|
| 24 |
+
return acc
|
| 25 |
+
|
| 26 |
+
# def mean_iou(preds, labels, num_classes=NUM_CLASSES):
|
| 27 |
+
# _, preds = torch.max(preds, 1)
|
| 28 |
+
# ious = []
|
| 29 |
+
# for cls in range(num_classes):
|
| 30 |
+
# intersection = ((preds == cls) & (labels == cls)).float().sum()
|
| 31 |
+
# union = ((preds == cls) | (labels == cls)).float().sum()
|
| 32 |
+
# if union > 0:
|
| 33 |
+
# ious.append(intersection / union)
|
| 34 |
+
# return sum(ious) / len(ious) if ious else 0
|
| 35 |
+
|
| 36 |
+
for epoch in tqdm(range(EPOCHS)):
|
| 37 |
+
model.train()
|
| 38 |
+
train_loss, train_acc = 0.0, 0.0
|
| 39 |
+
|
| 40 |
+
for images, masks in train_loader:
|
| 41 |
+
images, masks = images.to(DEVICE), masks.to(DEVICE)
|
| 42 |
+
optimizer.zero_grad()
|
| 43 |
+
outputs = model(images)
|
| 44 |
+
loss = criterion(outputs, masks)
|
| 45 |
+
loss.backward()
|
| 46 |
+
optimizer.step()
|
| 47 |
+
|
| 48 |
+
train_loss += loss.item()
|
| 49 |
+
train_acc += pixel_accuracy(outputs, masks).item()
|
| 50 |
+
|
| 51 |
+
train_loss /= len(train_loader)
|
| 52 |
+
train_acc /= len(train_loader)
|
| 53 |
+
|
| 54 |
+
# Validation
|
| 55 |
+
model.eval()
|
| 56 |
+
val_loss, val_acc = 0.0, 0.0
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
for images, masks in val_loader:
|
| 59 |
+
images, masks = images.to(DEVICE), masks.to(DEVICE)
|
| 60 |
+
outputs = model(images)
|
| 61 |
+
loss = criterion(outputs, masks)
|
| 62 |
+
|
| 63 |
+
val_loss += loss.item()
|
| 64 |
+
val_acc += pixel_accuracy(outputs, masks).item()
|
| 65 |
+
|
| 66 |
+
val_loss /= len(val_loader)
|
| 67 |
+
val_acc /= len(val_loader)
|
| 68 |
+
|
| 69 |
+
# wandb.log({
|
| 70 |
+
# "epoch": epoch + 1,
|
| 71 |
+
# "train_loss": train_loss,
|
| 72 |
+
# "train_accuracy": train_acc,
|
| 73 |
+
# "val_loss": val_loss,
|
| 74 |
+
# "val_accuracy": val_acc
|
| 75 |
+
# })
|
| 76 |
+
|
| 77 |
+
print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
|
| 78 |
+
|
| 79 |
+
torch.save(model.state_dict(), "segnet_efficientnet_camvid.pth")
|
| 80 |
+
# wandb.finish()
|
| 81 |
+
|
experiments/SegNet/vgg_backbone/SegNet_with_VGG16_backbone.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiments/SegNet/vgg_backbone/model.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models as models
|
| 4 |
+
|
| 5 |
+
class SegNet(nn.Module):
|
| 6 |
+
def __init__(self, num_classes=32):
|
| 7 |
+
super(SegNet, self).__init__()
|
| 8 |
+
vgg16 = models.vgg16_bn(pretrained=True)
|
| 9 |
+
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
|
| 10 |
+
self.unpool = nn.MaxUnpool2d(2, 2)
|
| 11 |
+
self.enc1 = nn.Sequential(*vgg16.features[:6])
|
| 12 |
+
self.enc2 = nn.Sequential(*vgg16.features[7:13])
|
| 13 |
+
self.enc3 = nn.Sequential(*vgg16.features[14:23])
|
| 14 |
+
self.enc4 = nn.Sequential(*vgg16.features[24:33])
|
| 15 |
+
self.dec4 = self.decoder_block(512, 256)
|
| 16 |
+
self.dec3 = self.decoder_block(256, 128)
|
| 17 |
+
self.dec2 = self.decoder_block(128, 64)
|
| 18 |
+
self.dec1 = self.decoder_block(64, 64)
|
| 19 |
+
self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)
|
| 20 |
+
|
| 21 |
+
def decoder_block(self, in_channels, out_channels):
|
| 22 |
+
return nn.Sequential(
|
| 23 |
+
nn.Conv2d(in_channels, in_channels, 3, padding=1),
|
| 24 |
+
nn.BatchNorm2d(in_channels),
|
| 25 |
+
nn.ReLU(inplace=True),
|
| 26 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
| 27 |
+
nn.BatchNorm2d(out_channels),
|
| 28 |
+
nn.ReLU(inplace=True)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
x1 = self.enc1(x)
|
| 33 |
+
x1p, ind1 = self.pool(x1)
|
| 34 |
+
x2 = self.enc2(x1p)
|
| 35 |
+
x2p, ind2 = self.pool(x2)
|
| 36 |
+
x3 = self.enc3(x2p)
|
| 37 |
+
x3p, ind3 = self.pool(x3)
|
| 38 |
+
x4 = self.enc4(x3p)
|
| 39 |
+
x4p, ind4 = self.pool(x4)
|
| 40 |
+
d4 = self.unpool(x4p, ind4, output_size=x4.size())
|
| 41 |
+
d4 = self.dec4(d4)
|
| 42 |
+
d3 = self.unpool(d4, ind3, output_size=x3.size())
|
| 43 |
+
d3 = self.dec3(d3)
|
| 44 |
+
d2 = self.unpool(d3, ind2, output_size=x2.size())
|
| 45 |
+
d2 = self.dec2(d2)
|
| 46 |
+
d1 = self.unpool(d2, ind1, output_size=x1.size())
|
| 47 |
+
d1 = self.dec1(d1)
|
| 48 |
+
return self.classifier(d1)
|
experiments/enhanced_kmeans_segmenter.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
def slic_kmeans(image_path, K=100, m=10, max_iter=10):
|
| 8 |
+
"""
|
| 9 |
+
Perform superpixel segmentation using enhanced K-means with LAB+XY.
|
| 10 |
+
Args:
|
| 11 |
+
image (np.ndarray): RGB input image.
|
| 12 |
+
K (int): Number of superpixels.
|
| 13 |
+
m (float): Compactness factor.
|
| 14 |
+
max_iter (int): Number of iterations.
|
| 15 |
+
Returns:
|
| 16 |
+
segmented_img: The segmented image with cluster colors.
|
| 17 |
+
labels: Cluster label for each pixel.
|
| 18 |
+
"""
|
| 19 |
+
jpg_image = Image.open(image_path)
|
| 20 |
+
image = np.array(jpg_image)
|
| 21 |
+
h, w = image.shape[:2]
|
| 22 |
+
S = int(np.sqrt(h * w / K)) # grid interval
|
| 23 |
+
|
| 24 |
+
# Convert to LAB color space
|
| 25 |
+
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 26 |
+
|
| 27 |
+
# Create 5D feature vector [L, a, b, x, y]
|
| 28 |
+
X, Y = np.meshgrid(np.arange(w), np.arange(h))
|
| 29 |
+
features = np.dstack((lab, X, Y)).reshape((-1, 5))
|
| 30 |
+
|
| 31 |
+
# Initialize cluster centers on grid
|
| 32 |
+
centers = []
|
| 33 |
+
for y in range(S // 2, h, S):
|
| 34 |
+
for x in range(S // 2, w, S):
|
| 35 |
+
center = features[y * w + x]
|
| 36 |
+
centers.append(center)
|
| 37 |
+
centers = np.array(centers)
|
| 38 |
+
|
| 39 |
+
labels = np.full((h * w,), -1, dtype=np.int32)
|
| 40 |
+
distances = np.full((h * w,), np.inf)
|
| 41 |
+
|
| 42 |
+
for iteration in tqdm(range(max_iter)):
|
| 43 |
+
for idx, center in enumerate(centers):
|
| 44 |
+
l, a, b, cx, cy = center
|
| 45 |
+
x_start, x_end = max(0, int(cx - S)), min(w, int(cx + S))
|
| 46 |
+
y_start, y_end = max(0, int(cy - S)), min(h, int(cy + S))
|
| 47 |
+
|
| 48 |
+
for y in range(y_start, y_end):
|
| 49 |
+
for x in range(x_start, x_end):
|
| 50 |
+
i = y * w + x
|
| 51 |
+
fp = features[i]
|
| 52 |
+
dc = np.linalg.norm(fp[:3] - center[:3]) # LAB distance
|
| 53 |
+
ds = np.linalg.norm(fp[3:] - center[3:]) # XY distance
|
| 54 |
+
D = np.sqrt(dc**2 + (ds / S)**2 * m**2)
|
| 55 |
+
|
| 56 |
+
if D < distances[i]:
|
| 57 |
+
distances[i] = D
|
| 58 |
+
labels[i] = idx
|
| 59 |
+
|
| 60 |
+
# Update cluster centers
|
| 61 |
+
new_centers = np.zeros_like(centers)
|
| 62 |
+
count = np.zeros(len(centers))
|
| 63 |
+
for i in range(h * w):
|
| 64 |
+
lbl = labels[i]
|
| 65 |
+
new_centers[lbl] += features[i]
|
| 66 |
+
count[lbl] += 1
|
| 67 |
+
for i in range(len(centers)):
|
| 68 |
+
if count[i] > 0:
|
| 69 |
+
new_centers[i] /= count[i]
|
| 70 |
+
centers = new_centers
|
| 71 |
+
|
| 72 |
+
# Recolor image based on cluster centers
|
| 73 |
+
segmented_img = np.zeros((h, w, 3), dtype=np.uint8)
|
| 74 |
+
for i in range(h * w):
|
| 75 |
+
lbl = labels[i]
|
| 76 |
+
lab_val = centers[lbl][:3]
|
| 77 |
+
lab_pixel = np.uint8([[lab_val]])
|
| 78 |
+
rgb_pixel = cv2.cvtColor(lab_pixel, cv2.COLOR_LAB2RGB)[0][0]
|
| 79 |
+
segmented_img[i // w, i % w] = rgb_pixel
|
| 80 |
+
|
| 81 |
+
return jpg_image, Image.fromarray(segmented_img), labels.reshape((h, w)), centers
|
| 82 |
+
|
| 83 |
+
# img_path = "/home/akshat/projects/CSL7360_Project/bird.jpeg"
|
| 84 |
+
# image = cv2.imread(img_path)
|
| 85 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 86 |
+
|
| 87 |
+
# _,seg_img, labels, centers = slic_kmeans(image, K=2, m=20)
|
| 88 |
+
# seg_img.save("enhaned_kmeans_segmented.png")
|
| 89 |
+
# plt.figure(figsize=(10, 5))
|
| 90 |
+
# plt.subplot(1, 2, 1)
|
| 91 |
+
# plt.imshow(image)
|
| 92 |
+
# plt.title("Original Image")
|
| 93 |
+
# plt.axis("off")
|
| 94 |
+
|
| 95 |
+
# plt.subplot(1, 2, 2)
|
| 96 |
+
# plt.imshow(seg_img)
|
| 97 |
+
# plt.title("SLIC-like K-Means Segmentation")
|
| 98 |
+
# plt.axis("off")
|
| 99 |
+
# plt.tight_layout()
|
| 100 |
+
# plt.show()
|
experiments/ensemble_method.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import cv2
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from experiments.otsu_segmenter import otsu_threshold
|
| 7 |
+
from experiments.SegNet.efficient_b0_backbone.architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE
|
| 8 |
+
|
| 9 |
+
def ensemble_segmentation(image_path, model_path="segnet_efficientnet_voc.pth", boundary_weight=0.3):
|
| 10 |
+
"""
|
| 11 |
+
Ensemble segmentation combining Otsu thresholding and SegNet
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
image_path: Path to input image
|
| 15 |
+
model_path: Path to SegNet model weights
|
| 16 |
+
boundary_weight: Weight for boundary refinement (0-1)
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
original_image: Original input image (PIL)
|
| 20 |
+
ensemble_result: Ensemble segmentation result (PIL)
|
| 21 |
+
method_comparison: Visualization of all methods side by side (PIL)
|
| 22 |
+
"""
|
| 23 |
+
# 1. Load the image
|
| 24 |
+
image = Image.open(image_path).convert('RGB')
|
| 25 |
+
original = image.copy()
|
| 26 |
+
image_np = np.array(image)
|
| 27 |
+
|
| 28 |
+
# 2. Run Otsu thresholding for boundary detection
|
| 29 |
+
# Convert to grayscale and apply Gaussian blur
|
| 30 |
+
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
| 31 |
+
gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
|
| 32 |
+
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
| 33 |
+
otsu_threshold_value, otsu_mask = otsu_threshold(blurred)
|
| 34 |
+
|
| 35 |
+
# 3. Run SegNet for semantic segmentation
|
| 36 |
+
model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
|
| 37 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 38 |
+
model.eval()
|
| 39 |
+
|
| 40 |
+
transform = transforms.Compose([
|
| 41 |
+
transforms.Resize((360, 480)), # Or larger if needed
|
| 42 |
+
transforms.ToTensor(),
|
| 43 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 44 |
+
std=[0.229, 0.224, 0.225])
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
| 48 |
+
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
output = model(input_tensor)
|
| 51 |
+
segnet_pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
|
| 52 |
+
|
| 53 |
+
# 4. Create edge map from Otsu result
|
| 54 |
+
edges = cv2.Canny(otsu_mask, 50, 150)
|
| 55 |
+
|
| 56 |
+
# Resize to match SegNet output size
|
| 57 |
+
edges_resized = cv2.resize(edges, (segnet_pred.shape[1], segnet_pred.shape[0]),
|
| 58 |
+
interpolation=cv2.INTER_NEAREST)
|
| 59 |
+
|
| 60 |
+
# 5. Ensemble: Use Otsu edges to refine SegNet boundaries
|
| 61 |
+
# Create a distance transform from the edges
|
| 62 |
+
dist_transform = cv2.distanceTransform(255 - edges_resized, cv2.DIST_L2, 5)
|
| 63 |
+
dist_transform = dist_transform / dist_transform.max() # Normalize to 0-1
|
| 64 |
+
|
| 65 |
+
# Areas close to edges get more influence from Otsu
|
| 66 |
+
edge_weight_map = np.exp(-dist_transform * 5) * boundary_weight
|
| 67 |
+
|
| 68 |
+
# Create binary mask from SegNet (foreground = any class other than background)
|
| 69 |
+
segnet_binary = (segnet_pred > 0).astype(np.uint8) * 255
|
| 70 |
+
|
| 71 |
+
# Resize Otsu mask to match SegNet output
|
| 72 |
+
otsu_resized = cv2.resize(otsu_mask, (segnet_pred.shape[1], segnet_pred.shape[0]),
|
| 73 |
+
interpolation=cv2.INTER_NEAREST)
|
| 74 |
+
|
| 75 |
+
# Combine: Use SegNet classes but refine boundaries with Otsu
|
| 76 |
+
# For boundary regions, adjust the segmentation based on Otsu
|
| 77 |
+
refined_binary = segnet_binary.copy()
|
| 78 |
+
boundary_region = edge_weight_map > 0.1
|
| 79 |
+
refined_binary[boundary_region] = (
|
| 80 |
+
(1 - edge_weight_map[boundary_region]) * segnet_binary[boundary_region] +
|
| 81 |
+
edge_weight_map[boundary_region] * otsu_resized[boundary_region]
|
| 82 |
+
).astype(np.uint8)
|
| 83 |
+
|
| 84 |
+
# Apply the refined binary mask to the original SegNet prediction
|
| 85 |
+
ensemble_result = segnet_pred.copy()
|
| 86 |
+
# Where the refined binary is 0, set to background class (0)
|
| 87 |
+
ensemble_result[refined_binary < 128] = 0
|
| 88 |
+
|
| 89 |
+
# 6. Visualize results
|
| 90 |
+
from matplotlib import cm
|
| 91 |
+
import matplotlib.pyplot as plt
|
| 92 |
+
import io
|
| 93 |
+
|
| 94 |
+
# Convert semantic maps to color visualizations
|
| 95 |
+
colormap = cm.get_cmap('nipy_spectral')
|
| 96 |
+
|
| 97 |
+
segnet_colored = colormap(segnet_pred / (NUM_CLASSES - 1))
|
| 98 |
+
segnet_colored = (segnet_colored[:, :, :3] * 255).astype(np.uint8)
|
| 99 |
+
|
| 100 |
+
ensemble_colored = colormap(ensemble_result / (NUM_CLASSES - 1))
|
| 101 |
+
ensemble_colored = (ensemble_colored[:, :, :3] * 255).astype(np.uint8)
|
| 102 |
+
|
| 103 |
+
# Create side-by-side comparison
|
| 104 |
+
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
|
| 105 |
+
|
| 106 |
+
# Resize original image to match the segmentation size
|
| 107 |
+
original_resized = original.resize((segnet_pred.shape[1], segnet_pred.shape[0]))
|
| 108 |
+
|
| 109 |
+
axes[0].imshow(original_resized)
|
| 110 |
+
axes[0].set_title("Original Image")
|
| 111 |
+
axes[0].axis('off')
|
| 112 |
+
|
| 113 |
+
axes[1].imshow(otsu_mask, cmap='gray')
|
| 114 |
+
axes[1].set_title(f"Otsu (t={otsu_threshold_value})")
|
| 115 |
+
axes[1].axis('off')
|
| 116 |
+
|
| 117 |
+
axes[2].imshow(segnet_colored)
|
| 118 |
+
axes[2].set_title("SegNet Prediction")
|
| 119 |
+
axes[2].axis('off')
|
| 120 |
+
|
| 121 |
+
axes[3].imshow(ensemble_colored)
|
| 122 |
+
axes[3].set_title("Ensemble Result")
|
| 123 |
+
axes[3].axis('off')
|
| 124 |
+
|
| 125 |
+
plt.tight_layout()
|
| 126 |
+
|
| 127 |
+
# Convert the plot to an image
|
| 128 |
+
buf = io.BytesIO()
|
| 129 |
+
plt.savefig(buf, format='png')
|
| 130 |
+
buf.seek(0)
|
| 131 |
+
comparison_image = Image.open(buf)
|
| 132 |
+
plt.close(fig)
|
| 133 |
+
|
| 134 |
+
# Return results
|
| 135 |
+
ensemble_pil = Image.fromarray(ensemble_colored)
|
| 136 |
+
ensemble_pil = ensemble_pil.resize(original.size, Image.NEAREST)
|
| 137 |
+
|
| 138 |
+
return original, ensemble_pil, comparison_image
|
| 139 |
+
|
| 140 |
+
# Add this function to your app.py
|
| 141 |
+
def generate_ensemble_segmentation(image_path, boundary_weight=0.3):
|
| 142 |
+
"""Wrapper for Gradio interface"""
|
| 143 |
+
original, ensemble_result, comparison = ensemble_segmentation(
|
| 144 |
+
image_path,
|
| 145 |
+
model_path="saved_models/segnet_efficientnet_camvid.pth",
|
| 146 |
+
boundary_weight=boundary_weight
|
| 147 |
+
)
|
| 148 |
+
return original, ensemble_result, comparison
|
experiments/felzenszwalb_segmentation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .segmentation import segment
|
experiments/felzenszwalb_segmentation/disjoint_set.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class DisjointSet:
|
| 5 |
+
|
| 6 |
+
def __init__(self, n_elements):
|
| 7 |
+
self.num = n_elements
|
| 8 |
+
self.elements = np.empty(
|
| 9 |
+
shape=(n_elements, 3),
|
| 10 |
+
dtype=int
|
| 11 |
+
)
|
| 12 |
+
for i in range(n_elements):
|
| 13 |
+
self.elements[i, 0] = 0
|
| 14 |
+
self.elements[i, 1] = 1
|
| 15 |
+
self.elements[i, 2] = i
|
| 16 |
+
|
| 17 |
+
def size(self, x):
|
| 18 |
+
return self.elements[x, 1]
|
| 19 |
+
|
| 20 |
+
def num_sets(self):
|
| 21 |
+
return self.num
|
| 22 |
+
|
| 23 |
+
def find(self, x):
|
| 24 |
+
y = int(x)
|
| 25 |
+
while y != self.elements[y, 2]:
|
| 26 |
+
y = self.elements[y, 2]
|
| 27 |
+
self.elements[x, 2] = y
|
| 28 |
+
return y
|
| 29 |
+
|
| 30 |
+
def join(self, x, y):
|
| 31 |
+
if self.elements[x, 0] > self.elements[y, 0]:
|
| 32 |
+
self.elements[y, 2] = x
|
| 33 |
+
self.elements[x, 1] += self.elements[y, 1]
|
| 34 |
+
else:
|
| 35 |
+
self.elements[x, 2] = y
|
| 36 |
+
self.elements[y, 1] += self.elements[x, 1]
|
| 37 |
+
if self.elements[x, 0] == self.elements[y, 0]:
|
| 38 |
+
self.elements[y, 0] += 1
|
| 39 |
+
self.num -= 1
|
experiments/felzenszwalb_segmentation/segmentation.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from .disjoint_set import DisjointSet
|
| 3 |
+
from .utils import smoothen, difference, get_random_rgb_image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def segment_graph(num_vertices, num_edges, edges, c):
|
| 7 |
+
edges[0 : num_edges, :] = edges[edges[0 : num_edges, 2].argsort()]
|
| 8 |
+
u = DisjointSet(num_vertices)
|
| 9 |
+
threshold = np.zeros(shape=num_vertices, dtype=float)
|
| 10 |
+
for i in range(num_vertices):
|
| 11 |
+
threshold[i] = c
|
| 12 |
+
for i in range(num_edges):
|
| 13 |
+
pedge = edges[i, :]
|
| 14 |
+
a = u.find(pedge[0])
|
| 15 |
+
b = u.find(pedge[1])
|
| 16 |
+
if a != b:
|
| 17 |
+
if (pedge[2] <= threshold[a]) and (pedge[2] <= threshold[b]):
|
| 18 |
+
u.join(a, b)
|
| 19 |
+
a = u.find(a)
|
| 20 |
+
threshold[a] = pedge[2] + (c / u.size(a))
|
| 21 |
+
return u
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def segment(in_image, sigma, k, min_size):
|
| 25 |
+
height, width, band = in_image.shape
|
| 26 |
+
smooth_red_band = smoothen(in_image[:, :, 0], sigma)
|
| 27 |
+
smooth_green_band = smoothen(in_image[:, :, 1], sigma)
|
| 28 |
+
smooth_blue_band = smoothen(in_image[:, :, 2], sigma)
|
| 29 |
+
# build graph
|
| 30 |
+
edges_size = width * height * 4
|
| 31 |
+
edges = np.zeros(shape=(edges_size, 3), dtype=object)
|
| 32 |
+
num = 0
|
| 33 |
+
for y in range(height):
|
| 34 |
+
for x in range(width):
|
| 35 |
+
if x < width - 1:
|
| 36 |
+
edges[num, 0] = int(y * width + x)
|
| 37 |
+
edges[num, 1] = int(y * width + (x + 1))
|
| 38 |
+
edges[num, 2] = difference(
|
| 39 |
+
smooth_red_band, smooth_green_band,
|
| 40 |
+
smooth_blue_band, x, y, x + 1, y
|
| 41 |
+
)
|
| 42 |
+
num += 1
|
| 43 |
+
if y < height - 1:
|
| 44 |
+
edges[num, 0] = int(y * width + x)
|
| 45 |
+
edges[num, 1] = int((y + 1) * width + x)
|
| 46 |
+
edges[num, 2] = difference(
|
| 47 |
+
smooth_red_band, smooth_green_band,
|
| 48 |
+
smooth_blue_band, x, y, x, y + 1
|
| 49 |
+
)
|
| 50 |
+
num += 1
|
| 51 |
+
if (x < width - 1) and (y < height - 2):
|
| 52 |
+
edges[num, 0] = int(y * width + x)
|
| 53 |
+
edges[num, 1] = int((y + 1) * width + (x + 1))
|
| 54 |
+
edges[num, 2] = difference(
|
| 55 |
+
smooth_red_band, smooth_green_band,
|
| 56 |
+
smooth_blue_band, x, y, x + 1, y + 1
|
| 57 |
+
)
|
| 58 |
+
num += 1
|
| 59 |
+
if (x < width - 1) and (y > 0):
|
| 60 |
+
edges[num, 0] = int(y * width + x)
|
| 61 |
+
edges[num, 1] = int((y - 1) * width + (x + 1))
|
| 62 |
+
edges[num, 2] = difference(
|
| 63 |
+
smooth_red_band, smooth_green_band,
|
| 64 |
+
smooth_blue_band, x, y, x + 1, y - 1
|
| 65 |
+
)
|
| 66 |
+
num += 1
|
| 67 |
+
u = segment_graph(width * height, num, edges, k)
|
| 68 |
+
for i in range(num):
|
| 69 |
+
a = u.find(edges[i, 0])
|
| 70 |
+
b = u.find(edges[i, 1])
|
| 71 |
+
if (a != b) and ((u.size(a) < min_size) or (u.size(b) < min_size)):
|
| 72 |
+
u.join(a, b)
|
| 73 |
+
num_cc = u.num_sets()
|
| 74 |
+
output = np.zeros(shape=(height, width, 3))
|
| 75 |
+
|
| 76 |
+
colors = np.zeros(shape=(height * width, 3))
|
| 77 |
+
for i in range(height * width):
|
| 78 |
+
colors[i, :] = get_random_rgb_image()
|
| 79 |
+
for y in range(height):
|
| 80 |
+
for x in range(width):
|
| 81 |
+
comp = u.find(y * width + x)
|
| 82 |
+
output[y, x, :] = colors[comp, :]
|
| 83 |
+
return output
|
experiments/felzenszwalb_segmentation/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import *
|
| 2 |
+
from .filter_utils import *
|
experiments/felzenszwalb_segmentation/utils/filter_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from math import ceil, exp, pow
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convolve(src, mask):
|
| 6 |
+
output = np.zeros(shape=src.shape, dtype=float)
|
| 7 |
+
height, width = src.shape
|
| 8 |
+
length = len(mask)
|
| 9 |
+
for y in range(height):
|
| 10 |
+
for x in range(width):
|
| 11 |
+
sum = float(mask[0] * src[y, x])
|
| 12 |
+
for i in range(1, length):
|
| 13 |
+
sum += mask[i] * (
|
| 14 |
+
src[y, max(x - i, 0)] + src[y, min(x + i, width - 1)])
|
| 15 |
+
output[y, x] = sum
|
| 16 |
+
return output
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def normalize(mask):
|
| 20 |
+
sum = 2 * np.sum(np.absolute(mask)) + abs(mask[0])
|
| 21 |
+
return np.divide(mask, sum)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def smoothen(src, sigma):
|
| 25 |
+
mask = make_gaussian_filter(sigma)
|
| 26 |
+
mask = normalize(mask)
|
| 27 |
+
tmp = convolve(src, mask)
|
| 28 |
+
dst = convolve(tmp, mask)
|
| 29 |
+
return dst
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def make_gaussian_filter(sigma):
|
| 33 |
+
sigma = max(sigma, 0.01)
|
| 34 |
+
length = int(ceil(sigma * 4.0)) + 1
|
| 35 |
+
mask = np.zeros(shape=length, dtype=float)
|
| 36 |
+
for i in range(length):
|
| 37 |
+
mask[i] = exp(-0.5 * pow(i / sigma, i / sigma))
|
| 38 |
+
return mask
|
experiments/felzenszwalb_segmentation/utils/utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from math import sqrt
|
| 3 |
+
from random import randint
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def difference(red_band, green_band, blue_band, x1, y1, x2, y2):
|
| 7 |
+
return sqrt(
|
| 8 |
+
(red_band[y1, x1] - red_band[y2, x2]) ** 2 +\
|
| 9 |
+
(green_band[y1, x1] - green_band[y2, x2]) ** 2 +\
|
| 10 |
+
(blue_band[y1, x1] - blue_band[y2, x2]) ** 2
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_random_rgb_image():
|
| 15 |
+
rgb = np.zeros(3, dtype=int)
|
| 16 |
+
rgb[0] = randint(0, 255)
|
| 17 |
+
rgb[1] = randint(0, 255)
|
| 18 |
+
rgb[2] = randint(0, 255)
|
| 19 |
+
return rgb
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_random_gray_image():
|
| 23 |
+
gray = np.zeros(1, dtype=int)
|
| 24 |
+
gray[0] = randint(0, 255)
|
| 25 |
+
return gray
|
experiments/kmeans_segmenter.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import cv2
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import io
|
| 6 |
+
|
| 7 |
+
def initialize_centroids(data, K):
|
| 8 |
+
"""Randomly choose K data points as initial centroids."""
|
| 9 |
+
indices = np.random.choice(data.shape[0], K, replace=False)
|
| 10 |
+
return data[indices]
|
| 11 |
+
|
| 12 |
+
def compute_distances(data, centroids):
|
| 13 |
+
"""Compute the Euclidean distance between each data point and each centroid."""
|
| 14 |
+
return np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)
|
| 15 |
+
|
| 16 |
+
def update_centroids(data, labels, K):
|
| 17 |
+
"""Update centroids as the mean of the points assigned to each cluster."""
|
| 18 |
+
new_centroids = np.zeros((K, data.shape[1]))
|
| 19 |
+
for k in range(K):
|
| 20 |
+
cluster_points = data[labels == k]
|
| 21 |
+
if len(cluster_points) > 0:
|
| 22 |
+
new_centroids[k] = np.mean(cluster_points, axis=0)
|
| 23 |
+
return new_centroids
|
| 24 |
+
|
| 25 |
+
def kmeans_from_scratch(image, K=4, max_iters=100, tol=1e-4):
|
| 26 |
+
"""Apply K-means clustering from scratch to segment the image."""
|
| 27 |
+
data = image.reshape((-1, 3)).astype(np.float32)
|
| 28 |
+
|
| 29 |
+
centroids = initialize_centroids(data, K)
|
| 30 |
+
|
| 31 |
+
for i in range(max_iters):
|
| 32 |
+
distances = compute_distances(data, centroids)
|
| 33 |
+
labels = np.argmin(distances, axis=1)
|
| 34 |
+
|
| 35 |
+
new_centroids = update_centroids(data, labels, K)
|
| 36 |
+
shift = np.linalg.norm(new_centroids - centroids)
|
| 37 |
+
|
| 38 |
+
if shift < tol:
|
| 39 |
+
break
|
| 40 |
+
centroids = new_centroids
|
| 41 |
+
|
| 42 |
+
segmented_data = centroids[labels].astype(np.uint8)
|
| 43 |
+
segmented_image = segmented_data.reshape(image.shape)
|
| 44 |
+
|
| 45 |
+
return segmented_image, labels.reshape(image.shape[:2]), centroids.astype(np.uint8)
|
| 46 |
+
|
| 47 |
+
def generate_kmeans_segmented_image(image_path, k=3):
|
| 48 |
+
"""Process image with K-means for Gradio app"""
|
| 49 |
+
image = Image.open(image_path)
|
| 50 |
+
image_np = np.array(image)
|
| 51 |
+
|
| 52 |
+
if len(image_np.shape) == 3:
|
| 53 |
+
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
| 54 |
+
image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
|
| 55 |
+
else:
|
| 56 |
+
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
|
| 57 |
+
|
| 58 |
+
seg_img, labels, centers = kmeans_from_scratch(image_rgb, K=k)
|
| 59 |
+
|
| 60 |
+
colors_image = np.zeros((50 * k, 100, 3), dtype=np.uint8)
|
| 61 |
+
for i, color in enumerate(centers):
|
| 62 |
+
colors_image[i*50:(i+1)*50, :] = color
|
| 63 |
+
|
| 64 |
+
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
| 65 |
+
|
| 66 |
+
axes[0].imshow(image_rgb)
|
| 67 |
+
axes[0].set_title("Original Image")
|
| 68 |
+
axes[0].axis('off')
|
| 69 |
+
|
| 70 |
+
axes[1].imshow(seg_img)
|
| 71 |
+
axes[1].set_title(f"K-Means (K={k})")
|
| 72 |
+
axes[1].axis('off')
|
| 73 |
+
|
| 74 |
+
axes[2].imshow(colors_image)
|
| 75 |
+
axes[2].set_title("Cluster Colors")
|
| 76 |
+
axes[2].axis('off')
|
| 77 |
+
|
| 78 |
+
plt.tight_layout()
|
| 79 |
+
|
| 80 |
+
buf = io.BytesIO()
|
| 81 |
+
fig.savefig(buf, format='png')
|
| 82 |
+
buf.seek(0)
|
| 83 |
+
comparison_image = Image.open(buf)
|
| 84 |
+
plt.close(fig)
|
| 85 |
+
|
| 86 |
+
return image, Image.fromarray(seg_img), comparison_image, f"K-Means clustering with K={k}"
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
image_path = "/home/akshat/projects/CSL7360_Project/bird.jpeg"
|
| 90 |
+
original, segmented, comparison, text = generate_kmeans_segmented_image(image_path, k=3)
|
| 91 |
+
|
| 92 |
+
# Save output images instead of displaying them
|
| 93 |
+
segmented.save("kmeans_segmented.png")
|
| 94 |
+
comparison.save("kmeans_comparison.png")
|
| 95 |
+
print(text)
|
experiments/otsu_segmenter.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import io
|
| 6 |
+
|
| 7 |
+
def otsu_threshold(image):
|
| 8 |
+
hist, bin_edges = np.histogram(image.flatten(), bins=256, range=[0, 256])
|
| 9 |
+
hist = hist.astype(float)
|
| 10 |
+
total_pixels = image.size
|
| 11 |
+
pixel_probability = hist / total_pixels
|
| 12 |
+
|
| 13 |
+
max_variance = 0
|
| 14 |
+
optimal_threshold = 0
|
| 15 |
+
|
| 16 |
+
for threshold in range(1, 256):
|
| 17 |
+
weight_background = np.sum(pixel_probability[:threshold])
|
| 18 |
+
weight_foreground = np.sum(pixel_probability[threshold:])
|
| 19 |
+
|
| 20 |
+
if weight_background == 0 or weight_foreground == 0:
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
mean_background = np.sum(np.arange(threshold) * pixel_probability[:threshold]) / weight_background
|
| 24 |
+
mean_foreground = np.sum(np.arange(threshold, 256) * pixel_probability[threshold:]) / weight_foreground
|
| 25 |
+
|
| 26 |
+
variance = weight_background * weight_foreground * (mean_background - mean_foreground) ** 2
|
| 27 |
+
|
| 28 |
+
if variance > max_variance:
|
| 29 |
+
max_variance = variance
|
| 30 |
+
optimal_threshold = threshold
|
| 31 |
+
|
| 32 |
+
segmented_image = np.zeros_like(image)
|
| 33 |
+
segmented_image[image >= optimal_threshold] = 255
|
| 34 |
+
|
| 35 |
+
return optimal_threshold, segmented_image
|
| 36 |
+
|
| 37 |
+
def generate_segmented_image(image_path):
|
| 38 |
+
# Convert PIL to OpenCV format
|
| 39 |
+
print(f"Image path: {image_path}")
|
| 40 |
+
image = Image.open(image_path)
|
| 41 |
+
image_np = np.array(image)
|
| 42 |
+
original_image = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
| 43 |
+
|
| 44 |
+
if len(original_image.shape) == 3:
|
| 45 |
+
gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
|
| 46 |
+
else:
|
| 47 |
+
gray_image = original_image.copy()
|
| 48 |
+
|
| 49 |
+
blurred = cv2.GaussianBlur(gray_image, (5, 5), 0)
|
| 50 |
+
|
| 51 |
+
# Our implementation
|
| 52 |
+
our_threshold, our_segmented = otsu_threshold(blurred)
|
| 53 |
+
|
| 54 |
+
# OpenCV's implementation
|
| 55 |
+
opencv_threshold, opencv_segmented = cv2.threshold(
|
| 56 |
+
blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Create histogram figure
|
| 60 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 61 |
+
ax.hist(gray_image.ravel(), 256, [0, 256], color='gray')
|
| 62 |
+
ax.axvline(x=our_threshold, color='red', linestyle='--', label=f'Ours: {our_threshold}')
|
| 63 |
+
ax.axvline(x=opencv_threshold, color='green', linestyle='--', label=f'OpenCV: {opencv_threshold}')
|
| 64 |
+
ax.set_title("Histogram with Thresholds")
|
| 65 |
+
ax.legend()
|
| 66 |
+
|
| 67 |
+
# Convert Matplotlib figure to image
|
| 68 |
+
buf = io.BytesIO()
|
| 69 |
+
plt.savefig(buf, format='png')
|
| 70 |
+
buf.seek(0)
|
| 71 |
+
hist_image = Image.open(buf)
|
| 72 |
+
plt.close(fig) # Close the figure to free memory
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
return (
|
| 76 |
+
image,
|
| 77 |
+
Image.fromarray(our_segmented),
|
| 78 |
+
Image.fromarray(opencv_segmented),
|
| 79 |
+
hist_image,
|
| 80 |
+
f"Our Threshold: {our_threshold}\nOpenCV Threshold: {opencv_threshold}",
|
| 81 |
+
)
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
#example usage
|
| 84 |
+
# Ensure you have the image path set correctly
|
| 85 |
+
image_path = '/home/akshat/projects/CSL7360_Project/bird.jpeg'
|
| 86 |
+
image = cv2.imread('/home/akshat/projects/CSL7360_Project/bird.jpeg')
|
| 87 |
+
# Call the function
|
| 88 |
+
generate_segmented_image(image)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# # Optionally, save results to files
|
| 93 |
+
# cv2.imwrite("our_segmented.png", our_segmented)
|
| 94 |
+
# cv2.imwrite("opencv_segmented.png", opencv_segmented)
|
| 95 |
+
|
experiments/watershed_segmenter.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import heapq
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from collections import deque
|
| 6 |
+
|
| 7 |
+
# 1. Compute local minima as markers
|
| 8 |
+
def get_local_minima(gray):
|
| 9 |
+
kernel = np.ones((3, 3), np.uint8)
|
| 10 |
+
eroded = cv2.erode(gray, kernel)
|
| 11 |
+
minima = (gray == eroded)
|
| 12 |
+
return minima.astype(np.uint8)
|
| 13 |
+
|
| 14 |
+
# 2. Label each connected component (marker)
|
| 15 |
+
def label_markers(minima):
|
| 16 |
+
num_labels, markers = cv2.connectedComponents(minima)
|
| 17 |
+
return markers, num_labels
|
| 18 |
+
|
| 19 |
+
# 3. Watershed from scratch
|
| 20 |
+
def watershed_from_scratch(gray, markers):
|
| 21 |
+
h, w = gray.shape
|
| 22 |
+
# Constants
|
| 23 |
+
WATERSHED = -1
|
| 24 |
+
INIT = -2
|
| 25 |
+
|
| 26 |
+
# Initialize label and visited map
|
| 27 |
+
label_map = np.full((h, w), INIT, dtype=np.int32)
|
| 28 |
+
label_map[markers > 0] = markers[markers > 0]
|
| 29 |
+
|
| 30 |
+
# Priority queue for pixels: (intensity, y, x)
|
| 31 |
+
pq = []
|
| 32 |
+
|
| 33 |
+
# Populate queue with boundary of initial markers
|
| 34 |
+
for y in range(h):
|
| 35 |
+
for x in range(w):
|
| 36 |
+
if markers[y, x] > 0:
|
| 37 |
+
for dy in [-1, 0, 1]:
|
| 38 |
+
for dx in [-1, 0, 1]:
|
| 39 |
+
ny, nx = y + dy, x + dx
|
| 40 |
+
if 0 <= ny < h and 0 <= nx < w:
|
| 41 |
+
if markers[ny, nx] == 0 and label_map[ny, nx] == INIT:
|
| 42 |
+
heapq.heappush(pq, (gray[ny, nx], ny, nx))
|
| 43 |
+
label_map[ny, nx] = 0 # Mark as in queue
|
| 44 |
+
|
| 45 |
+
# Flooding
|
| 46 |
+
while pq:
|
| 47 |
+
intensity, y, x = heapq.heappop(pq)
|
| 48 |
+
|
| 49 |
+
neighbor_labels = set()
|
| 50 |
+
for dy in [-1, 0, 1]:
|
| 51 |
+
for dx in [-1, 0, 1]:
|
| 52 |
+
ny, nx = y + dy, x + dx
|
| 53 |
+
if 0 <= ny < h and 0 <= nx < w:
|
| 54 |
+
lbl = label_map[ny, nx]
|
| 55 |
+
if lbl > 0:
|
| 56 |
+
neighbor_labels.add(lbl)
|
| 57 |
+
|
| 58 |
+
if len(neighbor_labels) == 1:
|
| 59 |
+
label_map[y, x] = neighbor_labels.pop()
|
| 60 |
+
elif len(neighbor_labels) > 1:
|
| 61 |
+
label_map[y, x] = WATERSHED
|
| 62 |
+
|
| 63 |
+
# Add unvisited neighbors to the queue
|
| 64 |
+
for dy in [-1, 0, 1]:
|
| 65 |
+
for dx in [-1, 0, 1]:
|
| 66 |
+
ny, nx = y + dy, x + dx
|
| 67 |
+
if 0 <= ny < h and 0 <= nx < w:
|
| 68 |
+
if label_map[ny, nx] == INIT:
|
| 69 |
+
heapq.heappush(pq, (gray[ny, nx], ny, nx))
|
| 70 |
+
label_map[ny, nx] = 0 # Mark as in queue
|
| 71 |
+
|
| 72 |
+
return label_map
|
| 73 |
+
|
| 74 |
+
import numpy as np
|
| 75 |
+
import cv2
|
| 76 |
+
import heapq
|
| 77 |
+
|
| 78 |
+
def improved_watershed(image_path):
|
| 79 |
+
# Load and preprocess image
|
| 80 |
+
original = cv2.imread(image_path)
|
| 81 |
+
gray = cv2.cvtColor(original, cv2.COLOR_BGR2GRAY)
|
| 82 |
+
blurred = cv2.GaussianBlur(gray, (9, 9), 2)
|
| 83 |
+
|
| 84 |
+
# Step 1: Better marker detection using adaptive thresholding
|
| 85 |
+
thresh = cv2.adaptiveThreshold(blurred, 255,
|
| 86 |
+
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
| 87 |
+
cv2.THRESH_BINARY_INV, 21, 4)
|
| 88 |
+
|
| 89 |
+
# Step 2: Noise removal and sure background area
|
| 90 |
+
kernel = np.ones((3,3), np.uint8)
|
| 91 |
+
opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
|
| 92 |
+
|
| 93 |
+
# Step 3: Distance transform for better foreground detection
|
| 94 |
+
dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
|
| 95 |
+
_, sure_fg = cv2.threshold(dist_transform, 0.5*dist_transform.max(), 255, 0)
|
| 96 |
+
sure_fg = np.uint8(sure_fg)
|
| 97 |
+
|
| 98 |
+
# Step 4: Create markers using connected components
|
| 99 |
+
_, markers = cv2.connectedComponents(sure_fg)
|
| 100 |
+
markers += 1 # Add 1 to all labels so background is 1
|
| 101 |
+
|
| 102 |
+
# Step 5: Apply custom watershed algorithm
|
| 103 |
+
label_map = watershed_from_scratch(blurred, markers)
|
| 104 |
+
|
| 105 |
+
# Enhanced visualization
|
| 106 |
+
output = original.copy()
|
| 107 |
+
boundaries = (label_map == -1).astype(np.uint8) * 255
|
| 108 |
+
contours, _ = cv2.findContours(boundaries, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 109 |
+
cv2.drawContours(output, contours, -1, (0,0,255), 1)
|
| 110 |
+
|
| 111 |
+
# Create intermediate step visualization
|
| 112 |
+
process_steps = {
|
| 113 |
+
"Original": original,
|
| 114 |
+
"Blurred": cv2.cvtColor(blurred, cv2.COLOR_GRAY2BGR),
|
| 115 |
+
"Threshold": cv2.cvtColor(thresh, cv2.COLOR_GRAY2BGR),
|
| 116 |
+
"Foreground Markers": cv2.cvtColor(sure_fg, cv2.COLOR_GRAY2BGR),
|
| 117 |
+
"Final Segmentation": output
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return process_steps
|
| 121 |
+
|
| 122 |
+
def watershed_from_scratch(gray, markers):
|
| 123 |
+
h, w = gray.shape
|
| 124 |
+
WATERSHED = -1
|
| 125 |
+
INIT = -2
|
| 126 |
+
|
| 127 |
+
label_map = np.full((h, w), INIT, dtype=np.int32)
|
| 128 |
+
label_map[markers > 1] = markers[markers > 1] # Skip background marker
|
| 129 |
+
|
| 130 |
+
pq = []
|
| 131 |
+
# Initialize queue with marker boundaries
|
| 132 |
+
for y in range(h):
|
| 133 |
+
for x in range(w):
|
| 134 |
+
if label_map[y, x] > 0:
|
| 135 |
+
for dy in [-1, 0, 1]:
|
| 136 |
+
for dx in [-1, 0, 1]:
|
| 137 |
+
ny, nx = y+dy, x+dx
|
| 138 |
+
if 0 <= ny < h and 0 <= nx < w:
|
| 139 |
+
if label_map[ny, nx] == INIT:
|
| 140 |
+
heapq.heappush(pq, (gray[ny, nx], ny, nx))
|
| 141 |
+
label_map[ny, nx] = 0 # Queued
|
| 142 |
+
|
| 143 |
+
# Improved flooding with gradient consideration
|
| 144 |
+
while pq:
|
| 145 |
+
intensity, y, x = heapq.heappop(pq)
|
| 146 |
+
neighbors = []
|
| 147 |
+
|
| 148 |
+
# Check 8 neighbors
|
| 149 |
+
for dy in [-1, 0, 1]:
|
| 150 |
+
for dx in [-1, 0, 1]:
|
| 151 |
+
if dy == 0 and dx == 0:
|
| 152 |
+
continue
|
| 153 |
+
ny, nx = y+dy, x+dx
|
| 154 |
+
if 0 <= ny < h and 0 <= nx < w:
|
| 155 |
+
neighbors.append(label_map[ny, nx])
|
| 156 |
+
|
| 157 |
+
# Find unique labels excluding watershed and background
|
| 158 |
+
unique = set(n for n in neighbors if n > 0)
|
| 159 |
+
|
| 160 |
+
if len(unique) == 0:
|
| 161 |
+
label_map[y, x] = 1 # Background
|
| 162 |
+
elif len(unique) == 1:
|
| 163 |
+
label_map[y, x] = unique.pop()
|
| 164 |
+
else:
|
| 165 |
+
label_map[y, x] = WATERSHED
|
| 166 |
+
|
| 167 |
+
# Add neighbors to queue
|
| 168 |
+
for dy in [-1, 0, 1]:
|
| 169 |
+
for dx in [-1, 0, 1]:
|
| 170 |
+
ny, nx = y+dy, x+dx
|
| 171 |
+
if 0 <= ny < h and 0 <= nx < w:
|
| 172 |
+
if label_map[ny, nx] == INIT:
|
| 173 |
+
heapq.heappush(pq, (gray[ny, nx], ny, nx))
|
| 174 |
+
label_map[ny, nx] = 0
|
| 175 |
+
|
| 176 |
+
return label_map
|
| 177 |
+
|
| 178 |
+
# Gradio integration would use:
|
| 179 |
+
def generate_watershed(image_path):
|
| 180 |
+
results = improved_watershed(image_path)
|
| 181 |
+
return (
|
| 182 |
+
results["Original"],
|
| 183 |
+
results["Blurred"],
|
| 184 |
+
results["Threshold"],
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
# Run the process
|
| 190 |
+
# Load grayscale image
|
| 191 |
+
image = cv2.imread("/home/akshat/projects/CSL7360_Project/bird.jpeg", cv2.IMREAD_GRAYSCALE)
|
| 192 |
+
image = cv2.GaussianBlur(image, (5, 5), 0)
|
| 193 |
+
minima = get_local_minima(image)
|
| 194 |
+
markers, num_labels = label_markers(minima)
|
| 195 |
+
result = watershed_from_scratch(image, markers)
|
| 196 |
+
|
| 197 |
+
# Visualization
|
| 198 |
+
output = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
|
| 199 |
+
output[result == -1] = [255, 0, 0] # Watershed lines in red
|
| 200 |
+
output[result > 0] = [0, 255, 0] # Segments in green
|
| 201 |
+
output[markers > 0] = [0, 0, 255] # Original minima in blue
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# Save the original grayscale and the output image
|
| 205 |
+
cv2.imwrite("original_grayscale.png", image)
|
| 206 |
+
cv2.imwrite("watershed_output.png", output)
|
| 207 |
+
|
| 208 |
+
print("Images saved as 'original_grayscale.png' and 'watershed_output.png'")
|
kmeans_comparison.png
ADDED
|
Git LFS Details
|
kmeans_segmented.png
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.5.1
|
| 2 |
+
torchvision==0.20.1
|
| 3 |
+
gradio==5.23.1
|
| 4 |
+
pillow==10.4.0
|
| 5 |
+
numpy==2.2.2
|
| 6 |
+
opencv-python==4.10.0.84
|
| 7 |
+
matplotlib==3.10.0
|
| 8 |
+
wandb==0.19.6
|
| 9 |
+
tqdm==4.67.1
|
| 10 |
+
gdown==5.2.0
|
| 11 |
+
opendatasets==0.1.22
|
saved_models/segnet_efficientnet_camvid.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f1e96df359eb0e1c153627880dc93e662b2ae5f998f9ed946ec71e726739481
|
| 3 |
+
size 29641657
|
saved_models/segnet_vgg.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ac7681151184571d468e4c408c30107dd8b44170b602a06b97a24240f0fb83b
|
| 3 |
+
size 49538462
|
segnet_efficientnet_voc.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5225a079173dc4b5b1f786e79a474d64c2d17a9aa8f35bbb0908cfbb0f2b9baa
|
| 3 |
+
size 29583954
|
watershed_output.png
ADDED
|
Git LFS Details
|