wuhp commited on
Commit
5067213
·
verified ·
1 Parent(s): 7451f43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -60
app.py CHANGED
@@ -3,50 +3,53 @@ import json
3
  import requests
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
 
6
 
7
 
8
  def parse_roboflow_url(url):
9
  """
10
  Extract workspace/project and version from a Roboflow Universe URL.
11
- Example: https://universe.roboflow.com/airborne-object-detection/airborne-object-detection-4-aod4/dataset/6
12
  Returns (workspace, project, version)
13
  """
14
  pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
15
  match = re.search(pattern, url)
16
  if not match:
17
  raise ValueError(f"Invalid Roboflow dataset URL: {url}")
18
- return match.groups() # (workspace, project, version)
19
 
20
 
21
  def fetch_metadata(api_key, workspace, project, version):
22
  """
23
  Fetch metadata for a given project version from Roboflow API.
24
- Returns total image count and class->count mapping.
25
  """
26
  endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
27
- resp = requests.get(endpoint, params={"api_key": api_key})
28
- resp.raise_for_status()
 
 
 
 
 
 
29
  data = resp.json()
30
- total = data.get('version', {}).get('images') or data.get('project', {}).get('images', 0)
31
- classes = data.get('project', {}).get('classes', {})
32
  return total, classes
33
 
34
 
35
  def aggregate_datasets(api_key, entries):
36
  """
37
- Given API key and list of (url, file_name, line_no) tuples,
38
- returns total images, aggregated lowercase class counts,
39
  and per-class source URLs.
40
- Raises ValueError with file and line for invalid URLs.
41
  """
42
  total_images = 0
43
  class_counts = {}
44
  class_sources = {}
45
  for url, fname, lineno in entries:
46
- try:
47
- ws, proj, ver = parse_roboflow_url(url)
48
- except ValueError:
49
- raise ValueError(f"Invalid URL '{url}' in file '{fname}', line {lineno}")
50
  imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
51
  total_images += imgs
52
  for cls, cnt in cls_map.items():
@@ -61,38 +64,36 @@ def make_bar_chart(counts):
61
  Return a matplotlib figure showing a bar chart of counts dict.
62
  """
63
  fig, ax = plt.subplots()
64
- ax.bar(counts.keys(), counts.values())
65
- ax.set_xticklabels(counts.keys(), rotation=45, ha='right')
66
- ax.set_ylabel('Image Count')
67
- ax.set_title('Class Distribution')
 
 
 
68
  fig.tight_layout()
69
  return fig
70
 
71
 
72
  def load_datasets(api_key, file_objs):
73
  """
74
- Read multiple .txt uploads, parse URLs with file/line info,
75
- dedupe URLs, and aggregate metadata. Reports precise errors.
76
- Returns: total_images, dataframe_data, plot_fig, json_counts, markdown_sources.
77
  """
78
  entries = []
79
  seen = set()
80
-
81
  for fobj in file_objs:
82
- # Determine filename for error reporting
83
- fname = getattr(fobj, 'name', None) or fobj.get('name', 'unknown')
84
- # Attempt to read raw bytes or retrieve .data
85
- raw = None
86
  try:
87
  raw = fobj.read()
88
- except Exception:
89
- raw = fobj.get('data') if isinstance(fobj, dict) else None
90
  if raw is None and isinstance(fobj, str):
91
- with open(fobj, 'rb') as f:
92
- raw = f.read()
93
- content = raw.decode('utf-8') if isinstance(raw, (bytes, bytearray)) else raw
94
-
95
- for i, line in enumerate(content.splitlines(), start=1):
96
  url = line.strip()
97
  if url and url not in seen:
98
  seen.add(url)
@@ -100,10 +101,10 @@ def load_datasets(api_key, file_objs):
100
 
101
  total, counts, sources = aggregate_datasets(api_key, entries)
102
 
103
- # Prepare DataFrame data
104
  df_data = [[cls, counts[cls]] for cls in counts]
105
 
106
- # Prepare clickable sources markdown
107
  md_lines = []
108
  for cls in counts:
109
  links = ", ".join(f"[{s.split('/')[-1]}]({s})" for s in sources[cls])
@@ -116,26 +117,24 @@ def load_datasets(api_key, file_objs):
116
 
117
  def update_classes(df_data):
118
  """
