File size: 4,459 Bytes
534c64f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import json
import os
import sys
import tempfile
from unittest.mock import patch
import pytest
from PIL import Image
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data_utils.chart.deplot_pipeline import (
build_deplot_visual_fact,
enrich_entries_with_deplot,
format_deplot_for_teacher,
has_real_deplot,
is_deplot_placeholder,
load_deplot_cache,
placeholder_deplot_table,
save_deplot_cache,
)
from opsd_utils.privileged.providers import VisualFactsProvider
def test_is_deplot_placeholder():
ph = placeholder_deplot_table({"question": "q"})
assert is_deplot_placeholder(ph)
real = build_deplot_visual_fact({"question": "q"}, "A | 1\nB | 2")
assert not is_deplot_placeholder(real)
assert has_real_deplot(real)
def test_format_deplot_for_teacher():
real = build_deplot_visual_fact({"question": "q"}, "Year | Value\n2020 | 10")
assert format_deplot_for_teacher(real) == "Year | Value\n2020 | 10"
assert format_deplot_for_teacher(placeholder_deplot_table({"question": "q"})) == ""
assert format_deplot_for_teacher(None) == ""
assert format_deplot_for_teacher("") == ""
def test_visual_facts_provider_skips_placeholder_and_missing():
provider = VisualFactsProvider()
sample_ph = {
"visual_fact_deplot": placeholder_deplot_table({"question": "q"}),
"visual_fact_hint": "hint text",
}
suffix = provider.build_teacher_suffix(sample_ph)
assert "Visual Facts - Hint" in suffix
assert "Visual Facts - DePlot" not in suffix
sample_none = {"visual_fact_hint": "only hint"}
suffix2 = provider.build_teacher_suffix(sample_none)
assert "Visual Facts - DePlot" not in suffix2
assert "Visual Facts - Hint" in suffix2
def test_visual_facts_provider_real_deplot_table():
provider = VisualFactsProvider()
table = "Category | 2019 | 2020\nA | 1 | 2"
sample = {
"visual_fact_deplot": build_deplot_visual_fact({"question": "q"}, table),
}
suffix = provider.build_teacher_suffix(sample)
assert "Visual Facts - DePlot" in suffix
assert table in suffix
assert '"parsed_table"' not in suffix
def test_deplot_cache_roundtrip():
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "cache.json")
save_deplot_cache(path, {"/a.png": "table text"})
loaded = load_deplot_cache(path)
assert loaded["/a.png"] == "table text"
def test_enrich_disabled_uses_placeholder():
entries = [{"question": "q1", "image": "missing.png"}]
stats = enrich_entries_with_deplot(entries, enabled=False)
assert stats["placeholder"] == 1
assert is_deplot_placeholder(entries[0]["visual_fact_deplot"])
def test_enrich_with_mock_runner():
with tempfile.TemporaryDirectory() as tmp:
img_path = os.path.join(tmp, "chart.png")
Image.new("RGB", (32, 32)).save(img_path)
cache_path = os.path.join(tmp, "deplot_cache.json")
entries = [{"question": "What?", "image": img_path}]
class _FakeRunner:
def load(self):
return True
def generate_batch_with_oom_retry(self, paths, batch_size=8):
return ["Col | Val\nA | 1" for _ in paths]
with patch("data_utils.chart.deplot_pipeline.DePlotRunner", return_value=_FakeRunner()):
stats = enrich_entries_with_deplot(
entries,
enabled=True,
cache_path=cache_path,
)
assert stats["real"] == 1
assert has_real_deplot(entries[0]["visual_fact_deplot"])
assert "Col | Val" in format_deplot_for_teacher(entries[0]["visual_fact_deplot"])
assert os.path.isfile(cache_path)
def test_build_script_disabled(tmp_path):
import subprocess
root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
inp = tmp_path / "in.json"
out = tmp_path / "out.json"
inp.write_text(json.dumps([{"question": "q", "hint": "h"}]), encoding="utf-8")
subprocess.run(
[
sys.executable,
os.path.join(root, "scripts", "build_visual_facts_chartqa_deplot.py"),
"--input",
str(inp),
"--output",
str(out),
"--no-enabled",
],
check=True,
cwd=root,
)
data = json.loads(out.read_text(encoding="utf-8"))
assert is_deplot_placeholder(data[0]["visual_fact_deplot"])
|