"""Contains helper functions that are used in the RAG pipeline.""" import os import gc import json import torch import shutil from typing import List, Dict import faiss import numpy as np def save_cache(data: List[Dict], filepath: str) -> None: """Saving the chunks and the embeddings for easy retrieval in .json format""" try: with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"Failed to save cache to {filepath}: {e}") def load_cache(filepath: str) -> List[Dict]: """Loading the saved cache""" if os.path.exists(filepath): try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: print(f"Failed to load cache from {filepath}: {e}") return [] # Vector Store Helper Functions using IndexFlatIP (for semantic search) def init_faiss_indexflatip(embedding_dim:int) -> faiss.IndexFlatIP: index = faiss.IndexFlatIP(embedding_dim) return index def add_embeddings_to_index(index, embeddings: np.ndarray): if embeddings.size > 0: # Embedding array is not empty index.add(embeddings.astype(np.float32)) def search_faiss_index(index, query_embedding: np.ndarray, k: int = 5): # Ensure query_embedding is 2D if query_embedding.ndim == 1: query_embedding = query_embedding.reshape(1, -1) distances, indices = index.search(query_embedding.astype(np.float32), k) return distances, indices def save_faiss_index(index, filepath: str): faiss.write_index(index, filepath) def load_faiss_index(filepath: str): return faiss.read_index(filepath) # Deleting extracted images directory after captioning def cleanup_images(image_dir: str): try: shutil.rmtree(image_dir) print(f"[INFO] Cleaned up extracted images directory: {image_dir}") except Exception as e: print(f"[WARNING] Failed to delete some images in {image_dir}: {e}") # Just being agnostic because my space may only be using CPU but why not? def clear_gpu_cache(): """Clear GPU cache and run garbage collection(saving on memory).""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect()