lhofstetter commited on
Commit
3dd62d7
·
verified ·
1 Parent(s): f903374

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -42
app.py CHANGED
@@ -3,60 +3,221 @@ import os
3
  import glob
4
  import cv2
5
  import numpy as np
 
 
 
6
 
7
- image_emoji = '📷'
8
- model_emoji = '⚙️'
9
- profile_emoji = '📈'
10
- st.title('PantoScanner')
11
- st.header('Example')
12
- st.subheader('Thickness measurement of sliding element')
13
 
14
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def generate_data(slope, intercept, num_points):
17
- """
18
- Generates data points with a linear degression and a +/- 5% tolerance.
19
 
20
- Args:
21
- slope: The slope of the linear degression.
22
- intercept: The y-intercept of the linear degression.
23
- num_points: The number of data points to generate.
 
 
 
 
 
24
 
25
- Returns:
26
- A numpy array of size (num_points, 1) containing the data points.
27
- """
28
- x = np.linspace(0, 1, num_points) # Creates evenly spaced x-values
29
- y = slope * x + intercept # Generates linear function values
30
 
31
- # Add random noise with +/- 5% tolerance
32
- noise = np.random.uniform(low=-0.07, high=0.01, size=num_points)
33
- y += noise * y # Scale noise by original y value for percentage variation
 
 
 
34
 
35
- return y.reshape(-1, 1) # Reshape to column vector
36
 
 
 
 
37
 
38
- tab1, tab2, tab3 = st.tabs([f' {image_emoji} Image', f' {model_emoji} Mask', f' {profile_emoji} Measurement'])
39
 
40
- with tab1:
41
- st.header(f'Source Image')
42
- img_array = cv2.imread(glob.glob(f'{os.getcwd()}/*.png')[0])
43
- st.image(img_array)
 
 
 
44
 
45
- with tab2:
46
- st.header(f'Model Output')
47
- #st.image("https://static.streamlit.io/examples/dog.jpg", width=200)
48
 
49
- with tab3:
50
- st.header(f'Profile Height')
51
- # data = np.random.randn(10, 1)
52
- # Example usage
53
- # Use 'data' for your chart with linear degression and +/- 5% tolerance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- slope = -2 # Example slope for linear degression
56
- intercept = 45
57
- num_points = 20
58
-
59
- data = generate_data(slope, intercept, num_points)
60
- st.line_chart(data)
61
 
 
 
62
 
 
 
 
3
  import glob
4
  import cv2
5
  import numpy as np
6
+ from strip_measure_4_0 import prepare_networks_for_measurement, measure_strip
7
+ import plotly.express as px
8
+ import pandas as pd
9
 
 
 
 
 
 
 
10
 
