Tomatillo commited on
Commit
341fabf
·
verified ·
1 Parent(s): 44eaf1f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +121 -67
src/streamlit_app.py CHANGED
@@ -1,8 +1,12 @@
1
  import streamlit as st
2
  import io
3
  import csv
4
- from datetime import datetime
5
  from segments import SegmentsClient
 
 
 
 
6
  from get_labels_from_samples import (
7
  get_samples as get_samples_objects,
8
  export_frames_and_annotations,
@@ -44,6 +48,95 @@ def parse_classes(input_str: str) -> list:
44
  return sorted(set(classes))
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def generate_csv(metrics: list, dataset_identifier: str) -> str:
48
  """
49
  Generate CSV content from list of per-sample metrics.
@@ -102,6 +195,9 @@ if api_key and dataset_identifier:
102
  if is_multisensor:
103
  sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names)
104
 
 
 
 
105
  if run_button:
106
  st.session_state.csv_content = None
107
  st.session_state.error = None
@@ -122,75 +218,33 @@ if run_button:
122
  st.info("Checking dataset type...")
123
  try:
124
  target_classes = parse_classes(classes_input)
125
- client = init_client(api_key)
126
  metrics = []
127
  # Update loader after dataset type check
128
  if status_ctx is not None:
129
  status_ctx.update(label="Dataset type checked. Processing samples...", state="running")
130
- for sample in samples_objects:
131
- try:
132
- label = client.get_label(sample.uuid)
133
- labelset = getattr(label, 'labelset', '') or ''
134
- labeled_by = getattr(label, 'created_by', '') or ''
135
- reviewed_by = getattr(label, 'reviewed_by', '') or ''
136
- if is_multisensor and sensor_select and sensor_select != 'All sensors':
137
- frames_list = export_sensor_frames_and_annotations(label, sensor_select)
138
- sensor_val = sensor_select
139
- num_frames = len(frames_list)
140
- total_annotations = sum(len(f['annotations']) for f in frames_list)
141
- matching_annotations = sum(
142
- 1
143
- for f in frames_list
144
- for ann in f['annotations']
145
- if getattr(ann, 'category_id', None) in target_classes
146
- )
147
- elif is_multisensor and (not sensor_select or sensor_select == 'All sensors'):
148
- all_sensor_frames = export_all_sensor_frames_and_annotations(label)
149
- for sensor_name, frames_list in all_sensor_frames.items():
150
- num_frames = len(frames_list)
151
- total_annotations = sum(len(f['annotations']) for f in frames_list)
152
- matching_annotations = sum(
153
- 1
154
- for f in frames_list
155
- for ann in f['annotations']
156
- if getattr(ann, 'category_id', None) in target_classes
157
- )
158
- metrics.append({
159
- 'name': getattr(sample, 'name', sample.uuid),
160
- 'uuid': sample.uuid,
161
- 'labelset': labelset,
162
- 'sensor': sensor_name,
163
- 'num_frames': num_frames,
164
- 'total_annotations': total_annotations,
165
- 'matching_annotations': matching_annotations,
166
- 'labeled_by': labeled_by,
167
- 'reviewed_by': reviewed_by
168
- })
169
- continue
170
- else:
171
- frames_list = export_frames_and_annotations(label)
172
- sensor_val = ''
173
- num_frames = len(frames_list)
174
- total_annotations = sum(len(f['annotations']) for f in frames_list)
175
- matching_annotations = sum(
176
- 1
177
- for f in frames_list
178
- for ann in f['annotations']
179
- if getattr(ann, 'category_id', None) in target_classes
180
- )
181
- metrics.append({
182
- 'name': getattr(sample, 'name', sample.uuid),
183
- 'uuid': sample.uuid,
184
- 'labelset': labelset,
185
- 'sensor': sensor_val if is_multisensor else '',
186
- 'num_frames': num_frames,
187
- 'total_annotations': total_annotations,
188
- 'matching_annotations': matching_annotations,
189
- 'labeled_by': labeled_by,
190
- 'reviewed_by': reviewed_by
191
- })
192
- except Exception as e:
193
- continue
194
  if not metrics:
195
  st.session_state.error = "No metrics could be generated for the dataset."
196
  else:
@@ -213,4 +267,4 @@ if st.session_state.csv_content:
213
  data=st.session_state.csv_content,
214
  file_name=filename,
215
  mime="text/csv"
216
- )
 
1
  import streamlit as st
2
  import io
3
  import csv
4
+ import concurrent.futures
5
  from segments import SegmentsClient
6
+ from datetime import datetime
7
+ import sys
8
+ import os
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../0_label_scripts")))
10
  from get_labels_from_samples import (
11
  get_samples as get_samples_objects,
12
  export_frames_and_annotations,
 
48
  return sorted(set(classes))
49
 
50
 
51
+ def _count_from_frames(frames, target_set):
52
+ """Helper to count frames, total annotations, and matching annotations directly."""
53
+ if not frames:
54
+ return 0, 0, 0
55
+ num_frames = len(frames)
56
+ total_annotations = 0
57
+ matching_annotations = 0
58
+ for f in frames:
59
+ anns = getattr(f, 'annotations', [])
60
+ total_annotations += len(anns)
61
+ if target_set:
62
+ for ann in anns:
63
+ if getattr(ann, 'category_id', None) in target_set:
64
+ matching_annotations += 1
65
+ return num_frames, total_annotations, matching_annotations
66
+
67
+
68
+ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sensor_select):
69
+ """
70
+ Fetch label for a single sample and compute metrics.
71
+ Returns a list of metric dicts (one per sensor if 'All sensors', otherwise one).
72
+ """
73
+ try:
74
+ client = init_client(api_key)
75
+ label = client.get_label(sample.uuid)
76
+ labelset = getattr(label, 'labelset', '') or ''
77
+ labeled_by = getattr(label, 'created_by', '') or ''
78
+ reviewed_by = getattr(label, 'reviewed_by', '') or ''
79
+
80
+ metrics_rows = []
81
+
82
+ if is_multisensor:
83
+ sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) or []
84
+ if sensor_select and sensor_select != 'All sensors':
85
+ # single sensor
86
+ for sensor in sensors:
87
+ if getattr(sensor, 'name', None) == sensor_select:
88
+ frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
89
+ num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
90
+ metrics_rows.append({
91
+ 'name': getattr(sample, 'name', sample.uuid),
92
+ 'uuid': sample.uuid,
93
+ 'labelset': labelset,
94
+ 'sensor': sensor_select,
95
+ 'num_frames': num_frames,
96
+ 'total_annotations': total_annotations,
97
+ 'matching_annotations': matching_annotations,
98
+ 'labeled_by': labeled_by,
99
+ 'reviewed_by': reviewed_by
100
+ })
101
+ break
102
+ else:
103
+ # all sensors
104
+ for sensor in sensors:
105
+ sensor_name = getattr(sensor, 'name', 'Unknown')
106
+ frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
107
+ num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
108
+ metrics_rows.append({
109
+ 'name': getattr(sample, 'name', sample.uuid),
110
+ 'uuid': sample.uuid,
111
+ 'labelset': labelset,
112
+ 'sensor': sensor_name,
113
+ 'num_frames': num_frames,
114
+ 'total_annotations': total_annotations,
115
+ 'matching_annotations': matching_annotations,
116
+ 'labeled_by': labeled_by,
117
+ 'reviewed_by': reviewed_by
118
+ })
119
+ else:
120
+ # single-sensor dataset
121
+ frames = getattr(getattr(label, 'attributes', None), 'frames', [])
122
+ num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
123
+ metrics_rows.append({
124
+ 'name': getattr(sample, 'name', sample.uuid),
125
+ 'uuid': sample.uuid,
126
+ 'labelset': labelset,
127
+ 'sensor': '',
128
+ 'num_frames': num_frames,
129
+ 'total_annotations': total_annotations,
130
+ 'matching_annotations': matching_annotations,
131
+ 'labeled_by': labeled_by,
132
+ 'reviewed_by': reviewed_by
133
+ })
134
+
135
+ return metrics_rows
136
+ except Exception:
137
+ return []
138
+
139
+
140
  def generate_csv(metrics: list, dataset_identifier: str) -> str:
141
  """
142
  Generate CSV content from list of per-sample metrics.
 
195
  if is_multisensor:
196
  sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names)
