data-collection / app.py
Nightfury16's picture
update app.py
6b53e51
import os
import gradio as gr
import pandas as pd
import requests
import csv
import json
import threading
import random
from io import BytesIO
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from filelock import FileLock
from huggingface_hub import HfApi, hf_hub_download
DATASET_REPO_ID = os.environ.get("DATASET_REPO_ID", "fast-stager/property-labels")
HF_TOKEN = os.environ.get("HF_TOKEN")
CACHE_DIR = "/tmp/data"
os.makedirs(CACHE_DIR, exist_ok=True)
URL_FILE = "new_urls.json"
LABEL_FILE = os.path.join(CACHE_DIR, "annotations.csv")
VERIFY_FILE = os.path.join(CACHE_DIR, "verifications.csv")
SKIP_FILE = os.path.join(CACHE_DIR, "skipped.csv")
LOCK_FILE = os.path.join(CACHE_DIR, "data.lock")
FIXED_IN_SESSION = set()
MANUAL_EXCLUDE = {"075c8bb8a73c45d71788e711edd9e8d5l", "07a0544f217db88fe2b06fd5d38f02a6l", "6bf16112723de3318c44641958638a56l"}
ROOM_CLASSES = ["living_room", "bedroom", "kitchen", "bathroom", "dining_room", "outdoor", "other"]
MAX_IMAGES = 6
THUMB_SIZE = (350, 350)
def sync_pull():
token = HF_TOKEN if HF_TOKEN and len(HF_TOKEN) > 5 else None
for filename in ["annotations.csv", "verifications.csv", "skipped.csv"]:
try:
local_path = os.path.join(CACHE_DIR, filename)
if os.path.exists(local_path): os.remove(local_path)
hf_hub_download(repo_id=DATASET_REPO_ID, filename=filename, repo_type="dataset", local_dir=CACHE_DIR, token=token, force_download=True)
except: pass
def sync_push_background(local_path, remote_filename):
token = HF_TOKEN if HF_TOKEN and len(HF_TOKEN) > 5 else None
if not token: return
def _push():
try:
api = HfApi(token=token)
api.upload_file(path_or_fileobj=local_path, path_in_repo=remote_filename, repo_id=DATASET_REPO_ID, repo_type="dataset")
except: pass
threading.Thread(target=_push).start()
def init_files():
sync_pull()
for f in [LABEL_FILE, VERIFY_FILE, SKIP_FILE]:
if not os.path.exists(f):
cols = ["timestamp", "user", "group_id", "url", "score", "label"] if f == LABEL_FILE else \
["timestamp", "user", "group_id", "url", "is_correct", "corrected_label", "corrected_score"] if f == VERIFY_FILE else \
["timestamp", "user", "group_id"]
pd.DataFrame(columns=cols).to_csv(f, index=False)
init_files()
def load_all_urls():
if not os.path.exists(URL_FILE): return []
try:
with open(URL_FILE, 'r') as f:
data = json.load(f)
return [img for g in data.get("groups", []) for img in g.get("images", [])]
except: return []
def get_ordered_groups():
groups = []
seen = set()
for u in load_all_urls():
try: gid = u.split("-m")[0].split("/")[-1]
except: gid = "unknown"
if gid not in seen:
groups.append(gid); seen.add(gid)
return groups
def get_clean_df(filepath):
if not os.path.exists(filepath): return pd.DataFrame()
try:
df = pd.read_csv(filepath)
if df.empty: return df
if 'label' in df.columns: df['label'] = df['label'].astype(str).str.strip().str.lower()
if 'corrected_label' in df.columns: df['corrected_label'] = df['corrected_label'].astype(str).str.strip().str.lower()
if 'score' in df.columns: df['score'] = pd.to_numeric(df['score'], errors='coerce').fillna(0).astype(int)
if 'corrected_score' in df.columns: df['corrected_score'] = pd.to_numeric(df['corrected_score'], errors='coerce').fillna(0).astype(int)
return df.drop_duplicates(subset=['url'], keep='last')
except: return pd.DataFrame()
def get_flagged_groups():
df = get_clean_df(LABEL_FILE)
if df.empty: return []
errors = df[(df['score'] == 10) & (df['label'] != 'living_room')]
flagged = errors['group_id'].unique().tolist()
return [g for g in flagged if g not in FIXED_IN_SESSION and g not in MANUAL_EXCLUDE]
def get_stats_text():
all_gids = get_ordered_groups()
flagged = get_flagged_groups()
df_l = get_clean_df(LABEL_FILE)
df_v = get_clean_df(VERIFY_FILE)
l_count = len(df_l['group_id'].unique()) if not df_l.empty else 0
v_count = len(df_v['group_id'].unique()) if not df_v.empty else 0
err_msg = f" | ⚠️ **Fix:** {len(flagged)}" if flagged else " | ✅ Clean"
return f"**Total:** {len(all_gids)} | **Labeled:** {l_count} | **Verified:** {v_count}{err_msg}"
def render_workspace(mode, history, specific_index=None, move_back=False):
all_ordered = get_ordered_groups()
flagged_pool = get_flagged_groups()
current_gid = history[-1] if history else None
target_gid = None
if specific_index is not None:
if 0 <= specific_index < len(all_ordered): target_gid = all_ordered[specific_index]
elif move_back and len(history) > 1:
history.pop(); target_gid = history[-1]
else:
if mode == "fix":
candidates = [g for g in flagged_pool if g != current_gid]
if not candidates and flagged_pool: candidates = flagged_pool
if candidates: target_gid = candidates[0]
else:
df_l, df_v = get_clean_df(LABEL_FILE), get_clean_df(VERIFY_FILE)
l_done = set(df_l['group_id'].unique()) if not df_l.empty else set()
v_done = set(df_v['group_id'].unique()) if not df_v.empty else set()
candidates = [g for g in all_ordered if (mode=="label" and g not in l_done) or (mode=="verify" and g in l_done and g not in v_done)]
if candidates: target_gid = random.choice(candidates)
if not target_gid: return {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), log_box: "Done!"}
urls = [u for u in load_all_urls() if target_gid in u][:MAX_IMAGES]
if not history or history[-1] != target_gid: history.append(target_gid)
saved_vals = {}
df_mode = get_clean_df(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE)
if not df_mode.empty:
for _, r in df_mode[df_mode['group_id'] == target_gid].iterrows():
if mode in ["label", "fix"]: saved_vals[r['url']] = {"score": r['score'], "label": r['label']}
else: saved_vals[r['url']] = {"is_correct": r['is_correct'], "label": r['corrected_label'], "score": r['corrected_score']}
with ThreadPoolExecutor(max_workers=MAX_IMAGES) as ex:
def fetch(u):
try:
res = requests.get(u, timeout=3, headers={'User-Agent': 'Mozilla/5.0'})
img = Image.open(BytesIO(res.content)); img.thumbnail(THUMB_SIZE); return img
except: return None
processed_images = list(ex.map(fetch, urls))
target_idx = all_ordered.index(target_gid)
updates = {
screen_menu: gr.update(visible=False), screen_work: gr.update(visible=True),
header_md: f"# {mode.upper()} - Prop #{target_idx + 1} ({target_gid})",
state_urls: urls, state_hist: history, state_idx: target_idx,
top_stats: get_stats_text(), log_box: f"Viewing: {target_gid}"
}
for i in range(MAX_IMAGES):
base = i * 4
c_sld, c_drp, c_chk, c_lbl = input_objs[base:base+4]
if i < len(urls):
u = urls[i]
updates[img_objs[i]] = gr.update(value=processed_images[i], visible=True)
v_sc = int(saved_vals.get(u, {}).get('score', 5))
v_lbl = str(saved_vals.get(u, {}).get('label', "living_room")).strip().lower()
is_err = (v_sc == 10 and v_lbl != "living_room")
if mode in ["label", "fix"]:
updates[c_sld] = gr.update(visible=True, value=v_sc, interactive=True)
updates[c_drp] = gr.update(visible=True, value=v_lbl if v_lbl in ROOM_CLASSES else "living_room", interactive=True)
updates[c_chk], updates[c_lbl] = gr.update(visible=False), gr.update(visible=True if is_err else False, value="<span style='color:red'>⚠️ Score 10 Only for Living Room</span>")
else:
updates[c_sld], updates[c_drp] = gr.update(visible=True, value=v_sc), gr.update(visible=True, value=v_lbl)
updates[c_chk], updates[c_lbl] = gr.update(visible=True, value=True), gr.update(visible=True, value=f"Prev: {v_lbl}")
else:
updates[img_objs[i]] = gr.update(visible=False)
for obj in [c_sld, c_drp, c_chk, c_lbl]: updates[obj] = gr.update(visible=False)
return updates
def save_data(mode, history, urls, *args):
if not history: return
gid = history[-1]
if mode == "fix": FIXED_IN_SESSION.add(gid)
ts = datetime.now().isoformat(); rows = []
for i, u in enumerate(urls):
sc, lbl, chk = args[i*4], args[i*4+1], args[i*4+2]
clean_lbl = str(lbl).strip().lower()
if mode in ["label", "fix"]: rows.append([ts, "user", gid, u, int(sc), clean_lbl])
else: rows.append([ts, "user", gid, u, chk, clean_lbl, int(sc)])
with FileLock(LOCK_FILE):
with open(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE, "a", newline="") as f: csv.writer(f).writerows(rows)
sync_push_background(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE, os.path.basename(LABEL_FILE if mode in ["label", "fix"] else VERIFY_FILE))
return render_workspace(mode, history)
def refresh_cat():
all_gids = get_ordered_groups()
flagged = set(get_flagged_groups())
df_l, df_v = get_clean_df(LABEL_FILE), get_clean_df(VERIFY_FILE)
l_set = set(df_l['group_id'].unique()) if not df_l.empty else set()
v_set = set(df_v['group_id'].unique()) if not df_v.empty else set()
data = [[i+1, "⚠️ Fix Needed" if gid in flagged else "✅ Verified" if gid in v_set else "🔵 Labeled" if gid in l_set else "⚪ Pending", gid] for i, gid in enumerate(all_gids)]
return pd.DataFrame(data, columns=["#", "Status", "ID"])
with gr.Blocks(theme=gr.themes.Soft(), title="Labeler Pro") as demo:
state_mode, state_hist, state_urls, state_idx = gr.State("label"), gr.State([]), gr.State([]), gr.State(0)
with gr.Row():
top_stats = gr.Markdown("Loading...")
btn_home = gr.Button("🏠 Home", size="sm", scale=0)
with gr.Tabs():
with gr.Tab("Workspace"):
with gr.Group() as screen_menu:
gr.Markdown("# Property Labeler Pro")
with gr.Row():
b_start_l, b_start_v, b_start_f = gr.Button("Label", variant="primary"), gr.Button("Verify"), gr.Button("🛠 Fix Errors", variant="secondary")
with gr.Group(visible=False) as screen_work:
header_md = gr.Markdown()
img_objs, input_objs = [], []
with gr.Row():
for i in range(MAX_IMAGES):
with gr.Column(min_width=200):
img = gr.Image(interactive=False, height=240)
sld, drp, chk, lbl = gr.Slider(1, 10, step=1, label="Score"), gr.Dropdown(ROOM_CLASSES, label="Class"), gr.Checkbox(label="Correct?"), gr.Markdown()
img_objs.append(img); input_objs.extend([sld, drp, chk, lbl])
with gr.Row():
b_back, b_save = gr.Button("⬅ Back"), gr.Button("💾 Save & Next", variant="primary")
log_box = gr.Textbox(label="Status", interactive=False)
with gr.Tab("Catalog"):
with gr.Row():
num_in = gr.Number(value=1, label="Prop #", precision=0)
b_go_l, b_go_v, b_go_f = gr.Button("Go Label"), gr.Button("Go Verify"), gr.Button("Go Fix")
df_cat = gr.Dataframe(interactive=False)
b_ref_cat = gr.Button("Refresh Catalog")
ALL_IO = [screen_menu, screen_work, header_md, state_urls, state_hist, state_idx, top_stats, log_box] + img_objs + input_objs
b_start_l.click(lambda: "label", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO)
b_start_v.click(lambda: "verify", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO)
b_start_f.click(lambda: "fix", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO)
b_save.click(save_data, [state_mode, state_hist, state_urls] + input_objs, ALL_IO)
b_back.click(lambda m, h: render_workspace(m, h, move_back=True), [state_mode, state_hist], ALL_IO)
btn_home.click(lambda: {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), state_hist: []}, None, [screen_menu, screen_work, state_hist])
b_go_l.click(lambda: "label", None, state_mode).then(lambda n,m,h: render_workspace(m,h,int(n)-1), [num_in, state_mode, state_hist], ALL_IO)
b_go_v.click(lambda: "verify", None, state_mode).then(lambda n,m,h: render_workspace(m,h,int(n)-1), [num_in, state_mode, state_hist], ALL_IO)
b_go_f.click(lambda: "fix", None, state_mode).then(lambda n,m,h: render_workspace(m,h,int(n)-1), [num_in, state_mode, state_hist], ALL_IO)
b_ref_cat.click(refresh_cat, None, df_cat).then(get_stats_text, None, top_stats)
demo.load(refresh_cat, None, df_cat).then(get_stats_text, None, top_stats)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)