LabelingApp / app.py
MichaelDeutges's picture
Update app.py
6b20f6d verified
# app.py
import os
import glob
from pathlib import Path
from datetime import datetime, timezone
import streamlit as st
import pandas as pd
from PIL import Image
from filelock import FileLock
from huggingface_hub import HfApi
# =========================
# Config (via env variables)
# =========================
IMAGE_DIR = os.getenv("IMAGE_DIR", "images")
LABELS_CSV = os.getenv("LABELS_CSV", "labels.csv")
SUPPORTED_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tif", ".tiff", ".webp")
# Button labels
LABEL_NONBLAST = os.getenv("LABEL_NONBLAST", "NON-BLAST")
LABEL_BLAST = os.getenv("LABEL_BLAST", "BLAST")
LABEL_UNCERTAIN = os.getenv("LABEL_UNCERTAIN", "UNCERTAIN")
LABEL_TRASH = os.getenv("LABEL_TRASH", "LOW_QUALITY")
# Optional Hub sync
HF_TOKEN = os.getenv("HF_TOKEN", "") # set in Space Secrets
SPACE_ID = os.getenv("SPACE_ID", "") # e.g. "org-or-user/your-space"
DATASET_REPO = os.getenv("DATASET_REPO", "") # e.g. "org-or-user/blast-labels" (recommended)
api = HfApi()
st.set_page_config(page_title="Blast Cell Labeling", layout="centered")
# ==========
# Utilities
# ==========
def list_images() -> list[str]:
paths: list[str] = []
for p in glob.glob(os.path.join(IMAGE_DIR, "**", "*"), recursive=True):
if p.lower().endswith(SUPPORTED_EXTS):
paths.append(p)
return sorted(paths)
def read_labels() -> pd.DataFrame:
if os.path.exists(LABELS_CSV):
try:
return pd.read_csv(LABELS_CSV)
except Exception:
pass
return pd.DataFrame(columns=["image", "label", "annotator", "timestamp"])
def rel_to_image_dir(p: str) -> str:
try:
return str(Path(p).resolve().relative_to(Path(IMAGE_DIR).resolve()))
except Exception:
return p
def write_label(image_path: str, label: str, annotator: str) -> None:
"""Append one row to labels.csv (local only)."""
dirpath = os.path.dirname(LABELS_CSV)
if dirpath and not os.path.isdir(dirpath):
os.makedirs(dirpath, exist_ok=True)
record = {
"image": rel_to_image_dir(image_path),
"label": label,
"annotator": annotator,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
with FileLock(LABELS_CSV + ".lock", timeout=10):
exists = os.path.exists(LABELS_CSV)
pd.DataFrame([record]).to_csv(
LABELS_CSV, mode="a", header=not exists, index=False
)
st.session_state["unsynced"] = True # mark as needing sync
def sync_to_hub() -> tuple[bool, str]:
"""
Manually upload labels.csv to the Hub.
Prefer pushing to a dataset repo to avoid Space restarts.
"""
if not os.path.exists(LABELS_CSV):
return False, "labels.csv not found — nothing to sync."
if not HF_TOKEN:
return False, "HF_TOKEN not set in Space secrets."
try:
if DATASET_REPO:
# Recommended: push to dataset repo (does NOT restart the Space)
api.upload_file(
path_or_fileobj=LABELS_CSV,
path_in_repo="labels.csv",
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN,
)
elif SPACE_ID:
# Fallback: push to Space repo (can trigger a restart)
api.upload_file(
path_or_fileobj=LABELS_CSV,
path_in_repo="labels.csv",
repo_id=SPACE_ID,
repo_type="space",
token=HF_TOKEN,
)
else:
return False, "Set DATASET_REPO (preferred) or SPACE_ID to enable uploads."
st.session_state["unsynced"] = False
return True, "Synced labels.csv to the Hub."
except Exception as e:
return False, f"Sync failed: {e}"
# =================
# Sidebar / Header
# =================
st.title("🏷️ Blast Cell Labeling App")
st.write(
"Enter your name, click **Start**, then classify each image. "
f"Use **{LABEL_UNCERTAIN}** if you’re not sure. "
f"Use the 🗑️ icon for **{LABEL_TRASH}** images."
)
with st.sidebar:
default_name = st.session_state.get("annotator", "")
annotator = st.text_input("Your name*", value=default_name, placeholder="e.g., Dr. Smith")
c1, c2 = st.columns(2)
with c1:
start_btn = st.button("Start", type="primary")
with c2:
reset_btn = st.button("🔄 Reset Session")
continue_by_name = st.toggle(
"Only show images I haven't labeled yet",
value=True,
help="Continue where you left off, based on your name in labels.csv",
)
sync_now = st.button("📤 Sync to Hub", help="Upload labels.csv to the Hub")
# =====================
# Session state defaults
# =====================
st.session_state.setdefault("order", [])
st.session_state.setdefault("idx", 0)
st.session_state.setdefault("total", 0)
st.session_state.setdefault("started", False)
st.session_state.setdefault("unsynced", False)
# =========
# Actions
# =========
if sync_now:
ok, msg = sync_to_hub()
(st.sidebar.success if ok else st.sidebar.error)(msg)
if reset_btn:
st.session_state.update({"started": False, "order": [], "idx": 0, "total": 0})
st.rerun()
if start_btn:
if not annotator.strip():
st.sidebar.error("Please enter your name.")
else:
st.session_state["annotator"] = annotator.strip()
imgs = list_images()
labels_df = read_labels()
if continue_by_name and not labels_df.empty:
already = set(
labels_df.query("annotator == @annotator")["image"].astype(str).tolist()
)
rel_imgs = [rel_to_image_dir(p) for p in imgs]
imgs = [p for p, r in zip(imgs, rel_imgs) if r not in already]
if not imgs:
st.warning("No images found (or all labeled). Upload images to the `images/` folder.")
else:
st.session_state.update(
{"order": imgs, "idx": 0, "total": len(imgs), "started": True}
)
# =========
# Main area
# =========
if not st.session_state.started:
st.info("Fill your name on the left and press **Start**.")
else:
idx = st.session_state.idx
total = st.session_state.total
if idx >= total:
st.success("All done 🎉 Thank you!")
else:
current_image = st.session_state.order[idx]
#st.caption(f"{idx+1} / {total}")
st.caption(f"{148-total+idx+1} / {148}")
# top-right trash button
#spacer, trash_col = st.columns([9, 1])
#with trash_col:
# if st.button("🗑️", help=f"Mark as {LABEL_TRASH}", use_container_width=True):
# write_label(current_image, LABEL_TRASH, annotator.strip())
# st.session_state.idx += 1
# st.rerun()
# image display
try:
img = Image.open(current_image)
if getattr(img, "n_frames", 1) > 1:
img.seek(0)
if img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
st.image(img, use_container_width=True)
except Exception as e:
st.warning(f"Could not display image: {current_image}\n{e}")
# three main buttons: NON-BLAST | UNCERTAIN | BLAST
c_left, c_mid, c_right, trash_col = st.columns([1, 1, 1, 1])
with c_left:
if st.button(f"⬅️ {LABEL_NONBLAST}", use_container_width=True):
write_label(current_image, LABEL_NONBLAST, annotator.strip())
st.session_state.idx += 1
st.rerun()
with c_mid:
if st.button(f"❓ {LABEL_UNCERTAIN}", use_container_width=True):
write_label(current_image, LABEL_UNCERTAIN, annotator.strip())
st.session_state.idx += 1
st.rerun()
with c_right:
if st.button(f"{LABEL_BLAST} ➡️", use_container_width=True):
write_label(current_image, LABEL_BLAST, annotator.strip())
st.session_state.idx += 1
st.rerun()
with trash_col:
if st.button("🗑️", help=f"Mark as {LABEL_TRASH}", use_container_width=True):
write_label(current_image, LABEL_TRASH, annotator.strip())
st.session_state.idx += 1
st.rerun()
# ======
# Footer
# ======
st.divider()
sync_note = " (unsynced changes)" if st.session_state.get("unsynced") else ""
target_repo = DATASET_REPO or SPACE_ID or "—"
st.caption(
f"Use **📤 Sync to Hub** to save your progress."
)