Spaces:
Sleeping
Sleeping
| import pandas as pd # Import pandas for data manipulation, aliased as pd | |
| import numpy as np # Import numpy for numerical operations, aliased as np | |
| import faiss # Import faiss for efficient similarity search | |
| import json # Import json for working with JSON data | |
| from typing import List, Dict, Optional # Import typing hints for better code readability and static analysis | |
| from datetime import datetime # Import datetime for handling date and time objects | |
| import logging # Import logging module for application logging | |
| from sentence_transformers import SentenceTransformer, CrossEncoder, util # Import specific classes from sentence_transformers library | |
| from indicnlp.tokenize import indic_tokenize # Import tokenizer for Indic languages | |
| from indicnlp.normalize.indic_normalize import IndicNormalizerFactory # Import normalizer factory for Indic languages | |
| import torch # Import PyTorch for deep learning functionalities | |
| import os # Import os module for interacting with the operating system (e.g., file paths) | |
| from indicnlp import common # Import common utilities from the indicnlp library | |
| import pickle # Import pickle for saving/loading Python objects (like lists of IDs) | |
| import re # Import re module for regular expression operations | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Import classes for loading pre-trained models from Hugging Face Transformers | |
| from fastapi import HTTPException # Import HTTPException for raising HTTP errors in FastAPI applications | |
| import asyncio # Import asyncio for asynchronous operations | |
| from src.fine_tuning.trainer import ModelTrainer | |
| from src.fine_tuning.config import MODEL_STATUS | |
| from src.config.settings import ( # Import configuration constants from the settings file | |
| EMBED_MODEL_NAME, GENERATOR_MODEL_NAME, RERANKER_MODEL_NAME, # Model names | |
| INDEX_PATH, INTERACTION_LOG_PATH, INDIC_NLP_RESOURCES_PATH, MONGO_FAISS_META_COLLECTION_NAME, # File paths & DB names | |
| HEADLINE_COL, SEOLOCATION_COL, DEEPLINK_COL, LAST_UPDATED_COL, FINE_TUNED_RERANKER_SAVE_PATH, # Column names & paths | |
| IMAGE_ID_COL, IMAGE_RATIO_COL, IMAGE_SIZE_COL, TAXONOMY_COL, INDEX_IDS_PATH, # Column names | |
| SYN_COL, KEY_COL, ID_COL, TOPIC_COL, PROPERTY_COL, # Existing Column names | |
| DEFAULT_K, SIMILARITY_THRESHOLD, CANDIDATE_MULTIPLIER # Recommendation parameters | |
| ) | |
| from src.database.mongodb import mongodb # Import the MongoDB client instance | |
| logger = logging.getLogger(__name__) # Initialize a logger instance for this module | |
| class RecoRecommender: | |
| """ | |
| A RAG-based recommender system for multi language content using FAISS retrieval. | |
| """ | |
| FAISS_IDS_DOC_ID = "faiss_indexed_document_ids_v1" # Document ID for storing indexed IDs in MongoDB | |
| MODEL_METADATA_DOC_ID = "model_metadata" | |
| EMBEDDING_CHECKSUM_KEY = "embedding_checksum" | |
| MODEL_VERSION_KEY = "model_version" | |
| FINE_TUNING_CONFIG_KEY = "fine_tuning_config" | |
| EMBEDDING_STATUS_KEY = "embedding_status" | |
| MODEL_STATUS = MODEL_STATUS # Add MODEL_STATUS as class attribute | |
| def __init__(self): | |
| """Initialize the recommender with configuration.""" | |
| logger.info("Initializing RecoRecommender...") | |
| # Configuration | |
| self.headline_col = HEADLINE_COL | |
| self.key = KEY_COL | |
| self.syn = SYN_COL | |
| self.id_col = ID_COL | |
| self.topic_col = TOPIC_COL | |
| self.property_col = PROPERTY_COL | |
| self.taxonomy_col = TAXONOMY_COL | |
| self.seolocation_col = SEOLOCATION_COL | |
| self.deeplink_col = DEEPLINK_COL | |
| self.last_updated_col = LAST_UPDATED_COL | |
| self.image_id_col = IMAGE_ID_COL | |
| self.image_ratio_col = IMAGE_RATIO_COL | |
| self.image_size_col = IMAGE_SIZE_COL | |
| self.processed_content_col = f"Processed Content" | |
| # State Variables | |
| self.df: Optional[pd.DataFrame] = None | |
| self.index: Optional[faiss.Index] = None | |
| self.embed_model: Optional[SentenceTransformer] = None | |
| self.base_reranker: Optional[CrossEncoder] = None | |
| self.fine_tuned_reranker: Optional[CrossEncoder] = None | |
| self.indexed_ids: List[str] = [] | |
| self.tokenizer: Optional[AutoTokenizer] = None | |
| self.generator: Optional[AutoModelForSeq2SeqLM] = None | |
| self.normalizer = None | |
| # self.device = "cuda" if torch.cuda.is_available() else "cpu" # Old device detection | |
| logger.info("Determining compute device...") | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| logger.info("CUDA is available. Using NVIDIA GPU.") | |
| else: | |
| try: | |
| import intel_extension_for_pytorch as ipex # Attempt to import IPEX | |
| if hasattr(torch, 'xpu') and torch.xpu.is_available(): | |
| self.device = "xpu" | |
| logger.info("Intel XPU is available via IPEX. Using Intel GPU.") | |
| else: | |
| self.device = "cpu" | |
| logger.info("CUDA not available. Intel XPU not available or IPEX not fully configured. Using CPU.") | |
| except ImportError: | |
| self.device = "cpu" | |
| logger.info("CUDA not available. Intel Extension for PyTorch (IPEX) not found. Using CPU.") | |
| except Exception as e: # Catch other potential errors during XPU check | |
| self.device = "cpu" | |
| logger.error(f"Error during Intel XPU check: {e}. Using CPU.") | |
| logger.info(f"Selected device: {self.device}") | |
| # Initialize MongoDB collection only if connection is available | |
| if mongodb.db is not None: | |
| self.faiss_meta_collection = mongodb.db[MONGO_FAISS_META_COLLECTION_NAME] | |
| else: | |
| self.faiss_meta_collection = None | |
| logger.warning("MongoDB not available. Some features may be limited.") | |
| # Initialize model trainer | |
| self.model_trainer = ModelTrainer(RERANKER_MODEL_NAME, device=self.device) # Pass the determined device | |
| self._setup_indic_nlp() | |
| logger.info("RecoRecommender initialized.") | |
| def _setup_indic_nlp(self): | |
| """Initialize Indic NLP resources.""" | |
| logger.info("Setting up Indic NLP resources...") # Log the start of Indic NLP setup | |
| if not os.path.exists(INDIC_NLP_RESOURCES_PATH): # Check if the Indic NLP resources path exists | |
| raise FileNotFoundError(f"Indic NLP resources not found at {INDIC_NLP_RESOURCES_PATH}") # Raise error if path not found | |
| os.environ["INDIC_RESOURCES_PATH"] = INDIC_NLP_RESOURCES_PATH # Set environment variable for Indic NLP resource path | |
| try: # Start a try-except block for error handling | |
| common.set_resources_path(INDIC_NLP_RESOURCES_PATH) # Set the resource path for the indicnlp library | |
| self.normalizer = IndicNormalizerFactory().get_normalizer("hi") # Initialize the language normalizer | |
| logger.info("Indic NLP resources setup complete.") # Log successful setup | |
| except Exception as e: # Catch any exception during setup | |
| logger.error(f"Error setting up Indic NLP resources: {e}", exc_info=True) # Log the error with traceback | |
| raise # Re-raise the exception to halt execution if setup fails | |
| def load_models(self): | |
| """Load all required ML models.""" | |
| logger.info("Loading ML models...") # Log the start of ML model loading | |
| try: # Start a try-except block for error handling | |
| if self.embed_model is None: # Check if the embedding model is not already loaded | |
| logger.info(f"Loading embedding model: {EMBED_MODEL_NAME}") # Log which embedding model is being loaded | |
| self.embed_model = SentenceTransformer(EMBED_MODEL_NAME, device=self.device) # Load the sentence transformer model | |
| if self.tokenizer is None or self.generator is None: # Check if tokenizer or generator model is not loaded | |
| logger.info(f"Loading generator model: {GENERATOR_MODEL_NAME}") # Log which generator model is being loaded | |
| self.tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_NAME, model_max_length=1024) # Load tokenizer for the generator | |
| self.generator = AutoModelForSeq2SeqLM.from_pretrained(GENERATOR_MODEL_NAME) # Load the sequence-to-sequence LM | |
| if self.device == "cuda": # Check if CUDA (GPU) is the selected device | |
| self.generator = self.generator.to(self.device) # Move the generator model to the GPU | |
| # Load base reranker model | |
| if self.base_reranker is None: | |
| logger.info(f"Loading base reranker model: {RERANKER_MODEL_NAME}") | |
| self.base_reranker = CrossEncoder(RERANKER_MODEL_NAME, device=self.device) | |
| self.reranker = self.base_reranker # Default to base model | |
| # Try to load fine-tuned model if available | |
| self._load_fine_tuned_model() | |
| logger.info("ML models loaded.") # Log successful loading of all models | |
| except Exception as e: # Catch any exception during model loading | |
| logger.error(f"Error loading ML models: {e}", exc_info=True) # Log the error with traceback | |
| raise # Re-raise the exception | |
| def _load_fine_tuned_model(self): | |
| """Attempt to load the fine-tuned reranker model if available.""" | |
| try: | |
| metadata = self.model_trainer.get_model_status() | |
| if metadata.get("current_model_status") == MODEL_STATUS["FINE_TUNED"]: | |
| current_version = metadata.get("current_version", "v0") | |
| model_path = str(self.model_trainer.get_model_path(current_version)) | |
| if os.path.exists(model_path): | |
| logger.info(f"Loading fine-tuned reranker model version {current_version}") | |
| self.fine_tuned_reranker = CrossEncoder(model_path, device=self.device) | |
| self.reranker = self.fine_tuned_reranker | |
| logger.info("Successfully loaded fine-tuned reranker model") | |
| return | |
| logger.info("No fine-tuned model found or not in fine-tuned state, using base model") | |
| self.reranker = self.base_reranker | |
| except Exception as e: | |
| logger.error(f"Error loading fine-tuned model: {e}. Falling back to base model.") | |
| self.reranker = self.base_reranker | |
| def _calculate_embedding_checksum(self, content: str) -> str: | |
| """Calculate a checksum for content to detect changes in embedding logic.""" | |
| import hashlib | |
| # Include model name and any relevant preprocessing parameters in the checksum | |
| checksum_content = f"{content}_{EMBED_MODEL_NAME}" | |
| return hashlib.md5(checksum_content.encode()).hexdigest() | |
| def _get_model_metadata(self) -> Dict: | |
| """Retrieve current model metadata from MongoDB.""" | |
| try: | |
| if self.faiss_meta_collection is None: | |
| logger.warning("MongoDB not available. Returning default metadata.") | |
| return { | |
| "_id": self.MODEL_METADATA_DOC_ID, | |
| "embedding_model_name": EMBED_MODEL_NAME, | |
| } | |
| metadata = self.faiss_meta_collection.find_one({"_id": self.MODEL_METADATA_DOC_ID}) | |
| if metadata: | |
| return metadata | |
| # Default metadata for the recommender service, focusing on aspects not directly | |
| # managed by ModelTrainer's file-based metadata (e.g., embedding model name). | |
| # Reranker model status, version, and its specific fine-tuning configuration | |
| # are primarily managed by ModelTrainer and its model_metadata.json file. | |
| return { | |
| "_id": self.MODEL_METADATA_DOC_ID, | |
| "embedding_model_name": EMBED_MODEL_NAME, | |
| # Add other global operational metadata specific to RecoRecommender here if needed. | |
| } | |
| except Exception as e: | |
| logger.error(f"Error retrieving model metadata: {e}") | |
| # Return a minimal fallback to ensure core functionalities can proceed if possible | |
| return {"_id": self.MODEL_METADATA_DOC_ID, "embedding_model_name": EMBED_MODEL_NAME} | |
| def _update_model_metadata(self, updates: Dict) -> bool: | |
| """Update model metadata in MongoDB.""" | |
| try: | |
| if self.faiss_meta_collection is None: | |
| logger.warning("MongoDB not available. Cannot update model metadata.") | |
| return False | |
| result = self.faiss_meta_collection.update_one( | |
| {"_id": self.MODEL_METADATA_DOC_ID}, | |
| {"$set": {**updates, "metadata_last_updated": datetime.now()}}, # Key changed for clarity | |
| upsert=True | |
| ) | |
| return result.acknowledged | |
| except Exception as e: | |
| logger.error(f"Error updating model metadata: {e}") | |
| return False | |
| def _increment_model_version(self) -> str: | |
| """DEPRECATED: Reranker model versioning is handled by ModelTrainer.""" | |
| # This method appears unused and its logic conflicts with ModelTrainer's | |
| # file-system based versioning for fine-tuned models. | |
| logger.warning("RecoRecommender._increment_model_version() is deprecated. Reranker versioning is managed by ModelTrainer.") | |
| # Fallback to ModelTrainer's current version if absolutely needed, but ideally this method should be removed. | |
| return self.model_trainer.get_model_status().get("current_version", "v0") | |
| def _needs_reembedding_batch(self, doc_ids: List[str], current_checksum: str) -> List[str]: | |
| """Check which documents from a batch need reembedding.""" | |
| try: | |
| if self.faiss_meta_collection is None: | |
| logger.warning("MongoDB not available. Assuming all documents need reembedding.") | |
| return doc_ids | |
| # Query for all documents in one go | |
| metadata_docs = self.faiss_meta_collection.find( | |
| {"_id": {"$in": doc_ids}} | |
| ) | |
| # Create a dict for quick lookup | |
| metadata_map = {doc["_id"]: doc.get(self.EMBEDDING_CHECKSUM_KEY) for doc in metadata_docs} | |
| # Return IDs that need reembedding | |
| return [ | |
| doc_id for doc_id in doc_ids | |
| if doc_id not in metadata_map or metadata_map[doc_id] != current_checksum | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error checking embedding status for batch: {e}") | |
| return doc_ids # If error, assume all need reembedding | |
| def _update_embedding_metadata(self, doc_id: str, checksum: str): | |
| """Update metadata for document embeddings.""" | |
| try: | |
| self.faiss_meta_collection.update_one( | |
| {"_id": doc_id}, | |
| { | |
| "$set": { | |
| self.EMBEDDING_CHECKSUM_KEY: checksum, | |
| "last_updated": datetime.now() | |
| } | |
| }, | |
| upsert=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error updating embedding metadata for doc {doc_id}: {e}") | |
| def _needs_reembedding(self, doc_id: str, current_checksum: str) -> bool: | |
| """Check if a document needs to be reembedded.""" | |
| try: | |
| metadata = self.faiss_meta_collection.find_one({"_id": doc_id}) | |
| if not metadata or metadata.get(self.EMBEDDING_CHECKSUM_KEY) != current_checksum: | |
| return True | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error checking embedding status for doc {doc_id}: {e}") | |
| return True | |
| async def update_embeddings_and_index(self, force_reload_data: bool = True): | |
| """ | |
| Updates embeddings and the FAISS index. | |
| If existing indexed documents have changed or the embedding model is different, | |
| a full rebuild of the index is triggered. | |
| Otherwise, only new documents are added incrementally. | |
| The FAISS index file (INDEX_PATH) and metadata are saved. | |
| """ | |
| if force_reload_data: | |
| logger.info("Reloading data from MongoDB for index update...") | |
| self._load_data_from_mongo() | |
| if self.df is None or self.df.empty: | |
| logger.warning("DataFrame is empty. Cannot update embeddings or index.") | |
| return | |
| if self.embed_model is None: | |
| logger.info("Embedding model not loaded. Loading models...") | |
| self.load_models() # Ensure embed_model is available | |
| logger.info("Checking for documents requiring embedding updates or additions...") | |
| trigger_full_rebuild = False | |
| if self.index and self.indexed_ids: | |
| df_indexed_docs = self.df[self.df[self.id_col].isin(self.indexed_ids)] | |
| content_map_for_indexed_ids = pd.Series( | |
| df_indexed_docs[self.processed_content_col].values, | |
| index=df_indexed_docs[self.id_col] | |
| ).to_dict() | |
| for doc_id in self.indexed_ids: | |
| if doc_id in content_map_for_indexed_ids: | |
| current_content = content_map_for_indexed_ids[doc_id] | |
| current_checksum = self._calculate_embedding_checksum(current_content) | |
| if self._needs_reembedding(doc_id, current_checksum): | |
| logger.info(f"Existing document ID {doc_id} requires re-embedding. Content/model change detected.") | |
| trigger_full_rebuild = True | |
| break | |
| else: | |
| logger.info(f"Document ID {doc_id} was in index but not in current data (deleted). Rebuild needed.") | |
| trigger_full_rebuild = True | |
| break | |
| if self.index is None: | |
| logger.info("No existing FAISS index found. A full build is required.") | |
| trigger_full_rebuild = True | |
| if trigger_full_rebuild: | |
| logger.info("Triggering a full rebuild of the FAISS index.") | |
| self.build_indexes_and_save(data_already_loaded=True) # self.df is already loaded | |
| return | |
| logger.info("No changes detected in existing indexed documents that require a full rebuild.") | |
| current_indexed_ids_set = set(self.indexed_ids) | |
| all_df_ids_set = set(self.df[self.id_col].tolist()) | |
| new_doc_ids_to_add = list(all_df_ids_set - current_indexed_ids_set) | |
| if new_doc_ids_to_add: | |
| logger.info(f"Found {len(new_doc_ids_to_add)} new documents to add to the index.") | |
| new_docs_df = self.df[self.df[self.id_col].isin(new_doc_ids_to_add)].copy() | |
| if not new_docs_df.empty: | |
| new_embeddings, actual_new_ids_embedded = self._generate_embeddings(new_docs_df) | |
| if new_embeddings.size > 0: | |
| self.index.add(new_embeddings.astype(np.float32)) | |
| self.indexed_ids.extend(actual_new_ids_embedded) | |
| logger.info(f"Added {len(actual_new_ids_embedded)} new vectors to FAISS index. Total vectors: {self.index.ntotal}.") | |
| faiss.write_index(self.index, INDEX_PATH) | |
| self.faiss_meta_collection.update_one( | |
| {"_id": self.FAISS_IDS_DOC_ID}, | |
| {"$set": {"ids": self.indexed_ids, "last_updated": datetime.now(), "total_vectors": self.index.ntotal}}, | |
| upsert=True | |
| ) | |
| logger.info("Updated FAISS index and list of indexed IDs saved.") | |
| for doc_id in actual_new_ids_embedded: | |
| content_for_checksum = new_docs_df[new_docs_df[self.id_col] == doc_id][self.processed_content_col].iloc[0] | |
| checksum = self._calculate_embedding_checksum(content_for_checksum) | |
| self._update_embedding_metadata(doc_id, checksum) | |
| logger.info(f"Updated embedding metadata for {len(actual_new_ids_embedded)} new documents.") | |
| else: | |
| logger.info("No embeddings were generated for the new documents (e.g., content was empty).") | |
| else: | |
| logger.info("Identified new document IDs, but the corresponding DataFrame slice was empty.") | |
| else: | |
| logger.info("No new documents to add. Index is up-to-date with the current DataFrame.") | |
| # Ensure index is saved if any changes (new docs added) occurred, even if not a full rebuild | |
| # The logic above already saves if new_embeddings.size > 0. | |
| # If no new docs and no rebuild, this point is reached. | |
| # Redundant save check removed; saves are handled within the conditional branches above. | |
| def fallback_to_base_model(self): | |
| """Switch to using the base model for inference.""" | |
| if self.base_reranker is not None: | |
| self.reranker = self.base_reranker | |
| logger.info("Switched to base reranker model") | |
| else: | |
| logger.warning("Base reranker model not available") | |
| def use_fine_tuned_model(self): | |
| """Switch to using the fine-tuned model for inference.""" | |
| if self.fine_tuned_reranker is not None: | |
| self.reranker = self.fine_tuned_reranker | |
| logger.info("Switched to fine-tuned reranker model") | |
| else: | |
| logger.warning("Fine-tuned model not available, staying with current model") | |
| def _load_data_from_mongo(self): | |
| """Load and preprocess data from MongoDB.""" | |
| logger.info("Loading data from MongoDB...") # Log the start of data loading from MongoDB | |
| try: # Start a try-except block for error handling | |
| # Fetch documents with projection | |
| projection = { # Define which fields to retrieve from MongoDB | |
| self.id_col: 1, | |
| self.headline_col: 1, # Include the headline column | |
| self.key: 1, # Include the key column | |
| self.syn: 1, # Include the system column | |
| self.topic_col: 1, # Include the topic column | |
| self.taxonomy_col: 1, # Include the taxonomy column | |
| self.property_col: 1, # Include the property column | |
| self.seolocation_col: 1, | |
| self.deeplink_col: 1, | |
| self.last_updated_col: 1, | |
| self.image_id_col: 1, | |
| self.image_ratio_col: 1, | |
| self.image_size_col: 1, | |
| "_id": 0 # Exclude the default MongoDB _id field | |
| } | |
| cursor = mongodb.news_collection.find({}, projection) # Execute the find query on the news_collection | |
| data = list(cursor) # Convert the cursor result to a list of dictionaries | |
| if not data: # Check if no data was returned from MongoDB | |
| logger.warning("No documents found in MongoDB.") # Log a warning if no documents are found | |
| self.df = pd.DataFrame(columns=[ # Create an empty DataFrame with expected columns | |
| self.id_col, self.headline_col, self.key, self.syn, self.taxonomy_col, | |
| self.topic_col, self.property_col, self.processed_content_col, | |
| self.seolocation_col, self.deeplink_col, self.last_updated_col, | |
| self.image_id_col, self.image_ratio_col, self.image_size_col | |
| ]) | |
| return # Exit the method as there's no data to process | |
| self.df = pd.DataFrame(data) # Convert the list of dictionaries to a pandas DataFrame | |
| logger.info(f"Loaded {len(self.df)} documents from MongoDB.") # Log the number of documents loaded | |
| # Data cleaning and preprocessing | |
| required_cols = [self.id_col, self.headline_col, self.key, self.syn] # Update the list of required columns | |
| for col in required_cols: # Iterate over the list of required columns | |
| if col not in self.df.columns: # Check if a required column is missing in the DataFrame | |
| raise ValueError(f"Required column '{col}' not found in MongoDB documents.") # Raise error if missing | |
| # Handle optional columns that are primarily textual or have simple defaults | |
| textual_optional_cols = { | |
| self.topic_col: "N/A", # Default to "N/A" string | |
| self.property_col: "N/A", # Default to "N/A" string | |
| self.key: "" # Default to empty string, used for text processing | |
| } | |
| for col, default_val in textual_optional_cols.items(): | |
| if col not in self.df.columns: | |
| logger.warning(f"Optional column '{col}' not found. Adding with default value '{default_val}'.") # Log a warning | |
| self.df[col] = default_val | |
| else: | |
| self.df[col] = self.df[col].fillna(default_val).astype(str) # Fill NA and ensure string | |
| # Handle other optional columns that are expected to be None if missing or all NaN | |
| other_optional_cols = [ | |
| self.seolocation_col, self.deeplink_col, self.last_updated_col, | |
| self.image_id_col, self.image_ratio_col, self.image_size_col | |
| ] | |
| for col in other_optional_cols: | |
| if col not in self.df.columns: | |
| logger.warning(f"Optional column '{col}' not found. Adding with default value None.") | |
| self.df[col] = None | |
| else: | |
| # If column exists, convert np.nan to None for object dtypes. | |
| # For numeric dtypes, np.nan is the standard missing value representation. | |
| if pd.api.types.is_object_dtype(self.df[col].dtype): | |
| self.df[col] = self.df[col].replace({np.nan: None}) | |
| # Special handling for taxonomy column (list of objects) | |
| if self.taxonomy_col not in self.df.columns: | |
| logger.warning(f"Taxonomy column '{self.taxonomy_col}' not found. Adding with default empty list for each row.") | |
| self.df[self.taxonomy_col] = [[] for _ in range(len(self.df))] | |
| else: | |
| # Apply the cleaning function to the taxonomy column | |
| self.df[self.taxonomy_col] = self.df[self.taxonomy_col].apply(self._clean_taxonomy_cell) | |
| # Clean data | |
| initial_len = len(self.df) # Store the number of rows before cleaning | |
| self.df = self.df.dropna(subset=[self.id_col, self.headline_col, self.syn]) # Drop rows where essential columns are NaN | |
| self.df = self.df[self.df[self.headline_col].apply(lambda x: isinstance(x, str) and x.strip() != '')] # Keep rows with non-empty string headlines | |
| self.df = self.df[self.df[self.syn].apply(lambda x: isinstance(x, str) and x.strip() != '')] | |
| if len(self.df) < initial_len: # Check if any rows were dropped during cleaning | |
| logger.warning(f"Dropped {initial_len - len(self.df)} rows due to missing/invalid values.") # Log the number of dropped rows | |
| if self.df.empty: # Check if the DataFrame is empty after cleaning | |
| logger.warning("DataFrame is empty after cleaning.") # Log a warning | |
| self.df = pd.DataFrame(columns=[ # Re-initialize an empty DataFrame with expected columns | |
| self.id_col, self.headline_col, self.key, self.syn, self.taxonomy_col, | |
| self.topic_col, self.property_col, self.processed_content_col, | |
| self.seolocation_col, self.deeplink_col, self.last_updated_col, | |
| self.image_id_col, self.image_ratio_col, self.image_size_col | |
| ]) | |
| return # Exit if DataFrame is empty | |
| # Improved contextual preprocessing: combine headline, synopsis, and keywords for richer embeddings | |
| logger.info("Preprocessing and combining content (headline, synopsis, keywords, taxonomy) for contextual embeddings...") | |
| self.df[self.headline_col] = self.df[self.headline_col].astype(str) | |
| self.df[self.syn] = self.df[self.syn].astype(str) | |
| self.df[self.key] = self.df[self.key].astype(str) | |
| def combine_and_preprocess_content(row): | |
| # Combine headline, synopsis, keywords, and taxonomy for better context | |
| headline = row[self.headline_col].strip() | |
| synopsis = row[self.syn].strip() | |
| keywords = str(row[self.key]).strip() # Ensure string, though already handled | |
| # Preprocess each part | |
| processed_headline = self._preprocess_text(headline) | |
| processed_synopsis = self._preprocess_text(synopsis) | |
| processed_keywords = self._preprocess_text(keywords) | |
| # Process taxonomy terms | |
| taxonomy_terms_list = row[self.taxonomy_col] | |
| processed_taxonomy_names = [] | |
| if isinstance(taxonomy_terms_list, list): | |
| for term_obj in taxonomy_terms_list: | |
| if isinstance(term_obj, dict) and "name" in term_obj and term_obj["name"]: | |
| processed_taxonomy_names.append(self._preprocess_text(str(term_obj["name"]))) | |
| processed_taxonomy_string = " ".join(p_name for p_name in processed_taxonomy_names if p_name) | |
| # Join all parts, prioritizing non-empty fields | |
| parts = [p for p in [processed_headline, processed_synopsis, processed_keywords, processed_taxonomy_string] if p] | |
| # A more structured combination might help models understand the different parts | |
| structured_parts = [] | |
| if processed_headline: structured_parts.append(f"शीर्षक: {processed_headline}") | |
| if processed_synopsis: structured_parts.append(f"सारांश: {processed_synopsis}") | |
| if processed_keywords: structured_parts.append(f"कीवर्ड: {processed_keywords}") | |
| if processed_taxonomy_string: structured_parts.append(f"श्रेणी: {processed_taxonomy_string}") | |
| return " ".join(structured_parts) | |
| self.df[self.processed_content_col] = self.df.apply(combine_and_preprocess_content, axis=1) | |
| logger.info("Data loading and preprocessing complete.") # Log completion of data loading and preprocessing | |
| except Exception as e: # Catch any exception during the process | |
| # Ensure df is None or empty on failure | |
| logger.error(f"Error loading data from MongoDB: {e}", exc_info=True) # Log the error with traceback | |
| self.df = None # Set DataFrame to None to indicate failure | |
| raise # Re-raise the exception | |
| def _clean_taxonomy_cell(self, cell_value) -> list: | |
| """ | |
| Cleans individual cells of the taxonomy column. | |
| Ensures that the cell content is a list, handling scalars, NaNs, and unexpected array types. | |
| """ | |
| if isinstance(cell_value, list): | |
| # If it's already a list (empty or not), return it. | |
| return cell_value | |
| elif pd.api.types.is_scalar(cell_value) and pd.isna(cell_value): | |
| # Handles scalar np.nan, None, pd.NA by returning an empty list. | |
| return [] | |
| elif isinstance(cell_value, (np.ndarray, pd.Series)): | |
| # Handles unexpected np.ndarray or pd.Series in a cell. | |
| logger.warning( | |
| f"Unexpected array/Series type in taxonomy column cell: {type(cell_value)}. " | |
| f"Content (first 100 chars): {str(cell_value)[:100]}. Converting to empty list." | |
| ) | |
| return [] | |
| else: | |
| # Handles other scalar types (e.g., string, int) or unhandled types. | |
| # These are converted to an empty list, mimicking the original lambda's behavior. | |
| if not pd.api.types.is_scalar(cell_value): # Log if it's an unexpected non-scalar, non-list, non-array type | |
| logger.warning( | |
| f"Unexpected non-scalar, non-list/array type in taxonomy column cell: {type(cell_value)}. " | |
| f"Content (first 100 chars): {str(cell_value)[:100]}. Converting to empty list." | |
| ) | |
| return [] | |
| def _preprocess_text(self, text: str) -> str: | |
| """Normalize and tokenize Hindi text.""" | |
| """ | |
| Normalize and tokenize text. Applies Indic normalization for Hindi text | |
| and generic tokenization for others. | |
| """ | |
| if not isinstance(text, str): | |
| logger.warning(f"Received non-string input for preprocessing: {type(text)}") | |
| return "" | |
| text = text.strip() # Remove leading/trailing whitespace | |
| if not text: # Handle empty strings after stripping | |
| return "" | |
| try: | |
| # Check for presence of Devanagari characters to identify Hindi text. | |
| # This is a basic check; for more complex scenarios, consider a language detection library. | |
| is_text = bool(re.search(r'[\u0900-\u097F]', text)) | |
| if is_text: | |
| if self.normalizer is None: | |
| logger.error("IndicNormalizer not initialized, but text detected. Proceeding without normalization.") | |
| normalized_text = text # Fallback: use original text if normalizer is missing | |
| else: | |
| normalized_text = self.normalizer.normalize(text) | |
| else: | |
| # For non-Hindi text, skip Hindi-specific normalization | |
| # logger.debug(f"Skipping Hindi normalization for non-Hindi text: {text[:50]}...") # Optional: enable if needed | |
| normalized_text = text | |
| tokens = indic_tokenize.trivial_tokenize(normalized_text) # Tokenize the (potentially normalized) text | |
| return " ".join(tokens) | |
| except Exception as e: | |
| logger.error(f"Error during text preprocessing for text starting with '{text[:50]}...': {e}", exc_info=True) | |
| # Fallback to returning the original text to prevent downstream errors | |
| return text | |
| def _generate_embeddings(self, df_subset: pd.DataFrame) -> tuple[np.ndarray, List[str]]: | |
| """ | |
| Generate embeddings for the processed content column of a DataFrame subset. | |
| Returns a tuple of (embeddings_array, list_of_ids). | |
| """ | |
| if not self.embed_model: # Check if the embedding model is loaded | |
| raise RuntimeError("Embedding model not loaded") # Raise an error if the model is not loaded | |
| if df_subset.empty or self.processed_content_col not in df_subset.columns or self.id_col not in df_subset.columns: | |
| logger.warning( | |
| f"DataFrame subset is empty or missing required columns " | |
| f"('{self.processed_content_col}', '{self.id_col}') for embedding generation." | |
| ) | |
| return np.array([]), [] | |
| texts_to_embed = df_subset[self.processed_content_col].tolist() | |
| # Ensure all texts are strings (preprocessing should ideally handle this, but as a safeguard) | |
| texts_to_embed = [str(t) if pd.notna(t) else "" for t in texts_to_embed] | |
| ids_to_embed = df_subset[self.id_col].tolist() | |
| if not texts_to_embed: # If all texts became empty strings or list was initially empty | |
| logger.warning("No non-empty texts available in the subset to generate embeddings.") | |
| return np.array([]), [] | |
| try: # Start a try-except block for error handling | |
| embeddings = self.embed_model.encode(texts_to_embed, show_progress_bar=False, convert_to_numpy=True) # Generate embeddings | |
| return embeddings, ids_to_embed # Return the generated embeddings and their IDs | |
| except Exception as e: # Catch any exception during embedding generation | |
| logger.error(f"Error generating embeddings: {e}", exc_info=True) # Log the error with traceback | |
| return np.array([]), [] # Return empty array and list in case of error | |
| def _build_faiss_index(self) -> List[str]: | |
| """Build FAISS index from processed content and return the list of indexed IDs.""" | |
| if self.df is None or self.df.empty or self.processed_content_col not in self.df.columns: | |
| raise ValueError("Cannot build FAISS index: Data not ready (DataFrame is None, empty, or missing processed content column)") | |
| if self.embed_model is None: | |
| raise RuntimeError("Cannot build FAISS index: Embedding model not loaded") | |
| logger.info("Building FAISS index...") | |
| if self.df.empty: | |
| raise ValueError("DataFrame is empty, cannot build FAISS index") | |
| # Process in batches for better memory management and parallelization | |
| BATCH_SIZE = 1000 # Adjust based on available memory | |
| total_docs = len(self.df) | |
| all_embeddings = [] | |
| all_ids = [] | |
| for start_idx in range(0, total_docs, BATCH_SIZE): | |
| end_idx = min(start_idx + BATCH_SIZE, total_docs) | |
| batch_df = self.df.iloc[start_idx:end_idx] | |
| # Generate embeddings for the batch | |
| batch_embeddings, batch_ids = self._generate_embeddings(batch_df) | |
| if batch_embeddings.size > 0: | |
| all_embeddings.append(batch_embeddings) | |
| all_ids.extend(batch_ids) | |
| logger.info(f"Processed batch {start_idx//BATCH_SIZE + 1}/{(total_docs + BATCH_SIZE - 1)//BATCH_SIZE}") | |
| if not all_embeddings: | |
| logger.warning("No embeddings were generated. FAISS index will be empty.") | |
| self.index = None | |
| return [] | |
| # Concatenate all batch embeddings | |
| embeddings = np.vstack(all_embeddings).astype(np.float32) | |
| dimension = embeddings.shape[1] | |
| # Initialize HNSW index with optimized parameters | |
| hnsw_m = 32 # Number of neighbors per layer | |
| ef_construction = 100 # Higher value = better accuracy but slower construction | |
| self.index = faiss.IndexHNSWFlat(dimension, hnsw_m, faiss.METRIC_INNER_PRODUCT) | |
| self.index.hnsw.efConstruction = ef_construction | |
| # Add vectors in batches to reduce memory usage | |
| BATCH_SIZE_INDEX = 10000 # Adjust based on available memory | |
| for i in range(0, len(embeddings), BATCH_SIZE_INDEX): | |
| batch = embeddings[i:i + BATCH_SIZE_INDEX] | |
| self.index.add(batch) | |
| logger.info(f"Added batch {i//BATCH_SIZE_INDEX + 1}/{(len(embeddings) + BATCH_SIZE_INDEX - 1)//BATCH_SIZE_INDEX} to index") | |
| logger.info(f"FAISS index built with {self.index.ntotal} vectors") | |
| return all_ids | |
| def _load_faiss_index_and_ids(self): | |
| """Load FAISS index and corresponding IDs from files.""" | |
| if not os.path.exists(INDEX_PATH): # Check if the FAISS index file exists at the specified path | |
| raise FileNotFoundError(f"FAISS index file not found at {INDEX_PATH}") # Raise error if file not found | |
| logger.info(f"Loading FAISS index from: {INDEX_PATH}") # Log the path from which the index is being loaded | |
| try: # Start a try-except block for error handling | |
| self.index = faiss.read_index(INDEX_PATH) # Read the FAISS index from the file | |
| logger.info(f"FAISS index loaded with {self.index.ntotal} vectors") # Log the number of vectors in the loaded index | |
| except Exception as e: # Catch any exception during index loading | |
| logger.error(f"Error loading FAISS index: {e}", exc_info=True) # Log the error with traceback | |
| self.index = None # Ensure index is None on failure | |
| raise # Re-raise the exception | |
| # Try to load IDs from MongoDB if available | |
| if self.faiss_meta_collection is not None: | |
| logger.info(f"Loading FAISS index IDs from MongoDB collection '{MONGO_FAISS_META_COLLECTION_NAME}', document_id '{self.FAISS_IDS_DOC_ID}'") | |
| try: | |
| ids_document = self.faiss_meta_collection.find_one({"_id": self.FAISS_IDS_DOC_ID}) | |
| if ids_document and "ids" in ids_document: | |
| self.indexed_ids = ids_document["ids"] | |
| logger.info(f"Loaded {len(self.indexed_ids)} indexed IDs from MongoDB.") | |
| # Basic consistency check | |
| if self.index and self.index.ntotal != len(self.indexed_ids): | |
| logger.warning( | |
| f"FAISS index vector count ({self.index.ntotal}) " | |
| f"does not match loaded ID count from MongoDB ({len(self.indexed_ids)}). " | |
| "Index might be inconsistent. Consider rebuilding." | |
| ) | |
| else: | |
| logger.warning(f"FAISS index IDs document not found in MongoDB or 'ids' field missing. Will attempt to build if necessary.") | |
| self.indexed_ids = [] # Initialize as empty if not found | |
| except Exception as e: | |
| logger.error(f"Error loading FAISS index IDs from MongoDB: {e}", exc_info=True) | |
| self.indexed_ids = [] # Ensure IDs list is empty on failure | |
| # We don't re-raise here, as load_components will decide if a rebuild is needed | |
| # based on whether self.indexed_ids is populated. | |
| else: | |
| logger.warning("MongoDB not available. Cannot load indexed IDs. Operating with empty ID list.") | |
| self.indexed_ids = [] | |
| def build_indexes_and_save(self, data_already_loaded: bool = False): | |
| """ | |
| Load data (if not already loaded), build FAISS index from current self.df, and save. | |
| Assumes self.df is populated if data_already_loaded is True. | |
| """ | |
| logger.info("Starting index building process...") # Log the start of the index building and saving process | |
| try: # Start a try-except block for error handling | |
| if not data_already_loaded: | |
| self._load_data_from_mongo() # Load data from MongoDB if not already loaded | |
| if self.df is None or self.df.empty: # Check if DataFrame is None or empty after loading | |
| raise ValueError("Data is empty. Cannot build index.") # Raise error if data is not loaded properly | |
| if self.embed_model is None: # Ensure models are loaded before building index | |
| self.load_models() | |
| # Build the FAISS index using the loaded data and models | |
| # This method now returns the list of IDs that were indexed | |
| indexed_ids = self._build_faiss_index() | |
| self.indexed_ids = indexed_ids # Store the list of IDs | |
| # Save FAISS index | |
| logger.info(f"Saving FAISS index to: {INDEX_PATH}") # Log the path where the index will be saved | |
| index_dir = os.path.dirname(INDEX_PATH) # Get the directory part of the index path | |
| if index_dir and not os.path.exists(index_dir): # If the directory exists and is not empty string | |
| os.makedirs(index_dir) # Create the directory if it doesn't exist | |
| # Save the index and the corresponding IDs | |
| faiss.write_index(self.index, INDEX_PATH) | |
| # Save indexed_ids to MongoDB | |
| logger.info(f"Saving FAISS index IDs to MongoDB collection '{MONGO_FAISS_META_COLLECTION_NAME}', document_id '{self.FAISS_IDS_DOC_ID}'") | |
| try: | |
| if self.faiss_meta_collection is not None: | |
| self.faiss_meta_collection.update_one( | |
| {"_id": self.FAISS_IDS_DOC_ID}, | |
| {"$set": {"ids": self.indexed_ids, "last_updated": datetime.now()}}, | |
| upsert=True | |
| ) | |
| logger.info(f"Saved {len(self.indexed_ids)} indexed IDs to MongoDB.") | |
| else: | |
| logger.warning("MongoDB not available. Skipping indexed IDs save.") | |
| except Exception as e: | |
| logger.error(f"Error saving FAISS index IDs to MongoDB: {e}", exc_info=True) | |
| # Don't raise here as the FAISS index was saved successfully | |
| # The IDs can be regenerated if needed | |
| logger.info("Index building and saving complete") # Log successful completion | |
| except Exception as e: # Catch any exception during the process | |
| logger.error(f"Error during index building: {e}", exc_info=True) # Log the error with traceback | |
| raise # Re-raise the exception | |
| def load_components(self): | |
| """Load all components (models, data, index).""" | |
| logger.info("Loading components...") # Log the start of component loading | |
| try: # Start a try-except block for error handling | |
| self.load_models() # Load all machine learning models | |
| # Try to load data from MongoDB, but handle failures gracefully | |
| try: | |
| self._load_data_from_mongo() # Load data from MongoDB and preprocess it | |
| logger.info("Successfully loaded data from MongoDB") | |
| except Exception as mongo_error: | |
| logger.warning(f"Failed to load data from MongoDB: {mongo_error}") | |
| logger.info("Attempting to work with existing FAISS index without MongoDB data...") | |
| # Set df to None to indicate no MongoDB data is available | |
| self.df = None | |
| # Try to load FAISS index and IDs | |
| try: | |
| self._load_faiss_index_and_ids() # Tries to load .bin and IDs from Mongo. | |
| # Sets self.index and self.indexed_ids. | |
| # self.indexed_ids will be [] if Mongo data for IDs is missing. | |
| # Raises FileNotFoundError if .bin (INDEX_PATH) is missing. | |
| if self.index and self.index.ntotal > 0: | |
| logger.info(f"FAISS index loaded successfully with {self.index.ntotal} vectors") | |
| # If we have MongoDB data, proceed with normal logic | |
| if self.df is not None and not self.df.empty: | |
| # Consistency check and incremental update logic | |
| if not self.indexed_ids: | |
| logger.warning("FAISS index file loaded, but no corresponding IDs found in MongoDB. Rebuilding for consistency.") | |
| self.build_indexes_and_save(data_already_loaded=True) | |
| else: | |
| logger.info("Existing FAISS index and IDs loaded from storage.") | |
| # Proceed with incremental update logic | |
| current_df_ids = set(self.df[self.id_col].tolist()) | |
| indexed_ids_set = set(self.indexed_ids) | |
| new_ids_to_add = list(current_df_ids - indexed_ids_set) | |
| if new_ids_to_add: | |
| logger.info(f"Found {len(new_ids_to_add)} new documents to add to the index.") | |
| new_docs_df = self.df[self.df[self.id_col].isin(new_ids_to_add)].copy() | |
| new_embeddings, new_doc_ids_added = self._generate_embeddings(new_docs_df) | |
| if new_embeddings.size > 0: | |
| self.index.add(new_embeddings.astype(np.float32)) | |
| self.indexed_ids.extend(new_doc_ids_added) | |
| logger.info(f"Added {len(new_doc_ids_added)} new vectors to FAISS index. Total vectors: {self.index.ntotal}") | |
| # Save the updated FAISS index | |
| faiss.write_index(self.index, INDEX_PATH) | |
| # Try to save the updated IDs to MongoDB, but don't fail if it doesn't work | |
| try: | |
| self.faiss_meta_collection.update_one( | |
| {"_id": self.FAISS_IDS_DOC_ID}, | |
| {"$set": {"ids": self.indexed_ids, "last_updated": datetime.now()}}, | |
| upsert=True | |
| ) | |
| logger.info("Updated FAISS index and IDs saved to MongoDB.") | |
| except Exception as e: | |
| logger.warning(f"Could not save IDs to MongoDB: {e}") | |
| else: | |
| logger.info("No new documents found to add to the index. Index is up-to-date.") | |
| else: | |
| # No MongoDB data available, but we have a FAISS index | |
| logger.info("FAISS index available but no MongoDB data. Operating in limited mode.") | |
| if not self.indexed_ids: | |
| logger.warning("No indexed IDs available. Some functionality may be limited.") | |
| else: | |
| # This case handles if self.index is None (FileNotFoundError caught below) | |
| # or if index was loaded but empty and no IDs from Mongo. | |
| if self.df is not None and not self.df.empty: | |
| logger.info("FAISS index and/or IDs not found or empty. Building new index.") | |
| self.build_indexes_and_save(data_already_loaded=True) | |
| else: | |
| logger.warning("No data available (neither MongoDB nor FAISS index). Cannot build index.") | |
| except FileNotFoundError: # This means INDEX_PATH (.bin file) was not found. | |
| logger.warning(f"FAISS index file ({INDEX_PATH}) not found.") | |
| if self.df is not None and not self.df.empty: | |
| logger.info("Building index from scratch.") | |
| self.build_indexes_and_save(data_already_loaded=True) | |
| else: | |
| logger.error("Cannot build index: no data available.") | |
| except Exception as e: | |
| logger.error(f"Error loading FAISS index: {e}", exc_info=True) | |
| if self.df is not None and not self.df.empty: | |
| logger.info("Attempting to rebuild index due to loading error.") | |
| self.build_indexes_and_save(data_already_loaded=True) | |
| else: | |
| logger.error("Cannot rebuild index: no data available.") | |
| logger.info("Components loaded successfully") # Log successful loading of all components | |
| except Exception as e: # Catch any exception during component loading | |
| logger.error(f"Error loading components: {e}", exc_info=True) # Log the error with traceback | |
| raise # Re-raise the exception | |
| def get_recommendations( | |
| self, # Added self parameter | |
| query: str, | |
| k: int = DEFAULT_K, # Number of recommendations to return, defaults to DEFAULT_K from config | |
| similarity_threshold: float = SIMILARITY_THRESHOLD # Similarity threshold, defaults to SIMILARITY_THRESHOLD from config | |
| ) -> Dict: | |
| """ | |
| Get recommendations for a query. | |
| Returns a dictionary with retrieved documents and generated response. | |
| """ | |
| # Check prerequisites | |
| if self.df is None or self.df.empty: # Check if the DataFrame (content data) is loaded | |
| raise HTTPException(status_code=503, detail="Recommender data not available") # Raise 503 error if data is missing | |
| if not all([self.index, self.embed_model, self.reranker, self.generator]): # Check if all essential components are loaded | |
| missing = [ # List comprehension to find names of missing components | |
| name for name, component in [ # Iterate through component names and their instances | |
| ("FAISS index", self.index), # FAISS index | |
| ("Embedding model", self.embed_model), # Embedding model | |
| ("Reranker model", self.reranker), # Reranker model | |
| ("Generator model", self.generator) # Generator model | |
| ] if component is None # Check if the component is None (not loaded) | |
| ] | |
| raise HTTPException( # Raise 503 error if components are missing | |
| status_code=503, # HTTP status code for Service Unavailable | |
| detail=f"Recommender not fully initialized. Missing: {', '.join(missing)}" # Error detail listing missing components | |
| ) | |
| logger.info(f"Processing recommendation request: query='{query}', k={k}") # Log the incoming recommendation request | |
| # Preprocess the query | |
| processed_query = self._preprocess_text(query) # Preprocess the input Hindi query | |
| query_embedding, _ = self._generate_embeddings(pd.DataFrame({self.processed_content_col: [processed_query], self.id_col: ["query"]})) # Generate embedding for the processed query | |
| if query_embedding.size == 0: # Check if the query embedding is empty (e.g., generation failed) | |
| logger.warning("Query embedding is empty.") # Log a warning | |
| return { # Return an empty result | |
| "retrieved_documents": [], # Empty list of documents | |
| "generated_response": "No recommendations found. (Query embedding failed)" # Informative message | |
| } | |
| # Retrieve candidates from FAISS | |
| num_candidates = max(k * CANDIDATE_MULTIPLIER, k) # Determine number of candidates to fetch (k * multiplier, or at least k) | |
| try: # Start a try-except block for FAISS search | |
| D, I = self.index.search(query_embedding.astype(np.float32), num_candidates) # Search FAISS index (D=distances, I=indices) | |
| except Exception as e: # Catch any exception during FAISS search | |
| logger.error(f"Error during FAISS search: {e}", exc_info=True) # Log the error with traceback | |
| return { # Return an empty result | |
| "retrieved_documents": [], # Empty list of documents | |
| "generated_response": "No recommendations found. (FAISS search failed)" # Informative message | |
| } | |
| # Process FAISS results | |
| retrieved_faiss_indices = I[0] | |
| retrieved_faiss_scores = D[0] | |
| # Map FAISS indices to original document IDs and collect scores | |
| valid_candidate_data = [] | |
| for faiss_idx, score in zip(retrieved_faiss_indices, retrieved_faiss_scores): | |
| # Ensure FAISS index is valid and within bounds of self.indexed_ids | |
| if faiss_idx != -1 and faiss_idx < len(self.indexed_ids): | |
| valid_candidate_data.append({ | |
| "original_id": self.indexed_ids[faiss_idx], # Actual ID of the item | |
| "faiss_score": score | |
| }) | |
| elif faiss_idx != -1: # Log if faiss_idx is valid but out of bounds for indexed_ids | |
| logger.warning( | |
| f"FAISS index {faiss_idx} is out of bounds for self.indexed_ids (len: {len(self.indexed_ids)}). Skipping." | |
| ) | |
| if not valid_candidate_data: | |
| logger.info(f"No valid candidates found from FAISS/ID mapping for query '{query}'.") | |
| return { | |
| "retrieved_documents": [], | |
| "generated_response": f"No recommendations found for '{query}' (no FAISS results or ID mapping issue)." | |
| } | |
| # Create a DataFrame from valid FAISS candidates (contains 'original_id' and 'faiss_score') | |
| faiss_candidates_df = pd.DataFrame(valid_candidate_data) | |
| # Fetch full candidate details from self.df by merging | |
| # This uses the 'original_id' (which are actual item IDs) to robustly fetch data | |
| # and preserves the order from FAISS retrieval. | |
| candidates = pd.merge( | |
| faiss_candidates_df, # Left DataFrame (dictates order and includes faiss_score) | |
| self.df, # Right DataFrame (provides full item details) | |
| left_on="original_id",# Key in faiss_candidates_df | |
| right_on=self.id_col, # Key in self.df | |
| how="inner" # Ensures only items present in both are kept | |
| ) | |
| # If 'original_id' column is different from self.id_col and still exists, drop it as it's redundant | |
| if "original_id" in candidates.columns and "original_id" != self.id_col: | |
| candidates = candidates.drop(columns=["original_id"]) | |
| # Ensure the ID column is of a consistent type if needed, though it should match indexed_ids type | |
| # candidates[self.id_col] = candidates[self.id_col].astype(str) # Example if IDs need to be strings | |
| # Filter out exact matches with query (case-insensitive, strip spaces) | |
| candidates = candidates[ | |
| (candidates[self.headline_col].str.strip().str.lower() != query.strip().lower()) & | |
| (candidates[self.syn].str.strip().str.lower() != query.strip().lower()) | |
| ] | |
| if candidates.empty: | |
| logger.info(f"No candidates left after filtering exact query matches for query '{query}'.") | |
| return {"retrieved_documents": [], "generated_response": f"No distinct recommendations found for '{query}'."} | |
| candidates = candidates.drop_duplicates(subset=[self.syn]).copy() # Use .copy() after selection/drop_duplicates | |
| if candidates.empty: | |
| logger.info(f"No candidates left after dropping duplicates for query '{query}'.") | |
| return {"retrieved_documents": [], "generated_response": f"No unique recommendations found for '{query}'."} | |
| # Rerank using cross-encoder | |
| #rerank_pairs = [(query, str(row[self.headline_col])) for _, row in candidates.iterrows()] # Create pairs of (query, candidate_headline) for reranking | |
| rerank_pairs = [(query, str(row[self.syn])) for _, row in candidates.iterrows()] | |
| if rerank_pairs: # Check if there are any candidate pairs to rerank | |
| rerank_scores = self.reranker.predict(rerank_pairs, show_progress_bar=False) # Predict reranking scores | |
| logger.info(f"Raw rerank scores for query '{query}': {rerank_scores.tolist()}") # Log raw scores | |
| candidates["rerank_score"] = rerank_scores # Add rerank scores as a new column | |
| candidates = candidates.sort_values("rerank_score", ascending=False) # Sort candidates by rerank score in descending order | |
| logger.debug(f"Top candidates before thresholding (query='{query}', threshold={similarity_threshold}):") | |
| for _, row in candidates.head().iterrows(): # Log top few candidates before filtering | |
| logger.debug(f" ID: {row[self.id_col]}, Synopsis: {str(row[self.syn])[:50]}..., Rerank Score: {row['rerank_score']:.4f}") | |
| #candidates = candidates[candidates["rerank_score"] >= similarity_threshold] | |
| candidates = candidates[candidates["rerank_score"] >= similarity_threshold] | |
| logger.info(f"Number of candidates after applying similarity_threshold ({similarity_threshold}): {len(candidates)}") | |
| else: # If no pairs to rerank (e.g., all candidates were filtered out) | |
| logger.info(f"No candidate pairs to rerank for query '{query}'.") | |
| candidates["rerank_score"] = 0.0 # Assign a default rerank score of 0.0 | |
| # Select top-k | |
| top_candidates = candidates.head(k) # Select the top-k candidates after reranking | |
| # Prepare output | |
| retrieved_documents = [] # Initialize an empty list to store formatted retrieved documents | |
| for _, row in top_candidates.iterrows(): # Iterate through the top-k candidate rows | |
| taxonomy_data = row[self.taxonomy_col] | |
| taxonomy_names = [] | |
| if isinstance(taxonomy_data, list): | |
| for term_obj in taxonomy_data: | |
| if isinstance(term_obj, dict) and term_obj.get("name"): # Check key exists and has a value | |
| taxonomy_names.append(str(term_obj["name"])) | |
| retrieved_documents.append({ # Create a dictionary for each retrieved document | |
| "id": row[self.id_col], # Document ID | |
| "hl": str(row[self.headline_col]), # Document headline | |
| "synopsis": row[self.syn], # Document primary content (synopsis) | |
| "keywords": row[self.key], # Document secondary content (keywords) | |
| "type": row.get(self.topic_col, None), # Document topic (or None if not available) | |
| "taxonomy": taxonomy_names, # List of taxonomy names | |
| "score": row["rerank_score"], # Rerank score of the document | |
| "seolocation": row.get(self.seolocation_col, None), | |
| "dl": row.get(self.deeplink_col, None), | |
| "lu": row.get(self.last_updated_col, None), | |
| "imageid": row.get(self.image_id_col, None), | |
| "imgratio": row.get(self.image_ratio_col, None), | |
| "imgsize": row.get(self.image_size_col, None) | |
| }) | |
| # Optionally, generate a response using the generator model (not implemented here) | |
| generated_response = f"Top {len(retrieved_documents)} recommendations for '{query}'." # Create a simple generated response string | |
| return { # Return the final result as a dictionary | |
| "retrieved_documents": retrieved_documents, # List of retrieved documents | |
| "generated_response": generated_response # Generated textual response | |
| } | |
| def _format_retrieved_documents(self, documents_df: pd.DataFrame) -> List[Dict]: | |
| """Helper function to format DataFrame rows into a list of document dictionaries.""" | |
| retrieved_documents = [] | |
| for _, row in documents_df.iterrows(): | |
| taxonomy_data = row[self.taxonomy_col] | |
| taxonomy_names = [] | |
| if isinstance(taxonomy_data, list): | |
| for term_obj in taxonomy_data: | |
| if isinstance(term_obj, dict) and term_obj.get("name"): | |
| taxonomy_names.append(str(term_obj["name"])) | |
| retrieved_documents.append({ | |
| "id": row[self.id_col], | |
| "hl": str(row[self.headline_col]), | |
| "synopsis": str(row.get(self.syn, "")), # Ensure synopsis is string | |
| "keywords": str(row.get(self.key, "")), # Ensure keywords is string | |
| "type": row.get(self.topic_col, None), | |
| "taxonomy": taxonomy_names, | |
| "score": row.get("rerank_score", 0.0), # Use .get for safety if rerank_score might be missing | |
| "seolocation": row.get(self.seolocation_col, None), | |
| "dl": row.get(self.deeplink_col, None), | |
| "lu": row.get(self.last_updated_col, None), | |
| "imageid": row.get(self.image_id_col, None), | |
| "imgratio": row.get(self.image_ratio_col, None), | |
| "imgsize": row.get(self.image_size_col, None) | |
| }) | |
| return retrieved_documents | |
| def get_recommendations_by_id( | |
| self, | |
| msid: str, | |
| k: int = DEFAULT_K, | |
| similarity_threshold: float = SIMILARITY_THRESHOLD | |
| ) -> Dict: | |
| """ | |
| Get recommendations based on a given item ID (msid). | |
| Finds items similar to the item identified by msid. | |
| Returns a dictionary with retrieved documents. | |
| """ | |
| # Check prerequisites | |
| if self.df is None or self.df.empty: | |
| logger.error("Recommender data not available for get_recommendations_by_id.") | |
| raise HTTPException(status_code=503, detail="Recommender data not available") | |
| # Generator model is not strictly needed for this item-to-item recommendation path | |
| if not all([self.index, self.embed_model, self.reranker]): | |
| missing = [ | |
| name for name, component in [ | |
| ("FAISS index", self.index), | |
| ("Embedding model", self.embed_model), | |
| ("Reranker model", self.reranker), | |
| ] if component is None | |
| ] | |
| logger.error(f"Recommender not fully initialized for get_recommendations_by_id. Missing: {', '.join(missing)}") | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Recommender not fully initialized. Missing: {', '.join(missing)}" | |
| ) | |
| logger.info(f"Processing recommendation request for item ID: '{msid}', k={k}") | |
| # Find the source item in the DataFrame | |
| source_item_row = self.df[self.df[self.id_col] == msid] | |
| if source_item_row.empty: | |
| logger.warning(f"Item with ID '{msid}' not found in DataFrame.") | |
| # Ensure the response structure matches what RecommendationResponse expects | |
| # even on failure, to avoid FastAPI server errors during response model validation. | |
| return {"retrieved_documents": [], "generated_response": f"Item with ID '{msid}' not found."} | |
| # Or, if you prefer to let the route handler catch this as a 404: | |
| # raise HTTPException(status_code=404, detail=f"Item with ID '{msid}' not found") | |
| source_item = source_item_row.iloc[0] | |
| source_item_content_for_embedding = source_item[self.processed_content_col] | |
| source_item_content_for_reranking = str(source_item[self.syn]) # Using synopsis for reranker query | |
| # Generate embedding for the source item's content | |
| item_embedding, _ = self._generate_embeddings( | |
| pd.DataFrame({ | |
| self.processed_content_col: [source_item_content_for_embedding], | |
| self.id_col: [msid] # Dummy ID, as _generate_embeddings expects it | |
| }) | |
| ) | |
| if item_embedding.size == 0: | |
| logger.warning(f"Embedding generation failed for item ID '{msid}'.") | |
| return {"retrieved_documents": [], "generated_response": "No recommendations found (source item embedding failed)"} | |
| # Retrieve candidates from FAISS: fetch k+1 (or more with multiplier) to account for filtering source item | |
| num_candidates_to_fetch = max((k + 1) * CANDIDATE_MULTIPLIER, k + 1) | |
| try: | |
| D, I = self.index.search(item_embedding.astype(np.float32), num_candidates_to_fetch) | |
| except Exception as e: | |
| logger.error(f"Error during FAISS search for item ID '{msid}': {e}", exc_info=True) | |
| return {"retrieved_documents": [], "generated_response": "No recommendations found (FAISS search failed)"} | |
| candidate_faiss_indices = I[0] | |
| # candidate_scores = D[0] # FAISS scores, can be used if needed | |
| valid_mask = candidate_faiss_indices != -1 | |
| candidate_faiss_indices = candidate_faiss_indices[valid_mask] | |
| if len(candidate_faiss_indices) == 0: | |
| logger.info(f"No candidates found from FAISS for item ID '{msid}'.") | |
| return {"retrieved_documents": [], "generated_response": f"No similar items found for ID '{msid}'."} | |
| candidate_original_ids = [self.indexed_ids[i] for i in candidate_faiss_indices if i < len(self.indexed_ids)] | |
| # Fetch candidates from the main DataFrame, excluding the source item itself | |
| candidates_df = self.df[self.df[self.id_col].isin(candidate_original_ids) & (self.df[self.id_col] != msid)].copy() | |
| candidates_df = candidates_df.drop_duplicates(subset=[self.syn]) # Avoid duplicate content | |
| if candidates_df.empty: | |
| logger.info(f"No candidates left after filtering for item ID '{msid}'.") | |
| return {"retrieved_documents": [], "generated_response": f"No other similar items found for ID '{msid}'."} | |
| # Rerank using cross-encoder | |
| rerank_pairs = [(source_item_content_for_reranking, str(row[self.syn])) for _, row in candidates_df.iterrows()] | |
| if rerank_pairs: | |
| rerank_scores = self.reranker.predict(rerank_pairs, show_progress_bar=False) | |
| candidates_df["rerank_score"] = rerank_scores | |
| candidates_df = candidates_df.sort_values("rerank_score", ascending=False) | |
| candidates_df = candidates_df[candidates_df["rerank_score"] >= similarity_threshold] | |
| else: | |
| candidates_df["rerank_score"] = 0.0 # Default score if no pairs or reranking skipped | |
| top_candidates = candidates_df.head(k) | |
| retrieved_documents = self._format_retrieved_documents(top_candidates) | |
| generated_response = f"Top {len(retrieved_documents)} recommendations similar to item ID '{msid}'." | |
| if not retrieved_documents: | |
| generated_response = f"No recommendations found similar to item ID '{msid}'." | |
| return {"retrieved_documents": retrieved_documents, "generated_response": generated_response} | |
| def prepare_reranker_training_data_from_new_feedback_format(self, user_id: str, training_event_details: Dict) -> List[Dict]: | |
| """ | |
| Prepares training data for the reranker model from a user's feedback document. | |
| Generates both positive and negative training samples using semantic similarity | |
| based negative sampling for better model discrimination. | |
| """ | |
| if self.df is None or self.df.empty: | |
| logger.warning(f"User {user_id}: DataFrame not loaded. Cannot prepare training data.") | |
| return [] | |
| training_samples = [] | |
| query_msid = training_event_details.get("query_msid") | |
| positive_msids_list = training_event_details.get("positive_msids") | |
| if not query_msid or not isinstance(query_msid, str): | |
| logger.warning(f"User {user_id}: 'query_msid' missing or invalid in training_event_details. Details: {str(training_event_details)[:200]}") | |
| return [] | |
| if not isinstance(positive_msids_list, list) or not positive_msids_list: | |
| logger.warning(f"User {user_id}: 'positive_msids' field missing, not a list, or empty in training_event_details for query_msid '{query_msid}'. Details: {str(training_event_details)[:200]}") | |
| return [] | |
| source_item_row = self.df[self.df[self.id_col] == query_msid] | |
| if source_item_row.empty: | |
| logger.warning(f"User {user_id}: Query item (msid: {query_msid}) for training data not found in DataFrame.") | |
| return [] | |
| query_text = str(source_item_row.iloc[0].get(self.syn, "")).strip() | |
| if not query_text: | |
| logger.warning(f"User {user_id}: Query text (synopsis) is empty for msid {query_msid}. Skipping training sample generation for this event.") | |
| return [] | |
| # Process positive samples | |
| positive_samples = [] | |
| for positive_msid in positive_msids_list: | |
| if not isinstance(positive_msid, str) or not positive_msid.strip(): | |
| continue | |
| clicked_item_row = self.df[self.df[self.id_col] == positive_msid] | |
| if not clicked_item_row.empty: | |
| candidate_text_positive = str(clicked_item_row.iloc[0].get(self.syn, "")).strip() | |
| if candidate_text_positive: | |
| positive_samples.append({ | |
| "query_text": query_text, | |
| "candidate_text": candidate_text_positive, | |
| "label": 1.0, | |
| "msid": positive_msid | |
| }) | |
| if not positive_samples: | |
| logger.warning(f"User {user_id}: No valid positive samples found for query_msid {query_msid}") | |
| return [] | |
| # Generate negative samples through semantic similarity based sampling | |
| num_negatives_per_positive = 5 # Increased for better training | |
| all_msids = set(self.df[self.id_col].tolist()) | |
| positive_msids_set = set(p["msid"] for p in positive_samples) | |
| # Get query embedding for semantic search | |
| query_embedding, _ = self._generate_embeddings( | |
| pd.DataFrame({ | |
| self.processed_content_col: [query_text], | |
| self.id_col: ["temp_query"] | |
| }) | |
| ) | |
| if query_embedding.size > 0: | |
| # Get semantically similar candidates (harder negatives) | |
| D, I = self.index.search(query_embedding.astype(np.float32), k=50) # Get more candidates | |
| candidate_indices = I[0] | |
| candidate_msids = [ | |
| self.indexed_ids[idx] for idx in candidate_indices | |
| if idx != -1 and idx < len(self.indexed_ids) | |
| ] | |
| # Filter out positives and query item | |
| negative_candidates = [ | |
| msid for msid in candidate_msids | |
| if msid not in positive_msids_set and msid != query_msid | |
| ] | |
| import random | |
| for pos_sample in positive_samples: | |
| # Mix of hard and random negatives | |
| num_hard_negatives = min(3, len(negative_candidates)) | |
| num_random_negatives = num_negatives_per_positive - num_hard_negatives | |
| # Select hard negatives (semantically similar) | |
| hard_negatives = negative_candidates[:num_hard_negatives] | |
| # Select random negatives from remaining pool | |
| remaining_candidates = list(all_msids - positive_msids_set - set(hard_negatives) - {query_msid}) | |
| random_negatives = random.sample(remaining_candidates, num_random_negatives) | |
| # Combine hard and random negatives | |
| selected_negatives = hard_negatives + random_negatives | |
| for neg_msid in selected_negatives: | |
| neg_item_row = self.df[self.df[self.id_col] == neg_msid] | |
| if not neg_item_row.empty: | |
| candidate_text_negative = str(neg_item_row.iloc[0].get(self.syn, "")).strip() | |
| if candidate_text_negative: | |
| training_samples.append({ | |
| "query_text": pos_sample["query_text"], | |
| "candidate_text": candidate_text_negative, | |
| "label": 0.0, | |
| "msid": neg_msid | |
| }) | |
| # Add positive samples to final training data | |
| training_samples.extend(positive_samples) | |
| if training_samples: | |
| logger.info(f"User {user_id}: Prepared {len(training_samples)} training samples ({len(positive_samples)} positive, {len(training_samples)-len(positive_samples)} negative) from the interaction event (query_msid: {query_msid}).") | |
| return training_samples | |
| def _log_training_data_for_refinement(self, training_data: List[Dict]): | |
| """Save training data for future fine-tuning.""" | |
| if not training_data: | |
| logger.info("No training data provided for saving.") | |
| return | |
| self.model_trainer.prepare_training_data(training_data) | |
| async def check_and_trigger_fine_tuning(self): | |
| """ | |
| Check if fine-tuning should be triggered based on conditions and start if needed. | |
| Returns True if fine-tuning was triggered, False otherwise. | |
| """ | |
| try: | |
| if not self.model_trainer.check_training_conditions(): | |
| return False | |
| # Start fine-tuning process in background | |
| asyncio.create_task(self._run_fine_tuning_process()) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error in check_and_trigger_fine_tuning: {e}") | |
| return False | |
| async def _run_fine_tuning_process(self): | |
| """Run the fine-tuning process in the background.""" | |
| try: | |
| # Start fine-tuning | |
| logger.info("Starting fine-tuning process") | |
| # Run fine-tuning using the model trainer | |
| new_version = await asyncio.to_thread(self.model_trainer.fine_tune) | |
| if new_version: | |
| # Load and validate the fine-tuned model before deploying | |
| model_path = str(self.model_trainer.get_model_path(new_version)) | |
| if os.path.exists(model_path): | |
| # Load the new model for validation | |
| # This is a synchronous operation, potentially okay if quick, | |
| # but could be moved to to_thread if model loading is slow. | |
| new_model = await asyncio.to_thread(CrossEncoder, model_path, device=self.device) | |
| # Validate model performance | |
| validation_passed = await self._validate_fine_tuned_model(new_model) | |
| if validation_passed: | |
| logger.info(f"Fine-tuned model validation passed. Deploying version: {new_version}") | |
| self.fine_tuned_reranker = new_model # new_model is already on self.device | |
| self.reranker = self.fine_tuned_reranker # Switch active reranker | |
| # Update embeddings and index if needed | |
| # update_embeddings_and_index is synchronous and can be long | |
| await asyncio.to_thread(self.update_embeddings_and_index) | |
| logger.info(f"Fine-tuning process completed successfully. New version: {new_version}") | |
| else: | |
| logger.warning("Fine-tuned model validation failed. Keeping current model.") | |
| self.reranker = self.base_reranker | |
| else: | |
| logger.error(f"Fine-tuned model file not found at {model_path}") | |
| self.reranker = self.base_reranker | |
| else: | |
| logger.error("Fine-tuning process failed") | |
| self.reranker = self.base_reranker | |
| except Exception as e: | |
| logger.error(f"Error during fine-tuning process: {e}") | |
| self.reranker = self.base_reranker | |
| async def _validate_fine_tuned_model(self, new_model: CrossEncoder) -> bool: | |
| """ | |
| Validate the fine-tuned model's performance before deployment. | |
| Uses multiple metrics for a more comprehensive evaluation. | |
| Returns True if validation passes, False otherwise. | |
| """ | |
| try: | |
| # Get a sample of validation data | |
| validation_data = self.model_trainer.get_validation_data() | |
| if not validation_data: | |
| logger.warning("No validation data available") | |
| return False | |
| # Initialize metrics | |
| base_metrics = { | |
| "true_positives": 0, | |
| "false_positives": 0, | |
| "true_negatives": 0, | |
| "false_negatives": 0, | |
| "scores": [] | |
| } | |
| new_metrics = { | |
| "true_positives": 0, | |
| "false_positives": 0, | |
| "true_negatives": 0, | |
| "false_negatives": 0, | |
| "scores": [] | |
| } | |
| # Evaluate both models on validation data | |
| for sample in validation_data: | |
| query = sample["query_text"] | |
| candidate = sample["candidate_text"] | |
| label = float(sample["label"]) | |
| # Get predictions from both models | |
| # predict is synchronous and CPU/GPU bound | |
| base_pred_array = await asyncio.to_thread(self.base_reranker.predict, [(query, candidate)]) | |
| base_pred = base_pred_array[0] | |
| new_pred_array = await asyncio.to_thread(new_model.predict, [(query, candidate)]) | |
| new_pred = new_pred_array[0] | |
| # Update metrics for base model | |
| base_metrics["scores"].append(base_pred) | |
| if label == 1.0: | |
| if base_pred >= 0.5: | |
| base_metrics["true_positives"] += 1 | |
| else: | |
| base_metrics["false_negatives"] += 1 | |
| else: | |
| if base_pred >= 0.5: | |
| base_metrics["false_positives"] += 1 | |
| else: | |
| base_metrics["true_negatives"] += 1 | |
| # Update metrics for new model | |
| new_metrics["scores"].append(new_pred) | |
| if label == 1.0: | |
| if new_pred >= 0.5: | |
| new_metrics["true_positives"] += 1 | |
| else: | |
| new_metrics["false_negatives"] += 1 | |
| else: | |
| if new_pred >= 0.5: | |
| new_metrics["false_positives"] += 1 | |
| else: | |
| new_metrics["true_negatives"] += 1 | |
| if not base_metrics["scores"] or not new_metrics["scores"]: | |
| logger.warning("No predictions generated during validation") | |
| return False | |
| # Calculate metrics for both models | |
| def calculate_model_metrics(metrics): | |
| tp = metrics["true_positives"] | |
| fp = metrics["false_positives"] | |
| tn = metrics["true_negatives"] | |
| fn = metrics["false_negatives"] | |
| # Prevent division by zero | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0 | |
| return { | |
| "precision": precision, | |
| "recall": recall, | |
| "f1": f1, | |
| "accuracy": accuracy | |
| } | |
| base_performance = calculate_model_metrics(base_metrics) | |
| new_performance = calculate_model_metrics(new_metrics) | |
| # Log detailed performance comparison | |
| logger.info("Model validation results:") | |
| logger.info(f"Base model metrics: {base_performance}") | |
| logger.info(f"Fine-tuned model metrics: {new_performance}") | |
| # More lenient validation criteria | |
| min_relative_improvement = 0.001 # 0.1% minimum relative improvement | |
| min_absolute_f1 = 0.3 # Lower minimum F1 score required | |
| # Check if new model shows improvement in any metric | |
| f1_improvement = new_performance["f1"] - base_performance["f1"] | |
| precision_improvement = new_performance["precision"] - base_performance["precision"] | |
| recall_improvement = new_performance["recall"] - base_performance["recall"] | |
| accuracy_improvement = new_performance["accuracy"] - base_performance["accuracy"] | |
| # Calculate relative improvements | |
| relative_f1_imp = f1_improvement / base_performance["f1"] if base_performance["f1"] > 0 else float('inf') | |
| relative_prec_imp = precision_improvement / base_performance["precision"] if base_performance["precision"] > 0 else float('inf') | |
| relative_recall_imp = recall_improvement / base_performance["recall"] if base_performance["recall"] > 0 else float('inf') | |
| relative_acc_imp = accuracy_improvement / base_performance["accuracy"] if base_performance["accuracy"] > 0 else float('inf') | |
| # Model passes validation if it shows improvement in any metric and meets minimum F1 | |
| if new_performance["f1"] >= min_absolute_f1 and ( | |
| relative_f1_imp >= min_relative_improvement or | |
| relative_prec_imp >= min_relative_improvement or | |
| relative_recall_imp >= min_relative_improvement or | |
| relative_acc_imp >= min_relative_improvement | |
| ): | |
| logger.info( | |
| f"Fine-tuned model shows improvement. Metrics changes:\n" | |
| f"F1: {f1_improvement:.4f} ({relative_f1_imp:.2%})\n" | |
| f"Precision: {precision_improvement:.4f} ({relative_prec_imp:.2%})\n" | |
| f"Recall: {recall_improvement:.4f} ({relative_recall_imp:.2%})\n" | |
| f"Accuracy: {accuracy_improvement:.4f} ({relative_acc_imp:.2%})" | |
| ) | |
| return True | |
| else: | |
| logger.warning( | |
| f"Fine-tuned model does not meet improvement criteria.\n" | |
| f"F1 change: {f1_improvement:.4f} ({relative_f1_imp:.2%})\n" | |
| f"Precision change: {precision_improvement:.4f} ({relative_prec_imp:.2%})\n" | |
| f"Recall change: {recall_improvement:.4f} ({relative_recall_imp:.2%})\n" | |
| f"Accuracy change: {accuracy_improvement:.4f} ({relative_acc_imp:.2%})" | |
| ) | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error during model validation: {e}") | |
| return False | |
| def reload_fine_tuned_model(self): | |
| """ | |
| Reload the fine-tuned reranker model from disk and switch to it if available. | |
| """ | |
| self._load_fine_tuned_model() | |
| logger.info("Reloaded fine-tuned reranker model (if available).") | |
| def get_recommendations_summary(self, msid: str, k: int = DEFAULT_K, summary: bool = True, smart_tip: bool = True): | |
| """ | |
| Synchronous wrapper for recommendations with summary and smart tip. This is a simplified version for Gradio or direct calls. | |
| """ | |
| # Get base recommendations by msid | |
| recommendations_data = self.get_recommendations_by_id(msid, k) | |
| if not recommendations_data or "retrieved_documents" not in recommendations_data: | |
| return { | |
| "generated_response": f"No recommendations found for item ID '{msid}'.", | |
| "retrieved_documents": [] | |
| } | |
| retrieved_docs = recommendations_data.get("retrieved_documents", []) | |
| if not retrieved_docs: | |
| return recommendations_data | |
| # Fetch article details from MongoDB | |
| from src.database.mongodb import mongodb | |
| doc_ids_to_fetch = [doc["id"] for doc in retrieved_docs if doc.get("id")] | |
| articles_details_map = {} | |
| if doc_ids_to_fetch and (summary or smart_tip): | |
| projection = {"_id": 0, "id": 1} | |
| if summary: | |
| projection.update({"story": 1, "syn": 1}) | |
| if smart_tip: | |
| projection.update({"seolocation": 1, "tn": 1, "hl": 1}) | |
| fetched_articles_list = list(mongodb.news_collection.find({"id": {"$in": doc_ids_to_fetch}}, projection)) | |
| for article in fetched_articles_list: | |
| if article.get("id"): | |
| if summary and not article.get("story") and article.get("syn"): | |
| article["story"] = article["syn"] | |
| articles_details_map[article["id"]] = article | |
| # Helper functions for summary and smart tip | |
| def _generate_summary(article_data): | |
| try: | |
| from src.test_summarize import get_summary_points | |
| story = article_data.get("story", "") | |
| if not story: | |
| return None | |
| summary_points = get_summary_points(story) | |
| if isinstance(summary_points, list): | |
| return " ".join(summary_points) if summary_points else None | |
| elif isinstance(summary_points, str): | |
| return summary_points if summary_points.strip() else None | |
| return None | |
| except Exception: | |
| return None | |
| def _generate_smart_tip(article_data): | |
| seolocation = article_data.get("seolocation") | |
| title = article_data.get("tn") | |
| headline = article_data.get("hl") | |
| if not all([seolocation, title, headline]): | |
| return None | |
| # Find related articles | |
| topic = title.lower() if title else "" | |
| query = {} | |
| if topic: | |
| query["$or"] = [ | |
| {"tn": {"$regex": topic, "$options": "i"}}, | |
| {"hl": {"$regex": topic, "$options": "i"}} | |
| ] | |
| if article_data.get("id"): | |
| query["id"] = {"$ne": article_data["id"]} | |
| related_articles = list(mongodb.news_collection.find(query, {"hl": 1, "seolocation": 1, "tn": 1, "_id": 0}).limit(3)) | |
| suggestions = [] | |
| for rel_article in related_articles: | |
| if rel_article.get("hl") and rel_article.get("seolocation"): | |
| suggestions.append({ | |
| "label": rel_article.get("hl", ""), | |
| "url": rel_article.get("seolocation", "") | |
| }) | |
| if not suggestions: | |
| suggestions = [{ | |
| "label": headline, | |
| "url": seolocation | |
| }] | |
| return { | |
| "title": f"\U0001F50D Smart Tip: {title}", | |
| "description": "You might also be interested in:", | |
| "suggestions": suggestions | |
| } | |
| # Process each document | |
| processed_documents = [] | |
| for doc in retrieved_docs: | |
| article_data = articles_details_map.get(doc.get("id")) | |
| if summary and article_data: | |
| doc["summary"] = _generate_summary(article_data) | |
| if smart_tip and article_data: | |
| doc["smart_tip"] = _generate_smart_tip(article_data) | |
| processed_documents.append(doc) | |
| recommendations_data["retrieved_documents"] = processed_documents | |
| return recommendations_data | |
| def get_recommendations_user_feedback(self, user_id: str, msid: str, clicked_msid: str, k: int = DEFAULT_K): | |
| """ | |
| Synchronous wrapper for user feedback recommendations. Returns recommendations based on clicked items. | |
| """ | |
| # clicked_msid can be a comma-separated string | |
| actual_clicked_msids = [s.strip() for s in clicked_msid.split(',') if s.strip()] | |
| combined_recommendations_docs = [] | |
| seen_recommendation_ids = set() | |
| for c_msid in actual_clicked_msids: | |
| result = self.get_recommendations_by_id(c_msid, k) | |
| for doc in result.get("retrieved_documents", []): | |
| if doc['id'] not in seen_recommendation_ids: | |
| combined_recommendations_docs.append(doc) | |
| seen_recommendation_ids.add(doc['id']) | |
| if not combined_recommendations_docs: | |
| recommendations_result = {"retrieved_documents": [], "generated_response": "No recommendations found for the clicked items."} | |
| else: | |
| combined_recommendations_docs.sort(key=lambda x: x.get('score', 0.0), reverse=True) | |
| final_retrieved_documents = combined_recommendations_docs[:k] | |
| recommendations_result = { | |
| "retrieved_documents": final_retrieved_documents, | |
| "generated_response": f"Top {len(final_retrieved_documents)} recommendations based on your recent clicks on: {', '.join(actual_clicked_msids)}." | |
| } | |
| return recommendations_result | |
| # Instantiate the recommender for use by other modules | |
| recommender = RecoRecommender() |