# ----------------------------------------------------------------------------- # # This file is part of the PantoScanner distribution on: # https://huggingface.co/spaces/swissrail/PantoScanner # # PantoScanner - Analytics and measurement capability for technical objects. # Copyright (C) 2017-2024 Schweizerische Bundesbahnen SBB # # Authors (C) 2024 L. Hofstetter (lukas.hofstetter@sbb.ch) # Authors (C) 2017 U. Gehrig (urs.gehrig@sbb.ch) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # ----------------------------------------------------------------------------- import streamlit as st import os import glob import cv2 import numpy as np from strip_measure_4_0 import prepare_networks_for_measurement, measure_strip import plotly.express as px import pandas as pd IMG_BASE_DIR = 'images' CAMERA_MATRIX = [ [11100, 0, 1604], [0, 11100, 1100], [0, 0, 1] ] OBJECT_REFERENCE_POINTS = [ [347, 0, 42], # B [347, 0, 522], # D [-347, 26, 480], # F [-347, 26, 0]] # H LOWER_CONTOUR_QUADRATIC_CONSTANT = 0.00005 CAMERA_PARAMETERS = (50, 0.0045, 2200, 3208) PLANE_PARAMETERS_CLOSE = ([0, 0, 0], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin PLANE_PARAMETERS_FAR = ([0, 0, 480], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin BOUNDARY_1 = (300, 92) BOUNDARY_2 = (650, 1500) IMAGE_SIZE_SEG = 1408 IMAGE_WIDTH_SEG = 1408 IMAGE_HEIGHT_SEG = 576 path_yolo_model = os.path.join(os.getcwd(), 'app', 'detection_model.pt') path_segmentation_model = os.path.join(os.getcwd(), 'app', 'segmentation_model.pth') # Function: get_image_paths # Description: Returns a list of image paths in the specified directory. def get_image_paths(base_dir: str): return glob.glob(f'{os.getcwd()}/{base_dir}/*.png') # Function: get_num_images # Description: Returns the number of images in the current session state. def get_num_images(): return len(st.session_state['image_path_list']) # Function: increment_index # Description: Increments the current image index by 1, taking into account the maximum index and optional overflow behavior. def increment_index(index_current: int, max_index: int, overflow=False, min_index=0): index_new = index_current + 1 if index_new <= max_index: return index_new elif overflow: return min_index else: return index_current # Function: decrement_index # Description: Decrements the current image index by 1, taking into account the minimum index and optional overflow behavior. def decrement_index(index_current: int, min_index, overflow=False, max_index=-1): index_new = index_current - 1 if index_new >= min_index: return index_new elif overflow: return max_index else: return index_current # Function: callback_button_previous # Description: Callback function for the "previous Image" button. Updates the current image index and calls update_on_index_change. def callback_button_previous(overflow_index=True): new_index = decrement_index(st.session_state['image_index_current'], min_index=0, overflow=overflow_index, max_index=st.session_state['num_images']-1) update_on_index_change(new_index) # Function: callback_button_next # Description: Callback function for the "next Image" button. Updates the current image index and calls update_on_index_change. def callback_button_next(overflow_index=True): new_index = increment_index(st.session_state['image_index_current'], st.session_state['num_images'], overflow=overflow_index, min_index=0) update_on_index_change(new_index) # Function: update_on_index_change # Description: Updates the session state variables related to the current image index and the current image array. Calls get_current_image and get_current_measurement. def update_on_index_change(new_index: int): st.session_state['image_index_current'] = new_index st.session_state['current_image_array'] = get_current_image() # put the current bale boundaries into the list, regardless of whether they have been stored to the database st.session_state['current_measurement'] = get_current_measurement() # Function: load_image_array # Description: Loads an image array from the specified image path. def load_image_array(image_path: str): return cv2.imread(image_path) # Function: get_current_image # Description: Returns the current image array based on the current image index. If the image array is not already loaded, it loads it using load_image_array. def get_current_image(): index_current = st.session_state['image_index_current'] this_img_current = st.session_state['image_data_list'][index_current] if isinstance(this_img_current, np.ndarray): return this_img_current else: this_img_current = load_image_array(st.session_state['image_path_list'][index_current]) st.session_state['image_data_list'][index_current] = this_img_current return this_img_current # Function: callback_button_measure # Description: Callback function for the "Measure" button. Calls either display_cached_measurement_data or display_calculate_measurement_data based on whether the current image has a cached measurement. def callback_button_measure(): has_measurement, measurement_result = get_current_measurement() if has_measurement: display_cached_measurement_data() else: display_calculate_measurement_data() # Function: display_cached_measurement_data # Description: Displays the cached measurement data for the current image. def display_cached_measurement_data(): st.info('Getting cached measurement', icon="â„šī¸") display_measurement() # Function: display_calculate_measurement_data # Description: Calculates the measurement data for the current image and updates the session state. Displays the measurement data. def display_calculate_measurement_data(): with st.spinner('Calculating Profile Height....'): this_image_path = st.session_state['image_path_list'][st.session_state['image_index_current']] measurement_result = measure_image(this_image_path) update_measurements(measurement_result, st.session_state['image_index_current']) st.success('Measurement is done !') display_measurement() # Function: measure_image # Description: Calls the measure_strip function from the strip_measure_4_0 module to measure the strip in the current image. Returns the measurement result. def measure_image(image_path: str): measurement_result = measure_strip(img_path=image_path, model_yolo=st.session_state['models']['detection'], segmentation_model=st.session_state['models']['segmentation'], camera_matrix=CAMERA_MATRIX, object_reference_points=OBJECT_REFERENCE_POINTS, camera_parameters=CAMERA_PARAMETERS, plane_parameters_close=PLANE_PARAMETERS_CLOSE, plane_parameters_far=PLANE_PARAMETERS_FAR, lower_contour_quadratic_constant=LOWER_CONTOUR_QUADRATIC_CONSTANT, boundary_1=BOUNDARY_1, boundary_2=BOUNDARY_2, image_size_seg=IMAGE_SIZE_SEG, image_width_seg=IMAGE_WIDTH_SEG, image_height_seg=IMAGE_HEIGHT_SEG) arr_0 = measurement_result[0] arr_1 = measurement_result[1] arr_0[:, 0] = np.abs(arr_0[:, 0]) arr_1[:, 0] = np.abs(arr_1[:, 0]) return arr_0, arr_1 # Function: get_current_measurement # Description: Returns the current measurement data based on the current image index. If the measurement data is not available, returns False and None. def get_current_measurement(): this_measurement = st.session_state['measurement_data_list'][st.session_state['image_index_current']] if this_measurement is not None: return True, this_measurement else: return False, None # Function: update_measurements # Description: Updates the measurement data for the specified index in the session state. def update_measurements(measurement, index_measurement): st.session_state['measurement_data_list'][index_measurement] = measurement # Function: display_measurement # Description: Displays the measurement data for the current image. def display_measurement(): has_measurement, measurement_data = get_current_measurement() if has_measurement: st.subheader(f'Profile height (mm)') measurement_to_streamlit_chart(measurement_data[0], measurement_data[1]) # Function: measurement_to_streamlit_chart # Description: Converts the measurement data into a Pandas DataFrame and plots a line chart using Plotly. def measurement_to_streamlit_chart(profile_array_1, profile_array_2): height_list = [] coord_list = [] indicator_list = [] height_list.extend(profile_array_1[:, 0].tolist()) coord_list.extend(profile_array_1[:, 1].tolist()) indicator_list.extend(['Profile A' for _ in range(len(profile_array_1))]) height_list.extend(profile_array_2[:, 0].tolist()) coord_list.extend(profile_array_2[:, 1].tolist()) indicator_list.extend(['Profile B' for _ in range(len(profile_array_2))]) df = pd.DataFrame(dict(x=coord_list, y=height_list, indicator=indicator_list)) fig = px.line(df, x='x', y='y', color='indicator', symbol="indicator") st.plotly_chart(fig, use_container_width=True) # Initialization section if 'image_path_list' not in st.session_state: st.session_state['image_path_list'] = get_image_paths(IMG_BASE_DIR) if 'num_images' not in st.session_state: st.session_state['num_images'] = get_num_images() if 'image_data_list' not in st.session_state: st.session_state['image_data_list'] = [None for _ in range(st.session_state['num_images'])] if 'image_index_current' not in st.session_state: st.session_state['image_index_current'] = 0 if 'current_image_array' not in st.session_state: st.session_state['current_image_array'] = get_current_image() if 'measurement_data_list' not in st.session_state: st.session_state['measurement_data_list'] = [None for _ in range(st.session_state['num_images'])] if 'current_measurement' not in st.session_state: st.session_state['current_measurement'] = get_current_measurement() if 'models' not in st.session_state: seg_dplv3_model, yolo_nn_model = prepare_networks_for_measurement(model_yolo_path=path_yolo_model, model_segmentation_path=path_segmentation_model) st.session_state['models'] = {'segmentation': seg_dplv3_model, 'detection': yolo_nn_model} # Display section image_emoji = '📷' model_emoji = 'âš™ī¸' profile_emoji = '📈' st.set_page_config(layout='wide') st.title('PantoScanner') #st.subheader(f'Source Image') multi = '''This app processes the detection and segementation of a pantograph sliding element from a train - and - demonstrates the extraction of the thickness by displaying it in a chart. To build a time series of thicknesses per sliding element, e.g. by selecting the thickness in the middle.''' st.markdown(multi) st.image(st.session_state['current_image_array']) st.markdown( """ """,unsafe_allow_html=True ) col1, col2, col3 = st.columns(3) # insert prev button --> decrement image_selected_index i = min(i -= 1, 0) % or just overflow to last image with col1: button_previous = st.button("Previous image", on_click=callback_button_previous, kwargs={'overflow_index': True}) with col2: button_measure = st.button("Measure") with col3: button_next = st.button("Next image", on_click=callback_button_next, kwargs={'overflow_index': True}) if button_measure: callback_button_measure()