samwaugh commited on
Commit
0ba12ad
Β·
1 Parent(s): 207c067

Use safetensors for consolidated embeddings; load via config; add ids JSON

Browse files
backend/runner/config.py CHANGED
@@ -66,3 +66,10 @@ WORKS_JSON = JSON_INFO_DIR / "works.json"
66
  TOPICS_JSON = JSON_INFO_DIR / "topics.json"
67
  CREATORS_JSON = JSON_INFO_DIR / "creators.json"
68
  TOPIC_NAMES_JSON = JSON_INFO_DIR / "topic_names.json"
 
 
 
 
 
 
 
 
66
  TOPICS_JSON = JSON_INFO_DIR / "topics.json"
67
  CREATORS_JSON = JSON_INFO_DIR / "creators.json"
68
  TOPIC_NAMES_JSON = JSON_INFO_DIR / "topic_names.json"
69
+
70
+ # Add below existing EMBEDDINGS_DIR constants
71
+ CLIP_EMBEDDINGS_ST = EMBEDDINGS_DIR / "clip_embeddings.safetensors"
72
+ CLIP_SENTENCE_IDS = EMBEDDINGS_DIR / "clip_embeddings_sentence_ids.json"
73
+
74
+ PAINTINGCLIP_EMBEDDINGS_ST = EMBEDDINGS_DIR / "paintingclip_embeddings.safetensors"
75
+ PAINTINGCLIP_SENTENCE_IDS = EMBEDDINGS_DIR / "paintingclip_embeddings_sentence_ids.json"
backend/runner/inference.py CHANGED
@@ -25,6 +25,7 @@ import torch.nn.functional as F
25
  from peft import PeftModel
26
  from PIL import Image
27
  from transformers import CLIPModel, CLIPProcessor
 
28
 
29
  from .filtering import get_filtered_sentence_ids
30
  # on-demand Grad-ECLIP & region-aware ranking
@@ -34,7 +35,9 @@ from .config import (
34
  PAINTINGCLIP_EMBEDDINGS_DIR,
35
  PAINTINGCLIP_MODEL_DIR,
36
  SENTENCES_JSON,
37
- EMBEDDINGS_DIR # ← Add this line
 
 
38
  )
39
 
40
  # ─── Configuration ───────────────────────────────────────────────────────────
@@ -456,33 +459,59 @@ def load_consolidated_embeddings(embedding_file: Path, metadata_file: Path):
456
 
457
  return embeddings, filename_to_index
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  # Update your embedding loading logic
460
  def load_embeddings_for_model(model_type: str):
461
- """Load embeddings for the specified model type"""
462
- if model_type == "clip":
463
- # Consolidated files are in the root embeddings directory
464
- embedding_file = EMBEDDINGS_DIR / "clip_embeddings_consolidated.pt"
465
- metadata_file = EMBEDDINGS_DIR / "clip_embeddings_metadata.json"
466
- else: # paintingclip
467
- # Consolidated files are in the root embeddings directory
468
- embedding_file = EMBEDDINGS_DIR / "paintingclip_embeddings_consolidated.pt"
469
- metadata_file = EMBEDDINGS_DIR / "paintingclip_embeddings_metadata.json"
470
-
471
- print(f"πŸ” Looking for embeddings at: {embedding_file}")
472
- print(f"πŸ” Looking for metadata at: {metadata_file}")
473
-
474
- if not embedding_file.exists():
475
- print(f"❌ Consolidated embedding file not found: {embedding_file}")
476
- print(f" Available files in embeddings directory:")
477
- for file in EMBEDDINGS_DIR.iterdir():
478
- print(f" - {file.name}")
479
- return None, None
480
-
481
- if not metadata_file.exists():
482
- print(f"❌ Metadata file not found: {metadata_file}")
483
- return None, None
484
-
485
- print(f"βœ… Found embedding file: {embedding_file}")
486
- print(f"βœ… Found metadata file: {metadata_file}")
487
-
488
- return load_consolidated_embeddings(embedding_file, metadata_file)
 
 
 
 
 
 
 
