Tomatillo commited on
Commit
00a6dbb
·
verified ·
1 Parent(s): 3372d59

Updated to include count per class

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +53 -19
src/streamlit_app.py CHANGED
@@ -1,4 +1,22 @@
1
  #!/usr/bin/env python3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import streamlit as st
4
  import io
@@ -8,6 +26,7 @@ from segments import SegmentsClient
8
  from datetime import datetime
9
  import sys
10
  import os
 
11
  from get_labels_from_samples import (
12
  get_samples as get_samples_objects,
13
  export_frames_and_annotations,
@@ -49,24 +68,28 @@ def parse_classes(input_str: str) -> list:
49
  return sorted(set(classes))
50
 
51
 
52
- def _count_from_frames(frames, target_set):
53
- """Helper to count frames, total annotations, and matching annotations directly."""
54
  if not frames:
55
- return 0, 0, 0
56
  num_frames = len(frames)
57
  total_annotations = 0
58
  matching_annotations = 0
 
59
  for f in frames:
60
  anns = getattr(f, 'annotations', [])
61
  total_annotations += len(anns)
62
  if target_set:
63
  for ann in anns:
64
- if getattr(ann, 'category_id', None) in target_set:
 
65
  matching_annotations += 1
66
- return num_frames, total_annotations, matching_annotations
 
 
67
 
68
 
69
- def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sensor_select):
70
  """
71
  Fetch label for a single sample and compute metrics.
72
  Returns a list of metric dicts (one per sensor if 'All sensors', otherwise one).
@@ -87,7 +110,7 @@ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sens
87
  for sensor in sensors:
88
  if getattr(sensor, 'name', None) == sensor_select:
89
  frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
90
- num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
91
  metrics_rows.append({
92
  'name': getattr(sample, 'name', sample.uuid),
93
  'uuid': sample.uuid,
@@ -96,6 +119,7 @@ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sens
96
  'num_frames': num_frames,
97
  'total_annotations': total_annotations,
98
  'matching_annotations': matching_annotations,
 
99
  'labeled_by': labeled_by,
100
  'reviewed_by': reviewed_by
101
  })
@@ -105,7 +129,7 @@ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sens
105
  for sensor in sensors:
106
  sensor_name = getattr(sensor, 'name', 'Unknown')
107
  frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
108
- num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
109
  metrics_rows.append({
110
  'name': getattr(sample, 'name', sample.uuid),
111
  'uuid': sample.uuid,
@@ -114,13 +138,14 @@ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sens
114
  'num_frames': num_frames,
115
  'total_annotations': total_annotations,
116
  'matching_annotations': matching_annotations,
 
117
  'labeled_by': labeled_by,
118
  'reviewed_by': reviewed_by
119
  })
120
  else:
121
  # single-sensor dataset
122
  frames = getattr(getattr(label, 'attributes', None), 'frames', [])
123
- num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set)
124
  metrics_rows.append({
125
  'name': getattr(sample, 'name', sample.uuid),
126
  'uuid': sample.uuid,
@@ -129,6 +154,7 @@ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sens
129
  'num_frames': num_frames,
130
  'total_annotations': total_annotations,
131
  'matching_annotations': matching_annotations,
 
132
  'labeled_by': labeled_by,
133
  'reviewed_by': reviewed_by
134
  })
@@ -138,7 +164,7 @@ def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sens
138
  return []
139
 
140
 
141
- def generate_csv(metrics: list, dataset_identifier: str) -> str:
142
  """
143
  Generate CSV content from list of per-sample metrics.
144
  Columns: name, sample_url, sensor, num_frames, total_annotations,
@@ -146,19 +172,26 @@ def generate_csv(metrics: list, dataset_identifier: str) -> str:
146
  """
147
  output = io.StringIO()
148
  writer = csv.writer(output)
149
- writer.writerow([
150
  'name', 'sample_url', 'sensor', 'num_frames',
151
- 'total_annotations', 'matching_annotations',
152
- 'labeled_by', 'reviewed_by'
153
- ])
 
 
 
154
  for m in metrics:
155
  url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}"
156
- writer.writerow([
157
  m['name'], url, m['sensor'],
158
  m['num_frames'], m['total_annotations'],
159
- m['matching_annotations'], m['labeled_by'],
160
- m['reviewed_by']
161
- ])
 
 
 
 
162
  content = output.getvalue()
163
  output.close()
164
  return content
@@ -234,6 +267,7 @@ if run_button:
234
  sample,
235
  api_key,
236
  target_set,
 
237
  is_multisensor,
238
  sensor_select,
239
  )
@@ -249,7 +283,7 @@ if run_button:
249
  if not metrics:
250
  st.session_state.error = "No metrics could be generated for the dataset."
251
  else:
252
- st.session_state.csv_content = generate_csv(metrics, dataset_identifier)
253
  if status_ctx is not None:
254
  status_ctx.update(label="CSV generated!", state="complete")
255
  except Exception as e:
 
1
  #!/usr/bin/env python3
2
+ """
3
+ File: count_by_class.py
4
+ Location: 6_Data_metrics/count_by_class.py
5
+
6
+ Description:
7
+ Streamlit application to count and report metrics per sample for specified classes.
8
+ For each sample, outputs:
9
+ - Sample name
10
+ - Sample URL (including labelset)
11
+ - Number of frames (ignoring a 31st frame if present)
12
+ - Total number of annotations
13
+ - Number of annotations matching any of the user-specified classes
14
+ - Labeled by (from label data)
15
+ - Reviewed by (from label data)
16
+
17
+ Usage:
18
+ streamlit run 6_Data_metrics/count_by_class.py
19
+ """
20
 
21
  import streamlit as st
22
  import io
 
26
  from datetime import datetime
27
  import sys
28
  import os
29
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../0_label_scripts")))
30
  from get_labels_from_samples import (
31
  get_samples as get_samples_objects,
32
  export_frames_and_annotations,
 
68
  return sorted(set(classes))
69
 
70
 
71
+ def _count_from_frames(frames, target_set, class_ids):
72
+ """Helper to count frames, totals, and per-class counts directly."""
73
  if not frames:
74
+ return 0, 0, 0, {cid: 0 for cid in class_ids}
75
  num_frames = len(frames)
76
  total_annotations = 0
77
  matching_annotations = 0
78
+ class_counts = {cid: 0 for cid in class_ids}
79
  for f in frames:
80
  anns = getattr(f, 'annotations', [])
81
  total_annotations += len(anns)
82
  if target_set:
83
  for ann in anns:
84
+ cid = getattr(ann, 'category_id', None)
85
+ if cid in target_set:
86
  matching_annotations += 1
87
+ if cid in class_counts:
88
+ class_counts[cid] += 1
89
+ return num_frames, total_annotations, matching_annotations, class_counts
90
 
91
 
92
+ def compute_metrics_for_sample(sample, api_key, target_set, class_ids, is_multisensor, sensor_select):
93
  """
