filter_audit / app.py
nickname-xingxing's picture
Update app.py
8c9473c verified
from __future__ import annotations
import json
import os
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path
import gradio as gr
from huggingface_hub import HfApi
from PIL import Image
ROOT = Path(__file__).resolve().parent
AUDIT_PACK = ROOT / "audit_pack"
DATA_FILE = AUDIT_PACK / "audit_samples.json"
ANNOTATION_DIR = ROOT / "annotations"
DATASET_REPO_ID = os.environ.get("ANNOTATION_DATASET_REPO", "nickname-xingxing/filter_data")
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_API = HfApi(token=HF_TOKEN) if HF_TOKEN else None
AUTO_SYNC_EVERY = 10
def resolve_image_fs_path(record: dict) -> str:
"""JSON 里 image_path 为相对于 audit_pack 的路径(如 images/foo.png),需拼到 audit_pack 下。"""
raw = (record.get("image_path") or "").strip()
if not raw:
return ""
p = Path(raw)
if p.is_absolute():
return str(p) if p.exists() else str(p)
under_pack = AUDIT_PACK / raw
if under_pack.exists():
return str(under_pack)
legacy = ROOT / raw
if legacy.exists():
return str(legacy)
return str(under_pack)
def load_image_for_ui(record: dict) -> tuple[Image.Image | None, str]:
"""返回 (PIL 图, 警告文案)。无文件或损坏时图为 None,警告非空。"""
path_str = resolve_image_fs_path(record)
if not path_str:
return None, "⚠️ 记录中无 image_path"
p = Path(path_str)
if not p.exists():
return None, f"⚠️ 图片不存在: `{p}`(请确认 `audit_pack/images/` 已随仓库打包)"
try:
return Image.open(p).convert("RGB"), ""
except OSError as e:
return None, f"⚠️ 无法读取图片 `{p}`: {e}"
def load_records():
with DATA_FILE.open("r", encoding="utf-8") as f:
return json.load(f)
def load_existing_annotations(annotation_file: Path):
annotations = {}
if not annotation_file.exists():
return annotations
with annotation_file.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
row = json.loads(line)
annotations[row["sample_id"]] = row
return annotations
def write_annotations(annotation_file: Path, annotations):
annotation_file.parent.mkdir(parents=True, exist_ok=True)
tmp = annotation_file.with_suffix(annotation_file.suffix + ".tmp")
with tmp.open("w", encoding="utf-8") as f:
for sample_id in sorted(annotations):
f.write(json.dumps(annotations[sample_id], ensure_ascii=False) + "\n")
f.flush()
os.fsync(f.fileno())
tmp.replace(annotation_file)
def sync_annotation_to_hf(annotation_file: Path, annotator_name: str) -> str:
"""
将标注文件上传到 HF dataset repo。
需要在 Space Secrets 里配置 `HF_TOKEN`。
"""
if not annotation_file.exists():
return "本地标注文件不存在,未同步到 HF"
if HF_API is None:
return "未配置 HF_TOKEN,仅保存在 Space 运行容器中"
path_in_repo = f"annotations/{annotator_name}.jsonl"
try:
HF_API.upload_file(
path_or_fileobj=str(annotation_file),
path_in_repo=path_in_repo,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Update annotations for {annotator_name}",
)
return f"已同步到 HF dataset: {DATASET_REPO_ID}/{path_in_repo}"
except Exception as e:
return f"本地保存成功,但同步到 HF 失败: {e}"
def merge_annotations_from_disk(annotation_file: Path, session_annotations: dict | None) -> dict:
"""
以磁盘上的 jsonl 为基准,再用当前会话里的 dict 覆盖同 sample_id。
避免 Gradio State 丢字典/变空时,整文件被覆盖成「只剩一条」。
"""
disk = load_existing_annotations(annotation_file)
sess = dict(session_annotations or {})
return {**disk, **sess}
def build_stats(records, annotations):
total = len(records)
labeled = len(annotations)
decision_counter = Counter(item["judgment"] for item in annotations.values())
by_dataset = defaultdict(lambda: Counter())
for record in records:
dataset = record["dataset"]
by_dataset[dataset]["total"] += 1
if record["sample_id"] in annotations:
by_dataset[dataset]["labeled"] += 1
by_dataset[dataset][annotations[record["sample_id"]]["judgment"]] += 1
lines = [
"### 标注进度",
"",
f"- 总样本数: `{total}`",
f"- 已标注: `{labeled}`",
f"- 未标注: `{total - labeled}`",
f"- `符合 prompt`: `{decision_counter.get('match', 0)}`",
f"- `不符合 prompt`: `{decision_counter.get('mismatch', 0)}`",
f"- `不确定`: `{decision_counter.get('unsure', 0)}`",
"",
"| 数据源 | 总数 | 已标注 | 符合 | 不符合 | 不确定 |",
"| --- | ---: | ---: | ---: | ---: | ---: |",
]
for dataset in sorted(by_dataset):
counter = by_dataset[dataset]
lines.append(
f"| {dataset} | {counter['total']} | {counter['labeled']} | "
f"{counter['match']} | {counter['mismatch']} | {counter['unsure']} |"
)
return "\n".join(lines)
def resolve_filters(records, dataset_filter, split_filter, status_filter, annotations):
indices = []
for idx, record in enumerate(records):
if dataset_filter != "all" and record["dataset"] != dataset_filter:
continue
if split_filter != "all" and record["split"] != split_filter:
continue
labeled = record["sample_id"] in annotations
if status_filter == "labeled" and not labeled:
continue
if status_filter == "unlabeled" and labeled:
continue
indices.append(idx)
return indices
def render_record(record, annotation, position_text):
judgment = annotation["judgment"] if annotation else "unsure"
note = annotation["note"] if annotation else ""
img, warn = load_image_for_ui(record)
pos = f"{warn}\n{position_text}" if warn else position_text
return (
img,
record.get("prompt_display") or record.get("prompt_zh") or record.get("prompt", ""),
judgment,
note,
pos,
)
def coerce_record_index(cursor) -> int | None:
"""
Gradio 前端可能把 State 传成 float / 字符串(如 "5.0"),Python 里 int("5.0") 会报错。
统一先 float 再 int;解析失败返回 None。
"""
if cursor is None:
return None
if isinstance(cursor, bool):
return None
try:
return int(float(cursor))
except (TypeError, ValueError):
return None
def normalize_cursor(cursor, pool: list[int]) -> int:
"""cursor 为 records 的全局下标;若不在当前 pool 内则落到 pool[0]。"""
if not pool:
return 0
c = coerce_record_index(cursor)
if c is None:
return pool[0]
if c not in pool:
return pool[0]
return c
def build_app():
records = load_records()
dataset_choices = ["all"] + sorted({record["dataset"] for record in records})
split_choices = ["all", "retained", "rejected"]
status_choices = ["all", "unlabeled", "labeled"]
with gr.Blocks(title="过滤数据人工审核", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 过滤数据人工审核")
gr.Markdown(
"请先填写标注人名称,然后判断图片是否符合给定 promp,全部标注完成后请同步到HF上面"
)
annotator = gr.Textbox(label="标注人名称", value="", placeholder="请先输入姓名/昵称")
# 当前样本在 records 中的下标(全局),不是「在当前列表里的第几个」
cursor_record_idx = gr.State(0)
annotations_state = gr.State({})
unsynced_count_state = gr.State(0)
with gr.Row():
dataset_filter = gr.Dropdown(dataset_choices, value="all", label="数据源")
split_filter = gr.Dropdown(split_choices, value="all", label="集合")
status_filter = gr.Dropdown(status_choices, value="all", label="标注状态")
refresh_btn = gr.Button("应用筛选", variant="primary")
progress_md = gr.Markdown()
sync_md = gr.Markdown("### 同步状态\n\n- 未同步修改: `0`\n- 建议标注一批后再点击一次“同步到HF”。")
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(label="图片", type="pil", height=480)
position_box = gr.Textbox(label="当前位置", interactive=False)
with gr.Column(scale=1):
prompt_box = gr.Textbox(label="Prompt(中英双语)", lines=7, interactive=False)
with gr.Row():
judgment = gr.Radio(choices=["match", "mismatch", "unsure"], value="unsure", label="人工判断")
note = gr.Textbox(label="备注", lines=3, placeholder="可选备注")
with gr.Row():
prev_btn = gr.Button("上一条")
save_btn = gr.Button("保存标注", variant="primary")
sync_btn = gr.Button("同步到HF", variant="secondary")
next_btn = gr.Button("下一条")
save_status = gr.Textbox(label="保存状态", interactive=False)
def sync_status_text(n: int) -> str:
return f"### 同步状态\n\n- 未同步修改: `{n}`\n- 系统会在累计达到 `{AUTO_SYNC_EVERY}` 条后自动同步一次。"
def validate_annotator_name(name: str) -> str:
who = (name or "").strip()
return who
def bootstrap(annotator_name):
annotator_name = validate_annotator_name(annotator_name)
if not annotator_name:
return {}, 0, 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", "### 标注进度\n\n- 请先输入标注人名称", sync_status_text(0), "请先输入标注人名称后再开始标注"
annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl"
annotations = load_existing_annotations(annotation_file)
pool = resolve_filters(records, "all", "all", "all", annotations)
if not pool:
return annotations, 0, 0, None, "", "unsure", "", "0 / 0", build_stats(records, annotations), sync_status_text(0), ""
cursor = pool[0]
record = records[cursor]
rendered = render_record(
record, annotations.get(record["sample_id"]), f"1 / {len(pool)}"
)
return annotations, int(cursor), 0, *rendered, build_stats(records, annotations), sync_status_text(0), ""
def refresh_pool(dataset_value, split_value, status_value, annotations, annotator_name):
annotator_name = validate_annotator_name(annotator_name)
if not annotator_name:
return (
{},
0,
0,
None,
"请先输入标注人名称",
"unsure",
"",
"0 / 0",
"### 标注进度\n\n- 请先输入标注人名称",
sync_status_text(0),
"请先输入标注人名称后再开始标注",
)
ann_file = ANNOTATION_DIR / f"{annotator_name}.jsonl"
annotations = merge_annotations_from_disk(ann_file, annotations)
pool = resolve_filters(records, dataset_value, split_value, status_value, annotations)
if not pool:
return (
annotations,
0,
0,
None,
"",
"unsure",
"",
"0 / 0",
build_stats(records, annotations),
sync_status_text(0),
"没有可显示的样本",
)
cursor = pool[0]
record = records[cursor]
rendered = render_record(
record, annotations.get(record["sample_id"]), f"1 / {len(pool)}"
)
return (
annotations,
int(cursor),
0,
*rendered,
build_stats(records, annotations),
sync_status_text(0),
"筛选条件已更新",
)
def move(delta, cursor, annotations, dataset_value, split_value, status_value, annotator_name):
annotator_name = validate_annotator_name(annotator_name)
if not annotator_name:
return 0, None, "请先输入标注人名称", "unsure", "", "0 / 0", "请先输入标注人名称后再开始标注"
ann_file = ANNOTATION_DIR / f"{annotator_name}.jsonl"
annotations = merge_annotations_from_disk(ann_file, annotations)
pool = resolve_filters(records, dataset_value, split_value, status_value, annotations)
if not pool:
return 0, None, "", "unsure", "", "0 / 0", "没有可显示的样本"
c = normalize_cursor(cursor, pool)
pos = pool.index(c)
new_pos = max(0, min(pos + delta, len(pool) - 1))
new_c = pool[new_pos]
record = records[new_c]
rendered = render_record(
record, annotations.get(record["sample_id"]), f"{new_pos + 1} / {len(pool)}"
)
# 显式 int,避免 State 里残留 float 导致后续 in pool 判断异常
return int(new_c), *rendered, ""
def on_prev(c, ann, d, s, st, who):
return move(-1, c, ann, d, s, st, who)
def on_next(c, ann, d, s, st, who):
return move(1, c, ann, d, s, st, who)
def save_annotation(
cursor,
decision,
note_value,
annotations,
unsynced_count,
annotator_name,
dataset_value,
split_value,
status_value,
):
annotator_name = validate_annotator_name(annotator_name)
if not annotator_name:
return (
annotations or {},
0,
unsynced_count or 0,
None,
"请先输入标注人名称",
"unsure",
"",
"0 / 0",
build_stats(records, annotations or {}),
sync_status_text(unsynced_count or 0),
"请先输入标注人名称后再保存",
)
annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl"
merged = merge_annotations_from_disk(annotation_file, annotations)
old_pool = resolve_filters(records, dataset_value, split_value, status_value, merged)
if not old_pool:
return (
merged,
0,
unsynced_count or 0,
None,
"",
"unsure",
"",
"0 / 0",
build_stats(records, merged),
sync_status_text(unsynced_count or 0),
"没有可保存的样本",
)
c = normalize_cursor(cursor, old_pool)
record = records[c]
record_idx = c
pos = old_pool.index(record_idx)
row = {
"sample_id": record["sample_id"],
"annotator": annotator_name,
"judgment": decision,
"note": note_value,
"updated_at": datetime.now(timezone.utc).isoformat(),
"dataset": record["dataset"],
"split": record["split"],
}
to_write = {**merged, record["sample_id"]: row}
try:
write_annotations(annotation_file, to_write)
except OSError as e:
cur = normalize_cursor(cursor, old_pool)
rec = records[cur]
ptxt = f"{old_pool.index(cur) + 1} / {len(old_pool)}"
rendered = render_record(rec, merged.get(rec["sample_id"]), ptxt)
return (
merged,
int(cur),
unsynced_count or 0,
*rendered,
build_stats(records, merged),
sync_status_text(unsynced_count or 0),
f"写入失败(请检查磁盘是否可写): {e}",
)
annotations = to_write
next_unsynced = int(unsynced_count or 0) + 1
sync_msg = ""
if next_unsynced >= AUTO_SYNC_EVERY:
sync_msg = sync_annotation_to_hf(annotation_file, annotator_name)
if sync_msg.startswith("已同步到 HF dataset"):
next_unsynced = 0
new_pool = resolve_filters(records, dataset_value, split_value, status_value, annotations)
stats = build_stats(records, annotations)
if not new_pool:
if next_unsynced > 0:
final_sync_msg = sync_annotation_to_hf(annotation_file, annotator_name)
if final_sync_msg.startswith("已同步到 HF dataset"):
next_unsynced = 0
sync_msg = (sync_msg + ";" + final_sync_msg).strip(";")
return (
annotations,
0,
next_unsynced,
None,
"",
"unsure",
"",
"0 / 0",
stats,
sync_status_text(next_unsynced),
f"已保存 {record['sample_id']}{sync_msg or '仅本地缓存,尚未同步到HF'}",
)
if record_idx in new_pool:
new_cursor = record_idx
elif pos > 0:
prev_idx = old_pool[pos - 1]
new_cursor = prev_idx if prev_idx in new_pool else new_pool[0]
else:
new_cursor = new_pool[0]
rec = records[new_cursor]
rendered = render_record(
rec,
annotations.get(rec["sample_id"]),
f"{new_pool.index(new_cursor) + 1} / {len(new_pool)}",
)
return (
annotations,
int(new_cursor),
next_unsynced,
*rendered,
stats,
sync_status_text(next_unsynced),
f"已保存 {record['sample_id']}{sync_msg or '仅本地缓存,尚未同步到HF'}",
)
def sync_current_annotations(annotations, unsynced_count, annotator_name):
annotator_name = validate_annotator_name(annotator_name)
if not annotator_name:
return annotations or {}, unsynced_count or 0, build_stats(records, annotations or {}), sync_status_text(unsynced_count or 0), "请先输入标注人名称后再同步"
annotation_file = ANNOTATION_DIR / f"{annotator_name}.jsonl"
merged = merge_annotations_from_disk(annotation_file, annotations)
if not merged:
return merged, 0, build_stats(records, merged), sync_status_text(0), "当前没有可同步的标注"
if not int(unsynced_count or 0):
return merged, 0, build_stats(records, merged), sync_status_text(0), "当前没有未同步修改"
sync_msg = sync_annotation_to_hf(annotation_file, annotator_name)
next_unsynced = 0 if sync_msg.startswith("已同步到 HF dataset") else int(unsynced_count or 0)
return merged, next_unsynced, build_stats(records, merged), sync_status_text(next_unsynced), sync_msg
annotator.change(
bootstrap,
inputs=[annotator],
outputs=[
annotations_state,
cursor_record_idx,
unsynced_count_state,
image,
prompt_box,
judgment,
note,
position_box,
progress_md,
sync_md,
save_status,
],
)
refresh_btn.click(
refresh_pool,
inputs=[dataset_filter, split_filter, status_filter, annotations_state, annotator],
outputs=[
annotations_state,
cursor_record_idx,
unsynced_count_state,
image,
prompt_box,
judgment,
note,
position_box,
progress_md,
sync_md,
save_status,
],
)
prev_btn.click(
on_prev,
inputs=[
cursor_record_idx,
annotations_state,
dataset_filter,
split_filter,
status_filter,
annotator,
],
outputs=[cursor_record_idx, image, prompt_box, judgment, note, position_box, save_status],
)
next_btn.click(
on_next,
inputs=[
cursor_record_idx,
annotations_state,
dataset_filter,
split_filter,
status_filter,
annotator,
],
outputs=[cursor_record_idx, image, prompt_box, judgment, note, position_box, save_status],
)
save_btn.click(
save_annotation,
inputs=[
cursor_record_idx,
judgment,
note,
annotations_state,
unsynced_count_state,
annotator,
dataset_filter,
split_filter,
status_filter,
],
outputs=[
annotations_state,
cursor_record_idx,
unsynced_count_state,
image,
prompt_box,
judgment,
note,
position_box,
progress_md,
sync_md,
save_status,
],
)
sync_btn.click(
sync_current_annotations,
inputs=[annotations_state, unsynced_count_state, annotator],
outputs=[annotations_state, unsynced_count_state, progress_md, sync_md, save_status],
)
demo.load(
lambda: bootstrap("annotator_1"),
outputs=[
annotations_state,
cursor_record_idx,
unsynced_count_state,
image,
prompt_box,
judgment,
note,
position_box,
progress_md,
sync_md,
save_status,
],
)
return demo
demo = build_app()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)