Spaces:
Running
Running
| # import chromadb | |
| import logging | |
| import grpc | |
| from config import settings | |
| from qdrant_client import QdrantClient, models | |
| from sentence_transformers import SentenceTransformer | |
| from src.database.session import engine | |
| from src.database.models import Animes | |
| from sqlmodel import Session, select | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s') | |
| class AnimeRetriever: | |
| """Handles anime retrieval from ChromaDB""" | |
| def __init__(self, | |
| collection_name: str = "anime_collection"): | |
| self.client = QdrantClient(url=settings.qdrant_url, | |
| api_key=settings.qdrant_api_key, | |
| cloud_inference=True, | |
| prefer_grpc=True, | |
| timeout=10) | |
| self.collection_name = collection_name | |
| self.points_count = self.client.count( | |
| collection_name=self.collection_name, exact=False).count | |
| # self.model = SentenceTransformer(model) | |
| self.model = "sentence-transformers/all-minilm-l6-v2" | |
| print( | |
| f"Loaded collection with {self.points_count} anime approximately") | |
| def fetch_anime_batch_from_postgres(self, mal_ids: list[int]) -> dict[int, Animes]: | |
| """Fetch multiple animes at once and return a dictionary mapped by mal_id""" | |
| with Session(engine) as session: | |
| # The .in_() operator acts like SQL's "WHERE mal_id IN (1, 2, 3)" | |
| statement = select(Animes).where(Animes.mal_id.in_(mal_ids)) | |
| results = session.exec(statement).all() | |
| return {anime.mal_id: anime for anime in results} | |
| def search( | |
| self, | |
| query: str, | |
| n_results: int = 5, | |
| genre_filter: list[str] | None = None, | |
| min_score: float | None = 6.0, | |
| anime_type: str | None = None | |
| ) -> list[dict]: | |
| """ | |
| Search for anime similar to query | |
| Args: | |
| query: User search query | |
| n_results: Number of results to return | |
| genre_filter: Optional genre to filter by | |
| min_score: Minimum MAL score (e.g., 7.0) | |
| anime_type: Type of Anime (e.g. TV, Movie, etc) | |
| Returns: | |
| List of dicts with anime info | |
| """ | |
| must_conditions = [ | |
| # Base condition: scored_by >= 9000 | |
| models.FieldCondition( | |
| key="scored_by", | |
| range=models.Range(gte=20000) | |
| ) | |
| ] | |
| if min_score: | |
| logger.info(f"SCORE: Filtered based on min_score: {min_score}") | |
| must_conditions.append( | |
| models.FieldCondition( | |
| key="score", | |
| range=models.Range(gte=min_score) | |
| ) | |
| ) | |
| if anime_type: | |
| logger.info( | |
| f"ANIME TYPE: Filtered based on anime_type: {anime_type}") | |
| must_conditions.append( | |
| models.FieldCondition( | |
| key="type", | |
| match=models.MatchValue(value=anime_type) | |
| ) | |
| ) | |
| if genre_filter: | |
| logger.info( | |
| f"GENRE: Pre-filtering (OR) for genres: {', '.join(genre_filter)}") | |
| # Qdrant's MatchAny automatically acts as an OR condition against list fields! | |
| must_conditions.append( | |
| models.FieldCondition( | |
| key="genres", | |
| match=models.MatchAny(any=genre_filter) | |
| ) | |
| ) | |
| # Wrap all conditions in a Filter object | |
| query_filter = models.Filter( | |
| must=must_conditions) if must_conditions else None | |
| search_results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=models.Document( | |
| text=query, | |
| model=self.model | |
| ), | |
| query_filter=query_filter, | |
| limit=n_results | |
| ).points | |
| if not search_results: | |
| return [] | |
| retrieved_ids = [hit.id for hit in search_results] | |
| postgres_data_map = self.fetch_anime_batch_from_postgres(retrieved_ids) | |
| anime_list = [] | |
| for hit in search_results: | |
| mal_id = hit.id | |
| similarity_score = hit.score # Qdrant returns cosine similarity here | |
| # Get the rich data from our Postgres map | |
| pg_anime = postgres_data_map.get(mal_id) | |
| if not pg_anime: | |
| logger.warning( | |
| f"Anime ID {mal_id} found in Qdrant but missing in Postgres!") | |
| continue | |
| # Merge Vector Search results with Postgres truths | |
| anime_info = { | |
| "mal_id": pg_anime.mal_id, | |
| "mal_url": pg_anime.url, | |
| "title": pg_anime.title, | |
| "title_english": pg_anime.title_english, | |
| "score": pg_anime.score, | |
| "scored_by": pg_anime.scored_by, | |
| "type": pg_anime.type, | |
| "year": pg_anime.year, | |
| "genres": pg_anime.genres, | |
| "studios": pg_anime.studios, | |
| "themes": pg_anime.themes, | |
| "demographics": pg_anime.demographics, | |
| "episodes": pg_anime.episodes, | |
| "popularity": pg_anime.popularity, | |
| "rating": pg_anime.rating, | |
| "aired_from": pg_anime.aired_from, | |
| "aired_to": pg_anime.aired_to, | |
| "favorites": pg_anime.favorites, | |
| "images": pg_anime.images, | |
| "synopsis": pg_anime.synopsis, | |
| "searchable_text": pg_anime.searchable_text, | |
| } | |
| anime_list.append(anime_info) | |
| return anime_list | |
| def get_by_title(self, title: str) -> dict | None: | |
| """Get anime by exact or partial title match""" | |
| # Search with title as query | |
| results = self.search(query=title, n_results=1) | |
| return results[0] if results else None | |
| if __name__ == "__main__": | |
| retriever = AnimeRetriever() | |
| # Test queries | |
| print("=== Test 1: Basic Search ===") | |
| results = retriever.search("dark psychological anime", n_results=15) | |
| for anime in results: | |
| print( | |
| f"- {anime['title']} (score: {anime['score']})") | |
| # print("\n=== Test 2: Genre Filter ===") | |
| # results = retriever.search( | |
| # "high school", n_results=30, genre_filter=["Fantasy", "Action", "Comedy", "Adventure"]) | |
| # for anime in results: | |
| # print(f"- {anime['title']} ({anime['genres']})") | |
| print("\n=== Test 3: Genre Filter ===") | |
| results = retriever.search( | |
| "Overpowered Main character", n_results=5, genre_filter=["Adventure"]) | |
| for anime in results: | |
| # print( | |
| # f"- {anime['title']} ({anime['genres']}) (Score: {anime["score"]}) (Scored by: {anime["scored_by"]})") | |
| print(anime) | |
| break | |
| # print("\n=== Test 3: Score Filter ===") | |
| # results = retriever.search("adventure", n_results=5, min_score=9.0) | |
| # for anime in results: | |
| # print(f"- {anime['title']} (score: {anime['score']})") | |
| # print("\n=== Test 4: Scored by Filter ===") | |
| # results = retriever.search("adventure", n_results=5, min_score=8.0) | |
| # for anime in results: | |
| # print( | |
| # f"- {anime['title']} (score: {anime['score']}) (scored_by: {anime['scored_by']})") | |
| # print("\n=== Test 5: TYPE Filter ===") | |
| # results = retriever.search( | |
| # "Attack On Titan", n_results=5, anime_type="Special") | |
| # for anime in results: | |
| # print( | |
| # f"- {anime['title']} (Anime Type: {anime['type']})") | |