File size: 2,411 Bytes
23680f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
"""Run HyperView demo with CIFAR-10 dataset."""

import argparse
import os
import sys
from pathlib import Path

# Add src to path for development
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))


def main():
    parser = argparse.ArgumentParser(description="Run HyperView demo")
    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10_demo",
        help="Dataset name to use for persistence (default: cifar10_demo)",
    )
    parser.add_argument(
        "--samples", type=int, default=50000, help="Number of samples to load (default: 50000)"
    )
    parser.add_argument(
        "--port", type=int, default=6262, help="Port to run server on (default: 6262)"
    )
    parser.add_argument(
        "--no-browser", action="store_true", help="Don't open browser automatically"
    )
    parser.add_argument(
        "--no-persist", action="store_true", help="Don't persist to database (use in-memory)"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="openai/clip-vit-base-patch32",
        help=(
            "Embedding model_id to use (default: openai/clip-vit-base-patch32). "
            "This is passed to Dataset.compute_embeddings(model=...)."
        ),
    )
    parser.add_argument(
        "--datasets-dir",
        "--database-dir",
        type=str,
        default=None,
        help="Override persistence directory (sets HYPERVIEW_DATASETS_DIR)",
    )
    parser.add_argument(
        "--no-server",
        action="store_true",
        help="Don't start the web server (useful for CI / DB checks)",
    )
    args = parser.parse_args()

    if args.datasets_dir:
        os.environ["HYPERVIEW_DATASETS_DIR"] = args.datasets_dir

    import hyperview as hv

    dataset = hv.Dataset(args.dataset, persist=not args.no_persist)

    dataset.add_from_huggingface(
        "uoft-cs/cifar10",
        split="train",
        image_key="img",
        label_key="label",
        max_samples=args.samples,
    )

    dataset.compute_embeddings(model=args.model, show_progress=True)

    # Compute both euclidean and poincare layouts
    dataset.compute_visualization(geometry="euclidean")
    dataset.compute_visualization(geometry="poincare")

    if args.no_server:
        return

    hv.launch(dataset, port=args.port, open_browser=not args.no_browser)


if __name__ == "__main__":
    main()