25
  from peft import PeftModel
26
  from PIL import Image
27
  from transformers import CLIPModel, CLIPProcessor
28
+ from safetensors.torch import load_file as st_load_file
29
 
30
  from .filtering import get_filtered_sentence_ids
31
  # on-demand Grad-ECLIP & region-aware ranking
 
35
  PAINTINGCLIP_EMBEDDINGS_DIR,
36
  PAINTINGCLIP_MODEL_DIR,
37
  SENTENCES_JSON,
38
+ EMBEDDINGS_DIR,
39
+ CLIP_EMBEDDINGS_ST, CLIP_SENTENCE_IDS,
40
+ PAINTINGCLIP_EMBEDDINGS_ST, PAINTINGCLIP_SENTENCE_IDS,
41
  )
42
 
43
  # ─── Configuration ───────────────────────────────────────────────────────────
 
459
 
460
  return embeddings, filename_to_index
461
 
462
+ def load_consolidated_embeddings_st(embedding_st_file: Path, ids_json_file: Path):
463
+ print(f"Loading safetensors embeddings from {embedding_st_file}")
464
+ if not embedding_st_file.exists():
465
+ raise FileNotFoundError(f"Missing {embedding_st_file}")
466
+ if not ids_json_file.exists():
467
+ raise FileNotFoundError(f"Missing {ids_json_file}")
468
+
469
+ data = st_load_file(str(embedding_st_file))
470
+ if "embeddings" not in data:
471
+ raise KeyError(f"'embeddings' tensor missing in {embedding_st_file}")
472
+ embeddings = data["embeddings"].to(dtype=torch.float32, device="cpu").contiguous()
473
+
474
+ with open(ids_json_file, "r", encoding="utf-8") as f:
475
+ sentence_ids = json.load(f)
476
+ if not isinstance(sentence_ids, list):
477
+ raise ValueError(f"IDs file malformed: {ids_json_file}")
478
+
479
+ print(f"Loaded {len(sentence_ids)} embeddings with dim {embeddings.shape[1]}")
480
+ return embeddings, sentence_ids
481
+
482
  # Update your embedding loading logic
483
  def load_embeddings_for_model(model_type: str):
484
+ """Load embeddings for the specified model type with safetensors-first strategy."""
485
+ if model_type == "clip":
486
+ st_file = CLIP_EMBEDDINGS_ST
487
+ ids_file = CLIP_SENTENCE_IDS
488
+ # Legacy PT locations for fallback (if repo still has them)
489
+ pt_file = EMBEDDINGS_DIR / "clip_embeddings_consolidated.pt"
490
+ meta_file = EMBEDDINGS_DIR / "clip_embeddings_metadata.json"
491
+ indiv_dir = CLIP_EMBEDDINGS_DIR
492
+ else:
493
+ st_file = PAINTINGCLIP_EMBEDDINGS_ST
494
+ ids_file = PAINTINGCLIP_SENTENCE_IDS
495
+ pt_file = EMBEDDINGS_DIR / "paintingclip_embeddings_consolidated.pt"
496
+ meta_file = EMBEDDINGS_DIR / "paintingclip_embeddings_metadata.json"
497
+ indiv_dir = PAINTINGCLIP_EMBEDDINGS_DIR
498
+
499
+ # 1) safetensors
500
+ if st_file.exists() and ids_file.exists():
501
+ try:
502
+ return load_consolidated_embeddings_st(st_file, ids_file)
503
+ except Exception as e:
504
+ print(f"⚠️ Safetensors load failed: {e}")
505
+
506
+ # 2) legacy PT (if present)
507
+ if pt_file.exists() and meta_file.exists():
508
+ try:
509
+ return load_consolidated_embeddings(pt_file, meta_file)
510
+ except Exception as e:
511
+ print(f"⚠️ Legacy PT load failed: {e}")
512
+
513
+ # 3) final fallback: refuse to scan 10k files here (HF Spaces file count limits)
514
+ print("❌ No valid consolidated embeddings found. Make sure you committed:")
515
+ print(f" - {st_file.name}")
516
+ print(f" - {ids_file.name}")
517
+ return None, None
consolidate_embeddings.py CHANGED
@@ -1,156 +1,81 @@
1
  #!/usr/bin/env python3
