Spaces:
Sleeping
Sleeping
Commit
·
c586a3a
1
Parent(s):
0484f8b
rag.py: use optionally HuggingFaceHub zephyr model
Browse files- src/rag.py +21 -5
src/rag.py
CHANGED
|
@@ -6,9 +6,9 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
| 6 |
|
| 7 |
# from langchain.embeddings import OpenAIEmbeddings
|
| 8 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 9 |
-
|
| 10 |
from langchain.vectorstores import Chroma
|
| 11 |
from langchain.chat_models import ChatOpenAI
|
|
|
|
| 12 |
from langchain.chains import RetrievalQA
|
| 13 |
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 14 |
from tqdm import tqdm
|
|
@@ -24,7 +24,13 @@ class RAG():
|
|
| 24 |
self.pdfs = pdfs # Source PDFs to encode in vectorestore
|
| 25 |
self.k = 3 # Number of relevant chunks to retrieve
|
| 26 |
|
| 27 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
dotenv.load_dotenv(dotenv.find_dotenv())
|
| 29 |
|
| 30 |
# Placeholders:
|
|
@@ -59,20 +65,30 @@ class RAG():
|
|
| 59 |
|
| 60 |
def create_embeddings(self):
|
| 61 |
# embeddings = OpenAIEmbeddings()
|
|
|
|
| 62 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 63 |
return embeddings
|
| 64 |
|
| 65 |
def create_retriever(self, texts, embeddings):
|
| 66 |
# Create embeddings and vector store
|
|
|
|
| 67 |
vectorstore = Chroma.from_documents(texts, embeddings)
|
| 68 |
retriever = vectorstore.as_retriever(search_kwargs={"k": self.k})
|
| 69 |
return retriever
|
| 70 |
|
| 71 |
def create_llm(self):
|
| 72 |
# Create the language model
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
return llm
|
| 77 |
|
| 78 |
def create_QAbot(self, retriever, llm):
|
|
|
|
| 6 |
|
| 7 |
# from langchain.embeddings import OpenAIEmbeddings
|
| 8 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
| 9 |
from langchain.vectorstores import Chroma
|
| 10 |
from langchain.chat_models import ChatOpenAI
|
| 11 |
+
from langchain.llms import HuggingFaceHub
|
| 12 |
from langchain.chains import RetrievalQA
|
| 13 |
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 14 |
from tqdm import tqdm
|
|
|
|
| 24 |
self.pdfs = pdfs # Source PDFs to encode in vectorestore
|
| 25 |
self.k = 3 # Number of relevant chunks to retrieve
|
| 26 |
|
| 27 |
+
# Constants
|
| 28 |
+
# self.use_model = 'gpt-4o-mini'
|
| 29 |
+
self.use_model = 'zephyr-7b-alpha'
|
| 30 |
+
|
| 31 |
+
# Load environment variables that should contain:
|
| 32 |
+
# - 'OPENAI_API_KEY' for OpenAI models
|
| 33 |
+
# - 'HUGGINGFACEHUB_API_TOKEN' for HuggingFace models
|
| 34 |
dotenv.load_dotenv(dotenv.find_dotenv())
|
| 35 |
|
| 36 |
# Placeholders:
|
|
|
|
| 65 |
|
| 66 |
def create_embeddings(self):
|
| 67 |
# embeddings = OpenAIEmbeddings()
|
| 68 |
+
print ('Using Embeddings from HuggingFace: all-MiniLM-L6-v2')
|
| 69 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 70 |
return embeddings
|
| 71 |
|
| 72 |
def create_retriever(self, texts, embeddings):
|
| 73 |
# Create embeddings and vector store
|
| 74 |
+
print ('Creating vectore store with Chroma')
|
| 75 |
vectorstore = Chroma.from_documents(texts, embeddings)
|
| 76 |
retriever = vectorstore.as_retriever(search_kwargs={"k": self.k})
|
| 77 |
return retriever
|
| 78 |
|
| 79 |
def create_llm(self):
|
| 80 |
# Create the language model
|
| 81 |
+
if self.use_model == 'gpt-4o-mini':
|
| 82 |
+
print(f'As llm, using OpenAI model: {self.use_model}')
|
| 83 |
+
llm = ChatOpenAI(
|
| 84 |
+
model_name="gpt-4o-mini",
|
| 85 |
+
temperature=0)
|
| 86 |
+
elif self.use_model == 'zephyr-7b-alpha':
|
| 87 |
+
print(f'As llm, using HF model: {self.use_model}')
|
| 88 |
+
llm = HuggingFaceHub(
|
| 89 |
+
repo_id="huggingfaceh4/zephyr-7b-alpha",
|
| 90 |
+
model_kwargs={"temperature": 0.5, "max_length": 64,"max_new_tokens":512}
|
| 91 |
+
)
|
| 92 |
return llm
|
| 93 |
|
| 94 |
def create_QAbot(self, retriever, llm):
|