aia / backend /functions.py
MickMick102's picture
intial commit
1abb43c
import os
import logging
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient # IMPORT AsyncMongoClient
from pythainlp.tokenize import word_tokenize # Moved import here
import models # Keep standard import
import asyncio
from typing import Optional, Dict
from utils import get_device
# import time # No longer needed for reranker
# import numpy as np # No longer needed for reranker
# import onnxruntime as ort # No longer needed for reranker
# from transformers import AutoTokenizer # No longer needed for reranker
# Load environment variables
if get_device() == "mps":
load_dotenv(override=True)
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# MongoDB Configuration
DATABASE_URL = os.getenv("MONGO_URL")
# DATABASE_URL = "mongodb://rabbit_reward:rabbit_password@host.docker.internal:27017/?directConnection=true"
DB_NAME = "homeshopping"
DEFAULT_VECTOR_INDEX = "default" # Example: Make configurable
DEFAULT_KEYWORD_INDEX = "default" # Example: Make configurable
class MongoHybridSearch:
def __init__(self, database_name=DB_NAME, mongo_uri=DATABASE_URL):
"""
Initialize MongoDB connection and embedder.
"""
try:
self.client = AsyncIOMotorClient(mongo_uri)
self.database = self.client[database_name]
# Consider making collection name configurable
self.collection = self.database["homeshopping"]
# self.collection_fact = self.database["SCG_financial_report_jai"]
self.llm_analyzer = models.LLMFinanceAnalyzer()
self.embedder = models.Embedder() # Instantiate Embedder class from models
logger.info("MongoHybridSearch initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize MongoHybridSearch: {e}")
raise # Re-raise exception to prevent app from starting with bad config
async def search_documents(self, query: str) -> list[str]:
"""
Find relevant data for each (subquery, original_query, quarter, year).
Args:
query_list (list): List of tuples (subquery, original_query, quarter, year).
Returns:
list: List of lists, where each inner list contains relevant document content strings.
Returns empty list if an error occurs during the overall search process.
"""
try:
all_docs_content = []
# for subquery, subkeyword, quarter, year in query_list: # Unpack the tuple
# Pass configured index names
result_content = await self.atlas_hybrid_search(collection_name = self.collection,
query=query,
top_k=100, # Consider making configurable
exact_top_k=17, # Consider making configurable
vector_index_name=DEFAULT_VECTOR_INDEX,
keyword_index_name=DEFAULT_KEYWORD_INDEX,
)
all_docs_content.append(result_content)
return result_content
except Exception as e:
logger.error(f"Error in search_documents: {e}")
return [] # Return empty list on failure
async def atlas_hybrid_search(self, collection_name :str, query: str, top_k: int, exact_top_k: int,
vector_index_name: str, keyword_index_name: str,
) -> list[str]:
"""
Perform hybrid search using Atlas Vector Search & Keyword Search.
Returns a list of document content strings.
"""
try:
# Ensure quarter and year are strings for MongoDB query
# quarter_str = [str(quarter)]
# year_str = [str(year)]
# if collection_name == "fact":
# collection = self.collection_fact
# elif collection_name == "report":
# collection = self.collection_report
# top_k = 15 # For report collection, we might want fewer results
# exact_top_k = 7
# else:
# pass
query_vector = await self.embedder.embed(query, "query")
print(len(query_vector))
# query_vector = query_vector[0]
if not query_vector:
logger.error(f"Failed to get embedding for query: {query}")
return []
# Perform vector search
vector_pipeline = [
{
"$vectorSearch": {
"queryVector": query_vector,
"path": "embedding", # Ensure 'embedding' is the correct field name
"numCandidates": 10000, # Consider making configurable
"limit": top_k,
"index": vector_index_name,
# "filter": {
# "$and": [
# {"quarter": {"$in": quarter_str}},
# {"year": {"$in": year_str}}
# ]
# }
}
},
{"$project": {"_id": 1, "content": 1, "score": {"$meta": "vectorSearchScore"}}}
]
vector_results_cursor = self.collection.aggregate(vector_pipeline)
vector_results = await vector_results_cursor.to_list(length=top_k)
logger.info(f"Vector search found {len(vector_results)} results for query: '{query}'")
# Tokenize query for keyword search using PyThaiNLP
query_tokens = word_tokenize(query, engine="newmm", keep_whitespace=False)
logger.info(f"Keyword search tokens: {query_tokens}")
# Perform keyword search (Atlas Search)
keyword_pipeline = [
{
"$search": {
"index": keyword_index_name,
"text": {
"query": query_tokens,
"path": "content_tokenized"
}
}
},
# {
# "$match": {
# "$and": [
# {"quarter": {"$in": quarter_str}},
# {"year": {"$in": year_str}}
# ]
# }
# },
{
"$project": {
"_id": 1,
"content": 1,
"score": {"$meta": "searchScore"}
}
},
{"$limit": top_k}
]
keyword_results_cursor = self.collection.aggregate(keyword_pipeline)
keyword_results = await keyword_results_cursor.to_list(length=top_k) # Using length for explicit limit from cursor
logger.info(f"Keyword search found {len(keyword_results)} results for query: '{query}'")
# Apply Weighted Reciprocal Rank Fusion (WRRF)
# Prepare results in the expected format for WRRF: list of dicts with _id and content
print(f"Vector results: {len(vector_results)}, Keyword results: {len(keyword_results)}")
vec_docs = [{"_id": str(doc["_id"]), "content": doc.get("content", "")} for doc in vector_results]
key_docs = [{"_id": str(doc["_id"]), "content": doc.get("content", "")} for doc in keyword_results]
# Handle potential missing 'content' key more robustly
# Ensure content is string
for doc_list in [vec_docs, key_docs]:
for doc in doc_list:
if not isinstance(doc["content"], str):
logger.warning(f"Document content is not a string (ID: {doc['_id']}), converting.")
doc["content"] = str(doc["content"])
fused_documents = self.weighted_reciprocal_rank([vec_docs, key_docs], top_k)
if len(fused_documents) < exact_top_k:
exact_top_k = len(fused_documents)
fused_documents = fused_documents[:exact_top_k]
# async def check_and_get_relevant(doc: Dict) -> Optional[Dict]:
# # Use a helper to run the classification and return the doc if relevant
# is_relevant = await self.llm_analyzer.classify_relevance(query=query, document_content=doc.get("content", ""))
# if is_relevant:
# return doc
# return None
# tasks = [check_and_get_relevant(doc) for doc in fused_documents]
# relevance_results = await asyncio.gather(*tasks)
# # Filter out None values (non-relevant docs)
# relevant_docs = [doc for doc in relevance_results if doc is not None]
# logger.info(f"Found {len(relevant_docs)} relevant documents after LLM classification (out of {len(fused_documents)}).")
# # if len(relevant_docs) < exact_top_k:
# # exact_top_k = len(relevant_docs)
# # Return only the content strings, limited to exact_top_k
# return [doc["content"] for doc in relevant_docs]
if not fused_documents:
logger.info("No documents to rank after fusion.")
return []
# 1. Format documents for the LLM
# docs_for_selection = {
# idx: doc.get("content", "")
# for idx, doc in enumerate(fused_documents)
# }
# # 2. Call the LLM to get indices of relevant documents
# selected_indices = await self.llm_analyzer.select_relevant_documents(
# query=query,
# documents=docs_for_selection
# )
# # 3. Filter the original fused_documents list based on the selected indices
# relevant_docs = []
# if selected_indices:
# # Create a set for efficient lookup and filter out-of-bounds indices
# valid_indices = set(idx for idx in selected_indices if 0 <= idx < len(fused_documents))
# relevant_docs = [fused_documents[i] for i in sorted(list(valid_indices))] # Sort to maintain some order
# return [doc["content"] for doc in relevant_docs]
# else:
# return [e["content"] for e in fused_documents] # If no indices selected, return all content
# --- END OF NEW LOGIC ---
return [e["content"] for e in fused_documents]
except Exception as e:
logger.error(f"Error in atlas_hybrid_search for query '{query}': {e}", exc_info=True)
return []
def weighted_reciprocal_rank(self, doc_lists: list[list[dict]], top_k: int) -> list[dict]:
"""
Apply Weighted Reciprocal Rank Fusion (WRRF) to rank results.
Args:
doc_lists: List of lists of documents. Each inner list is from one search method.
Each document is a dict with at least '_id' and 'content'.
top_k: The maximum number of documents to return after fusion.
Returns:
List of fused documents, sorted by RRF score, limited by top_k.
"""
try:
# Ensure doc_lists is not empty and contains lists
if not doc_lists or not all(isinstance(dl, list) for dl in doc_lists):
logger.warning("WRRF called with invalid doc_lists.")
return []
# Configuration for WRRF
c = 60 # Constant for rank penalty, tunable
weights = [1.0, 1.0] # Vector search weight, keyword search weight - Tunable
if len(doc_lists) != len(weights):
# Fallback if weights don't match lists (e.g., one search returned nothing)
# This basic handling might need refinement based on desired behavior
weights = [1.0] * len(doc_lists)
logger.warning(f"Number of doc lists ({len(doc_lists)}) != number of weights ({len(weights)}). Using equal weights.")
# raise ValueError("Number of rank lists must be equal to the number of weights.")
# Use a dictionary to map unique content to its document dict and accumulate scores
# This handles cases where the same doc appears in multiple lists or multiple times
rrf_scores = {} # content -> {'score': float, 'doc': dict}
for doc_list, weight in zip(doc_lists, weights):
processed_ids_in_list = set() # Track IDs within the current list to handle duplicates from the *same* source
for rank, doc in enumerate(doc_list, start=1):
doc_id = doc.get("_id")
content = doc.get("content")
# Basic validation
if not doc_id or content is None:
logger.warning(f"Skipping doc with missing ID or content in WRRF: {doc}")
continue
if not isinstance(content, str): # Ensure content is string for keying
content = str(content)
doc["content"] = content # Update doc dict too
# Only score the first occurrence of a document *within the same list*
if doc_id in processed_ids_in_list:
continue
processed_ids_in_list.add(doc_id)
# Calculate RRF score contribution
rank_score = weight * (1.0 / (rank + c))
# Accumulate score or add new entry
if content in rrf_scores:
rrf_scores[content]['score'] += rank_score
else:
# Store the first encountered 'doc' dict for this content
rrf_scores[content] = {'score': rank_score, 'doc': doc}
# Sort documents based on accumulated RRF score
# We sort the items (content, score_data) by score
sorted_items = sorted(rrf_scores.items(), key=lambda item: item[1]['score'], reverse=True)
# Return the document dictionaries from the sorted items, limited by top_k
return [item[1]['doc'] for item in sorted_items[:top_k]]
except Exception as e:
logger.error(f"Error in weighted_reciprocal_rank: {e}", exc_info=True)
return []
# Example usage (optional, for testing)
if __name__ == "__main__":
# To test async code, you need an asyncio event loop
async def main_test():
print("Testing MongoHybridSearch...")
try:
search_engine = MongoHybridSearch()
query_example = 'มี product ไรบ้าง'
results = await search_engine.search_documents(query_example) # Await here
print("\nSearch Results:")
if results:
print(results)
else:
print("Search failed or returned no results.")
except Exception as e:
print(f"An error occurred during testing: {e}")
# Run the async test function
asyncio.run(main_test())