File size: 12,993 Bytes
10a2580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edd5fa8
227564e
a815e4b
c7a2f0c
5fe2dca
3dd62d7
 
 
a815e4b
3dd62d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02437a8
 
3dd62d7
5274a1a
 
3dd62d7
 
 
5274a1a
 
3dd62d7
 
 
5274a1a
 
3dd62d7
 
 
 
 
 
 
 
 
5274a1a
 
3dd62d7
 
 
 
 
 
 
 
 
5274a1a
 
3dd62d7
 
 
 
 
5274a1a
 
3dd62d7
 
 
 
 
5274a1a
 
3dd62d7
 
 
 
 
 
5274a1a
 
3dd62d7
 
4a4e8e6
5274a1a
 
3dd62d7
 
 
 
 
 
 
 
 
4a4e8e6
5274a1a
 
3dd62d7
 
 
 
 
 
4a4e8e6
5274a1a
 
3dd62d7
 
 
4a4e8e6
5274a1a
 
3dd62d7
 
 
 
 
 
 
f2a8508
5274a1a
 
3dd62d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5274a1a
 
3dd62d7
 
 
 
 
 
 
5274a1a
 
3dd62d7
 
 
5274a1a
 
3dd62d7
 
 
3e942ca
3dd62d7
 
5274a1a
 
3dd62d7
 
 
 
 
 
 
 
 
 
 
 
 
 
5274a1a
 
3dd62d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5274a1a
 
3dd62d7
 
 
da61339
3dd62d7
 
f8bc440
 
7e77889
 
 
 
f8bc440
 
3dd62d7
f8bc440
da61339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dd62d7
 
 
1f20319
4a4e8e6
3dd62d7
 
f2a8508
3dd62d7
1f20319
f2a8508
3dd62d7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# -----------------------------------------------------------------------------
#
# This file is part of the PantoScanner distribution on: 
# https://huggingface.co/spaces/swissrail/PantoScanner
#
# PantoScanner - Analytics and measurement capability for technical objects.
# Copyright (C) 2017-2024 Schweizerische Bundesbahnen SBB
#
# Authors (C) 2024 L. Hofstetter (lukas.hofstetter@sbb.ch)
# Authors (C) 2017 U. Gehrig (urs.gehrig@sbb.ch)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# -----------------------------------------------------------------------------

import streamlit as st
import os
import glob
import cv2
import numpy as np
from strip_measure_4_0 import prepare_networks_for_measurement, measure_strip
import plotly.express as px
import pandas as pd

IMG_BASE_DIR = 'images'
CAMERA_MATRIX = [
    [11100, 0, 1604],
    [0, 11100, 1100],
    [0, 0, 1]
]

OBJECT_REFERENCE_POINTS = [
    [347, 0, 42],  # B
    [347, 0, 522],  # D
    [-347, 26, 480],  # F
    [-347, 26, 0]]  # H

LOWER_CONTOUR_QUADRATIC_CONSTANT = 0.00005
CAMERA_PARAMETERS = (50, 0.0045, 2200, 3208)
PLANE_PARAMETERS_CLOSE = ([0, 0, 0], (1, 0, 0), (0, 1, 0))  # Vector from pantograph coordinate frame to plane origin
PLANE_PARAMETERS_FAR = ([0, 0, 480], (1, 0, 0), (0, 1, 0))  # Vector from pantograph coordinate frame to plane origin
BOUNDARY_1 = (300, 92)
BOUNDARY_2 = (650, 1500)
IMAGE_SIZE_SEG = 1408
IMAGE_WIDTH_SEG = 1408
IMAGE_HEIGHT_SEG = 576

path_yolo_model = os.path.join(os.getcwd(), 'app', 'detection_model.pt')
path_segmentation_model = os.path.join(os.getcwd(), 'app', 'segmentation_model.pth')

# Function: get_image_paths
# Description: Returns a list of image paths in the specified directory.
def get_image_paths(base_dir: str):
    return glob.glob(f'{os.getcwd()}/{base_dir}/*.png')

# Function: get_num_images
# Description: Returns the number of images in the current session state.
def get_num_images():
    return len(st.session_state['image_path_list'])

# Function: increment_index
# Description: Increments the current image index by 1, taking into account the maximum index and optional overflow behavior.
def increment_index(index_current: int, max_index: int, overflow=False, min_index=0):
    index_new = index_current + 1
    if index_new <= max_index:
        return index_new
    elif overflow:
        return min_index
    else:
        return index_current

