File size: 5,966 Bytes
61d3625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Runtime retrieval service utilities."""
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional

import numpy as np
import pandas as pd
from PIL import Image

from .config import CONFIG
from .embeddings import ImageEncoder, TextEncoder
from .indexers import AnnoyVectorIndex


class RetrievalService:
    def __init__(self) -> None:
        cfg = CONFIG
        cfg.prepare()
        paths = cfg.paths
        
        def _parse_tags(value: object) -> list[str]:
            if isinstance(value, str):
                try:
                    parsed = json.loads(value)
                    if isinstance(parsed, list):
                        return [str(item) for item in parsed]
                except json.JSONDecodeError:
                    return [part.strip() for part in value.split(",") if part.strip()]
            elif isinstance(value, (list, tuple, set)):
                return [str(item) for item in value]
            return []

        metadata_df = pd.read_parquet(paths.omni_metadata_path).set_index("id")
        if "tags" in metadata_df.columns:
            metadata_df["tags"] = metadata_df["tags"].apply(_parse_tags)
        else:
            metadata_df["tags"] = [[] for _ in range(len(metadata_df))]
        metadata_df["style"] = metadata_df["style"].fillna("unknown")
        metadata_df["genre"] = metadata_df["genre"].fillna("unknown")
        self.metadata = metadata_df
        self.id_to_idx = {sample_id: idx for idx, sample_id in enumerate(self.metadata.index)}

        self.omni_tag_matrix = np.load(paths.indexes_dir / "labels" / "omni_tag_embeddings.npy")
        with open(paths.indexes_dir / "labels" / "omni_tags.json", "r", encoding="utf-8") as fin:
            self.omni_tags: list[str] = json.load(fin)

        self.caption_embeddings = np.load(paths.embeddings_dir / "caption_embeddings.npy")
        self.image_embeddings = np.load(paths.embeddings_dir / "image_embeddings.npy")

        self.image_index = AnnoyVectorIndex(
            dimension=self.image_embeddings.shape[1],
        )
        self.image_index.load(paths.indexes_dir / "image.ann")

        self.caption_index = AnnoyVectorIndex(
            dimension=self.caption_embeddings.shape[1],
        )
        self.caption_index.load(paths.indexes_dir / "caption.ann")

        self.paths = paths
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()

    def _resolve_image(self, reference: str | Path | Image.Image) -> Image.Image:
        if isinstance(reference, Image.Image):
            return reference
        path = Path(reference)
        if not path.is_absolute():
            path = self.paths.root / path
        image = Image.open(path)
        return image

    def _metadata_for(self, sample_id: str) -> dict:
        row = self.metadata.loc[sample_id].to_dict()
        row["id"] = sample_id
        image_path = row.get("image_path")
        if image_path:
            row["image_path"] = str((self.paths.root / image_path).resolve())
        return row

    def search_similar_images(self, image: str | Path | Image.Image, top_k: int | None = None) -> list[dict]:
        candidate_image = self._resolve_image(image)
        embedding = self.image_encoder.encode([candidate_image]).numpy()[0]
        results = self.image_index.query(embedding, top_k=top_k)
        return [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results]

    def search_by_caption(self, caption: str, top_k: int | None = None) -> list[dict]:
        embedding = self.text_encoder.encode([caption]).numpy()[0]
        results = self.caption_index.query(embedding, top_k=top_k)
        return [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results]

    def search_omni(

        self,

        text_query: str | None = None,

        styles: Optional[Iterable[int]] = None,

        genres: Optional[Iterable[int]] = None,

        extra_tags: Optional[Iterable[str]] = None,

        top_k: int | None = None,

    ) -> list[dict]:
        candidate_ids = self._filter_candidates(
            styles=styles or [],
            genres=genres or [],
            extra_tags=extra_tags or [],
        )
        if not candidate_ids:
            return []

        candidate_vectors = self.caption_embeddings[[self.id_to_idx[sample_id] for sample_id in candidate_ids]]
        
        if text_query:
            temp_index = AnnoyVectorIndex(dimension=candidate_vectors.shape[1]) # Building runs quickly, so we can afford it in inference
            temp_index.build(candidate_vectors, candidate_ids)

            requested_top_k = top_k or CONFIG.index.top_k
            requested_top_k = min(requested_top_k, len(candidate_ids))

            text_embedding = self.text_encoder.encode([text_query]).numpy()[0]

            results = temp_index.query(text_embedding, top_k=requested_top_k)
        else:
            results = [(sample_id, 0.0) for sample_id in candidate_ids[: top_k or len(candidate_ids)]]

        formatted = [self._metadata_for(sample_id) | {"distance": float(distance)} for sample_id, distance in results]
        return formatted

    def _filter_candidates(

        self,

        styles: Iterable[int],

        genres: Iterable[int],

        extra_tags: Iterable[str],

    ) -> list[str]:
        df = self.metadata
        mask = pd.Series(True, index=df.index)
        if styles:
            mask &= df["style"].isin(list(styles))
        if genres:
            mask &= df["genre"].isin(list(genres))
        if extra_tags:
            required_tags = set(extra_tags)
            mask &= df["tags"].apply(lambda tags: required_tags.issubset(set(tags)))
        return df.index[mask].tolist()