data-collection / app.py
Nightfury16's picture
Initial commit
de17047
import gradio as gr
import pandas as pd
import requests
import os
import csv
from io import BytesIO
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from filelock import FileLock
if os.path.exists("/data"):
DATA_DIR = "/data"
else:
DATA_DIR = "."
URL_FILE = "urls.txt"
LABEL_FILE = os.path.join(DATA_DIR, "annotations.csv")
VERIFY_FILE = os.path.join(DATA_DIR, "verifications.csv")
SKIP_FILE = os.path.join(DATA_DIR, "skipped.csv")
LOCK_FILE = os.path.join(DATA_DIR, "data.lock")
MAX_IMAGES = 6
THUMB_SIZE = (350, 350)
ROOM_CLASSES = ["living_room", "bedroom", "kitchen", "bathroom", "dining_room", "outdoor", "other"]
def init_files():
for f in [LABEL_FILE, VERIFY_FILE, SKIP_FILE]:
if not os.path.exists(f):
if f == LABEL_FILE: cols = ["timestamp", "user", "group_id", "url", "score", "label"]
elif f == VERIFY_FILE: cols = ["timestamp", "user", "group_id", "url", "is_correct", "corrected_label"]
else: cols = ["timestamp", "user", "group_id"]
pd.DataFrame(columns=cols).to_csv(f, index=False)
if not os.path.exists(URL_FILE):
with open(URL_FILE, "w") as f: f.write("")
init_files()
def get_image_optimized(url):
if not url: return None
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers, timeout=3)
if response.status_code == 200:
img = Image.open(BytesIO(response.content))
img.thumbnail(THUMB_SIZE, Image.Resampling.LANCZOS)
return img
except:
pass
return url
def get_ordered_groups():
groups = []
seen = set()
if not os.path.exists(URL_FILE): return []
with open(URL_FILE, 'r') as f:
for line in f:
u = line.strip()
if u:
try: gid = u.split("-m")[0].split("/")[-1]
except: gid = "unknown"
if gid not in seen:
groups.append(gid)
seen.add(gid)
return groups
def get_group_urls(target_gid):
urls = []
with open(URL_FILE, 'r') as f:
for line in f:
u = line.strip()
if u and target_gid in u:
urls.append(u)
return urls[:MAX_IMAGES]
def get_saved_values(gid, mode):
saved_data = {}
try:
filename = LABEL_FILE if mode == "label" else VERIFY_FILE
if os.path.exists(filename):
df = pd.read_csv(filename)
rows = df[df['group_id'] == gid]
for _, row in rows.iterrows():
if mode == "label":
saved_data[row['url']] = {"score": row['score'], "label": row['label']}
else:
saved_data[row['url']] = {"is_correct": row['is_correct'], "label": row['corrected_label']}
except: pass
return saved_data
def get_stats_text():
all_gids = get_ordered_groups()
try: l = len(pd.read_csv(LABEL_FILE)['group_id'].unique())
except: l = 0
return f"**Total Properties:** {len(all_gids)} | **Labeled:** {l}"
def render_workspace(mode, history, specific_index=None, move_back=False):
all_groups = get_ordered_groups()
total_groups = len(all_groups)
try: l_done = set(pd.read_csv(LABEL_FILE)['group_id'].unique())
except: l_done = set()
try: v_done = set(pd.read_csv(VERIFY_FILE)['group_id'].unique())
except: v_done = set()
try: s_done = set(pd.read_csv(SKIP_FILE)['group_id'].unique())
except: s_done = set()
target_gid = None
target_idx = -1
if specific_index is not None:
if 0 <= specific_index < total_groups:
target_gid = all_groups[specific_index]
target_idx = specific_index
else:
return {log_box: "End of list."}
elif move_back and len(history) > 1:
history.pop()
target_gid = history[-1]
try: target_idx = all_groups.index(target_gid)
except: target_idx = 0
else:
found = False
for i, gid in enumerate(all_groups):
if gid in s_done: continue
is_ready = False
if mode == "label" and gid not in l_done: is_ready = True
elif mode == "verify" and gid in l_done and gid not in v_done: is_ready = True
if is_ready:
target_gid = gid
target_idx = i
found = True
break
if not found:
return {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), log_box: "No more tasks found."}
urls = get_group_urls(target_gid)
if not history or history[-1] != target_gid:
history.append(target_gid)
saved_vals = get_saved_values(target_gid, mode)
r1_vals = get_saved_values(target_gid, "label") if mode == "verify" else {}
processed_images = [None] * MAX_IMAGES
with ThreadPoolExecutor(max_workers=MAX_IMAGES) as executor:
futures = {executor.submit(get_image_optimized, u): i for i, u in enumerate(urls)}
for future in futures:
processed_images[futures[future]] = future.result()
header = f"# Property #{target_idx + 1} <span style='font-size:14px;color:gray;'>(ID: {target_gid})</span>"
updates = {
screen_menu: gr.update(visible=False),
screen_work: gr.update(visible=True),
header_md: header,
state_urls: urls,
state_hist: history,
state_idx: target_idx,
top_stats: get_stats_text(),
log_box: f"Loaded group {target_gid}"
}
for i in range(MAX_IMAGES):
img_c = img_objs[i]
base = i * 4
c_sld, c_drp, c_chk, c_lbl = input_objs[base], input_objs[base+1], input_objs[base+2], input_objs[base+3]
if i < len(urls):
u = urls[i]
img_data = processed_images[i]
updates[img_c] = gr.update(value=img_data, visible=True)
v_sc = saved_vals.get(u, {}).get('score', 5)
v_lbl = saved_vals.get(u, {}).get('label', ROOM_CLASSES[0])
v_chk = saved_vals.get(u, {}).get('is_correct', True)
if mode == "label":
updates[c_sld] = gr.update(visible=True, value=v_sc)
updates[c_drp] = gr.update(visible=True, value=v_lbl)
updates[c_chk] = gr.update(visible=False)
updates[c_lbl] = gr.update(visible=False)
else:
prev = r1_vals.get(u, {}).get('label', "Unknown")
updates[c_sld] = gr.update(visible=False)
updates[c_lbl] = gr.update(visible=True, value=f"**Labeled:** {prev}")
updates[c_drp] = gr.update(visible=True, value=v_lbl)
updates[c_chk] = gr.update(visible=True, value=v_chk)
else:
updates[img_c] = gr.update(value=None, visible=False)
updates[c_sld] = gr.update(visible=False)
updates[c_drp] = gr.update(visible=False)
updates[c_chk] = gr.update(visible=False)
updates[c_lbl] = gr.update(visible=False)
return updates
def save_data(mode, history, urls, *args):
if not history: return
gid = history[-1]
ts = datetime.now().isoformat()
rows = []
for i, u in enumerate(urls):
base = i * 4
sc, lbl, chk = args[base], args[base+1], args[base+2]
if mode == "label": rows.append([ts, "user", gid, u, sc, lbl])
else: rows.append([ts, "user", gid, u, chk, lbl])
fname = LABEL_FILE if mode == "label" else VERIFY_FILE
with FileLock(LOCK_FILE):
with open(fname, "a", newline="") as f:
csv.writer(f).writerows(rows)
return render_workspace(mode, history)
def skip_group(idx, history, mode):
if history:
gid = history[-1]
with FileLock(LOCK_FILE):
with open(SKIP_FILE, "a", newline="") as f:
csv.writer(f).writerow([datetime.now().isoformat(), "user", gid])
return render_workspace(mode, history, specific_index=idx + 1)
def refresh_cat():
all_gids = get_ordered_groups()
try: l_set = set(pd.read_csv(LABEL_FILE)['group_id'].unique())
except: l_set = set()
try: v_set = set(pd.read_csv(VERIFY_FILE)['group_id'].unique())
except: v_set = set()
data = []
for i, gid in enumerate(all_gids):
s = "⚪ Pending"
if gid in v_set: s = "✅ Verified"
elif gid in l_set: s = "🔵 Labeled"
data.append([i+1, s, gid])
return pd.DataFrame(data, columns=["#", "Status", "ID"])
with gr.Blocks(title="Fast Labeler") as demo:
state_mode = gr.State("label")
state_hist = gr.State([])
state_urls = gr.State([])
state_idx = gr.State(0)
with gr.Row(variant="panel"):
top_stats = gr.Markdown("Loading stats...")
btn_home = gr.Button("🏠 Home", size="sm", scale=0)
with gr.Tabs():
with gr.Tab("Workspace", id=0):
with gr.Group() as screen_menu:
gr.Markdown("# Welcome! 👋")
gr.Markdown("Server-side compression enabled.")
with gr.Row():
b_start_l = gr.Button("Start Labeling", variant="primary")
b_start_v = gr.Button("Start Verification")
with gr.Group(visible=False) as screen_work:
header_md = gr.Markdown()
img_objs, input_objs = [], []
with gr.Row():
for i in range(MAX_IMAGES):
with gr.Column(min_width=250):
img = gr.Image(interactive=False, height=280, show_label=False)
with gr.Group():
sld = gr.Slider(1, 10, step=1, label="Score", visible=False)
lbl = gr.Markdown(visible=False)
drp = gr.Dropdown(ROOM_CLASSES, label="Class", visible=False)
chk = gr.Checkbox(label="Correct?", value=True, visible=False)
img_objs.append(img)
input_objs.extend([sld, drp, chk, lbl])
with gr.Row():
b_back = gr.Button("⬅ Back")
b_skip = gr.Button("Skip ➡")
b_save = gr.Button("💾 Save & Next", variant="primary")
log_box = gr.Textbox(label="Log", interactive=False, max_lines=1)
with gr.Tab("Catalog", id=1):
with gr.Row():
num_in = gr.Number(value=1, label="Property #", precision=0)
b_go_l = gr.Button("Go (Label)")
b_go_v = gr.Button("Go (Verify)")
df_cat = gr.Dataframe(interactive=False)
b_ref_cat = gr.Button("Refresh")
ALL_IO = [screen_menu, screen_work, header_md, state_urls, state_hist, state_idx, top_stats, log_box] + img_objs + input_objs
b_start_l.click(lambda: "label", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO)
b_start_v.click(lambda: "verify", None, state_mode).then(render_workspace, [state_mode, state_hist], ALL_IO)
b_save.click(save_data, [state_mode, state_hist, state_urls] + input_objs, ALL_IO)
b_skip.click(skip_group, [state_idx, state_hist, state_mode], ALL_IO)
b_back.click(lambda m, h: render_workspace(m, h, move_back=True), [state_mode, state_hist], ALL_IO)
btn_home.click(lambda: {screen_menu: gr.update(visible=True), screen_work: gr.update(visible=False), state_hist: []}, None, [screen_menu, screen_work, state_hist])
b_go_l.click(lambda: "label", None, state_mode).then(lambda n, m, h: render_workspace(m, h, specific_index=int(n)-1), [num_in, state_mode, state_hist], ALL_IO)
b_go_v.click(lambda: "verify", None, state_mode).then(lambda n, m, h: render_workspace(m, h, specific_index=int(n)-1), [num_in, state_mode, state_hist], ALL_IO)
b_ref_cat.click(refresh_cat, None, df_cat)
demo.load(refresh_cat, None, df_cat).then(get_stats_text, None, top_stats)
demo.queue().launch(theme=gr.themes.Soft())