QuaidKhalid commited on
Commit
25d1423
·
verified ·
1 Parent(s): b95521c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain import hub
3
+ from langchain_chroma import Chrom a
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ from sentence_transformers import SentenceTransformer
9
+ import torch
10
+
11
+ # Define the embedding class
12
+ class SentenceTransformerEmbedding:
13
+ def __init__(self, model_name):
14
+ self.model = SentenceTransformer(model_name)
15
+
16
+ def embed_documents(self, texts):
17
+ embeddings = self.model.encode(texts, convert_to_tensor=True)
18
+ if isinstance(embeddings, torch.Tensor):
19
+ return embeddings.cpu().detach().numpy().tolist() # Convert tensor to list
20
+ return embeddings
21
+
22
+ def embed_query(self, query):
23
+ embedding = self.model.encode([query], convert_to_tensor=True)
24
+ if isinstance(embedding, torch.Tensor):
25
+ return embedding.cpu().detach().numpy().tolist()[0] # Convert tensor to list
26
+ return embedding[0]
27
+
28
+ # Initialize the embedding class
29
+ embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
30
+
31
+ # Get API keys for Groq and LangChain
32
+ groq_api_key = "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"
33
+ langchain_api_key = "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"
34
+
35
+ def load_document(document_path):
36
+ try:
37
+ loader = PyPDFLoader(document_path)
38
+ docs = loader.load()
39
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200)
40
+ splits = text_splitter.split_documents(docs)
41
+ return splits
42
+ except Exception as e:
43
+ return str(e)
44
+
45
+ def initialize_chroma(splits):
46
+ try:
47
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
48
+ retriever = vectorstore.as_retriever()
49
+ prompt = hub.pull("rlm/rag-prompt")
50
+ def format_docs(docs):
51
+ return "\n\n".join(doc.page_content for doc in docs)
52
+ rag_chain = (
53
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
54
+ | prompt
55
+ | llm # Replace `llm` with an appropriate language model
56
+ | StrOutputParser()
57
+ )
58
+ return rag_chain
59
+ except Exception as e:
60
+ return str(e)
61
+
62
+ def answer_question(rag_chain, query):
63
+ try:
64
+ result = rag_chain.invoke(query)
65
+ return result
66
+ except Exception as e:
67
+ return str(e)
68
+
69
+ st.title("PDF Question Answering")
70
+ st.write("Upload your PDF document and ask a question!")
71
+
72
+ document_path = st.file_uploader("Upload your PDF document", type=["pdf"])
73
+ query = st.text_input("Enter your question")
74
+
75
+ if document_path is not None and query:
76
+ splits = load_document(document_path)
77
+ if isinstance(splits, str):
78
+ st.write("Error loading document:", splits)
79
+ else:
80
+ rag_chain = initialize_chroma(splits)
81
+ if isinstance(rag_chain, str):
82
+ st.write("Error initializing Chroma:", rag_chain)
83
+ else:
84
+ result = answer_question(rag_chain, query)
85
+ st.write("Result:", result)
86
+
87
+ st.write("Note: Replace `llm` with an appropriate language model.")