11
+ IMG_BASE_DIR = 'images'
12
+ CAMERA_MATRIX = [
13
+ [11100, 0, 1604],
14
+ [0, 11100, 1100],
15
+ [0, 0, 1]
16
+ ]
17
+
18
+ OBJECT_REFERENCE_POINTS = [
19
+ [347, 0, 42], # B
20
+ [347, 0, 522], # D
21
+ [-347, 26, 480], # F
22
+ [-347, 26, 0]] # H
23
+
24
+ LOWER_CONTOUR_QUADRATIC_CONSTANT = 0.00005
25
+ CAMERA_PARAMETERS = (50, 0.0045, 2200, 3208)
26
+ PLANE_PARAMETERS_CLOSE = ([0, 0, 0], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin
27
+ PLANE_PARAMETERS_FAR = ([0, 0, 480], (1, 0, 0), (0, 1, 0)) # Vector from pantograph coordinate frame to plane origin
28
+ BOUNDARY_1 = (300, 92)
29
+ BOUNDARY_2 = (650, 1500)
30
+ IMAGE_SIZE_SEG = 1408
31
+ IMAGE_WIDTH_SEG = 1408
32
+ IMAGE_HEIGHT_SEG = 576
33
+
34
+ path_yolo_model = os.path.join(os.getcwd(), 'app', 'best.pt')
35
+ path_segmentation_model = os.path.join(os.getcwd(), 'app', '31_best_model.pth')
36
+
37
+
38
+ def get_image_paths(base_dir: str):
39
+ return glob.glob(f'{os.getcwd()}/{base_dir}/*.png')
40
+
41
+
42
+ def get_num_images():
43
+ return len(st.session_state['image_path_list'])
44
+
45
+
46
+ def increment_index(index_current: int, max_index: int, overflow=False, min_index=0):
47
+ index_new = index_current + 1
48
+ if index_new <= max_index:
49
+ return index_new
50
+ elif overflow:
51
+ return min_index
52
+ else:
53
+ return index_current
54
+
55
+
56
+ def decrement_index(index_current: int, min_index, overflow=False, max_index=-1):
57
+ index_new = index_current - 1
58
+ if index_new >= min_index:
59
+ return index_new
60
+ elif overflow:
61
+ return max_index
62
+ else:
63
+ return index_current
64
+
65
+
66
+ def callback_button_previous(overflow_index=True):
67
+ new_index = decrement_index(st.session_state['image_index_current'], min_index=0,
68
+ overflow=overflow_index, max_index=st.session_state['num_images']-1)
69
+ update_on_index_change(new_index)
70
+
71
+
72
+ def callback_button_next(overflow_index=True):
73
+ new_index = increment_index(st.session_state['image_index_current'], st.session_state['num_images'],
74
+ overflow=overflow_index, min_index=0)
75
+ update_on_index_change(new_index)
76
+
77
+
78
+ def update_on_index_change(new_index: int):
79
+ st.session_state['image_index_current'] = new_index
80
+ st.session_state['current_image_array'] = get_current_image()
81
+ # put the current bale boundaries into the list, regardless of whether they have been stored to the database
82
+ st.session_state['current_measurement'] = get_current_measurement()
83
+
84
+
85
+ def load_image_array(image_path: str):
86
+ return cv2.imread(image_path)
87
 
 
 
 
88
 
89
+ def get_current_image():
90
+ index_current = st.session_state['image_index_current']
91
+ this_img_current = st.session_state['image_data_list'][index_current]
92
+ if isinstance(this_img_current, np.ndarray):
93
+ return this_img_current
94
+ else:
95
+ this_img_current = load_image_array(st.session_state['image_path_list'][index_current])
96
+ st.session_state['image_data_list'][index_current] = this_img_current
97
+ return this_img_current
98
 
 
 
 
 
 
99
 
100
+ def callback_button_measure():
101
+ has_measurement, measurement_result = get_current_measurement()
102
+ if has_measurement:
103
+ display_cached_measurement_data()
104
+ else:
105
+ display_calculate_measurement_data()
106
 
 
107
 
108
+ def display_cached_measurement_data():
109
+ st.info('Getting cached measurement', icon="ℹ️")
110
+ display_measurement()
111
 
 
112
 
113
+ def display_calculate_measurement_data():
114
+ with st.spinner('Calculating Profile Height....'):
115
+ this_image_path = st.session_state['image_path_list'][st.session_state['image_index_current']]
116
+ measurement_result = measure_image(this_image_path)
117
+ update_measurements(measurement_result, st.session_state['image_index_current'])
118
+ st.success('Measurement is done !')
119
+ display_measurement()
120
 
 
 
 
121
 
122
+ def measure_image(image_path: str):
123
+ measurement_result = measure_strip(img_path=image_path,
124
+ model_yolo=st.session_state['models']['detection'],
125
+ segmentation_model=st.session_state['models']['segmentation'],
126
+ camera_matrix=CAMERA_MATRIX,
127
+ object_reference_points=OBJECT_REFERENCE_POINTS,
128
+ camera_parameters=CAMERA_PARAMETERS,
129
+ plane_parameters_close=PLANE_PARAMETERS_CLOSE,
130
+ plane_parameters_far=PLANE_PARAMETERS_FAR,
131
+ lower_contour_quadratic_constant=LOWER_CONTOUR_QUADRATIC_CONSTANT,
132
+ boundary_1=BOUNDARY_1,
133
+ boundary_2=BOUNDARY_2,
134
+ image_size_seg=IMAGE_SIZE_SEG,
135
+ image_width_seg=IMAGE_WIDTH_SEG,
136
+ image_height_seg=IMAGE_HEIGHT_SEG)
137
+ arr_0 = measurement_result[0]
138
+ arr_1 = measurement_result[1]
139
+ arr_0[:, 0] = np.abs(arr_0[:, 0])
140
+ arr_1[:, 0] = np.abs(arr_1[:, 0])
141
+ return arr_0, arr_1
142
+
143
+
144
+ def get_current_measurement():
145
+ this_measurement = st.session_state['measurement_data_list'][st.session_state['image_index_current']]
146
+ if this_measurement is not None:
147
+ return True, this_measurement
148
+ else:
149
+ return False, None
150
+
151
+
152
+ def update_measurements(measurement, index_measurement):
153
+ st.session_state['measurement_data_list'][index_measurement] = measurement
154
+
155
+
156
+ def display_measurement():
157
+ has_measurement, measurement_data = get_current_measurement()
158
+ if has_measurement:
159
+ st.subheader(f'Profile Height')
160
+ measurement_to_streamlit_chart(measurement_data[0], measurement_data[1])
161
+
162
+
163
+ def measurement_to_streamlit_chart(profile_array_1, profile_array_2):
164
+ height_list = []
165
+ coord_list = []
166
+ indicator_list = []
167
+ height_list.extend(profile_array_1[:, 0].tolist())
168
+ coord_list.extend(profile_array_1[:, 1].tolist())
169
+ indicator_list.extend(['Profile A' for _ in range(len(profile_array_1))])
170
+ height_list.extend(profile_array_2[:, 0].tolist())
171
+ coord_list.extend(profile_array_2[:, 1].tolist())
172
+ indicator_list.extend(['Profile B' for _ in range(len(profile_array_2))])
173
+ df = pd.DataFrame(dict(x=coord_list, y=height_list, indicator=indicator_list))
174
+ fig = px.line(df, x='x', y='y', color='indicator', symbol="indicator")
175
+ st.plotly_chart(fig, use_container_width=True)
176
+
177
+
178
+ if 'image_path_list' not in st.session_state:
179
+ st.session_state['image_path_list'] = get_image_paths(IMG_BASE_DIR)
180
+
181
+ if 'num_images' not in st.session_state:
182
+ st.session_state['num_images'] = get_num_images()
183
+
184
+ if 'image_data_list' not in st.session_state:
185
+ st.session_state['image_data_list'] = [None for _ in range(st.session_state['num_images'])]
186
+
187
+ if 'image_index_current' not in st.session_state:
188
+ st.session_state['image_index_current'] = 0
189
+
190
+ if 'current_image_array' not in st.session_state:
191
+ st.session_state['current_image_array'] = get_current_image()
192
+
193
+ if 'measurement_data_list' not in st.session_state:
194
+ st.session_state['measurement_data_list'] = [None for _ in range(st.session_state['num_images'])]
195
+
196
+ if 'current_measurement' not in st.session_state:
197
+ st.session_state['current_measurement'] = get_current_measurement()
198
+
199
+ if 'models' not in st.session_state:
200
+ seg_dplv3_model, yolo_nn_model = prepare_networks_for_measurement(model_yolo_path=path_yolo_model,
201
+ model_segmentation_path=path_segmentation_model)
202
+ st.session_state['models'] = {'segmentation': seg_dplv3_model, 'detection': yolo_nn_model}
203
+
204
+
205
+ image_emoji = '📷'
206
+ model_emoji = '⚙️'
207
+ profile_emoji = '📈'
208
+ st.title('PantoScanner')
209
+ #st.subheader(f'Source Image')
210
+ st.image(st.session_state['current_image_array'])
211
+ col1, col2, col3 = st.columns(3)
212
+ # insert prev button --> decrement image_selected_index i = min(i -= 1, 0) % or just overflow to last image
213
+ with col1:
214
+ button_previous = st.button("previous Image", on_click=callback_button_previous, kwargs={'overflow_index': True})
215
 
216
+ with col2:
217
+ button_measure = st.button("Measure")
 
 
 
 
218
 
219
+ with col3:
220
+ button_next = st.button("next Image", on_click=callback_button_next, kwargs={'overflow_index': True})
221
 
222
+ if button_measure:
223
+ callback_button_measure()