samwaugh commited on
Commit
267162d
Β·
1 Parent(s): a02b702
backend/runner/config.py CHANGED
@@ -6,7 +6,7 @@ All runner modules should import from this module instead of defining their own
6
  import os
7
  import json
8
  from pathlib import Path
9
- from typing import Any, Dict, Optional
10
 
11
  # Try to import required libraries
12
  try:
@@ -155,47 +155,25 @@ def load_json_datasets() -> Optional[Dict[str, Any]]:
155
  return None
156
 
157
  def load_embeddings_datasets() -> Optional[Dict[str, Any]]:
158
- """Load embeddings datasets from Hugging Face"""
159
- if not HF_HUB_AVAILABLE:
160
- print("⚠️ huggingface_hub library not available - skipping HF embeddings loading")
161
  return None
162
 
163
  try:
164
- print(f" Loading embeddings files from {ARTEFACT_EMBEDDINGS_DATASET}...")
165
 
166
- # Download the files to local paths
167
- clip_embeddings_path = hf_hub_download(
168
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
169
- filename='clip_embeddings.safetensors',
170
- repo_type="dataset"
171
- )
172
- paintingclip_embeddings_path = hf_hub_download(
173
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
174
- filename='paintingclip_embeddings.safetensors',
175
- repo_type="dataset"
176
- )
177
- clip_sentence_ids_path = hf_hub_download(
178
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
179
- filename='clip_embeddings_sentence_ids.json',
180
- repo_type="dataset"
181
- )
182
- paintingclip_sentence_ids_path = hf_hub_download(
183
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
184
- filename='paintingclip_embeddings_sentence_ids.json',
185
- repo_type="dataset"
186
- )
187
 
188
- print(f"βœ… Successfully downloaded embeddings files:")
189
- print(f" CLIP embeddings: {clip_embeddings_path}")
190
- print(f" PaintingCLIP embeddings: {paintingclip_embeddings_path}")
191
- print(f" CLIP sentence IDs: {clip_sentence_ids_path}")
192
- print(f" PaintingCLIP sentence IDs: {paintingclip_sentence_ids_path}")
193
 
 
194
  return {
195
- 'clip_embeddings_path': clip_embeddings_path,
196
- 'paintingclip_embeddings_path': paintingclip_embeddings_path,
197
- 'clip_sentence_ids_path': clip_sentence_ids_path,
198
- 'paintingclip_sentence_ids_path': paintingclip_sentence_ids_path
199
  }
200
  except Exception as e:
201
  print(f"❌ Failed to load embeddings datasets from HF: {e}")
 
6
  import os
7
  import json
8
  from pathlib import Path
9
+ from typing import Any, Dict, Optional, List, Tuple
10
 
11
  # Try to import required libraries
12
  try:
 
155
  return None
156
 
157
  def load_embeddings_datasets() -> Optional[Dict[str, Any]]:
158
+ """Load embeddings datasets from Hugging Face using streaming"""
159
+ if not DATASETS_AVAILABLE:
160
+ print("⚠️ datasets library not available - skipping HF embeddings loading")
161
  return None
162
 
163
  try:
164
+ print(f" Loading embeddings using streaming from {ARTEFACT_EMBEDDINGS_DATASET}...")
165
 
