dnzblgn commited on
Commit
e4d5b9b
Β·
verified Β·
1 Parent(s): 18541f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -73
app.py CHANGED
@@ -1,91 +1,176 @@
1
- import os
2
  import gradio as gr
3
- from langchain_community.embeddings import HuggingFaceEmbeddings
4
- from langchain_community.llms import HuggingFaceEndpoint
5
- from langchain_community.vectorstores import FAISS
6
- from langchain_community.document_loaders import PyPDFLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
8
  from langchain.chains import ConversationalRetrievalChain
 
9
  from langchain.memory import ConversationBufferMemory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Initialize embeddings
12
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
13
 
14
- # Initialize Mistral LLM
15
- llm = HuggingFaceEndpoint(
16
- endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
17
- huggingfacehub_api_token=os.getenv("HF_TOKEN"),
18
- task="text-generation",
19
- )
20
 
21
- def process_pdf(pdf_file):
22
- # Load PDF
23
- loader = PyPDFLoader(pdf_file)
24
- documents = loader.load()
25
-
26
- # Split text into chunks
27
- text_splitter = RecursiveCharacterTextSplitter(
28
- chunk_size=1000,
29
- chunk_overlap=200,
30
- length_function=len
31
  )
32
- chunks = text_splitter.split_documents(documents)
33
-
34
- # Create vector store
35
- vectorstore = FAISS.from_documents(chunks, embeddings)
36
-
37
- return vectorstore
38
 
39
- def setup_rag_chain(vectorstore):
40
- memory = ConversationBufferMemory(
41
- memory_key="chat_history",
42
- return_messages=True,
43
- output_key='answer'
44
- )
45
-
46
- chain = ConversationalRetrievalChain.from_llm(
47
  llm=llm,
48
- retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
49
  memory=memory,
50
- return_source_documents=True,
51
- chain_type="stuff",
52
- verbose=True
53
  )
54
-
55
- return chain
56
-
57
- def get_response(query, chain):
58
- result = chain({"question": query})
59
- return result['answer']
60
 
61
- def create_demo():
62
- def process_file(file):
63
- vectorstore = process_pdf(file.name)
64
- return setup_rag_chain(vectorstore)
65
 
66
- def respond(message, history, chain_state):
67
- if chain_state is None:
68
- return history + [["Please upload a PDF first.", None]]
69
- response = get_response(message, chain_state)
70
- history = history + [[message, response]]
71
- return history
72
-
73
- with gr.Blocks() as demo:
74
- chain_state = gr.State(None)
75
-
76
- with gr.Row():
77
- file_input = gr.File(label="Upload PDF", file_types=[".pdf"])
78
 
79
- chatbot = gr.Chatbot()
80
- msg = gr.Textbox(label="Question")
81
- clear = gr.Button("Clear")
82
 
83
- file_input.upload(fn=process_file, outputs=[chain_state])
84
- msg.submit(fn=respond, inputs=[msg, chatbot, chain_state], outputs=[chatbot])
85
- clear.click(lambda: None, None, chatbot, queue=False)
86
-
87
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if __name__ == "__main__":
90
- demo = create_demo()
91
- demo.launch()
 
 
1
  import gradio as gr
2
+ import os
3
+ import time
4
+ import PyPDF2
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain.memory import ConversationBufferMemory
10
+ from langchain_community.llms import HuggingFaceEndpoint
11
+
12
+ def read_file(file_path):
13
+ try:
14
+ if file_path.endswith(".txt"):
15
+ with open(file_path, "r", encoding="utf-8") as f:
16
+ content = f.read()
17
+ elif file_path.endswith(".pdf"):
18
+ content = ""
19
+ with open(file_path, "rb") as f:
20
+ reader = PyPDF2.PdfReader(f)
21
+ for page in reader.pages:
22
+ content += page.extract_text() + "\n"
23
+ else:
24
+ return None, "Unsupported file format. Please upload a .txt or .pdf file."
25
+
26
+ if not content.strip():
27
+ return None, "File is empty. Please upload a valid document."
28
+
29
+ return content, "Successfully processed the uploaded file! Ready for questions."
30
+ except Exception as e:
31
+ return None, f"Error reading file: {str(e)}"
32
 
33
+ def create_db_from_text(text):
34
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
35
+ splits = text_splitter.create_documents([text])
36
+ embeddings = HuggingFaceEmbeddings()
37
+ vector_db = FAISS.from_documents(splits, embeddings)
38
+ return vector_db
39
 
40
+ def initialize_chatbot(vector_db):
41
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
42
+ retriever = vector_db.as_retriever()
 
 
 
43
 
