File size: 5,662 Bytes
51c36ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""

Upload MMB-style dataset (CSV + images) to Hugging Face Hub.



Usage:

  pip install datasets pillow

  huggingface-cli login   # or set HF_TOKEN



  python scripts/upload_to_huggingface.py hf_dataset/image_mapping_with_questions.csv

  python scripts/upload_to_huggingface.py hf_dataset/image_mapping_with_questions.csv --repo-id scholo/MMB_dataset

"""
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

# Add project root for imports
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))

import pandas as pd
from datasets import Dataset, Image


def resolve_image_path(csv_path: Path, fname: str) -> Path | None:
    """Try multiple locations for image file."""
    base = csv_path.resolve().parent
    for candidate in (
        base / "images" / fname,
        base / fname,
        base.parent / "images" / fname,
        base.parent / fname,
    ):
        if candidate.exists() and candidate.is_file():
            return candidate
    return None


def load_cf_type_from_scene(csv_path: Path, scene_id: str, variant: str) -> str | None:
    """Load cf_type from scenes/{scene_id}_{variant}.json."""
    base = csv_path.resolve().parent
    for scenes_dir in (base / "scenes", base.parent / "scenes"):
        if not scenes_dir.is_dir():
            continue
        path = scenes_dir / f"{scene_id.lower()}_{variant}.json"
        if not path.exists():
            continue
        try:
            data = json.loads(path.read_text(encoding="utf-8"))
            meta = data.get("cf_metadata") or {}
            t = meta.get("cf_type")
            return str(t) if t is not None else None
        except Exception:
            pass
    return None


def scene_id_from_image_name(fname: str) -> str:
    s = str(fname).strip()
    for suf in ("_original.png", "_original", "_cf1.png", "_cf1", "_cf2.png", "_cf2", ".png"):
        if s.lower().endswith(suf.lower()):
            s = s[: -len(suf)]
            break
    return s.strip() or fname


def build_dataset(csv_path: Path) -> tuple[list[dict], list[str]]:
    """Build list of row dicts from CSV + images. Returns (rows, image_cols)."""
    df = pd.read_csv(csv_path)
    image_cols = ["original_image", "counterfactual1_image", "counterfactual2_image"]
    image_cols = [c for c in image_cols if c in df.columns]

    rows = []
    for i, row in df.iterrows():
        rec = {}
        missing = False
        for col in df.columns:
            v = row[col]
            if pd.isna(v):
                v = ""
            rec[col] = str(v)

        for col in image_cols:
            fname = str(row.get(col, "") or "").strip()
            if not fname:
                missing = True
                break
            fp = resolve_image_path(csv_path, fname)
            if fp is None:
                print(f"Warning: missing image {fname} for row {i}", file=sys.stderr)
                missing = True
                break
            # Store image bytes so Hub viewer can display inline (paths don't work on server)
            rec[col] = {"bytes": fp.read_bytes()}

        if missing:
            continue

        sid_cf1 = scene_id_from_image_name(str(row.get(image_cols[1], "")))
        sid_cf2 = scene_id_from_image_name(str(row.get(image_cols[2], "")))
        cf1_type = load_cf_type_from_scene(csv_path, sid_cf1, "cf1")
        cf2_type = load_cf_type_from_scene(csv_path, sid_cf2, "cf2")
        if cf1_type is not None:
            rec["counterfactual1_type"] = cf1_type
        if cf2_type is not None:
            rec["counterfactual2_type"] = cf2_type

        rows.append(rec)

    return rows, image_cols


def main():
    parser = argparse.ArgumentParser(description="Upload MMB dataset to Hugging Face Hub")
    parser.add_argument("csv_path", type=Path, help="Path to image_mapping_with_questions.csv")
    parser.add_argument("--repo-id", default=None, help="Hub repo ID (e.g. username/dataset-name)")
    parser.add_argument("--private", action="store_true", help="Create private dataset")
    parser.add_argument("--dry-run", action="store_true", help="Build dataset but don't push")
    args = parser.parse_args()

    csv_path = args.csv_path.resolve()
    if not csv_path.exists():
        print(f"Error: {csv_path} not found", file=sys.stderr)
        sys.exit(1)

    try:
        from datasets import Dataset, Image
    except ImportError:
        print("Install datasets and pillow: pip install datasets pillow", file=sys.stderr)
        sys.exit(1)

    print("Building dataset from", csv_path)
    rows, image_cols = build_dataset(csv_path)
    if not rows:
        print("Error: no valid rows (check image paths)", file=sys.stderr)
        sys.exit(1)

    print(f"Loaded {len(rows)} rows")

    ds = Dataset.from_list(rows)
    for col in image_cols:
        if col in ds.column_names:
            ds = ds.cast_column(col, Image())

    if args.dry_run:
        print("Dry run: dataset built, not pushing.")
        print("Columns:", ds.column_names)
        return

    repo_id = args.repo_id
    if not repo_id:
        repo_id = csv_path.parent.name.replace(" ", "-").lower()
        repo_id = f"mmb-{repo_id}"
        print(f"No --repo-id given, using: {repo_id}")

    print(f"Pushing to Hugging Face Hub: {repo_id}")
    ds.push_to_hub(repo_id, private=args.private)
    print("Done. View at: https://huggingface.co/datasets/" + repo_id)


if __name__ == "__main__":
    main()