| | import json |
| | import os |
| | import sys |
| | from typing import Any, Dict, List |
| |
|
| | |
| | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| | sys.path.append(project_root) |
| |
|
| | import chromadb |
| | import numpy as np |
| | import torch |
| | from nanoid import generate |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| |
|
| | def append_to_json(new_entries, filename="json_file_record.json"): |
| | """ |
| | Append new entries to an existing JSON array file, or create a new one if it doesn't exist. |
| | |
| | Args: |
| | new_entries (list): List of dictionaries to append |
| | filename (str): Name of the JSON file |
| | """ |
| | try: |
| | |
| | if os.path.exists(filename) and os.path.getsize(filename) > 0: |
| | with open(filename, "r") as f: |
| | data = json.load(f) |
| | if not isinstance(data, list): |
| | data = [] |
| | else: |
| | data = [] |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | with open(filename, "w") as f: |
| | json.dump(data, f, indent=4) |
| |
|
| | except json.JSONDecodeError: |
| | |
| | data = new_entries |
| | with open(filename, "w") as f: |
| | json.dump(data, f, indent=4) |
| |
|
| |
|
| | class BERTEmbedder: |
| | def __init__(self): |
| | self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| | self.model = AutoModel.from_pretrained("bert-base-uncased") |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model.to(self.device) |
| |
|
| | def get_embeddings(self, texts: List[str]) -> np.ndarray: |
| | embeddings = [] |
| | self.model.eval() |
| | with torch.no_grad(): |
| | for text in texts: |
| | inputs = self.tokenizer( |
| | text, |
| | padding=True, |
| | truncation=True, |
| | max_length=512, |
| | return_tensors="pt", |
| | ) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | outputs = self.model(**inputs) |
| | embeddings.append(outputs.last_hidden_state.mean(dim=1).cpu().numpy()) |
| | return np.vstack(embeddings) |
| |
|
| |
|
| | class VectorStore: |
| | def __init__( |
| | self, persist_directory: str = "../chroma_rag", query=False, is_uploaded=False |
| | ): |
| | try: |
| | if query == False and is_uploaded == True: |
| | print("Embbeding store mode.") |
| | print(f"Initializing ChromaDB with directory: {persist_directory}") |
| | self.client = chromadb.PersistentClient(path=persist_directory) |
| | print("ChromaDB client created successfully") |
| |
|
| | self.collection = self.client.get_or_create_collection( |
| | name="documents", |
| | metadata={"hnsw:space": "cosine"}, |
| | embedding_function=None, |
| | ) |
| | print(f"Collection 'documents' ready") |
| |
|
| | self.embedder = BERTEmbedder() |
| | print("BERT embedder initialized") |
| |
|
| | |
| | content = self.collection.get() |
| | print(f"Collection contains {len(content['documents'])} documents") |
| | self.json_file_path = "json_file_record.json" |
| |
|
| | else: |
| | print("query mode") |
| | persist_directory = "chroma_rag" |
| | print(f"Initializing ChromaDB with directory: {persist_directory}") |
| | self.client = chromadb.PersistentClient(path=persist_directory) |
| | print("ChromaDB client created successfully") |
| |
|
| | self.collection = self.client.get_or_create_collection( |
| | name="documents", |
| | metadata={"hnsw:space": "cosine"}, |
| | embedding_function=None, |
| | ) |
| | print(f"Collection 'documents' ready") |
| |
|
| | self.embedder = BERTEmbedder() |
| | print("BERT embedder initialized") |
| |
|
| | |
| | content = self.collection.get() |
| | print(f"Collection contains {len(content['documents'])} documents") |
| | self.json_file_path = "utils/json_file_record.json" |
| |
|
| | except Exception as e: |
| | print(f"Error initializing VectorStore: {e}") |
| | raise |
| |
|
| | def is_collection_empty(self) -> bool: |
| | try: |
| | content = self.collection.get() |
| | return len(content["documents"]) == 0 |
| | except Exception as e: |
| | print(f"Error checking collection: {e}") |
| | return True |
| |
|
| | def add_documents(self, chunks: List[Dict[str, Any]]): |
| | try: |
| | texts = [chunk["content"] for chunk in chunks] |
| | metadatas = [chunk["metadata"] for chunk in chunks] |
| |
|
| | print(f"Generating embeddings for {len(texts)} documents...") |
| | print(texts) |
| | embeddings = self.embedder.get_embeddings(texts) |
| |
|
| | id_val = str(generate(size=8)) |
| | print(f"Generated ID: {id_val}") |
| |
|
| | if os.path.exists(self.json_file_path): |
| |
|
| | with open(self.json_file_path, "r") as f: |
| | data = json.load(f) |
| | for chunk in chunks: |
| | temp = {"id": id_val, "file_path": chunk["metadata"]["source"]} |
| |
|
| | break |
| | |
| | data.append(temp) |
| |
|
| | |
| | with open(self.json_file_path, "w") as file: |
| | json.dump(data, file, indent=4) |
| | else: |
| | |
| | with open(self.json_file_path, "w") as f: |
| | temp = [] |
| | for chunk in chunks: |
| | temp.append( |
| | {"id": id_val, "file_path": chunk["metadata"]["source"]} |
| | ) |
| | break |
| | |
| | with open(self.json_file_path, "w") as file: |
| | json.dump(temp, file, indent=4) |
| |
|
| | print("*************") |
| | count = 0 |
| | ids = [] |
| | |
| | for metadata in metadatas: |
| | metadata["topics"] = str(metadata["topics"]) |
| | ids.append(f"{id_val}{count}") |
| | count += 1 |
| | print(metadatas) |
| | print("------------------------") |
| | print(len(metadatas)) |
| |
|
| | print(f"Adding {len(texts)} documents to collection...") |
| | self.collection.add( |
| | embeddings=embeddings.tolist(), |
| | documents=texts, |
| | metadatas=metadatas, |
| | ids=ids, |
| | ) |
| |
|
| | |
| | collection_content = self.collection.get() |
| | print( |
| | f"Collection now contains {len(collection_content['documents'])} documents" |
| | ) |
| |
|
| | return True |
| | except Exception as e: |
| | print(f"Error adding documents: {e}") |
| | return False |
| |
|
| | def query(self, query_text: str, n_results: int = 3) -> Dict: |
| | try: |
| |
|
| | print(f"Generating embedding for query: {query_text}") |
| | query_embedding = self.embedder.get_embeddings([query_text]) |
| |
|
| | print("Checking collection content:") |
| | collection_content = self.collection.get() |
| | print( |
| | f"Number of documents in collection: {len(collection_content['documents'])}" |
| | ) |
| |
|
| | print("Executing query...") |
| | query_vector = query_embedding.tolist() |
| | results = self.collection.query( |
| | n_results=min(n_results, len(collection_content["documents"])), |
| | query_embeddings=query_vector, |
| | ) |
| |
|
| | |
| |
|
| | print(f"Query results: {json.dumps(results, indent=2)}") |
| | return results |
| | except Exception as e: |
| | print(f"Error during query: {e}") |
| | return {"error": str(e)} |
| |
|
| | def delete_documents_by_filename(self, file_substring: str): |
| | """ |
| | Delete documents from the collection and JSON file by matching a substring in the file path. |
| | |
| | Args: |
| | file_substring (str): Substring to match in the file paths. |
| | json_file (str): Path to the JSON file containing document metadata. |
| | """ |
| | try: |
| | |
| | print(file_substring) |
| | json_file = self.json_file_path |
| | if not os.path.exists(json_file): |
| | print(f"JSON file {json_file} does not exist.") |
| | return |
| |
|
| | with open(json_file, "r") as f: |
| | data = json.load(f) |
| |
|
| | |
| | matching_records = [ |
| | record for record in data if file_substring in record["file_path"] |
| | ] |
| | if not matching_records: |
| | print(f"No records found matching substring: {file_substring}") |
| | return |
| |
|
| | |
| |
|
| | |
| | matching_ids = [record["id"] for record in matching_records] |
| | print("maching_ids", matching_ids[0]) |
| |
|
| | |
| | updated_data = [ |
| | record for record in data if record["id"] not in matching_ids |
| | ] |
| |
|
| | print("updated data", updated_data) |
| |
|
| | with open(json_file, "w") as f: |
| | json.dump(updated_data, f, indent=4) |
| |
|
| | print(f"Deleted {len(matching_records)} records from JSON file.") |
| | |
| | all_ids = self.collection.get()["ids"] |
| |
|
| | |
| | ids_to_delete = [id_ for id_ in all_ids if matching_ids[0] in id_] |
| |
|
| | |
| | if ids_to_delete: |
| | self.collection.delete(ids=ids_to_delete) |
| | print( |
| | f"Deleted {len(ids_to_delete)} records with IDs containing 'LDtz9CG5'." |
| | ) |
| | else: |
| | print("No matching IDs found.") |
| | except Exception as e: |
| | print(f"Error deleting documents: {e}") |
| |
|