samwaugh commited on
Commit
8a51c8c
·
1 Parent(s): 65310de

Fix bad inference

Browse files
backend/runner/__pycache__/inference.cpython-313.pyc ADDED
Binary file (36.7 kB). View file
 
backend/runner/inference.py CHANGED
@@ -72,12 +72,14 @@ def load_embeddings_from_hf():
72
  try:
73
  print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
75
- if not EMBEDDINGS_DATASETS:
 
 
76
  print("❌ No embeddings datasets loaded")
77
  return None
78
 
79
  # Check if we're using direct download
80
- if EMBEDDINGS_DATASETS.get('use_direct_download', False):
81
  print("✅ Using direct file download for embeddings")
82
 
83
  # Download the safetensors files
@@ -598,11 +600,13 @@ def st_load_file(file_path: Path) -> Any:
598
  def load_embedding_for_sentence(sentence_id: str, model_type: str = "clip") -> Optional[torch.Tensor]:
599
  """Load a single embedding for a specific sentence using streaming"""
600
  try:
601
- if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
 
 
602
  print("❌ Streaming embeddings not available")
603
  return None
604
 
605
- dataset = EMBEDDINGS_DATASETS['streaming_dataset']
606
 
607
  # Search for the sentence in the streaming dataset
608
  for item in dataset:
@@ -626,11 +630,13 @@ def load_embedding_for_sentence(sentence_id: str, model_type: str = "clip") -> O
626
  def get_top_k_embeddings(query_embedding: torch.Tensor, k: int = 10, model_type: str = "clip") -> List[Tuple[str, float]]:
627
  """Get top-k most similar embeddings using streaming"""
628
  try:
629
- if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
 
 
630
  print("❌ Streaming embeddings not available")
631
  return []
632
 
633
- dataset = EMBEDDINGS_DATASETS['streaming_dataset']
634
  similarities = []
635
 
636
  # Process embeddings in batches to avoid memory issues
@@ -720,10 +726,12 @@ def run_inference_streaming(
720
  print(f"✅ Image embedding computed successfully")
721
 
722
  # Get streaming dataset
723
- if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
 
 
724
  raise ValueError("Streaming embeddings not available")
725
 
726
- dataset = EMBEDDINGS_DATASETS['streaming_dataset']
727
 
728
  # Process embeddings in streaming mode
729
  results = []
 
72
  try:
73
  print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
74
 
75
+ # Call the function to get the actual dictionary
76
+ embeddings_datasets = EMBEDDINGS_DATASETS()
77
+ if not embeddings_datasets:
78
  print("❌ No embeddings datasets loaded")
79
  return None
80
 
81
  # Check if we're using direct download
82
+ if embeddings_datasets.get('use_direct_download', False):
83
  print("✅ Using direct file download for embeddings")
84
 
85
  # Download the safetensors files
 
600
  def load_embedding_for_sentence(sentence_id: str, model_type: str = "clip") -> Optional[torch.Tensor]:
601
  """Load a single embedding for a specific sentence using streaming"""
602
  try:
603
+ # Call the function to get the actual dictionary
604
+ embeddings_datasets = EMBEDDINGS_DATASETS()
605
+ if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False):
606
  print("❌ Streaming embeddings not available")
607
  return None
608
 
609
+ dataset = embeddings_datasets['streaming_dataset']
610
 
611
  # Search for the sentence in the streaming dataset
612
  for item in dataset:
 
630
  def get_top_k_embeddings(query_embedding: torch.Tensor, k: int = 10, model_type: str = "clip") -> List[Tuple[str, float]]:
631
  """Get top-k most similar embeddings using streaming"""
632
  try:
633
+ # Call the function to get the actual dictionary
634
+ embeddings_datasets = EMBEDDINGS_DATASETS()
635
+ if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False):
636
  print("❌ Streaming embeddings not available")
637
  return []
638
 
639
+ dataset = embeddings_datasets['streaming_dataset']
640
  similarities = []
641
 
642
  # Process embeddings in batches to avoid memory issues
 
726
  print(f"✅ Image embedding computed successfully")
727
 
728
  # Get streaming dataset
729
+ # Call the function to get the actual dictionary
730
+ embeddings_datasets = EMBEDDINGS_DATASETS()
731
+ if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False):
732
  raise ValueError("Streaming embeddings not available")
733
 
734
+ dataset = embeddings_datasets['streaming_dataset']
735
 
736
  # Process embeddings in streaming mode
737
  results = []