"""Offline DePlot (google/deplot) batch pipeline for ChartQA visual_fact_deplot.""" from __future__ import annotations import json import os from typing import Any, Optional from data_utils.paths import resolve_image_path DEFAULT_MODEL_ID = "google/deplot" DEFAULT_PROMPT = "Generate underlying data table of the figure below:" PLACEHOLDER_SOURCE = "deplot_placeholder" REAL_SOURCE = "google/deplot" def _parse_vf(raw: Any) -> Optional[dict[str, Any]]: if raw is None: return None if isinstance(raw, dict): return raw if isinstance(raw, str): text = raw.strip() if not text: return None try: data = json.loads(text) except json.JSONDecodeError: return None return data if isinstance(data, dict) else None return None def is_deplot_placeholder(vf: Any) -> bool: data = _parse_vf(vf) if data is None: return False return data.get("source") == PLACEHOLDER_SOURCE def has_real_deplot(vf: Any) -> bool: data = _parse_vf(vf) if data is None: return False if data.get("source") == PLACEHOLDER_SOURCE: return False table = (data.get("parsed_table") or "").strip() return bool(table) and data.get("source") in (REAL_SOURCE, "google/deplot", "deplot") def format_deplot_for_teacher(vf: Any) -> str: """Teacher-facing text from visual_fact_deplot; empty if missing/placeholder.""" data = _parse_vf(vf) if data is None: return "" if data.get("source") == PLACEHOLDER_SOURCE: return "" table = (data.get("parsed_table") or "").strip() if table: return table return "" def placeholder_deplot_table(entry: dict[str, Any], error: Optional[str] = None) -> str: question = entry.get("question", entry.get("question_wo_prompt", "")) payload: dict[str, Any] = { "source": PLACEHOLDER_SOURCE, "question": question, "parsed_table": {"note": "DePlot unavailable or image missing"}, } if error: payload["error"] = error return json.dumps(payload, ensure_ascii=False) def build_deplot_visual_fact( entry: dict[str, Any], parsed_table: str, *, model_id: str = DEFAULT_MODEL_ID, ) -> str: question = entry.get("question", entry.get("question_wo_prompt", "")) payload = { "source": REAL_SOURCE, "model_id": model_id, "question": question, "parsed_table": parsed_table.strip(), } return json.dumps(payload, ensure_ascii=False) def cache_key_for_entry(entry: dict[str, Any]) -> str: image = entry.get("image", "") return os.path.abspath(resolve_image_path(image)) if image else "" def load_deplot_cache(path: str) -> dict[str, str]: if not path or not os.path.isfile(path): return {} try: with open(path, encoding="utf-8") as f: data = json.load(f) if isinstance(data, dict): return {str(k): str(v) for k, v in data.items()} except (json.JSONDecodeError, OSError): pass return {} def save_deplot_cache(path: str, cache: dict[str, str]) -> None: if not path: return os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) tmp = f"{path}.tmp" with open(tmp, "w", encoding="utf-8") as f: json.dump(cache, f, ensure_ascii=False, indent=2) os.replace(tmp, path) def needs_deplot_processing( entry: dict[str, Any], *, replace_placeholder: bool = True, only_missing: bool = False, ) -> bool: vf = entry.get("visual_fact_deplot") if not vf: return True if is_deplot_placeholder(vf): return replace_placeholder or only_missing if has_real_deplot(vf): return replace_placeholder and not only_missing return only_missing or replace_placeholder class DePlotRunner: """Lazy-loaded batched DePlot inference.""" def __init__( self, model_id: str = DEFAULT_MODEL_ID, device: Optional[str] = None, dtype: Optional[str] = None, prompt: str = DEFAULT_PROMPT, max_new_tokens: int = 384, ): self.model_id = model_id self.prompt = prompt self.max_new_tokens = max_new_tokens self._device = device self._dtype = dtype self._processor = None self._model = None def _resolve_device(self): import torch if self._device and self._device != "auto": return torch.device(self._device) if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def _resolve_dtype(self, device): import torch if self._dtype == "float32": return torch.float32 if self._dtype == "float16": return torch.float16 if self._dtype == "bfloat16": return torch.bfloat16 if device.type == "cuda": return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 return torch.float32 def load(self) -> bool: if self._model is not None: return True try: import torch from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor device = self._resolve_device() dtype = self._resolve_dtype(device) self._processor = Pix2StructProcessor.from_pretrained(self.model_id) self._model = Pix2StructForConditionalGeneration.from_pretrained( self.model_id, torch_dtype=dtype, ).to(device) self._model.eval() self._device_obj = device return True except Exception as exc: print(f"[DePlot] model load failed: {exc}") self._model = None return False def generate_batch(self, image_paths: list[str]) -> list[str]: if not image_paths: return [] if not self.load(): return [""] * len(image_paths) import torch from PIL import Image images = [] valid_indices: list[int] = [] results: list[str] = [""] * len(image_paths) for i, path in enumerate(image_paths): if not path or not os.path.isfile(path): continue try: images.append(Image.open(path).convert("RGB")) valid_indices.append(i) except OSError: continue if not images: return results device = self._device_obj texts = [self.prompt] * len(images) with torch.inference_mode(): inputs = self._processor(images=images, text=texts, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} outputs = self._model.generate(**inputs, max_new_tokens=self.max_new_tokens) decoded = self._processor.batch_decode(outputs, skip_special_tokens=True) for idx, text in zip(valid_indices, decoded): results[idx] = (text or "").strip() return results def generate_batch_with_oom_retry( self, image_paths: list[str], batch_size: int = 8, max_retries: int = 3, ) -> list[str]: if not image_paths: return [] import torch bs = max(1, batch_size) out: list[str] = [] pos = 0 retries_left = max_retries while pos < len(image_paths): chunk_paths = image_paths[pos : pos + bs] try: chunk_out = self.generate_batch(chunk_paths) out.extend(chunk_out) pos += len(chunk_paths) retries_left = max_retries except RuntimeError as exc: if "out of memory" not in str(exc).lower() or bs <= 1 or retries_left <= 0: out.extend([""] * len(chunk_paths)) pos += len(chunk_paths) continue if torch.cuda.is_available(): torch.cuda.empty_cache() bs = max(1, bs // 2) retries_left -= 1 return out def enrich_entries_with_deplot( entries: list[dict[str, Any]], *, enabled: bool = True, model_id: str = DEFAULT_MODEL_ID, batch_size: int = 8, max_new_tokens: int = 384, cache_path: str = "", replace_placeholder: bool = True, only_missing: bool = False, max_samples: int = 0, device: Optional[str] = None, ) -> dict[str, int]: """ Fill visual_fact_deplot on entries in-place. Returns stats dict: real, placeholder, skipped, failed, cached. """ stats = {"real": 0, "placeholder": 0, "skipped": 0, "failed": 0, "cached": 0} work_entries = entries[:max_samples] if max_samples > 0 else entries if not enabled: for entry in work_entries: if not needs_deplot_processing( entry, replace_placeholder=replace_placeholder, only_missing=only_missing ): stats["skipped"] += 1 continue entry["visual_fact_deplot"] = placeholder_deplot_table(entry, error="deplot_disabled") stats["placeholder"] += 1 return stats cache = load_deplot_cache(cache_path) runner = DePlotRunner(model_id=model_id, device=device, max_new_tokens=max_new_tokens) model_ok = runner.load() pending: list[tuple[int, str, str]] = [] for idx, entry in enumerate(work_entries): if not needs_deplot_processing( entry, replace_placeholder=replace_placeholder, only_missing=only_missing ): stats["skipped"] += 1 continue key = cache_key_for_entry(entry) if key and key in cache and cache[key].strip(): entry["visual_fact_deplot"] = build_deplot_visual_fact(entry, cache[key], model_id=model_id) stats["cached"] += 1 stats["real"] += 1 continue if not key or not os.path.isfile(key): entry["visual_fact_deplot"] = placeholder_deplot_table(entry, error="image_missing") stats["placeholder"] += 1 continue if not model_ok: entry["visual_fact_deplot"] = placeholder_deplot_table(entry, error="model_load_failed") stats["placeholder"] += 1 continue pending.append((idx, key, key)) if pending and model_ok: bs = max(1, batch_size) for start in range(0, len(pending), bs): chunk = pending[start : start + bs] paths = [p[2] for p in chunk] tables = runner.generate_batch_with_oom_retry(paths, batch_size=bs) for (entry_idx, key, _), table in zip(chunk, tables): entry = work_entries[entry_idx] if table: entry["visual_fact_deplot"] = build_deplot_visual_fact( entry, table, model_id=model_id ) if key: cache[key] = table stats["real"] += 1 else: entry["visual_fact_deplot"] = placeholder_deplot_table( entry, error="inference_failed" ) stats["failed"] += 1 stats["placeholder"] += 1 if cache_path and cache: save_deplot_cache(cache_path, cache) return stats