File size: 1,820 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import annotations

from pathlib import Path
from typing import List, Tuple

import numpy as np

from src.embeddings.aligned_embeddings import AlignedEmbedder
from src.embeddings.similarity import cosine_similarity


class ImageRetrievalGenerator:
    """
    V1 image generator via retrieval.
    """

    def __init__(self, index_path: str = "data/embeddings/image_index.npz"):
        self.index_path = Path(index_path)

        if not self.index_path.exists():
            raise RuntimeError(
                f"[ImageRetrievalGenerator] Missing image index at {self.index_path}. "
                "Run scripts/build_embedding_indexes.py first."
            )

        data = np.load(self.index_path, allow_pickle=True)
        self.ids = data["ids"].tolist()
        self.embs = data["embs"].astype("float32")

        if len(self.ids) == 0:
            raise RuntimeError(
                "[ImageRetrievalGenerator] Image index is empty. "
                "Add images to data/processed/images/ and rebuild the index."
            )

        self.embedder = AlignedEmbedder(target_dim=512)

    def retrieve_top_k(self, query_text: str, k: int = 5) -> List[Tuple[str, float]]:
        query_emb = self.embedder.embed_text(query_text)
        scored = [
            (path, cosine_similarity(query_emb, emb))
            for path, emb in zip(self.ids, self.embs)
        ]
        scored.sort(key=lambda x: x[1], reverse=True)
        return scored[:k]


def generate_image(
    prompt: str,
    out_dir: str,
    index_path: str = "data/embeddings/image_index.npz",
) -> str:
    generator = ImageRetrievalGenerator(index_path=index_path)
    results = generator.retrieve_top_k(prompt, k=1)
    if not results:
        raise RuntimeError("No images available for retrieval.")
    return results[0][0]