|
|
""" |
|
|
PaintingCLIP inference pipeline for art-historical text retrieval. |
|
|
|
|
|
This module provides a pure functional interface for comparing artwork images |
|
|
against a corpus of pre-computed sentence embeddings using CLIP models with |
|
|
optional LoRA fine-tuning. |
|
|
|
|
|
The pipeline: |
|
|
1. Loads an image and computes its embedding using CLIP/PaintingCLIP |
|
|
2. Compares against pre-computed sentence embeddings via cosine similarity |
|
|
3. Returns the top-K most similar sentences with their metadata |
|
|
""" |
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
import time |
|
|
from functools import lru_cache |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Literal, Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from peft import PeftModel |
|
|
from PIL import Image |
|
|
from transformers import CLIPModel, CLIPProcessor |
|
|
from datasets import load_dataset |
|
|
|
|
|
from .filtering import get_filtered_sentence_ids |
|
|
|
|
|
from .heatmap import generate_heatmap |
|
|
from .config import ( |
|
|
JSON_INFO_DIR, |
|
|
EMBEDDINGS_DIR, |
|
|
JSON_DATASETS, |
|
|
EMBEDDINGS_DATASETS, |
|
|
PAINTINGCLIP_MODEL_DIR, |
|
|
ARTEFACT_EMBEDDINGS_DATASET, |
|
|
sentences, |
|
|
CLIP_EMBEDDINGS_ST, |
|
|
PAINTINGCLIP_EMBEDDINGS_ST, |
|
|
CLIP_SENTENCE_IDS, |
|
|
PAINTINGCLIP_SENTENCE_IDS, |
|
|
CLIP_EMBEDDINGS_DIR, |
|
|
PAINTINGCLIP_EMBEDDINGS_DIR |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_TYPE: Literal["clip", "paintingclip"] = "paintingclip" |
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
"clip": { |
|
|
"model_id": "openai/clip-vit-base-patch32", |
|
|
"use_lora": False, |
|
|
"lora_dir": None, |
|
|
}, |
|
|
"paintingclip": { |
|
|
"model_id": "openai/clip-vit-base-patch32", |
|
|
"use_lora": True, |
|
|
"lora_dir": PAINTINGCLIP_MODEL_DIR, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
TOP_K = 25 |
|
|
|
|
|
|
|
|
def load_embeddings_from_hf(): |
|
|
"""Load embeddings from HF dataset using safetensors files""" |
|
|
try: |
|
|
print(f" Loading embeddings from {ARTEFACT_EMBEDDINGS_DATASET}...") |
|
|
|
|
|
|
|
|
embeddings_datasets = EMBEDDINGS_DATASETS() |
|
|
if not embeddings_datasets: |
|
|
print("β No embeddings datasets loaded") |
|
|
return None |
|
|
|
|
|
|
|
|
if embeddings_datasets.get('use_direct_download', False): |
|
|
print("β
Using direct file download for embeddings") |
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import safetensors |
|
|
|
|
|
|
|
|
print("π Downloading CLIP embeddings...") |
|
|
clip_embeddings_path = hf_hub_download( |
|
|
repo_id=ARTEFACT_EMBEDDINGS_DATASET, |
|
|
filename="clip_embeddings.safetensors", |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
clip_ids_path = hf_hub_download( |
|
|
repo_id=ARTEFACT_EMBEDDINGS_DATASET, |
|
|
filename="clip_embeddings_sentence_ids.json", |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
|
|
|
print("π Downloading PaintingCLIP embeddings...") |
|
|
paintingclip_embeddings_path = hf_hub_download( |
|
|
repo_id=ARTEFACT_EMBEDDINGS_DATASET, |
|
|
filename="paintingclip_embeddings.safetensors", |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
paintingclip_ids_path = hf_hub_download( |
|
|
repo_id=ARTEFACT_EMBEDDINGS_DATASET, |
|
|
filename="paintingclip_embeddings_sentence_ids.json", |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
|
|
|
print("π Loading CLIP embeddings...") |
|
|
clip_embeddings = safetensors.torch.load_file(clip_embeddings_path)['embeddings'] |
|
|
|
|
|
print("π Loading PaintingCLIP embeddings...") |
|
|
paintingclip_embeddings = safetensors.torch.load_file(paintingclip_embeddings_path)['embeddings'] |
|
|
|
|
|
|
|
|
with open(clip_ids_path, 'r') as f: |
|
|
clip_sentence_ids = json.load(f) |
|
|
|
|
|
with open(paintingclip_ids_path, 'r') as f: |
|
|
paintingclip_sentence_ids = json.load(f) |
|
|
|
|
|
print(f"β
Loaded CLIP embeddings: {clip_embeddings.shape}") |
|
|
print(f"β
Loaded PaintingCLIP embeddings: {paintingclip_embeddings.shape}") |
|
|
|
|
|
return { |
|
|
"clip": (clip_embeddings, clip_sentence_ids), |
|
|
"paintingclip": (paintingclip_embeddings, paintingclip_sentence_ids) |
|
|
} |
|
|
else: |
|
|
|
|
|
print("β οΈ Using fallback embedding loading method") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Failed to load embeddings from HF: {e}") |
|
|
return None |
|
|
|
|
|
def _load_sentences_metadata() -> Dict[str, Dict[str, Any]]: |
|
|
""" |
|
|
Get sentence metadata from global config (loaded from HF datasets). |
|
|
""" |
|
|
if not sentences: |
|
|
print("β οΈ No sentence metadata available - check if HF datasets loaded successfully") |
|
|
return {} |
|
|
return sentences |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def _initialize_pipeline(): |
|
|
""" |
|
|
Initialize the inference pipeline components (cached). |
|
|
|
|
|
This function loads all heavy resources once and caches them: |
|
|
- CLIP model (with optional LoRA adapter) |
|
|
- Pre-computed sentence embeddings from HF |
|
|
- Sentence metadata from HF |
|
|
|
|
|
Returns: |
|
|
Tuple of (processor, model, embeddings, sentence_ids, sentences_data, device) |
|
|
""" |
|
|
|
|
|
config = MODEL_CONFIG[MODEL_TYPE] |
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
elif torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
processor = CLIPProcessor.from_pretrained(config["model_id"], use_fast=False) |
|
|
base_model = CLIPModel.from_pretrained(config["model_id"]) |
|
|
|
|
|
|
|
|
if config["use_lora"] and config["lora_dir"]: |
|
|
lora_path = Path(config["lora_dir"]) |
|
|
adapter_config_path = lora_path / "adapter_config.json" |
|
|
|
|
|
if adapter_config_path.exists(): |
|
|
print(f"β
Loading LoRA adapter from {lora_path}") |
|
|
model = PeftModel.from_pretrained(base_model, str(lora_path)) |
|
|
else: |
|
|
print(f"β οΈ LoRA adapter not found at {lora_path}") |
|
|
print(f"β οΈ Missing file: {adapter_config_path}") |
|
|
print(f"β οΈ Falling back to base CLIP model without LoRA adapter") |
|
|
model = base_model |
|
|
else: |
|
|
model = base_model |
|
|
|
|
|
|
|
|
has_meta_tensors = any(p.device.type == "meta" for p in model.parameters()) |
|
|
|
|
|
if has_meta_tensors: |
|
|
|
|
|
print("[inference] meta tensors detected β materializing model on CPU") |
|
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
for param in model.parameters(): |
|
|
if param.device.type == "meta": |
|
|
|
|
|
param.data = param.data.to(device) |
|
|
else: |
|
|
|
|
|
if device.type != "cpu": |
|
|
model = model.to(device) |
|
|
|
|
|
model = model.eval() |
|
|
|
|
|
|
|
|
try: |
|
|
embeddings_data = load_embeddings_from_hf() |
|
|
if embeddings_data is None: |
|
|
raise ValueError(f"Failed to load embeddings from HF dataset: {ARTEFACT_EMBEDDINGS_DATASET}") |
|
|
|
|
|
|
|
|
if embeddings_data.get("streaming", False): |
|
|
print("β
Using streaming embeddings - will load on-demand") |
|
|
return processor, model, "STREAMING", "STREAMING", "STREAMING", device |
|
|
else: |
|
|
|
|
|
if MODEL_TYPE == "clip": |
|
|
embeddings, sentence_ids = embeddings_data["clip"] |
|
|
else: |
|
|
embeddings, sentence_ids = embeddings_data["paintingclip"] |
|
|
|
|
|
if embeddings is None or sentence_ids is None: |
|
|
raise ValueError(f"Failed to load embeddings for model type: {MODEL_TYPE}") |
|
|
|
|
|
print(f"π Loaded {len(sentence_ids)} embeddings with shape {embeddings.shape}") |
|
|
except Exception as e: |
|
|
print(f"β Error loading embeddings: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
sentences_data = _load_sentences_metadata() |
|
|
print(f"π Loaded {len(sentences_data)} sentence metadata entries") |
|
|
if sentences_data: |
|
|
sample_key = next(iter(sentences_data.keys())) |
|
|
print(f"π Sample sentence data structure: {sentences_data[sample_key]}") |
|
|
|
|
|
return processor, model, embeddings, sentence_ids, sentences_data, device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_heatmap( |
|
|
image_path: str, |
|
|
sentence: str, |
|
|
*, |
|
|
layer_idx: int = -1, |
|
|
alpha: float = 0.45, |
|
|
colormap: int = cv2.COLORMAP_JET, |
|
|
) -> str: |
|
|
""" |
|
|
Generate a Grad-ECLIP heat-map for (image, sentence). |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
image_path : str |
|
|
Path to the input image (same one sent to run_inference). |
|
|
sentence : str |
|
|
Caption text to explain (usually one of the sentences returned by |
|
|
run_inference). |
|
|
layer_idx : int, optional |
|
|
Vision transformer block to analyse (default last). |
|
|
alpha : float, optional |
|
|
Heatmap overlay opacity (default: 0.45) |
|
|
colormap : int, optional |
|
|
OpenCV colormap for visualization (default: COLORMAP_JET) |
|
|
|
|
|
Returns |
|
|
------- |
|
|
data_url : str |
|
|
PNG overlay encoded as ``data:image/png;base64,...`` suitable for the |
|
|
front-end. |
|
|
""" |
|
|
|
|
|
processor, model, _, _, _, device = _initialize_pipeline() |
|
|
|
|
|
pil_img = Image.open(image_path).convert("RGB") |
|
|
|
|
|
overlay = generate_heatmap( |
|
|
image=pil_img, |
|
|
sentence=sentence, |
|
|
model=model, |
|
|
processor=processor, |
|
|
device=device, |
|
|
layer_idx=layer_idx, |
|
|
alpha=alpha, |
|
|
colormap=colormap, |
|
|
) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
overlay.save(buf, format="PNG") |
|
|
b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
return f"data:image/png;base64,{b64}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_inference( |
|
|
image_path: str, |
|
|
*, |
|
|
cell: Optional[Tuple[int, int]] = None, |
|
|
grid_size: Tuple[int, int] = (7, 7), |
|
|
top_k: int = TOP_K, |
|
|
filter_topics: List[str] = None, |
|
|
filter_creators: List[str] = None, |
|
|
model_type: str = None, |
|
|
) -> List[Dict[str, Any]]: |
|
|
print(f"π run_inference called with:") |
|
|
print(f"π image_path: {image_path}") |
|
|
print(f"π cell: {cell}") |
|
|
print(f"π filter_topics: {filter_topics}") |
|
|
print(f"π filter_creators: {filter_creators}") |
|
|
print(f"π model_type: {model_type}") |
|
|
|
|
|
try: |
|
|
|
|
|
if model_type: |
|
|
print(f"π Setting model type to: {model_type}") |
|
|
set_model_type(model_type.lower()) |
|
|
|
|
|
|
|
|
if cell is not None: |
|
|
print(f"π Using region-aware pathway for cell {cell}") |
|
|
from .patch_inference import rank_sentences_for_cell |
|
|
|
|
|
row, col = cell |
|
|
results = rank_sentences_for_cell( |
|
|
image_path=image_path, |
|
|
cell_row=row, |
|
|
cell_col=col, |
|
|
grid_size=grid_size, |
|
|
top_k=top_k * 3, |
|
|
) |
|
|
|
|
|
|
|
|
if filter_topics or filter_creators: |
|
|
from .filtering import apply_filters_to_results |
|
|
results = apply_filters_to_results(results, filter_topics, filter_creators) |
|
|
results = results[:top_k] |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
print(f"π Using whole-painting pathway") |
|
|
|
|
|
|
|
|
print(f"π Loading pipeline components...") |
|
|
processor, model, embeddings, sentence_ids, sentences_data, device = ( |
|
|
_initialize_pipeline() |
|
|
) |
|
|
print(f"β
Pipeline components loaded successfully") |
|
|
|
|
|
|
|
|
if embeddings == "STREAMING": |
|
|
print("β
Streaming mode detected - using streaming embeddings") |
|
|
return run_inference_streaming( |
|
|
image_path=image_path, |
|
|
filter_topics=filter_topics, |
|
|
filter_creators=filter_creators, |
|
|
model_type=model_type, |
|
|
top_k=top_k, |
|
|
processor=processor, |
|
|
model=model, |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if filter_topics or filter_creators: |
|
|
print(f"π Applying filters...") |
|
|
valid_sentence_ids = get_filtered_sentence_ids(filter_topics, filter_creators) |
|
|
print(f"β
Filtered to {len(valid_sentence_ids)} valid sentences") |
|
|
|
|
|
|
|
|
valid_indices = [ |
|
|
i for i, sid in enumerate(sentence_ids) if sid in valid_sentence_ids |
|
|
] |
|
|
|
|
|
if not valid_indices: |
|
|
print(f"β οΈ No sentences match the filters") |
|
|
return [] |
|
|
|
|
|
|
|
|
filtered_embeddings = embeddings[valid_indices] |
|
|
filtered_sentence_ids = [sentence_ids[i] for i in valid_indices] |
|
|
else: |
|
|
print(f"π No filtering applied") |
|
|
filtered_embeddings = embeddings |
|
|
filtered_sentence_ids = sentence_ids |
|
|
|
|
|
|
|
|
print(f"π Loading and preprocessing image: {image_path}") |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
print(f"β
Image loaded successfully, size: {image.size}") |
|
|
|
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = model.get_image_features(**inputs) |
|
|
image_embedding = F.normalize(image_features.squeeze(0), dim=-1) |
|
|
|
|
|
|
|
|
sentence_embeddings = F.normalize(filtered_embeddings.to(device), dim=-1) |
|
|
similarities = torch.matmul(sentence_embeddings, image_embedding).cpu() |
|
|
|
|
|
|
|
|
k = min(top_k, len(similarities)) |
|
|
top_scores, top_indices = torch.topk(similarities, k=k) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), start=1): |
|
|
sentence_id = filtered_sentence_ids[idx] |
|
|
sentence_data = sentences_data.get( |
|
|
sentence_id, |
|
|
{"English Original": f"[Sentence data not found for {sentence_id}]", "Has PaintingCLIP Embedding": True}, |
|
|
).copy() |
|
|
work_id = sentence_id.split("_")[0] |
|
|
sentence_data.setdefault("Work", work_id) |
|
|
results.append({ |
|
|
"id": sentence_id, |
|
|
"score": float(score), |
|
|
"english_original": sentence_data.get("English Original", "N/A"), |
|
|
"work": work_id, |
|
|
"rank": rank, |
|
|
}) |
|
|
|
|
|
print(f"π run_inference returning {len(results)} results") |
|
|
if results: |
|
|
print(f"π First result: {results[0]}") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error in run_inference: {e}") |
|
|
print(f"β Error type: {type(e).__name__}") |
|
|
import traceback |
|
|
print(f"β Full traceback:") |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def get_available_models() -> List[str]: |
|
|
"""Return list of available model types.""" |
|
|
return list(MODEL_CONFIG.keys()) |
|
|
|
|
|
|
|
|
def set_model_type(model_type: str) -> None: |
|
|
""" |
|
|
Change the active model type. |
|
|
|
|
|
Args: |
|
|
model_type: Either "clip" or "paintingclip" |
|
|
|
|
|
Raises: |
|
|
ValueError: If model_type is not recognized |
|
|
""" |
|
|
global MODEL_TYPE |
|
|
if model_type not in MODEL_CONFIG: |
|
|
raise ValueError( |
|
|
f"Unknown model type: {model_type}. " |
|
|
f"Available options: {', '.join(MODEL_CONFIG.keys())}" |
|
|
) |
|
|
MODEL_TYPE = model_type |
|
|
|
|
|
_initialize_pipeline.cache_clear() |
|
|
|
|
|
|
|
|
def load_consolidated_embeddings(embedding_file: Path, metadata_file: Path): |
|
|
"""Load embeddings from consolidated file with metadata""" |
|
|
print(f"Loading consolidated embeddings from {embedding_file}") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
consolidated_data = torch.load(embedding_file, map_location='cpu', weights_only=False) |
|
|
print(f"β
Successfully loaded consolidated embeddings") |
|
|
except Exception as e: |
|
|
print(f"β Failed to load with weights_only=False: {e}") |
|
|
|
|
|
try: |
|
|
print(f"π Trying fallback with weights_only=True...") |
|
|
consolidated_data = torch.load(embedding_file, map_location='cpu', weights_only=True) |
|
|
print(f"β
Successfully loaded with weights_only=True") |
|
|
except Exception as e2: |
|
|
print(f"β Both loading methods failed:") |
|
|
print(f" weights_only=False: {e}") |
|
|
print(f" weights_only=True: {e2}") |
|
|
raise RuntimeError(f"Cannot load embedding file with either method: {e2}") |
|
|
|
|
|
embeddings = consolidated_data['embeddings'] |
|
|
|
|
|
|
|
|
with open(metadata_file, 'r', encoding='utf-8') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
|
|
|
filename_to_index = {item['filename']: item['index'] for item in metadata['file_mapping']} |
|
|
|
|
|
print(f"Loaded {len(embeddings)} embeddings with dimension {embeddings.shape[1]}") |
|
|
|
|
|
return embeddings, filename_to_index |
|
|
|
|
|
def load_consolidated_embeddings_st(embedding_st_file: Path, ids_json_file: Path): |
|
|
print(f"Loading safetensors embeddings from {embedding_st_file}") |
|
|
if not embedding_st_file.exists(): |
|
|
raise FileNotFoundError(f"Missing {embedding_st_file}") |
|
|
if not ids_json_file.exists(): |
|
|
raise FileNotFoundError(f"Missing {ids_json_file}") |
|
|
|
|
|
data = st_load_file(str(embedding_st_file)) |
|
|
if "embeddings" not in data: |
|
|
raise KeyError(f"'embeddings' tensor missing in {embedding_st_file}") |
|
|
embeddings = data["embeddings"].to(dtype=torch.float32, device="cpu").contiguous() |
|
|
|
|
|
with open(ids_json_file, "r", encoding="utf-8") as f: |
|
|
sentence_ids = json.load(f) |
|
|
if not isinstance(sentence_ids, list): |
|
|
raise ValueError(f"IDs file malformed: {ids_json_file}") |
|
|
|
|
|
print(f"Loaded {len(sentence_ids)} embeddings with dim {embeddings.shape[1]}") |
|
|
return embeddings, sentence_ids |
|
|
|
|
|
|
|
|
def load_embeddings_for_model(model_type: str): |
|
|
"""Load embeddings for the specified model type with safetensors-first strategy.""" |
|
|
if model_type == "clip": |
|
|
st_file = CLIP_EMBEDDINGS_ST |
|
|
ids_file = CLIP_SENTENCE_IDS |
|
|
|
|
|
pt_file = EMBEDDINGS_DIR / "clip_embeddings_consolidated.pt" |
|
|
meta_file = EMBEDDINGS_DIR / "clip_embeddings_metadata.json" |
|
|
indiv_dir = CLIP_EMBEDDINGS_DIR |
|
|
else: |
|
|
st_file = PAINTINGCLIP_EMBEDDINGS_ST |
|
|
ids_file = PAINTINGCLIP_SENTENCE_IDS |
|
|
pt_file = EMBEDDINGS_DIR / "paintingclip_embeddings_consolidated.pt" |
|
|
meta_file = EMBEDDINGS_DIR / "paintingclip_embeddings_metadata.json" |
|
|
indiv_dir = PAINTINGCLIP_EMBEDDINGS_DIR |
|
|
|
|
|
|
|
|
if st_file.exists() and ids_file.exists(): |
|
|
try: |
|
|
return load_consolidated_embeddings_st(st_file, ids_file) |
|
|
except Exception as e: |
|
|
print(f"β οΈ Safetensors load failed: {e}") |
|
|
|
|
|
|
|
|
if pt_file.exists() and meta_file.exists(): |
|
|
try: |
|
|
return load_consolidated_embeddings(pt_file, meta_file) |
|
|
except Exception as e: |
|
|
print(f"β οΈ Legacy PT load failed: {e}") |
|
|
|
|
|
|
|
|
print("β No valid consolidated embeddings found. Make sure you committed:") |
|
|
print(f" - {st_file.name}") |
|
|
print(f" - {ids_file.name}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def st_load_file(file_path: Path) -> Any: |
|
|
"""Load a file using safetensors or other methods""" |
|
|
try: |
|
|
if file_path.suffix == '.safetensors': |
|
|
import safetensors |
|
|
return safetensors.safe_open(str(file_path), framework="pt") |
|
|
else: |
|
|
import torch |
|
|
return torch.load(str(file_path)) |
|
|
except ImportError: |
|
|
print(f"β οΈ Required library not available for loading {file_path}") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"β Error loading {file_path}: {e}") |
|
|
return None |
|
|
|
|
|
def load_embedding_for_sentence(sentence_id: str, model_type: str = "clip") -> Optional[torch.Tensor]: |
|
|
"""Load a single embedding for a specific sentence using streaming""" |
|
|
try: |
|
|
|
|
|
embeddings_datasets = EMBEDDINGS_DATASETS() |
|
|
if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False): |
|
|
print("β Streaming embeddings not available") |
|
|
return None |
|
|
|
|
|
dataset = embeddings_datasets['streaming_dataset'] |
|
|
|
|
|
|
|
|
for item in dataset: |
|
|
if item.get('sentence_id') == sentence_id: |
|
|
|
|
|
if model_type == "clip" and 'clip_embedding' in item: |
|
|
return torch.tensor(item['clip_embedding']) |
|
|
elif model_type == "paintingclip" and 'paintingclip_embedding' in item: |
|
|
return torch.tensor(item['paintingclip_embedding']) |
|
|
else: |
|
|
print(f"β οΈ Embedding not found for {model_type} in sentence {sentence_id}") |
|
|
return None |
|
|
|
|
|
print(f"β οΈ Sentence {sentence_id} not found in streaming dataset") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error loading streaming embedding for {sentence_id}: {e}") |
|
|
return None |
|
|
|
|
|
def get_top_k_embeddings(query_embedding: torch.Tensor, k: int = 10, model_type: str = "clip") -> List[Tuple[str, float]]: |
|
|
"""Get top-k most similar embeddings using streaming""" |
|
|
try: |
|
|
|
|
|
embeddings_datasets = EMBEDDINGS_DATASETS() |
|
|
if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False): |
|
|
print("β Streaming embeddings not available") |
|
|
return [] |
|
|
|
|
|
dataset = embeddings_datasets['streaming_dataset'] |
|
|
similarities = [] |
|
|
|
|
|
|
|
|
batch_size = 1000 |
|
|
batch = [] |
|
|
|
|
|
for item in dataset: |
|
|
batch.append(item) |
|
|
|
|
|
if len(batch) >= batch_size: |
|
|
|
|
|
batch_similarities = process_embedding_batch(batch, query_embedding, model_type) |
|
|
similarities.extend(batch_similarities) |
|
|
batch = [] |
|
|
|
|
|
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
similarities = similarities[:k] |
|
|
|
|
|
|
|
|
if batch: |
|
|
batch_similarities = process_embedding_batch(batch, query_embedding, model_type) |
|
|
similarities.extend(batch_similarities) |
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
similarities = similarities[:k] |
|
|
|
|
|
return similarities |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error getting top-k embeddings: {e}") |
|
|
return [] |
|
|
|
|
|
def process_embedding_batch(batch: List[Dict], query_embedding: torch.Tensor, model_type: str) -> List[Tuple[str, float]]: |
|
|
"""Process a batch of embeddings to find similarities""" |
|
|
similarities = [] |
|
|
|
|
|
for item in batch: |
|
|
try: |
|
|
sentence_id = item.get('sentence_id', '') |
|
|
|
|
|
|
|
|
if model_type == "clip" and 'clip_embedding' in item: |
|
|
embedding = torch.tensor(item['clip_embedding']) |
|
|
elif model_type == "paintingclip" and 'paintingclip_embedding' in item: |
|
|
embedding = torch.tensor(item['paintingclip_embedding']) |
|
|
else: |
|
|
continue |
|
|
|
|
|
|
|
|
similarity = F.cosine_similarity(query_embedding.unsqueeze(0), embedding.unsqueeze(0), dim=1) |
|
|
similarities.append((sentence_id, similarity.item())) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Error processing item in batch: {e}") |
|
|
continue |
|
|
|
|
|
return similarities |
|
|
|
|
|
def run_inference_streaming( |
|
|
image_path: str, |
|
|
filter_topics: List[str] = None, |
|
|
filter_creators: List[str] = None, |
|
|
model_type: str = "CLIP", |
|
|
top_k: int = 10, |
|
|
processor=None, |
|
|
model=None, |
|
|
device=None |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Run inference using streaming embeddings""" |
|
|
try: |
|
|
print(f"π Running streaming inference for {image_path}") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
print(f"π Loading and preprocessing image: {image_path}") |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
print(f"β
Image loaded successfully, size: {image.size}") |
|
|
|
|
|
|
|
|
print(f"π Computing image embedding...") |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = model.get_image_features(**inputs) |
|
|
image_embedding = F.normalize(image_features.squeeze(0), dim=-1) |
|
|
print(f"β
Image embedding computed successfully") |
|
|
|
|
|
|
|
|
|
|
|
embeddings_datasets = EMBEDDINGS_DATASETS() |
|
|
if not embeddings_datasets or not embeddings_datasets.get('use_streaming', False): |
|
|
raise ValueError("Streaming embeddings not available") |
|
|
|
|
|
dataset = embeddings_datasets['streaming_dataset'] |
|
|
|
|
|
|
|
|
results = [] |
|
|
batch_size = 1000 |
|
|
batch = [] |
|
|
total_processed = 0 |
|
|
batch_count = 0 |
|
|
|
|
|
print(f"π Starting streaming processing of 3.1M+ sentence embeddings...") |
|
|
print(f"π Batch size: {batch_size}") |
|
|
print(f"π Target top-k: {top_k}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if hasattr(dataset, '__len__'): |
|
|
total_items = len(dataset) |
|
|
print(f"π Total embeddings to process: {total_items:,}") |
|
|
else: |
|
|
total_items = None |
|
|
print(f"π Dataset size unknown (streaming mode)") |
|
|
except: |
|
|
total_items = None |
|
|
|
|
|
for item in dataset: |
|
|
batch.append(item) |
|
|
total_processed += 1 |
|
|
|
|
|
if len(batch) >= batch_size: |
|
|
batch_count += 1 |
|
|
batch_start_time = time.time() |
|
|
|
|
|
|
|
|
print(f"π Processing batch {batch_count} ({total_processed:,} items processed)...") |
|
|
batch_results = process_embedding_batch_streaming( |
|
|
batch, image_embedding, model_type, device |
|
|
) |
|
|
results.extend(batch_results) |
|
|
batch = [] |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x["score"], reverse=True) |
|
|
results = results[:top_k] |
|
|
|
|
|
batch_time = time.time() - batch_start_time |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if total_items: |
|
|
progress_pct = (total_processed / total_items) * 100 |
|
|
print(f"π Batch {batch_count} completed in {batch_time:.2f}s") |
|
|
print(f"π Progress: {total_processed:,}/{total_items:,} ({progress_pct:.1f}%)") |
|
|
print(f"π Elapsed time: {elapsed_time:.1f}s") |
|
|
print(f"π Current top score: {results[0]['score']:.4f}" if results else "π Current top score: N/A") |
|
|
print(f"π Estimated time remaining: {((elapsed_time / total_processed) * (total_items - total_processed)):.1f}s") |
|
|
else: |
|
|
print(f"π Batch {batch_count} completed in {batch_time:.2f}s") |
|
|
print(f"π Total processed: {total_processed:,}") |
|
|
print(f"π Elapsed time: {elapsed_time:.1f}s") |
|
|
print(f"π Current top score: {results[0]['score']:.4f}" if results else "π Current top score: N/A") |
|
|
|
|
|
print(f"π Current top result: {results[0]['english_original'][:100]}..." if results else "π No results yet") |
|
|
print("β" * 80) |
|
|
|
|
|
|
|
|
if batch: |
|
|
print(f"π Processing final batch of {len(batch)} items...") |
|
|
batch_results = process_embedding_batch_streaming( |
|
|
batch, image_embedding, model_type, device |
|
|
) |
|
|
results.extend(batch_results) |
|
|
results.sort(key=lambda x: x["score"], reverse=True) |
|
|
results = results[:top_k] |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
print(f"β
Streaming inference completed!") |
|
|
print(f"π Total time: {total_time:.2f}s") |
|
|
print(f"π Total embeddings processed: {total_processed:,}") |
|
|
print(f"π Final results: {len(results)} items") |
|
|
if results: |
|
|
print(f"π Top result score: {results[0]['score']:.4f}") |
|
|
print(f"π Top result: {results[0]['english_original'][:100]}...") |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error in streaming inference: {e}") |
|
|
raise |
|
|
|
|
|
def process_embedding_batch_streaming( |
|
|
batch: List[Dict], |
|
|
image_embedding: torch.Tensor, |
|
|
model_type: str, |
|
|
device: torch.device |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Process a batch of streaming embeddings""" |
|
|
results = [] |
|
|
processed_count = 0 |
|
|
error_count = 0 |
|
|
|
|
|
print(f"π Processing batch of {len(batch)} items...") |
|
|
|
|
|
|
|
|
for i, item in enumerate(batch[:3]): |
|
|
print(f" Item {i}: keys = {list(item.keys())}") |
|
|
print(f" Item {i}: full item = {item}") |
|
|
|
|
|
for item in batch: |
|
|
try: |
|
|
sentence_id = item.get('sentence_id', '') |
|
|
|
|
|
|
|
|
if model_type == "CLIP" and 'clip_embedding' in item: |
|
|
embedding = torch.tensor(item['clip_embedding']) |
|
|
elif model_type == "PaintingCLIP" and 'paintingclip_embedding' in item: |
|
|
embedding = torch.tensor(item['paintingclip_embedding']) |
|
|
else: |
|
|
if processed_count < 3: |
|
|
print(f"β οΈ No embedding found for {model_type} in item: {list(item.keys())}") |
|
|
continue |
|
|
|
|
|
|
|
|
embedding = embedding.to(device) |
|
|
similarity = F.cosine_similarity( |
|
|
image_embedding.unsqueeze(0), |
|
|
embedding.unsqueeze(0), |
|
|
dim=1 |
|
|
).item() |
|
|
|
|
|
|
|
|
sentences_data = _load_sentences_metadata() |
|
|
sentence_data = sentences_data.get(sentence_id, {}) |
|
|
work_id = sentence_id.split("_")[0] |
|
|
|
|
|
results.append({ |
|
|
"id": sentence_id, |
|
|
"score": similarity, |
|
|
"english_original": sentence_data.get("English Original", "N/A"), |
|
|
"work": work_id, |
|
|
"rank": len(results) + 1, |
|
|
}) |
|
|
processed_count += 1 |
|
|
|
|
|
except Exception as e: |
|
|
error_count += 1 |
|
|
if error_count < 3: |
|
|
print(f"β οΈ Error processing item in streaming batch: {e}") |
|
|
continue |
|
|
|
|
|
print(f"π Batch processing complete: {processed_count} successful, {error_count} errors") |
|
|
return results |
|
|
|