samwaugh commited on
Commit
ae940aa
Β·
1 Parent(s): 5b7bc70

Download instead of stream

Browse files
backend/runner/config.py CHANGED
@@ -155,24 +155,18 @@ def load_json_datasets() -> Optional[Dict[str, Any]]:
155
  return None
156
 
157
  def load_embeddings_datasets() -> Optional[Dict[str, Any]]:
158
- """Load embeddings datasets from Hugging Face using streaming"""
159
- if not DATASETS_AVAILABLE:
160
- print("⚠️ datasets library not available - skipping HF embeddings loading")
161
  return None
162
 
163
  try:
164
- print(f" Loading embeddings using streaming from {ARTEFACT_EMBEDDINGS_DATASET}...")
165
-
166
- # Use streaming to avoid downloading large files
167
- dataset = load_dataset(ARTEFACT_EMBEDDINGS_DATASET, split='train', streaming=True)
168
-
169
- print(f"βœ… Successfully loaded streaming dataset")
170
- print(f" Dataset type: {type(dataset)}")
171
 
172
- # Return the streaming dataset for on-demand processing
 
173
  return {
174
- 'streaming_dataset': dataset,
175
- 'use_streaming': True,
176
  'repo_id': ARTEFACT_EMBEDDINGS_DATASET
177
  }
178
  except Exception as e:
 
155
  return None
156
 
157
  def load_embeddings_datasets() -> Optional[Dict[str, Any]]:
158
+ """Load embeddings datasets from Hugging Face using direct file download"""
159
+ if not HF_HUB_AVAILABLE:
160
+ print("⚠️ huggingface_hub library not available - skipping HF embeddings loading")
161
  return None
162
 
163
  try:
164
+ print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
 
 
 
 
 
 
165
 
166
+ # Return a flag indicating we should use direct file download
167
+ # The actual loading will be done in inference.py
168
  return {
169
+ 'use_direct_download': True,
 
170
  'repo_id': ARTEFACT_EMBEDDINGS_DATASET
171
  }
172
  except Exception as e:
backend/runner/inference.py CHANGED
@@ -72,60 +72,72 @@ def load_embeddings_from_hf():
72
  try:
73
  print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
75
- # Download the safetensors files
76
- from huggingface_hub import hf_hub_download
77
- import safetensors
78
-
79
- # Download CLIP embeddings
80
- print("πŸ” Downloading CLIP embeddings...")
81
- clip_embeddings_path = hf_hub_download(
82
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
83
- filename="clip_embeddings.safetensors",
84
- repo_type="dataset"
85
- )
86
-
87
- clip_ids_path = hf_hub_download(
88
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
89
- filename="clip_embeddings_sentence_ids.json",
90
- repo_type="dataset"
91
- )
92
-
93
- # Download PaintingCLIP embeddings
94
- print("πŸ” Downloading PaintingCLIP embeddings...")
95
- paintingclip_embeddings_path = hf_hub_download(
96
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
97
- filename="paintingclip_embeddings.safetensors",
98
- repo_type="dataset"
99
- )
100
-
101
- paintingclip_ids_path = hf_hub_download(
102
- repo_id=ARTEFACT_EMBEDDINGS_DATASET,
103
- filename="paintingclip_embeddings_sentence_ids.json",
104
- repo_type="dataset"
105
- )
106
-
107
- # Load the embeddings
108
- print("πŸ” Loading CLIP embeddings...")
109
- clip_embeddings = safetensors.torch.load_file(clip_embeddings_path)['embeddings']
110
-
111
- print("πŸ” Loading PaintingCLIP embeddings...")
112
- paintingclip_embeddings = safetensors.torch.load_file(paintingclip_embeddings_path)['embeddings']
113
-
114
- # Load the sentence IDs
115
- with open(clip_ids_path, 'r') as f:
116
- clip_sentence_ids = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- with open(paintingclip_ids_path, 'r') as f:
119
- paintingclip_sentence_ids = json.load(f)
120
-
121
- print(f"βœ… Loaded CLIP embeddings: {clip_embeddings.shape}")
122
- print(f"βœ… Loaded PaintingCLIP embeddings: {paintingclip_embeddings.shape}")
123
-
124
- return {
125
- "clip": (clip_embeddings, clip_sentence_ids),
126
- "paintingclip": (paintingclip_embeddings, paintingclip_sentence_ids)
127
- }
128
-
129
  except Exception as e:
