Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| from collections import Counter, defaultdict | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| from PIL import Image | |
| ROOT = Path(__file__).resolve().parent | |
| AUDIT_PACK = ROOT / "audit_pack" | |
| DATA_FILE = AUDIT_PACK / "audit_samples.json" | |
| ANNOTATION_DIR = ROOT / "annotations" | |
| DATASET_REPO_ID = os.environ.get("ANNOTATION_DATASET_REPO", "nickname-xingxing/filter_data") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| HF_API = HfApi(token=HF_TOKEN) if HF_TOKEN else None | |
| AUTO_SYNC_EVERY = 10 | |
| def resolve_image_fs_path(record: dict) -> str: | |
| """JSON 里 image_path 为相对于 audit_pack 的路径(如 images/foo.png),需拼到 audit_pack 下。""" | |
| raw = (record.get("image_path") or "").strip() | |
| if not raw: | |
| return "" | |
| p = Path(raw) | |
| if p.is_absolute(): | |
| return str(p) if p.exists() else str(p) | |
| under_pack = AUDIT_PACK / raw | |
| if under_pack.exists(): | |
| return str(under_pack) | |
| legacy = ROOT / raw | |
| if legacy.exists(): | |
| return str(legacy) | |
| return str(under_pack) | |
| def load_image_for_ui(record: dict) -> tuple[Image.Image | None, str]: | |
| """返回 (PIL 图, 警告文案)。无文件或损坏时图为 None,警告非空。""" | |
| path_str = resolve_image_fs_path(record) | |
| if not path_str: | |
| return None, "⚠️ 记录中无 image_path" | |
| p = Path(path_str) | |
| if not p.exists(): | |
| return None, f"⚠️ 图片不存在: `{p}`(请确认 `audit_pack/images/` 已随仓库打包)" | |
| try: | |
| return Image.open(p).convert("RGB"), "" | |
| except OSError as e: | |
| return None, f"⚠️ 无法读取图片 `{p}`: {e}" | |
| def load_records(): | |
| with DATA_FILE.open("r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def load_existing_annotations(annotation_file: Path): | |
| annotations = {} | |
| if not annotation_file.exists(): | |
| return annotations | |
| with annotation_file.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| row = json.loads(line) | |
| annotations[row["sample_id"]] = row | |
| return annotations | |
| def write_annotations(annotation_file: Path, annotations): | |
| annotation_file.parent.mkdir(parents=True, exist_ok=True) | |
| tmp = annotation_file.with_suffix(annotation_file.suffix + ".tmp") | |
| with tmp.open("w", encoding="utf-8") as f: | |
| for sample_id in sorted(annotations): | |
| f.write(json.dumps(annotations[sample_id], ensure_ascii=False) + "\n") | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| tmp.replace(annotation_file) | |
| def sync_annotation_to_hf(annotation_file: Path, annotator_name: str) -> str: | |
| """ | |
| 将标注文件上传到 HF dataset repo。 | |
| 需要在 Space Secrets 里配置 `HF_TOKEN`。 | |
| """ | |
| if not annotation_file.exists(): | |
| return "本地标注文件不存在,未同步到 HF" | |
| if HF_API is None: | |
| return "未配置 HF_TOKEN,仅保存在 Space 运行容器中" | |
| path_in_repo = f"annotations/{annotator_name}.jsonl" | |
| try: | |
| HF_API.upload_file( | |
| path_or_fileobj=str(annotation_file), | |
| path_in_repo=path_in_repo, | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"Update annotations for {annotator_name}", | |
| ) | |
| return f"已同步到 HF dataset: {DATASET_REPO_ID}/{path_in_repo}" | |
| except Exception as e: | |
| return f"本地保存成功,但同步到 HF 失败: {e}" | |
| def merge_annotations_from_disk(annotation_file: Path, session_annotations: dict | None) -> dict: | |
| """ | |
| 以磁盘上的 jsonl 为基准,再用当前会话里的 dict 覆盖同 sample_id。 | |
| 避免 Gradio State 丢字典/变空时,整文件被覆盖成「只剩一条」。 | |
| """ | |
| disk = load_existing_annotations(annotation_file) | |
| sess = dict(session_annotations or {}) | |
| return {**disk, **sess} | |
| def build_stats(records, annotations): | |
| total = len(records) | |
| labeled = len(annotations) | |
| decision_counter = Counter(item["judgment"] for item in annotations.values()) | |
| by_dataset = defaultdict(lambda: Counter()) | |
| for record in records: | |
| dataset = record["dataset"] | |
| by_dataset[dataset]["total"] += 1 | |
| if record["sample_id"] in annotations: | |
| by_dataset[dataset]["labeled"] += 1 | |
| by_dataset[dataset][annotations[record["sample_id"]]["judgment"]] += 1 | |
| lines = [ | |
| "### 标注进度", | |
| "", | |
| f"- 总样本数: `{total}`", | |
| f"- 已标注: `{labeled}`", | |
| f"- 未标注: `{total - labeled}`", | |
| f"- `符合 prompt`: `{decision_counter.get('match', 0)}`", | |
| f"- `不符合 prompt`: `{decision_counter.get('mismatch', 0)}`", | |
| f"- `不确定`: `{decision_counter.get('unsure', 0)}`", | |
| "", | |
| "| 数据源 | 总数 | 已标注 | 符合 | 不符合 | 不确定 |", | |
| "| --- | ---: | ---: | ---: | ---: | ---: |", | |
| ] | |
| for dataset in sorted(by_dataset): | |
| counter = by_dataset[dataset] | |
| lines.append( | |
| f"| {dataset} | {counter['total']} | {counter['labeled']} | " | |
| f"{counter['match']} | {counter['mismatch']} | {counter['unsure']} |" | |
| ) | |
| return "\n".join(lines) | |
| def resolve_filters(records, dataset_filter, split_filter, status_filter, annotations): | |
| indices = [] | |
| for idx, record in enumerate(records): | |
| if dataset_filter != "all" and record["dataset"] != dataset_filter: | |
| continue | |
| if split_filter != "all" and record["split"] != split_filter: | |
| continue | |
| labeled = record["sample_id"] in annotations | |
| if status_filter == "labeled" and not labeled: | |
| continue | |
| if status_filter == "unlabeled" and labeled: | |
| continue | |
| indices.append(idx) | |
| return indices | |
| def render_record(record, annotation, position_text): | |
| judgment = annotation["judgment"] if annotation else "unsure" | |
| note = annotation["note"] if annotation else "" | |
| img, warn = load_image_for_ui(record) | |
| pos = f"{warn}\n{position_text}" if warn else position_text | |
| return ( | |
| img, | |
| record.get("prompt_display") or record.get("prompt_zh") or record.get("prompt", ""), | |
| judgment, | |
| note, | |
| pos, | |
| ) | |
| def coerce_record_index(cursor) -> int | None: | |
| """ | |
| Gradio 前端可能把 State 传成 float / 字符串(如 "5.0"),Python 里 int("5.0") 会报错。 | |
| 统一先 float 再 int;解析失败返回 None。 | |
| """ | |
| if cursor is None: | |
| return None | |
| if isinstance(cursor, bool): | |
| return None | |
| try: | |
| return int(float(cursor)) | |
| except (TypeError, ValueError): | |
| return None | |
| def normalize_cursor(cursor, pool: list[int]) -> int: | |
| """cursor 为 records 的全局下标;若不在当前 pool 内则落到 pool[0]。""" | |
| if not pool: | |
| return 0 | |
| c = coerce_record_index(cursor) | |
| if c is None: | |
| return pool[0] | |
| if c not in pool: | |
| return pool[0] | |
| return c | |
| def build_app(): | |
| records = load_records() | |
| dataset_choices = ["all"] + sorted({record["dataset"] for record in records}) | |
| split_choices = ["all", "retained", "rejected"] | |
| status_choices = ["all", "unlabeled", "labeled"] | |
| with gr.Blocks(title="过滤数据人工审核", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 过滤数据人工审核") | |
| gr.Markdown( | |
| "请先填写标注人名称,然后判断图片是否符合给定 promp,全部标注完成后请同步到HF上面" | |
| ) | |
| annotator = gr.Textbox(label="标注人名称", value="", placeholder="请先输入姓名/昵称") | |
| # 当前样本在 records 中的下标(全局),不是「在当前列表里的第几个」 | |
| cursor_record_idx = gr.State(0) | |
| annotations_state = gr.State({}) | |
| unsynced_count_state = gr.State(0) | |
| with gr.Row(): | |
| dataset_filter = gr.Dropdown(dataset_choices, value="all", label="数据源") | |
| split_filter = gr.Dropdown(split_choices, value="all", label="集合") | |
| status_filter = gr.Dropdown(status_choices, value="all", label="标注状态") | |
| refresh_btn = gr.Button("应用筛选", variant="primary") | |
| progress_md = gr.Markdown() | |
| sync_md = gr.Markdown("### 同步状态\n\n- 未同步修改: `0`\n- 建议标注一批后再点击一次“同步到HF”。") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image = gr.Image(label="图片", type="pil", height=480) | |
| position_box = gr.Textbox(label="当前位置", interactive=False) | |
| with gr.Column(scale=1): | |
| prompt_box = gr.Textbox(label="Prompt(中英双语)", lines=7, interactive=False) | |
| with gr.Row(): | |
| judgment = gr.Radio(choices=["match", "mismatch", "unsure"], value="unsure", label="人工判断") | |
| note = gr.Textbox(label="备注", lines=3, placeholder="可选备注") | |
| with gr.Row(): | |
| prev_btn = gr.Button("上一条") | |
| save_btn = gr.Button("保存标注", variant="primary") | |
| sync_btn = gr.Button("同步到HF", variant="secondary") | |
| next_btn = gr.Button("下一条") | |
| save_status = gr.Textbox(label="保存状态", interactive=False) | |
| def sync_status_text(n: int) -> str: | |
| return f"### 同步状态\n\n- 未同步修改: `{n}`\n- 系统会在累计达到 `{AUTO_SYNC_EVERY}` 条后自动同步一次。" | |
| def validate_annotator_name(name: str) -> str: | |
| who = (name or "").strip() | |
| return who | |
| def bootstrap(annotator_name): | |
| annotator_name = validate_annotator_name(annotator_name) | |
| if not annotator_name: | |
| return {}, 0, 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", "### 标注进度\n\n- 请先输入标注人名称", sync_status_text(0), "请先输入标注人名称后再开始标注" | |
| annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" | |
| annotations = load_existing_annotations(annotation_file) | |
| pool = resolve_filters(records, "all", "all", "all", annotations) | |
| if not pool: | |
| return annotations, 0, 0, None, "", "unsure", "", "0 / 0", build_stats(records, annotations), sync_status_text(0), "" | |
| cursor = pool[0] | |
| record = records[cursor] | |
| rendered = render_record( | |
| record, annotations.get(record["sample_id"]), f"1 / {len(pool)}" | |
| ) | |
| return annotations, int(cursor), 0, *rendered, build_stats(records, annotations), sync_status_text(0), "" | |
| def refresh_pool(dataset_value, split_value, status_value, annotations, annotator_name): | |
| annotator_name = validate_annotator_name(annotator_name) | |
| if not annotator_name: | |
| return ( | |
| {}, | |
| 0, | |
| 0, | |
| None, | |
| "请先输入标注人名称", | |
| "unsure", | |
| "", | |
| "0 / 0", | |
| "### 标注进度\n\n- 请先输入标注人名称", | |
| sync_status_text(0), | |
| "请先输入标注人名称后再开始标注", | |
| ) | |
| ann_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" | |
| annotations = merge_annotations_from_disk(ann_file, annotations) | |
| pool = resolve_filters(records, dataset_value, split_value, status_value, annotations) | |
| if not pool: | |
| return ( | |
| annotations, | |
| 0, | |
| 0, | |
| None, | |
| "", | |
| "unsure", | |
| "", | |
| "0 / 0", | |
| build_stats(records, annotations), | |
| sync_status_text(0), | |
| "没有可显示的样本", | |
| ) | |
| cursor = pool[0] | |
| record = records[cursor] | |
| rendered = render_record( | |
| record, annotations.get(record["sample_id"]), f"1 / {len(pool)}" | |
| ) | |
| return ( | |
| annotations, | |
| int(cursor), | |
| 0, | |
| *rendered, | |
| build_stats(records, annotations), | |
| sync_status_text(0), | |
| "筛选条件已更新", | |
| ) | |
| def move(delta, cursor, annotations, dataset_value, split_value, status_value, annotator_name): | |
| annotator_name = validate_annotator_name(annotator_name) | |
| if not annotator_name: | |
| return 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", "请先输入标注人名称后再开始标注" | |
| ann_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" | |
| annotations = merge_annotations_from_disk(ann_file, annotations) | |
| pool = resolve_filters(records, dataset_value, split_value, status_value, annotations) | |
| if not pool: | |
| return 0, None, "", "unsure", "", "0 / 0", "没有可显示的样本" | |
| c = normalize_cursor(cursor, pool) | |
| pos = pool.index(c) | |
| new_pos = max(0, min(pos + delta, len(pool) - 1)) | |
| new_c = pool[new_pos] | |
| record = records[new_c] | |
| rendered = render_record( | |
| record, annotations.get(record["sample_id"]), f"{new_pos + 1} / {len(pool)}" | |
| ) | |
| # 显式 int,避免 State 里残留 float 导致后续 in pool 判断异常 | |
| return int(new_c), *rendered, "" | |
| def on_prev(c, ann, d, s, st, who): | |
| return move(-1, c, ann, d, s, st, who) | |
| def on_next(c, ann, d, s, st, who): | |
| return move(1, c, ann, d, s, st, who) | |
| def save_annotation( | |
| cursor, | |
| decision, | |
| note_value, | |
| annotations, | |
| unsynced_count, | |
| annotator_name, | |
| dataset_value, | |
| split_value, | |
| status_value, | |
| ): | |
| annotator_name = validate_annotator_name(annotator_name) | |
| if not annotator_name: | |
| return ( | |
| annotations or {}, | |
| 0, | |
| unsynced_count or 0, | |
| None, | |
| "请先输入标注人名称", | |
| "unsure", | |
| "", | |
| "0 / 0", | |
| build_stats(records, annotations or {}), | |
| sync_status_text(unsynced_count or 0), | |
| "请先输入标注人名称后再保存", | |
| ) | |
| annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" | |
| merged = merge_annotations_from_disk(annotation_file, annotations) | |
| old_pool = resolve_filters(records, dataset_value, split_value, status_value, merged) | |
| if not old_pool: | |
| return ( | |
| merged, | |
| 0, | |
| unsynced_count or 0, | |
| None, | |
| "", | |
| "unsure", | |
| "", | |
| "0 / 0", | |
| build_stats(records, merged), | |
| sync_status_text(unsynced_count or 0), | |
| "没有可保存的样本", | |
| ) | |
| c = normalize_cursor(cursor, old_pool) | |
| record = records[c] | |
| record_idx = c | |
| pos = old_pool.index(record_idx) | |
| row = { | |
| "sample_id": record["sample_id"], | |
| "annotator": annotator_name, | |
| "judgment": decision, | |
| "note": note_value, | |
| "updated_at": datetime.now(timezone.utc).isoformat(), | |
| "dataset": record["dataset"], | |
| "split": record["split"], | |
| } | |
| to_write = {**merged, record["sample_id"]: row} | |
| try: | |
| write_annotations(annotation_file, to_write) | |
| except OSError as e: | |
| cur = normalize_cursor(cursor, old_pool) | |
| rec = records[cur] | |
| ptxt = f"{old_pool.index(cur) + 1} / {len(old_pool)}" | |
| rendered = render_record(rec, merged.get(rec["sample_id"]), ptxt) | |
| return ( | |
| merged, | |
| int(cur), | |
| unsynced_count or 0, | |
| *rendered, | |
| build_stats(records, merged), | |
| sync_status_text(unsynced_count or 0), | |
| f"写入失败(请检查磁盘是否可写): {e}", | |
| ) | |
| annotations = to_write | |
| next_unsynced = int(unsynced_count or 0) + 1 | |
| sync_msg = "" | |
| if next_unsynced >= AUTO_SYNC_EVERY: | |
| sync_msg = sync_annotation_to_hf(annotation_file, annotator_name) | |
| if sync_msg.startswith("已同步到 HF dataset"): | |
| next_unsynced = 0 | |
| new_pool = resolve_filters(records, dataset_value, split_value, status_value, annotations) | |
| stats = build_stats(records, annotations) | |
| if not new_pool: | |
| if next_unsynced > 0: | |
| final_sync_msg = sync_annotation_to_hf(annotation_file, annotator_name) | |
| if final_sync_msg.startswith("已同步到 HF dataset"): | |
| next_unsynced = 0 | |
| sync_msg = (sync_msg + ";" + final_sync_msg).strip(";") | |
| return ( | |
| annotations, | |
| 0, | |
| next_unsynced, | |
| None, | |
| "", | |
| "unsure", | |
| "", | |
| "0 / 0", | |
| stats, | |
| sync_status_text(next_unsynced), | |
| f"已保存 {record['sample_id']};{sync_msg or '仅本地缓存,尚未同步到HF'}", | |
| ) | |
| if record_idx in new_pool: | |
| new_cursor = record_idx | |
| elif pos > 0: | |
| prev_idx = old_pool[pos - 1] | |
| new_cursor = prev_idx if prev_idx in new_pool else new_pool[0] | |
| else: | |
| new_cursor = new_pool[0] | |
| rec = records[new_cursor] | |
| rendered = render_record( | |
| rec, | |
| annotations.get(rec["sample_id"]), | |
| f"{new_pool.index(new_cursor) + 1} / {len(new_pool)}", | |
| ) | |
| return ( | |
| annotations, | |
| int(new_cursor), | |
| next_unsynced, | |
| *rendered, | |
| stats, | |
| sync_status_text(next_unsynced), | |
| f"已保存 {record['sample_id']};{sync_msg or '仅本地缓存,尚未同步到HF'}", | |
| ) | |
| def sync_current_annotations(annotations, unsynced_count, annotator_name): | |
| annotator_name = validate_annotator_name(annotator_name) | |
| if not annotator_name: | |
| return annotations or {}, unsynced_count or 0, build_stats(records, annotations or {}), sync_status_text(unsynced_count or 0), "请先输入标注人名称后再同步" | |
| annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl" | |
| merged = merge_annotations_from_disk(annotation_file, annotations) | |
| if not merged: | |
| return merged, 0, build_stats(records, merged), sync_status_text(0), "当前没有可同步的标注" | |
| if not int(unsynced_count or 0): | |
| return merged, 0, build_stats(records, merged), sync_status_text(0), "当前没有未同步修改" | |
| sync_msg = sync_annotation_to_hf(annotation_file, annotator_name) | |
| next_unsynced = 0 if sync_msg.startswith("已同步到 HF dataset") else int(unsynced_count or 0) | |
| return merged, next_unsynced, build_stats(records, merged), sync_status_text(next_unsynced), sync_msg | |
| annotator.change( | |
| bootstrap, | |
| inputs=[annotator], | |
| outputs=[ | |
| annotations_state, | |
| cursor_record_idx, | |
| unsynced_count_state, | |
| image, | |
| prompt_box, | |
| judgment, | |
| note, | |
| position_box, | |
| progress_md, | |
| sync_md, | |
| save_status, | |
| ], | |
| ) | |
| refresh_btn.click( | |
| refresh_pool, | |
| inputs=[dataset_filter, split_filter, status_filter, annotations_state, annotator], | |
| outputs=[ | |
| annotations_state, | |
| cursor_record_idx, | |
| unsynced_count_state, | |
| image, | |
| prompt_box, | |
| judgment, | |
| note, | |
| position_box, | |
| progress_md, | |
| sync_md, | |
| save_status, | |
| ], | |
| ) | |
| prev_btn.click( | |
| on_prev, | |
| inputs=[ | |
| cursor_record_idx, | |
| annotations_state, | |
| dataset_filter, | |
| split_filter, | |
| status_filter, | |
| annotator, | |
| ], | |
| outputs=[cursor_record_idx, image, prompt_box, judgment, note, position_box, save_status], | |
| ) | |
| next_btn.click( | |
| on_next, | |
| inputs=[ | |
| cursor_record_idx, | |
| annotations_state, | |
| dataset_filter, | |
| split_filter, | |
| status_filter, | |
| annotator, | |
| ], | |
| outputs=[cursor_record_idx, image, prompt_box, judgment, note, position_box, save_status], | |
| ) | |
| save_btn.click( | |
| save_annotation, | |
| inputs=[ | |
| cursor_record_idx, | |
| judgment, | |
| note, | |
| annotations_state, | |
| unsynced_count_state, | |
| annotator, | |
| dataset_filter, | |
| split_filter, | |
| status_filter, | |
| ], | |
| outputs=[ | |
| annotations_state, | |
| cursor_record_idx, | |
| unsynced_count_state, | |
| image, | |
| prompt_box, | |
| judgment, | |
| note, | |
| position_box, | |
| progress_md, | |
| sync_md, | |
| save_status, | |
| ], | |
| ) | |
| sync_btn.click( | |
| sync_current_annotations, | |
| inputs=[annotations_state, unsynced_count_state, annotator], | |
| outputs=[annotations_state, unsynced_count_state, progress_md, sync_md, save_status], | |
| ) | |
| demo.load( | |
| lambda: bootstrap("annotator_1"), | |
| outputs=[ | |
| annotations_state, | |
| cursor_record_idx, | |
| unsynced_count_state, | |
| image, | |
| prompt_box, | |
| judgment, | |
| note, | |
| position_box, | |
| progress_md, | |
| sync_md, | |
| save_status, | |
| ], | |
| ) | |
| return demo | |
| demo = build_app() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |