ArteFact / backend /runner /inference.py
samwaugh's picture
Fix bad inference
8a51c8c
"""
PaintingCLIP inference pipeline for art-historical text retrieval.
This module provides a pure functional interface for comparing artwork images
against a corpus of pre-computed sentence embeddings using CLIP models with
optional LoRA fine-tuning.
The pipeline:
1. Loads an image and computes its embedding using CLIP/PaintingCLIP
2. Compares against pre-computed sentence embeddings via cosine similarity
3. Returns the top-K most similar sentences with their metadata
"""
import base64
import io
import json
import time
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple
import cv2
import torch
import torch.nn.functional as F
from peft import PeftModel
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from datasets import load_dataset
from .filtering import get_filtered_sentence_ids
# on-demand Grad-ECLIP & region-aware ranking
from .heatmap import generate_heatmap
from .config import (
JSON_INFO_DIR,
EMBEDDINGS_DIR,
JSON_DATASETS,
EMBEDDINGS_DATASETS,
PAINTINGCLIP_MODEL_DIR,
ARTEFACT_EMBEDDINGS_DATASET,
sentences, # Add this
CLIP_EMBEDDINGS_ST, # Add these for backward compatibility
PAINTINGCLIP_EMBEDDINGS_ST,
CLIP_SENTENCE_IDS,
PAINTINGCLIP_SENTENCE_IDS,
CLIP_EMBEDDINGS_DIR,
PAINTINGCLIP_EMBEDDINGS_DIR
)
# ─── Configuration ───────────────────────────────────────────────────────────
MODEL_TYPE: Literal["clip", "paintingclip"] = "paintingclip"
# Model selection - change this to switch between models
MODEL_CONFIG = {
"clip": {
"model_id": "openai/clip-vit-base-patch32",
"use_lora": False,
"lora_dir": None,
},
"paintingclip": {
"model_id": "openai/clip-vit-base-patch32",
"use_lora": True,
"lora_dir": PAINTINGCLIP_MODEL_DIR, # This should now point to the correct path
},
}
# Inference settings
TOP_K = 25 # Number of results to return
# ─────────────────────────────────────────────────────────────────────────────
def load_embeddings_from_hf():
"""Load embeddings from HF dataset using safetensors files"""
try:
print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...")
# Call the function to get the actual dictionary
embeddings_datasets = EMBEDDINGS_DATASETS()
if not embeddings_datasets:
print("❌ No embeddings datasets loaded")
return None
# Check if we're using direct download
if embeddings_datasets.get('use_direct_download', False):
print("βœ… Using direct file download for embeddings")
# Download the safetensors files
from huggingface_hub import hf_hub_download
import safetensors
# Download CLIP embeddings
print("πŸ” Downloading CLIP embeddings...")
clip_embeddings_path = hf_hub_download(
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
filename="clip_embeddings.safetensors",
repo_type="dataset"
)
clip_ids_path = hf_hub_download(
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
filename="clip_embeddings_sentence_ids.json",
repo_type="dataset"
)
# Download PaintingCLIP embeddings
print("πŸ” Downloading PaintingCLIP embeddings...")
paintingclip_embeddings_path = hf_hub_download(
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
filename="paintingclip_embeddings.safetensors",
repo_type="dataset"
)
paintingclip_ids_path = hf_hub_download(
repo_id=ARTEFACT_EMBEDDINGS_DATASET,
filename="paintingclip_embeddings_sentence_ids.json",
repo_type="dataset"
)
# Load the embeddings
print("πŸ” Loading CLIP embeddings...")
clip_embeddings = safetensors.torch.load_file(clip_embeddings_path)['embeddings']
print("πŸ” Loading PaintingCLIP embeddings...")
paintingclip_embeddings = safetensors.torch.load_file(paintingclip_embeddings_path)['embeddings']
# Load the sentence IDs
with open(clip_ids_path, 'r') as f:
clip_sentence_ids = json.load(f)
with open(paintingclip_ids_path, 'r') as f:
paintingclip_sentence_ids = json.load(f)
print(f"βœ… Loaded CLIP embeddings: {clip_embeddings.shape}")
print(f"βœ… Loaded PaintingCLIP embeddings: {paintingclip_embeddings.shape}")
return {
"clip": (clip_embeddings, clip_sentence_ids),
"paintingclip": (paintingclip_embeddings, paintingclip_sentence_ids)
}
else:
# Fallback to old method if not using direct download
print("⚠️ Using fallback embedding loading method")
return None
except Exception as e:
print(f"❌ Failed to load embeddings from HF: {e}")
return None
def _load_sentences_metadata() -> Dict[str, Dict[str, Any]]:
"""
Get sentence metadata from global config (loaded from HF datasets).
"""
if not sentences:
print("⚠️ No sentence metadata available - check if HF datasets loaded successfully")
return {}
return sentences
@lru_cache(maxsize=1)
def _initialize_pipeline():
"""
Initialize the inference pipeline components (cached).
This function loads all heavy resources once and caches them:
- CLIP model (with optional LoRA adapter)
- Pre-computed sentence embeddings from HF
- Sentence metadata from HF
Returns:
Tuple of (processor, model, embeddings, sentence_ids, sentences_data, device)
"""
# Select configuration based on MODEL_TYPE
config = MODEL_CONFIG[MODEL_TYPE]
# Determine compute device
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load CLIP processor and base model
processor = CLIPProcessor.from_pretrained(config["model_id"], use_fast=False)
base_model = CLIPModel.from_pretrained(config["model_id"])
# Apply LoRA adapter if configured and available
if config["use_lora"] and config["lora_dir"]:
lora_path = Path(config["lora_dir"])
adapter_config_path = lora_path / "adapter_config.json"
if adapter_config_path.exists():
print(f"βœ… Loading LoRA adapter from {lora_path}")
model = PeftModel.from_pretrained(base_model, str(lora_path))
else:
print(f"⚠️ LoRA adapter not found at {lora_path}")
print(f"⚠️ Missing file: {adapter_config_path}")
print(f"⚠️ Falling back to base CLIP model without LoRA adapter")
model = base_model
else:
model = base_model
# Check for meta tensors and handle them properly
has_meta_tensors = any(p.device.type == "meta" for p in model.parameters())
if has_meta_tensors:
# Meta tensors mean the model needs to be materialized
print("[inference] meta tensors detected – materializing model on CPU")
device = torch.device("cpu")
# Materialize the model by moving it to CPU
# This converts meta tensors to actual tensors with allocated memory
model = model.to(device)
# Ensure all parameters are properly initialized
for param in model.parameters():
if param.device.type == "meta":
# This shouldn't happen after .to(device), but as a safety check
param.data = param.data.to(device)
else:
# Normal case: move model to selected device
if device.type != "cpu":
model = model.to(device)
model = model.eval()
# Load pre-computed embeddings from HF
try:
embeddings_data = load_embeddings_from_hf()
if embeddings_data is None:
raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}")
# Check if we're using streaming (old approach)
if embeddings_data.get("streaming", False):
print("βœ… Using streaming embeddings - will load on-demand")
return processor, model, "STREAMING", "STREAMING", "STREAMING", device
else:
# New code path for direct file download
if MODEL_TYPE == "clip":
embeddings, sentence_ids = embeddings_data["clip"]
else:
embeddings, sentence_ids = embeddings_data["paintingclip"]
if embeddings is None or sentence_ids is None:
raise ValueError(f"Failed to load embeddings for model type: {MODEL_TYPE}")
print(f"πŸ” Loaded {len(sentence_ids)} embeddings with shape {embeddings.shape}")
except Exception as e:
print(f"❌ Error loading embeddings: {e}")
raise
# Get sentence metadata from global config
sentences_data = _load_sentences_metadata()
print(f"πŸ” Loaded {len(sentences_data)} sentence metadata entries")
if sentences_data:
sample_key = next(iter(sentences_data.keys()))
print(f"πŸ” Sample sentence data structure: {sentences_data[sample_key]}")
return processor, model, embeddings, sentence_ids, sentences_data, device
# ========================================================================== #
# Optional saliency overlay #
# ========================================================================== #
def compute_heatmap(
image_path: str,
sentence: str,
*,
layer_idx: int = -1,
alpha: float = 0.45, # Add this
colormap: int = cv2.COLORMAP_JET, # Add this
) -> str:
"""
Generate a Grad-ECLIP heat-map for (image, sentence).
Parameters
----------
image_path : str
Path to the input image (same one sent to run_inference).
sentence : str
Caption text to explain (usually one of the sentences returned by
run_inference).
layer_idx : int, optional
Vision transformer block to analyse (default last).
alpha : float, optional
Heatmap overlay opacity (default: 0.45)
colormap : int, optional
OpenCV colormap for visualization (default: COLORMAP_JET)
Returns
-------
data_url : str
PNG overlay encoded as ``data:image/png;base64,...`` suitable for the
front-end.
"""
# Re-use cached objects
processor, model, _, _, _, device = _initialize_pipeline()
pil_img = Image.open(image_path).convert("RGB")
overlay = generate_heatmap(
image=pil_img,
sentence=sentence,
model=model,
processor=processor,
device=device,
layer_idx=layer_idx,
alpha=alpha, # Pass through
colormap=colormap, # Pass through
)
buf = io.BytesIO()
overlay.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
# ========================================================================== #
# Main retrieval routine #
# ========================================================================== #
def run_inference(
image_path: str,
*,
cell: Optional[Tuple[int, int]] = None,
grid_size: Tuple[int, int] = (7, 7),
top_k: int = TOP_K,
filter_topics: List[str] = None,
filter_creators: List[str] = None,
model_type: str = None,
) -> List[Dict[str, Any]]:
print(f"πŸ” run_inference called with:")
print(f"πŸ” image_path: {image_path}")
print(f"πŸ” cell: {cell}")
print(f"πŸ” filter_topics: {filter_topics}")
print(f"πŸ” filter_creators: {filter_creators}")
print(f"πŸ” model_type: {model_type}")
try:
# Set model type if specified
if model_type:
print(f"πŸ” Setting model type to: {model_type}")
set_model_type(model_type.lower())
# ---- Region-aware pathway --------------------------------------------
if cell is not None:
print(f"πŸ” Using region-aware pathway for cell {cell}")
from .patch_inference import rank_sentences_for_cell
row, col = cell
results = rank_sentences_for_cell(
image_path=image_path,
cell_row=row,
cell_col=col,
grid_size=grid_size,
top_k=top_k * 3,
)
# Apply filtering
if filter_topics or filter_creators:
from .filtering import apply_filters_to_results
results = apply_filters_to_results(results, filter_topics, filter_creators)
results = results[:top_k]
return results
# ---- Whole-painting pathway (original implementation) ----------------
print(f"πŸ” Using whole-painting pathway")
# Load cached pipeline components
print(f"πŸ” Loading pipeline components...")
processor, model, embeddings, sentence_ids, sentences_data, device = (
_initialize_pipeline()
)
print(f"βœ… Pipeline components loaded successfully")
# Check if we're in streaming mode
if embeddings == "STREAMING":
print("βœ… Streaming mode detected - using streaming embeddings")
return run_inference_streaming(
image_path=image_path,
filter_topics=filter_topics,
filter_creators=filter_creators,
model_type=model_type,
top_k=top_k,
processor=processor,
model=model,
device=device
)
# Non-streaming mode - continue with existing logic
# Get valid sentence IDs based on filters
if filter_topics or filter_creators:
print(f"πŸ” Applying filters...")
valid_sentence_ids = get_filtered_sentence_ids(filter_topics, filter_creators)
print(f"βœ… Filtered to {len(valid_sentence_ids)} valid sentences")
# Create mask for valid sentences
valid_indices = [
i for i, sid in enumerate(sentence_ids) if sid in valid_sentence_ids
]
if not valid_indices:
print(f"⚠️ No sentences match the filters")
return []
# Filter embeddings and sentence_ids
filtered_embeddings = embeddings[valid_indices]
filtered_sentence_ids = [sentence_ids[i] for i in valid_indices]
else:
print(f"πŸ” No filtering applied")
filtered_embeddings = embeddings
filtered_sentence_ids = sentence_ids
# Load and preprocess the image
print(f"πŸ” Loading and preprocessing image: {image_path}")
image = Image.open(image_path).convert("RGB")
print(f"βœ… Image loaded successfully, size: {image.size}")
# Compute image embedding
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
image_features = model.get_image_features(**inputs)
image_embedding = F.normalize(image_features.squeeze(0), dim=-1)
# Normalize sentence embeddings and compute similarities
sentence_embeddings = F.normalize(filtered_embeddings.to(device), dim=-1)
similarities = torch.matmul(sentence_embeddings, image_embedding).cpu()
# Get top-K results
k = min(top_k, len(similarities))
top_scores, top_indices = torch.topk(similarities, k=k)
# Build results with full sentence metadata
results = []
for rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), start=1):
sentence_id = filtered_sentence_ids[idx]
sentence_data = sentences_data.get(
sentence_id,
{"English Original": f"[Sentence data not found for {sentence_id}]", "Has PaintingCLIP Embedding": True},
).copy()
work_id = sentence_id.split("_")[0]
sentence_data.setdefault("Work", work_id)
results.append({
"id": sentence_id,
"score": float(score),
"english_original": sentence_data.get("English Original", "N/A"),
"work": work_id,
"rank": rank,
})
print(f"πŸ” run_inference returning {len(results)} results")
if results:
print(f"πŸ” First result: {results[0]}")
return results
except Exception as e:
print(f"❌ Error in run_inference: {e}")
print(f"❌ Error type: {type(e).__name__}")
import traceback
print(f"❌ Full traceback:")
traceback.print_exc()
raise
# ─── Utilities ───────────────────────────────────────────────────────────────
def get_available_models() -> List[str]:
"""Return list of available model types."""
return list(MODEL_CONFIG.keys())
def set_model_type(model_type: str) -> None:
"""
Change the active model type.
Args:
model_type: Either "clip" or "paintingclip"
Raises:
ValueError: If model_type is not recognized
"""
global MODEL_TYPE
if model_type not in MODEL_CONFIG:
raise ValueError(
f"Unknown model type: {model_type}. "
f"Available options: {', '.join(MODEL_CONFIG.keys())}"
)
MODEL_TYPE = model_type
# Clear the cache to force reinitialization
_initialize_pipeline.cache_clear()
def load_consolidated_embeddings(embedding_file: Path, metadata_file: Path):
"""Load embeddings from consolidated file with metadata"""
print(f"Loading consolidated embeddings from {embedding_file}")
# Load consolidated data with weights_only=False for compatibility
# This is safe since we're loading our own pre-computed embeddings
try:
consolidated_data = torch.load(embedding_file, map_location='cpu', weights_only=False)
print(f"βœ… Successfully loaded consolidated embeddings")
except Exception as e:
print(f"❌ Failed to load with weights_only=False: {e}")
# Fallback: try with weights_only=True (may fail if file has non-tensor data)
try:
print(f"πŸ” Trying fallback with weights_only=True...")
consolidated_data = torch.load(embedding_file, map_location='cpu', weights_only=True)
print(f"βœ… Successfully loaded with weights_only=True")
except Exception as e2:
print(f"❌ Both loading methods failed:")
print(f" weights_only=False: {e}")
print(f" weights_only=True: {e2}")
raise RuntimeError(f"Cannot load embedding file with either method: {e2}")
embeddings = consolidated_data['embeddings']
# Load metadata for file mapping
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Create filename to index mapping
filename_to_index = {item['filename']: item['index'] for item in metadata['file_mapping']}
print(f"Loaded {len(embeddings)} embeddings with dimension {embeddings.shape[1]}")
return embeddings, filename_to_index
def load_consolidated_embeddings_st(embedding_st_file: Path, ids_json_file: Path):
print(f"Loading safetensors embeddings from {embedding_st_file}")
if not embedding_st_file.exists():
raise FileNotFoundError(f"Missing {embedding_st_file}")
if not ids_json_file.exists():
raise FileNotFoundError(f"Missing {ids_json_file}")
data = st_load_file(str(embedding_st_file))
if "embeddings" not in data:
raise KeyError(f"'embeddings' tensor missing in {embedding_st_file}")
embeddings = data["embeddings"].to(dtype=torch.float32, device="cpu").contiguous()
with open(ids_json_file, "r", encoding="utf-8") as f:
sentence_ids = json.load(f)
if not isinstance(sentence_ids, list):
raise ValueError(f"IDs file malformed: {ids_json_file}")
print(f"Loaded {len(sentence_ids)} embeddings with dim {embeddings.shape[1]}")
return embeddings, sentence_ids
# Update your embedding loading logic
def load_embeddings_for_model(model_type: str):
"""Load embeddings for the specified model type with safetensors-first strategy."""
if model_type == "clip":
st_file = CLIP_EMBEDDINGS_ST
ids_file = CLIP_SENTENCE_IDS
# Legacy PT locations for fallback (if repo still has them)
pt_file = EMBEDDINGS_DIR / "clip_embeddings_consolidated.pt"
meta_file = EMBEDDINGS_DIR / "clip_embeddings_metadata.json"
indiv_dir = CLIP_EMBEDDINGS_DIR
else:
st_file = PAINTINGCLIP_EMBEDDINGS_ST
ids_file = PAINTINGCLIP_SENTENCE_IDS
pt_file = EMBEDDINGS_DIR / "paintingclip_embeddings_consolidated.pt"
meta_file = EMBEDDINGS_DIR / "paintingclip_embeddings_metadata.json"
indiv_dir = PAINTINGCLIP_EMBEDDINGS_DIR
# 1) safetensors
if st_file.exists() and ids_file.exists():
try:
return load_consolidated_embeddings_st(st_file, ids_file)
except Exception as e:
print(f"⚠️ Safetensors load failed: {e}")
# 2) legacy PT (if present)
if pt_file.exists() and meta_file.exists():
try:
return load_consolidated_embeddings(pt_file, meta_file)
except Exception as e:
print(f"⚠️ Legacy PT load failed: {e}")
# 3) final fallback: refuse to scan 10k files here (HF Spaces file count limits)
print("❌ No valid consolidated embeddings found. Make sure you committed:")
print(f" - {st_file.name}")
print(f" - {ids_file.name}")
return None, None
# Add this function for backward compatibility
def st_load_file(file_path: Path) -> Any:
"""Load a file using safetensors or other methods"""
try:
if file_path.suffix == '.safetensors':
import safetensors
return safetensors.safe_open(str(file_path), framework="pt")
else:
import torch
return torch.load(str(file_path))
except ImportError:
print(f"⚠️ Required library not available for loading {file_path}")
return None
except Exception as e:
print(f"❌ Error loading {file_path}: {e}")
return None
def load_embedding_for_sentence(sentence_id: str, model_type: str = "clip") -> Optional[torch.Tensor]:
"""Load a single embedding for a specific sentence using streaming"""
try:
# Call the function to get the actual dictionary
embeddings_datasets = EMBEDDINGS_DATASETS()
if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False):
print("❌ Streaming embeddings not available")
return None
dataset = embeddings_datasets['streaming_dataset']
# Search for the sentence in the streaming dataset
for item in dataset:
if item.get('sentence_id') == sentence_id:
# Extract the appropriate embedding based on model type
if model_type == "clip" and 'clip_embedding' in item:
return torch.tensor(item['clip_embedding'])
elif model_type == "paintingclip" and 'paintingclip_embedding' in item:
return torch.tensor(item['paintingclip_embedding'])
else:
print(f"⚠️ Embedding not found for {model_type} in sentence {sentence_id}")
return None
print(f"⚠️ Sentence {sentence_id} not found in streaming dataset")
return None
except Exception as e:
print(f"❌ Error loading streaming embedding for {sentence_id}: {e}")
return None
def get_top_k_embeddings(query_embedding: torch.Tensor, k: int = 10, model_type: str = "clip") -> List[Tuple[str, float]]:
"""Get top-k most similar embeddings using streaming"""
try:
# Call the function to get the actual dictionary
embeddings_datasets = EMBEDDINGS_DATASETS()
if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False):
print("❌ Streaming embeddings not available")
return []
dataset = embeddings_datasets['streaming_dataset']
similarities = []
# Process embeddings in batches to avoid memory issues
batch_size = 1000
batch = []
for item in dataset:
batch.append(item)
if len(batch) >= batch_size:
# Process batch
batch_similarities = process_embedding_batch(batch, query_embedding, model_type)
similarities.extend(batch_similarities)
batch = []
# Keep only top-k so far
similarities.sort(key=lambda x: x[1], reverse=True)
similarities = similarities[:k]
# Process remaining items
if batch:
batch_similarities = process_embedding_batch(batch, query_embedding, model_type)
similarities.extend(batch_similarities)
similarities.sort(key=lambda x: x[1], reverse=True)
similarities = similarities[:k]
return similarities
except Exception as e:
print(f"❌ Error getting top-k embeddings: {e}")
return []
def process_embedding_batch(batch: List[Dict], query_embedding: torch.Tensor, model_type: str) -> List[Tuple[str, float]]:
"""Process a batch of embeddings to find similarities"""
similarities = []
for item in batch:
try:
sentence_id = item.get('sentence_id', '')
# Get the appropriate embedding
if model_type == "clip" and 'clip_embedding' in item:
embedding = torch.tensor(item['clip_embedding'])
elif model_type == "paintingclip" and 'paintingclip_embedding' in item:
embedding = torch.tensor(item['paintingclip_embedding'])
else:
continue
# Calculate similarity
similarity = F.cosine_similarity(query_embedding.unsqueeze(0), embedding.unsqueeze(0), dim=1)
similarities.append((sentence_id, similarity.item()))
except Exception as e:
print(f"⚠️ Error processing item in batch: {e}")
continue
return similarities
def run_inference_streaming(
image_path: str,
filter_topics: List[str] = None,
filter_creators: List[str] = None,
model_type: str = "CLIP",
top_k: int = 10,
processor=None,
model=None,
device=None
) -> List[Dict[str, Any]]:
"""Run inference using streaming embeddings"""
try:
print(f"πŸ” Running streaming inference for {image_path}")
start_time = time.time()
# Load and preprocess the image
print(f"πŸ” Loading and preprocessing image: {image_path}")
image = Image.open(image_path).convert("RGB")
print(f"βœ… Image loaded successfully, size: {image.size}")
# Compute image embedding
print(f"πŸ” Computing image embedding...")
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
image_features = model.get_image_features(**inputs)
image_embedding = F.normalize(image_features.squeeze(0), dim=-1)
print(f"βœ… Image embedding computed successfully")
# Get streaming dataset
# Call the function to get the actual dictionary
embeddings_datasets = EMBEDDINGS_DATASETS()
if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False):
raise ValueError("Streaming embeddings not available")
dataset = embeddings_datasets['streaming_dataset']
# Process embeddings in streaming mode
results = []
batch_size = 1000
batch = []
total_processed = 0
batch_count = 0
print(f"πŸ” Starting streaming processing of 3.1M+ sentence embeddings...")
print(f"πŸ” Batch size: {batch_size}")
print(f"πŸ” Target top-k: {top_k}")
# Estimate total items for progress tracking
try:
# Try to get dataset size if available
if hasattr(dataset, '__len__'):
total_items = len(dataset)
print(f"πŸ” Total embeddings to process: {total_items:,}")
else:
total_items = None
print(f"πŸ” Dataset size unknown (streaming mode)")
except:
total_items = None
for item in dataset:
batch.append(item)
total_processed += 1
if len(batch) >= batch_size:
batch_count += 1
batch_start_time = time.time()
# Process batch
print(f"πŸ” Processing batch {batch_count} ({total_processed:,} items processed)...")
batch_results = process_embedding_batch_streaming(
batch, image_embedding, model_type, device
)
results.extend(batch_results)
batch = []
# Keep only top-k so far
results.sort(key=lambda x: x["score"], reverse=True)
results = results[:top_k]
batch_time = time.time() - batch_start_time
elapsed_time = time.time() - start_time
# Progress reporting
if total_items:
progress_pct = (total_processed / total_items) * 100
print(f"πŸ” Batch {batch_count} completed in {batch_time:.2f}s")
print(f"πŸ” Progress: {total_processed:,}/{total_items:,} ({progress_pct:.1f}%)")
print(f"πŸ” Elapsed time: {elapsed_time:.1f}s")
print(f"πŸ” Current top score: {results[0]['score']:.4f}" if results else "πŸ” Current top score: N/A")
print(f"πŸ” Estimated time remaining: {((elapsed_time / total_processed) * (total_items - total_processed)):.1f}s")
else:
print(f"πŸ” Batch {batch_count} completed in {batch_time:.2f}s")
print(f"πŸ” Total processed: {total_processed:,}")
print(f"πŸ” Elapsed time: {elapsed_time:.1f}s")
print(f"πŸ” Current top score: {results[0]['score']:.4f}" if results else "πŸ” Current top score: N/A")
print(f"πŸ” Current top result: {results[0]['english_original'][:100]}..." if results else "πŸ” No results yet")
print("─" * 80)
# Process remaining items
if batch:
print(f"πŸ” Processing final batch of {len(batch)} items...")
batch_results = process_embedding_batch_streaming(
batch, image_embedding, model_type, device
)
results.extend(batch_results)
results.sort(key=lambda x: x["score"], reverse=True)
results = results[:top_k]
total_time = time.time() - start_time
print(f"βœ… Streaming inference completed!")
print(f"πŸ” Total time: {total_time:.2f}s")
print(f"πŸ” Total embeddings processed: {total_processed:,}")
print(f"πŸ” Final results: {len(results)} items")
if results:
print(f"πŸ” Top result score: {results[0]['score']:.4f}")
print(f"πŸ” Top result: {results[0]['english_original'][:100]}...")
return results
except Exception as e:
print(f"❌ Error in streaming inference: {e}")
raise
def process_embedding_batch_streaming(
batch: List[Dict],
image_embedding: torch.Tensor,
model_type: str,
device: torch.device
) -> List[Dict[str, Any]]:
"""Process a batch of streaming embeddings"""
results = []
processed_count = 0
error_count = 0
print(f"πŸ” Processing batch of {len(batch)} items...")
# Debug: show first few items to understand the data structure
for i, item in enumerate(batch[:3]):
print(f" Item {i}: keys = {list(item.keys())}")
print(f" Item {i}: full item = {item}")
for item in batch:
try:
sentence_id = item.get('sentence_id', '')
# Get the appropriate embedding
if model_type == "CLIP" and 'clip_embedding' in item:
embedding = torch.tensor(item['clip_embedding'])
elif model_type == "PaintingCLIP" and 'paintingclip_embedding' in item:
embedding = torch.tensor(item['paintingclip_embedding'])
else:
if processed_count < 3: # Only show first few errors
print(f"⚠️ No embedding found for {model_type} in item: {list(item.keys())}")
continue
# Calculate similarity
embedding = embedding.to(device)
similarity = F.cosine_similarity(
image_embedding.unsqueeze(0),
embedding.unsqueeze(0),
dim=1
).item()
# Get sentence metadata
sentences_data = _load_sentences_metadata()
sentence_data = sentences_data.get(sentence_id, {})
work_id = sentence_id.split("_")[0]
results.append({
"id": sentence_id,
"score": similarity,
"english_original": sentence_data.get("English Original", "N/A"),
"work": work_id,
"rank": len(results) + 1,
})
processed_count += 1
except Exception as e:
error_count += 1
if error_count < 3: # Only show first few errors
print(f"⚠️ Error processing item in streaming batch: {e}")
continue
print(f"πŸ” Batch processing complete: {processed_count} successful, {error_count} errors")
return results