samwaugh commited on
Commit
fd6d67a
Β·
1 Parent(s): 28fdc3d

Add logging for large output

Browse files
backend/runner/inference.py CHANGED
@@ -646,6 +646,7 @@ def run_inference_streaming(
646
  """Run inference using streaming embeddings"""
647
  try:
648
  print(f"πŸ” Running streaming inference for {image_path}")
 
649
 
650
  # Load and preprocess the image
651
  print(f"πŸ” Loading and preprocessing image: {image_path}")
@@ -653,12 +654,14 @@ def run_inference_streaming(
653
  print(f"βœ… Image loaded successfully, size: {image.size}")
654
 
655
  # Compute image embedding
 
656
  inputs = processor(images=image, return_tensors="pt")
657
  inputs = {k: v.to(device) for k, v in inputs.items()}
658
 
659
  with torch.no_grad():
660
  image_features = model.get_image_features(**inputs)
661
  image_embedding = F.normalize(image_features.squeeze(0), dim=-1)
 
662
 
663
  # Get streaming dataset
664
  if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
@@ -670,14 +673,35 @@ def run_inference_streaming(
670
  results = []
671
  batch_size = 1000
672
  batch = []
 
 
673
 
674
- print(f"πŸ” Processing streaming embeddings...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
  for item in dataset:
677
  batch.append(item)
 
678
 
679
  if len(batch) >= batch_size:
 
 
 
680
  # Process batch
 
681
  batch_results = process_embedding_batch_streaming(
682
  batch, image_embedding, model_type, device
683
  )
@@ -688,10 +712,29 @@ def run_inference_streaming(
688
  results.sort(key=lambda x: x["score"], reverse=True)
689
  results = results[:top_k]
690
 
691
- print(f"πŸ” Processed batch, current top score: {results[0]['score'] if results else 'N/A'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
  # Process remaining items
694
  if batch:
 
695
  batch_results = process_embedding_batch_streaming(
696
  batch, image_embedding, model_type, device
697
  )
@@ -699,7 +742,15 @@ def run_inference_streaming(
699
  results.sort(key=lambda x: x["score"], reverse=True)
700
  results = results[:top_k]
701
 
702
- print(f"βœ… Streaming inference completed, returning {len(results)} results")
 
 
 
 
 
 
 
 
703
  return results
704
 
705
  except Exception as e:
 
646
  """Run inference using streaming embeddings"""
647
  try:
648
  print(f"πŸ” Running streaming inference for {image_path}")
649
+ start_time = time.time()
650
 
651
  # Load and preprocess the image
652
  print(f"πŸ” Loading and preprocessing image: {image_path}")
 
654
  print(f"βœ… Image loaded successfully, size: {image.size}")
655
 
656
  # Compute image embedding
657
+ print(f"πŸ” Computing image embedding...")
658
  inputs = processor(images=image, return_tensors="pt")
659
  inputs = {k: v.to(device) for k, v in inputs.items()}
660
 
661
  with torch.no_grad():
662
  image_features = model.get_image_features(**inputs)
663
  image_embedding = F.normalize(image_features.squeeze(0), dim=-1)
664
+ print(f"βœ… Image embedding computed successfully")
665
 
666
  # Get streaming dataset
667
  if not EMBEDDINGS_DATASETS or not EMBEDDINGS_DATASETS.get('use_streaming', False):
 
673
  results = []
674
  batch_size = 1000
675
  batch = []
676
+ total_processed = 0
677
+ batch_count = 0
678
 
679
+ print(f"πŸ” Starting streaming processing of 3.1M+ sentence embeddings...")
680
+ print(f"πŸ” Batch size: {batch_size}")
681
+ print(f"πŸ” Target top-k: {top_k}")
682
+
683
+ # Estimate total items for progress tracking
684
+ try:
685
+ # Try to get dataset size if available
686
+ if hasattr(dataset, '__len__'):
687
+ total_items = len(dataset)
688
+ print(f"πŸ” Total embeddings to process: {total_items:,}")
689
+ else:
690
+ total_items = None
691
+ print(f"πŸ” Dataset size unknown (streaming mode)")
692
+ except:
693
+ total_items = None
694
 
695
  for item in dataset:
696
  batch.append(item)
697
+ total_processed += 1
698
 
699
  if len(batch) >= batch_size:
700
+ batch_count += 1
701
+ batch_start_time = time.time()
702
+
703
  # Process batch
704
+ print(f"πŸ” Processing batch {batch_count} ({total_processed:,} items processed)...")
705
  batch_results = process_embedding_batch_streaming(
706
  batch, image_embedding, model_type, device
707
  )
 
712
  results.sort(key=lambda x: x["score"], reverse=True)
713
  results = results[:top_k]
714
 
715
+ batch_time = time.time() - batch_start_time
716
+ elapsed_time = time.time() - start_time
717
+
718
+ # Progress reporting
719
+ if total_items:
720
+ progress_pct = (total_processed / total_items) * 100
721
+ print(f"πŸ” Batch {batch_count} completed in {batch_time:.2f}s")
722
+ print(f"πŸ” Progress: {total_processed:,}/{total_items:,} ({progress_pct:.1f}%)")
723
+ print(f"πŸ” Elapsed time: {elapsed_time:.1f}s")
724
+ print(f"πŸ” Current top score: {results[0]['score']:.4f} if results else 'N/A'")
725
+ print(f"πŸ” Estimated time remaining: {((elapsed_time / total_processed) * (total_items - total_processed)):.1f}s")
726
+ else:
727
+ print(f"πŸ” Batch {batch_count} completed in {batch_time:.2f}s")
728
+ print(f"πŸ” Total processed: {total_processed:,}")
729
+ print(f"πŸ” Elapsed time: {elapsed_time:.1f}s")
730
+ print(f"πŸ” Current top score: {results[0]['score']:.4f} if results else 'N/A'")
731
+
732
+ print(f"πŸ” Current top result: {results[0]['english_original'][:100]}..." if results else "No results yet")
733
+ print("─" * 80)
734
 
735
  # Process remaining items
736
  if batch:
737
+ print(f"πŸ” Processing final batch of {len(batch)} items...")
738
  batch_results = process_embedding_batch_streaming(
739
  batch, image_embedding, model_type, device
740
  )
 
742
  results.sort(key=lambda x: x["score"], reverse=True)
743
  results = results[:top_k]
744
 
745
+ total_time = time.time() - start_time
746
+ print(f"βœ… Streaming inference completed!")
747
+ print(f"πŸ” Total time: {total_time:.2f}s")
748
+ print(f"πŸ” Total embeddings processed: {total_processed:,}")
749
+ print(f"πŸ” Final results: {len(results)} items")
750
+ if results:
751
+ print(f"πŸ” Top result score: {results[0]['score']:.4f}")
752
+ print(f"πŸ” Top result: {results[0]['english_original'][:100]}...")
753
+
754
  return results
755
 
756
  except Exception as e:
consolidate_embeddings.py DELETED
@@ -1,81 +0,0 @@
1
- #!/usr/bin/env python3
2
- import json
3
- import sys
4
- from pathlib import Path
5
- from typing import List, Tuple
6
-
7
- import torch
8
- from safetensors.torch import save_file
9
-
10
- ROOT = Path(__file__).resolve().parent
11
- DATA_DIR = ROOT / "data" / "embeddings"
12
- CLIP_DIR = DATA_DIR / "CLIP_Embeddings"
13
- PAINTINGCLIP_DIR = DATA_DIR / "PaintingCLIP_Embeddings"
14
-
15
- def load_one(pt_path: Path) -> torch.Tensor:
16
- """Load a single .pt embedding, handling dict-or-tensor variants."""
17
- obj = torch.load(pt_path, map_location="cpu", weights_only=True)
18
- if isinstance(obj, torch.Tensor):
19
- return obj
20
- if isinstance(obj, dict):
21
- for k in ("embedding", "embeddings", "features"):
22
- if k in obj:
23
- t = obj[k]
24
- if isinstance(t, torch.Tensor):
25
- return t
26
- raise ValueError(f"Unsupported .pt content in {pt_path}")
27
-
28
- def derive_id_from_filename(stem: str) -> str:
29
- """
30
- - CLIP: Wxxxx_sYYYY_clip β†’ Wxxxx_sYYYY
31
- - PaintingCLIP: Wxxxx_sYYYY_painting_clip β†’ Wxxxx_sYYYY
32
- """
33
- if stem.endswith("_painting_clip"):
34
- return stem[: -len("_painting_clip")]
35
- if stem.endswith("_clip"):
36
- return stem[: -len("_clip")]
37
- return stem # fallback
38
-
39
- def consolidate_dir(indir: Path) -> Tuple[torch.Tensor, List[str]]:
40
- pt_files = sorted(indir.glob("*.pt"))
41
- if not pt_files:
42
- raise RuntimeError(f"No .pt files found under {indir}")
43
-
44
- embs: List[torch.Tensor] = []
45
- ids: List[str] = []
46
-
47
- for i, p in enumerate(pt_files, 1):
48
- e = load_one(p).float()
49
- if e.ndim > 1:
50
- e = e.squeeze()
51
- if e.ndim != 1:
52
- raise ValueError(f"Embedding is not 1D in {p}: shape={tuple(e.shape)}")
53
- embs.append(e)
54
- ids.append(derive_id_from_filename(p.stem))
55
- if i % 1000 == 0:
56
- print(f"... processed {i} files from {indir}")
57
-
58
- # Stack to [N, D]
59
- embeddings = torch.stack(embs, dim=0).contiguous()
60
- return embeddings, ids
61
-
62
- def save_as_safetensors(embeddings: torch.Tensor, ids: List[str], out_prefix: Path) -> None:
63
- out_st = out_prefix.with_suffix(".safetensors")
64
- out_json = out_prefix.with_name(out_prefix.name + "_sentence_ids.json")
65
- save_file({"embeddings": embeddings}, str(out_st))
66
- with open(out_json, "w", encoding="utf-8") as f:
67
- json.dump(ids, f, ensure_ascii=False, indent=2)
68
- print(f"Saved embeddings: {out_st} [{tuple(embeddings.shape)}]")
69
- print(f"Saved sentence IDs: {out_json} [{len(ids)} ids]")
70
-
71
- def main():
72
- print("Consolidating CLIP...")
73
- clip_emb, clip_ids = consolidate_dir(CLIP_DIR)
74
- save_as_safetensors(clip_emb, clip_ids, DATA_DIR / "clip_embeddings")
75
-
76
- print("Consolidating PaintingCLIP...")
77
- pclip_emb, pclip_ids = consolidate_dir(PAINTINGCLIP_DIR)
78
- save_as_safetensors(pclip_emb, pclip_ids, DATA_DIR / "paintingclip_embeddings")
79
-
80
- if __name__ == "__main__":
81
- main()