119
- Combine edited classes (merge duplicates, lowercase) and recalc.
120
- Returns: total_images, updated_dataframe, plot_fig, json_counts, markdown_summary.
121
  """
122
  combined = {}
123
- for row in df_data:
124
- if not row[0]:
125
  continue
126
- name = row[0].strip().lower()
127
  try:
128
- cnt = int(row[1])
129
- except Exception:
130
- cnt = 0
131
- combined[name] = combined.get(name, 0) + cnt
132
 
133
  total = sum(combined.values())
134
- # Build updated DataFrame
135
- updated_df = [[cls, combined[cls]] for cls in combined]
136
-
137
  fig = make_bar_chart(combined)
138
- md_summary = "\n".join(f"- **{cls}** ({combined[cls]} images)" for cls in combined)
139
  return str(total), updated_df, fig, json.dumps(combined, indent=2), md_summary
140
 
141
 
@@ -144,27 +143,30 @@ def build_ui():
144
  gr.Markdown("## Roboflow Dataset Inspector")
145
 
146
  with gr.Row():
147
- api_input = gr.Textbox(label="Roboflow API Key", type="password")
148
- files = gr.Files(label="Upload .txt files of Roboflow URLs", file_types=[".txt"])
149
 
150
  load_btn = gr.Button("Load Datasets")
151
  total_out = gr.Textbox(label="Total Images", interactive=False)
152
- df = gr.Dataframe(headers=["Class Name", "Count"], row_count=(1, None), col_count=2, interactive=True)
 
 
153
  plot = gr.Plot()
154
  json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
155
  md_out = gr.Markdown(label="Class Sources")
156
 
157
  update_btn = gr.Button("Apply Class Edits")
158
 
159
- # Load datasets
160
- load_btn.click(fn=load_datasets,
161
- inputs=[api_input, files],
162
- outputs=[total_out, df, plot, json_out, md_out])
163
-
164
- # Apply edits and refresh all outputs (including table)
165
- update_btn.click(fn=update_classes,
166
- inputs=[df],
167
- outputs=[total_out, df, plot, json_out, md_out])
 
168
 
169
  return demo
170
 
 
3
  import requests
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
+ from requests.exceptions import HTTPError
7
 
8
 
9
  def parse_roboflow_url(url):
10
  """
11
  Extract workspace/project and version from a Roboflow Universe URL.
12
+ Example: https://universe.roboflow.com/.../dataset/6
13
  Returns (workspace, project, version)
14
  """
15
  pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
16
  match = re.search(pattern, url)
17
  if not match:
18
  raise ValueError(f"Invalid Roboflow dataset URL: {url}")
19
+ return match.groups()
20
 
21
 
22
  def fetch_metadata(api_key, workspace, project, version):
23
  """
24
  Fetch metadata for a given project version from Roboflow API.
25
+ Raises ValueError on HTTP errors.
26
  """
27
  endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
28
+ try:
29
+ resp = requests.get(endpoint, params={"api_key": api_key})
30
+ resp.raise_for_status()
31
+ except HTTPError:
32
+ if resp.status_code == 401:
33
+ raise ValueError("Unauthorized: check your API key.")
34
+ else:
35
+ raise ValueError(f"Error fetching {workspace}/{project}/{version}: {resp.status_code}")
36
  data = resp.json()
37
+ total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0)
38
+ classes = data.get("project", {}).get("classes", {})
39
  return total, classes
40
 
41
 
42
  def aggregate_datasets(api_key, entries):
43
  """
44
+ Given API key and list of (url, file, line) tuples,
45
+ returns total_images, aggregated lowercase class counts,
46
  and per-class source URLs.
 