166
+ # Use streaming to avoid downloading large files
167
+ dataset = load_dataset(ARTEFACT_EMBEDDINGS_DATASET, split='train', streaming=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ print(f"βœ… Successfully loaded streaming dataset")
170
+ print(f" Dataset type: {type(dataset)}")
 
 
 
171
 
172
+ # Return the streaming dataset for on-demand processing
173
  return {
174
+ 'streaming_dataset': dataset,
175
+ 'use_streaming': True,
176
+ 'repo_id': ARTEFACT_EMBEDDINGS_DATASET
 
177
  }
178
  except Exception as e:
179
  print(f"❌ Failed to load embeddings datasets from HF: {e}")
backend/runner/inference.py CHANGED
@@ -68,34 +68,27 @@ TOP_K = 25 # Number of results to return
68
  # ─────────────────────────────────────────────────────────────────────────────
69
 
70
  def load_embeddings_from_hf():
71
- """Load embeddings from HF dataset"""
72
  try:
73
- print(f"πŸ” Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
75
- # Get the downloaded file paths from config
76
  if not EMBEDDINGS_DATASETS:
77
  print("❌ No embeddings datasets loaded")
78
  return None
79
 
80
- # Load sentence IDs
81
- with open(EMBEDDINGS_DATASETS['clip_sentence_ids_path'], 'r') as f:
82
- clip_sentence_ids = json.load(f)
83
- with open(EMBEDDINGS_DATASETS['paintingclip_sentence_ids_path'], 'r') as f:
84
- paintingclip_sentence_ids = json.load(f)
85
-
86
- # Load embeddings using safetensors
87
- import safetensors
88
- clip_embeddings = safetensors.safe_open(EMBEDDINGS_DATASETS['clip_embeddings_path'], framework="pt")
89
- paintingclip_embeddings = safetensors.safe_open(EMBEDDINGS_DATASETS['paintingclip_embeddings_path'], framework="pt")
90
-
91
- print(f"βœ… Successfully loaded embeddings from HF:")
92
- print(f" CLIP: {len(clip_sentence_ids)} embeddings")
93
- print(f" PaintingCLIP: {len(paintingclip_sentence_ids)} embeddings")
94
-
95
- return {
96
- "clip": (clip_embeddings, clip_sentence_ids),
97
- "paintingclip": (paintingclip_embeddings, paintingclip_sentence_ids)
98
- }
99
  except Exception as e:
100
  print(f"❌ Failed to load embeddings from HF: {e}")
101
  return None
@@ -173,15 +166,26 @@ def _initialize_pipeline():
173
  if embeddings_data is None:
174
  raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}")
175
 
176
- if MODEL_TYPE == "clip":
177
- embeddings, sentence_ids = embeddings_data["clip"]
 
 
 
 
 
 
 
178
  else:
179
- embeddings, sentence_ids = embeddings_data["paintingclip"]
180
-
181
- if embeddings is None or sentence_ids is None:
182
- raise ValueError(f"Failed to load embeddings for model type: {MODEL_TYPE}")
183
-
184
- print(f"πŸ” Loaded {len(sentence_ids)} embeddings with shape {embeddings.shape}")
 
 
 
 
185
  except Exception as e:
186
  print(f"❌ Error loading embeddings: {e}")
187
  raise
@@ -521,3 +525,97 @@ def st_load_file(file_path: Path) -> Any:
521
  except Exception as e:
522
  print(f"❌ Error loading {file_path}: {e}")
523
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # ─────────────────────────────────────────────────────────────────────────────
69
 
70
  def load_embeddings_from_hf():
71
+ """Load embeddings from HF dataset using streaming"""
72
  try:
73
+ print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
 
75
  if not EMBEDDINGS_DATASETS:
76
  print("❌ No embeddings datasets loaded")
77
  return None
78
 
79
+ # Check if we're using streaming
80
+ if EMBEDDINGS_DATASETS.get('use_streaming', False):
81
+ print("βœ… Using streaming embeddings dataset")
82
+ return {
83
+ "streaming": True,
84
+ "dataset": EMBEDDINGS_DATASETS['streaming_dataset'],
85
+ "repo_id": EMBEDDINGS_DATASETS['repo_id']
86
+ }
87
+ else:
88
+ # Fallback to old method if not streaming
89
+ print("⚠️ Using fallback embedding loading method")
90
+ return None
91
+
 
 
 
 
 
 
92
  except Exception as e:
93
  print(f"❌ Failed to load embeddings from HF: {e}")
94
  return None
 
166
  if embeddings_data is None:
167
  raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}")
168
 
169
+ # Check if we're using streaming
170
+ if embeddings_data.get("streaming", False):
171
+ print("βœ… Using streaming embeddings - will load on-demand")
172
+ # For streaming, we'll load embeddings as needed during inference
173
+ return {
174
+ "streaming": True,
175
+ "dataset": embeddings_data["dataset"],
176
+ "repo_id": embeddings_data["repo_id"]
177
+ }
178
  else:
179
+ # Old code path for non-streaming
180
+ if MODEL_TYPE == "clip":
181
+ embeddings, sentence_ids = embeddings_data["clip"]
182
+ else:
183
+ embeddings, sentence_ids = embeddings_data["paintingclip"]
184
+
185
+ if embeddings is None or sentence_ids is None:
186
+ raise ValueError(f"Failed to load embeddings for model type: {MODEL_TYPE}")
187
+
188
+ print(f"πŸ” Loaded {len(sentence_ids)} embeddings with shape {embeddings.shape}")
189
  except Exception as e:
190
  print(f"❌ Error loading embeddings: {e}")
191
  raise
 
525
  except Exception as e:
