Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| from pymongo import MongoClient | |
| from pathlib import Path, PurePosixPath | |
| import os | |
| class MongoDBHandler: | |
| def __init__(self, connection_string, db_name, base_dir): | |
| """ | |
| Initialize the MongoDB handler. | |
| :param connection_string: MongoDB connection string. | |
| :param db_name: MongoDB database name. | |
| """ | |
| self.client = MongoClient(connection_string) | |
| self.db = self.client[db_name] | |
| self.base_dir=Path(base_dir) | |
| def insert_evaluation_files(self, identifiers=["model_version_id", "ordinal"]): | |
| files_base_dir = os.listdir(self.base_dir) | |
| folders_base_dir = [self.base_dir / folder for folder in files_base_dir if os.path.isdir(self.base_dir / folder)] | |
| for folder_path in folders_base_dir: | |
| questionnaires_files = os.listdir(folder_path) | |
| questionnaires_type = str(PurePosixPath(folder_path)).split("/")[-1] | |
| for questionnaire in questionnaires_files: | |
| questionnaire_path = folder_path / questionnaire | |
| collection_name = f"{questionnaires_type}_{questionnaire.removesuffix('.csv')}" | |
| self.create_empty_collection(collection_name) | |
| self.insert_new_rows(questionnaire_path, collection_name=collection_name, identifiers=identifiers) | |
| def insert_meta_data(self, identifiers=["model_version_id"]): | |
| meta_data_file = self.base_dir / "models_meta_data.csv" | |
| self.create_empty_collection("models_meta_data") | |
| self.insert_new_rows(meta_data_file, collection_name="models_meta_data", identifiers=identifiers) | |
| def insert_errors(self, identifiers=["model_version_id"]): | |
| errors_file = self.base_dir / "models_errors.csv" | |
| self.create_empty_collection("models_errors") | |
| self.insert_new_rows(errors_file, collection_name="models_errors", identifiers=identifiers) | |
| def insert_csv_to_mongo(self, csv_path, collection_name): | |
| """ | |
| Insert data from a CSV file into a specified MongoDB collection. | |
| :param csv_path: Path to the CSV file. | |
| :param collection_name: The name of the MongoDB collection. | |
| """ | |
| collection = self.db[collection_name] | |
| df = pd.read_csv(csv_path, encoding = "utf-8-sig") | |
| documents = df.to_dict(orient='records') | |
| if documents: | |
| collection.insert_many(documents) | |
| print(f"{len(documents)} documents inserted into {collection_name}.") | |
| else: | |
| print("No data found in the CSV.") | |
| def fetch_documents(self, collection_name, query={}): | |
| """ | |
| Fetch documents from a specified collection with an optional query. | |
| :param collection_name: The name of the MongoDB collection. | |
| :param query: Optional query to filter the documents. | |
| :return: A list of documents. | |
| """ | |
| collection = self.db[collection_name] | |
| documents = collection.find(query) | |
| return [doc for doc in documents] | |
| def insert_new_rows(self, csv_path, collection_name, identifiers): | |
| """ | |
| Insert new rows from a CSV file into a specified MongoDB collection. | |
| Only new rows (those not already in the database) will be inserted. | |
| :param csv_path: Path to the CSV file. | |
| :param collection_name: The name of the MongoDB collection. | |
| """ | |
| collection = self.db[collection_name] | |
| df = pd.read_csv(csv_path, encoding = 'utf-8-sig') | |
| documents = df.to_dict(orient='records') | |
| existing_documents = collection.find() | |
| existing_keys = {tuple(doc[identifier] for identifier in identifiers) for doc in existing_documents} | |
| new_documents = [] | |
| for document in documents: | |
| document_key = tuple(document[identifier] for identifier in identifiers) | |
| if document_key not in existing_keys: | |
| new_documents.append(document) | |
| if new_documents: | |
| collection.insert_many(new_documents) | |
| print(f"Inserted {len(new_documents)} new documents in {collection_name}.") | |
| else: | |
| print(f"No new documents to insert to {collection_name}.") | |
| def create_empty_collection(self, collection_name): | |
| """ | |
| Explicitly create an empty collection. | |
| :param collection_name: The name of the collection to create. | |
| """ | |
| if collection_name not in self.db.list_collection_names(): | |
| self.db.create_collection(collection_name) | |
| print(f"Collection '{collection_name}' created.") | |
| else: | |
| print(f"Collection '{collection_name}' already exists.") | |
| def delete_collection(self, collection_name): | |
| ans = input("Are you sure? (yes/no)") | |
| if ans=="yes": | |
| collection = self.db[collection_name] | |
| result = collection.delete_many({}) | |
| print(f"Deleted {result.deleted_count} documents.") | |
| def clear_all_collections(self): | |
| """ | |
| Clears all collections in the MongoDB database by removing all documents. | |
| This method does not drop the collections themselves, only clears them of any documents. | |
| """ | |
| ans = input("Are you sure? (yes/no)") | |
| if ans=="yes": | |
| collections = self.db.list_collection_names() | |
| for collection_name in collections: | |
| collection = self.db[collection_name] | |
| result = collection.delete_many({}) | |
| print(f"Cleared {result.deleted_count} documents from collection: {collection_name}") | |
| def count_documents(self, collection_name): | |
| """ | |
| Count how many documents are in a given MongoDB collection. | |
| :param collection_name: The name of the MongoDB collection. | |
| :return: The number of documents in the collection. | |
| """ | |
| collection = self.db[collection_name] | |
| document_count = collection.count_documents({}) | |
| print(f"{document_count} documents in {collection_name}") | |
| def check_duplicates(self, collection_name, field_name): | |
| """ | |
| Check if there are any duplicate documents in a collection based on a specific field. | |
| :param collection_name: The name of the MongoDB collection. | |
| :param field_name: The field name to check for duplicates (e.g., 'commit_hash'). | |
| :return: List of documents with duplicate values in the specified field. | |
| """ | |
| collection = self.db[collection_name] | |
| duplicates = collection.aggregate([ | |
| {"$group": { | |
| "_id": f"${field_name}", | |
| "count": {"$sum": 1}, | |
| "docs": {"$push": "$$ROOT"} | |
| }}, | |
| {"$match": { | |
| "count": {"$gt": 1} | |
| }} | |
| ]) | |
| duplicate_docs = [] | |
| for duplicate in duplicates: | |
| duplicate_docs.append(duplicate['docs']) | |
| print(f"{len(duplicate_docs)} duplicate documents in {collection_name}") | |
| return duplicate_docs | |
| def export_collection_to_csv(self, collection_name, output_path="./fetched_model_logs/model_logs/"): | |
| """ | |
| Export documents from a MongoDB collection to a CSV file. | |
| If "all" is passed as the collection_name, export all collections. | |
| :param collection_name: The name of the MongoDB collection or "all". | |
| :param output_path: Path to the output CSV file. | |
| """ | |
| if not os.path.exists(output_path): | |
| os.makedirs(output_path) | |
| qmnli_path = output_path + "QMNLI/" | |
| if not os.path.exists(qmnli_path): | |
| os.makedirs(qmnli_path) | |
| qmlm_path = output_path + "QMLM/" | |
| if not os.path.exists(qmlm_path): | |
| os.makedirs(qmlm_path) | |
| if collection_name == "all": | |
| collection_names = self.db.list_collection_names() | |
| else: | |
| collection_names = [collection_name] | |
| for collection_name in collection_names: | |
| collection = self.db[collection_name] | |
| documents = collection.find({}, {'_id': 0}) | |
| if "QMNLI" in collection_name or "QMLM" in collection_name: | |
| csv_path = collection_name.replace("_","/") + ".csv" | |
| else: | |
| csv_path = collection_name + ".csv" | |
| document_list = [doc for doc in documents] | |
| if document_list: | |
| df = pd.DataFrame(document_list) | |
| df.to_csv(output_path + csv_path, index=False, encoding="utf-8-sig") | |
| print(f"Exported {len(document_list)} documents to {output_path + csv_path}") | |
| else: | |
| print(f"No documents found in the collection '{collection_name}'.") | |
| def export_qmlm_qmnli_collections_to_csv(self): | |
| """ | |
| Concatenate all documents from collections that have 'QMLM' or 'QMNLI' in their name | |
| and export them into a single CSV file. | |
| """ | |
| all_documents = [] | |
| collection_names = self.db.list_collection_names() | |
| for collection_name in collection_names: | |
| if "QMLM" in collection_name or "QMNLI" in collection_name: | |
| collection = self.db[collection_name] | |
| documents = collection.find({}, {'_id': 0}) | |
| document_list = [doc for doc in documents] | |
| all_documents.extend(document_list) | |
| if all_documents: | |
| df = pd.DataFrame(all_documents) | |
| return df | |
| print("No matching documents found in the QMLM or QMNLI collections.") | |
| return None | |
| def close_connection(self): | |
| """Close the connection to the MongoDB server.""" | |
| self.client.close() | |