94
  Fetch label for a single sample and compute metrics.
95
  Returns a list of metric dicts (one per sensor if 'All sensors', otherwise one).
 
110
  for sensor in sensors:
111
  if getattr(sensor, 'name', None) == sensor_select:
112
  frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
113
+ num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids)
114
  metrics_rows.append({
115
  'name': getattr(sample, 'name', sample.uuid),
116
  'uuid': sample.uuid,
 
119
  'num_frames': num_frames,
120
  'total_annotations': total_annotations,
121
  'matching_annotations': matching_annotations,
122
+ 'class_counts': class_counts,
123
  'labeled_by': labeled_by,
124
  'reviewed_by': reviewed_by
125
  })
 
129
  for sensor in sensors:
130
  sensor_name = getattr(sensor, 'name', 'Unknown')
131
  frames = getattr(getattr(sensor, 'attributes', None), 'frames', [])
132
+ num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids)
133
  metrics_rows.append({
134
  'name': getattr(sample, 'name', sample.uuid),
135
  'uuid': sample.uuid,
 
138
  'num_frames': num_frames,
139
  'total_annotations': total_annotations,
140
  'matching_annotations': matching_annotations,
141
+ 'class_counts': class_counts,
142
  'labeled_by': labeled_by,
143
  'reviewed_by': reviewed_by
144
  })
145
  else:
146
  # single-sensor dataset
147
  frames = getattr(getattr(label, 'attributes', None), 'frames', [])
148
+ num_frames, total_annotations, matching_annotations, class_counts = _count_from_frames(frames, target_set, class_ids)
149
  metrics_rows.append({
150
  'name': getattr(sample, 'name', sample.uuid),
151
  'uuid': sample.uuid,
 
154
  'num_frames': num_frames,
155
  'total_annotations': total_annotations,
156
  'matching_annotations': matching_annotations,
157
+ 'class_counts': class_counts,
158
  'labeled_by': labeled_by,
159
  'reviewed_by': reviewed_by
160
  })
 
164
  return []
165
 
166
 
167
+ def generate_csv(metrics: list, dataset_identifier: str, target_classes: list[int]) -> str:
168
  """
169
  Generate CSV content from list of per-sample metrics.
170
  Columns: name, sample_url, sensor, num_frames, total_annotations,
 
172
  """
173
  output = io.StringIO()
174
  writer = csv.writer(output)
175
+ header = [
176
  'name', 'sample_url', 'sensor', 'num_frames',
177
+ 'total_annotations', 'matching_annotations'
178
+ ]
179
+ # dynamic per-class columns
180
+ header.extend([f'class_{cid}' for cid in target_classes])
181
+ header.extend(['labeled_by', 'reviewed_by'])
182
+ writer.writerow(header)
183
  for m in metrics:
184
  url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}"
185
+ row = [
186
  m['name'], url, m['sensor'],
187
  m['num_frames'], m['total_annotations'],
188
+ m['matching_annotations']
189
+ ]
190
+ # add per-class counts in the same order as header
191
+ class_counts = m.get('class_counts', {})
192
+ row.extend([class_counts.get(cid, 0) for cid in target_classes])
193
+ row.extend([m['labeled_by'], m['reviewed_by']])
194
+ writer.writerow(row)
195
  content = output.getvalue()
196
  output.close()
197
  return content
 
267
  sample,
268
  api_key,
269
  target_set,
270
+ target_classes,
271
  is_multisensor,
272
  sensor_select,
273
  )
 
283
  if not metrics:
284
  st.session_state.error = "No metrics could be generated for the dataset."
285
  else:
286
+ st.session_state.csv_content = generate_csv(metrics, dataset_identifier, target_classes)
287
  if status_ctx is not None:
288
  status_ctx.update(label="CSV generated!", state="complete")
289
  except Exception as e: