namfam commited on
Commit
9915cfc
·
verified ·
1 Parent(s): 7001200

Update chains.py

Browse files
Files changed (1) hide show
  1. chains.py +125 -125
chains.py CHANGED
@@ -1,126 +1,126 @@
1
- from langchain.document_loaders import SitemapLoader, RecursiveUrlLoader, WebBaseLoader
2
- from langchain.text_splitter import RecursiveCharacterTextSplitter
3
- from langchain.vectorstores import FAISS, Chroma
4
- from langchain_openai import OpenAIEmbeddings, ChatOpenAI
5
- from langchain_core.output_parsers import StrOutputParser
6
- from langchain_core.runnables import RunnablePassthrough, RunnableParallel
7
-
8
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
9
-
10
- from prompts import qa_prompt, condense_question_prompt
11
- from db import load_session_history, save_message
12
-
13
- from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings
14
-
15
- def get_llm():
16
- # llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=1000)
17
- llm = GoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0, max_tokens=1000)
18
- return llm
19
-
20
- def get_embeddings():
21
- # embeddings = OpenAIEmbeddings()
22
- embeddings = GoogleGenerativeAIEmbeddings(
23
- model="models/text-embedding-004",
24
- # output_dimensionality=768
25
- )
26
- return embeddings
27
-
28
- def load_documents(urls):
29
-
30
- loader = WebBaseLoader(urls)
31
-
32
- # docs = sitemap_loader.load()
33
- docs = loader.load()
34
-
35
- return docs
36
-
37
- def get_keyword_retriever(docs):
38
-
39
- keyword_retriever = BM25Retriever.from_documents(docs)
40
- return keyword_retriever
41
-
42
- def create_vector_db(collection_name, docs):
43
- # # Split
44
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000,
45
- chunk_overlap=200)
46
-
47
- # Split the documents into smaller text chunks
48
- texts = text_splitter.split_documents(docs)
49
- persist_directory = "persist"
50
-
51
- # Create a new Chroma collection from the text chunks
52
- try:
53
- vector_db = Chroma.from_documents(
54
- documents=texts,
55
- embedding=get_embeddings(),
56
- persist_directory=persist_directory,
57
- collection_name=collection_name,
58
- )
59
- except Exception as e:
60
- print(f"Error creating collection: {e}")
61
- return None
62
-
63
- return vector_db
64
-
65
- def load_vector_db(collection_name):
66
- persist_directory = "../persist"
67
- # Load the Chroma collection from the specified directory
68
- vector_db = Chroma(
69
- persist_directory=persist_directory,
70
- embedding_function=get_embeddings(),
71
- collection_name=collection_name,
72
- )
73
-
74
- return vector_db
75
-
76
- def get_vectordb_retriever(vector_db):
77
- # print("vector_db:", vector_db)
78
- vector_db_retriever = vector_db.as_retriever()
79
-
80
- return vector_db_retriever
81
-
82
- def get_rag_chain():
83
- llm = get_llm()
84
- urls = [
85
- 'https://ati.vn/',
86
- ]
87
- docs = load_documents(urls)
88
-
89
- vector_db = create_vector_db(collection_name="ask_ati", docs=docs)
90
- keyword_retriever = get_keyword_retriever(docs)
91
- vectordb_retriever = get_vectordb_retriever(vector_db)
92
-
93
- ensemble_retriever = EnsembleRetriever(retrievers=[keyword_retriever, vectordb_retriever],
94
- weights=[0.5, 0.5])
95
-
96
- condense_question_chain = condense_question_prompt | llm | StrOutputParser()
97
- context_chain = condense_question_chain | ensemble_retriever
98
- rag_chain = qa_prompt | llm | StrOutputParser()
99
-
100
- parallel_chain = RunnableParallel({
101
- "context": lambda x: x["context"],
102
- "question": lambda x: x["question"],
103
- "chat_history": lambda x: x["chat_history"]
104
- })
105
-
106
- rag_with_sources_chain = RunnablePassthrough.assign(
107
- context=context_chain,
108
- question=condense_question_chain
109
- ) | parallel_chain.assign(answer=rag_chain)
110
-
111
- return rag_with_sources_chain
112
-
113
-
114
- def get_response(session_id, question):
115
-
116
- chat_history = load_session_history(session_id).messages
117
- chat_history = chat_history[-6:] # using last 3 turns of chat
118
- # print(chat_history)
119
-
120
- chain = get_rag_chain()
121
- input = {"question": question, "chat_history": chat_history}
122
- # response = chain.invoke(input)
123
- response = chain.invoke(input)
124
-
125
-
126
  return response
 
