Spaces:
Sleeping
Sleeping
| import re | |
| import json | |
| import requests | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from requests.exceptions import HTTPError | |
| def parse_roboflow_url(url): | |
| """ | |
| Extract workspace/project and version from a Roboflow Universe URL. | |
| Returns (workspace, project, version) | |
| """ | |
| pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)" | |
| match = re.search(pattern, url) | |
| if not match: | |
| raise ValueError(f"Invalid Roboflow dataset URL: {url}") | |
| return match.groups() | |
| def fetch_metadata(api_key, workspace, project, version): | |
| """ | |
| Fetch metadata from Roboflow. Raises ValueError on HTTP errors. | |
| """ | |
| endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}" | |
| resp = requests.get(endpoint, params={"api_key": api_key}) | |
| try: | |
| resp.raise_for_status() | |
| except HTTPError: | |
| if resp.status_code == 401: | |
| raise ValueError("Unauthorized: check your API key.") | |
| else: | |
| raise ValueError(f"Error {resp.status_code} for {workspace}/{project}/{version}") | |
| data = resp.json() | |
| total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0) | |
| classes = data.get("project", {}).get("classes", {}) | |
| return total, classes | |
| def aggregate_datasets(api_key, entries): | |
| """ | |
| Given list of (url, file, line), returns: | |
| - total_images | |
| - dict[class_name_lowercase] = aggregated count | |
| - dict[class_name_lowercase] = set(source URLs) | |
| """ | |
| total_images = 0 | |
| class_counts = {} | |
| class_sources = {} | |
| for url, fname, lineno in entries: | |
| try: | |
| ws, proj, ver = parse_roboflow_url(url) | |
| except ValueError: | |
| raise ValueError(f"Invalid URL '{url}' in file '{fname}', line {lineno}") | |
| imgs, cls_map = fetch_metadata(api_key, ws, proj, ver) | |
| total_images += imgs | |
| for cls, cnt in cls_map.items(): | |
| norm = cls.strip().lower() | |
| class_counts[norm] = class_counts.get(norm, 0) + cnt | |
| class_sources.setdefault(norm, set()).add(url) | |
| return total_images, class_counts, class_sources | |
| def make_bar_chart(counts): | |
| """ | |
| Build a bar chart from a {label: value} dict. | |
| """ | |
| fig, ax = plt.subplots() | |
| labels = list(counts.keys()) | |
| values = list(counts.values()) | |
| ax.bar(range(len(labels)), values) | |
| ax.set_xticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=45, ha="right") | |
| ax.set_ylabel("Image Count") | |
| ax.set_title("Class Distribution") | |
| fig.tight_layout() | |
| return fig | |
| def load_datasets(api_key, file_objs): | |
| """ | |
| 1) Ensure API key present | |
| 2) Read & dedupe URLs from each uploaded .txt | |
| 3) Fetch & aggregate metadata | |
| Returns: total, table_data, figure, json_counts, markdown_sources | |
| """ | |
| if not api_key or not api_key.strip(): | |
| raise ValueError("Please enter your Roboflow API Key before loading datasets.") | |
| entries = [] | |
| seen = set() | |
| for fobj in file_objs: | |
| fname = getattr(fobj, "name", None) or fobj.get("name", "unknown") | |
| # read raw bytes or dict-data or file path | |
| try: | |
| raw = fobj.read() | |
| except: | |
| raw = fobj.get("data") if isinstance(fobj, dict) else None | |
| if raw is None and isinstance(fobj, str): | |
| with open(fobj, "rb") as fh: | |
| raw = fh.read() | |
| text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw | |
| for i, line in enumerate(text.splitlines(), start=1): | |
| url = line.strip() | |
| if url and url not in seen: | |
| seen.add(url) | |
| entries.append((url, fname, i)) | |
| total, counts, sources = aggregate_datasets(api_key, entries) | |
| # build dataframe rows | |
| table_data = [[cls, counts[cls]] for cls in counts] | |
| # build clickable markdown per-class | |
| md_lines = [] | |
| for cls in counts: | |
| links = ", ".join(f"[{url.split('/')[-1]}]({url})" for url in sources[cls]) | |
| md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}") | |
| md_sources = "\n".join(md_lines) | |
| fig = make_bar_chart(counts) | |
| return str(total), table_data, fig, json.dumps(counts, indent=2), md_sources | |
| def update_classes(df_data): | |
| """ | |
| Convert df_data into a list-of-lists (if needed), | |
| merge duplicate/lowercased class names, and recalc all outputs. | |
| Returns: total, updated_table, figure, json_counts, markdown_summary | |
| """ | |
| # convert Pandas DataFrame or NumPy array into list-of-lists | |
| if not isinstance(df_data, list): | |
| if hasattr(df_data, "to_numpy"): | |
| df_data = df_data.to_numpy().tolist() | |
| elif hasattr(df_data, "tolist"): | |
| df_data = df_data.tolist() | |
| combined = {} | |
| for row in df_data: | |
| if len(row) < 2: | |
| continue | |
| name, cnt = row[0], row[1] | |
| if not name: | |
| continue | |
| key = str(name).strip().lower() | |
| try: | |
| val = int(cnt) | |
| except: | |
| val = 0 | |
| combined[key] = combined.get(key, 0) + val | |
| total = sum(combined.values()) | |
| updated_table = [[k, combined[k]] for k in combined] | |
| fig = make_bar_chart(combined) | |
| md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined) | |
| return str(total), updated_table, fig, json.dumps(combined, indent=2), md_summary | |
| def build_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Roboflow Dataset Inspector") | |
| with gr.Row(): | |
| api_input = gr.Textbox(label="API Key", type="password") | |
| files = gr.Files(label="Upload .txt files", file_types=[".txt"]) | |
| load_btn = gr.Button("Load Datasets") | |
| total_out = gr.Textbox(label="Total Images", interactive=False) | |
| df = gr.Dataframe( | |
| headers=["Class Name", "Count"], | |
| row_count=(1, None), | |
| col_count=2, | |
| interactive=True | |
| ) | |
| plot = gr.Plot() | |
| json_out = gr.Textbox(label="Counts (JSON)", interactive=False) | |
| md_out = gr.Markdown(label="Class Sources") | |
| update_btn = gr.Button("Apply Class Edits") | |
| load_btn.click( | |
| fn=load_datasets, | |
| inputs=[api_input, files], | |
| outputs=[total_out, df, plot, json_out, md_out] | |
| ) | |
| update_btn.click( | |
| fn=update_classes, | |
| inputs=[df], | |
| outputs=[total_out, df, plot, json_out, md_out] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| build_ui().launch() | |