Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| import os | |
| import csv | |
| from io import BytesIO | |
| from PIL import Image | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime | |
| from filelock import FileLock | |
| if os.path.exists("/data"): | |
| DATA_DIR = "/data" | |
| else: | |
| DATA_DIR = "." | |
| URL_FILE = "urls.txt" | |
| LABEL_FILE = os.path.join(DATA_DIR, "annotations.csv") | |
| VERIFY_FILE = os.path.join(DATA_DIR, "verifications.csv") | |
| SKIP_FILE = os.path.join(DATA_DIR, "skipped.csv") | |
| LOCK_FILE = os.path.join(DATA_DIR, "data.lock") | |
| MAX_IMAGES = 6 | |
| THUMB_SIZE = (350, 350) | |
| ROOM_CLASSES = ["living_room", "bedroom", "kitchen", "bathroom", "dining_room", "outdoor", "other"] | |
| def init_files(): | |
| for f in [LABEL_FILE, VERIFY_FILE, SKIP_FILE]: | |
| if not os.path.exists(f): | |
| if f == LABEL_FILE: cols = ["timestamp", "user", "group_id", "url", "score", "label"] | |
| elif f == VERIFY_FILE: cols = ["timestamp", "user", "group_id", "url", "is_correct", "corrected_label"] | |
| else: cols = ["timestamp", "user", "group_id"] | |
| pd.DataFrame(columns=cols).to_csv(f, index=False) | |
| if not os.path.exists(URL_FILE): | |
| with open(URL_FILE, "w") as f: f.write("") | |
| init_files() | |
| def get_image_optimized(url): | |
| if not url: return None | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
| } | |
| response = requests.get(url, headers=headers, timeout=3) | |
| if response.status_code == 200: | |
| img = Image.open(BytesIO(response.content)) | |
| img.thumbnail(THUMB_SIZE, Image.Resampling.LANCZOS) | |
| return img | |
| except: | |
| pass | |
| return url | |
| def get_ordered_groups(): | |
| groups = [] | |
| seen = set() | |
| if not os.path.exists(URL_FILE): return [] | |
| with open(URL_FILE, 'r') as f: | |
| for line in f: | |
| u = line.strip() | |
| if u: | |
| try: gid = u.split("-m")[0].split("/")[-1] | |
| except: gid = "unknown" | |
| if gid not in seen: | |
| groups.append(gid) | |
| seen.add(gid) | |
| return groups | |
| def get_group_urls(target_gid): | |
| urls = [] | |
| with open(URL_FILE, 'r') as f: | |
| for line in f: | |
| u = line.strip() | |
| if u and target_gid in u: | |
| urls.append(u) | |
| return urls[:MAX_IMAGES] | |
| def get_saved_values(gid, mode): | |
| saved_data = {} | |
| try: | |
| filename = LABEL_FILE if mode == "label" else VERIFY_FILE | |
| if os.path.exists(filename): | |
| df = pd.read_csv(filename) | |
| rows = df[df['group_id'] == gid] | |
| for _, row in rows.iterrows(): | |
| if mode == "label": | |
| saved_data[row['url']] = {"score": row['score'], "label": row['label']} | |
| else: | |
| saved_data[row['url']] = {"is_correct": row['is_correct'], "label": row['corrected_label']} | |
| except: pass | |
| return saved_data | |
| def get_stats_text(): | |
| all_gids = get_ordered_groups() | |
| try: l = len(pd.read_csv(LABEL_FILE)['group_id'].unique()) | |
| except: l = 0 | |
| return f"**Total Properties:** {len(all_gids)} | **Labeled:** {l}" | |
| def render_workspace(mode, history, specific_index=None, move_back=False): | |
| all_groups = get_ordered_groups() | |
| total_groups = len(all_groups) | |
| try: l_done = set(pd.read_csv(LABEL_FILE)['group_id'].unique()) | |
| except: l_done = set() | |
| try: v_done = set(pd.read_csv(VERIFY_FILE)['group_id'].unique()) | |
| except: v_done = set() | |
| try: s_done = set(pd.read_csv(SKIP_FILE)['group_id'].unique()) | |
| except: s_done = set() | |
| target_gid = None | |
| target_idx = -1 | |
| if specific_index is not None: | |
| if 0 <= specific_index < total_groups: | |
| target_gid = all_groups[specific_index] | |
| target_idx = specific_index | |
| else: | |
| return {log_box: "End of list."} | |
| elif move_back and len(history) > 1: | |
| history.pop() | |
| target_gid = history[-1] | |
| try: target_idx = all_groups.index(target_gid) | |
| except: target_idx = 0 | |
| else: | |
| found = False | |
| for i, gid in enumerate(all_groups): | |
| if gid in s_done: continue | |
| is_ready = False | |
| if mode == "label" and gid not in l_done: is_ready = True | |
| elif mode == "verify" and gid in l_done and gid not in v_done: is_ready = True | |
| if is_ready: | |
| target_gid = gid | |
| target_idx = i | |
| found = True | |
| break | |
| if not found: | |
| return {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), log_box: "No more tasks found."} | |
| urls = get_group_urls(target_gid) | |
| if not history or history[-1] != target_gid: | |
| history.append(target_gid) | |
| saved_vals = get_saved_values(target_gid, mode) | |
| r1_vals = get_saved_values(target_gid, "label") if mode == "verify" else {} | |
| processed_images = [None] * MAX_IMAGES | |
| with ThreadPoolExecutor(max_workers=MAX_IMAGES) as executor: | |
| futures = {executor.submit(get_image_optimized, u): i for i, u in enumerate(urls)} | |
| for future in futures: | |
| processed_images[futures[future]] = future.result() | |
| header = f"# Property #{target_idx + 1} <span style='font-size:14px;color:gray;'>(ID: {target_gid})</span>" | |
| updates = { | |
| screen_menu: gr.update(visible=False), | |
| screen_work: gr.update(visible=True), | |
| header_md: header, | |
| state_urls: urls, | |
| state_hist: history, | |
| state_idx: target_idx, | |
| top_stats: get_stats_text(), | |
| log_box: f"Loaded group {target_gid}" | |
| } | |
| for i in range(MAX_IMAGES): | |
| img_c = img_objs[i] | |
| base = i * 4 | |
| c_sld, c_drp, c_chk, c_lbl = input_objs[base], input_objs[base+1], input_objs[base+2], input_objs[base+3] | |
| if i < len(urls): | |
| u = urls[i] | |
| img_data = processed_images[i] | |
| updates[img_c] = gr.update(value=img_data, visible=True) | |
| v_sc = saved_vals.get(u, {}).get('score', 5) | |
| v_lbl = saved_vals.get(u, {}).get('label', ROOM_CLASSES[0]) | |
| v_chk = saved_vals.get(u, {}).get('is_correct', True) | |
| if mode == "label": | |
| updates[c_sld] = gr.update(visible=True, value=v_sc) | |
| updates[c_drp] = gr.update(visible=True, value=v_lbl) | |
| updates[c_chk] = gr.update(visible=False) | |
| updates[c_lbl] = gr.update(visible=False) | |
| else: | |
| prev = r1_vals.get(u, {}).get('label', "Unknown") | |
| updates[c_sld] = gr.update(visible=False) | |
| updates[c_lbl] = gr.update(visible=True, value=f"**Labeled:** {prev}") | |
| updates[c_drp] = gr.update(visible=True, value=v_lbl) | |
| updates[c_chk] = gr.update(visible=True, value=v_chk) | |
| else: | |
| updates[img_c] = gr.update(value=None, visible=False) | |
| updates[c_sld] = gr.update(visible=False) | |
| updates[c_drp] = gr.update(visible=False) | |
| updates[c_chk] = gr.update(visible=False) | |
| updates[c_lbl] = gr.update(visible=False) | |
| return updates | |
| def save_data(mode, history, urls, *args): | |
| if not history: return | |
| gid = history[-1] | |
| ts = datetime.now().isoformat() | |
| rows = [] | |
| for i, u in enumerate(urls): | |
| base = i * 4 | |
| sc, lbl, chk = args[base], args[base+1], args[base+2] | |
| if mode == "label": rows.append([ts, "user", gid, u, sc, lbl]) | |
| else: rows.append([ts, "user", gid, u, chk, lbl]) | |
| fname = LABEL_FILE if mode == "label" else VERIFY_FILE | |
| with FileLock(LOCK_FILE): | |
| with open(fname, "a", newline="") as f: | |
| csv.writer(f).writerows(rows) | |
| return render_workspace(mode, history) | |
| def skip_group(idx, history, mode): | |
| if history: | |
| gid = history[-1] | |
| with FileLock(LOCK_FILE): | |
| with open(SKIP_FILE, "a", newline="") as f: | |
| csv.writer(f).writerow([datetime.now().isoformat(), "user", gid]) | |
| return render_workspace(mode, history, specific_index=idx + 1) | |
| def refresh_cat(): | |
| all_gids = get_ordered_groups() | |
| try: l_set = set(pd.read_csv(LABEL_FILE)['group_id'].unique()) | |
| except: l_set = set() | |
| try: v_set = set(pd.read_csv(VERIFY_FILE)['group_id'].unique()) | |
| except: v_set = set() | |
| data = [] | |
| for i, gid in enumerate(all_gids): | |
| s = "⚪ Pending" | |
| if gid in v_set: s = "✅ Verified" | |
| elif gid in l_set: s = "🔵 Labeled" | |
| data.append([i+1, s, gid]) | |
| return pd.DataFrame(data, columns=["#", "Status", "ID"]) | |
| with gr.Blocks(title="Fast Labeler") as demo: | |
| state_mode = gr.State("label") | |
| state_hist = gr.State([]) | |
| state_urls = gr.State([]) | |
| state_idx = gr.State(0) | |
| with gr.Row(variant="panel"): | |
| top_stats = gr.Markdown("Loading stats...") | |
| btn_home = gr.Button("🏠 Home", size="sm", scale=0) | |
| with gr.Tabs(): | |
| with gr.Tab("Workspace", id=0): | |
| with gr.Group() as screen_menu: | |
| gr.Markdown("# Welcome! 👋") | |
| gr.Markdown("Server-side compression enabled.") | |
| with gr.Row(): | |
| b_start_l = gr.Button("Start Labeling", variant="primary") | |
| b_start_v = gr.Button("Start Verification") | |
| with gr.Group(visible=False) as screen_work: | |
| header_md = gr.Markdown() | |
| img_objs, input_objs = [], [] | |
| with gr.Row(): | |
| for i in range(MAX_IMAGES): | |
| with gr.Column(min_width=250): | |
| img = gr.Image(interactive=False, height=280, show_label=False) | |
| with gr.Group(): | |
| sld = gr.Slider(1, 10, step=1, label="Score", visible=False) | |
| lbl = gr.Markdown(visible=False) | |
| drp = gr.Dropdown(ROOM_CLASSES, label="Class", visible=False) | |
| chk = gr.Checkbox(label="Correct?", value=True, visible=False) | |
| img_objs.append(img) | |
| input_objs.extend([sld, drp, chk, lbl]) | |
| with gr.Row(): | |
| b_back = gr.Button("⬅ Back") | |
| b_skip = gr.Button("Skip ➡") | |
| b_save = gr.Button("💾 Save & Next", variant="primary") | |
| log_box = gr.Textbox(label="Log", interactive=False, max_lines=1) | |
| with gr.Tab("Catalog", id=1): | |
| with gr.Row(): | |
| num_in = gr.Number(value=1, label="Property #", precision=0) | |
| b_go_l = gr.Button("Go (Label)") | |
| b_go_v = gr.Button("Go (Verify)") | |
| df_cat = gr.Dataframe(interactive=False) | |
| b_ref_cat = gr.Button("Refresh") | |
| ALL_IO = [screen_menu, screen_work, header_md, state_urls, state_hist, state_idx, top_stats, log_box] + img_objs + input_objs | |
| b_start_l.click(lambda: "label", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO) | |
| b_start_v.click(lambda: "verify", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO) | |
| b_save.click(save_data, [state_mode, state_hist, state_urls] + input_objs, ALL_IO) | |
| b_skip.click(skip_group, [state_idx, state_hist, state_mode], ALL_IO) | |
| b_back.click(lambda m, h: render_workspace(m, h, move_back=True), [state_mode, state_hist], ALL_IO) | |
| btn_home.click(lambda: {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), state_hist: []}, None, [screen_menu, screen_work, state_hist]) | |
| b_go_l.click(lambda: "label", None, state_mode).then(lambda n, m, h: render_workspace(m, h, specific_index=int(n)-1), [num_in, state_mode, state_hist], ALL_IO) | |
| b_go_v.click(lambda: "verify", None, state_mode).then(lambda n, m, h: render_workspace(m, h, specific_index=int(n)-1), [num_in, state_mode, state_hist], ALL_IO) | |
| b_ref_cat.click(refresh_cat, None, df_cat) | |
| demo.load(refresh_cat, None, df_cat).then(get_stats_text, None, top_stats) | |
| demo.queue().launch(theme=gr.themes.Soft()) | |