Spaces:
Sleeping
Sleeping
File size: 7,777 Bytes
929258f 1a2b9e6 f69a6fa 1a2b9e6 929258f 1a2b9e6 f69a6fa 1a2b9e6 f69a6fa 1a2b9e6 f69a6fa 1a2b9e6 929258f 1a2b9e6 f69a6fa 929258f f69a6fa 929258f 1a2b9e6 f69a6fa 929258f f69a6fa 929258f f69a6fa 1a2b9e6 929258f 1a2b9e6 f69a6fa 1a2b9e6 f69a6fa 929258f f69a6fa 929258f 1a2b9e6 f69a6fa 929258f f69a6fa 929258f 1a2b9e6 929258f 1a2b9e6 f69a6fa 1a2b9e6 f69a6fa 1a2b9e6 f69a6fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | # import chromadb
import logging
import grpc
from config import settings
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from src.database.session import engine
from src.database.models import Animes
from sqlmodel import Session, select
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
class AnimeRetriever:
"""Handles anime retrieval from ChromaDB"""
def __init__(self,
collection_name: str = "anime_collection"):
self.client = QdrantClient(url=settings.qdrant_url,
api_key=settings.qdrant_api_key,
cloud_inference=True,
prefer_grpc=True,
timeout=10)
self.collection_name = collection_name
self.points_count = self.client.count(
collection_name=self.collection_name, exact=False).count
# self.model = SentenceTransformer(model)
self.model = "sentence-transformers/all-minilm-l6-v2"
print(
f"Loaded collection with {self.points_count} anime approximately")
def fetch_anime_batch_from_postgres(self, mal_ids: list[int]) -> dict[int, Animes]:
"""Fetch multiple animes at once and return a dictionary mapped by mal_id"""
with Session(engine) as session:
# The .in_() operator acts like SQL's "WHERE mal_id IN (1, 2, 3)"
statement = select(Animes).where(Animes.mal_id.in_(mal_ids))
results = session.exec(statement).all()
return {anime.mal_id: anime for anime in results}
def search(
self,
query: str,
n_results: int = 5,
genre_filter: list[str] | None = None,
min_score: float | None = 6.0,
anime_type: str | None = None
) -> list[dict]:
"""
Search for anime similar to query
Args:
query: User search query
n_results: Number of results to return
genre_filter: Optional genre to filter by
min_score: Minimum MAL score (e.g., 7.0)
anime_type: Type of Anime (e.g. TV, Movie, etc)
Returns:
List of dicts with anime info
"""
must_conditions = [
# Base condition: scored_by >= 9000
models.FieldCondition(
key="scored_by",
range=models.Range(gte=20000)
)
]
if min_score:
logger.info(f"SCORE: Filtered based on min_score: {min_score}")
must_conditions.append(
models.FieldCondition(
key="score",
range=models.Range(gte=min_score)
)
)
if anime_type:
logger.info(
f"ANIME TYPE: Filtered based on anime_type: {anime_type}")
must_conditions.append(
models.FieldCondition(
key="type",
match=models.MatchValue(value=anime_type)
)
)
if genre_filter:
logger.info(
f"GENRE: Pre-filtering (OR) for genres: {', '.join(genre_filter)}")
# Qdrant's MatchAny automatically acts as an OR condition against list fields!
must_conditions.append(
models.FieldCondition(
key="genres",
match=models.MatchAny(any=genre_filter)
)
)
# Wrap all conditions in a Filter object
query_filter = models.Filter(
must=must_conditions) if must_conditions else None
search_results = self.client.query_points(
collection_name=self.collection_name,
query=models.Document(
text=query,
model=self.model
),
query_filter=query_filter,
limit=n_results
).points
if not search_results:
return []
retrieved_ids = [hit.id for hit in search_results]
postgres_data_map = self.fetch_anime_batch_from_postgres(retrieved_ids)
anime_list = []
for hit in search_results:
mal_id = hit.id
similarity_score = hit.score # Qdrant returns cosine similarity here
# Get the rich data from our Postgres map
pg_anime = postgres_data_map.get(mal_id)
if not pg_anime:
logger.warning(
f"Anime ID {mal_id} found in Qdrant but missing in Postgres!")
continue
# Merge Vector Search results with Postgres truths
anime_info = {
"mal_id": pg_anime.mal_id,
"mal_url": pg_anime.url,
"title": pg_anime.title,
"title_english": pg_anime.title_english,
"score": pg_anime.score,
"scored_by": pg_anime.scored_by,
"type": pg_anime.type,
"year": pg_anime.year,
"genres": pg_anime.genres,
"studios": pg_anime.studios,
"themes": pg_anime.themes,
"demographics": pg_anime.demographics,
"episodes": pg_anime.episodes,
"popularity": pg_anime.popularity,
"rating": pg_anime.rating,
"aired_from": pg_anime.aired_from,
"aired_to": pg_anime.aired_to,
"favorites": pg_anime.favorites,
"images": pg_anime.images,
"synopsis": pg_anime.synopsis,
"searchable_text": pg_anime.searchable_text,
}
anime_list.append(anime_info)
return anime_list
def get_by_title(self, title: str) -> dict | None:
"""Get anime by exact or partial title match"""
# Search with title as query
results = self.search(query=title, n_results=1)
return results[0] if results else None
if __name__ == "__main__":
retriever = AnimeRetriever()
# Test queries
print("=== Test 1: Basic Search ===")
results = retriever.search("dark psychological anime", n_results=15)
for anime in results:
print(
f"- {anime['title']} (score: {anime['score']})")
# print("\n=== Test 2: Genre Filter ===")
# results = retriever.search(
# "high school", n_results=30, genre_filter=["Fantasy", "Action", "Comedy", "Adventure"])
# for anime in results:
# print(f"- {anime['title']} ({anime['genres']})")
print("\n=== Test 3: Genre Filter ===")
results = retriever.search(
"Overpowered Main character", n_results=5, genre_filter=["Adventure"])
for anime in results:
# print(
# f"- {anime['title']} ({anime['genres']}) (Score: {anime["score"]}) (Scored by: {anime["scored_by"]})")
print(anime)
break
# print("\n=== Test 3: Score Filter ===")
# results = retriever.search("adventure", n_results=5, min_score=9.0)
# for anime in results:
# print(f"- {anime['title']} (score: {anime['score']})")
# print("\n=== Test 4: Scored by Filter ===")
# results = retriever.search("adventure", n_results=5, min_score=8.0)
# for anime in results:
# print(
# f"- {anime['title']} (score: {anime['score']}) (scored_by: {anime['scored_by']})")
# print("\n=== Test 5: TYPE Filter ===")
# results = retriever.search(
# "Attack On Titan", n_results=5, anime_type="Special")
# for anime in results:
# print(
# f"- {anime['title']} (Anime Type: {anime['type']})")
|