47
  """
48
  total_images = 0
49
  class_counts = {}
50
  class_sources = {}
51
  for url, fname, lineno in entries:
52
+ ws, proj, ver = parse_roboflow_url(url)
 
 
 
53
  imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
54
  total_images += imgs
55
  for cls, cnt in cls_map.items():
 
64
  Return a matplotlib figure showing a bar chart of counts dict.
65
  """
66
  fig, ax = plt.subplots()
67
+ keys = list(counts.keys())
68
+ vals = list(counts.values())
69
+ ax.bar(range(len(keys)), vals)
70
+ ax.set_xticks(range(len(keys)))
71
+ ax.set_xticklabels(keys, rotation=45, ha="right")
72
+ ax.set_ylabel("Image Count")
73
+ ax.set_title("Class Distribution")
74
  fig.tight_layout()
75
  return fig
76
 
77
 
78
  def load_datasets(api_key, file_objs):
79
  """
80
+ Read uploaded .txt files, dedupe URLs, fetch metadata,
81
+ and return all outputs for the UI.
 
82
  """
83
  entries = []
84
  seen = set()
 
85
  for fobj in file_objs:
86
+ fname = getattr(fobj, "name", None) or fobj.get("name", "unknown")
87
+ # read raw content
 
 
88
  try:
89
  raw = fobj.read()
90
+ except:
91
+ raw = fobj.get("data") if isinstance(fobj, dict) else None
92
  if raw is None and isinstance(fobj, str):
93
+ with open(fobj, "rb") as fh:
94
+ raw = fh.read()
95
+ text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw
96
+ for i, line in enumerate(text.splitlines(), start=1):
 
97
  url = line.strip()
98
  if url and url not in seen:
99
  seen.add(url)
 
101
 
102
  total, counts, sources = aggregate_datasets(api_key, entries)
103
 
104
+ # build dataframe list
105
  df_data = [[cls, counts[cls]] for cls in counts]
106
 
107
+ # build markdown of sources
108
  md_lines = []
109
  for cls in counts:
110
  links = ", ".join(f"[{s.split('/')[-1]}]({s})" for s in sources[cls])
 
117
 
118
  def update_classes(df_data):
119
  """
120
+ Take the edited table rows, merge duplicates (lowercase),
121
+ and return updated total, df, chart, JSON, and markdown.
122
  """
123
  combined = {}
124
+ for name, cnt in df_data:
125
+ if not name:
126
  continue
127
+ key = name.strip().lower()
128
  try:
129
+ val = int(cnt)
130
+ except:
131
+ val = 0
132
+ combined[key] = combined.get(key, 0) + val
133
 
134
  total = sum(combined.values())
135
+ updated_df = [[k, combined[k]] for k in combined]
 
 
136
  fig = make_bar_chart(combined)
137
+ md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined)
138
  return str(total), updated_df, fig, json.dumps(combined, indent=2), md_summary
139
 
140
 
 
143
  gr.Markdown("## Roboflow Dataset Inspector")
144
 
145
  with gr.Row():
146
+ api_input = gr.Textbox(label="API Key", type="password")
147
+ files = gr.Files(label="Upload .txt files", file_types=[".txt"])
148
 
149
  load_btn = gr.Button("Load Datasets")
150
  total_out = gr.Textbox(label="Total Images", interactive=False)
151
+ df = gr.Dataframe(
152
+ headers=["Class Name", "Count"], row_count=(1, None), col_count=2, interactive=True
153
+ )
154
  plot = gr.Plot()
155
  json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
156
  md_out = gr.Markdown(label="Class Sources")
157
 
158
  update_btn = gr.Button("Apply Class Edits")
159
 
160
+ load_btn.click(
161
+ fn=load_datasets,
162
+ inputs=[api_input, files],
163
+ outputs=[total_out, df, plot, json_out, md_out],
164
+ )
165
+ update_btn.click(
166
+ fn=update_classes,
167
+ inputs=[df],
168
+ outputs=[total_out, df, plot, json_out, md_out],
169
+ )
170
 
171
  return demo
172