Spaces:
Runtime error
Runtime error
File size: 9,713 Bytes
59a8a7c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | import pandas as pd
from pymongo import MongoClient
from pathlib import Path, PurePosixPath
import os
class MongoDBHandler:
def __init__(self, connection_string, db_name, base_dir):
"""
Initialize the MongoDB handler.
:param connection_string: MongoDB connection string.
:param db_name: MongoDB database name.
"""
self.client = MongoClient(connection_string)
self.db = self.client[db_name]
self.base_dir=Path(base_dir)
def insert_evaluation_files(self, identifiers=["model_version_id", "ordinal"]):
files_base_dir = os.listdir(self.base_dir)
folders_base_dir = [self.base_dir / folder for folder in files_base_dir if os.path.isdir(self.base_dir / folder)]
for folder_path in folders_base_dir:
questionnaires_files = os.listdir(folder_path)
questionnaires_type = str(PurePosixPath(folder_path)).split("/")[-1]
for questionnaire in questionnaires_files:
questionnaire_path = folder_path / questionnaire
collection_name = f"{questionnaires_type}_{questionnaire.removesuffix('.csv')}"
self.create_empty_collection(collection_name)
self.insert_new_rows(questionnaire_path, collection_name=collection_name, identifiers=identifiers)
def insert_meta_data(self, identifiers=["model_version_id"]):
meta_data_file = self.base_dir / "models_meta_data.csv"
self.create_empty_collection("models_meta_data")
self.insert_new_rows(meta_data_file, collection_name="models_meta_data", identifiers=identifiers)
def insert_errors(self, identifiers=["model_version_id"]):
errors_file = self.base_dir / "models_errors.csv"
self.create_empty_collection("models_errors")
self.insert_new_rows(errors_file, collection_name="models_errors", identifiers=identifiers)
def insert_csv_to_mongo(self, csv_path, collection_name):
"""
Insert data from a CSV file into a specified MongoDB collection.
:param csv_path: Path to the CSV file.
:param collection_name: The name of the MongoDB collection.
"""
collection = self.db[collection_name]
df = pd.read_csv(csv_path, encoding = "utf-8-sig")
documents = df.to_dict(orient='records')
if documents:
collection.insert_many(documents)
print(f"{len(documents)} documents inserted into {collection_name}.")
else:
print("No data found in the CSV.")
def fetch_documents(self, collection_name, query={}):
"""
Fetch documents from a specified collection with an optional query.
:param collection_name: The name of the MongoDB collection.
:param query: Optional query to filter the documents.
:return: A list of documents.
"""
collection = self.db[collection_name]
documents = collection.find(query)
return [doc for doc in documents]
def insert_new_rows(self, csv_path, collection_name, identifiers):
"""
Insert new rows from a CSV file into a specified MongoDB collection.
Only new rows (those not already in the database) will be inserted.
:param csv_path: Path to the CSV file.
:param collection_name: The name of the MongoDB collection.
"""
collection = self.db[collection_name]
df = pd.read_csv(csv_path, encoding = 'utf-8-sig')
documents = df.to_dict(orient='records')
existing_documents = collection.find()
existing_keys = {tuple(doc[identifier] for identifier in identifiers) for doc in existing_documents}
new_documents = []
for document in documents:
document_key = tuple(document[identifier] for identifier in identifiers)
if document_key not in existing_keys:
new_documents.append(document)
if new_documents:
collection.insert_many(new_documents)
print(f"Inserted {len(new_documents)} new documents in {collection_name}.")
else:
print(f"No new documents to insert to {collection_name}.")
def create_empty_collection(self, collection_name):
"""
Explicitly create an empty collection.
:param collection_name: The name of the collection to create.
"""
if collection_name not in self.db.list_collection_names():
self.db.create_collection(collection_name)
print(f"Collection '{collection_name}' created.")
else:
print(f"Collection '{collection_name}' already exists.")
def delete_collection(self, collection_name):
ans = input("Are you sure? (yes/no)")
if ans=="yes":
collection = self.db[collection_name]
result = collection.delete_many({})
print(f"Deleted {result.deleted_count} documents.")
def clear_all_collections(self):
"""
Clears all collections in the MongoDB database by removing all documents.
This method does not drop the collections themselves, only clears them of any documents.
"""
ans = input("Are you sure? (yes/no)")
if ans=="yes":
collections = self.db.list_collection_names()
for collection_name in collections:
collection = self.db[collection_name]
result = collection.delete_many({})
print(f"Cleared {result.deleted_count} documents from collection: {collection_name}")
def count_documents(self, collection_name):
"""
Count how many documents are in a given MongoDB collection.
:param collection_name: The name of the MongoDB collection.
:return: The number of documents in the collection.
"""
collection = self.db[collection_name]
document_count = collection.count_documents({})
print(f"{document_count} documents in {collection_name}")
def check_duplicates(self, collection_name, field_name):
"""
Check if there are any duplicate documents in a collection based on a specific field.
:param collection_name: The name of the MongoDB collection.
:param field_name: The field name to check for duplicates (e.g., 'commit_hash').
:return: List of documents with duplicate values in the specified field.
"""
collection = self.db[collection_name]
duplicates = collection.aggregate([
{"$group": {
"_id": f"${field_name}",
"count": {"$sum": 1},
"docs": {"$push": "$$ROOT"}
}},
{"$match": {
"count": {"$gt": 1}
}}
])
duplicate_docs = []
for duplicate in duplicates:
duplicate_docs.append(duplicate['docs'])
print(f"{len(duplicate_docs)} duplicate documents in {collection_name}")
return duplicate_docs
def export_collection_to_csv(self, collection_name, output_path="./fetched_model_logs/model_logs/"):
"""
Export documents from a MongoDB collection to a CSV file.
If "all" is passed as the collection_name, export all collections.
:param collection_name: The name of the MongoDB collection or "all".
:param output_path: Path to the output CSV file.
"""
if not os.path.exists(output_path):
os.makedirs(output_path)
qmnli_path = output_path + "QMNLI/"
if not os.path.exists(qmnli_path):
os.makedirs(qmnli_path)
qmlm_path = output_path + "QMLM/"
if not os.path.exists(qmlm_path):
os.makedirs(qmlm_path)
if collection_name == "all":
collection_names = self.db.list_collection_names()
else:
collection_names = [collection_name]
for collection_name in collection_names:
collection = self.db[collection_name]
documents = collection.find({}, {'_id': 0})
if "QMNLI" in collection_name or "QMLM" in collection_name:
csv_path = collection_name.replace("_","/") + ".csv"
else:
csv_path = collection_name + ".csv"
document_list = [doc for doc in documents]
if document_list:
df = pd.DataFrame(document_list)
df.to_csv(output_path + csv_path, index=False, encoding="utf-8-sig")
print(f"Exported {len(document_list)} documents to {output_path + csv_path}")
else:
print(f"No documents found in the collection '{collection_name}'.")
def export_qmlm_qmnli_collections_to_csv(self):
"""
Concatenate all documents from collections that have 'QMLM' or 'QMNLI' in their name
and export them into a single CSV file.
"""
all_documents = []
collection_names = self.db.list_collection_names()
for collection_name in collection_names:
if "QMLM" in collection_name or "QMNLI" in collection_name:
collection = self.db[collection_name]
documents = collection.find({}, {'_id': 0})
document_list = [doc for doc in documents]
all_documents.extend(document_list)
if all_documents:
df = pd.DataFrame(all_documents)
return df
print("No matching documents found in the QMLM or QMNLI collections.")
return None
def close_connection(self):
"""Close the connection to the MongoDB server."""
self.client.close()
|