File size: 7,676 Bytes
0e61117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac1f80
0e61117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
patch_inference.py
──────────────────
Fast patch‑text similarity ranking on top of the existing PaintingCLIP
inference pipeline.

Public API
----------
rank_sentences_for_cell(...)
list_grid_scores(...)          # optional diagnostic helper
"""

from __future__ import annotations

import math
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch
import torch.nn.functional as F
from PIL import Image

# Local import: reuse the heavyweight initialiser & sentence metadata
from .inference import _initialize_pipeline  # same package, no circular import


# ════════════════════════════════════════════════════════════════════════════
# Internal helpers
# ════════════════════════════════════════════════════════════════════════════
def _infer_patch_hw(num_patches: int) -> Tuple[int, int]:
    """
    Infer the ViT patch grid (H, W) from the flat token count.
    Works for square layouts only (ViT‑B/32 β†’ 7Γ—7; ViT‑B/16 β†’ 14Γ—14).
    """
    root = int(math.sqrt(num_patches))
    if root * root == num_patches:
        return root, root
    raise ValueError(f"Unexpected non‑square patch layout: {num_patches}")


@lru_cache(maxsize=8)  # cache a few recent paintings Γ— grid sizes
def _prepare_image(image_path: str, grid_size: Tuple[int, int]) -> torch.Tensor:
    """
    Generate cell embeddings for the entire image.

    Uses ViT patch embeddings directly for efficiency.
    """
    # Load resources from main inference pipeline
    processor, model, _, _, _, device = _initialize_pipeline()

    # Load and process image
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")

    # Ensure inputs are on the correct device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        # Get patch embeddings from vision model
        vision_out = model.vision_model(**inputs, output_hidden_states=True)
        # Exclude CLS (token‑0), keep patch tokens
        patch_tokens = vision_out.last_hidden_state[:, 1:, :]  # (1, N, 768)
        patch_tokens = model.vision_model.post_layernorm(patch_tokens)  # LayerNorm
        patch_feats = model.visual_projection(patch_tokens)  # (1, N, 512)
        patch_feats = F.normalize(patch_feats.squeeze(0), dim=-1)  # (N, 512)

    # 2. Reshape β†’ (D, H, W) to pool channel‑wise
    num_patches, dim = patch_feats.shape
    H, W = _infer_patch_hw(num_patches)
    patch_grid = (
        patch_feats.view(H, W, dim)
        .permute(2, 0, 1)  # (D, H, W)
        .unsqueeze(0)  # (1, D, H, W) for pooling
    )

    # 3. Adaptive average‑pool down to UI grid resolution
    grid_h, grid_w = grid_size

    # Special case: if grid size matches patch grid, no pooling needed
    if (grid_h, grid_w) == (H, W):
        cell_grid = patch_grid.squeeze(0)  # Just remove batch dimension
    else:
        cell_grid = F.adaptive_avg_pool2d(
            patch_grid, output_size=(grid_h, grid_w)
        ).squeeze(0)

    # 4. Flatten β†’ (cells, D) & L2‑normalise
    cell_vecs = cell_grid.permute(1, 2, 0).reshape(-1, dim)  # (gΒ², 512)
    return F.normalize(cell_vecs, dim=-1).to(device)


# ════════════════════════════════════════════════════════════════════════════
# Public API
# ════════════════════════════════════════════════════════════════════════════
def rank_sentences_for_cell(
    image_path: str | Path,
    cell_row: int,
    cell_col: int,
    grid_size: Tuple[int, int] = (7, 7),
    top_k: int = 25,
    filter_topics: List[str] = None,
    filter_creators: List[str] = None,
) -> List[Dict[str, Any]]:
    """
    Retrieve the *top‑k* sentences whose text embeddings align most strongly
    with a specific grid cell of the painting.

    Parameters
    ----------
    image_path : str | Path
        Path (local or mounted) to the RGB image file.
    cell_row, cell_col : int
        Zero‑indexed row/column of the clicked grid cell.
    grid_size : (int, int), default (7, 7)
        Resolution of the UI grid (7x7 matches ViT-B/32 patch grid).
    top_k : int, default 25
        How many sentences to return.
    filter_topics : List[str], optional
        List of topic codes to filter results by
    filter_creators : List[str], optional
        List of creator names to filter results by

    Returns
    -------
    List[dict]
        Each item is the same schema as `run_inference`, facilitating front‑end reuse:
        { "sentence_id", "score", "english_original", "work", "rank" }
    """
    # Shared resources
    _proc, _model, sent_mat, sentence_ids, sent_meta, device = _initialize_pipeline()
    sent_mat = F.normalize(sent_mat.to(device), dim=-1)

    # Apply filtering if needed
    if filter_topics or filter_creators:
        from .filtering import get_filtered_sentence_ids

        valid_sentence_ids = get_filtered_sentence_ids(filter_topics, filter_creators)

        # Create mask for valid sentences
        valid_indices = [
            i for i, sid in enumerate(sentence_ids) if sid in valid_sentence_ids
        ]

        if not valid_indices:
            return []

        # Filter embeddings and sentence_ids
        sent_mat = sent_mat[valid_indices]
        sentence_ids = [sentence_ids[i] for i in valid_indices]

    # Validate cell indices
    grid_h, grid_w = grid_size
    if not (0 <= cell_row < grid_h and 0 <= cell_col < grid_w):
        raise ValueError(f"Cell ({cell_row}, {cell_col}) outside grid {grid_size}")

    # Cell feature vector
    cell_idx = cell_row * grid_w + cell_col
    cell_vecs = _prepare_image(image_path, grid_size)
    cell_vec = cell_vecs[cell_idx]

    # Cosine similarity and ranking
    scores = torch.matmul(sent_mat, cell_vec)
    k = min(top_k, scores.size(0))
    top_scores, top_idx = torch.topk(scores, k)

    # Assemble output
    out: List[Dict[str, Any]] = []
    for rank, (idx, sc) in enumerate(zip(top_idx.tolist(), top_scores.tolist()), 1):
        sid = sentence_ids[idx]
        meta = sent_meta.get(
            sid,
            {"English Original": f"[Sentence data not found for {sid}]"},
        )
        work_id = sid.split("_")[0]
        out.append(
            {
                "id": sid,  # Frontend expects "id", not "sentence_id"
                "score": float(sc),
                "english_original": meta.get("English Original", ""),
                "work": work_id,
                "rank": rank,
            }
        )
    return out


# Optional helper for debugging / heat‑map pre‑computation
def list_grid_scores(
    image_path: str | Path,
    grid_size: Tuple[int, int] = (7, 7),  # Changed default to 7x7
) -> torch.Tensor:
    """
    Return the full similarity matrix of shape (sentences, cells).
    Primarily for diagnostics or off‑line analysis.
    """
    _p, _m, sent_mat, *_, device = _initialize_pipeline()
    sent_mat = F.normalize(sent_mat.to(device), dim=-1)
    cell_vecs = _prepare_image(image_path, grid_size)
    return sent_mat @ cell_vecs.T  # (S, gΒ²)