Pushkar02-n commited on
Commit
1d33c61
·
verified ·
1 Parent(s): d906298

Delete src/data_ingestion

Browse files
src/data_ingestion/clean_data.py DELETED
@@ -1,160 +0,0 @@
1
- import json
2
- import pandas as pd
3
- import logging
4
-
5
- logger = logging.getLogger(__name__)
6
- logging.basicConfig(level=logging.INFO)
7
-
8
-
9
- class AnimeDataCleaner:
10
- """Cleans and prepares anime data for embeddings"""
11
-
12
- @staticmethod
13
- def load_raw_data(filepath: str = "data/raw/raw_anime.json") -> list[dict]:
14
- """Load raw anime data"""
15
- with open(filepath, 'r', encoding='utf-8') as f:
16
- return json.load(f)
17
-
18
- @staticmethod
19
- def clean_synopsis(synopsis: str) -> str:
20
- """Clean synopsis text"""
21
- if not synopsis or synopsis == "":
22
- return "No synopsis available."
23
-
24
- synopsis = synopsis.replace("[Written by MAL Rewrite]", "")
25
- synopsis.strip()
26
-
27
- return synopsis
28
-
29
- @staticmethod
30
- def create_searchable_text(anime: dict) -> str:
31
- """
32
- Combine multiple fields into one searchable text.
33
- This is what we'll embed!
34
-
35
- Format: Title. Genres. Synopsis.
36
- """
37
- title = anime.get("title")
38
- title_en = anime.get("title_english")
39
-
40
- title_text = title
41
- if title_en and title_en != title:
42
- title_text = f"{title} ({title_en})"
43
-
44
- genres = ", ".join(anime.get("genres", []))
45
- themes = ", ".join(anime.get("themes", []))
46
- demographics = ", ".join(anime.get("demographics", []))
47
-
48
- genre_parts = [p for p in [genres, themes, demographics] if p]
49
- genre_text = ". ".join(genre_parts) if genre_parts else ""
50
-
51
- synopsis = AnimeDataCleaner.clean_synopsis(anime.get("synopsis", ""))
52
-
53
- searchable_text = f"{title_text}. {genre_text}. {synopsis}"
54
-
55
- return searchable_text.strip()
56
-
57
- @staticmethod
58
- def filter_valid_anime(anime_list: list[dict]) -> list[dict]:
59
- """Remove anime without synopsis or essential fields"""
60
- valid_anime = []
61
-
62
- for anime in anime_list:
63
- if not anime.get("title"):
64
- continue
65
-
66
- synopsis = anime.get("synopsis", "")
67
- if not synopsis or len(synopsis) < 50:
68
- continue # Later modify this part to get custom synopsis from online sources
69
-
70
- valid_anime.append(anime)
71
-
72
- print(f"Filtered {len(anime_list)} -> {len(valid_anime)} animes")
73
- logger.info(f"Filtered {len(anime_list)} -> {len(valid_anime)} animes")
74
-
75
- return valid_anime
76
-
77
- @staticmethod
78
- def prepare_for_embedding(anime_list: list[dict]) -> pd.DataFrame:
79
- """
80
- Prepare final dataset for embedding
81
-
82
- Returns dataframe with columns:
83
- - mal_id: unique identifier
84
- - searchable_text: what to embed
85
- - metadata: everything else (for filtering/display)
86
- """
87
-
88
- records = []
89
-
90
- for anime in anime_list:
91
- record = {
92
- "mal_id": anime["mal_id"],
93
- "url": anime.get("url"),
94
- "title": anime.get("title"),
95
- "title_english": anime.get("title_english"),
96
- "synopsis": AnimeDataCleaner.clean_synopsis(anime.get("synopsis", "")),
97
-
98
- # Keep these as native dicts/lists for Postgres JSONB!
99
- "images": anime.get("images", {}),
100
- "genres": anime.get("genres", []),
101
- "studios": anime.get("studios", []),
102
- "themes": anime.get("themes", []),
103
- "demographics": anime.get("demographics", []),
104
-
105
- "type": anime.get("type"),
106
- "episodes": anime.get("episodes"),
107
- "score": anime.get("score"),
108
- "scored_by": anime.get("scored_by"),
109
- "rank": anime.get("rank"),
110
- "popularity": anime.get("popularity"),
111
- "year": anime.get("year"),
112
- "season": anime.get("season"),
113
- "rating": anime.get("rating"),
114
- "aired_from": anime.get("aired_from"),
115
- "aired_to": anime.get("aired_to"),
116
- "favorites": anime.get("favorites"),
117
-
118
- # And our custom RAG field
119
- "searchable_text": AnimeDataCleaner.create_searchable_text(anime)
120
- }
121
-
122
- records.append(record)
123
-
124
- df = pd.DataFrame(records)
125
- return df
126
-
127
- @staticmethod
128
- def save_processed_data(df: pd.DataFrame, filepath: str = "data/processed/anime_clean.csv"):
129
- """Save processed data"""
130
- df.to_csv(filepath, index=False, encoding="utf-8")
131
- logger.info(f"Saved {len(df)} anime to {filepath}")
132
-
133
- json_path = filepath.replace(".csv", ".json")
134
- df.to_json(json_path, orient="records", indent=2, force_ascii=False)
135
-
136
- logger.info(f"Also saved to {json_path}")
137
-
138
-
139
- if __name__ == "__main__":
140
- cleaner = AnimeDataCleaner()
141
-
142
- print("Loading raw data....")
143
- raw_animes = cleaner.load_raw_data("data/raw/raw_anime.json")
144
-
145
- valid_animes = cleaner.filter_valid_anime(raw_animes)
146
-
147
- print("\nPreparing data for embedding...")
148
- df = cleaner.prepare_for_embedding(valid_animes)
149
-
150
- cleaner.save_processed_data(
151
- df, filepath="data/processed/anime_clean.csv")
152
-
153
- print("\nSample searchable text:")
154
- print(df.iloc[0]["searchable_text"][:500])
155
-
156
- print(f"\nDataset statistics:")
157
- print(f"Total anime: {len(df)}")
158
- print(
159
- f"Average text length: {df['searchable_text'].str.len().mean():.0f} chars")
160
- print(f"Score range: {df['score'].min():.1f} - {df['score'].max():.1f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_ingestion/create_embeddings.py DELETED
@@ -1,232 +0,0 @@
1
- from sentence_transformers import SentenceTransformer
2
- # import chromadb
3
- # from chromadb.config import Settings
4
- from qdrant_client import QdrantClient, models
5
- from sqlmodel import Session, select
6
- import logging
7
- from config import settings
8
-
9
- from src.database.session import engine
10
- from src.database.models import Animes
11
-
12
- logger = logging.getLogger(__name__)
13
- logging.basicConfig(level=logging.INFO,
14
- format='%(asctime)s - %(levelname)s - %(message)s')
15
-
16
-
17
- class EmbeddingPipeline:
18
- """Creates embeddings and store in ChromaDB"""
19
-
20
- def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
21
- """
22
- Initialize embedding model and ChromaDB client
23
-
24
- Args:
25
- model_name: HuggingFace model for embeddings
26
- all-MiniLM-L6-v2: Fast, good quality, 384 dims
27
- """
28
- logger.info(f"Loading embedding model: {model_name}")
29
- self.model = SentenceTransformer(model_name)
30
- self.vector_size = self.model.get_sentence_embedding_dimension() or 0
31
-
32
- # self.chroma_client = chromadb.PersistentClient(
33
- # path="data/embeddings/chroma_db")
34
- self.client = QdrantClient(url=settings.qdrant_url,
35
- api_key=settings.qdrant_api_key,
36
- cloud_inference=True)
37
-
38
- self.use_existing_embeddings = False
39
- print("ChromaDB initialized at data/embeddings/chroma_db")
40
-
41
- def create_or_get_collection(self, collection_name: str = "anime_collection"):
42
- """Create or get existing collection"""
43
- if self.client.collection_exists(collection_name=collection_name):
44
- logger.info(f"Found existing collection: {collection_name}")
45
-
46
- collection = self.client.get_collection(collection_name)
47
- logger.info(f"Found existing collection: {collection_name}")
48
- logger.info(f"Current count: {collection.points_count} points")
49
-
50
- user_input = input("Reset collection? (y/n): ")
51
- if user_input.lower() == "y":
52
- self.client.delete_collection(collection_name)
53
- logger.info("Collection reset")
54
- else:
55
- self.use_existing_embeddings = True
56
- return collection_name
57
-
58
- if not self.use_existing_embeddings:
59
- is_collection_created = self.client.create_collection(collection_name=collection_name,
60
- vectors_config=models.VectorParams(
61
- size=self.vector_size,
62
- distance=models.Distance.COSINE
63
- ))
64
- logger.info(f"Created new collection: {collection_name}: {is_collection_created}")
65
-
66
- return collection_name
67
-
68
- def fetch_data_from_postgres(self, batch_size: int = 2000):
69
- """Fetch anime records from PostgreSQL in batches to avoid timeouts"""
70
- logger.info("Fetching data from PostgreSQL in batches...")
71
- all_results = []
72
-
73
- with Session(engine) as session:
74
- offset = 0
75
- while True:
76
- # order_by is strictly required when using offset/limit to guarantee no duplicates
77
- statement = (
78
- select(Animes)
79
- .where(Animes.searchable_text != None)
80
- .order_by(Animes.id)
81
- .offset(offset)
82
- .limit(batch_size)
83
- )
84
-
85
- batch = session.exec(statement).all()
86
-
87
- if not batch:
88
- break # Break the loop when no more rows are returned
89
-
90
- all_results.extend(batch)
91
- offset += len(batch)
92
- logger.info(
93
- f"Downloaded {offset} rows from Supabase so far...")
94
-
95
- logger.info(
96
- f"Successfully fetched a total of {len(all_results)} records.")
97
- return all_results
98
-
99
- def embed_texts(self, texts: list[str], batch_size: int = 32) -> list[list[float]] | None:
100
- """
101
- Create embeddings for texts
102
-
103
- Args:
104
- texts: List of texts to embed
105
- batch_size: Process in batches for efficiency
106
- """
107
-
108
- if self.use_existing_embeddings == False:
109
- logger.info(f"Embedding {len(texts)} texts...")
110
-
111
- embeddings = self.model.encode(
112
- sentences=texts,
113
- batch_size=batch_size,
114
- show_progress_bar=True,
115
- convert_to_numpy=True
116
- )
117
-
118
- return embeddings.tolist()
119
-
120
- else:
121
- logger.info(f"Using existing stored embeddings.")
122
-
123
- def store_in_QdrantDB(self, client: QdrantClient, collection_name, db_records: list[Animes], final_texts: list[str], embeddings: list[list[float]]):
124
- """
125
- Store embeddings and metadata in QdrantDB
126
-
127
- Args:
128
- client: QdrantDB Client
129
- collection_name: QdrantDB collection name,
130
- db_records: List of Anime data retrieved from PostgreSQL database,
131
- embeddings: Pre_commputed embeddings
132
- """
133
-
134
- logger.info("Storing in QdrantDB...")
135
-
136
- points = []
137
-
138
- for i, row in enumerate(db_records):
139
- genres_list = row.genres if isinstance(row.genres, list) else []
140
- if len(genres_list) == 0:
141
- genres_list = ["Unknown"]
142
-
143
- # Qdrant uses 'PointStruct' which holds the ID, Vector, and Payload (metadata + document)
144
- point = models.PointStruct(
145
- # Qdrant requires IDs to be integers or UUIDs
146
- id=int(row.mal_id),
147
- vector=embeddings[i],
148
- payload={
149
- # Store the text here since Qdrant doesn't separate docs from metadata
150
- "document": final_texts[i],
151
- "title": row.title,
152
- "genres": genres_list,
153
- "score": float(row.score) if row.score else 0.0,
154
- "type": row.type if row.type else "Unknown",
155
- "scored_by": row.scored_by if row.scored_by else 0
156
- }
157
- )
158
- points.append(point)
159
-
160
- chunk_size = 500
161
- total_chunks = (len(points) // chunk_size) + 1
162
- logger.info(f"Inserting into Qdrant in {total_chunks} batches...")
163
-
164
- for i in range(0, len(points), chunk_size):
165
- batch = points[i: i + chunk_size]
166
- self.client.upsert(
167
- collection_name=collection_name,
168
- points=batch
169
- )
170
- logger.info(f"Inserted batch {(i//chunk_size)+1}/{total_chunks}")
171
-
172
- logger.info(f"Successfully stored {len(points)} animes in Qdrant")
173
-
174
- def run_pipeline(self):
175
- """Run complete embedding pipeline"""
176
-
177
- # 1. Fetch from DB instead of CSV
178
- db_records = self.fetch_data_from_postgres()
179
- logger.info(f"Loaded {len(db_records)} animes from Postgres")
180
-
181
- collection_name = self.create_or_get_collection()
182
-
183
- if not self.use_existing_embeddings:
184
- texts_to_embed = []
185
- for row in db_records:
186
- text = row.searchable_text if row.searchable_text else ""
187
- if hasattr(row, 'studios') and row.studios:
188
- text += f" Studio: {', '.join(row.studios)}"
189
- texts_to_embed.append(text)
190
-
191
- print(texts_to_embed[0])
192
- embeddings = self.embed_texts(texts_to_embed)
193
-
194
- if embeddings:
195
- self.store_in_QdrantDB(
196
- self.client, collection_name, db_records, texts_to_embed, embeddings)
197
-
198
- logger.info("Embedding pipeline complete!")
199
- return collection_name
200
-
201
-
202
- if __name__ == "__main__":
203
- pass
204
- # pipeline = EmbeddingPipeline()
205
- # collection_name = pipeline.run_pipeline()
206
-
207
- # client = QdrantClient(
208
- # url=settings.qdrant_url,
209
- # api_key=settings.qdrant_api_key,
210
- # cloud_inference=True
211
- # )
212
-
213
- # print("\n--- Testing vector search ---")
214
- # query = "Attack Titan"
215
-
216
- # print(f"Query: {query}")
217
-
218
- # search_results = client.search(
219
- # collection_name=collection_name,
220
- # query_vector=query_vector,
221
- # limit=limit,
222
- # # We want Qdrant to return the payload (metadata) so we can see the titles
223
- # with_payload=True
224
- # )
225
-
226
- # print("\n--- TOP 15 RESULTS ---")
227
-
228
- # for i, (title, distance) in enumerate(zip(
229
- # [m["title"] for m in results["metadatas"][0]],
230
- # results["distances"][0]
231
- # )):
232
- # print(f"{i+1}. {title} (distance: {distance:.3f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_ingestion/fetch_anime.py DELETED
@@ -1,151 +0,0 @@
1
- import requests
2
- import time
3
- import json
4
- import logging
5
- from datetime import datetime
6
-
7
-
8
- def convert_datetime(dt: str | None):
9
- if not dt:
10
- return None
11
- return datetime.fromisoformat(dt)
12
-
13
-
14
- logging.basicConfig(level=logging.INFO,
15
- format='%(asctime)s - %(levelname)s - %(message)s')
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- class AnimeDataFetcher:
20
- """Fetches anime data from Jikan API (Unofficial MyAnimeList API)"""
21
-
22
- BASE_URL = "https://api.jikan.moe/v4/"
23
-
24
- def __init__(self):
25
- self.session = requests.Session()
26
-
27
- def fetch_bulk_anime(self, total_limit: int = 10000, filename: str = "raw_anime.json"):
28
- """
29
- Fetches anime in bulk with Resume and Retry capabilities.
30
- """
31
- filepath = f"data/raw/{filename}"
32
- all_animes = []
33
-
34
- # --- RESUME LOGIC ---
35
- try:
36
- with open(filepath, "r", encoding="utf-8") as f:
37
- all_animes = json.load(f)
38
- logger.info(
39
- f"Found existing data. Resuming from record {len(all_animes)}.")
40
- except (FileNotFoundError, json.JSONDecodeError):
41
- logger.info("No existing data found. Starting fresh.")
42
-
43
- # Calculate the next page to fetch (25 items per page)
44
- page = (len(all_animes) // 25) + 1
45
- max_retries = 5
46
-
47
- while len(all_animes) < total_limit:
48
- retries = 0
49
- success = False
50
-
51
- while retries < max_retries and not success:
52
- try:
53
- logger.info(
54
- f"🚀 Fetching page {page} (Progress: {len(all_animes)}/{total_limit})...")
55
- response = self.session.get(
56
- f"{self.BASE_URL}anime",
57
- params={
58
- "page": page,
59
- "limit": 25,
60
- "order_by": "popularity",
61
- "sort": "asc"
62
- },
63
- timeout=30
64
- )
65
-
66
- if response.status_code == 429:
67
- wait = 60 + (retries * 30) # Increasing wait time
68
- logger.warning(f"⚠️ Rate limit! Sleeping {wait}s...")
69
- time.sleep(wait)
70
- retries += 1
71
- continue
72
-
73
- response.raise_for_status()
74
- data = response.json()
75
- anime_list = data.get("data", [])
76
-
77
- if not anime_list:
78
- logger.info("🏁 No more anime found in API.")
79
- return all_animes
80
-
81
- all_animes.extend(anime_list)
82
-
83
- # --- AUTO-SAVE EVERY PAGE ---
84
- # In production, we save often so we never lose more than 1 page of work
85
- self.save_raw_data(all_animes, filename=filename)
86
-
87
- success = True
88
- page += 1
89
- time.sleep(1.2) # Polite delay
90
-
91
- except Exception as e:
92
- retries += 1
93
- logger.error(
94
- f"❌ Error on page {page}: {e}. Retry {retries}/{max_retries}")
95
- time.sleep(10 * retries)
96
-
97
- if not success:
98
- logger.critical(
99
- f"🛑 Giving up on page {page}. Run script again later to resume.")
100
- break
101
-
102
- return all_animes[:total_limit]
103
-
104
- def extract_relevant_fields(self, anime: dict) -> dict:
105
- """Extract only fields we need for RAG"""
106
- return {
107
- "mal_id": anime.get("mal_id"),
108
- "url": anime.get("url", ""),
109
- "images": anime.get("images", {}),
110
- "title": anime.get("title"),
111
- "title_english": anime.get("title_english"),
112
- "synopsis": anime.get("synopsis"),
113
- "genres": [g["name"] for g in anime.get("genres", [])],
114
- "studios": [s["name"] for s in anime.get("studios", [])],
115
- "themes": [t["name"] for t in anime.get("themes", [])],
116
- "demographics": [d["name"] for d in anime.get("demographics", [])],
117
- "type": anime.get("type"),
118
- "episodes": anime.get("episodes"),
119
- "score": anime.get("score"),
120
- "scored_by": anime.get("scored_by"),
121
- "rank": anime.get("rank"),
122
- "popularity": anime.get("popularity"),
123
- "year": anime.get("year"),
124
- "rating": anime.get("rating"),
125
- "season": anime.get("season"),
126
- "aired_from": anime.get("aired", {}).get("from", ""),
127
- "aired_to": anime.get("aired", {}).get("to", ""),
128
- "favorites": anime.get("favorites")
129
- }
130
-
131
- def save_raw_data(self, anime_list: list[dict], filename: str = "raw_anime.json"):
132
- """Save raw anime data to file"""
133
- with open(f"data/raw/{filename}", "w", encoding="utf-8") as f:
134
- json.dump(anime_list, f, indent=2, ensure_ascii=False)
135
-
136
- print(f"Saved {len(anime_list)} anime to data/raw/{filename}")
137
-
138
-
139
- if __name__ == "__main__":
140
- fetcher = AnimeDataFetcher()
141
-
142
- logger.info("Fetching top 10000 anime from MyAnimeList...")
143
- raw_anime = fetcher.fetch_bulk_anime(total_limit=10000)
144
-
145
- processed_anime = [fetcher.extract_relevant_fields(a) for a in raw_anime]
146
-
147
- fetcher.save_raw_data(processed_anime,
148
- filename="raw_anime.json")
149
-
150
- print("\nSample anime: ")
151
- print(json.dumps(processed_anime[0], indent=2, ensure_ascii=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_ingestion/load_to_postgres.py DELETED
@@ -1,53 +0,0 @@
1
- import json
2
- import logging
3
- from sqlmodel import Session, select
4
- from src.database.session import engine
5
- from src.database.models import Animes
6
- from src.database import init_db
7
-
8
- init_db()
9
-
10
- logging.basicConfig(level=logging.WARNING)
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- def load_json_data(filepath: str = "data/processed/anime_clean.json") -> list[dict]:
15
- """Reads the cleaned JSON file."""
16
- with open(filepath, 'r', encoding='utf-8') as f:
17
- return json.load(f)
18
-
19
-
20
- def insert_animes_to_db(anime_list: list[dict]):
21
- """
22
- Inject a list of anime dictionaries into PostgreSQL safely
23
- """
24
- inserted_count = 0
25
- skipped_count = 0
26
- with Session(engine) as session:
27
- for data in anime_list:
28
- try:
29
- existing = session.exec(
30
- select(Animes).where(Animes.mal_id == data["mal_id"])
31
- ).first()
32
-
33
- if not existing:
34
- new_anime = Animes(**data)
35
- session.add(new_anime)
36
- inserted_count += 1
37
- print(f"Inserted_count: {inserted_count}/{len(anime_list)} animes")
38
- else:
39
- skipped_count += 1
40
-
41
- except Exception as e:
42
- logger.error(
43
- f"Error processing anime ID: {data.get("mal_id")}: {e}")
44
-
45
- session.commit()
46
- logger.info(
47
- f"Injection complete! Inserted: {inserted_count} | Skipped (Duplicates): {skipped_count}")
48
-
49
-
50
- if __name__ == "__main__":
51
- anime_data = load_json_data(filepath="data/processed/anime_clean.json")
52
-
53
- insert_animes_to_db(anime_data)