130
  print(f"❌ Failed to load embeddings from HF: {e}")
131
  return None
@@ -203,15 +215,12 @@ def _initialize_pipeline():
203
  if embeddings_data is None:
204
  raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}")
205
 
206
- # Check if we're using streaming
207
  if embeddings_data.get("streaming", False):
208
  print("βœ… Using streaming embeddings - will load on-demand")
209
- # For streaming, we need to handle this differently
210
- # We'll return the components but mark embeddings as streaming
211
- # The calling code will need to handle this case
212
  return processor, model, "STREAMING", "STREAMING", "STREAMING", device
213
  else:
214
- # New code path for safetensors files
215
  if MODEL_TYPE == "clip":
216
  embeddings, sentence_ids = embeddings_data["clip"]
217
  else:
 
72
  try:
73
  print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
75
+ if not EMBEDDINGS_DATASETS:
76
+ print("❌ No embeddings datasets loaded")
77
+ return None
78
+
79
+ # Check if we're using direct download
80
+ if EMBEDDINGS_DATASETS.get('use_direct_download', False):
81
+ print("βœ… Using direct file download for embeddings")
82
+
83
+ # Download the safetensors files
84
+ from huggingface_hub import hf_hub_download
85
+ import safetensors
86
+
87
+ # Download CLIP embeddings
88
+ print("πŸ” Downloading CLIP embeddings...")
89
+ clip_embeddings_path = hf_hub_download(
90
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
91
+ filename="clip_embeddings.safetensors",
92
+ repo_type="dataset"
93
+ )
94
+
95
+ clip_ids_path = hf_hub_download(
96
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
97
+ filename="clip_embeddings_sentence_ids.json",
98
+ repo_type="dataset"
99
+ )
100
+
101
+ # Download PaintingCLIP embeddings
102
+ print("πŸ” Downloading PaintingCLIP embeddings...")
103
+ paintingclip_embeddings_path = hf_hub_download(
104
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
105
+ filename="paintingclip_embeddings.safetensors",
106
+ repo_type="dataset"
107
+ )
108
+
109
+ paintingclip_ids_path = hf_hub_download(
110
+ repo_id=ARTEFACT_EMBEDDINGS_DATASET,
111
+ filename="paintingclip_embeddings_sentence_ids.json",
112
+ repo_type="dataset"
113
+ )
114
+
115
+ # Load the embeddings
116
+ print("πŸ” Loading CLIP embeddings...")
117
+ clip_embeddings = safetensors.torch.load_file(clip_embeddings_path)['embeddings']
118
+
119
+ print("πŸ” Loading PaintingCLIP embeddings...")
120
+ paintingclip_embeddings = safetensors.torch.load_file(paintingclip_embeddings_path)['embeddings']
121
+
122
+ # Load the sentence IDs
123
+ with open(clip_ids_path, 'r') as f:
124
+ clip_sentence_ids = json.load(f)
125
+
126
+ with open(paintingclip_ids_path, 'r') as f:
127
+ paintingclip_sentence_ids = json.load(f)
128
+
129
+ print(f"βœ… Loaded CLIP embeddings: {clip_embeddings.shape}")
130
+ print(f"βœ… Loaded PaintingCLIP embeddings: {paintingclip_embeddings.shape}")
131
+
132
+ return {
133
+ "clip": (clip_embeddings, clip_sentence_ids),
134
+ "paintingclip": (paintingclip_embeddings, paintingclip_sentence_ids)
135
+ }
136
+ else:
137
+ # Fallback to old method if not using direct download
138
+ print("⚠️ Using fallback embedding loading method")
139
+ return None
140
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
  print(f"❌ Failed to load embeddings from HF: {e}")
143
  return None
 
215
  if embeddings_data is None:
216
  raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}")
217
 
218
+ # Check if we're using streaming (old approach)
219
  if embeddings_data.get("streaming", False):
220
  print("βœ… Using streaming embeddings - will load on-demand")
 
 
 
221
  return processor, model, "STREAMING", "STREAMING", "STREAMING", device
222
  else:
223
+ # New code path for direct file download
224
  if MODEL_TYPE == "clip":
225
  embeddings, sentence_ids = embeddings_data["clip"]
226
  else: