dataviewer / app.py
wuhp's picture
Update app.py
02f8610 verified
import re
import json
import requests
import matplotlib.pyplot as plt
import gradio as gr
from requests.exceptions import HTTPError
def parse_roboflow_url(url):
"""
Extract workspace/project and version from a Roboflow Universe URL.
Returns (workspace, project, version)
"""
pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
match = re.search(pattern, url)
if not match:
raise ValueError(f"Invalid Roboflow dataset URL: {url}")
return match.groups()
def fetch_metadata(api_key, workspace, project, version):
"""
Fetch metadata from Roboflow. Raises ValueError on HTTP errors.
"""
endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
resp = requests.get(endpoint, params={"api_key": api_key})
try:
resp.raise_for_status()
except HTTPError:
if resp.status_code == 401:
raise ValueError("Unauthorized: check your API key.")
else:
raise ValueError(f"Error {resp.status_code} for {workspace}/{project}/{version}")
data = resp.json()
total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0)
classes = data.get("project", {}).get("classes", {})
return total, classes
def aggregate_datasets(api_key, entries):
"""
Given list of (url, file, line), returns:
- total_images
- dict[class_name_lowercase] = aggregated count
- dict[class_name_lowercase] = set(source URLs)
"""
total_images = 0
class_counts = {}
class_sources = {}
for url, fname, lineno in entries:
try:
ws, proj, ver = parse_roboflow_url(url)
except ValueError:
raise ValueError(f"Invalid URL '{url}' in file '{fname}', line {lineno}")
imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
total_images += imgs
for cls, cnt in cls_map.items():
norm = cls.strip().lower()
class_counts[norm] = class_counts.get(norm, 0) + cnt
class_sources.setdefault(norm, set()).add(url)
return total_images, class_counts, class_sources
def make_bar_chart(counts):
"""
Build a bar chart from a {label: value} dict.
"""
fig, ax = plt.subplots()
labels = list(counts.keys())
values = list(counts.values())
ax.bar(range(len(labels)), values)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.set_ylabel("Image Count")
ax.set_title("Class Distribution")
fig.tight_layout()
return fig
def load_datasets(api_key, file_objs):
"""
1) Ensure API key present
2) Read & dedupe URLs from each uploaded .txt
3) Fetch & aggregate metadata
Returns: total, table_data, figure, json_counts, markdown_sources
"""
if not api_key or not api_key.strip():
raise ValueError("Please enter your Roboflow API Key before loading datasets.")
entries = []
seen = set()
for fobj in file_objs:
fname = getattr(fobj, "name", None) or fobj.get("name", "unknown")
# read raw bytes or dict-data or file path
try:
raw = fobj.read()
except:
raw = fobj.get("data") if isinstance(fobj, dict) else None
if raw is None and isinstance(fobj, str):
with open(fobj, "rb") as fh:
raw = fh.read()
text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw
for i, line in enumerate(text.splitlines(), start=1):
url = line.strip()
if url and url not in seen:
seen.add(url)
entries.append((url, fname, i))
total, counts, sources = aggregate_datasets(api_key, entries)
# build dataframe rows
table_data = [[cls, counts[cls]] for cls in counts]
# build clickable markdown per-class
md_lines = []
for cls in counts:
links = ", ".join(f"[{url.split('/')[-1]}]({url})" for url in sources[cls])
md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}")
md_sources = "\n".join(md_lines)
fig = make_bar_chart(counts)
return str(total), table_data, fig, json.dumps(counts, indent=2), md_sources
def update_classes(df_data):
"""
Convert df_data into a list-of-lists (if needed),
merge duplicate/lowercased class names, and recalc all outputs.
Returns: total, updated_table, figure, json_counts, markdown_summary
"""
# convert Pandas DataFrame or NumPy array into list-of-lists
if not isinstance(df_data, list):
if hasattr(df_data, "to_numpy"):
df_data = df_data.to_numpy().tolist()
elif hasattr(df_data, "tolist"):
df_data = df_data.tolist()
combined = {}
for row in df_data:
if len(row) < 2:
continue
name, cnt = row[0], row[1]
if not name:
continue
key = str(name).strip().lower()
try:
val = int(cnt)
except:
val = 0
combined[key] = combined.get(key, 0) + val
total = sum(combined.values())
updated_table = [[k, combined[k]] for k in combined]
fig = make_bar_chart(combined)
md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined)
return str(total), updated_table, fig, json.dumps(combined, indent=2), md_summary
def build_ui():
with gr.Blocks() as demo:
gr.Markdown("## Roboflow Dataset Inspector")
with gr.Row():
api_input = gr.Textbox(label="API Key", type="password")
files = gr.Files(label="Upload .txt files", file_types=[".txt"])
load_btn = gr.Button("Load Datasets")
total_out = gr.Textbox(label="Total Images", interactive=False)
df = gr.Dataframe(
headers=["Class Name", "Count"],
row_count=(1, None),
col_count=2,
interactive=True
)
plot = gr.Plot()
json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
md_out = gr.Markdown(label="Class Sources")
update_btn = gr.Button("Apply Class Edits")
load_btn.click(
fn=load_datasets,
inputs=[api_input, files],
outputs=[total_out, df, plot, json_out, md_out]
)
update_btn.click(
fn=update_classes,
inputs=[df],
outputs=[total_out, df, plot, json_out, md_out]
)
return demo
if __name__ == "__main__":
build_ui().launch()