|
|
""" |
|
|
patch_inference.py |
|
|
ββββββββββββββββββ |
|
|
Fast patchβtext similarity ranking on top of the existing PaintingCLIP |
|
|
inference pipeline. |
|
|
|
|
|
Public API |
|
|
---------- |
|
|
rank_sentences_for_cell(...) |
|
|
list_grid_scores(...) # optional diagnostic helper |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from functools import lru_cache |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
from .inference import _initialize_pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _infer_patch_hw(num_patches: int) -> Tuple[int, int]: |
|
|
""" |
|
|
Infer the ViT patch grid (H, W) from the flat token count. |
|
|
Works for square layouts only (ViTβB/32 β 7Γ7; ViTβB/16 β 14Γ14). |
|
|
""" |
|
|
root = int(math.sqrt(num_patches)) |
|
|
if root * root == num_patches: |
|
|
return root, root |
|
|
raise ValueError(f"Unexpected nonβsquare patch layout: {num_patches}") |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=8) |
|
|
def _prepare_image(image_path: str, grid_size: Tuple[int, int]) -> torch.Tensor: |
|
|
""" |
|
|
Generate cell embeddings for the entire image. |
|
|
|
|
|
Uses ViT patch embeddings directly for efficiency. |
|
|
""" |
|
|
|
|
|
processor, model, _, _, _, device = _initialize_pipeline() |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
vision_out = model.vision_model(**inputs, output_hidden_states=True) |
|
|
|
|
|
patch_tokens = vision_out.last_hidden_state[:, 1:, :] |
|
|
patch_tokens = model.vision_model.post_layernorm(patch_tokens) |
|
|
patch_feats = model.visual_projection(patch_tokens) |
|
|
patch_feats = F.normalize(patch_feats.squeeze(0), dim=-1) |
|
|
|
|
|
|
|
|
num_patches, dim = patch_feats.shape |
|
|
H, W = _infer_patch_hw(num_patches) |
|
|
patch_grid = ( |
|
|
patch_feats.view(H, W, dim) |
|
|
.permute(2, 0, 1) |
|
|
.unsqueeze(0) |
|
|
) |
|
|
|
|
|
|
|
|
grid_h, grid_w = grid_size |
|
|
|
|
|
|
|
|
if (grid_h, grid_w) == (H, W): |
|
|
cell_grid = patch_grid.squeeze(0) |
|
|
else: |
|
|
cell_grid = F.adaptive_avg_pool2d( |
|
|
patch_grid, output_size=(grid_h, grid_w) |
|
|
).squeeze(0) |
|
|
|
|
|
|
|
|
cell_vecs = cell_grid.permute(1, 2, 0).reshape(-1, dim) |
|
|
return F.normalize(cell_vecs, dim=-1).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rank_sentences_for_cell( |
|
|
image_path: str | Path, |
|
|
cell_row: int, |
|
|
cell_col: int, |
|
|
grid_size: Tuple[int, int] = (7, 7), |
|
|
top_k: int = 25, |
|
|
filter_topics: List[str] = None, |
|
|
filter_creators: List[str] = None, |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Retrieve the *topβk* sentences whose text embeddings align most strongly |
|
|
with a specific grid cell of the painting. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
image_path : str | Path |
|
|
Path (local or mounted) to the RGB image file. |
|
|
cell_row, cell_col : int |
|
|
Zeroβindexed row/column of the clicked grid cell. |
|
|
grid_size : (int, int), default (7, 7) |
|
|
Resolution of the UI grid (7x7 matches ViT-B/32 patch grid). |
|
|
top_k : int, default 25 |
|
|
How many sentences to return. |
|
|
filter_topics : List[str], optional |
|
|
List of topic codes to filter results by |
|
|
filter_creators : List[str], optional |
|
|
List of creator names to filter results by |
|
|
|
|
|
Returns |
|
|
------- |
|
|
List[dict] |
|
|
Each item is the same schema as `run_inference`, facilitating frontβend reuse: |
|
|
{ "sentence_id", "score", "english_original", "work", "rank" } |
|
|
""" |
|
|
|
|
|
_proc, _model, sent_mat, sentence_ids, sent_meta, device = _initialize_pipeline() |
|
|
sent_mat = F.normalize(sent_mat.to(device), dim=-1) |
|
|
|
|
|
|
|
|
if filter_topics or filter_creators: |
|
|
from .filtering import get_filtered_sentence_ids |
|
|
|
|
|
valid_sentence_ids = get_filtered_sentence_ids(filter_topics, filter_creators) |
|
|
|
|
|
|
|
|
valid_indices = [ |
|
|
i for i, sid in enumerate(sentence_ids) if sid in valid_sentence_ids |
|
|
] |
|
|
|
|
|
if not valid_indices: |
|
|
return [] |
|
|
|
|
|
|
|
|
sent_mat = sent_mat[valid_indices] |
|
|
sentence_ids = [sentence_ids[i] for i in valid_indices] |
|
|
|
|
|
|
|
|
grid_h, grid_w = grid_size |
|
|
if not (0 <= cell_row < grid_h and 0 <= cell_col < grid_w): |
|
|
raise ValueError(f"Cell ({cell_row}, {cell_col}) outside grid {grid_size}") |
|
|
|
|
|
|
|
|
cell_idx = cell_row * grid_w + cell_col |
|
|
cell_vecs = _prepare_image(image_path, grid_size) |
|
|
cell_vec = cell_vecs[cell_idx] |
|
|
|
|
|
|
|
|
scores = torch.matmul(sent_mat, cell_vec) |
|
|
k = min(top_k, scores.size(0)) |
|
|
top_scores, top_idx = torch.topk(scores, k) |
|
|
|
|
|
|
|
|
out: List[Dict[str, Any]] = [] |
|
|
for rank, (idx, sc) in enumerate(zip(top_idx.tolist(), top_scores.tolist()), 1): |
|
|
sid = sentence_ids[idx] |
|
|
meta = sent_meta.get( |
|
|
sid, |
|
|
{"English Original": f"[Sentence data not found for {sid}]"}, |
|
|
) |
|
|
work_id = sid.split("_")[0] |
|
|
out.append( |
|
|
{ |
|
|
"id": sid, |
|
|
"score": float(sc), |
|
|
"english_original": meta.get("English Original", ""), |
|
|
"work": work_id, |
|
|
"rank": rank, |
|
|
} |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def list_grid_scores( |
|
|
image_path: str | Path, |
|
|
grid_size: Tuple[int, int] = (7, 7), |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Return the full similarity matrix of shape (sentences, cells). |
|
|
Primarily for diagnostics or offβline analysis. |
|
|
""" |
|
|
_p, _m, sent_mat, *_, device = _initialize_pipeline() |
|
|
sent_mat = F.normalize(sent_mat.to(device), dim=-1) |
|
|
cell_vecs = _prepare_image(image_path, grid_size) |
|
|
return sent_mat @ cell_vecs.T |
|
|
|