AnimeRAGSystem / src /retrieval /vector_search.py
Pushkar02-n's picture
Final changes with supabase and qdrant for first phase
929258f
# import chromadb
import logging
import grpc
from config import settings
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from src.database.session import engine
from src.database.models import Animes
from sqlmodel import Session, select
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
class AnimeRetriever:
"""Handles anime retrieval from ChromaDB"""
def __init__(self,
collection_name: str = "anime_collection"):
self.client = QdrantClient(url=settings.qdrant_url,
api_key=settings.qdrant_api_key,
cloud_inference=True,
prefer_grpc=True,
timeout=10)
self.collection_name = collection_name
self.points_count = self.client.count(
collection_name=self.collection_name, exact=False).count
# self.model = SentenceTransformer(model)
self.model = "sentence-transformers/all-minilm-l6-v2"
print(
f"Loaded collection with {self.points_count} anime approximately")
def fetch_anime_batch_from_postgres(self, mal_ids: list[int]) -> dict[int, Animes]:
"""Fetch multiple animes at once and return a dictionary mapped by mal_id"""
with Session(engine) as session:
# The .in_() operator acts like SQL's "WHERE mal_id IN (1, 2, 3)"
statement = select(Animes).where(Animes.mal_id.in_(mal_ids))
results = session.exec(statement).all()
return {anime.mal_id: anime for anime in results}
def search(
self,
query: str,
n_results: int = 5,
genre_filter: list[str] | None = None,
min_score: float | None = 6.0,
anime_type: str | None = None
) -> list[dict]:
"""
Search for anime similar to query
Args:
query: User search query
n_results: Number of results to return
genre_filter: Optional genre to filter by
min_score: Minimum MAL score (e.g., 7.0)
anime_type: Type of Anime (e.g. TV, Movie, etc)
Returns:
List of dicts with anime info
"""
must_conditions = [
# Base condition: scored_by >= 9000
models.FieldCondition(
key="scored_by",
range=models.Range(gte=20000)
)
]
if min_score:
logger.info(f"SCORE: Filtered based on min_score: {min_score}")
must_conditions.append(
models.FieldCondition(
key="score",
range=models.Range(gte=min_score)
)
)
if anime_type:
logger.info(
f"ANIME TYPE: Filtered based on anime_type: {anime_type}")
must_conditions.append(
models.FieldCondition(
key="type",
match=models.MatchValue(value=anime_type)
)
)
if genre_filter:
logger.info(
f"GENRE: Pre-filtering (OR) for genres: {', '.join(genre_filter)}")
# Qdrant's MatchAny automatically acts as an OR condition against list fields!
must_conditions.append(
models.FieldCondition(
key="genres",
match=models.MatchAny(any=genre_filter)
)
)
# Wrap all conditions in a Filter object
query_filter = models.Filter(
must=must_conditions) if must_conditions else None
search_results = self.client.query_points(
collection_name=self.collection_name,
query=models.Document(
text=query,
model=self.model
),
query_filter=query_filter,
limit=n_results
).points
if not search_results:
return []
retrieved_ids = [hit.id for hit in search_results]
postgres_data_map = self.fetch_anime_batch_from_postgres(retrieved_ids)
anime_list = []
for hit in search_results:
mal_id = hit.id
similarity_score = hit.score # Qdrant returns cosine similarity here
# Get the rich data from our Postgres map
pg_anime = postgres_data_map.get(mal_id)
if not pg_anime:
logger.warning(
f"Anime ID {mal_id} found in Qdrant but missing in Postgres!")
continue
# Merge Vector Search results with Postgres truths
anime_info = {
"mal_id": pg_anime.mal_id,
"mal_url": pg_anime.url,
"title": pg_anime.title,
"title_english": pg_anime.title_english,
"score": pg_anime.score,
"scored_by": pg_anime.scored_by,
"type": pg_anime.type,
"year": pg_anime.year,
"genres": pg_anime.genres,
"studios": pg_anime.studios,
"themes": pg_anime.themes,
"demographics": pg_anime.demographics,
"episodes": pg_anime.episodes,
"popularity": pg_anime.popularity,
"rating": pg_anime.rating,
"aired_from": pg_anime.aired_from,
"aired_to": pg_anime.aired_to,
"favorites": pg_anime.favorites,
"images": pg_anime.images,
"synopsis": pg_anime.synopsis,
"searchable_text": pg_anime.searchable_text,
}
anime_list.append(anime_info)
return anime_list
def get_by_title(self, title: str) -> dict | None:
"""Get anime by exact or partial title match"""
# Search with title as query
results = self.search(query=title, n_results=1)
return results[0] if results else None
if __name__ == "__main__":
retriever = AnimeRetriever()
# Test queries
print("=== Test 1: Basic Search ===")
results = retriever.search("dark psychological anime", n_results=15)
for anime in results:
print(
f"- {anime['title']} (score: {anime['score']})")
# print("\n=== Test 2: Genre Filter ===")
# results = retriever.search(
# "high school", n_results=30, genre_filter=["Fantasy", "Action", "Comedy", "Adventure"])
# for anime in results:
# print(f"- {anime['title']} ({anime['genres']})")
print("\n=== Test 3: Genre Filter ===")
results = retriever.search(
"Overpowered Main character", n_results=5, genre_filter=["Adventure"])
for anime in results:
# print(
# f"- {anime['title']} ({anime['genres']}) (Score: {anime["score"]}) (Scored by: {anime["scored_by"]})")
print(anime)
break
# print("\n=== Test 3: Score Filter ===")
# results = retriever.search("adventure", n_results=5, min_score=9.0)
# for anime in results:
# print(f"- {anime['title']} (score: {anime['score']})")
# print("\n=== Test 4: Scored by Filter ===")
# results = retriever.search("adventure", n_results=5, min_score=8.0)
# for anime in results:
# print(
# f"- {anime['title']} (score: {anime['score']}) (scored_by: {anime['scored_by']})")
# print("\n=== Test 5: TYPE Filter ===")
# results = retriever.search(
# "Attack On Titan", n_results=5, anime_type="Special")
# for anime in results:
# print(
# f"- {anime['title']} (Anime Type: {anime['type']})")