File size: 9,159 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
Improved image generator with domain gating, similarity thresholding,
and explicit retrieval failure reporting.

Phase 1B+1C: Addresses retrieval reliability for controlled experiments.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np

from src.embeddings.aligned_embeddings import AlignedEmbedder
from src.embeddings.similarity import cosine_similarity
from src.exceptions import IndexError_

logger = logging.getLogger(__name__)


# Domain keywords for gating — reject obvious mismatches
DOMAIN_KEYWORDS = {
    "nature": {"forest", "tree", "mountain", "jungle", "garden", "park", "field",
               "meadow", "countryside", "rural", "fog", "dawn", "sunrise", "hill",
               "valley", "woodland", "grove", "leaf", "green", "wildlife"},
    "urban": {"city", "street", "neon", "urban", "downtown", "skyscraper",
              "building", "traffic", "night", "cobblestone", "road", "car",
              "sign", "shop", "window", "concrete", "sidewalk"},
    "water": {"beach", "ocean", "wave", "sea", "shore", "coast", "lake",
              "river", "water", "sand", "surf", "tide", "tropical", "island"},
}

# Domains that should NOT co-occur in prompt+image
INCOMPATIBLE_DOMAINS = {
    "nature": {"urban"},
    "urban": {"nature", "water"},
    "water": {"urban"},
}


@dataclass
class ImageRetrievalResult:
    """Result of image retrieval with metadata for experiment bundles."""
    image_path: str
    similarity: float
    domain: str
    retrieval_failed: bool
    candidates_considered: int
    candidates_above_threshold: int
    top_5: List[Tuple[str, float]]


def _detect_prompt_domain(prompt: str) -> Optional[str]:
    """Detect the primary domain of a prompt from keywords."""
    prompt_lower = prompt.lower()
    prompt_words = set(prompt_lower.split())

    scores = {}
    for domain, keywords in DOMAIN_KEYWORDS.items():
        overlap = len(prompt_words & keywords)
        # Also check substring matches for compound words
        substring_hits = sum(1 for kw in keywords if kw in prompt_lower)
        scores[domain] = overlap + substring_hits

    if not scores or max(scores.values()) == 0:
        return None

    best_domain = max(scores, key=scores.get)
    return best_domain


def _is_domain_compatible(prompt_domain: Optional[str], image_domain: str) -> bool:
    """Check if image domain is compatible with prompt domain."""
    if prompt_domain is None:
        return True  # No domain detected, allow everything
    if image_domain == "other":
        return True  # Unknown domain, don't reject
    incompatible = INCOMPATIBLE_DOMAINS.get(prompt_domain, set())
    return image_domain not in incompatible