2
- """
3
- Consolidate individual embedding .pt files into larger consolidated files.
4
- This solves the Hugging Face Spaces 10,000 files per directory limit.
5
- """
6
-
7
- import torch
8
- import os
9
  import json
 
10
  from pathlib import Path
11
- from typing import Dict, List, Tuple
12
- import argparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def consolidate_embeddings(
15
- input_dir: Path,
16
- output_file: Path,
17
- metadata_file: Path,
18
- batch_size: int = 1000
19
- ) -> Dict[str, int]:
20
- """
21
- Merge individual .pt files into one large tensor file with metadata.
22
-
23
- Args:
24
- input_dir: Directory containing individual .pt files
25
- output_file: Path to save consolidated tensor
26
- metadata_file: Path to save file mapping metadata
27
- batch_size: Process files in batches to manage memory
28
-
29
- Returns:
30
- Dict with statistics about the consolidation
31
- """
32
- embedding_files = sorted(list(input_dir.glob("*.pt")))
33
-
34
- if not embedding_files:
35
- raise ValueError(f"No .pt files found in {input_dir}")
36
-
37
- print(f"Found {len(embedding_files)} embedding files in {input_dir}")
38
-
39
- # Load first file to get embedding dimension
40
- print("Loading first embedding to determine dimensions...")
41
- first_embedding = torch.load(embedding_files[0])
42
- embedding_dim = first_embedding.shape[0]
43
- print(f"Embedding dimension: {embedding_dim}")
44
-
45
- # Pre-allocate tensor
46
- all_embeddings = torch.zeros(len(embedding_files), embedding_dim, dtype=first_embedding.dtype)
47
- file_mapping = []
48
-
49
- print(f"Consolidating {len(embedding_files)} embeddings...")
50
-
51
- for i, file_path in enumerate(embedding_files):
52
- if i % 1000 == 0:
53
- print(f"Processing {i}/{len(embedding_files)} ({i/len(embedding_files)*100:.1f}%)")
54
-
55
- try:
56
- embedding = torch.load(file_path)
57
- all_embeddings[i] = embedding
58
-
59
- # Store file mapping for later lookup
60
- file_mapping.append({
61
- 'index': i,
62
- 'filename': file_path.name,
63
- 'stem': file_path.stem,
64
- 'file_size': file_path.stat().st_size
65
- })
66
-
67
- except Exception as e:
68
- print(f"Error loading {file_path}: {e}")
69
- # Fill with zeros if file is corrupted
70
- all_embeddings[i] = torch.zeros(embedding_dim, dtype=first_embedding.dtype)
71
-
72
- # Save consolidated data
73
- print(f"Saving consolidated embeddings to {output_file}...")
74
- consolidated_data = {
75
- 'embeddings': all_embeddings,
76
- 'embedding_dim': embedding_dim,
77
- 'num_embeddings': len(embedding_files),
78
- 'dtype': str(first_embedding.dtype)
79
- }
80
-
81
- torch.save(consolidated_data, output_file)
82
-
83
- # Save metadata for lookup
84
- print(f"Saving metadata to {metadata_file}...")
85
- metadata = {
86
- 'input_directory': str(input_dir),
87
- 'output_file': str(output_file),
88
- 'num_embeddings': len(embedding_files),
89
- 'embedding_dim': embedding_dim,
90
- 'dtype': str(first_embedding.dtype),
91
- 'file_mapping': file_mapping
92
- }
93
-
94
- with open(metadata_file, 'w', encoding='utf-8') as f:
95
- json.dump(metadata, f, indent=2, ensure_ascii=False)
96
-
97
- # Calculate file sizes
98
- original_size = sum(f.stat().st_size for f in embedding_files)
99
- consolidated_size = output_file.stat().st_size
100
- metadata_size = metadata_file.stat().st_size
101
-
102
- stats = {
103
- 'num_files_processed': len(embedding_files),
104
- 'original_size_mb': original_size / (1024 * 1024),
105
- 'consolidated_size_mb': consolidated_size / (1024 * 1024),
106
- 'metadata_size_kb': metadata_size / 1024,
107
- 'compression_ratio': original_size / consolidated_size if consolidated_size > 0 else 0
108
- }
109
-
110
- print(f"\nConsolidation complete!")
111
- print(f"Files processed: {stats['num_files_processed']}")
112
- print(f"Original size: {stats['original_size_mb']:.1f} MB")
113
- print(f"Consolidated size: {stats['consolidated_size_mb']:.1f} MB")
114
- print(f"Metadata size: {stats['metadata_size_kb']:.1f} KB")
115
- print(f"Compression ratio: {stats['compression_ratio']:.2f}x")
116
-
117
- return stats
118
 