1
+ from langchain.document_loaders import SitemapLoader, RecursiveUrlLoader, WebBaseLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.vectorstores import FAISS, Chroma
4
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
7
+
8
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
9
+
10
+ from prompts import qa_prompt, condense_question_prompt
11
+ from db import load_session_history, save_message
12
+
13
+ from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings
14
+
15
+ def get_llm():
16
+ # llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=1000)
17
+ llm = GoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0, max_tokens=1000)
18
+ return llm
19
+
20
+ def get_embeddings():
21
+ # embeddings = OpenAIEmbeddings()
22
+ embeddings = GoogleGenerativeAIEmbeddings(
23
+ model="models/text-embedding-004",
24
+ # output_dimensionality=768
25
+ )
26
+ return embeddings
27
+
28
+ def load_documents(urls):
29
+
30
+ loader = WebBaseLoader(urls)
31
+
32
+ # docs = sitemap_loader.load()
33
+ docs = loader.load()
34
+
35
+ return docs
36
+
37
+ def get_keyword_retriever(docs):
38
+
39
+ keyword_retriever = BM25Retriever.from_documents(docs)
40
+ return keyword_retriever
41
+
42
+ def create_vector_db(collection_name, docs):
43
+ # # Split
44
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000,
45
+ chunk_overlap=200)
46
+
47
+ # Split the documents into smaller text chunks
48
+ texts = text_splitter.split_documents(docs)
49
+ persist_directory = "persist"
50
+
51
+ # Create a new Chroma collection from the text chunks
52
+ try:
53
+ vector_db = Chroma.from_documents(
54
+ documents=texts,
55
+ embedding=get_embeddings(),
56
+ persist_directory=persist_directory,
57
+ collection_name=collection_name,
58
+ )
59
+ except Exception as e:
60
+ print(f"Error creating collection: {e}")
61
+ return None
62
+
63
+ return vector_db
64
+
65
+ def load_vector_db(collection_name):
66
+ persist_directory = "persist"
67
+ # Load the Chroma collection from the specified directory
68
+ vector_db = Chroma(
69
+ persist_directory=persist_directory,
70
+ embedding_function=get_embeddings(),
71
+ collection_name=collection_name,
72
+ )
73
+
74
+ return vector_db
75
+
76
+ def get_vectordb_retriever(vector_db):
77
+ # print("vector_db:", vector_db)
78
+ vector_db_retriever = vector_db.as_retriever()
79
+
80
+ return vector_db_retriever
81
+
82
+ def get_rag_chain():
83
+ llm = get_llm()
84
+ urls = [
85
+ 'https://ati.vn/',
86
+ ]
87
+ docs = load_documents(urls)
88
+
89
+ vector_db = create_vector_db(collection_name="ask_ati", docs=docs)
90
+ keyword_retriever = get_keyword_retriever(docs)
91
+ vectordb_retriever = get_vectordb_retriever(vector_db)
92
+
93
+ ensemble_retriever = EnsembleRetriever(retrievers=[keyword_retriever, vectordb_retriever],
94
+ weights=[0.5, 0.5])
95
+
96
+ condense_question_chain = condense_question_prompt | llm | StrOutputParser()
97
+ context_chain = condense_question_chain | ensemble_retriever
98
+ rag_chain = qa_prompt | llm | StrOutputParser()
99
+
100
+ parallel_chain = RunnableParallel({
101
+ "context": lambda x: x["context"],
102
+ "question": lambda x: x["question"],
103
+ "chat_history": lambda x: x["chat_history"]
104
+ })
105
+
106
+ rag_with_sources_chain = RunnablePassthrough.assign(
107
+ context=context_chain,
108
+ question=condense_question_chain
109
+ ) | parallel_chain.assign(answer=rag_chain)
110
+
111
+ return rag_with_sources_chain
112
+
113
+
114
+ def get_response(session_id, question):
115
+
116
+ chat_history = load_session_history(session_id).messages
117
+ chat_history = chat_history[-6:] # using last 3 turns of chat
118
+ # print(chat_history)
119
+
120
+ chain = get_rag_chain()
121
+ input = {"question": question, "chat_history": chat_history}
122
+ # response = chain.invoke(input)
123
+ response = chain.invoke(input)
124
+
125
+
126
  return response