526
  print(f"❌ Error loading {file_path}: {e}")
527
  return None
528
+
529
+ def load_embedding_for_sentence(sentence_id: str, model_type: str = "clip") -> Optional[torch.Tensor]:
530
+ """Load a single embedding for a specific sentence using streaming"""
531
+ try:
532
+ if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
533
+ print("❌ Streaming embeddings not available")
534
+ return None
535
+
536
+ dataset = EMBEDDINGS_DATASETS['streaming_dataset']
537
+
538
+ # Search for the sentence in the streaming dataset
539
+ for item in dataset:
540
+ if item.get('sentence_id') == sentence_id:
541
+ # Extract the appropriate embedding based on model type
542
+ if model_type == "clip" and 'clip_embedding' in item:
543
+ return torch.tensor(item['clip_embedding'])
544
+ elif model_type == "paintingclip" and 'paintingclip_embedding' in item:
545
+ return torch.tensor(item['paintingclip_embedding'])
546
+ else:
547
+ print(f"⚠️ Embedding not found for {model_type} in sentence {sentence_id}")
548
+ return None
549
+
550
+ print(f"⚠️ Sentence {sentence_id} not found in streaming dataset")
551
+ return None
552
+
553
+ except Exception as e:
554
+ print(f"❌ Error loading streaming embedding for {sentence_id}: {e}")
555
+ return None
556
+
557
+ def get_top_k_embeddings(query_embedding: torch.Tensor, k: int = 10, model_type: str = "clip") -> List[Tuple[str, float]]:
558
+ """Get top-k most similar embeddings using streaming"""
559
+ try:
560
+ if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
561
+ print("❌ Streaming embeddings not available")
562
+ return []
563
+
564
+ dataset = EMBEDDINGS_DATASETS['streaming_dataset']
565
+ similarities = []
566
+
567
+ # Process embeddings in batches to avoid memory issues
568
+ batch_size = 1000
569
+ batch = []
570
+
571
+ for item in dataset:
572
+ batch.append(item)
573
+
574
+ if len(batch) >= batch_size:
575
+ # Process batch
576
+ batch_similarities = process_embedding_batch(batch, query_embedding, model_type)
577
+ similarities.extend(batch_similarities)
578
+ batch = []
579
+
580
+ # Keep only top-k so far
581
+ similarities.sort(key=lambda x: x[1], reverse=True)
582
+ similarities = similarities[:k]
583
+
584
+ # Process remaining items
585
+ if batch:
586
+ batch_similarities = process_embedding_batch(batch, query_embedding, model_type)
587
+ similarities.extend(batch_similarities)
588
+ similarities.sort(key=lambda x: x[1], reverse=True)
589
+ similarities = similarities[:k]
590
+
591
+ return similarities
592
+
593
+ except Exception as e:
594
+ print(f"❌ Error getting top-k embeddings: {e}")
595
+ return []
596
+
597
+ def process_embedding_batch(batch: List[Dict], query_embedding: torch.Tensor, model_type: str) -> List[Tuple[str, float]]:
598
+ """Process a batch of embeddings to find similarities"""
599
+ similarities = []
600
+
601
+ for item in batch:
602
+ try:
603
+ sentence_id = item.get('sentence_id', '')
604
+
605
+ # Get the appropriate embedding
606
+ if model_type == "clip" and 'clip_embedding' in item:
607
+ embedding = torch.tensor(item['clip_embedding'])
608
+ elif model_type == "paintingclip" and 'paintingclip_embedding' in item:
609
+ embedding = torch.tensor(item['paintingclip_embedding'])
610
+ else:
611
+ continue
612
+
613
+ # Calculate similarity
614
+ similarity = F.cosine_similarity(query_embedding.unsqueeze(0), embedding.unsqueeze(0), dim=1)
615
+ similarities.append((sentence_id, similarity.item()))
616
+
617
+ except Exception as e:
618
+ print(f"⚠️ Error processing item in batch: {e}")
619
+ continue
620
+
621
+ return similarities
requirements.txt CHANGED
@@ -21,4 +21,4 @@ numpy>=1.24.0
21
  # Optional: GPU acceleration (if available)
22
  # torchvision>=0.15.0 # Uncomment if you need additional vision models
23
 
24
- safetensors>=0.4.0
 
21
  # Optional: GPU acceleration (if available)
22
  # torchvision>=0.15.0 # Uncomment if you need additional vision models
23
 
24
+ safetensors>=0.4.0