Spaces:
Sleeping
Sleeping
| # app.py | |
| """ | |
| Semi-Auto Image Captioning - Full version for HF Spaces (Gradio) | |
| Features: | |
| - ingest images or ZIP | |
| - preprocess: Laplacian blur (OpenCV), dHash de-dupe | |
| - optional InsightFace filtering (if insightface installed) | |
| - auto caption: BLIP (base / large) | |
| - optional taggers: WD14 / CLIP Interrogator (if installed) | |
| - human edit via Gradio Dataframe & export CSV/JSONL/ZIP | |
| """ | |
| import os | |
| import io | |
| import shutil | |
| import zipfile | |
| import hashlib | |
| import json | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import torch | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| # Optional: try to import InsightFace and CLIP interrogator style modules | |
| try: | |
| import insightface | |
| from insightface.app import FaceAnalysis | |
| _HAS_INSIGHTFACE = True | |
| except Exception: | |
| _HAS_INSIGHTFACE = False | |
| # Optional taggers (WD14 or CLIP Interrogator) | |
| # We do a soft import so Space works even if these are not available. | |
| try: | |
| from clip_interrogator import ClipInterrogator, Config as CIConfig # hypothetical | |
| _HAS_CI = True | |
| except Exception: | |
| _HAS_CI = False | |
| try: | |
| # placeholder for WD14 tagger library import | |
| import wd14_tagger # hypothetical package name | |
| _HAS_WD14 = True | |
| except Exception: | |
| _HAS_WD14 = False | |
| # ---------------- Settings ---------------- | |
| DEFAULT_MODEL = "Salesforce/blip-image-captioning-base" # CPU friendly | |
| BIG_MODEL = "Salesforce/blip-image-captioning-large" # GPU recommended | |
| BLUR_VAR_THRESHOLD = 100.0 | |
| # Work directories inside the Space container | |
| ROOT = Path("workspace") | |
| IMAGES_DIR = ROOT / "images" | |
| EXPORT_DIR = ROOT / "export" | |
| ROOT.mkdir(parents=True, exist_ok=True) | |
| IMAGES_DIR.mkdir(parents=True, exist_ok=True) | |
| EXPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| # ---------------- Utilities ---------------- | |
| def clear_workspace(): | |
| """Remove workspace images/export and recreate directories.""" | |
| if IMAGES_DIR.exists(): | |
| shutil.rmtree(IMAGES_DIR) | |
| if EXPORT_DIR.exists(): | |
| shutil.rmtree(EXPORT_DIR) | |
| IMAGES_DIR.mkdir(parents=True, exist_ok=True) | |
| EXPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| def is_image(fname: str) -> bool: | |
| ext = str(fname).lower().split(".")[-1] | |
| return ext in ["jpg", "jpeg", "png", "bmp", "webp"] | |
| def laplacian_var_blur(pil_img: Image.Image) -> float: | |
| arr = np.array(pil_img.convert("L")) | |
| if arr.size == 0: | |
| return 0.0 | |
| fm = cv2.Laplacian(arr, cv2.CV_64F).var() | |
| return float(fm) | |
| def dhash(pil_img: Image.Image, hash_size: int = 8) -> str: | |
| img = pil_img.convert("L").resize((hash_size + 1, hash_size), Image.LANCZOS) | |
| diff = np.array(img)[:, 1:] > np.array(img)[:, :-1] | |
| return ''.join('1' if v else '0' for v in diff.flatten()) | |
| def save_uploaded_files(files: List[gr.File]) -> List[str]: | |
| saved = [] | |
| for f in files: | |
| if f is None: | |
| continue | |
| # gradio file object: f.name is the temporary path on server | |
| name = os.path.basename(f.name) | |
| dst = IMAGES_DIR / name | |
| shutil.copy(f.name, dst) | |
| saved.append(str(dst)) | |
| return saved | |
| def unzip_to_images(zbytes: bytes) -> List[str]: | |
| saved = [] | |
| with zipfile.ZipFile(io.BytesIO(zbytes)) as zf: | |
| for info in zf.infolist(): | |
| if info.is_dir(): | |
| continue | |
| if not is_image(info.filename): | |
| continue | |
| with zf.open(info) as src: | |
| data = src.read() | |
| fname = os.path.basename(info.filename) | |
| dst = IMAGES_DIR / fname | |
| with open(dst, 'wb') as out: | |
| out.write(data) | |
| saved.append(str(dst)) | |
| return saved | |
| # ---------------- Optional InsightFace wrapper ---------------- | |
| _insightface_app = None | |
| if _HAS_INSIGHTFACE: | |
| try: | |
| _insightface_app = FaceAnalysis(providers=['CPUExecutionProvider']) # or CUDA if available | |
| _insightface_app.prepare(ctx_id=0 if torch.cuda.is_available() else -1, det_size=(640, 640)) | |
| except Exception: | |
| _insightface_app = None | |
| _HAS_INSIGHTFACE = False | |
| def insightface_quality_score(pil_img: Image.Image) -> Optional[float]: | |
| """Return a simple face quality score if InsightFace available, else None. | |
| We compute average detection 'bbox score' as a proxy (if provided by model). | |
| """ | |
| if not _HAS_INSIGHTFACE or _insightface_app is None: | |
| return None | |
| try: | |
| arr = np.array(pil_img.convert("RGB")) | |
| res = _insightface_app.get(arr) | |
| if not res: | |
| return 0.0 | |
| # Some insightface returns dict-like object with bbox/score | |
| scores = [] | |
| for r in res: | |
| # support different result structures | |
| s = getattr(r, 'det_score', None) or getattr(r, 'score', None) or None | |
| if s is not None: | |
| scores.append(float(s)) | |
| if not scores: | |
| return 0.0 | |
| return float(np.mean(scores)) | |
| except Exception: | |
| return None | |
| # ---------------- Captioner ---------------- | |
| class BlipCaptioner: | |
| def __init__(self, model_name: str = DEFAULT_MODEL, device: str = None): | |
| self.model_name = model_name | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| # load processor & model | |
| self.processor = BlipProcessor.from_pretrained(model_name) | |
| self.model = BlipForConditionalGeneration.from_pretrained(model_name) | |
| if self.device == "cuda": | |
| try: | |
| self.model = self.model.half().to(self.device) | |
| except Exception: | |
| self.model = self.model.to(self.device) | |
| else: | |
| self.model = self.model.to(self.device) | |
| def caption(self, pil_img: Image.Image, max_new_tokens: int = 40) -> str: | |
| inputs = self.processor(images=pil_img, return_tensors="pt").to(self.device) | |
| out = self.model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| text = self.processor.decode(out[0], skip_special_tokens=True) | |
| return text.strip() | |
| _captioner_cache: Dict[str, BlipCaptioner] = {} | |
| def get_captioner(model_name: str) -> BlipCaptioner: | |
| key = model_name | |
| if key not in _captioner_cache: | |
| _captioner_cache[key] = BlipCaptioner(model_name=model_name) | |
| return _captioner_cache[key] | |
| # ---------------- Optional Taggers ---------------- | |
| # These are placeholders: if real libs are installed, replace with real calls. | |
| _ci = None | |
| if _HAS_CI: | |
| try: | |
| ci_cfg = CIConfig() | |
| _ci = ClipInterrogator(ci_cfg) | |
| except Exception: | |
| _ci = None | |
| _HAS_CI = False | |
| def clip_interrogate_caption(pil_img: Image.Image) -> Optional[str]: | |
| if not _HAS_CI or _ci is None: | |
| return None | |
| try: | |
| return _ci.interrogate(pil_img) | |
| except Exception: | |
| return None | |
| def wd14_tags(pil_img: Image.Image) -> Optional[List[str]]: | |
| if not _HAS_WD14: | |
| return None | |
| try: | |
| # hypothetical API, replace if you install a real wd14 tagger | |
| tags = wd14_tagger.infer_tags(pil_img) | |
| return tags | |
| except Exception: | |
| return None | |
| # ---------------- Pipeline steps ---------------- | |
| def step_ingest(files, zip_file): | |
| """ | |
| Ingest files or zip, clear workspace, and save incoming images. | |
| Return: gallery, table | |
| gallery: list of (path, filename) | |
| table: rows [name, path, status, caption, blur_var, hash] | |
| """ | |
| clear_workspace() | |
| saved = [] | |
| if files: | |
| saved += save_uploaded_files(files) | |
| if zip_file is not None: | |
| try: | |
| with open(zip_file.name, "rb") as f: | |
| zbytes = f.read() | |
| saved += unzip_to_images(zbytes) | |
| except Exception: | |
| # gradio may provide zip file as bytes in memory | |
| try: | |
| zbytes = zip_file.read() | |
| saved += unzip_to_images(zbytes) | |
| except Exception: | |
| pass | |
| gallery = [(p, os.path.basename(p)) for p in saved if is_image(p)] | |
| table = [[os.path.basename(p), p, "", "", 0.0, ""] for p in saved if is_image(p)] | |
| return gallery, table | |
| def step_preprocess(table, rm_blurry=True, rm_dupes=True, blur_thr=BLUR_VAR_THRESHOLD, use_insightface=False, face_score_thr=0.1): | |
| """ | |
| table: list of rows [name, path, status, caption, blur_var, hash] | |
| Returns new table with statuses set to "kept" or "filtered:reason" | |
| """ | |
| seen_hashes = set() | |
| new_table = [] | |
| for row in table: | |
| try: | |
| name, path, status, caption, blur_var, dh = row | |
| except Exception: | |
| # malformed row, skip | |
| continue | |
| try: | |
| pil = Image.open(path).convert("RGB") | |
| except Exception: | |
| row[2] = "read_error" | |
| new_table.append(row) | |
| continue | |
| blur = laplacian_var_blur(pil) | |
| ph = dhash(pil) | |
| keep = True | |
| reason = [] | |
| if rm_blurry and blur < blur_thr: | |
| keep = False | |
| reason.append(f"blur<{blur_thr:.0f}") | |
| if rm_dupes and ph in seen_hashes: | |
| keep = False | |
| reason.append("duplicate") | |
| if use_insightface and _HAS_INSIGHTFACE: | |
| score = insightface_quality_score(pil) | |
| if score is not None: | |
| # treat very low score as filter | |
| if score < face_score_thr: | |
| keep = False | |
| reason.append("low_face_score") | |
| if keep: | |
| seen_hashes.add(ph) | |
| new_table.append([name, path, "kept", caption, blur, ph]) | |
| else: | |
| new_table.append([name, path, "filtered:" + ",".join(reason), caption, blur, ph]) | |
| return new_table | |
| def step_autocaption(table, model_choice: str, max_tokens: int, use_ci=False, use_wd14=False): | |
| """ | |
| For each kept row, generate caption (BLIP) and optionally append tags from other taggers. | |
| """ | |
| cap = get_captioner(model_choice) | |
| new_table = [] | |
| for row in table: | |
| name, path, status, caption, blur_var, dh = row | |
| if not os.path.exists(path): | |
| row[2] = "missing" | |
| new_table.append(row) | |
| continue | |
| # only process kept items (or empty status) | |
| if not status.startswith("kept") and status != "": | |
| new_table.append(row) | |
| continue | |
| try: | |
| pil = Image.open(path).convert("RGB") | |
| auto_cap = cap.caption(pil, max_new_tokens=max_tokens) | |
| except Exception as e: | |
| auto_cap = f"<error: {e}>" | |
| # optional additional interrogator / tagger info | |
| extras = [] | |
| if use_ci: | |
| try: | |
| ci_cap = clip_interrogate_caption(pil) | |
| if ci_cap: | |
| extras.append(ci_cap) | |
| except Exception: | |
| pass | |
| if use_wd14: | |
| try: | |
| tags = wd14_tags(pil) | |
| if tags: | |
| extras.append(", ".join(tags)) | |
| except Exception: | |
| pass | |
| final_caption = caption if caption else auto_cap | |
| if extras: | |
| # keep extras briefer and join | |
| final_caption = final_caption + " | " + " | ".join(extras) | |
| new_table.append([name, path, "kept", final_caption, blur_var, dh]) | |
| return new_table | |
| def step_export(table, file_prefix: str = "dataset") -> Tuple[str, str, str]: | |
| """ | |
| Build CSV, JSONL and ZIP. Return (csv_path, jsonl_path, zip_path) | |
| """ | |
| rows = [] | |
| for name, path, status, caption, blur_var, dh in table: | |
| if status.startswith("kept") and caption and len(caption.strip()) > 0: | |
| rows.append({"image": path, "caption": caption}) | |
| csv_path = EXPORT_DIR / f"{file_prefix}.csv" | |
| jsonl_path = EXPORT_DIR / f"{file_prefix}.jsonl" | |
| EXPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| # write CSV | |
| import csv | |
| with open(csv_path, 'w', newline='', encoding='utf-8') as f: | |
| w = csv.writer(f) | |
| w.writerow(["image", "caption"]) | |
| for r in rows: | |
| w.writerow([r["image"], r["caption"]]) | |
| # write JSONL | |
| with open(jsonl_path, 'w', encoding='utf-8') as f: | |
| for r in rows: | |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") | |
| # Zip package (images + csv/jsonl) | |
| zip_path = EXPORT_DIR / f"{file_prefix}.zip" | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as z: | |
| z.write(csv_path, arcname=csv_path.name) | |
| z.write(jsonl_path, arcname=jsonl_path.name) | |
| for r in rows: | |
| src = Path(r["image"]) | |
| if src.exists(): | |
| z.write(src, arcname=f"images/{src.name}") | |
| return str(csv_path), str(jsonl_path), str(zip_path) | |
| # ---------------- Gradio UI ---------------- | |
| title_md = """ | |
| # 半自动图像标注(Captioning) | |
| **步骤**:上传图片或 ZIP → 预处理/过滤 → 自动打草稿 → 人工修订(表格) → 导出 CSV/JSONL/ZIP。 | |
| """ | |
| with gr.Blocks(title="Semi-Auto Image Captioning") as demo: | |
| gr.Markdown(title_md) | |
| with gr.Row(): | |
| with gr.Column(): | |
| files = gr.File(file_count="multiple", file_types=["image"], label="上传图片(可多选)") | |
| zip_up = gr.File(file_count="single", file_types=[".zip"], label="或上传 ZIP(包含图片)") | |
| btn_ingest = gr.Button("1) 导入") | |
| with gr.Column(): | |
| gallery = gr.Gallery(label="预览", show_label=True, columns=6, height=260) | |
| table = gr.Dataframe( | |
| headers=["name", "path", "status", "caption", "blur_var", "hash"], | |
| datatype=["str", "str", "str", "str", "number", "str"], | |
| row_count=(0, "dynamic"), | |
| col_count=(6, "fixed"), | |
| wrap=True, | |
| interactive=True, | |
| label="数据表(可直接编辑 caption)" | |
| ) | |
| with gr.Row(): | |
| rm_blur = gr.Checkbox(value=True, label="过滤模糊图") | |
| rm_dup = gr.Checkbox(value=True, label="去重") | |
| blur_thr = gr.Slider(10, 500, value=BLUR_VAR_THRESHOLD, step=10, label="模糊阈值 (Laplacian Var)") | |
| use_insight = gr.Checkbox(value=False, label="使用 InsightFace 进行人脸质量检测(可选)") | |
| face_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="InsightFace 人脸质量阈值(越高越严格)") | |
| btn_pre = gr.Button("2) 预处理/过滤") | |
| with gr.Row(): | |
| model_choice = gr.Dropdown(choices=[DEFAULT_MODEL, BIG_MODEL], value=DEFAULT_MODEL, label="BLIP 模型") | |
| max_toks = gr.Slider(16, 80, value=40, step=4, label="最大新词数") | |
| use_ci = gr.Checkbox(value=False, label="使用 CLIP Interrogator(可选)") | |
| use_wd14 = gr.Checkbox(value=False, label="使用 WD14 Tagger(可选)") | |
| btn_caption = gr.Button("3) 自动打草稿 (Caption)") | |
| with gr.Row(): | |
| prefix = gr.Textbox(value="dataset", label="导出文件前缀") | |
| btn_export = gr.Button("4) 导出 CSV / JSONL / ZIP") | |
| csv_out = gr.File(label="CSV") | |
| jsonl_out = gr.File(label="JSONL") | |
| zip_out = gr.File(label="打包 ZIP") | |
| # wiring | |
| btn_ingest.click(fn=step_ingest, inputs=[files, zip_up], outputs=[gallery, table]) | |
| btn_pre.click(fn=step_preprocess, inputs=[table, rm_blur, rm_dup, blur_thr, use_insight, face_thr], outputs=table) | |
| btn_caption.click(fn=step_autocaption, inputs=[table, model_choice, max_toks, use_ci, use_wd14], outputs=table) | |
| btn_export.click(fn=step_export, inputs=[table, prefix], outputs=[csv_out, jsonl_out, zip_out]) | |
| if __name__ == "__main__": | |
| demo.launch() | |