File size: 6,500 Bytes
fc7318a
 
 
 
 
5067213
fc7318a
 
 
 
 
 
 
 
 
 
 
5067213
fc7318a
 
 
 
02f8610
fc7318a
 
02f8610
5067213
 
 
 
 
 
02f8610
fc7318a
5067213
 
ae07410
fc7318a
 
e5beb8a
fc7318a
02f8610
 
 
 
fc7318a
 
ae07410
 
 
02f8610
 
 
 
ae07410
fc7318a
ae07410
 
 
 
 
fc7318a
 
 
 
02f8610
fc7318a
 
02f8610
 
 
 
 
5067213
 
fc7318a
 
 
 
88491ae
 
02f8610
 
 
 
88491ae
02f8610
 
 
e5beb8a
 
ae07410
5067213
02f8610
594bb51
ae07410
5067213
 
ae07410
5067213
 
 
02f8610
5067213
e5beb8a
ae07410
e5beb8a
ae07410
 
 
 
02f8610
 
e5beb8a
02f8610
ae07410
 
02f8610
ae07410
 
 
 
02f8610
fc7318a
 
 
ae07410
02f8610
 
 
ae07410
02f8610
 
 
 
 
 
 
ae07410
02f8610
 
 
 
5067213
ae07410
02f8610
ae07410
5067213
 
 
 
ae07410
 
02f8610
ae07410
5067213
02f8610
 
fc7318a
 
 
 
 
ae07410
fc7318a
5067213
 
ae07410
7451f43
fc7318a
5067213
02f8610
 
 
 
5067213
ae07410
 
 
fc7318a
7451f43
 
5067213
 
 
02f8610
5067213
 
 
 
02f8610
5067213
fc7318a
 
 
 
 
ae07410
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import re
import json
import requests
import matplotlib.pyplot as plt
import gradio as gr
from requests.exceptions import HTTPError


def parse_roboflow_url(url):
    """
    Extract workspace/project and version from a Roboflow Universe URL.
    Returns (workspace, project, version)
    """
    pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
    match = re.search(pattern, url)
    if not match:
        raise ValueError(f"Invalid Roboflow dataset URL: {url}")
    return match.groups()


def fetch_metadata(api_key, workspace, project, version):
    """
    Fetch metadata from Roboflow. Raises ValueError on HTTP errors.
    """
    endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
    resp = requests.get(endpoint, params={"api_key": api_key})
    try:
        resp.raise_for_status()
    except HTTPError:
        if resp.status_code == 401:
            raise ValueError("Unauthorized: check your API key.")
        else:
            raise ValueError(f"Error {resp.status_code} for {workspace}/{project}/{version}")
    data = resp.json()
    total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0)
    classes = data.get("project", {}).get("classes", {})
    return total, classes


def aggregate_datasets(api_key, entries):
    """
    Given list of (url, file, line), returns:
      - total_images
      - dict[class_name_lowercase] = aggregated count
      - dict[class_name_lowercase] = set(source URLs)
    """
    total_images = 0
    class_counts = {}
    class_sources = {}
    for url, fname, lineno in entries:
        try:
            ws, proj, ver = parse_roboflow_url(url)
        except ValueError:
            raise ValueError(f"Invalid URL '{url}' in file '{fname}', line {lineno}")
        imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
        total_images += imgs
        for cls, cnt in cls_map.items():
            norm = cls.strip().lower()
            class_counts[norm] = class_counts.get(norm, 0) + cnt
            class_sources.setdefault(norm, set()).add(url)
    return total_images, class_counts, class_sources


def make_bar_chart(counts):
    """
    Build a bar chart from a {label: value} dict.
    """
    fig, ax = plt.subplots()
    labels = list(counts.keys())
    values = list(counts.values())
    ax.bar(range(len(labels)), values)
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel("Image Count")
    ax.set_title("Class Distribution")
    fig.tight_layout()
    return fig


def load_datasets(api_key, file_objs):
    """
    1) Ensure API key present
    2) Read & dedupe URLs from each uploaded .txt
    3) Fetch & aggregate metadata
    Returns: total, table_data, figure, json_counts, markdown_sources
    """
    if not api_key or not api_key.strip():
        raise ValueError("Please enter your Roboflow API Key before loading datasets.")

    entries = []
    seen = set()
    for fobj in file_objs:
        fname = getattr(fobj, "name", None) or fobj.get("name", "unknown")
        # read raw bytes or dict-data or file path
        try:
            raw = fobj.read()
        except:
            raw = fobj.get("data") if isinstance(fobj, dict) else None
        if raw is None and isinstance(fobj, str):
            with open(fobj, "rb") as fh:
                raw = fh.read()
        text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw

        for i, line in enumerate(text.splitlines(), start=1):
            url = line.strip()
            if url and url not in seen:
                seen.add(url)
                entries.append((url, fname, i))

    total, counts, sources = aggregate_datasets(api_key, entries)

    # build dataframe rows
    table_data = [[cls, counts[cls]] for cls in counts]

    # build clickable markdown per-class
    md_lines = []
    for cls in counts:
        links = ", ".join(f"[{url.split('/')[-1]}]({url})" for url in sources[cls])
        md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}")
    md_sources = "\n".join(md_lines)

    fig = make_bar_chart(counts)
    return str(total), table_data, fig, json.dumps(counts, indent=2), md_sources


def update_classes(df_data):
    """
    Convert df_data into a list-of-lists (if needed),
    merge duplicate/lowercased class names, and recalc all outputs.
    Returns: total, updated_table, figure, json_counts, markdown_summary
    """
    # convert Pandas DataFrame or NumPy array into list-of-lists
    if not isinstance(df_data, list):
        if hasattr(df_data, "to_numpy"):
            df_data = df_data.to_numpy().tolist()
        elif hasattr(df_data, "tolist"):
            df_data = df_data.tolist()

    combined = {}
    for row in df_data:
        if len(row) < 2:
            continue
        name, cnt = row[0], row[1]
        if not name:
            continue
        key = str(name).strip().lower()
        try:
            val = int(cnt)
        except:
            val = 0
        combined[key] = combined.get(key, 0) + val

    total = sum(combined.values())
    updated_table = [[k, combined[k]] for k in combined]
    fig = make_bar_chart(combined)
    md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined)

    return str(total), updated_table, fig, json.dumps(combined, indent=2), md_summary


def build_ui():
    with gr.Blocks() as demo:
        gr.Markdown("## Roboflow Dataset Inspector")

        with gr.Row():
            api_input = gr.Textbox(label="API Key", type="password")
            files = gr.Files(label="Upload .txt files", file_types=[".txt"])

        load_btn = gr.Button("Load Datasets")
        total_out = gr.Textbox(label="Total Images", interactive=False)
        df = gr.Dataframe(
            headers=["Class Name", "Count"],
            row_count=(1, None),
            col_count=2,
            interactive=True
        )
        plot = gr.Plot()
        json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
        md_out = gr.Markdown(label="Class Sources")

        update_btn = gr.Button("Apply Class Edits")

        load_btn.click(
            fn=load_datasets,
            inputs=[api_input, files],
            outputs=[total_out, df, plot, json_out, md_out]
        )
        update_btn.click(
            fn=update_classes,
            inputs=[df],
            outputs=[total_out, df, plot, json_out, md_out]
        )

    return demo


if __name__ == "__main__":
    build_ui().launch()