Qpsychometric / db_module /db_handler.py
Fadi12's picture
Qpsychometric Space
59a8a7c
import pandas as pd
from pymongo import MongoClient
from pathlib import Path, PurePosixPath
import os
class MongoDBHandler:
def __init__(self, connection_string, db_name, base_dir):
"""
Initialize the MongoDB handler.
:param connection_string: MongoDB connection string.
:param db_name: MongoDB database name.
"""
self.client = MongoClient(connection_string)
self.db = self.client[db_name]
self.base_dir=Path(base_dir)
def insert_evaluation_files(self, identifiers=["model_version_id", "ordinal"]):
files_base_dir = os.listdir(self.base_dir)
folders_base_dir = [self.base_dir / folder for folder in files_base_dir if os.path.isdir(self.base_dir / folder)]
for folder_path in folders_base_dir:
questionnaires_files = os.listdir(folder_path)
questionnaires_type = str(PurePosixPath(folder_path)).split("/")[-1]
for questionnaire in questionnaires_files:
questionnaire_path = folder_path / questionnaire
collection_name = f"{questionnaires_type}_{questionnaire.removesuffix('.csv')}"
self.create_empty_collection(collection_name)
self.insert_new_rows(questionnaire_path, collection_name=collection_name, identifiers=identifiers)
def insert_meta_data(self, identifiers=["model_version_id"]):
meta_data_file = self.base_dir / "models_meta_data.csv"
self.create_empty_collection("models_meta_data")
self.insert_new_rows(meta_data_file, collection_name="models_meta_data", identifiers=identifiers)
def insert_errors(self, identifiers=["model_version_id"]):
errors_file = self.base_dir / "models_errors.csv"
self.create_empty_collection("models_errors")
self.insert_new_rows(errors_file, collection_name="models_errors", identifiers=identifiers)
def insert_csv_to_mongo(self, csv_path, collection_name):
"""
Insert data from a CSV file into a specified MongoDB collection.
:param csv_path: Path to the CSV file.
:param collection_name: The name of the MongoDB collection.
"""
collection = self.db[collection_name]
df = pd.read_csv(csv_path, encoding = "utf-8-sig")
documents = df.to_dict(orient='records')
if documents:
collection.insert_many(documents)
print(f"{len(documents)} documents inserted into {collection_name}.")
else:
print("No data found in the CSV.")
def fetch_documents(self, collection_name, query={}):
"""
Fetch documents from a specified collection with an optional query.
:param collection_name: The name of the MongoDB collection.
:param query: Optional query to filter the documents.
:return: A list of documents.
"""
collection = self.db[collection_name]
documents = collection.find(query)
return [doc for doc in documents]
def insert_new_rows(self, csv_path, collection_name, identifiers):
"""
Insert new rows from a CSV file into a specified MongoDB collection.
Only new rows (those not already in the database) will be inserted.
:param csv_path: Path to the CSV file.
:param collection_name: The name of the MongoDB collection.
"""
collection = self.db[collection_name]
df = pd.read_csv(csv_path, encoding = 'utf-8-sig')
documents = df.to_dict(orient='records')
existing_documents = collection.find()
existing_keys = {tuple(doc[identifier] for identifier in identifiers) for doc in existing_documents}
new_documents = []
for document in documents:
document_key = tuple(document[identifier] for identifier in identifiers)
if document_key not in existing_keys:
new_documents.append(document)
if new_documents:
collection.insert_many(new_documents)
print(f"Inserted {len(new_documents)} new documents in {collection_name}.")
else:
print(f"No new documents to insert to {collection_name}.")
def create_empty_collection(self, collection_name):
"""
Explicitly create an empty collection.
:param collection_name: The name of the collection to create.
"""
if collection_name not in self.db.list_collection_names():
self.db.create_collection(collection_name)
print(f"Collection '{collection_name}' created.")
else:
print(f"Collection '{collection_name}' already exists.")
def delete_collection(self, collection_name):
ans = input("Are you sure? (yes/no)")
if ans=="yes":
collection = self.db[collection_name]
result = collection.delete_many({})
print(f"Deleted {result.deleted_count} documents.")
def clear_all_collections(self):
"""
Clears all collections in the MongoDB database by removing all documents.
This method does not drop the collections themselves, only clears them of any documents.
"""
ans = input("Are you sure? (yes/no)")
if ans=="yes":
collections = self.db.list_collection_names()
for collection_name in collections:
collection = self.db[collection_name]
result = collection.delete_many({})
print(f"Cleared {result.deleted_count} documents from collection: {collection_name}")
def count_documents(self, collection_name):
"""
Count how many documents are in a given MongoDB collection.
:param collection_name: The name of the MongoDB collection.
:return: The number of documents in the collection.
"""
collection = self.db[collection_name]
document_count = collection.count_documents({})
print(f"{document_count} documents in {collection_name}")
def check_duplicates(self, collection_name, field_name):
"""
Check if there are any duplicate documents in a collection based on a specific field.
:param collection_name: The name of the MongoDB collection.
:param field_name: The field name to check for duplicates (e.g., 'commit_hash').
:return: List of documents with duplicate values in the specified field.
"""
collection = self.db[collection_name]
duplicates = collection.aggregate([
{"$group": {
"_id": f"${field_name}",
"count": {"$sum": 1},
"docs": {"$push": "$$ROOT"}
}},
{"$match": {
"count": {"$gt": 1}
}}
])
duplicate_docs = []
for duplicate in duplicates:
duplicate_docs.append(duplicate['docs'])
print(f"{len(duplicate_docs)} duplicate documents in {collection_name}")
return duplicate_docs
def export_collection_to_csv(self, collection_name, output_path="./fetched_model_logs/model_logs/"):
"""
Export documents from a MongoDB collection to a CSV file.
If "all" is passed as the collection_name, export all collections.
:param collection_name: The name of the MongoDB collection or "all".
:param output_path: Path to the output CSV file.
"""
if not os.path.exists(output_path):
os.makedirs(output_path)
qmnli_path = output_path + "QMNLI/"
if not os.path.exists(qmnli_path):
os.makedirs(qmnli_path)
qmlm_path = output_path + "QMLM/"
if not os.path.exists(qmlm_path):
os.makedirs(qmlm_path)
if collection_name == "all":
collection_names = self.db.list_collection_names()
else:
collection_names = [collection_name]
for collection_name in collection_names:
collection = self.db[collection_name]
documents = collection.find({}, {'_id': 0})
if "QMNLI" in collection_name or "QMLM" in collection_name:
csv_path = collection_name.replace("_","/") + ".csv"
else:
csv_path = collection_name + ".csv"
document_list = [doc for doc in documents]
if document_list:
df = pd.DataFrame(document_list)
df.to_csv(output_path + csv_path, index=False, encoding="utf-8-sig")
print(f"Exported {len(document_list)} documents to {output_path + csv_path}")
else:
print(f"No documents found in the collection '{collection_name}'.")
def export_qmlm_qmnli_collections_to_csv(self):
"""
Concatenate all documents from collections that have 'QMLM' or 'QMNLI' in their name
and export them into a single CSV file.
"""
all_documents = []
collection_names = self.db.list_collection_names()
for collection_name in collection_names:
if "QMLM" in collection_name or "QMNLI" in collection_name:
collection = self.db[collection_name]
documents = collection.find({}, {'_id': 0})
document_list = [doc for doc in documents]
all_documents.extend(document_list)
if all_documents:
df = pd.DataFrame(all_documents)
return df
print("No matching documents found in the QMLM or QMNLI collections.")
return None
def close_connection(self):
"""Close the connection to the MongoDB server."""
self.client.close()