Rahul2298 commited on
Commit
ac882e7
·
verified ·
1 Parent(s): 3b85947

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +144 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,146 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import faiss
3
+ import os
4
+ from io import BytesIO
5
+ from docx import Document
6
+ import numpy as np
7
+ from langchain_community.document_loaders import WebBaseLoader
8
+ from PyPDF2 import PdfReader
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.text_splitter import CharacterTextSplitter
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain_community.docstore.in_memory import InMemoryDocstore
14
+ from langchain_huggingface import HuggingFaceEndpoint
15
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
16
+
17
+
18
+ # from secret_api_keys import huggingface_api_key # Set the Hugging Face Hub API token as an environment variable
19
+ huggingface_api_key = "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"
20
+ os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
21
+
22
+ def process_input(input_type, input_data):
23
+ """Processes different input types and returns a vectorstore."""
24
+ loader = None
25
+ if input_type == "Link":
26
+ loader = WebBaseLoader(input_data)
27
+ documents = loader.load()
28
+ elif input_type == "PDF":
29
+ if isinstance(input_data, BytesIO):
30
+ pdf_reader = PdfReader(input_data)
31
+ elif isinstance(input_data, UploadedFile):
32
+ pdf_reader = PdfReader(BytesIO(input_data.read()))
33
+ else:
34
+ raise ValueError("Invalid input data for PDF")
35
+ text = ""
36
+ for page in pdf_reader.pages:
37
+ text += page.extract_text()
38
+ documents = text
39
+ elif input_type == "Text":
40
+ if isinstance(input_data, str):
41
+ documents = input_data # Input is already a text string
42
+ else:
43
+ raise ValueError("Expected a string for 'Text' input type.")
44
+ elif input_type == "DOCX":
45
+ if isinstance(input_data, BytesIO):
46
+ doc = Document(input_data)
47
+ elif isinstance(input_data, UploadedFile):
48
+ doc = Document(BytesIO(input_data.read()))
49
+ else:
50
+ raise ValueError("Invalid input data for DOCX")
51
+ text = "\n".join([para.text for para in doc.paragraphs])
52
+ documents = text
53
+ elif input_type == "TXT":
54
+ if isinstance(input_data, BytesIO):
55
+ text = input_data.read().decode('utf-8')
56
+ elif isinstance(input_data, UploadedFile):
57
+ text = str(input_data.read().decode('utf-8'))
58
+ else:
59
+ raise ValueError("Invalid input data for TXT")
60
+ documents = text
61
+ else:
62
+ raise ValueError("Unsupported input type")
63
+
64
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
65
+ if input_type == "Link":
66
+ texts = text_splitter.split_documents(documents)
67
+ texts = [ str(doc.page_content) for doc in texts ] # Access page_content from each Document
68
+ else:
69
+ texts = text_splitter.split_text(documents)
70
+
71
+ model_name = "sentence-transformers/all-mpnet-base-v2"
72
+ model_kwargs = {'device': 'cpu'}
73
+ encode_kwargs = {'normalize_embeddings': False}
74
+
75
+ hf_embeddings = HuggingFaceEmbeddings(
76
+ model_name=model_name,
77
+ model_kwargs=model_kwargs,
78
+ encode_kwargs=encode_kwargs
79
+ )
80
+ # Create FAISS index
81
+ sample_embedding = np.array(hf_embeddings.embed_query("sample text"))
82
+ dimension = sample_embedding.shape[0]
83
+ index = faiss.IndexFlatL2(dimension)
84
+ # Create FAISS vector store with the embedding function
85
+ vector_store = FAISS(
86
+ embedding_function=hf_embeddings.embed_query,
87
+ index=index,
88
+ docstore=InMemoryDocstore(),
89
+ index_to_docstore_id={},
90
+ )
91
+ vector_store.add_texts(texts) # Add documents to the vector store
92
+ return vector_store
93
+
94
+ # def answer_question(vectorstore, query):
95
+ # """Answers a question based on the provided vectorstore."""
96
+ # llm = HuggingFaceEndpoint(repo_id= 'meta-llama/Meta-Llama-3-8B-Instruct',
97
+ # huggingfacehub_api_token = huggingface_api_key, temperature= 0.6)
98
+ # qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
99
+
100
+ # answer = qa({"query": query})
101
+ # return answer
102
+
103
+ # In your answer_question function
104
+
105
+ def answer_question(vectorstore, query):
106
+ """Answers a question based on the provided vectorstore."""
107
+ llm = HuggingFaceEndpoint(
108
+ # repo_id='meta-llama/Meta-Llama-3-8B-Instruct',
109
+ huggingfacehub_api_token=huggingface_api_key,
110
+ temperature=0.6,
111
+ # Add this line to ensure you're using the official HF endpoint
112
+ endpoint_url="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
113
+ )
114
+ qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
115
+ answer = qa.invoke({"query": query})
116
+ return answer
117
+
118
+ def main():
119
+ st.title("RAG Q&A App")
120
+ input_type = st.selectbox("Input Type", ["Link", "PDF", "Text", "DOCX", "TXT"])
121
+ if input_type == "Link":
122
+ number_input = st.number_input(min_value=1, max_value=20, step=1, label = "Enter the number of Links")
123
+ input_data = []
124
+ for i in range(number_input):
125
+ url = st.sidebar.text_input(f"URL {i+1}")
126
+ input_data.append(url)
127
+ elif input_type == "Text":
128
+ input_data = st.text_input("Enter the text")
129
+ elif input_type == 'PDF':
130
+ input_data = st.file_uploader("Upload a PDF file", type=["pdf"])
131
+ elif input_type == 'TXT':
132
+ input_data = st.file_uploader("Upload a text file", type=['txt'])
133
+ elif input_type == 'DOCX':
134
+ input_data = st.file_uploader("Upload a DOCX file", type=[ 'docx', 'doc'])
135
+ if st.button("Proceed"):
136
+ # st.write(process_input(input_type, input_data))
137
+ vectorstore = process_input(input_type, input_data)
138
+ st.session_state["vectorstore"] = vectorstore
139
+ if "vectorstore" in st.session_state:
140
+ query = st.text_input("Ask your question")
141
+ if st.button("Submit"):
142
+ answer = answer_question(st.session_state["vectorstore"], query)
143
+ st.write(answer)
144
 
145
+ if __name__ == "__main__":
146
+ main()