File size: 2,705 Bytes
5f9784f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from pinecone import Pinecone, ServerlessSpec, PodSpec
from langchain_pinecone import PineconeVectorStore
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain.indexes import SQLRecordManager, index

from src.pdf_handler import extract_pdf, load_pdf_directory, split_pdf

import os
import shutil
from dotenv import load_dotenv

load_dotenv()


def setup_pinecone(index_name, embedding_model, embedding_dim, metric='cosine', use_serverless=True):
    pc = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
    if use_serverless:
        spec = ServerlessSpec(cloud='aws', region='us-east-1')
    else:
        spec = PodSpec()

    if index_name in pc.list_indexes().names():
        pc.delete_index(index_name)

    pc.create_index(
        index_name,
        dimension=embedding_dim,
        metric=metric,
        spec=spec
    )

    db = PineconeVectorStore(index_name=index_name, embedding=embedding_model)
    return db


def setup_chroma(index_name, embedding_model, persist_directory=None):
    if not persist_directory:
        persist_directory = './.cache/database'

    os.makedirs(persist_directory, exist_ok=True)

    db = Chroma(index_name, embedding_function=embedding_model, persist_directory=persist_directory)
    return db


class VectorDB:
    def __init__(self, db_name, index_name, cache_dir=None):
        embedding = OllamaEmbeddings(model='nomic-embed-text:latest', num_gpu=1)

        if not cache_dir:
            cache_dir = './.cache/database'
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

        if db_name == 'pinecone':
            self.vectorstore = setup_pinecone(index_name, embedding, 768, 'cosine')
        else:
            self.vectorstore = setup_chroma(index_name, embedding, self.cache_dir)

        namespace = f'{db_name}/{index_name}'
        self.record_manager = SQLRecordManager(namespace,
                                               db_url=f'sqlite:///{self.cache_dir}/record_manager_cache.sql')
        self.record_manager.create_schema()

    def index(self, uploaded_file):
        directory = extract_pdf(uploaded_file)
        docs = load_pdf_directory(directory)
        chunks = split_pdf(docs)

        index(
            docs_source=chunks,
            record_manager=self.record_manager,
            vector_store=self.vectorstore,
            cleanup='full',
            source_id_key='source'
        )

        for file in os.listdir(directory):
            os.remove(os.path.join(directory, file))

    def as_retriever(self):
        return self.vectorstore.as_retriever()

    def __del__(self):
        shutil.rmtree(self.cache_dir)