class ImprovedImageRetrievalGenerator:
    """
    Image retrieval with:
    - Domain gating: rejects obvious domain mismatches (forest prompt → no city images)
    - Raised similarity floor: min_similarity=0.20 (was 0.15)
    - Explicit retrieval failure: returns retrieval_failed=True instead of silent nonsense
    - Full diagnostic metadata for experiment bundles
    """

    def __init__(
        self,
        index_path: str = "data/embeddings/image_index.npz",
        min_similarity: float = 0.20,
        top_k: int = 5,
    ):
        self.index_path = Path(index_path)
        self.min_similarity = min_similarity
        self.top_k = top_k

        if not self.index_path.exists():
            raise IndexError_(
                f"Missing image index at {self.index_path}. "
                "Run: python scripts/build_embedding_indexes.py",
                index_path=str(self.index_path),
            )

        data = np.load(self.index_path, allow_pickle=True)
        self.ids = data["ids"].tolist()
        self.embs = data["embs"].astype("float32")

        # Load domain tags if available (from rebuilt index)
        if "domains" in data:
            self.domains = data["domains"].tolist()
        else:
            # Infer from filenames for old indexes
            self.domains = [self._infer_domain(p) for p in self.ids]

        if len(self.ids) == 0:
            raise IndexError_(
                "Image index is empty. "
                "Add images and run: python scripts/build_embedding_indexes.py",
                index_path=str(self.index_path),
            )

        self.embedder = AlignedEmbedder(target_dim=512)

    @staticmethod
    def _infer_domain(filepath: str) -> str:
        """Infer domain from filename."""
        name = Path(filepath).stem.lower()
        for domain, keywords in DOMAIN_KEYWORDS.items():
            if any(kw in name for kw in keywords):
                return domain
        return "other"

    def retrieve(
        self,
        query_text: str,
        min_similarity: Optional[float] = None,
    ) -> ImageRetrievalResult:
        """
        Retrieve best matching image with domain gating and quality checks.

        Returns ImageRetrievalResult with full metadata including retrieval_failed flag.
        """
        if min_similarity is None:
            min_similarity = self.min_similarity

        prompt_domain = _detect_prompt_domain(query_text)
        query_emb = self.embedder.embed_text(query_text)

        # Score all candidates
        scored = []
        for img_path, img_emb, img_domain in zip(self.ids, self.embs, self.domains):
            sim = cosine_similarity(query_emb, img_emb)
            scored.append((img_path, sim, img_domain))
        scored.sort(key=lambda x: x[1], reverse=True)

        top_5 = [(Path(p).name, s) for p, s, _ in scored[:5]]

        # Phase 1: Domain gating — filter out incompatible domains
        domain_filtered = [
            (p, s, d) for p, s, d in scored
            if _is_domain_compatible(prompt_domain, d)
        ]

        # Phase 2: Similarity thresholding
        candidates = domain_filtered if domain_filtered else scored
        above_threshold = [(p, s, d) for p, s, d in candidates if s >= min_similarity]

        if above_threshold:
            # Best candidate passes both domain and similarity checks
            best_path, best_sim, best_domain = above_threshold[0]
            return ImageRetrievalResult(
                image_path=best_path,
                similarity=best_sim,
                domain=best_domain,
                retrieval_failed=False,
                candidates_considered=len(scored),
                candidates_above_threshold=len(above_threshold),
                top_5=top_5,
            )

        # Phase 3: Fallback — nothing passed threshold
        # Return best domain-compatible candidate (even if below threshold)
        if domain_filtered:
            best_path, best_sim, best_domain = domain_filtered[0]
        else:
            best_path, best_sim, best_domain = scored[0]

        return ImageRetrievalResult(
            image_path=best_path,
            similarity=best_sim,
            domain=best_domain,
            retrieval_failed=best_sim < min_similarity,
            candidates_considered=len(scored),
            candidates_above_threshold=0,
            top_5=top_5,
        )

    # Backward-compatible method
    def retrieve_top_k(
        self,
        query_text: str,
        k: int = 1,
        min_similarity: Optional[float] = None,
    ) -> List[Tuple[str, float]]:
        """Backward-compatible interface. Returns list of (path, score) tuples."""
        result = self.retrieve(query_text, min_similarity)
        return [(result.image_path, result.similarity)]


def generate_image_improved(
    prompt: str,
    out_dir: str,
    index_path: str = "data/embeddings/image_index.npz",
    min_similarity: float = 0.20,
) -> str:
    """
    Generate (retrieve) an image for a prompt.

    Returns the image path. Warns on low similarity or retrieval failure.
    """
    generator = ImprovedImageRetrievalGenerator(
        index_path=index_path,
        min_similarity=min_similarity,
    )
    result = generator.retrieve(prompt, min_similarity=min_similarity)

    if result.retrieval_failed:
        logger.warning(
            "Image retrieval failed: no image above threshold (%.2f) "
            "for prompt: \"%s...\" — best: %s (sim=%.4f, domain=%s)",
            min_similarity, prompt[:60], Path(result.image_path).name,
            result.similarity, result.domain,
        )
    elif result.similarity < 0.25:
        logger.warning(
            "Low image similarity: %.4f for \"%s...\" -> %s",
            result.similarity, prompt[:60], Path(result.image_path).name,
        )

    return result.image_path


def generate_image_with_metadata(
    prompt: str,
    index_path: str = "data/embeddings/image_index.npz",
    min_similarity: float = 0.20,
) -> ImageRetrievalResult:
    """
    Generate (retrieve) an image and return full metadata.

    Use this in experiment pipelines where retrieval quality matters.
    """
    generator = ImprovedImageRetrievalGenerator(
        index_path=index_path,
        min_similarity=min_similarity,
    )
    return generator.retrieve(prompt, min_similarity=min_similarity)