Try fix 5
Browse files- backend/runner/config.py +13 -35
- backend/runner/inference.py +128 -30
- requirements.txt +1 -1
backend/runner/config.py
CHANGED
|
@@ -6,7 +6,7 @@ All runner modules should import from this module instead of defining their own
|
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Any, Dict, Optional
|
| 10 |
|
| 11 |
# Try to import required libraries
|
| 12 |
try:
|
|
@@ -155,47 +155,25 @@ def load_json_datasets() -> Optional[Dict[str, Any]]:
|
|
| 155 |
return None
|
| 156 |
|
| 157 |
def load_embeddings_datasets() -> Optional[Dict[str, Any]]:
|
| 158 |
-
"""Load embeddings datasets from Hugging Face"""
|
| 159 |
-
if not
|
| 160 |
-
print("β οΈ
|
| 161 |
return None
|
| 162 |
|
| 163 |
try:
|
| 164 |
-
print(f" Loading embeddings
|
| 165 |
|
| 166 |
-
#
|
| 167 |
-
|
| 168 |
-
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
|
| 169 |
-
filename='clip_embeddings.safetensors',
|
| 170 |
-
repo_type="dataset"
|
| 171 |
-
)
|
| 172 |
-
paintingclip_embeddings_path = hf_hub_download(
|
| 173 |
-
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
|
| 174 |
-
filename='paintingclip_embeddings.safetensors',
|
| 175 |
-
repo_type="dataset"
|
| 176 |
-
)
|
| 177 |
-
clip_sentence_ids_path = hf_hub_download(
|
| 178 |
-
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
|
| 179 |
-
filename='clip_embeddings_sentence_ids.json',
|
| 180 |
-
repo_type="dataset"
|
| 181 |
-
)
|
| 182 |
-
paintingclip_sentence_ids_path = hf_hub_download(
|
| 183 |
-
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
|
| 184 |
-
filename='paintingclip_embeddings_sentence_ids.json',
|
| 185 |
-
repo_type="dataset"
|
| 186 |
-
)
|
| 187 |
|
| 188 |
-
print(f"β
Successfully
|
| 189 |
-
print(f"
|
| 190 |
-
print(f" PaintingCLIP embeddings: {paintingclip_embeddings_path}")
|
| 191 |
-
print(f" CLIP sentence IDs: {clip_sentence_ids_path}")
|
| 192 |
-
print(f" PaintingCLIP sentence IDs: {paintingclip_sentence_ids_path}")
|
| 193 |
|
|
|
|
| 194 |
return {
|
| 195 |
-
'
|
| 196 |
-
'
|
| 197 |
-
'
|
| 198 |
-
'paintingclip_sentence_ids_path': paintingclip_sentence_ids_path
|
| 199 |
}
|
| 200 |
except Exception as e:
|
| 201 |
print(f"β Failed to load embeddings datasets from HF: {e}")
|
|
|
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, List, Tuple
|
| 10 |
|
| 11 |
# Try to import required libraries
|
| 12 |
try:
|
|
|
|
| 155 |
return None
|
| 156 |
|
| 157 |
def load_embeddings_datasets() -> Optional[Dict[str, Any]]:
|
| 158 |
+
"""Load embeddings datasets from Hugging Face using streaming"""
|
| 159 |
+
if not DATASETS_AVAILABLE:
|
| 160 |
+
print("β οΈ datasets library not available - skipping HF embeddings loading")
|
| 161 |
return None
|
| 162 |
|
| 163 |
try:
|
| 164 |
+
print(f" Loading embeddings using streaming from {ARTEFACT_EMBEDDINGS_DATASET}...")
|
| 165 |
|
| 166 |
+
# Use streaming to avoid downloading large files
|
| 167 |
+
dataset = load_dataset(ARTEFACT_EMBEDDINGS_DATASET, split='train', streaming=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
print(f"β
Successfully loaded streaming dataset")
|
| 170 |
+
print(f" Dataset type: {type(dataset)}")
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
# Return the streaming dataset for on-demand processing
|
| 173 |
return {
|
| 174 |
+
'streaming_dataset': dataset,
|
| 175 |
+
'use_streaming': True,
|
| 176 |
+
'repo_id': ARTEFACT_EMBEDDINGS_DATASET
|
|
|
|
| 177 |
}
|
| 178 |
except Exception as e:
|
| 179 |
print(f"β Failed to load embeddings datasets from HF: {e}")
|
backend/runner/inference.py
CHANGED
|
@@ -68,34 +68,27 @@ TOP_K = 25 # Number of results to return
|
|
| 68 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
|
| 70 |
def load_embeddings_from_hf():
|
| 71 |
-
"""Load embeddings from HF dataset"""
|
| 72 |
try:
|
| 73 |
-
print(f"
|
| 74 |
|
| 75 |
-
# Get the downloaded file paths from config
|
| 76 |
if not EMBEDDINGS_DATASETS:
|
| 77 |
print("β No embeddings datasets loaded")
|
| 78 |
return None
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
print(f" PaintingCLIP: {len(paintingclip_sentence_ids)} embeddings")
|
| 94 |
-
|
| 95 |
-
return {
|
| 96 |
-
"clip": (clip_embeddings, clip_sentence_ids),
|
| 97 |
-
"paintingclip": (paintingclip_embeddings, paintingclip_sentence_ids)
|
| 98 |
-
}
|
| 99 |
except Exception as e:
|
| 100 |
print(f"β Failed to load embeddings from HF: {e}")
|
| 101 |
return None
|
|
@@ -173,15 +166,26 @@ def _initialize_pipeline():
|
|
| 173 |
if embeddings_data is None:
|
| 174 |
raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}")
|
| 175 |
|
| 176 |
-
if
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
else:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
except Exception as e:
|
| 186 |
print(f"β Error loading embeddings: {e}")
|
| 187 |
raise
|
|
@@ -521,3 +525,97 @@ def st_load_file(file_path: Path) -> Any:
|
|
| 521 |
except Exception as e:
|
| 522 |
print(f"β Error loading {file_path}: {e}")
|
| 523 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
|
| 70 |
def load_embeddings_from_hf():
|
| 71 |
+
"""Load embeddings from HF dataset using streaming"""
|
| 72 |
try:
|
| 73 |
+
print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
|
| 74 |
|
|
|
|
| 75 |
if not EMBEDDINGS_DATASETS:
|
| 76 |
print("β No embeddings datasets loaded")
|
| 77 |
return None
|
| 78 |
|
| 79 |
+
# Check if we're using streaming
|
| 80 |
+
if EMBEDDINGS_DATASETS.get('use_streaming', False):
|
| 81 |
+
print("β
Using streaming embeddings dataset")
|
| 82 |
+
return {
|
| 83 |
+
"streaming": True,
|
| 84 |
+
"dataset": EMBEDDINGS_DATASETS['streaming_dataset'],
|
| 85 |
+
"repo_id": EMBEDDINGS_DATASETS['repo_id']
|
| 86 |
+
}
|
| 87 |
+
else:
|
| 88 |
+
# Fallback to old method if not streaming
|
| 89 |
+
print("β οΈ Using fallback embedding loading method")
|
| 90 |
+
return None
|
| 91 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
except Exception as e:
|
| 93 |
print(f"β Failed to load embeddings from HF: {e}")
|
| 94 |
return None
|
|
|
|
| 166 |
if embeddings_data is None:
|
| 167 |
raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}")
|
| 168 |
|
| 169 |
+
# Check if we're using streaming
|
| 170 |
+
if embeddings_data.get("streaming", False):
|
| 171 |
+
print("β
Using streaming embeddings - will load on-demand")
|
| 172 |
+
# For streaming, we'll load embeddings as needed during inference
|
| 173 |
+
return {
|
| 174 |
+
"streaming": True,
|
| 175 |
+
"dataset": embeddings_data["dataset"],
|
| 176 |
+
"repo_id": embeddings_data["repo_id"]
|
| 177 |
+
}
|
| 178 |
else:
|
| 179 |
+
# Old code path for non-streaming
|
| 180 |
+
if MODEL_TYPE == "clip":
|
| 181 |
+
embeddings, sentence_ids = embeddings_data["clip"]
|
| 182 |
+
else:
|
| 183 |
+
embeddings, sentence_ids = embeddings_data["paintingclip"]
|
| 184 |
+
|
| 185 |
+
if embeddings is None or sentence_ids is None:
|
| 186 |
+
raise ValueError(f"Failed to load embeddings for model type: {MODEL_TYPE}")
|
| 187 |
+
|
| 188 |
+
print(f"π Loaded {len(sentence_ids)} embeddings with shape {embeddings.shape}")
|
| 189 |
except Exception as e:
|
| 190 |
print(f"β Error loading embeddings: {e}")
|
| 191 |
raise
|
|
|
|
| 525 |
except Exception as e:
|
| 526 |
print(f"β Error loading {file_path}: {e}")
|
| 527 |
return None
|
| 528 |
+
|
| 529 |
+
def load_embedding_for_sentence(sentence_id: str, model_type: str = "clip") -> Optional[torch.Tensor]:
|
| 530 |
+
"""Load a single embedding for a specific sentence using streaming"""
|
| 531 |
+
try:
|
| 532 |
+
if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
|
| 533 |
+
print("β Streaming embeddings not available")
|
| 534 |
+
return None
|
| 535 |
+
|
| 536 |
+
dataset = EMBEDDINGS_DATASETS['streaming_dataset']
|
| 537 |
+
|
| 538 |
+
# Search for the sentence in the streaming dataset
|
| 539 |
+
for item in dataset:
|
| 540 |
+
if item.get('sentence_id') == sentence_id:
|
| 541 |
+
# Extract the appropriate embedding based on model type
|
| 542 |
+
if model_type == "clip" and 'clip_embedding' in item:
|
| 543 |
+
return torch.tensor(item['clip_embedding'])
|
| 544 |
+
elif model_type == "paintingclip" and 'paintingclip_embedding' in item:
|
| 545 |
+
return torch.tensor(item['paintingclip_embedding'])
|
| 546 |
+
else:
|
| 547 |
+
print(f"β οΈ Embedding not found for {model_type} in sentence {sentence_id}")
|
| 548 |
+
return None
|
| 549 |
+
|
| 550 |
+
print(f"β οΈ Sentence {sentence_id} not found in streaming dataset")
|
| 551 |
+
return None
|
| 552 |
+
|
| 553 |
+
except Exception as e:
|
| 554 |
+
print(f"β Error loading streaming embedding for {sentence_id}: {e}")
|
| 555 |
+
return None
|
| 556 |
+
|
| 557 |
+
def get_top_k_embeddings(query_embedding: torch.Tensor, k: int = 10, model_type: str = "clip") -> List[Tuple[str, float]]:
|
| 558 |
+
"""Get top-k most similar embeddings using streaming"""
|
| 559 |
+
try:
|
| 560 |
+
if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
|
| 561 |
+
print("β Streaming embeddings not available")
|
| 562 |
+
return []
|
| 563 |
+
|
| 564 |
+
dataset = EMBEDDINGS_DATASETS['streaming_dataset']
|
| 565 |
+
similarities = []
|
| 566 |
+
|
| 567 |
+
# Process embeddings in batches to avoid memory issues
|
| 568 |
+
batch_size = 1000
|
| 569 |
+
batch = []
|
| 570 |
+
|
| 571 |
+
for item in dataset:
|
| 572 |
+
batch.append(item)
|
| 573 |
+
|
| 574 |
+
if len(batch) >= batch_size:
|
| 575 |
+
# Process batch
|
| 576 |
+
batch_similarities = process_embedding_batch(batch, query_embedding, model_type)
|
| 577 |
+
similarities.extend(batch_similarities)
|
| 578 |
+
batch = []
|
| 579 |
+
|
| 580 |
+
# Keep only top-k so far
|
| 581 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 582 |
+
similarities = similarities[:k]
|
| 583 |
+
|
| 584 |
+
# Process remaining items
|
| 585 |
+
if batch:
|
| 586 |
+
batch_similarities = process_embedding_batch(batch, query_embedding, model_type)
|
| 587 |
+
similarities.extend(batch_similarities)
|
| 588 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 589 |
+
similarities = similarities[:k]
|
| 590 |
+
|
| 591 |
+
return similarities
|
| 592 |
+
|
| 593 |
+
except Exception as e:
|
| 594 |
+
print(f"β Error getting top-k embeddings: {e}")
|
| 595 |
+
return []
|
| 596 |
+
|
| 597 |
+
def process_embedding_batch(batch: List[Dict], query_embedding: torch.Tensor, model_type: str) -> List[Tuple[str, float]]:
|
| 598 |
+
"""Process a batch of embeddings to find similarities"""
|
| 599 |
+
similarities = []
|
| 600 |
+
|
| 601 |
+
for item in batch:
|
| 602 |
+
try:
|
| 603 |
+
sentence_id = item.get('sentence_id', '')
|
| 604 |
+
|
| 605 |
+
# Get the appropriate embedding
|
| 606 |
+
if model_type == "clip" and 'clip_embedding' in item:
|
| 607 |
+
embedding = torch.tensor(item['clip_embedding'])
|
| 608 |
+
elif model_type == "paintingclip" and 'paintingclip_embedding' in item:
|
| 609 |
+
embedding = torch.tensor(item['paintingclip_embedding'])
|
| 610 |
+
else:
|
| 611 |
+
continue
|
| 612 |
+
|
| 613 |
+
# Calculate similarity
|
| 614 |
+
similarity = F.cosine_similarity(query_embedding.unsqueeze(0), embedding.unsqueeze(0), dim=1)
|
| 615 |
+
similarities.append((sentence_id, similarity.item()))
|
| 616 |
+
|
| 617 |
+
except Exception as e:
|
| 618 |
+
print(f"β οΈ Error processing item in batch: {e}")
|
| 619 |
+
continue
|
| 620 |
+
|
| 621 |
+
return similarities
|
requirements.txt
CHANGED
|
@@ -21,4 +21,4 @@ numpy>=1.24.0
|
|
| 21 |
# Optional: GPU acceleration (if available)
|
| 22 |
# torchvision>=0.15.0 # Uncomment if you need additional vision models
|
| 23 |
|
| 24 |
-
safetensors>=0.4.0
|
|
|
|
| 21 |
# Optional: GPU acceleration (if available)
|
| 22 |
# torchvision>=0.15.0 # Uncomment if you need additional vision models
|
| 23 |
|
| 24 |
+
safetensors>=0.4.0
|