Nelly-43 commited on
Commit
775c3a6
·
verified ·
1 Parent(s): dba6dd7

Update create_retriever.py

Browse files
Files changed (1) hide show
  1. create_retriever.py +105 -97
create_retriever.py CHANGED
@@ -1,98 +1,106 @@
1
- import os
2
- import glob
3
- from langchain_community.document_loaders import Docx2txtLoader, TextLoader, PyPDFLoader
4
- from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter
5
- from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.vectorstores import Chroma
7
- from langchain.retrievers import EnsembleRetriever
8
- # from ragatouille import RAGPretrainedModel
9
-
10
- # Function to load and process documents
11
- def docs_return(flag):
12
- directory_path = 'rag_data/'
13
- docx_file_pattern = '*.docx'
14
- pdf_file_pattern = '*.pdf'
15
- txt_file_pattern = '*.txt'
16
-
17
- docx_file_paths = glob.glob(directory_path + docx_file_pattern)
18
- pdf_file_paths = glob.glob(directory_path + pdf_file_pattern)
19
- txt_file_paths = glob.glob(directory_path + txt_file_pattern)
20
-
21
- all_doc, all_doc2 = [], []
22
-
23
- for x in docx_file_paths:
24
- loader = Docx2txtLoader(x)
25
- documents = loader.load()
26
- all_doc.extend(documents)
27
- all_doc2.append(str(documents[0].page_content))
28
-
29
- for x in pdf_file_paths:
30
- loader = PyPDFLoader(x, extract_images=True)
31
- docs_lazy = loader.lazy_load()
32
- documents = []
33
- for doc in docs_lazy:
34
- documents.append(doc)
35
- all_doc.extend(documents)
36
- all_doc2.append(str(documents[0].page_content))
37
-
38
- for x in txt_file_paths:
39
- loader = TextLoader(x)
40
- documents = loader.load()
41
- all_doc.extend(documents)
42
- all_doc2.append(str(documents[0].page_content))
43
-
44
- docs = '\n\n'.join(all_doc2)
45
-
46
- return all_doc if flag == 0 else docs
47
-
48
- # Function to get or download the embedding model
49
- def get_embedding_model(model_name):
50
- local_model_path = f"embedding_model/{model_name.replace('/', '_')}"
51
- if os.path.exists(local_model_path):
52
- print(f"Loading local model from {local_model_path}")
53
- return HuggingFaceEmbeddings(model_name=local_model_path)
54
- else:
55
- print(f"Downloading model {model_name}")
56
- return HuggingFaceEmbeddings(model_name=model_name)
57
-
58
- # Function to return different types of text splitters
59
- def get_text_splitter(splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
60
- if splitter_type == 'character':
61
- return CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator)
62
- elif splitter_type == 'recursive':
63
- return RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
64
- elif splitter_type == 'token':
65
- return TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=chunk_overlap)
66
- else:
67
- raise ValueError("Unsupported splitter type. Choose from 'character', 'recursive', or 'token'.")
68
-
69
- # Retriever using Chroma and HuggingFace embeddings
70
- def retriever_chroma(flag, model_name="BAAI/bge-large-en-v1.5", splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
71
- # Load or download the embedding model
72
- embeddings = get_embedding_model(model_name)
73
-
74
- if not flag:
75
- # Load the documents
76
- all_doc = docs_return(0)
77
-
78
- # Use the splitter parameters
79
- text_splitter = get_text_splitter(splitter_type=splitter_type, chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator, max_tokens=max_tokens)
80
-
81
- # Split the documents using the text splitter
82
- docs = text_splitter.split_documents(documents=all_doc)
83
-
84
- # Create a Chroma vector database
85
- vectordb = Chroma.from_documents(docs, embeddings, persist_directory="./chroma_db")
86
-
87
- # Create the retriever
88
- chroma_retriever = vectordb.as_retriever(
89
- search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10}
90
- )
91
- return chroma_retriever
92
- else:
93
- # Load a local Chroma vectorstore
94
- vectordb = Chroma.load_local("vectorstore", embeddings)
95
- chroma_retriever = vectordb.as_retriever(
96
- search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10}
97
- )
 
 
 
 
 
 
 
 
98
  return chroma_retriever
 