44
+ llm = HuggingFaceEndpoint(
45
+ repo_id="mistralai/Mistral-7B-Instruct-v0.2",
46
+ huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
47
+ temperature=0.5,
48
+ max_new_tokens=256
 
 
 
 
 
49
  )
 
 
 
 
 
 
50
 
51
+ qa_chain = ConversationalRetrievalChain.from_llm(
 
 
 
 
 
 
 
52
  llm=llm,
53
+ retriever=retriever,
54
  memory=memory,
55
+ verbose=False
 
 
56
  )
57
+ return qa_chain
 
 
 
 
 
58
 
59
+ def process_and_initialize(file):
60
+ if file is None:
61
+ return None, None, "Please upload a file first."
 
62
 
63
+ try:
64
+ text, status_message = read_file(file)
65
+ if text is None:
66
+ return None, None, status_message
 
 
 
 
 
 
 
 
67
 
68
+ db = create_db_from_text(text)
69
+ qa = initialize_chatbot(db)
 
70
 
71
+ return db, qa, status_message
72
+ except Exception as e:
73
+ return None, None, f"Processing error: {str(e)}"
74
+
75
+ def user_query_typing_effect(query, qa_chain, chatbot):
76
+ history = chatbot or []
77
+ try:
78
+ response = qa_chain.invoke({"question": query, "chat_history": []})
79
+ assistant_response = response["answer"]
80
+
81
+ history.append({"role": "user", "content": query})
82
+ history.append({"role": "assistant", "content": ""})
83
+
84
+ for i in range(len(assistant_response)):
85
+ history[-1]["content"] += assistant_response[i]
86
+ yield history, ""
87
+ time.sleep(0.05)
88
+ except Exception as e:
89
+ history.append({"role": "assistant", "content": f"Error: {str(e)}"})
90
+ yield history, ""
91
+
92
+ def demo():
93
+ custom_css = """
94
+ body {
95
+ background-color: #FF8C00;
96
+ font-family: Arial, sans-serif;
97
+ }
98
+ .gradio-container {
99
+ border-radius: 15px;
100
+ box-shadow: 0px 4px 20px rgba(0, 0, 0, 0.3);
101
+ padding: 20px;
102
+ }
103
+ footer {
104
+ visibility: hidden;
105
+ }
106
+ .chatbot {
107
+ border: 2px solid #000;
108
+ border-radius: 10px;
109
+ background-color: #FFF5E1;
110
+ }
111
+ """
112
+
113
+ with gr.Blocks(css=custom_css) as app:
114
+ vector_db = gr.State(None)
115
+ qa_chain = gr.State(None)
116
+
117
+ gr.Markdown("### 🌟 **Document-Based Chatbot** 🌟")
118
+ gr.Markdown("#### Upload your document and ask questions interactively!")
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=1):
122
+ txt_file = gr.File(
123
+ label="πŸ“ Upload Document",
124
+ file_types=[".txt", ".pdf"],
125
+ type="filepath"
126
+ )
127
+ analyze_btn = gr.Button("πŸš€ Process Document")
128
+ status = gr.Textbox(
129
+ label="πŸ“Š Status",
130
+ placeholder="Status updates will appear here...",
131
+ interactive=False
132
+ )
133
+
134
+ with gr.Column(scale=3):
135
+ chatbot = gr.Chatbot(
136
+ label="πŸ€– Chat with your data",
137
+ height=600,
138
+ bubble_full_width=False,
139
+ show_label=False,
140
+ render_markdown=True,
141
+ type="messages",
142
+ elem_classes=["chatbot"]
143
+ )
144
+ query_input = gr.Textbox(
145
+ label="Ask a question",
146
+ placeholder="Ask about the document...",
147
+ show_label=False,
148
+ container=False
149
+ )
150
+ query_btn = gr.Button("Ask")
151
+
152
+ analyze_btn.click(
153
+ fn=process_and_initialize,
154
+ inputs=[txt_file],
155
+ outputs=[vector_db, qa_chain, status],
156
+ show_progress="minimal"
157
+ )
158
+
159
+ query_btn.click(
160
+ fn=user_query_typing_effect,
161
+ inputs=[query_input, qa_chain, chatbot],
162
+ outputs=[chatbot, query_input],
163
+ show_progress="minimal"
164
+ )
165
+
166
+ query_input.submit(
167
+ fn=user_query_typing_effect,
168
+ inputs=[query_input, qa_chain, chatbot],
169
+ outputs=[chatbot, query_input],
170
+ show_progress="minimal"
171
+ )
172
+
173
+ app.launch()
174
 
175
  if __name__ == "__main__":
176
+ demo()