197
 
198
+ # Concurrency control
199
+ parallel_workers = st.slider("Parallel requests", min_value=1, max_value=32, value=8, help="Increase to speed up processing; lower if you hit API limits.")
200
+
201
  if run_button:
202
  st.session_state.csv_content = None
203
  st.session_state.error = None
 
218
  st.info("Checking dataset type...")
219
  try:
220
  target_classes = parse_classes(classes_input)
221
+ target_set = set(target_classes)
222
  metrics = []
223
  # Update loader after dataset type check
224
  if status_ctx is not None:
225
  status_ctx.update(label="Dataset type checked. Processing samples...", state="running")
226
+ progress = st.progress(0)
227
+ total = len(samples_objects)
228
+ done = 0
229
+ with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_workers) as executor:
230
+ futures = [
231
+ executor.submit(
232
+ compute_metrics_for_sample,
233
+ sample,
234
+ api_key,
235
+ target_set,
236
+ is_multisensor,
237
+ sensor_select,
238
+ )
239
+ for sample in samples_objects
240
+ ]
241
+ for future in concurrent.futures.as_completed(futures):
242
+ rows = future.result()
243
+ if rows:
244
+ metrics.extend(rows)
245
+ done += 1
246
+ if total:
247
+ progress.progress(min(done / total, 1.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  if not metrics:
249
  st.session_state.error = "No metrics could be generated for the dataset."
250
  else:
 
267
  data=st.session_state.csv_content,
268
  file_name=filename,
269
  mime="text/csv"
270
+ )