Sarathrsk03's picture
Upload 8 files
d21bdd3 verified
import chromadb
from chromadb.utils import embedding_functions
import csv
default_ef = embedding_functions.DefaultEmbeddingFunction()
def createEmbeddings():
try:
client = chromadb.PersistentClient("./chromaDB")
collection = client.get_or_create_collection("ICC_Rules")
# Use with statements for proper file handling
with open("notes.txt", "r") as f:
data = [line.strip() for line in f.readlines()] # Strip whitespace/newlines
with open("metaData.csv", "r") as f:
metaData = list(csv.reader(f))
# Validate data lengths
if len(data) != len(metaData):
print(f"Warning: Mismatch in data lengths. notes.txt has {len(data)} lines but metaData.csv has {len(metaData)} rows.")
# Process in batches if possible for better efficiency
for i in range(len(data)):
if i < len(metaData):
collection.add(
ids=[str(i)],
embeddings=default_ef([data[i]]),
metadatas=[{"ruleDescription": metaData[i][2]}],
documents=[data[i]] # Also store the original text
)
else:
# Handle case where metaData is shorter than data
collection.add(
ids=[str(i)],
embeddings=default_ef([data[i]]),
metadatas=[{"ruleDescription": "No metadata available"}],
documents=[data[i]]
)
print(f"Successfully added {len(data)} documents to the collection.")
except Exception as e:
print(f"Error creating embeddings: {e}")
def retrieveInfo(query, n_results=3):
try:
client = chromadb.PersistentClient("./chromaDB")
collection = client.get_or_create_collection("ICC_Rules")
results = collection.query(
query_embeddings=default_ef([query]), # Pass as a list
n_results=n_results,
)
if not results or not results['documents'] or len(results['documents'][0]) == 0:
return "No relevant information found."
# Format the results for easier consumption
formatted_results = []
for i in range(len(results['documents'][0])):
formatted_results.append({
"document": results['documents'][0][i],
"metadata": results['metadatas'][0][i] if 'metadatas' in results else {},
"distance": results['distances'][0][i] if 'distances' in results else None
})
print("Retrieved information successfully.")
return formatted_results
except Exception as e:
print(f"Error retrieving information: {e}")
return f"An error occurred: {e}"
if __name__ == "__main__":
#createEmbeddings()
print(retrieveInfo("What happens when there are less than 11 players fit?"))