samwaugh commited on
Commit
5b7bc70
Β·
1 Parent(s): 8e05ec6

New approach

Browse files
Files changed (1) hide show
  1. backend/runner/inference.py +56 -24
backend/runner/inference.py CHANGED
@@ -68,27 +68,64 @@ TOP_K = 25 # Number of results to return
68
  # ─────────────────────────────────────────────────────────────────────────────
69
 
70
  def load_embeddings_from_hf():
71
- """Load embeddings from HF dataset using streaming"""
72
  try:
73
  print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
75
- if not EMBEDDINGS_DATASETS:
76
- print("❌ No embeddings datasets loaded")
77
- return None
78
-
79
- # Check if we're using streaming
80
- if EMBEDDINGS_DATASETS.get('use_streaming', False):
81
- print("βœ… Using streaming embeddings dataset")
82
- return {
83
- "streaming": True,
84
- "dataset": EMBEDDINGS_DATASETS['streaming_dataset'],
85
- "repo_id": EMBEDDINGS_DATASETS['repo_id']
86
- }
87
- else:
88
- # Fallback to old method if not streaming
89
- print("⚠️ Using fallback embedding loading method")
90
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
  print(f"❌ Failed to load embeddings from HF: {e}")
94
  return None
@@ -174,7 +211,7 @@ def _initialize_pipeline():
174
  # The calling code will need to handle this case
175
  return processor, model, "STREAMING", "STREAMING", "STREAMING", device
176
  else:
177
- # Old code path for non-streaming
178
  if MODEL_TYPE == "clip":
179
  embeddings, sentence_ids = embeddings_data["clip"]
180
  else:
@@ -773,12 +810,7 @@ def process_embedding_batch_streaming(
773
  # Debug: show first few items to understand the data structure
774
  for i, item in enumerate(batch[:3]):
775
  print(f" Item {i}: keys = {list(item.keys())}")
776
- if 'clip_embedding' in item:
777
- print(f"πŸ” Item {i}: clip_embedding shape = {len(item['clip_embedding'])}")
778
- if 'paintingclip_embedding' in item:
779
- print(f" Item {i}: paintingclip_embedding shape = {len(item['paintingclip_embedding'])}")
780
- if 'sentence_id' in item:
781
- print(f" Item {i}: sentence_id = {item['sentence_id']}")
782
 
783
  for item in batch:
784
  try:
 
68
  # ─────────────────────────────────────────────────────────────────────────────
69
 
70
  def load_embeddings_from_hf():
71
+ """Load embeddings from HF dataset using safetensors files"""
72
  try:
73
  print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
75
+ # Download the safetensors files
76
+ from huggingface_hub import hf_hub_download
77
+ import safetensors
78
+
79
+ # Download CLIP embeddings
80
+ print("πŸ” Downloading CLIP embeddings...")
81
+ clip_embeddings_path = hf_hub_download(
82
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
83
+ filename="clip_embeddings.safetensors",
84
+ repo_type="dataset"
85
+ )
86
+
87
+ clip_ids_path = hf_hub_download(
88
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
89
+ filename="clip_embeddings_sentence_ids.json",
90
+ repo_type="dataset"
91
+ )
92
+
93
+ # Download PaintingCLIP embeddings
94
+ print("πŸ” Downloading PaintingCLIP embeddings...")
95
+ paintingclip_embeddings_path = hf_hub_download(
96
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
97
+ filename="paintingclip_embeddings.safetensors",
98
+ repo_type="dataset"
99
+ )
100
+
101
+ paintingclip_ids_path = hf_hub_download(
102
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
103
+ filename="paintingclip_embeddings_sentence_ids.json",
104
+ repo_type="dataset"
105
+ )
106
+
107
+ # Load the embeddings
108
+ print("πŸ” Loading CLIP embeddings...")
109
+ clip_embeddings = safetensors.torch.load_file(clip_embeddings_path)['embeddings']
110
+
111
+ print("πŸ” Loading PaintingCLIP embeddings...")
112
+ paintingclip_embeddings = safetensors.torch.load_file(paintingclip_embeddings_path)['embeddings']
113
+
114
+ # Load the sentence IDs
115
+ with open(clip_ids_path, 'r') as f:
116
+ clip_sentence_ids = json.load(f)
117
 
118
+ with open(paintingclip_ids_path, 'r') as f:
119
+ paintingclip_sentence_ids = json.load(f)
120
+
121
+ print(f"βœ… Loaded CLIP embeddings: {clip_embeddings.shape}")
122
+ print(f"βœ… Loaded PaintingCLIP embeddings: {paintingclip_embeddings.shape}")
123
+
124
+ return {
125
+ "clip": (clip_embeddings, clip_sentence_ids),
126
+ "paintingclip": (paintingclip_embeddings, paintingclip_sentence_ids)
127
+ }
128
+
129
  except Exception as e:
130
  print(f"❌ Failed to load embeddings from HF: {e}")
131
  return None
 
211
  # The calling code will need to handle this case
212
  return processor, model, "STREAMING", "STREAMING", "STREAMING", device
213
  else:
214
+ # New code path for safetensors files
215
  if MODEL_TYPE == "clip":
216
  embeddings, sentence_ids = embeddings_data["clip"]
217
  else:
 
810
  # Debug: show first few items to understand the data structure
811
  for i, item in enumerate(batch[:3]):
812
  print(f" Item {i}: keys = {list(item.keys())}")
813
+ print(f" Item {i}: full item = {item}")
 
 
 
 
 
814
 
815
  for item in batch:
816
  try: