File size: 2,711 Bytes
2bc3168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations

import argparse
from dataclasses import asdict

from tqdm.auto import tqdm

from .jepa_adapter import VisionEmbeddingAdapter, VisionEncoderConfig
from .utils import detect_label_names, embeddings_to_frame, load_image_split, save_embeddings


def extract_embeddings(
    dataset_name: str,
    split: str,
    image_column: str,
    label_column: str,
    model_name: str,
    max_samples: int,
    output_dir: str,
    batch_size: int = 8,
    device: str | None = None,
) -> dict:
    dataset = load_image_split(dataset_name, split, max_samples)
    if image_column not in dataset.column_names:
        raise ValueError(f"Image column '{image_column}' not found. Available columns: {dataset.column_names}")
    if label_column not in dataset.column_names:
        raise ValueError(f"Label column '{label_column}' not found. Available columns: {dataset.column_names}")

    config = VisionEncoderConfig(model_name=model_name, batch_size=batch_size, device=device)
    adapter = VisionEmbeddingAdapter(config)

    images = [row[image_column] for row in tqdm(dataset, desc="Loading samples")]
    labels = [row[label_column] for row in dataset]
    sample_ids = [str(i) for i in range(len(dataset))]
    embeddings = adapter.embed_images(images)
    label_names = detect_label_names(dataset, label_column)
    df = embeddings_to_frame(embeddings, labels, label_names, sample_ids, model_name, dataset_name, split)
    parquet_path = save_embeddings(df, output_dir)

    return {
        "output": str(parquet_path),
        "samples": len(df),
        "embedding_dim": int(embeddings.shape[1]) if embeddings.size else 0,
        "config": asdict(config),
    }


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Extract image embeddings from a Hugging Face dataset.")
    parser.add_argument("--dataset-name", default="beans")
    parser.add_argument("--split", default="train")
    parser.add_argument("--image-column", default="image")
    parser.add_argument("--label-column", default="labels")
    parser.add_argument("--model-name", default="facebook/dinov2-base")
    parser.add_argument("--max-samples", type=int, default=50)
    parser.add_argument("--output-dir", default="outputs/beans")
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--device", default=None)
    return parser


def main() -> None:
    args = build_parser().parse_args()
    result = extract_embeddings(**vars(args))
    print(f"Saved embeddings: {result['output']}")
    print(f"Samples: {result['samples']}")
    print(f"Embedding dimension: {result['embedding_dim']}")


if __name__ == "__main__":
    main()