PantoScanner / app.py
ursgehrig's picture
Update app.py
3e942ca verified
# -----------------------------------------------------------------------------
#
# This file is part of the PantoScanner distribution on:
# https://huggingface.co/spaces/swissrail/PantoScanner
#
# PantoScanner - Analytics and measurement capability for technical objects.
# Copyright (C) 2017-2024 Schweizerische Bundesbahnen SBB
#
# Authors (C) 2024 L. Hofstetter (lukas.hofstetter@sbb.ch)
# Authors (C) 2017 U. Gehrig (urs.gehrig@sbb.ch)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# -----------------------------------------------------------------------------
import streamlit as st
import os
import glob
import cv2
import numpy as np
from strip_measure_4_0 import prepare_networks_for_measurement, measure_strip
import plotly.express as px
import pandas as pd
IMG_BASE_DIR = 'images'
CAMERA_MATRIX = [
[11100, 0, 1604],
[0, 11100, 1100],
[0, 0, 1]
]
OBJECT_REFERENCE_POINTS = [
[347, 0, 42], # B
[347, 0, 522], # D
[-347, 26, 480], # F
[-347, 26, 0]] # H
LOWER_CONTOUR_QUADRATIC_CONSTANT = 0.00005
CAMERA_PARAMETERS = (50, 0.0045, 2200, 3208)
PLANE_PARAMETERS_CLOSE = ([0, 0, 0], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin
PLANE_PARAMETERS_FAR = ([0, 0, 480], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin
BOUNDARY_1 = (300, 92)
BOUNDARY_2 = (650, 1500)
IMAGE_SIZE_SEG = 1408
IMAGE_WIDTH_SEG = 1408
IMAGE_HEIGHT_SEG = 576
path_yolo_model = os.path.join(os.getcwd(), 'app', 'detection_model.pt')
path_segmentation_model = os.path.join(os.getcwd(), 'app', 'segmentation_model.pth')
# Function: get_image_paths
# Description: Returns a list of image paths in the specified directory.
def get_image_paths(base_dir: str):
return glob.glob(f'{os.getcwd()}/{base_dir}/*.png')
# Function: get_num_images
# Description: Returns the number of images in the current session state.
def get_num_images():
return len(st.session_state['image_path_list'])
# Function: increment_index
# Description: Increments the current image index by 1, taking into account the maximum index and optional overflow behavior.
def increment_index(index_current: int, max_index: int, overflow=False, min_index=0):
index_new = index_current + 1
if index_new <= max_index:
return index_new
elif overflow:
return min_index
else:
return index_current
# Function: decrement_index
# Description: Decrements the current image index by 1, taking into account the minimum index and optional overflow behavior.
def decrement_index(index_current: int, min_index, overflow=False, max_index=-1):
index_new = index_current - 1
if index_new >= min_index:
return index_new
elif overflow:
return max_index
else:
return index_current
# Function: callback_button_previous
# Description: Callback function for the "previous Image" button. Updates the current image index and calls update_on_index_change.
def callback_button_previous(overflow_index=True):
new_index = decrement_index(st.session_state['image_index_current'], min_index=0,
overflow=overflow_index, max_index=st.session_state['num_images']-1)
update_on_index_change(new_index)
# Function: callback_button_next
# Description: Callback function for the "next Image" button. Updates the current image index and calls update_on_index_change.
def callback_button_next(overflow_index=True):
new_index = increment_index(st.session_state['image_index_current'], st.session_state['num_images'],
overflow=overflow_index, min_index=0)
update_on_index_change(new_index)
# Function: update_on_index_change
# Description: Updates the session state variables related to the current image index and the current image array. Calls get_current_image and get_current_measurement.
def update_on_index_change(new_index: int):
st.session_state['image_index_current'] = new_index
st.session_state['current_image_array'] = get_current_image()
# put the current bale boundaries into the list, regardless of whether they have been stored to the database
st.session_state['current_measurement'] = get_current_measurement()
# Function: load_image_array
# Description: Loads an image array from the specified image path.
def load_image_array(image_path: str):
return cv2.imread(image_path)
# Function: get_current_image
# Description: Returns the current image array based on the current image index. If the image array is not already loaded, it loads it using load_image_array.
def get_current_image():
index_current = st.session_state['image_index_current']
this_img_current = st.session_state['image_data_list'][index_current]
if isinstance(this_img_current, np.ndarray):
return this_img_current
else:
this_img_current = load_image_array(st.session_state['image_path_list'][index_current])
st.session_state['image_data_list'][index_current] = this_img_current
return this_img_current
# Function: callback_button_measure
# Description: Callback function for the "Measure" button. Calls either display_cached_measurement_data or display_calculate_measurement_data based on whether the current image has a cached measurement.
def callback_button_measure():
has_measurement, measurement_result = get_current_measurement()
if has_measurement:
display_cached_measurement_data()
else:
display_calculate_measurement_data()
# Function: display_cached_measurement_data
# Description: Displays the cached measurement data for the current image.
def display_cached_measurement_data():
st.info('Getting cached measurement', icon="ℹ️")
display_measurement()
# Function: display_calculate_measurement_data
# Description: Calculates the measurement data for the current image and updates the session state. Displays the measurement data.
def display_calculate_measurement_data():
with st.spinner('Calculating Profile Height....'):
this_image_path = st.session_state['image_path_list'][st.session_state['image_index_current']]
measurement_result = measure_image(this_image_path)
update_measurements(measurement_result, st.session_state['image_index_current'])
st.success('Measurement is done !')
display_measurement()
# Function: measure_image
# Description: Calls the measure_strip function from the strip_measure_4_0 module to measure the strip in the current image. Returns the measurement result.
def measure_image(image_path: str):
measurement_result = measure_strip(img_path=image_path,
model_yolo=st.session_state['models']['detection'],
segmentation_model=st.session_state['models']['segmentation'],
camera_matrix=CAMERA_MATRIX,
object_reference_points=OBJECT_REFERENCE_POINTS,
camera_parameters=CAMERA_PARAMETERS,
plane_parameters_close=PLANE_PARAMETERS_CLOSE,
plane_parameters_far=PLANE_PARAMETERS_FAR,
lower_contour_quadratic_constant=LOWER_CONTOUR_QUADRATIC_CONSTANT,
boundary_1=BOUNDARY_1,
boundary_2=BOUNDARY_2,
image_size_seg=IMAGE_SIZE_SEG,
image_width_seg=IMAGE_WIDTH_SEG,
image_height_seg=IMAGE_HEIGHT_SEG)
arr_0 = measurement_result[0]
arr_1 = measurement_result[1]
arr_0[:, 0] = np.abs(arr_0[:, 0])
arr_1[:, 0] = np.abs(arr_1[:, 0])
return arr_0, arr_1
# Function: get_current_measurement
# Description: Returns the current measurement data based on the current image index. If the measurement data is not available, returns False and None.
def get_current_measurement():
this_measurement = st.session_state['measurement_data_list'][st.session_state['image_index_current']]
if this_measurement is not None:
return True, this_measurement
else:
return False, None
# Function: update_measurements
# Description: Updates the measurement data for the specified index in the session state.
def update_measurements(measurement, index_measurement):
st.session_state['measurement_data_list'][index_measurement] = measurement
# Function: display_measurement
# Description: Displays the measurement data for the current image.
def display_measurement():
has_measurement, measurement_data = get_current_measurement()
if has_measurement:
st.subheader(f'Profile height (mm)')
measurement_to_streamlit_chart(measurement_data[0], measurement_data[1])
# Function: measurement_to_streamlit_chart
# Description: Converts the measurement data into a Pandas DataFrame and plots a line chart using Plotly.
def measurement_to_streamlit_chart(profile_array_1, profile_array_2):
height_list = []
coord_list = []
indicator_list = []
height_list.extend(profile_array_1[:, 0].tolist())
coord_list.extend(profile_array_1[:, 1].tolist())
indicator_list.extend(['Profile A' for _ in range(len(profile_array_1))])
height_list.extend(profile_array_2[:, 0].tolist())
coord_list.extend(profile_array_2[:, 1].tolist())
indicator_list.extend(['Profile B' for _ in range(len(profile_array_2))])
df = pd.DataFrame(dict(x=coord_list, y=height_list, indicator=indicator_list))
fig = px.line(df, x='x', y='y', color='indicator', symbol="indicator")
st.plotly_chart(fig, use_container_width=True)
# Initialization section
if 'image_path_list' not in st.session_state:
st.session_state['image_path_list'] = get_image_paths(IMG_BASE_DIR)
if 'num_images' not in st.session_state:
st.session_state['num_images'] = get_num_images()
if 'image_data_list' not in st.session_state:
st.session_state['image_data_list'] = [None for _ in range(st.session_state['num_images'])]
if 'image_index_current' not in st.session_state:
st.session_state['image_index_current'] = 0
if 'current_image_array' not in st.session_state:
st.session_state['current_image_array'] = get_current_image()
if 'measurement_data_list' not in st.session_state:
st.session_state['measurement_data_list'] = [None for _ in range(st.session_state['num_images'])]
if 'current_measurement' not in st.session_state:
st.session_state['current_measurement'] = get_current_measurement()
if 'models' not in st.session_state:
seg_dplv3_model, yolo_nn_model = prepare_networks_for_measurement(model_yolo_path=path_yolo_model,
model_segmentation_path=path_segmentation_model)
st.session_state['models'] = {'segmentation': seg_dplv3_model, 'detection': yolo_nn_model}
# Display section
image_emoji = '📷'
model_emoji = '⚙️'
profile_emoji = '📈'
st.set_page_config(layout='wide')
st.title('PantoScanner')
#st.subheader(f'Source Image')
multi = '''This app processes the detection and segementation of a
pantograph sliding element from a train - and - demonstrates the
extraction of the thickness by displaying it in a chart. To build
a time series of thicknesses per sliding element, e.g. by selecting
the thickness in the middle.'''
st.markdown(multi)
st.image(st.session_state['current_image_array'])
st.markdown(
"""
<style>
div[data-testid="column"]:nth-of-type(1)
{
text-align: start;
}
div[data-testid="column"]:nth-of-type(2)
{
text-align: center;
}
div[data-testid="column"]:nth-of-type(3)
{
text-align: end;
}
</style>
""",unsafe_allow_html=True
)
col1, col2, col3 = st.columns(3)
# insert prev button --> decrement image_selected_index i = min(i -= 1, 0) % or just overflow to last image
with col1:
button_previous = st.button("Previous image", on_click=callback_button_previous, kwargs={'overflow_index': True})
with col2:
button_measure = st.button("Measure")
with col3:
button_next = st.button("Next image", on_click=callback_button_next, kwargs={'overflow_index': True})
if button_measure:
callback_button_measure()