agentic-rl-main / data_utils /chart /deplot_pipeline.py
Jack04810's picture
Add files using upload-large-folder tool
f5ab8aa verified
Raw
History Blame Contribute Delete
11.5 kB
"""Offline DePlot (google/deplot) batch pipeline for ChartQA visual_fact_deplot."""
from __future__ import annotations
import json
import os
from typing import Any, Optional
from data_utils.paths import resolve_image_path
DEFAULT_MODEL_ID = "google/deplot"
DEFAULT_PROMPT = "Generate underlying data table of the figure below:"
PLACEHOLDER_SOURCE = "deplot_placeholder"
REAL_SOURCE = "google/deplot"
def _parse_vf(raw: Any) -> Optional[dict[str, Any]]:
if raw is None:
return None
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
text = raw.strip()
if not text:
return None
try:
data = json.loads(text)
except json.JSONDecodeError:
return None
return data if isinstance(data, dict) else None
return None
def is_deplot_placeholder(vf: Any) -> bool:
data = _parse_vf(vf)
if data is None:
return False
return data.get("source") == PLACEHOLDER_SOURCE
def has_real_deplot(vf: Any) -> bool:
data = _parse_vf(vf)
if data is None:
return False
if data.get("source") == PLACEHOLDER_SOURCE:
return False
table = (data.get("parsed_table") or "").strip()
return bool(table) and data.get("source") in (REAL_SOURCE, "google/deplot", "deplot")
def format_deplot_for_teacher(vf: Any) -> str:
"""Teacher-facing text from visual_fact_deplot; empty if missing/placeholder."""
data = _parse_vf(vf)
if data is None:
return ""
if data.get("source") == PLACEHOLDER_SOURCE:
return ""
table = (data.get("parsed_table") or "").strip()
if table:
return table
return ""
def placeholder_deplot_table(entry: dict[str, Any], error: Optional[str] = None) -> str:
question = entry.get("question", entry.get("question_wo_prompt", ""))
payload: dict[str, Any] = {
"source": PLACEHOLDER_SOURCE,
"question": question,
"parsed_table": {"note": "DePlot unavailable or image missing"},
}
if error:
payload["error"] = error
return json.dumps(payload, ensure_ascii=False)
def build_deplot_visual_fact(
entry: dict[str, Any],
parsed_table: str,
*,
model_id: str = DEFAULT_MODEL_ID,
) -> str:
question = entry.get("question", entry.get("question_wo_prompt", ""))
payload = {
"source": REAL_SOURCE,
"model_id": model_id,
"question": question,
"parsed_table": parsed_table.strip(),
}
return json.dumps(payload, ensure_ascii=False)
def cache_key_for_entry(entry: dict[str, Any]) -> str:
image = entry.get("image", "")
return os.path.abspath(resolve_image_path(image)) if image else ""
def load_deplot_cache(path: str) -> dict[str, str]:
if not path or not os.path.isfile(path):
return {}
try:
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
return {str(k): str(v) for k, v in data.items()}
except (json.JSONDecodeError, OSError):
pass
return {}
def save_deplot_cache(path: str, cache: dict[str, str]) -> None:
if not path:
return
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
tmp = f"{path}.tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
os.replace(tmp, path)
def needs_deplot_processing(
entry: dict[str, Any],
*,
replace_placeholder: bool = True,
only_missing: bool = False,
) -> bool:
vf = entry.get("visual_fact_deplot")
if not vf:
return True
if is_deplot_placeholder(vf):
return replace_placeholder or only_missing
if has_real_deplot(vf):
return replace_placeholder and not only_missing
return only_missing or replace_placeholder
class DePlotRunner:
"""Lazy-loaded batched DePlot inference."""
def __init__(
self,
model_id: str = DEFAULT_MODEL_ID,
device: Optional[str] = None,
dtype: Optional[str] = None,
prompt: str = DEFAULT_PROMPT,
max_new_tokens: int = 384,
):
self.model_id = model_id
self.prompt = prompt
self.max_new_tokens = max_new_tokens
self._device = device
self._dtype = dtype
self._processor = None
self._model = None
def _resolve_device(self):
import torch
if self._device and self._device != "auto":
return torch.device(self._device)
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def _resolve_dtype(self, device):
import torch
if self._dtype == "float32":
return torch.float32
if self._dtype == "float16":
return torch.float16
if self._dtype == "bfloat16":
return torch.bfloat16
if device.type == "cuda":
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return torch.float32
def load(self) -> bool:
if self._model is not None:
return True
try:
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
device = self._resolve_device()
dtype = self._resolve_dtype(device)
self._processor = Pix2StructProcessor.from_pretrained(self.model_id)
self._model = Pix2StructForConditionalGeneration.from_pretrained(
self.model_id,
torch_dtype=dtype,
).to(device)
self._model.eval()
self._device_obj = device
return True
except Exception as exc:
print(f"[DePlot] model load failed: {exc}")
self._model = None
return False
def generate_batch(self, image_paths: list[str]) -> list[str]:
if not image_paths:
return []
if not self.load():
return [""] * len(image_paths)
import torch
from PIL import Image
images = []
valid_indices: list[int] = []
results: list[str] = [""] * len(image_paths)
for i, path in enumerate(image_paths):
if not path or not os.path.isfile(path):
continue
try:
images.append(Image.open(path).convert("RGB"))
valid_indices.append(i)
except OSError:
continue
if not images:
return results
device = self._device_obj
texts = [self.prompt] * len(images)
with torch.inference_mode():
inputs = self._processor(images=images, text=texts, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = self._model.generate(**inputs, max_new_tokens=self.max_new_tokens)
decoded = self._processor.batch_decode(outputs, skip_special_tokens=True)
for idx, text in zip(valid_indices, decoded):
results[idx] = (text or "").strip()
return results
def generate_batch_with_oom_retry(
self,
image_paths: list[str],
batch_size: int = 8,
max_retries: int = 3,
) -> list[str]:
if not image_paths:
return []
import torch
bs = max(1, batch_size)
out: list[str] = []
pos = 0
retries_left = max_retries
while pos < len(image_paths):
chunk_paths = image_paths[pos : pos + bs]
try:
chunk_out = self.generate_batch(chunk_paths)
out.extend(chunk_out)
pos += len(chunk_paths)
retries_left = max_retries
except RuntimeError as exc:
if "out of memory" not in str(exc).lower() or bs <= 1 or retries_left <= 0:
out.extend([""] * len(chunk_paths))
pos += len(chunk_paths)
continue
if torch.cuda.is_available():
torch.cuda.empty_cache()
bs = max(1, bs // 2)
retries_left -= 1
return out
def enrich_entries_with_deplot(
entries: list[dict[str, Any]],
*,
enabled: bool = True,
model_id: str = DEFAULT_MODEL_ID,
batch_size: int = 8,
max_new_tokens: int = 384,
cache_path: str = "",
replace_placeholder: bool = True,
only_missing: bool = False,
max_samples: int = 0,
device: Optional[str] = None,
) -> dict[str, int]:
"""
Fill visual_fact_deplot on entries in-place.
Returns stats dict: real, placeholder, skipped, failed, cached.
"""
stats = {"real": 0, "placeholder": 0, "skipped": 0, "failed": 0, "cached": 0}
work_entries = entries[:max_samples] if max_samples > 0 else entries
if not enabled:
for entry in work_entries:
if not needs_deplot_processing(
entry, replace_placeholder=replace_placeholder, only_missing=only_missing
):
stats["skipped"] += 1
continue
entry["visual_fact_deplot"] = placeholder_deplot_table(entry, error="deplot_disabled")
stats["placeholder"] += 1
return stats
cache = load_deplot_cache(cache_path)
runner = DePlotRunner(model_id=model_id, device=device, max_new_tokens=max_new_tokens)
model_ok = runner.load()
pending: list[tuple[int, str, str]] = []
for idx, entry in enumerate(work_entries):
if not needs_deplot_processing(
entry, replace_placeholder=replace_placeholder, only_missing=only_missing
):
stats["skipped"] += 1
continue
key = cache_key_for_entry(entry)
if key and key in cache and cache[key].strip():
entry["visual_fact_deplot"] = build_deplot_visual_fact(entry, cache[key], model_id=model_id)
stats["cached"] += 1
stats["real"] += 1
continue
if not key or not os.path.isfile(key):
entry["visual_fact_deplot"] = placeholder_deplot_table(entry, error="image_missing")
stats["placeholder"] += 1
continue
if not model_ok:
entry["visual_fact_deplot"] = placeholder_deplot_table(entry, error="model_load_failed")
stats["placeholder"] += 1
continue
pending.append((idx, key, key))
if pending and model_ok:
bs = max(1, batch_size)
for start in range(0, len(pending), bs):
chunk = pending[start : start + bs]
paths = [p[2] for p in chunk]
tables = runner.generate_batch_with_oom_retry(paths, batch_size=bs)
for (entry_idx, key, _), table in zip(chunk, tables):
entry = work_entries[entry_idx]
if table:
entry["visual_fact_deplot"] = build_deplot_visual_fact(
entry, table, model_id=model_id
)
if key:
cache[key] = table
stats["real"] += 1
else:
entry["visual_fact_deplot"] = placeholder_deplot_table(
entry, error="inference_failed"
)
stats["failed"] += 1
stats["placeholder"] += 1
if cache_path and cache:
save_deplot_cache(cache_path, cache)
return stats