Saint5's picture
Uploading Mulitimodal Retrieval Augmented Generation System.
c3a4b6a verified
"""Contains helper functions that are used in the RAG pipeline."""
import os
import gc
import json
import torch
import shutil
from typing import List, Dict
import faiss
import numpy as np
def save_cache(data: List[Dict], filepath: str) -> None:
"""Saving the chunks and the embeddings for easy retrieval in .json format"""
try:
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Failed to save cache to {filepath}: {e}")
def load_cache(filepath: str) -> List[Dict]:
"""Loading the saved cache"""
if os.path.exists(filepath):
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Failed to load cache from {filepath}: {e}")
return []
# Vector Store Helper Functions using IndexFlatIP (for semantic search)
def init_faiss_indexflatip(embedding_dim:int) -> faiss.IndexFlatIP:
index = faiss.IndexFlatIP(embedding_dim)
return index
def add_embeddings_to_index(index, embeddings: np.ndarray):
if embeddings.size > 0: # Embedding array is not empty
index.add(embeddings.astype(np.float32))
def search_faiss_index(index, query_embedding: np.ndarray, k: int = 5):
# Ensure query_embedding is 2D
if query_embedding.ndim == 1:
query_embedding = query_embedding.reshape(1, -1)
distances, indices = index.search(query_embedding.astype(np.float32), k)
return distances, indices
def save_faiss_index(index, filepath: str):
faiss.write_index(index, filepath)
def load_faiss_index(filepath: str):
return faiss.read_index(filepath)
# Deleting extracted images directory after captioning
def cleanup_images(image_dir: str):
try:
shutil.rmtree(image_dir)
print(f"[INFO] Cleaned up extracted images directory: {image_dir}")
except Exception as e:
print(f"[WARNING] Failed to delete some images in {image_dir}: {e}")
# Just being agnostic because my space may only be using CPU but why not?
def clear_gpu_cache():
"""Clear GPU cache and run garbage collection(saving on memory)."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()