File size: 3,524 Bytes
25d1423
 
7687529
25d1423
 
 
 
 
 
aa1686a
 
b329469
d5b1c8d
7687529
25d1423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e0aff0
b329469
25d1423
 
aa1686a
 
 
 
 
 
 
 
 
25d1423
 
 
 
 
 
 
 
 
 
 
 
 
3f6019e
25d1423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1686a
 
b329469
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
from langchain import hub
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import torch
import os
import tempfile
from langchain_groq import ChatGroq


# Define the embedding class
class SentenceTransformerEmbedding:
    def __init__(self, model_name):
        self.model = SentenceTransformer(model_name)

    def embed_documents(self, texts):
        embeddings = self.model.encode(texts, convert_to_tensor=True)
        if isinstance(embeddings, torch.Tensor):
            return embeddings.cpu().detach().numpy().tolist()  # Convert tensor to list
        return embeddings

    def embed_query(self, query):
        embedding = self.model.encode([query], convert_to_tensor=True)
        if isinstance(embedding, torch.Tensor):
            return embedding.cpu().detach().numpy().tolist()[0]  # Convert tensor to list
        return embedding[0]

# Initialize the embedding class
embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')

# Get API keys for Groq and LangChain
groq_api_key = "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"
langchain_api_key = "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"

llm = ChatGroq(model="llama3-8b-8192", groq_api_key=groq_api_key)

def load_document(document_path):
    try:
        with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_file = os.path.join(tmp_dir, 'temp.pdf')
            with open(tmp_file, 'wb') as f:
                f.write(document_path.getvalue())
            loader = PyPDFLoader(tmp_file)
            docs = loader.load()
            text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200)
            splits = text_splitter.split_documents(docs)
            return splits
    except Exception as e:
        return str(e)

def initialize_chroma(splits):
    try:
        vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
        retriever = vectorstore.as_retriever()
        prompt = hub.pull("rlm/rag-prompt")
        def format_docs(docs):
            return "\n\n".join(doc.page_content for doc in docs)
        rag_chain = (
            {"context": retriever | format_docs, "question": RunnablePassthrough()}
            | prompt
            | llm
            | StrOutputParser()
        )
        return rag_chain
    except Exception as e:
        return str(e)

def answer_question(rag_chain, query):
    try:
        result = rag_chain.invoke(query)
        return result
    except Exception as e:
        return str(e)

st.title("PDF Question Answering")
st.write("Upload your PDF document and ask a question!")

document_path = st.file_uploader("Upload your PDF document", type=["pdf"])
query = st.text_input("Enter your question")

if document_path is not None and query:
    splits = load_document(document_path)
    if isinstance(splits, str):
        st.write("Error loading document:", splits)
    else:
        rag_chain = initialize_chroma(splits)
        if isinstance(rag_chain, str):
            st.write("Error initializing Chroma:", rag_chain)
        else:
            result = answer_question(rag_chain, query)
            st.write("Result:", result)



# st.write("Note: Replace `llm` with an appropriate language model.")