# Function: decrement_index
# Description: Decrements the current image index by 1, taking into account the minimum index and optional overflow behavior.
def decrement_index(index_current: int, min_index, overflow=False, max_index=-1):
    index_new = index_current - 1
    if index_new >= min_index:
        return index_new
    elif overflow:
        return max_index
    else:
        return index_current

# Function: callback_button_previous
# Description: Callback function for the "previous Image" button. Updates the current image index and calls update_on_index_change.
def callback_button_previous(overflow_index=True):
    new_index = decrement_index(st.session_state['image_index_current'], min_index=0,
                                overflow=overflow_index, max_index=st.session_state['num_images']-1)
    update_on_index_change(new_index)

# Function: callback_button_next
# Description: Callback function for the "next Image" button. Updates the current image index and calls update_on_index_change.
def callback_button_next(overflow_index=True):
    new_index = increment_index(st.session_state['image_index_current'], st.session_state['num_images'],
                                overflow=overflow_index, min_index=0)
    update_on_index_change(new_index)

# Function: update_on_index_change
# Description: Updates the session state variables related to the current image index and the current image array. Calls get_current_image and get_current_measurement.
def update_on_index_change(new_index: int):
    st.session_state['image_index_current'] = new_index
    st.session_state['current_image_array'] = get_current_image()
    # put the current bale boundaries into the list, regardless of whether they have been stored to the database
    st.session_state['current_measurement'] = get_current_measurement()

# Function: load_image_array
# Description: Loads an image array from the specified image path.
def load_image_array(image_path: str):
    return cv2.imread(image_path)

# Function: get_current_image
# Description: Returns the current image array based on the current image index. If the image array is not already loaded, it loads it using load_image_array.
def get_current_image():
    index_current = st.session_state['image_index_current']
    this_img_current = st.session_state['image_data_list'][index_current]
    if isinstance(this_img_current, np.ndarray):
        return this_img_current
    else:
        this_img_current = load_image_array(st.session_state['image_path_list'][index_current])
        st.session_state['image_data_list'][index_current] = this_img_current
        return this_img_current

# Function: callback_button_measure
# Description: Callback function for the "Measure" button. Calls either display_cached_measurement_data or display_calculate_measurement_data based on whether the current image has a cached measurement.
def callback_button_measure():
    has_measurement, measurement_result = get_current_measurement()
    if has_measurement:
        display_cached_measurement_data()
    else:
        display_calculate_measurement_data()

# Function: display_cached_measurement_data
# Description: Displays the cached measurement data for the current image.
def display_cached_measurement_data():
    st.info('Getting cached measurement', icon="ℹ️")
    display_measurement()

# Function: display_calculate_measurement_data
# Description: Calculates the measurement data for the current image and updates the session state. Displays the measurement data.
def display_calculate_measurement_data():
    with st.spinner('Calculating Profile Height....'):
        this_image_path = st.session_state['image_path_list'][st.session_state['image_index_current']]
        measurement_result = measure_image(this_image_path)
        update_measurements(measurement_result, st.session_state['image_index_current'])
    st.success('Measurement is done !')
    display_measurement()

# Function: measure_image
# Description: Calls the measure_strip function from the strip_measure_4_0 module to measure the strip in the current image. Returns the measurement result.
def measure_image(image_path: str):
    measurement_result = measure_strip(img_path=image_path,
                                       model_yolo=st.session_state['models']['detection'],
                                       segmentation_model=st.session_state['models']['segmentation'],
                                       camera_matrix=CAMERA_MATRIX,
                                       object_reference_points=OBJECT_REFERENCE_POINTS,
                                       camera_parameters=CAMERA_PARAMETERS,
                                       plane_parameters_close=PLANE_PARAMETERS_CLOSE,
                                       plane_parameters_far=PLANE_PARAMETERS_FAR,
                                       lower_contour_quadratic_constant=LOWER_CONTOUR_QUADRATIC_CONSTANT,
                                       boundary_1=BOUNDARY_1,
                                       boundary_2=BOUNDARY_2,
                                       image_size_seg=IMAGE_SIZE_SEG,
                                       image_width_seg=IMAGE_WIDTH_SEG,
                                       image_height_seg=IMAGE_HEIGHT_SEG)
    arr_0 = measurement_result[0]
    arr_1 = measurement_result[1]
    arr_0[:, 0] = np.abs(arr_0[:, 0])
    arr_1[:, 0] = np.abs(arr_1[:, 0])
    return arr_0, arr_1

