Spaces:
Runtime error
Runtime error
File size: 12,993 Bytes
10a2580 edd5fa8 227564e a815e4b c7a2f0c 5fe2dca 3dd62d7 a815e4b 3dd62d7 02437a8 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 4a4e8e6 5274a1a 3dd62d7 4a4e8e6 5274a1a 3dd62d7 4a4e8e6 5274a1a 3dd62d7 4a4e8e6 5274a1a 3dd62d7 f2a8508 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 3e942ca 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 5274a1a 3dd62d7 da61339 3dd62d7 f8bc440 7e77889 f8bc440 3dd62d7 f8bc440 da61339 3dd62d7 1f20319 4a4e8e6 3dd62d7 f2a8508 3dd62d7 1f20319 f2a8508 3dd62d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
# -----------------------------------------------------------------------------
#
# This file is part of the PantoScanner distribution on:
# https://huggingface.co/spaces/swissrail/PantoScanner
#
# PantoScanner - Analytics and measurement capability for technical objects.
# Copyright (C) 2017-2024 Schweizerische Bundesbahnen SBB
#
# Authors (C) 2024 L. Hofstetter (lukas.hofstetter@sbb.ch)
# Authors (C) 2017 U. Gehrig (urs.gehrig@sbb.ch)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# -----------------------------------------------------------------------------
import streamlit as st
import os
import glob
import cv2
import numpy as np
from strip_measure_4_0 import prepare_networks_for_measurement, measure_strip
import plotly.express as px
import pandas as pd
IMG_BASE_DIR = 'images'
CAMERA_MATRIX = [
[11100, 0, 1604],
[0, 11100, 1100],
[0, 0, 1]
]
OBJECT_REFERENCE_POINTS = [
[347, 0, 42], # B
[347, 0, 522], # D
[-347, 26, 480], # F
[-347, 26, 0]] # H
LOWER_CONTOUR_QUADRATIC_CONSTANT = 0.00005
CAMERA_PARAMETERS = (50, 0.0045, 2200, 3208)
PLANE_PARAMETERS_CLOSE = ([0, 0, 0], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin
PLANE_PARAMETERS_FAR = ([0, 0, 480], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin
BOUNDARY_1 = (300, 92)
BOUNDARY_2 = (650, 1500)
IMAGE_SIZE_SEG = 1408
IMAGE_WIDTH_SEG = 1408
IMAGE_HEIGHT_SEG = 576
path_yolo_model = os.path.join(os.getcwd(), 'app', 'detection_model.pt')
path_segmentation_model = os.path.join(os.getcwd(), 'app', 'segmentation_model.pth')
# Function: get_image_paths
# Description: Returns a list of image paths in the specified directory.
def get_image_paths(base_dir: str):
return glob.glob(f'{os.getcwd()}/{base_dir}/*.png')
# Function: get_num_images
# Description: Returns the number of images in the current session state.
def get_num_images():
return len(st.session_state['image_path_list'])
# Function: increment_index
# Description: Increments the current image index by 1, taking into account the maximum index and optional overflow behavior.
def increment_index(index_current: int, max_index: int, overflow=False, min_index=0):
index_new = index_current + 1
if index_new <= max_index:
return index_new
elif overflow:
return min_index
else:
return index_current
# Function: decrement_index
# Description: Decrements the current image index by 1, taking into account the minimum index and optional overflow behavior.
def decrement_index(index_current: int, min_index, overflow=False, max_index=-1):
index_new = index_current - 1
if index_new >= min_index:
return index_new
elif overflow:
return max_index
else:
return index_current
# Function: callback_button_previous
# Description: Callback function for the "previous Image" button. Updates the current image index and calls update_on_index_change.
def callback_button_previous(overflow_index=True):
new_index = decrement_index(st.session_state['image_index_current'], min_index=0,
overflow=overflow_index, max_index=st.session_state['num_images']-1)
update_on_index_change(new_index)
# Function: callback_button_next
# Description: Callback function for the "next Image" button. Updates the current image index and calls update_on_index_change.
def callback_button_next(overflow_index=True):
new_index = increment_index(st.session_state['image_index_current'], st.session_state['num_images'],
overflow=overflow_index, min_index=0)
update_on_index_change(new_index)
# Function: update_on_index_change
# Description: Updates the session state variables related to the current image index and the current image array. Calls get_current_image and get_current_measurement.
def update_on_index_change(new_index: int):
st.session_state['image_index_current'] = new_index
st.session_state['current_image_array'] = get_current_image()
# put the current bale boundaries into the list, regardless of whether they have been stored to the database
st.session_state['current_measurement'] = get_current_measurement()
# Function: load_image_array
# Description: Loads an image array from the specified image path.
def load_image_array(image_path: str):
return cv2.imread(image_path)
# Function: get_current_image
# Description: Returns the current image array based on the current image index. If the image array is not already loaded, it loads it using load_image_array.
def get_current_image():
index_current = st.session_state['image_index_current']
this_img_current = st.session_state['image_data_list'][index_current]
if isinstance(this_img_current, np.ndarray):
return this_img_current
else:
this_img_current = load_image_array(st.session_state['image_path_list'][index_current])
st.session_state['image_data_list'][index_current] = this_img_current
return this_img_current
# Function: callback_button_measure
# Description: Callback function for the "Measure" button. Calls either display_cached_measurement_data or display_calculate_measurement_data based on whether the current image has a cached measurement.
def callback_button_measure():
has_measurement, measurement_result = get_current_measurement()
if has_measurement:
display_cached_measurement_data()
else:
display_calculate_measurement_data()
# Function: display_cached_measurement_data
# Description: Displays the cached measurement data for the current image.
def display_cached_measurement_data():
st.info('Getting cached measurement', icon="ℹ️")
display_measurement()
# Function: display_calculate_measurement_data
# Description: Calculates the measurement data for the current image and updates the session state. Displays the measurement data.
def display_calculate_measurement_data():
with st.spinner('Calculating Profile Height....'):
this_image_path = st.session_state['image_path_list'][st.session_state['image_index_current']]
measurement_result = measure_image(this_image_path)
update_measurements(measurement_result, st.session_state['image_index_current'])
st.success('Measurement is done !')
display_measurement()
# Function: measure_image
# Description: Calls the measure_strip function from the strip_measure_4_0 module to measure the strip in the current image. Returns the measurement result.
def measure_image(image_path: str):
measurement_result = measure_strip(img_path=image_path,
model_yolo=st.session_state['models']['detection'],
segmentation_model=st.session_state['models']['segmentation'],
camera_matrix=CAMERA_MATRIX,
object_reference_points=OBJECT_REFERENCE_POINTS,
camera_parameters=CAMERA_PARAMETERS,
plane_parameters_close=PLANE_PARAMETERS_CLOSE,
plane_parameters_far=PLANE_PARAMETERS_FAR,
lower_contour_quadratic_constant=LOWER_CONTOUR_QUADRATIC_CONSTANT,
boundary_1=BOUNDARY_1,
boundary_2=BOUNDARY_2,
image_size_seg=IMAGE_SIZE_SEG,
image_width_seg=IMAGE_WIDTH_SEG,
image_height_seg=IMAGE_HEIGHT_SEG)
arr_0 = measurement_result[0]
arr_1 = measurement_result[1]
arr_0[:, 0] = np.abs(arr_0[:, 0])
arr_1[:, 0] = np.abs(arr_1[:, 0])
return arr_0, arr_1
# Function: get_current_measurement
# Description: Returns the current measurement data based on the current image index. If the measurement data is not available, returns False and None.
def get_current_measurement():
this_measurement = st.session_state['measurement_data_list'][st.session_state['image_index_current']]
if this_measurement is not None:
return True, this_measurement
else:
return False, None
# Function: update_measurements
# Description: Updates the measurement data for the specified index in the session state.
def update_measurements(measurement, index_measurement):
st.session_state['measurement_data_list'][index_measurement] = measurement
# Function: display_measurement
# Description: Displays the measurement data for the current image.
def display_measurement():
has_measurement, measurement_data = get_current_measurement()
if has_measurement:
st.subheader(f'Profile height (mm)')
measurement_to_streamlit_chart(measurement_data[0], measurement_data[1])
# Function: measurement_to_streamlit_chart
# Description: Converts the measurement data into a Pandas DataFrame and plots a line chart using Plotly.
def measurement_to_streamlit_chart(profile_array_1, profile_array_2):
height_list = []
coord_list = []
indicator_list = []
height_list.extend(profile_array_1[:, 0].tolist())
coord_list.extend(profile_array_1[:, 1].tolist())
indicator_list.extend(['Profile A' for _ in range(len(profile_array_1))])
height_list.extend(profile_array_2[:, 0].tolist())
coord_list.extend(profile_array_2[:, 1].tolist())
indicator_list.extend(['Profile B' for _ in range(len(profile_array_2))])
df = pd.DataFrame(dict(x=coord_list, y=height_list, indicator=indicator_list))
fig = px.line(df, x='x', y='y', color='indicator', symbol="indicator")
st.plotly_chart(fig, use_container_width=True)
# Initialization section
if 'image_path_list' not in st.session_state:
st.session_state['image_path_list'] = get_image_paths(IMG_BASE_DIR)
if 'num_images' not in st.session_state:
st.session_state['num_images'] = get_num_images()
if 'image_data_list' not in st.session_state:
st.session_state['image_data_list'] = [None for _ in range(st.session_state['num_images'])]
if 'image_index_current' not in st.session_state:
st.session_state['image_index_current'] = 0
if 'current_image_array' not in st.session_state:
st.session_state['current_image_array'] = get_current_image()
if 'measurement_data_list' not in st.session_state:
st.session_state['measurement_data_list'] = [None for _ in range(st.session_state['num_images'])]
if 'current_measurement' not in st.session_state:
st.session_state['current_measurement'] = get_current_measurement()
if 'models' not in st.session_state:
seg_dplv3_model, yolo_nn_model = prepare_networks_for_measurement(model_yolo_path=path_yolo_model,
model_segmentation_path=path_segmentation_model)
st.session_state['models'] = {'segmentation': seg_dplv3_model, 'detection': yolo_nn_model}
# Display section
image_emoji = '📷'
model_emoji = '⚙️'
profile_emoji = '📈'
st.set_page_config(layout='wide')
st.title('PantoScanner')
#st.subheader(f'Source Image')
multi = '''This app processes the detection and segementation of a
pantograph sliding element from a train - and - demonstrates the
extraction of the thickness by displaying it in a chart. To build
a time series of thicknesses per sliding element, e.g. by selecting
the thickness in the middle.'''
st.markdown(multi)
st.image(st.session_state['current_image_array'])
st.markdown(
"""
<style>
div[data-testid="column"]:nth-of-type(1)
{
text-align: start;
}
div[data-testid="column"]:nth-of-type(2)
{
text-align: center;
}
div[data-testid="column"]:nth-of-type(3)
{
text-align: end;
}
</style>
""",unsafe_allow_html=True
)
col1, col2, col3 = st.columns(3)
# insert prev button --> decrement image_selected_index i = min(i -= 1, 0) % or just overflow to last image
with col1:
button_previous = st.button("Previous image", on_click=callback_button_previous, kwargs={'overflow_index': True})
with col2:
button_measure = st.button("Measure")
with col3:
button_next = st.button("Next image", on_click=callback_button_next, kwargs={'overflow_index': True})
if button_measure:
callback_button_measure()
|