File size: 4,113 Bytes
d8f0836 36654b6 7d9b1aa d8f0836 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import HumanMessage
from langchain.document_loaders import UnstructuredFileLoader
#from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_community.vectorstores import Chroma
from langchain_groq import ChatGroq
import gradio as gr
# Initialize ChromaDB and Groq API
DB_DIR = "chroma_db"
COLLECTION_NAME = "document_collection"
embedding_function = HuggingFaceEmbeddings()
GROQ_API_KEY = groq_api_key = os.environ.get("GROQ_API_KEY")
llm = ChatGroq(api_key=GROQ_API_KEY, model_name="llama-3.1-8b-instant")
# Keep track of current document ID
current_document_id = None
def load_and_split_document(file_path):
"""Loads a document and splits it into chunks."""
loader = UnstructuredFileLoader(file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=50)
chunks = text_splitter.split_documents(documents)
return chunks
def upload_and_process(file):
"""Processes uploaded file and stores it in ChromaDB."""
try:
global current_document_id
uploaded_file_path = file.name
# Generate a unique document ID (using filename in this case)
current_document_id = os.path.basename(uploaded_file_path)
# Load and split the document
chunks = load_and_split_document(uploaded_file_path)
# Add document ID as metadata to each chunk
for chunk in chunks:
chunk.metadata['document_id'] = current_document_id
# Get or create vector store
vector_store = Chroma(
persist_directory=DB_DIR,
embedding_function=embedding_function,
collection_name=COLLECTION_NAME
)
# Add new documents
vector_store.add_documents(chunks)
return f"Document successfully processed: {current_document_id}"
except Exception as e:
return f"Error processing document: {str(e)}"
def retrieve_and_generate_response(query):
"""Retrieves relevant text and uses Groq LLM to generate a response."""
try:
vector_store = Chroma(
persist_directory=DB_DIR,
embedding_function=embedding_function,
collection_name=COLLECTION_NAME
)
# Only search within the current document
if current_document_id:
filter_dict = {"document_id": current_document_id}
results = vector_store.similarity_search(
query,
k=2,
filter=filter_dict
)
else:
return "Please upload a document first."
retrieved_texts = [doc.page_content for doc in results]
context = "\n".join(retrieved_texts)
if not context:
return "No relevant content found in the current document."
messages = [
HumanMessage(content=f"Use the following context to answer the question:\n\n{context}\n\nQuestion: {query}")
]
response = llm.invoke(messages)
return response.content
except Exception as e:
return f"Error generating response: {str(e)}"
# Define the Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🤖 RAG Chatbot with Groq & ChromaDB")
file_input = gr.File(label="Upload a PDF")
upload_button = gr.Button("Process Document")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
query_input = gr.Textbox(label="Ask a Question")
response_output = gr.Textbox(label="Response", interactive=False)
chat_button = gr.Button("Get Answer")
upload_button.click(
upload_and_process,
inputs=[file_input],
outputs=[upload_status]
)
chat_button.click(
retrieve_and_generate_response, # Use the function directly
inputs=[query_input],
outputs=[response_output]
)
# Launch the Gradio app
demo.launch()
|