# Function: get_current_measurement
# Description: Returns the current measurement data based on the current image index. If the measurement data is not available, returns False and None.
def get_current_measurement():
    this_measurement = st.session_state['measurement_data_list'][st.session_state['image_index_current']]
    if this_measurement is not None:
        return True, this_measurement
    else:
        return False, None

# Function: update_measurements
# Description: Updates the measurement data for the specified index in the session state.
def update_measurements(measurement, index_measurement):
    st.session_state['measurement_data_list'][index_measurement] = measurement

# Function: display_measurement
# Description: Displays the measurement data for the current image.
def display_measurement():
    has_measurement, measurement_data = get_current_measurement()
    if has_measurement:
        st.subheader(f'Profile height (mm)')
        measurement_to_streamlit_chart(measurement_data[0], measurement_data[1])

# Function: measurement_to_streamlit_chart
# Description: Converts the measurement data into a Pandas DataFrame and plots a line chart using Plotly.
def measurement_to_streamlit_chart(profile_array_1, profile_array_2):
    height_list = []
    coord_list = []
    indicator_list = []
    height_list.extend(profile_array_1[:, 0].tolist())
    coord_list.extend(profile_array_1[:, 1].tolist())
    indicator_list.extend(['Profile A' for _ in range(len(profile_array_1))])
    height_list.extend(profile_array_2[:, 0].tolist())
    coord_list.extend(profile_array_2[:, 1].tolist())
    indicator_list.extend(['Profile B' for _ in range(len(profile_array_2))])
    df = pd.DataFrame(dict(x=coord_list, y=height_list, indicator=indicator_list))
    fig = px.line(df, x='x', y='y', color='indicator', symbol="indicator")
    st.plotly_chart(fig, use_container_width=True)

# Initialization section
    
if 'image_path_list' not in st.session_state:
    st.session_state['image_path_list'] = get_image_paths(IMG_BASE_DIR)

if 'num_images' not in st.session_state:
    st.session_state['num_images'] = get_num_images()

if 'image_data_list' not in st.session_state:
    st.session_state['image_data_list'] = [None for _ in range(st.session_state['num_images'])]

if 'image_index_current' not in st.session_state:
    st.session_state['image_index_current'] = 0

if 'current_image_array' not in st.session_state:
    st.session_state['current_image_array'] = get_current_image()

if 'measurement_data_list' not in st.session_state:
    st.session_state['measurement_data_list'] = [None for _ in range(st.session_state['num_images'])]

if 'current_measurement' not in st.session_state:
    st.session_state['current_measurement'] = get_current_measurement()

if 'models' not in st.session_state:
    seg_dplv3_model, yolo_nn_model = prepare_networks_for_measurement(model_yolo_path=path_yolo_model,
                                                                      model_segmentation_path=path_segmentation_model)
    st.session_state['models'] = {'segmentation': seg_dplv3_model, 'detection': yolo_nn_model}

# Display section
    
image_emoji = '📷'
model_emoji = '⚙️'
profile_emoji = '📈'
st.set_page_config(layout='wide')
st.title('PantoScanner')
#st.subheader(f'Source Image')

multi = '''This app processes the detection and segementation of a 
pantograph sliding element from a train - and - demonstrates the 
extraction of the thickness by displaying it in a chart. To build 
a time series of thicknesses per sliding element, e.g. by selecting 
the thickness in the middle.'''
st.markdown(multi)

st.image(st.session_state['current_image_array'])

st.markdown(
    """
    <style>
        div[data-testid="column"]:nth-of-type(1)
        {
            text-align: start;
        } 

        div[data-testid="column"]:nth-of-type(2)
        {
            text-align: center;
        } 
        
        div[data-testid="column"]:nth-of-type(3)
        {
            text-align: end;
        } 
    </style>
    """,unsafe_allow_html=True
)

col1, col2, col3 = st.columns(3)
# insert prev button --> decrement image_selected_index i = min(i -= 1, 0) % or just overflow to last image
with col1:
    button_previous = st.button("Previous image", on_click=callback_button_previous, kwargs={'overflow_index': True})

with col2:
    button_measure = st.button("Measure")

with col3:
    button_next = st.button("Next image", on_click=callback_button_next, kwargs={'overflow_index': True})

if button_measure:
    callback_button_measure()