samwaugh commited on
Commit
28fdc3d
Β·
1 Parent(s): 267162d

Try to fix inference.py

Browse files
Files changed (1) hide show
  1. backend/runner/inference.py +140 -6
backend/runner/inference.py CHANGED
@@ -169,12 +169,10 @@ def _initialize_pipeline():
169
  # Check if we're using streaming
170
  if embeddings_data.get("streaming", False):
171
  print("βœ… Using streaming embeddings - will load on-demand")
172
- # For streaming, we'll load embeddings as needed during inference
173
- return {
174
- "streaming": True,
175
- "dataset": embeddings_data["dataset"],
176
- "repo_id": embeddings_data["repo_id"]
177
- }
178
  else:
179
  # Old code path for non-streaming
180
  if MODEL_TYPE == "clip":
@@ -314,6 +312,21 @@ def run_inference(
314
  )
315
  print(f"βœ… Pipeline components loaded successfully")
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  # Get valid sentence IDs based on filters
318
  if filter_topics or filter_creators:
319
  print(f"πŸ” Applying filters...")
@@ -619,3 +632,124 @@ def process_embedding_batch(batch: List[Dict], query_embedding: torch.Tensor, mo
619
  continue
620
 
621
  return similarities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # Check if we're using streaming
170
  if embeddings_data.get("streaming", False):
171
  print("βœ… Using streaming embeddings - will load on-demand")
172
+ # For streaming, we need to handle this differently
173
+ # We'll return the components but mark embeddings as streaming
174
+ # The calling code will need to handle this case
175
+ return processor, model, "STREAMING", "STREAMING", "STREAMING", device
 
 
176
  else:
177
  # Old code path for non-streaming
178
  if MODEL_TYPE == "clip":
 
312
  )
313
  print(f"βœ… Pipeline components loaded successfully")
314
 
315
+ # Check if we're in streaming mode
316
+ if embeddings == "STREAMING":
317
+ print("βœ… Streaming mode detected - using streaming embeddings")
318
+ return run_inference_streaming(
319
+ image_path=image_path,
320
+ filter_topics=filter_topics,
321
+ filter_creators=filter_creators,
322
+ model_type=model_type,
323
+ top_k=top_k,
324
+ processor=processor,
325
+ model=model,
326
+ device=device
327
+ )
328
+
329
+ # Non-streaming mode - continue with existing logic
330
  # Get valid sentence IDs based on filters
331
  if filter_topics or filter_creators:
332
  print(f"πŸ” Applying filters...")
 
632
  continue
633
 
634
  return similarities
635
+
636
+ def run_inference_streaming(
637
+ image_path: str,
638
+ filter_topics: List[str] = None,
639
+ filter_creators: List[str] = None,
640
+ model_type: str = "CLIP",
641
+ top_k: int = 10,
642
+ processor=None,
643
+ model=None,
644
+ device=None
645
+ ) -> List[Dict[str, Any]]:
646
+ """Run inference using streaming embeddings"""
647
+ try:
648
+ print(f"πŸ” Running streaming inference for {image_path}")
649
+
650
+ # Load and preprocess the image
651
+ print(f"πŸ” Loading and preprocessing image: {image_path}")
652
+ image = Image.open(image_path).convert("RGB")
653
+ print(f"βœ… Image loaded successfully, size: {image.size}")
654
+
655
+ # Compute image embedding
656
+ inputs = processor(images=image, return_tensors="pt")
657
+ inputs = {k: v.to(device) for k, v in inputs.items()}
658
+
659
+ with torch.no_grad():
660
+ image_features = model.get_image_features(**inputs)
661
+ image_embedding = F.normalize(image_features.squeeze(0), dim=-1)
662
+
663
+ # Get streaming dataset
664
+ if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
665
+ raise ValueError("Streaming embeddings not available")
666
+
667
+ dataset = EMBEDDINGS_DATASETS['streaming_dataset']
668
+
669
+ # Process embeddings in streaming mode
670
+ results = []
671
+ batch_size = 1000
672
+ batch = []
673
+
674
+ print(f"πŸ” Processing streaming embeddings...")
675
+
676
+ for item in dataset:
677
+ batch.append(item)
678
+
679
+ if len(batch) >= batch_size:
680
+ # Process batch
681
+ batch_results = process_embedding_batch_streaming(
682
+ batch, image_embedding, model_type, device
683
+ )
684
+ results.extend(batch_results)
685
+ batch = []
686
+
687
+ # Keep only top-k so far
688
+ results.sort(key=lambda x: x["score"], reverse=True)
689
+ results = results[:top_k]
690
+
691
+ print(f"πŸ” Processed batch, current top score: {results[0]['score'] if results else 'N/A'}")
692
+
693
+ # Process remaining items
694
+ if batch:
695
+ batch_results = process_embedding_batch_streaming(
696
+ batch, image_embedding, model_type, device
697
+ )
698
+ results.extend(batch_results)
699
+ results.sort(key=lambda x: x["score"], reverse=True)
700
+ results = results[:top_k]
701
+
702
+ print(f"βœ… Streaming inference completed, returning {len(results)} results")
703
+ return results
704
+
705
+ except Exception as e:
706
+ print(f"❌ Error in streaming inference: {e}")
707
+ raise
708
+
709
+ def process_embedding_batch_streaming(
710
+ batch: List[Dict],
711
+ image_embedding: torch.Tensor,
712
+ model_type: str,
713
+ device: torch.device
714
+ ) -> List[Dict[str, Any]]:
715
+ """Process a batch of streaming embeddings"""
716
+ results = []
717
+
718
+ for item in batch:
719
+ try:
720
+ sentence_id = item.get('sentence_id', '')
721
+
722
+ # Get the appropriate embedding
723
+ if model_type == "CLIP" and 'clip_embedding' in item:
724
+ embedding = torch.tensor(item['clip_embedding'])
725
+ elif model_type == "PaintingCLIP" and 'paintingclip_embedding' in item:
726
+ embedding = torch.tensor(item['paintingclip_embedding'])
727
+ else:
728
+ continue
729
+
730
+ # Calculate similarity
731
+ embedding = embedding.to(device)
732
+ similarity = F.cosine_similarity(
733
+ image_embedding.unsqueeze(0),
734
+ embedding.unsqueeze(0),
735
+ dim=1
736
+ ).item()
737
+
738
+ # Get sentence metadata
739
+ sentences_data = _load_sentences_metadata()
740
+ sentence_data = sentences_data.get(sentence_id, {})
741
+ work_id = sentence_id.split("_")[0]
742
+
743
+ results.append({
744
+ "id": sentence_id,
745
+ "score": similarity,
746
+ "english_original": sentence_data.get("English Original", "N/A"),
747
+ "work": work_id,
748
+ "rank": len(results) + 1,
749
+ })
750
+
751
+ except Exception as e:
752
+ print(f"⚠️ Error processing item in streaming batch: {e}")
753
+ continue
754
+
755
+ return results