119
  def main():
120
- parser = argparse.ArgumentParser(description='Consolidate embedding files')
121
- parser.add_argument('--input-dir', type=str, required=True,
122
- help='Input directory containing .pt files')
123
- parser.add_argument('--output-file', type=str, required=True,
124
- help='Output consolidated .pt file')
125
- parser.add_argument('--metadata-file', type=str, required=True,
126
- help='Output metadata JSON file')
127
- parser.add_argument('--batch-size', type=int, default=1000,
128
- help='Batch size for processing (default: 1000)')
129
-
130
- args = parser.parse_args()
131
-
132
- input_dir = Path(args.input_dir)
133
- output_file = Path(args.output_file)
134
- metadata_file = Path(args.metadata_file)
135
-
136
- if not input_dir.exists():
137
- print(f"Error: Input directory {input_dir} does not exist")
138
- return 1
139
-
140
- # Create output directory if it doesn't exist
141
- output_file.parent.mkdir(parents=True, exist_ok=True)
142
-
143
- try:
144
- stats = consolidate_embeddings(
145
- input_dir=input_dir,
146
- output_file=output_file,
147
- metadata_file=metadata_file,
148
- batch_size=args.batch_size
149
- )
150
- return 0
151
- except Exception as e:
152
- print(f"Error during consolidation: {e}")
153
- return 1
154
 
155
  if __name__ == "__main__":
156
- exit(main())
 
1
  #!/usr/bin/env python3
 
 
 
 
 
 
 
2
  import json
3
+ import sys
4
  from pathlib import Path
5
+ from typing import List, Tuple
6
+
7
+ import torch
8
+ from safetensors.torch import save_file
9
+
10
+ ROOT = Path(__file__).resolve().parent
11
+ DATA_DIR = ROOT / "data" / "embeddings"
12
+ CLIP_DIR = DATA_DIR / "CLIP_Embeddings"
13
+ PAINTINGCLIP_DIR = DATA_DIR / "PaintingCLIP_Embeddings"
14
+
15
+ def load_one(pt_path: Path) -> torch.Tensor:
16
+ """Load a single .pt embedding, handling dict-or-tensor variants."""
17
+ obj = torch.load(pt_path, map_location="cpu", weights_only=True)
18
+ if isinstance(obj, torch.Tensor):
19
+ return obj
20
+ if isinstance(obj, dict):
21
+ for k in ("embedding", "embeddings", "features"):
22
+ if k in obj:
23
+ t = obj[k]
24
+ if isinstance(t, torch.Tensor):
25
+ return t
26
+ raise ValueError(f"Unsupported .pt content in {pt_path}")
27
 