1
+ import os
2
+ import glob
3
+ from langchain_community.document_loaders import Docx2txtLoader, TextLoader, PyPDFLoader, CSVLoader
4
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.retrievers import EnsembleRetriever
8
+ # from ragatouille import RAGPretrainedModel
9
+
10
+ # Function to load and process documents
11
+ def docs_return(flag):
12
+ directory_path = 'rag_data/'
13
+ docx_file_pattern = '*.docx'
14
+ pdf_file_pattern = '*.pdf'
15
+ txt_file_pattern = '*.txt'
16
+ csv_file_pattern = '*.csv'
17
+
18
+ docx_file_paths = glob.glob(directory_path + docx_file_pattern)
19
+ pdf_file_paths = glob.glob(directory_path + pdf_file_pattern)
20
+ txt_file_paths = glob.glob(directory_path + txt_file_pattern)
21
+ csv_file_paths = glob.glob(directory_path + csv_file_pattern)
22
+
23
+ all_doc, all_doc2 = [], []
24
+
25
+ for x in docx_file_paths:
26
+ loader = Docx2txtLoader(x)
27
+ documents = loader.load()
28
+ all_doc.extend(documents)
29
+ all_doc2.append(str(documents[0].page_content))
30
+
31
+ for x in pdf_file_paths:
32
+ loader = PyPDFLoader(x, extract_images=True)
33
+ docs_lazy = loader.lazy_load()
34
+ documents = []
35
+ for doc in docs_lazy:
36
+ documents.append(doc)
37
+ all_doc.extend(documents)
38
+ all_doc2.append(str(documents[0].page_content))
39
+
40
+ for x in txt_file_paths:
41
+ loader = TextLoader(x)
42
+ documents = loader.load()
43
+ all_doc.extend(documents)
44
+ all_doc2.append(str(documents[0].page_content))
45
+
46
+ for x in csv_file_paths:
47
+ loader = CSVLoader(file_path=x, source_column="translation")
48
+ documents = loader.load()
49
+ all_doc.extend(documents)
50
+ all_doc2.append(str(documents[0].page_content))
51
+
52
+ docs = '\n\n'.join(all_doc2)
53
+
54
+ return all_doc if flag == 0 else docs
55
+
56
+ # Function to get or download the embedding model
57
+ def get_embedding_model(model_name):
58
+ local_model_path = f"embedding_model/{model_name.replace('/', '_')}"
59
+ if os.path.exists(local_model_path):
60
+ print(f"Loading local model from {local_model_path}")
61
+ return HuggingFaceEmbeddings(model_name=local_model_path)
62
+ else:
63
+ print(f"Downloading model {model_name}")
64
+ return HuggingFaceEmbeddings(model_name=model_name)
65
+
66
+ # Function to return different types of text splitters
67
+ def get_text_splitter(splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
68
+ if splitter_type == 'character':
69
+ return CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator)
70
+ elif splitter_type == 'recursive':
71
+ return RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
72
+ elif splitter_type == 'token':
73
+ return TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=chunk_overlap)
74
+ else:
75
+ raise ValueError("Unsupported splitter type. Choose from 'character', 'recursive', or 'token'.")
76
+
77
+ # Retriever using Chroma and HuggingFace embeddings
78
+ def retriever_chroma(flag, model_name="BAAI/bge-large-en-v1.5", splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
79
+ # Load or download the embedding model
80
+ embeddings = get_embedding_model(model_name)
81
+
82
+ if not flag:
83
+ # Load the documents
84
+ all_doc = docs_return(0)
85
+
86
+ # Use the splitter parameters
87
+ text_splitter = get_text_splitter(splitter_type=splitter_type, chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator, max_tokens=max_tokens)
88
+
89
+ # Split the documents using the text splitter
90
+ docs = text_splitter.split_documents(documents=all_doc)
91
+
92
+ # Create a Chroma vector database
93
+ vectordb = Chroma.from_documents(docs, embeddings, persist_directory="./chroma_db")
94
+
95
+ # Create the retriever
96
+ chroma_retriever = vectordb.as_retriever(
97
+ search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10}
98
+ )
99
+ return chroma_retriever
100
+ else:
101
+ # Load a local Chroma vectorstore
102
+ vectordb = Chroma.load_local("vectorstore", embeddings)
103
+ chroma_retriever = vectordb.as_retriever(
104
+ search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10}
105
+ )
106
  return chroma_retriever