from __future__ import annotations import json import os from collections import Counter, defaultdict from datetime import datetime, timezone from pathlib import Path import gradio as gr from huggingface_hub import HfApi from PIL import Image ROOT = Path(__file__).resolve().parent AUDIT_PACK = ROOT / "audit_pack" DATA_FILE = AUDIT_PACK / "audit_samples.json" ANNOTATION_DIR = ROOT / "annotations" DATASET_REPO_ID = os.environ.get("ANNOTATION_DATASET_REPO", "nickname-xingxing/filter_data") HF_TOKEN = os.environ.get("HF_TOKEN") HF_API = HfApi(token=HF_TOKEN) if HF_TOKEN else None AUTO_SYNC_EVERY = 10 def resolve_image_fs_path(record: dict) -> str: """JSON 里 image_path 为相对于 audit_pack 的路径(如 images/foo.png),需拼到 audit_pack 下。""" raw = (record.get("image_path") or "").strip() if not raw: return "" p = Path(raw) if p.is_absolute(): return str(p) if p.exists() else str(p) under_pack = AUDIT_PACK / raw if under_pack.exists(): return str(under_pack) legacy = ROOT / raw if legacy.exists(): return str(legacy) return str(under_pack) def load_image_for_ui(record: dict) -> tuple[Image.Image | None, str]: """返回 (PIL 图, 警告文案)。无文件或损坏时图为 None,警告非空。""" path_str = resolve_image_fs_path(record) if not path_str: return None, "⚠️ 记录中无 image_path" p = Path(path_str) if not p.exists(): return None, f"⚠️ 图片不存在: `{p}`(请确认 `audit_pack/images/` 已随仓库打包)" try: return Image.open(p).convert("RGB"), "" except OSError as e: return None, f"⚠️ 无法读取图片 `{p}`: {e}" def load_records(): with DATA_FILE.open("r", encoding="utf-8") as f: return json.load(f) def load_existing_annotations(annotation_file: Path): annotations = {} if not annotation_file.exists(): return annotations with annotation_file.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if line: row = json.loads(line) annotations[row["sample_id"]] = row return annotations def write_annotations(annotation_file: Path, annotations): annotation_file.parent.mkdir(parents=True, exist_ok=True) tmp = annotation_file.with_suffix(annotation_file.suffix + ".tmp") with tmp.open("w", encoding="utf-8") as f: for sample_id in sorted(annotations): f.write(json.dumps(annotations[sample_id], ensure_ascii=False) + "\n") f.flush() os.fsync(f.fileno()) tmp.replace(annotation_file) def sync_annotation_to_hf(annotation_file: Path, annotator_name: str) -> str: """ 将标注文件上传到 HF dataset repo。 需要在 Space Secrets 里配置 `HF_TOKEN`。 """ if not annotation_file.exists(): return "本地标注文件不存在,未同步到 HF" if HF_API is None: return "未配置 HF_TOKEN,仅保存在 Space 运行容器中" path_in_repo = f"annotations/{annotator_name}.jsonl" try: HF_API.upload_file( path_or_fileobj=str(annotation_file), path_in_repo=path_in_repo, repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message=f"Update annotations for {annotator_name}", ) return f"已同步到 HF dataset: {DATASET_REPO_ID}/{path_in_repo}" except Exception as e: return f"本地保存成功,但同步到 HF 失败: {e}" def merge_annotations_from_disk(annotation_file: Path, session_annotations: dict | None) -> dict: """ 以磁盘上的 jsonl 为基准,再用当前会话里的 dict 覆盖同 sample_id。 避免 Gradio State 丢字典/变空时,整文件被覆盖成「只剩一条」。 """ disk = load_existing_annotations(annotation_file) sess = dict(session_annotations or {}) return {**disk, **sess} def build_stats(records, annotations): total = len(records) labeled = len(annotations) decision_counter = Counter(item["judgment"] for item in annotations.values()) by_dataset = defaultdict(lambda: Counter()) for record in records: dataset = record["dataset"] by_dataset[dataset]["total"] += 1 if record["sample_id"] in annotations: by_dataset[dataset]["labeled"] += 1 by_dataset[dataset][annotations[record["sample_id"]]["judgment"]] += 1 lines = [ "### 标注进度", "", f"- 总样本数: `{total}`", f"- 已标注: `{labeled}`", f"- 未标注: `{total - labeled}`", f"- `符合 prompt`: `{decision_counter.get('match', 0)}`", f"- `不符合 prompt`: `{decision_counter.get('mismatch', 0)}`", f"- `不确定`: `{decision_counter.get('unsure', 0)}`", "", "| 数据源 | 总数 | 已标注 | 符合 | 不符合 | 不确定 |", "| --- | ---: | ---: | ---: | ---: | ---: |", ] for dataset in sorted(by_dataset): counter = by_dataset[dataset] lines.append( f"| {dataset} | {counter['total']} | {counter['labeled']} | " f"{counter['match']} | {counter['mismatch']} | {counter['unsure']} |" ) return "\n".join(lines) def resolve_filters(records, dataset_filter, split_filter, status_filter, annotations): indices = [] for idx, record in enumerate(records): if dataset_filter != "all" and record["dataset"] != dataset_filter: continue if split_filter != "all" and record["split"] != split_filter: continue labeled = record["sample_id"] in annotations if status_filter == "labeled" and not labeled: continue if status_filter == "unlabeled" and labeled: continue indices.append(idx) return indices def render_record(record, annotation, position_text): judgment = annotation["judgment"] if annotation else "unsure" note = annotation["note"] if annotation else "" img, warn = load_image_for_ui(record) pos = f"{warn}\n{position_text}" if warn else position_text return ( img, record.get("prompt_display") or record.get("prompt_zh") or record.get("prompt", ""), judgment, note, pos, ) def coerce_record_index(cursor) -> int | None: """ Gradio 前端可能把 State 传成 float / 字符串(如 "5.0"),Python 里 int("5.0") 会报错。 统一先 float 再 int;解析失败返回 None。 """ if cursor is None: return None if isinstance(cursor, bool): return None try: return int(float(cursor)) except (TypeError, ValueError): return None def normalize_cursor(cursor, pool: list[int]) -> int: """cursor 为 records 的全局下标;若不在当前 pool 内则落到 pool[0]。""" if not pool: return 0 c = coerce_record_index(cursor) if c is None: return pool[0] if c not in pool: return pool[0] return c def build_app(): records = load_records() dataset_choices = ["all"] + sorted({record["dataset"] for record in records}) split_choices = ["all", "retained", "rejected"] status_choices = ["all", "unlabeled", "labeled"] with gr.Blocks(title="过滤数据人工审核", theme=gr.themes.Soft()) as demo: gr.Markdown("# 过滤数据人工审核") gr.Markdown( "请先填写标注人名称,然后判断图片是否符合给定 promp,全部标注完成后请同步到HF上面" ) annotator = gr.Textbox(label="标注人名称", value="", placeholder="请先输入姓名/昵称") # 当前样本在 records 中的下标(全局),不是「在当前列表里的第几个」 cursor_record_idx = gr.State(0) annotations_state = gr.State({}) unsynced_count_state = gr.State(0) with gr.Row(): dataset_filter = gr.Dropdown(dataset_choices, value="all", label="数据源") split_filter = gr.Dropdown(split_choices, value="all", label="集合") status_filter = gr.Dropdown(status_choices, value="all", label="标注状态") refresh_btn = gr.Button("应用筛选", variant="primary") progress_md = gr.Markdown() sync_md = gr.Markdown("### 同步状态\n\n- 未同步修改: `0`\n- 建议标注一批后再点击一次“同步到HF”。") with gr.Row(): with gr.Column(scale=1): image = gr.Image(label="图片", type="pil", height=480) position_box = gr.Textbox(label="当前位置", interactive=False) with gr.Column(scale=1): prompt_box = gr.Textbox(label="Prompt(中英双语)", lines=7, interactive=False) with gr.Row(): judgment = gr.Radio(choices=["match", "mismatch", "unsure"], value="unsure", label="人工判断") note = gr.Textbox(label="备注", lines=3, placeholder="可选备注") with gr.Row(): prev_btn = gr.Button("上一条") save_btn = gr.Button("保存标注", variant="primary") sync_btn = gr.Button("同步到HF", variant="secondary") next_btn = gr.Button("下一条") save_status = gr.Textbox(label="保存状态", interactive=False) def sync_status_text(n: int) -> str: return f"### 同步状态\n\n- 未同步修改: `{n}`\n- 系统会在累计达到 `{AUTO_SYNC_EVERY}` 条后自动同步一次。" def validate_annotator_name(name: str) -> str: who = (name or "").strip() return who def bootstrap(annotator_name): annotator_name = validate_annotator_name(annotator_name) if not annotator_name: return {}, 0, 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", "### 标注进度\n\n- 请先输入标注人名称", sync_status_text(0), "请先输入标注人名称后再开始标注" annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" annotations = load_existing_annotations(annotation_file) pool = resolve_filters(records, "all", "all", "all", annotations) if not pool: return annotations, 0, 0, None, "", "unsure", "", "0 / 0", build_stats(records, annotations), sync_status_text(0), "" cursor = pool[0] record = records[cursor] rendered = render_record( record, annotations.get(record["sample_id"]), f"1 / {len(pool)}" ) return annotations, int(cursor), 0, *rendered, build_stats(records, annotations), sync_status_text(0), "" def refresh_pool(dataset_value, split_value, status_value, annotations, annotator_name): annotator_name = validate_annotator_name(annotator_name) if not annotator_name: return ( {}, 0, 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", "### 标注进度\n\n- 请先输入标注人名称", sync_status_text(0), "请先输入标注人名称后再开始标注", ) ann_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" annotations = merge_annotations_from_disk(ann_file, annotations) pool = resolve_filters(records, dataset_value, split_value, status_value, annotations) if not pool: return ( annotations, 0, 0, None, "", "unsure", "", "0 / 0", build_stats(records, annotations), sync_status_text(0), "没有可显示的样本", ) cursor = pool[0] record = records[cursor] rendered = render_record( record, annotations.get(record["sample_id"]), f"1 / {len(pool)}" ) return ( annotations, int(cursor), 0, *rendered, build_stats(records, annotations), sync_status_text(0), "筛选条件已更新", ) def move(delta, cursor, annotations, dataset_value, split_value, status_value, annotator_name): annotator_name = validate_annotator_name(annotator_name) if not annotator_name: return 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", "请先输入标注人名称后再开始标注" ann_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" annotations = merge_annotations_from_disk(ann_file, annotations) pool = resolve_filters(records, dataset_value, split_value, status_value, annotations) if not pool: return 0, None, "", "unsure", "", "0 / 0", "没有可显示的样本" c = normalize_cursor(cursor, pool) pos = pool.index(c) new_pos = max(0, min(pos + delta, len(pool) - 1)) new_c = pool[new_pos] record = records[new_c] rendered = render_record( record, annotations.get(record["sample_id"]), f"{new_pos + 1} / {len(pool)}" ) # 显式 int,避免 State 里残留 float 导致后续 in pool 判断异常 return int(new_c), *rendered, "" def on_prev(c, ann, d, s, st, who): return move(-1, c, ann, d, s, st, who) def on_next(c, ann, d, s, st, who): return move(1, c, ann, d, s, st, who) def save_annotation( cursor, decision, note_value, annotations, unsynced_count, annotator_name, dataset_value, split_value, status_value, ): annotator_name = validate_annotator_name(annotator_name) if not annotator_name: return ( annotations or {}, 0, unsynced_count or 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", build_stats(records, annotations or {}), sync_status_text(unsynced_count or 0), "请先输入标注人名称后再保存", ) annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" merged = merge_annotations_from_disk(annotation_file, annotations) old_pool = resolve_filters(records, dataset_value, split_value, status_value, merged) if not old_pool: return ( merged, 0, unsynced_count or 0, None, "", "unsure", "", "0 / 0", build_stats(records, merged), sync_status_text(unsynced_count or 0), "没有可保存的样本", ) c = normalize_cursor(cursor, old_pool) record = records[c] record_idx = c pos = old_pool.index(record_idx) row = { "sample_id": record["sample_id"], "annotator": annotator_name, "judgment": decision, "note": note_value, "updated_at": datetime.now(timezone.utc).isoformat(), "dataset": record["dataset"], "split": record["split"], } to_write = {**merged, record["sample_id"]: row} try: write_annotations(annotation_file, to_write) except OSError as e: cur = normalize_cursor(cursor, old_pool) rec = records[cur] ptxt = f"{old_pool.index(cur) + 1} / {len(old_pool)}" rendered = render_record(rec, merged.get(rec["sample_id"]), ptxt) return ( merged, int(cur), unsynced_count or 0, *rendered, build_stats(records, merged), sync_status_text(unsynced_count or 0), f"写入失败(请检查磁盘是否可写): {e}", ) annotations = to_write next_unsynced = int(unsynced_count or 0) + 1 sync_msg = "" if next_unsynced >= AUTO_SYNC_EVERY: sync_msg = sync_annotation_to_hf(annotation_file, annotator_name) if sync_msg.startswith("已同步到 HF dataset"): next_unsynced = 0 new_pool = resolve_filters(records, dataset_value, split_value, status_value, annotations) stats = build_stats(records, annotations) if not new_pool: if next_unsynced > 0: final_sync_msg = sync_annotation_to_hf(annotation_file, annotator_name) if final_sync_msg.startswith("已同步到 HF dataset"): next_unsynced = 0 sync_msg = (sync_msg + ";" + final_sync_msg).strip(";") return ( annotations, 0, next_unsynced, None, "", "unsure", "", "0 / 0", stats, sync_status_text(next_unsynced), f"已保存 {record['sample_id']};{sync_msg or '仅本地缓存,尚未同步到HF'}", ) if record_idx in new_pool: new_cursor = record_idx elif pos > 0: prev_idx = old_pool[pos - 1] new_cursor = prev_idx if prev_idx in new_pool else new_pool[0] else: new_cursor = new_pool[0] rec = records[new_cursor] rendered = render_record( rec, annotations.get(rec["sample_id"]), f"{new_pool.index(new_cursor) + 1} / {len(new_pool)}", ) return ( annotations, int(new_cursor), next_unsynced, *rendered, stats, sync_status_text(next_unsynced), f"已保存 {record['sample_id']};{sync_msg or '仅本地缓存,尚未同步到HF'}", ) def sync_current_annotations(annotations, unsynced_count, annotator_name): annotator_name = validate_annotator_name(annotator_name) if not annotator_name: return annotations or {}, unsynced_count or 0, build_stats(records, annotations or {}), sync_status_text(unsynced_count or 0), "请先输入标注人名称后再同步" annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" merged = merge_annotations_from_disk(annotation_file, annotations) if not merged: return merged, 0, build_stats(records, merged), sync_status_text(0), "当前没有可同步的标注" if not int(unsynced_count or 0): return merged, 0, build_stats(records, merged), sync_status_text(0), "当前没有未同步修改" sync_msg = sync_annotation_to_hf(annotation_file, annotator_name) next_unsynced = 0 if sync_msg.startswith("已同步到 HF dataset") else int(unsynced_count or 0) return merged, next_unsynced, build_stats(records, merged), sync_status_text(next_unsynced), sync_msg annotator.change( bootstrap, inputs=[annotator], outputs=[ annotations_state, cursor_record_idx, unsynced_count_state, image, prompt_box, judgment, note, position_box, progress_md, sync_md, save_status, ], ) refresh_btn.click( refresh_pool, inputs=[dataset_filter, split_filter, status_filter, annotations_state, annotator], outputs=[ annotations_state, cursor_record_idx, unsynced_count_state, image, prompt_box, judgment, note, position_box, progress_md, sync_md, save_status, ], ) prev_btn.click( on_prev, inputs=[ cursor_record_idx, annotations_state, dataset_filter, split_filter, status_filter, annotator, ], outputs=[cursor_record_idx, image, prompt_box, judgment, note, position_box, save_status], ) next_btn.click( on_next, inputs=[ cursor_record_idx, annotations_state, dataset_filter, split_filter, status_filter, annotator, ], outputs=[cursor_record_idx, image, prompt_box, judgment, note, position_box, save_status], ) save_btn.click( save_annotation, inputs=[ cursor_record_idx, judgment, note, annotations_state, unsynced_count_state, annotator, dataset_filter, split_filter, status_filter, ], outputs=[ annotations_state, cursor_record_idx, unsynced_count_state, image, prompt_box, judgment, note, position_box, progress_md, sync_md, save_status, ], ) sync_btn.click( sync_current_annotations, inputs=[annotations_state, unsynced_count_state, annotator], outputs=[annotations_state, unsynced_count_state, progress_md, sync_md, save_status], ) demo.load( lambda: bootstrap("annotator_1"), outputs=[ annotations_state, cursor_record_idx, unsynced_count_state, image, prompt_box, judgment, note, position_box, progress_md, sync_md, save_status, ], ) return demo demo = build_app() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)