28
+ def derive_id_from_filename(stem: str) -> str:
29
+ """
30
+ - CLIP: Wxxxx_sYYYY_clip β†’ Wxxxx_sYYYY
31
+ - PaintingCLIP: Wxxxx_sYYYY_painting_clip β†’ Wxxxx_sYYYY
32
+ """
33
+ if stem.endswith("_painting_clip"):
34
+ return stem[: -len("_painting_clip")]
35
+ if stem.endswith("_clip"):
36
+ return stem[: -len("_clip")]
37
+ return stem # fallback
38
+
39
+ def consolidate_dir(indir: Path) -> Tuple[torch.Tensor, List[str]]:
40
+ pt_files = sorted(indir.glob("*.pt"))
41
+ if not pt_files:
42
+ raise RuntimeError(f"No .pt files found under {indir}")
43
+
44
+ embs: List[torch.Tensor] = []
45
+ ids: List[str] = []
46
+
47
+ for i, p in enumerate(pt_files, 1):
48
+ e = load_one(p).float()
49
+ if e.ndim > 1:
50
+ e = e.squeeze()
51
+ if e.ndim != 1:
52
+ raise ValueError(f"Embedding is not 1D in {p}: shape={tuple(e.shape)}")
53
+ embs.append(e)
54
+ ids.append(derive_id_from_filename(p.stem))
55
+ if i % 1000 == 0:
56
+ print(f"... processed {i} files from {indir}")
57
+
58
+ # Stack to [N, D]
59
+ embeddings = torch.stack(embs, dim=0).contiguous()
60
+ return embeddings, ids
61
+
62
+ def save_as_safetensors(embeddings: torch.Tensor, ids: List[str], out_prefix: Path) -> None:
63
+ out_st = out_prefix.with_suffix(".safetensors")
64
+ out_json = out_prefix.with_name(out_prefix.name + "_sentence_ids.json")
65
+ save_file({"embeddings": embeddings}, str(out_st))
66
+ with open(out_json, "w", encoding="utf-8") as f:
67
+ json.dump(ids, f, ensure_ascii=False, indent=2)
68
+ print(f"Saved embeddings: {out_st} [{tuple(embeddings.shape)}]")
69
+ print(f"Saved sentence IDs: {out_json} [{len(ids)} ids]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def main():
72
+ print("Consolidating CLIP...")
73
+ clip_emb, clip_ids = consolidate_dir(CLIP_DIR)
74
+ save_as_safetensors(clip_emb, clip_ids, DATA_DIR / "clip_embeddings")
75
+
76
+ print("Consolidating PaintingCLIP...")
77
+ pclip_emb, pclip_ids = consolidate_dir(PAINTINGCLIP_DIR)
78
+ save_as_safetensors(pclip_emb, pclip_ids, DATA_DIR / "paintingclip_embeddings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  if __name__ == "__main__":
81
+ main()
data/embeddings/clip_embeddings.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0f8443adf8f749c0dc80339a40817dbcb4b0d23eca505ca106096b8af9b89b7
3
+ size 30052440
data/embeddings/clip_embeddings_sentence_ids.json ADDED
The diff for this file is too large to render. See raw diff
 
data/embeddings/paintingclip_embeddings.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38c1dc2984813f5ea242b83e27b32ff19f827e517d985b03d30604ea775bd97a
3
+ size 30052440
data/embeddings/paintingclip_embeddings_sentence_ids.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -19,3 +19,5 @@ numpy>=1.24.0
19
 
20
  # Optional: GPU acceleration (if available)
21
  # torchvision>=0.15.0 # Uncomment if you need additional vision models
 
 
 
19
 
20
  # Optional: GPU acceleration (if available)
21
  # torchvision>=0.15.0 # Uncomment if you need additional vision models
22
+
23
+ safetensors>=0.4.0