pratik-250620's picture
Upload folder using huggingface_hub
6835659 verified
"""
Improved image generator with domain gating, similarity thresholding,
and explicit retrieval failure reporting.
Phase 1B+1C: Addresses retrieval reliability for controlled experiments.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
from src.embeddings.aligned_embeddings import AlignedEmbedder
from src.embeddings.similarity import cosine_similarity
from src.exceptions import IndexError_
logger = logging.getLogger(__name__)
# Domain keywords for gating — reject obvious mismatches
DOMAIN_KEYWORDS = {
"nature": {"forest", "tree", "mountain", "jungle", "garden", "park", "field",
"meadow", "countryside", "rural", "fog", "dawn", "sunrise", "hill",
"valley", "woodland", "grove", "leaf", "green", "wildlife"},
"urban": {"city", "street", "neon", "urban", "downtown", "skyscraper",
"building", "traffic", "night", "cobblestone", "road", "car",
"sign", "shop", "window", "concrete", "sidewalk"},
"water": {"beach", "ocean", "wave", "sea", "shore", "coast", "lake",
"river", "water", "sand", "surf", "tide", "tropical", "island"},
}
# Domains that should NOT co-occur in prompt+image
INCOMPATIBLE_DOMAINS = {
"nature": {"urban"},
"urban": {"nature", "water"},
"water": {"urban"},
}
@dataclass
class ImageRetrievalResult:
"""Result of image retrieval with metadata for experiment bundles."""
image_path: str
similarity: float
domain: str
retrieval_failed: bool
candidates_considered: int
candidates_above_threshold: int
top_5: List[Tuple[str, float]]
def _detect_prompt_domain(prompt: str) -> Optional[str]:
"""Detect the primary domain of a prompt from keywords."""
prompt_lower = prompt.lower()
prompt_words = set(prompt_lower.split())
scores = {}
for domain, keywords in DOMAIN_KEYWORDS.items():
overlap = len(prompt_words & keywords)
# Also check substring matches for compound words
substring_hits = sum(1 for kw in keywords if kw in prompt_lower)
scores[domain] = overlap + substring_hits
if not scores or max(scores.values()) == 0:
return None
best_domain = max(scores, key=scores.get)
return best_domain
def _is_domain_compatible(prompt_domain: Optional[str], image_domain: str) -> bool:
"""Check if image domain is compatible with prompt domain."""
if prompt_domain is None:
return True # No domain detected, allow everything
if image_domain == "other":
return True # Unknown domain, don't reject
incompatible = INCOMPATIBLE_DOMAINS.get(prompt_domain, set())
return image_domain not in incompatible
class ImprovedImageRetrievalGenerator:
"""
Image retrieval with:
- Domain gating: rejects obvious domain mismatches (forest prompt → no city images)
- Raised similarity floor: min_similarity=0.20 (was 0.15)
- Explicit retrieval failure: returns retrieval_failed=True instead of silent nonsense
- Full diagnostic metadata for experiment bundles
"""
def __init__(
self,
index_path: str = "data/embeddings/image_index.npz",
min_similarity: float = 0.20,
top_k: int = 5,
):
self.index_path = Path(index_path)
self.min_similarity = min_similarity
self.top_k = top_k
if not self.index_path.exists():
raise IndexError_(
f"Missing image index at {self.index_path}. "
"Run: python scripts/build_embedding_indexes.py",
index_path=str(self.index_path),
)
data = np.load(self.index_path, allow_pickle=True)
self.ids = data["ids"].tolist()
self.embs = data["embs"].astype("float32")
# Load domain tags if available (from rebuilt index)
if "domains" in data:
self.domains = data["domains"].tolist()
else:
# Infer from filenames for old indexes
self.domains = [self._infer_domain(p) for p in self.ids]
if len(self.ids) == 0:
raise IndexError_(
"Image index is empty. "
"Add images and run: python scripts/build_embedding_indexes.py",
index_path=str(self.index_path),
)
self.embedder = AlignedEmbedder(target_dim=512)
@staticmethod
def _infer_domain(filepath: str) -> str:
"""Infer domain from filename."""
name = Path(filepath).stem.lower()
for domain, keywords in DOMAIN_KEYWORDS.items():
if any(kw in name for kw in keywords):
return domain
return "other"
def retrieve(
self,
query_text: str,
min_similarity: Optional[float] = None,
) -> ImageRetrievalResult:
"""
Retrieve best matching image with domain gating and quality checks.
Returns ImageRetrievalResult with full metadata including retrieval_failed flag.
"""
if min_similarity is None:
min_similarity = self.min_similarity
prompt_domain = _detect_prompt_domain(query_text)
query_emb = self.embedder.embed_text(query_text)
# Score all candidates
scored = []
for img_path, img_emb, img_domain in zip(self.ids, self.embs, self.domains):
sim = cosine_similarity(query_emb, img_emb)
scored.append((img_path, sim, img_domain))
scored.sort(key=lambda x: x[1], reverse=True)
top_5 = [(Path(p).name, s) for p, s, _ in scored[:5]]
# Phase 1: Domain gating — filter out incompatible domains
domain_filtered = [
(p, s, d) for p, s, d in scored
if _is_domain_compatible(prompt_domain, d)
]
# Phase 2: Similarity thresholding
candidates = domain_filtered if domain_filtered else scored
above_threshold = [(p, s, d) for p, s, d in candidates if s >= min_similarity]
if above_threshold:
# Best candidate passes both domain and similarity checks
best_path, best_sim, best_domain = above_threshold[0]
return ImageRetrievalResult(
image_path=best_path,
similarity=best_sim,
domain=best_domain,
retrieval_failed=False,
candidates_considered=len(scored),
candidates_above_threshold=len(above_threshold),
top_5=top_5,
)
# Phase 3: Fallback — nothing passed threshold
# Return best domain-compatible candidate (even if below threshold)
if domain_filtered:
best_path, best_sim, best_domain = domain_filtered[0]
else:
best_path, best_sim, best_domain = scored[0]
return ImageRetrievalResult(
image_path=best_path,
similarity=best_sim,
domain=best_domain,
retrieval_failed=best_sim < min_similarity,
candidates_considered=len(scored),
candidates_above_threshold=0,
top_5=top_5,
)
# Backward-compatible method
def retrieve_top_k(
self,
query_text: str,
k: int = 1,
min_similarity: Optional[float] = None,
) -> List[Tuple[str, float]]:
"""Backward-compatible interface. Returns list of (path, score) tuples."""
result = self.retrieve(query_text, min_similarity)
return [(result.image_path, result.similarity)]
def generate_image_improved(
prompt: str,
out_dir: str,
index_path: str = "data/embeddings/image_index.npz",
min_similarity: float = 0.20,
) -> str:
"""
Generate (retrieve) an image for a prompt.
Returns the image path. Warns on low similarity or retrieval failure.
"""
generator = ImprovedImageRetrievalGenerator(
index_path=index_path,
min_similarity=min_similarity,
)
result = generator.retrieve(prompt, min_similarity=min_similarity)
if result.retrieval_failed:
logger.warning(
"Image retrieval failed: no image above threshold (%.2f) "
"for prompt: \"%s...\" — best: %s (sim=%.4f, domain=%s)",
min_similarity, prompt[:60], Path(result.image_path).name,
result.similarity, result.domain,
)
elif result.similarity < 0.25:
logger.warning(
"Low image similarity: %.4f for \"%s...\" -> %s",
result.similarity, prompt[:60], Path(result.image_path).name,
)
return result.image_path
def generate_image_with_metadata(
prompt: str,
index_path: str = "data/embeddings/image_index.npz",
min_similarity: float = 0.20,
) -> ImageRetrievalResult:
"""
Generate (retrieve) an image and return full metadata.
Use this in experiment pipelines where retrieval quality matters.
"""
generator = ImprovedImageRetrievalGenerator(
index_path=index_path,
min_similarity=min_similarity,
)
return generator.retrieve(prompt, min_similarity=min_similarity)