from __future__ import annotations import argparse from dataclasses import asdict from tqdm.auto import tqdm from .jepa_adapter import VisionEmbeddingAdapter, VisionEncoderConfig from .utils import detect_label_names, embeddings_to_frame, load_image_split, save_embeddings def extract_embeddings( dataset_name: str, split: str, image_column: str, label_column: str, model_name: str, max_samples: int, output_dir: str, batch_size: int = 8, device: str | None = None, ) -> dict: dataset = load_image_split(dataset_name, split, max_samples) if image_column not in dataset.column_names: raise ValueError(f"Image column '{image_column}' not found. Available columns: {dataset.column_names}") if label_column not in dataset.column_names: raise ValueError(f"Label column '{label_column}' not found. Available columns: {dataset.column_names}") config = VisionEncoderConfig(model_name=model_name, batch_size=batch_size, device=device) adapter = VisionEmbeddingAdapter(config) images = [row[image_column] for row in tqdm(dataset, desc="Loading samples")] labels = [row[label_column] for row in dataset] sample_ids = [str(i) for i in range(len(dataset))] embeddings = adapter.embed_images(images) label_names = detect_label_names(dataset, label_column) df = embeddings_to_frame(embeddings, labels, label_names, sample_ids, model_name, dataset_name, split) parquet_path = save_embeddings(df, output_dir) return { "output": str(parquet_path), "samples": len(df), "embedding_dim": int(embeddings.shape[1]) if embeddings.size else 0, "config": asdict(config), } def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Extract image embeddings from a Hugging Face dataset.") parser.add_argument("--dataset-name", default="beans") parser.add_argument("--split", default="train") parser.add_argument("--image-column", default="image") parser.add_argument("--label-column", default="labels") parser.add_argument("--model-name", default="facebook/dinov2-base") parser.add_argument("--max-samples", type=int, default=50) parser.add_argument("--output-dir", default="outputs/beans") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--device", default=None) return parser def main() -> None: args = build_parser().parse_args() result = extract_embeddings(**vars(args)) print(f"Saved embeddings: {result['output']}") print(f"Samples: {result['samples']}") print(f"Embedding dimension: {result['embedding_dim']}") if __name__ == "__main__": main()