Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| import csv | |
| import json | |
| import threading | |
| import random | |
| from io import BytesIO | |
| from PIL import Image | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime | |
| from filelock import FileLock | |
| from huggingface_hub import HfApi, hf_hub_download | |
| DATASET_REPO_ID = os.environ.get("DATASET_REPO_ID", "fast-stager/property-labels") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| CACHE_DIR = "/tmp/data" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| URL_FILE = "new_urls.json" | |
| LABEL_FILE = os.path.join(CACHE_DIR, "annotations.csv") | |
| VERIFY_FILE = os.path.join(CACHE_DIR, "verifications.csv") | |
| SKIP_FILE = os.path.join(CACHE_DIR, "skipped.csv") | |
| LOCK_FILE = os.path.join(CACHE_DIR, "data.lock") | |
| FIXED_IN_SESSION = set() | |
| MANUAL_EXCLUDE = {"075c8bb8a73c45d71788e711edd9e8d5l", "07a0544f217db88fe2b06fd5d38f02a6l", "6bf16112723de3318c44641958638a56l"} | |
| ROOM_CLASSES = ["living_room", "bedroom", "kitchen", "bathroom", "dining_room", "outdoor", "other"] | |
| MAX_IMAGES = 6 | |
| THUMB_SIZE = (350, 350) | |
| def sync_pull(): | |
| token = HF_TOKEN if HF_TOKEN and len(HF_TOKEN) > 5 else None | |
| for filename in ["annotations.csv", "verifications.csv", "skipped.csv"]: | |
| try: | |
| local_path = os.path.join(CACHE_DIR, filename) | |
| if os.path.exists(local_path): os.remove(local_path) | |
| hf_hub_download(repo_id=DATASET_REPO_ID, filename=filename, repo_type="dataset", local_dir=CACHE_DIR, token=token, force_download=True) | |
| except: pass | |
| def sync_push_background(local_path, remote_filename): | |
| token = HF_TOKEN if HF_TOKEN and len(HF_TOKEN) > 5 else None | |
| if not token: return | |
| def _push(): | |
| try: | |
| api = HfApi(token=token) | |
| api.upload_file(path_or_fileobj=local_path, path_in_repo=remote_filename, repo_id=DATASET_REPO_ID, repo_type="dataset") | |
| except: pass | |
| threading.Thread(target=_push).start() | |
| def init_files(): | |
| sync_pull() | |
| for f in [LABEL_FILE, VERIFY_FILE, SKIP_FILE]: | |
| if not os.path.exists(f): | |
| cols = ["timestamp", "user", "group_id", "url", "score", "label"] if f == LABEL_FILE else \ | |
| ["timestamp", "user", "group_id", "url", "is_correct", "corrected_label", "corrected_score"] if f == VERIFY_FILE else \ | |
| ["timestamp", "user", "group_id"] | |
| pd.DataFrame(columns=cols).to_csv(f, index=False) | |
| init_files() | |
| def load_all_urls(): | |
| if not os.path.exists(URL_FILE): return [] | |
| try: | |
| with open(URL_FILE, 'r') as f: | |
| data = json.load(f) | |
| return [img for g in data.get("groups", []) for img in g.get("images", [])] | |
| except: return [] | |
| def get_ordered_groups(): | |
| groups = [] | |
| seen = set() | |
| for u in load_all_urls(): | |
| try: gid = u.split("-m")[0].split("/")[-1] | |
| except: gid = "unknown" | |
| if gid not in seen: | |
| groups.append(gid); seen.add(gid) | |
| return groups | |
| def get_clean_df(filepath): | |
| if not os.path.exists(filepath): return pd.DataFrame() | |
| try: | |
| df = pd.read_csv(filepath) | |
| if df.empty: return df | |
| if 'label' in df.columns: df['label'] = df['label'].astype(str).str.strip().str.lower() | |
| if 'corrected_label' in df.columns: df['corrected_label'] = df['corrected_label'].astype(str).str.strip().str.lower() | |
| if 'score' in df.columns: df['score'] = pd.to_numeric(df['score'], errors='coerce').fillna(0).astype(int) | |
| if 'corrected_score' in df.columns: df['corrected_score'] = pd.to_numeric(df['corrected_score'], errors='coerce').fillna(0).astype(int) | |
| return df.drop_duplicates(subset=['url'], keep='last') | |
| except: return pd.DataFrame() | |
| def get_flagged_groups(): | |
| df = get_clean_df(LABEL_FILE) | |
| if df.empty: return [] | |
| errors = df[(df['score'] == 10) & (df['label'] != 'living_room')] | |
| flagged = errors['group_id'].unique().tolist() | |
| return [g for g in flagged if g not in FIXED_IN_SESSION and g not in MANUAL_EXCLUDE] | |
| def get_stats_text(): | |
| all_gids = get_ordered_groups() | |
| flagged = get_flagged_groups() | |
| df_l = get_clean_df(LABEL_FILE) | |
| df_v = get_clean_df(VERIFY_FILE) | |
| l_count = len(df_l['group_id'].unique()) if not df_l.empty else 0 | |
| v_count = len(df_v['group_id'].unique()) if not df_v.empty else 0 | |
| err_msg = f" | ⚠️ **Fix:** {len(flagged)}" if flagged else " | ✅ Clean" | |
| return f"**Total:** {len(all_gids)} | **Labeled:** {l_count} | **Verified:** {v_count}{err_msg}" | |
| def render_workspace(mode, history, specific_index=None, move_back=False): | |
| all_ordered = get_ordered_groups() | |
| flagged_pool = get_flagged_groups() | |
| current_gid = history[-1] if history else None | |
| target_gid = None | |
| if specific_index is not None: | |
| if 0 <= specific_index < len(all_ordered): target_gid = all_ordered[specific_index] | |
| elif move_back and len(history) > 1: | |
| history.pop(); target_gid = history[-1] | |
| else: | |
| if mode == "fix": | |
| candidates = [g for g in flagged_pool if g != current_gid] | |
| if not candidates and flagged_pool: candidates = flagged_pool | |
| if candidates: target_gid = candidates[0] | |
| else: | |
| df_l, df_v = get_clean_df(LABEL_FILE), get_clean_df(VERIFY_FILE) | |
| l_done = set(df_l['group_id'].unique()) if not df_l.empty else set() | |
| v_done = set(df_v['group_id'].unique()) if not df_v.empty else set() | |
| candidates = [g for g in all_ordered if (mode=="label" and g not in l_done) or (mode=="verify" and g in l_done and g not in v_done)] | |
| if candidates: target_gid = random.choice(candidates) | |
| if not target_gid: return {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), log_box: "Done!"} | |
| urls = [u for u in load_all_urls() if target_gid in u][:MAX_IMAGES] | |
| if not history or history[-1] != target_gid: history.append(target_gid) | |
| saved_vals = {} | |
| df_mode = get_clean_df(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE) | |
| if not df_mode.empty: | |
| for _, r in df_mode[df_mode['group_id'] == target_gid].iterrows(): | |
| if mode in ["label", "fix"]: saved_vals[r['url']] = {"score": r['score'], "label": r['label']} | |
| else: saved_vals[r['url']] = {"is_correct": r['is_correct'], "label": r['corrected_label'], "score": r['corrected_score']} | |
| with ThreadPoolExecutor(max_workers=MAX_IMAGES) as ex: | |
| def fetch(u): | |
| try: | |
| res = requests.get(u, timeout=3, headers={'User-Agent': 'Mozilla/5.0'}) | |
| img = Image.open(BytesIO(res.content)); img.thumbnail(THUMB_SIZE); return img | |
| except: return None | |
| processed_images = list(ex.map(fetch, urls)) | |
| target_idx = all_ordered.index(target_gid) | |
| updates = { | |
| screen_menu: gr.update(visible=False), screen_work: gr.update(visible=True), | |
| header_md: f"# {mode.upper()} - Prop #{target_idx + 1} ({target_gid})", | |
| state_urls: urls, state_hist: history, state_idx: target_idx, | |
| top_stats: get_stats_text(), log_box: f"Viewing: {target_gid}" | |
| } | |
| for i in range(MAX_IMAGES): | |
| base = i * 4 | |
| c_sld, c_drp, c_chk, c_lbl = input_objs[base:base+4] | |
| if i < len(urls): | |
| u = urls[i] | |
| updates[img_objs[i]] = gr.update(value=processed_images[i], visible=True) | |
| v_sc = int(saved_vals.get(u, {}).get('score', 5)) | |
| v_lbl = str(saved_vals.get(u, {}).get('label', "living_room")).strip().lower() | |
| is_err = (v_sc == 10 and v_lbl != "living_room") | |
| if mode in ["label", "fix"]: | |
| updates[c_sld] = gr.update(visible=True, value=v_sc, interactive=True) | |
| updates[c_drp] = gr.update(visible=True, value=v_lbl if v_lbl in ROOM_CLASSES else "living_room", interactive=True) | |
| updates[c_chk], updates[c_lbl] = gr.update(visible=False), gr.update(visible=True if is_err else False, value="<span style='color:red'>⚠️ Score 10 Only for Living Room</span>") | |
| else: | |
| updates[c_sld], updates[c_drp] = gr.update(visible=True, value=v_sc), gr.update(visible=True, value=v_lbl) | |
| updates[c_chk], updates[c_lbl] = gr.update(visible=True, value=True), gr.update(visible=True, value=f"Prev: {v_lbl}") | |
| else: | |
| updates[img_objs[i]] = gr.update(visible=False) | |
| for obj in [c_sld, c_drp, c_chk, c_lbl]: updates[obj] = gr.update(visible=False) | |
| return updates | |
| def save_data(mode, history, urls, *args): | |
| if not history: return | |
| gid = history[-1] | |
| if mode == "fix": FIXED_IN_SESSION.add(gid) | |
| ts = datetime.now().isoformat(); rows = [] | |
| for i, u in enumerate(urls): | |
| sc, lbl, chk = args[i*4], args[i*4+1], args[i*4+2] | |
| clean_lbl = str(lbl).strip().lower() | |
| if mode in ["label", "fix"]: rows.append([ts, "user", gid, u, int(sc), clean_lbl]) | |
| else: rows.append([ts, "user", gid, u, chk, clean_lbl, int(sc)]) | |
| with FileLock(LOCK_FILE): | |
| with open(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE, "a", newline="") as f: csv.writer(f).writerows(rows) | |
| sync_push_background(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE, os.path.basename(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE)) | |
| return render_workspace(mode, history) | |
| def refresh_cat(): | |
| all_gids = get_ordered_groups() | |
| flagged = set(get_flagged_groups()) | |
| df_l, df_v = get_clean_df(LABEL_FILE), get_clean_df(VERIFY_FILE) | |
| l_set = set(df_l['group_id'].unique()) if not df_l.empty else set() | |
| v_set = set(df_v['group_id'].unique()) if not df_v.empty else set() | |
| data = [[i+1, "⚠️ Fix Needed" if gid in flagged else "✅ Verified" if gid in v_set else "🔵 Labeled" if gid in l_set else "⚪ Pending", gid] for i, gid in enumerate(all_gids)] | |
| return pd.DataFrame(data, columns=["#", "Status", "ID"]) | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Labeler Pro") as demo: | |
| state_mode, state_hist, state_urls, state_idx = gr.State("label"), gr.State([]), gr.State([]), gr.State(0) | |
| with gr.Row(): | |
| top_stats = gr.Markdown("Loading...") | |
| btn_home = gr.Button("🏠 Home", size="sm", scale=0) | |
| with gr.Tabs(): | |
| with gr.Tab("Workspace"): | |
| with gr.Group() as screen_menu: | |
| gr.Markdown("# Property Labeler Pro") | |
| with gr.Row(): | |
| b_start_l, b_start_v, b_start_f = gr.Button("Label", variant="primary"), gr.Button("Verify"), gr.Button("🛠 Fix Errors", variant="secondary") | |
| with gr.Group(visible=False) as screen_work: | |
| header_md = gr.Markdown() | |
| img_objs, input_objs = [], [] | |
| with gr.Row(): | |
| for i in range(MAX_IMAGES): | |
| with gr.Column(min_width=200): | |
| img = gr.Image(interactive=False, height=240) | |
| sld, drp, chk, lbl = gr.Slider(1, 10, step=1, label="Score"), gr.Dropdown(ROOM_CLASSES, label="Class"), gr.Checkbox(label="Correct?"), gr.Markdown() | |
| img_objs.append(img); input_objs.extend([sld, drp, chk, lbl]) | |
| with gr.Row(): | |
| b_back, b_save = gr.Button("⬅ Back"), gr.Button("💾 Save & Next", variant="primary") | |
| log_box = gr.Textbox(label="Status", interactive=False) | |
| with gr.Tab("Catalog"): | |
| with gr.Row(): | |
| num_in = gr.Number(value=1, label="Prop #", precision=0) | |
| b_go_l, b_go_v, b_go_f = gr.Button("Go Label"), gr.Button("Go Verify"), gr.Button("Go Fix") | |
| df_cat = gr.Dataframe(interactive=False) | |
| b_ref_cat = gr.Button("Refresh Catalog") | |
| ALL_IO = [screen_menu, screen_work, header_md, state_urls, state_hist, state_idx, top_stats, log_box] + img_objs + input_objs | |
| b_start_l.click(lambda: "label", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO) | |
| b_start_v.click(lambda: "verify", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO) | |
| b_start_f.click(lambda: "fix", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO) | |
| b_save.click(save_data, [state_mode, state_hist, state_urls] + input_objs, ALL_IO) | |
| b_back.click(lambda m, h: render_workspace(m, h, move_back=True), [state_mode, state_hist], ALL_IO) | |
| btn_home.click(lambda: {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), state_hist: []}, None, [screen_menu, screen_work, state_hist]) | |
| b_go_l.click(lambda: "label", None, state_mode).then(lambda n,m,h: render_workspace(m,h,int(n)-1), [num_in, state_mode, state_hist], ALL_IO) | |
| b_go_v.click(lambda: "verify", None, state_mode).then(lambda n,m,h: render_workspace(m,h,int(n)-1), [num_in, state_mode, state_hist], ALL_IO) | |
| b_go_f.click(lambda: "fix", None, state_mode).then(lambda n,m,h: render_workspace(m,h,int(n)-1), [num_in, state_mode, state_hist], ALL_IO) | |
| b_ref_cat.click(refresh_cat, None, df_cat).then(get_stats_text, None, top_stats) | |
| demo.load(refresh_cat, None, df_cat).then(get_stats_text, None, top_stats) | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |