fazzam commited on
Commit
e972242
·
verified ·
1 Parent(s): a891910

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Streamlit_app.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: PetroSeg
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.33.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
 
 
 
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # Unsupervised Segmentation App with Streamlit and PyTorch
2
+
3
+ ## Table of Contents
4
+ 1. [Introduction](#introduction)
5
+ 2. [Acknowledgments](#acknowledgments)
6
+ 3. [Requirements](#requirements)
7
+ 4. [Installation](#installation)
8
+ 5. [How to Run](#how-to-run)
9
+ 6. [Code Explanation](#code-explanation)
10
+ 7. [Contributing](#contributing)
11
+ 8. [License](#license)
12
+
13
+ ---
14
+
15
+ ## Introduction 🌟
16
+ This project is a web application built using Streamlit and PyTorch. It performs unsupervised segmentation on uploaded images. The segmented image can be downloaded, and the colors of the segments can be customized.
17
+
18
+ ---
19
+
20
+ ## Acknowledgments 🙏
21
+ This code is inspired from the project [pytorch-unsupervised-segmentation](https://github.com/kanezaki/pytorch-unsupervised-segmentation) by kanezaki. The original project is based on the paper "Unsupervised Image Segmentation by Backpropagation" presented at IEEE ICASSP 2018. The code is optimized for thin section images and microscopy analysis.
22
+
23
+ ---
24
+
25
+ ## Requirements 📋
26
+ - Python 3.x
27
+ - Streamlit
28
+ - PyTorch
29
+ - OpenCV
30
+ - NumPy
31
+ - scikit-image
32
+ - PIL
33
+ - base64
34
+
35
+ ---
36
+
37
+ ## Installation 🛠️
38
+
39
+ 1. **Clone the repository**
40
+ ```bash
41
+ git clone https://github.com/your-repo/unsupervised-segmentation.git
42
+ ```
43
+ 2. **Navigate to the project directory**
44
+ ```bash
45
+ cd unsupervised-segmentation
46
+ ```
47
+ 3. **Install the required packages**
48
+ ```bash
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
  ---
53
+
54
+ ## How to Run 🚀
55
+
56
+ 1. **Navigate to the project directory**
57
+ ```bash
58
+ cd unsupervised-segmentation
59
+ ```
60
+ 2. **Run the Streamlit app**
61
+ ```bash
62
+ streamlit run app.py
63
+ ```
64
+
65
  ---
66
+ ![Streamlit App Screenshot](https://github.com/fazzam12345/Unsupervised-Segmentation-App/blob/master/Streamlit_app.png?raw=true)
67
+
68
+
69
+ ---
70
+
71
+ ## Contributing 🤝
72
+ Feel free to open issues and pull requests!
73
+
74
+ ---
75
+
76
+ ## License 📜
77
+ This project is licensed under the MIT License.
78
+
79
 
 
Streamlit_app.png ADDED

Git LFS Details

  • SHA256: 785889315b69e61423e57a77494056fa0f15ca125219e558be8750ffc25c32b9
  • Pointer size: 132 Bytes
  • Size of remote file: 5.18 MB
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from src import init
2
+ from src.interface import main
3
+
4
+ if __name__ == "__main__":
5
+ init.initialize_session_state()
6
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit==1.27.0
2
+ opencv-python-headless
3
+ numpy==1.24
4
+ torch==2.0
5
+ Pillow==9.5
6
+ scikit-image==0.18
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (196 Bytes). View file
 
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (212 Bytes). View file
 
src/__pycache__/config.cpython-310.pyc ADDED
Binary file (370 Bytes). View file
 
src/__pycache__/init.cpython-310.pyc ADDED
Binary file (424 Bytes). View file
 
src/__pycache__/interface.cpython-310.pyc ADDED
Binary file (5.96 kB). View file
 
src/__pycache__/interface.cpython-311.pyc ADDED
Binary file (7.97 kB). View file
 
src/__pycache__/models.cpython-310.pyc ADDED
Binary file (3.85 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.52 kB). View file
 
src/init.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def initialize_session_state():
4
+ st.session_state.setdefault('segmented_image', None)
5
+ st.session_state.setdefault('new_colors', {})
src/interface.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import cv2
4
+ import numpy as np
5
+ from src.models import perform_custom_segmentation
6
+ from src.utils import resize_image, download_image
7
+ import os
8
+ import torch
9
+
10
+ # Constants
11
+ TARGET_SIZE = (750, 750)
12
+
13
+ def get_parameters_from_sidebar() -> dict:
14
+ """Get segmentation parameters from sidebar"""
15
+ st.sidebar.header("Segmentation Parameters")
16
+ param_names = ['train_epoch', 'mod_dim1', 'mod_dim2', 'min_label_num', 'max_label_num']
17
+ param_values = [(1, 200, 43), (1, 128, 67), (1, 128, 63), (1, 20, 3), (1, 200, 25)]
18
+ params = {name: st.sidebar.slider(name.replace('_', ' ').title(), *values) for name, values in zip(param_names, param_values)}
19
+
20
+ # Add sliders for target size width and height
21
+ target_size_width = st.sidebar.number_input("Target Size Width", 100, 1200, 750)
22
+ target_size_height = st.sidebar.number_input("Target Size Height", 100, 1200, 750)
23
+ params['target_size'] = (target_size_width, target_size_height)
24
+
25
+ return params
26
+ def display_segmentation_results() -> None:
27
+ """Display segmentation results"""
28
+ st.image(st.session_state.segmented_image, caption='Updated Segmented Image', use_column_width=True)
29
+
30
+ def randomize_colors() -> None:
31
+ """Randomize colors for segmentation labels"""
32
+ unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0)
33
+ random_colors = {tuple(label): tuple(np.random.randint(0, 256, size=3)) for label in unique_labels}
34
+
35
+ for old_color, new_color in random_colors.items():
36
+ mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1)
37
+ st.session_state.segmented_image[mask] = new_color
38
+
39
+ # Update color mappings in session state
40
+ st.session_state.new_colors.update(random_colors)
41
+ st.session_state.image_update_trigger += 1 # Trigger image update
42
+
43
+ def handle_color_picking() -> None:
44
+ """Handle color picking and other functionalities"""
45
+ unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0)
46
+ for i, label in enumerate(unique_labels):
47
+ hex_label = f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}'
48
+ new_color = st.color_picker(f"Choose a new color for label {i}", value=hex_label, key=f"label_{i}")
49
+ new_color_rgb = tuple(int(new_color.lstrip('#')[j:j+2], 16) for j in (0, 2, 4))
50
+ st.session_state.new_colors[tuple(label)] = new_color_rgb
51
+
52
+ # Convert the new colors to hexadecimal for comparison
53
+ new_colors_hex = {tuple(label): f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}' for label in st.session_state.new_colors.values()}
54
+
55
+ for old_color, new_color in st.session_state.new_colors.items():
56
+ # Convert the old color to hexadecimal for comparison
57
+ old_color_hex = f'#{old_color[0]:02x}{old_color[1]:02x}{old_color[2]:02x}'
58
+ # Find the corresponding new color in hexadecimal
59
+ new_color_hex = new_colors_hex[new_color]
60
+ # Update the segmented image with the new color
61
+ mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1)
62
+ st.session_state.segmented_image[mask] = new_color
63
+
64
+ # After updating colors, trigger an update to the segmented image display
65
+ st.session_state.image_update_trigger += 1
66
+
67
+ def calculate_and_display_label_percentages() -> None:
68
+ """Calculate and display label percentages"""
69
+ final_labels = cv2.cvtColor(st.session_state.segmented_image, cv2.COLOR_BGR2GRAY)
70
+ unique_labels, counts = np.unique(final_labels, return_counts=True)
71
+ total_pixels = np.sum(counts)
72
+ label_percentages = {int(label): (count / total_pixels) * 100 for label, count in zip(unique_labels, counts)}
73
+
74
+ # Create a mapping from grayscale label to RGB color
75
+ label_to_color = {}
76
+ for label in unique_labels:
77
+ mask = final_labels == label
78
+ corresponding_color = st.session_state.segmented_image[mask][0]
79
+ hex_color = f'#{corresponding_color[0]:02x}{corresponding_color[1]:02x}{corresponding_color[2]:02x}'
80
+ label_to_color[int(label)] = hex_color
81
+
82
+ st.write("Label Percentages:")
83
+ for label, percentage in label_percentages.items():
84
+ hex_color = label_to_color[label]
85
+ color_box = f'<div style="display: inline-block; width: 20px; height: 20px; background-color: {hex_color}; margin-right: 10px;"></div>'
86
+ st.markdown(f'{color_box} Label {label}: {percentage:.2f}%', unsafe_allow_html=True)
87
+
88
+ def main() -> None:
89
+ st.title("PetroSeg")
90
+ st.info("""
91
+ - **Training Epochs**: Higher values will lead to fewer segments but may take more time.
92
+ - **Image Size**: For better efficiency, upload small-sized images.
93
+ - **Cache**: For best results, clear the cache between different image uploads. You can do this from the menu in the top-right corner.
94
+ """)
95
+
96
+ if torch.cuda.is_available():
97
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
98
+
99
+ # Initialize session state if not already initialized
100
+ if 'segmented_image' not in st.session_state:
101
+ st.session_state.segmented_image = None
102
+ if 'new_colors' not in st.session_state:
103
+ st.session_state.new_colors = {}
104
+ if 'image_update_trigger' not in st.session_state:
105
+ st.session_state.image_update_trigger = 0
106
+
107
+ # Define params before using it
108
+ params = get_parameters_from_sidebar()
109
+
110
+ uploaded_image = st.sidebar.file_uploader("Upload an image", type=["jpg", "png", "jpeg", "bmp", "tiff", "webp"])
111
+ if uploaded_image:
112
+ file_bytes = np.asarray(bytearray(uploaded_image.read()), dtype=np.uint8)
113
+ image = cv2.imdecode(file_bytes, 1)
114
+
115
+ if image is None:
116
+ st.error("Error loading image. Please check the file and try again.")
117
+ return
118
+
119
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
120
+ st.image(image_rgb, caption='Original Image', use_column_width=True)
121
+
122
+ # Use the target size specified by the user
123
+ target_size = params['target_size']
124
+ image_resized = resize_image(image_rgb, target_size)
125
+
126
+ if st.sidebar.button("Start Segmentation"):
127
+ st.session_state.segmented_image = perform_custom_segmentation(image_resized, params)
128
+
129
+ if st.sidebar.button("Change Colors"):
130
+ randomize_colors()
131
+
132
+ if st.session_state.segmented_image is not None:
133
+ handle_color_picking()
134
+ display_segmentation_results()
135
+ calculate_and_display_label_percentages()
136
+ download_image(st.session_state.segmented_image, 'segmented_image.png')
137
+
138
+ if __name__ == "__main__":
139
+ main()
src/models.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import streamlit as st
5
+ import os
6
+ from skimage import segmentation
7
+
8
+
9
+ def perform_custom_segmentation(image, params):
10
+ class Args(object):
11
+ def __init__(self, params):
12
+ self.train_epoch = params.get('train_epoch', 2 ** 3)
13
+ self.mod_dim1 = params.get('mod_dim1', 64)
14
+ self.mod_dim2 = params.get('mod_dim2', 32)
15
+ self.gpu_id = params.get('gpu_id', 0)
16
+ self.min_label_num = params.get('min_label_num', 6)
17
+ self.max_label_num = params.get('max_label_num', 256)
18
+
19
+ args = Args(params)
20
+
21
+ class MyNet(nn.Module):
22
+ def __init__(self, inp_dim, mod_dim1, mod_dim2):
23
+ super(MyNet, self).__init__()
24
+ self.seq = nn.Sequential(
25
+ nn.Conv2d(inp_dim, mod_dim1, kernel_size=3, stride=1, padding=1),
26
+ nn.BatchNorm2d(mod_dim1),
27
+ nn.ReLU(inplace=True),
28
+ nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
29
+ nn.BatchNorm2d(mod_dim2),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(mod_dim2, mod_dim1, kernel_size=3, stride=1, padding=1),
32
+ nn.BatchNorm2d(mod_dim1),
33
+ nn.ReLU(inplace=True),
34
+ nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
35
+ nn.BatchNorm2d(mod_dim2),
36
+ )
37
+
38
+ def forward(self, x):
39
+ return self.seq(x)
40
+
41
+ torch.cuda.manual_seed_all(1943)
42
+ np.random.seed(1943)
43
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
44
+
45
+ '''segmentation ML'''
46
+ seg_map = segmentation.felzenszwalb(image, scale=15, sigma=0.06, min_size=14)
47
+ seg_map = seg_map.flatten()
48
+ seg_lab = [np.where(seg_map == u_label)[0]
49
+ for u_label in np.unique(seg_map)]
50
+
51
+ device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
52
+ tensor = image.transpose((2, 0, 1))
53
+ tensor = tensor.astype(np.float32) / 255.0
54
+ tensor = tensor[np.newaxis, :, :, :]
55
+ tensor = torch.from_numpy(tensor).to(device)
56
+
57
+ model = MyNet(inp_dim=3, mod_dim1=args.mod_dim1, mod_dim2=args.mod_dim2).to(device)
58
+ criterion = torch.nn.CrossEntropyLoss()
59
+ optimizer = torch.optim.SGD(model.parameters(), lr=5e-2, momentum=0.9)
60
+
61
+ image_flatten = image.reshape((-1, 3))
62
+ color_avg = np.random.randint(255, size=(args.max_label_num, 3))
63
+ show = image
64
+
65
+ progress_bar = st.progress(0)
66
+
67
+ for batch_idx in range(args.train_epoch):
68
+ optimizer.zero_grad()
69
+ output = model(tensor)[0]
70
+ output = output.permute(1, 2, 0).view(-1, args.mod_dim2)
71
+ target = torch.argmax(output, 1)
72
+ im_target = target.data.cpu().numpy()
73
+
74
+ for inds in seg_lab:
75
+ u_labels, hist = np.unique(im_target[inds], return_counts=True)
76
+ im_target[inds] = u_labels[np.argmax(hist)]
77
+
78
+ target = torch.from_numpy(im_target)
79
+ target = target.to(device)
80
+ loss = criterion(output, target)
81
+ loss.backward()
82
+ optimizer.step()
83
+
84
+ un_label, lab_inverse = np.unique(im_target, return_inverse=True, )
85
+ if un_label.shape[0] < args.max_label_num:
86
+ img_flatten = image_flatten.copy()
87
+ if len(color_avg) != un_label.shape[0]:
88
+ color_avg = [np.mean(img_flatten[im_target == label], axis=0, dtype=int) for label in un_label]
89
+ for lab_id, color in enumerate(color_avg):
90
+ img_flatten[lab_inverse == lab_id] = color
91
+ show = img_flatten.reshape(image.shape)
92
+
93
+ progress = (batch_idx + 1) / args.train_epoch
94
+ progress_bar.progress(progress)
95
+
96
+ return show
src/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import tempfile
3
+ from io import BytesIO
4
+ import streamlit as st
5
+ import numpy as np
6
+
7
+ def resize_image(image, size):
8
+ return cv2.resize(image, size, interpolation=cv2.INTER_AREA)
9
+
10
+ def automatically_change_segment_colors(segmented_image):
11
+ # Generate a unique color for each segment
12
+ unique_labels = np.unique(segmented_image.reshape(-1, 3), axis=0)
13
+ new_colors = np.random.randint(0, 256, (len(unique_labels), 3), dtype=np.uint8)
14
+
15
+ # Apply the new colors to the segmented image
16
+ for i, label in enumerate(unique_labels):
17
+ mask = np.all(segmented_image == label, axis=-1)
18
+ segmented_image[mask] = new_colors[i]
19
+
20
+ return segmented_image
21
+
22
+ def download_image(image_array, file_name):
23
+ try:
24
+ image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
25
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
26
+ success = cv2.imwrite(temp_file.name, image_array)
27
+ if not success:
28
+ st.error("Could not save image.")
29
+ return
30
+ with open(temp_file.name, 'rb') as f:
31
+ bytes = f.read()
32
+ st.download_button(
33
+ label="Download Image",
34
+ data=BytesIO(bytes),
35
+ file_name=file_name,
36
+ mime='image/png',
37
+ )
38
+ except Exception as e:
39
+ st.error(f"An error occurred: {e}")