|
|
|
|
|
|
| from langchain.embeddings import HuggingFaceEmbeddings
|
| from langchain.vectorstores import Chroma
|
| from langchain.text_splitter import CharacterTextSplitter
|
| from langchain.llms import HuggingFaceHub
|
| from langchain.chains import RetrievalQA
|
| import os
|
|
|
| os.environ["HF_API_TOKEN"] = "your_huggingface_api_token"
|
|
|
| class Chatbot:
|
| def __init__(self, db_path="chroma_db"):
|
| self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| self.db_path = db_path
|
| self.vectorstore = None
|
| self.retriever = None
|
| self.qa_chain = None
|
|
|
| def load_documents(self, documents):
|
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
| texts = text_splitter.split_documents(documents)
|
| self.vectorstore = Chroma.from_documents(texts, self.embeddings, persist_directory=self.db_path)
|
| self.vectorstore.persist()
|
| self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 2})
|
| self.qa_chain = RetrievalQA.from_chain_type(llm=HuggingFaceHub(repo_id="google/flan-t5-base"), chain_type="stuff", retriever=self.retriever)
|
|
|
| def get_response(self, query):
|
| if not self.qa_chain:
|
| return "No documents loaded. Please load documents first."
|
| return self.qa_chain.run(query)
|
|
|
|
|
| if __name__ == "__main__":
|
| chatbot = Chatbot()
|
|
|
|
|
| response = chatbot.get_response("What is One.Chat?")
|
| print(response)
|
|
|
|
|