usmanyousaf commited on
Commit
d62bf95
Β·
verified Β·
1 Parent(s): 54c78cb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +441 -0
app.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import uuid
5
+ from langchain_groq import ChatGroq
6
+ from langchain.prompts import ChatPromptTemplate
7
+ from langchain.schema import HumanMessage, AIMessage
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.document_loaders import PyPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import Chroma
12
+ from langchain.chains import RetrievalQA
13
+ import re
14
+
15
+ from app import check_custom_db_exists
16
+
17
+ # Custom CSS Injection
18
+ def inject_custom_css():
19
+ st.markdown("""
20
+ <style>
21
+ /* Main container */
22
+ .stApp {
23
+ background: linear-gradient(135deg, #1a1a1a, #2d2d2d);
24
+ color: #e0e0e0;
25
+ }
26
+
27
+ /* Chat containers */
28
+ .stChatMessage {
29
+ padding: 1.5rem;
30
+ border-radius: 15px;
31
+ margin: 1rem 0;
32
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
33
+ }
34
+
35
+ /* User message styling */
36
+ [data-testid="stChatMessage"][aria-label="user"] {
37
+ background-color: #2d2d2d;
38
+ border: 1px solid #3d3d3d;
39
+ margin-left: 10%;
40
+ }
41
+
42
+ /* Assistant message styling */
43
+ [data-testid="stChatMessage"][aria-label="assistant"] {
44
+ background-color: #004d40;
45
+ border: 1px solid #00695c;
46
+ margin-right: 10%;
47
+ }
48
+
49
+ /* Sidebar styling */
50
+ [data-testid="stSidebar"] {
51
+ background: #121212 !important;
52
+ border-right: 2px solid #2d2d2d;
53
+ padding: 1rem;
54
+ }
55
+
56
+ /* Button styling */
57
+ .stButton>button {
58
+ background: linear-gradient(45deg, #00695c, #004d40);
59
+ color: white !important;
60
+ border: none;
61
+ border-radius: 8px;
62
+ padding: 0.8rem 1.5rem;
63
+ transition: all 0.3s;
64
+ font-weight: 500;
65
+ }
66
+
67
+ .stButton>button:hover {
68
+ transform: translateY(-2px);
69
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);
70
+ }
71
+
72
+ /* File uploader */
73
+ [data-testid="stFileUploader"] {
74
+ border: 2px dashed #3d3d3d;
75
+ border-radius: 10px;
76
+ padding: 1rem;
77
+ background: #2d2d2d;
78
+ }
79
+
80
+ /* Input field */
81
+ .stTextInput>div>div>input {
82
+ background-color: #2d2d2d;
83
+ color: white;
84
+ border: 1px solid #3d3d3d;
85
+ border-radius: 8px;
86
+ padding: 0.8rem;
87
+ }
88
+
89
+ /* Spinner color */
90
+ .stSpinner>div>div {
91
+ border-color: #00bcd4 transparent transparent transparent;
92
+ }
93
+
94
+ /* Custom title styling */
95
+ .title-text {
96
+ background: linear-gradient(45deg, #00bcd4, #00695c);
97
+ -webkit-background-clip: text;
98
+ -webkit-text-fill-color: transparent;
99
+ font-family: 'Roboto', sans-serif;
100
+ font-size: 2.8rem;
101
+ text-align: center;
102
+ margin-bottom: 2rem;
103
+ letter-spacing: -0.5px;
104
+ text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
105
+ }
106
+
107
+ /* Similar questions buttons */
108
+ .stButton>button.similar-q {
109
+ background: #2d2d2d;
110
+ border: 1px solid #00bcd4;
111
+ color: #00bcd4 !important;
112
+ white-space: normal;
113
+ height: auto;
114
+ min-height: 3rem;
115
+ transition: all 0.3s;
116
+ }
117
+
118
+ /* Hover effects */
119
+ .stButton>button.similar-q:hover {
120
+ background: #004d40 !important;
121
+ transform: scale(1.02);
122
+ }
123
+
124
+ /* Source text styling */
125
+ .source-text {
126
+ color: #00bcd4;
127
+ font-size: 0.9rem;
128
+ margin-top: 1rem;
129
+ padding-top: 0.5rem;
130
+ border-top: 1px solid #3d3d3d;
131
+ }
132
+ </style>
133
+ """, unsafe_allow_html=True)
134
+
135
+ # Page Configuration
136
+ st.set_page_config(
137
+ page_title="AI Law Agent",
138
+ page_icon="βš–οΈ",
139
+ layout="centered",
140
+ initial_sidebar_state="expanded"
141
+ )
142
+
143
+ # Constants
144
+ DEFAULT_GROQ_API_KEY = "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"
145
+ MODEL_NAME = "llama-3.3-70b-versatile"
146
+ DEFAULT_DOCUMENT_PATH = "/Users/appleenterprises/Desktop/ai law bot/lawbook.pdf"
147
+ DEFAULT_COLLECTION_NAME = "pakistan_laws_default"
148
+ CHROMA_PERSIST_DIR = "./chroma_db"
149
+
150
+ # Session state initialization
151
+ if "messages" not in st.session_state:
152
+ st.session_state.messages = []
153
+ if "user_id" not in st.session_state:
154
+ st.session_state.user_id = str(uuid.uuid4())
155
+ if "vectordb" not in st.session_state:
156
+ st.session_state.vectordb = None
157
+ if "llm" not in st.session_state:
158
+ st.session_state.llm = None
159
+ if "qa_chain" not in st.session_state:
160
+ st.session_state.qa_chain = None
161
+ if "similar_questions" not in st.session_state:
162
+ st.session_state.similar_questions = []
163
+ if "using_custom_docs" not in st.session_state:
164
+ st.session_state.using_custom_docs = False
165
+ if "custom_collection_name" not in st.session_state:
166
+ st.session_state.custom_collection_name = f"custom_laws_{st.session_state.user_id}"
167
+
168
+ def setup_embeddings():
169
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
170
+
171
+ def setup_llm():
172
+ if st.session_state.llm is None:
173
+ st.session_state.llm = ChatGroq(
174
+ model_name=MODEL_NAME,
175
+ groq_api_key=DEFAULT_GROQ_API_KEY,
176
+ temperature=0.2
177
+ )
178
+ return st.session_state.llm
179
+
180
+ def check_default_db_exists():
181
+ return os.path.exists(os.path.join(CHROMA_PERSIST_DIR, DEFAULT_COLLECTION_NAME))
182
+
183
+ def load_existing_vectordb(collection_name):
184
+ try:
185
+ return Chroma(
186
+ persist_directory=CHROMA_PERSIST_DIR,
187
+ embedding_function=setup_embeddings(),
188
+ collection_name=collection_name
189
+ )
190
+ except Exception as e:
191
+ st.error(f"Error loading database: {str(e)}")
192
+ return None
193
+
194
+ def process_default_document(force_rebuild=False):
195
+ if check_default_db_exists() and not force_rebuild:
196
+ db = load_existing_vectordb(DEFAULT_COLLECTION_NAME)
197
+ if db:
198
+ st.session_state.vectordb = db
199
+ setup_qa_chain()
200
+ st.session_state.using_custom_docs = False
201
+ return True
202
+
203
+ if not os.path.exists(DEFAULT_DOCUMENT_PATH):
204
+ st.error("Default document not found.")
205
+ return False
206
+
207
+ try:
208
+ with st.spinner("Building knowledge base..."):
209
+ loader = PyPDFLoader(DEFAULT_DOCUMENT_PATH)
210
+ documents = loader.load()
211
+
212
+ for doc in documents:
213
+ doc.metadata["source"] = "Pakistan Laws (Official)"
214
+
215
+ text_splitter = RecursiveCharacterTextSplitter(
216
+ chunk_size=1000,
217
+ chunk_overlap=200
218
+ )
219
+ chunks = text_splitter.split_documents(documents)
220
+
221
+ db = Chroma.from_documents(
222
+ documents=chunks,
223
+ embedding=setup_embeddings(),
224
+ collection_name=DEFAULT_COLLECTION_NAME,
225
+ persist_directory=CHROMA_PERSIST_DIR
226
+ )
227
+
228
+ db.persist()
229
+ st.session_state.vectordb = db
230
+ setup_qa_chain()
231
+ st.session_state.using_custom_docs = False
232
+ return True
233
+ except Exception as e:
234
+ st.error(f"Error processing document: {str(e)}")
235
+ return False
236
+
237
+ def process_custom_documents(uploaded_files):
238
+ collection_name = st.session_state.custom_collection_name
239
+ documents = []
240
+
241
+ for uploaded_file in uploaded_files:
242
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
243
+ tmp_file.write(uploaded_file.getvalue())
244
+ tmp_path = tmp_file.name
245
+
246
+ try:
247
+ loader = PyPDFLoader(tmp_path)
248
+ file_docs = loader.load()
249
+ for doc in file_docs:
250
+ doc.metadata["source"] = uploaded_file.name
251
+ documents.extend(file_docs)
252
+ os.unlink(tmp_path)
253
+ except Exception as e:
254
+ st.error(f"Error processing {uploaded_file.name}: {str(e)}")
255
+
256
+ if documents:
257
+ text_splitter = RecursiveCharacterTextSplitter(
258
+ chunk_size=1000,
259
+ chunk_overlap=200
260
+ )
261
+ chunks = text_splitter.split_documents(documents)
262
+
263
+ with st.spinner("Analyzing documents..."):
264
+ if check_custom_db_exists(collection_name):
265
+ temp_db = Chroma(
266
+ persist_directory=CHROMA_PERSIST_DIR,
267
+ embedding_function=setup_embeddings(),
268
+ collection_name=collection_name
269
+ )
270
+ temp_db.delete_collection()
271
+
272
+ db = Chroma.from_documents(
273
+ documents=chunks,
274
+ embedding=setup_embeddings(),
275
+ collection_name=collection_name,
276
+ persist_directory=CHROMA_PERSIST_DIR
277
+ )
278
+
279
+ db.persist()
280
+ st.session_state.vectordb = db
281
+ setup_qa_chain()
282
+ st.session_state.using_custom_docs = True
283
+ return True
284
+ return False
285
+
286
+ def setup_qa_chain():
287
+ if st.session_state.vectordb:
288
+ template = """You are a legal expert specializing in Pakistani law.
289
+ Use context to answer. If unsure, state uncertainty but provide general legal info.
290
+
291
+ Context: {context}
292
+
293
+ Question: {question}
294
+
295
+ Answer:"""
296
+
297
+ prompt = ChatPromptTemplate.from_template(template)
298
+
299
+ st.session_state.qa_chain = RetrievalQA.from_chain_type(
300
+ llm=setup_llm(),
301
+ chain_type="stuff",
302
+ retriever=st.session_state.vectordb.as_retriever(search_kwargs={"k": 3}),
303
+ chain_type_kwargs={"prompt": prompt},
304
+ return_source_documents=True
305
+ )
306
+
307
+ def generate_similar_questions(question, docs):
308
+ llm = setup_llm()
309
+ context = "\n".join([doc.page_content for doc in docs[:2]])
310
+
311
+ prompt = f"""Generate 3 specific Pakistani law questions related to:
312
+
313
+ Original: {question}
314
+
315
+ Context: {context}
316
+
317
+ Generate exactly 3 questions:"""
318
+
319
+ try:
320
+ response = llm.invoke(prompt)
321
+ questions = re.findall(r"\d+\.\s+(.*?)(?=\d+\.|$)", response.content, re.DOTALL)
322
+ if not questions:
323
+ questions = response.content.split("\n")
324
+ questions = [q.strip() for q in questions if q.strip() and "?" in q]
325
+ return [q.strip().replace("\n", " ") for q in questions if "?" in q][:3]
326
+ except:
327
+ return []
328
+
329
+ def get_answer(question):
330
+ if not st.session_state.vectordb:
331
+ with st.spinner("Initializing system..."):
332
+ process_default_document()
333
+
334
+ if st.session_state.qa_chain:
335
+ result = st.session_state.qa_chain({"query": question})
336
+ answer = result["result"]
337
+
338
+ st.session_state.similar_questions = generate_similar_questions(question, result.get("source_documents", []))
339
+
340
+ sources = set()
341
+ for doc in result.get("source_documents", []):
342
+ if "source" in doc.metadata:
343
+ sources.add(doc.metadata["source"])
344
+
345
+ if sources:
346
+ answer += f"\n\n<div class='source-text'>Sources: {', '.join(sources)}</div>"
347
+
348
+ return answer
349
+ return "System initializing... Please try again."
350
+
351
+ def main():
352
+ inject_custom_css()
353
+
354
+ st.markdown("""
355
+ <h1 class="title-text">
356
+ <div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem;">
357
+ <span>βš–οΈ</span>
358
+ <span>Your AI Law Agent</span>
359
+ </div>
360
+ </h1>
361
+ """, unsafe_allow_html=True)
362
+
363
+ # Sidebar Management
364
+ with st.sidebar:
365
+ st.header("πŸ“š Document Management")
366
+
367
+ if st.session_state.using_custom_docs:
368
+ if st.button("πŸ”™ Return to Official Database", use_container_width=True):
369
+ with st.spinner("Switching..."):
370
+ process_default_document()
371
+ st.session_state.messages.append(AIMessage(content="Switched to official database"))
372
+ st.rerun()
373
+
374
+ if not st.session_state.using_custom_docs:
375
+ if st.button("πŸ”„ Rebuild Database", use_container_width=True):
376
+ with st.spinner("Rebuilding..."):
377
+ process_default_document(force_rebuild=True)
378
+ st.rerun()
379
+
380
+ st.header("πŸ“ Upload Documents")
381
+ uploaded_files = st.file_uploader(
382
+ "Upload legal PDFs",
383
+ type=["pdf"],
384
+ accept_multiple_files=True,
385
+ label_visibility="collapsed"
386
+ )
387
+
388
+ if st.button("πŸš€ Train on Uploads", use_container_width=True) and uploaded_files:
389
+ with st.spinner("Processing..."):
390
+ if process_custom_documents(uploaded_files):
391
+ st.session_state.messages.append(AIMessage(content="Custom documents loaded"))
392
+ st.rerun()
393
+
394
+ # Chat Display
395
+ for message in st.session_state.messages:
396
+ avatar = "πŸ‘€" if isinstance(message, HumanMessage) else "βš–οΈ"
397
+ with st.chat_message("user" if isinstance(message, HumanMessage) else "assistant", avatar=avatar):
398
+ st.write(message.content)
399
+
400
+ # Similar Questions
401
+ if st.session_state.similar_questions:
402
+ st.markdown("""
403
+ <div style="padding: 1rem; background: #2d2d2d; border-radius: 10px; margin: 1rem 0;">
404
+ <h4 style="color: #00bcd4; margin-bottom: 0.5rem;">πŸ” Related Queries</h4>
405
+ """, unsafe_allow_html=True)
406
+
407
+ cols = st.columns([1,1,1])
408
+ for i, question in enumerate(st.session_state.similar_questions):
409
+ with cols[i]:
410
+ if st.button(
411
+ f"❓ {question}",
412
+ key=f"similar_q_{i}",
413
+ use_container_width=True,
414
+ help="Click to ask this related question"
415
+ ):
416
+ st.session_state.messages.append(HumanMessage(content=question))
417
+ with st.chat_message("assistant", avatar="βš–οΈ"):
418
+ with st.spinner("Analyzing..."):
419
+ response = get_answer(question)
420
+ st.write(response, unsafe_allow_html=True)
421
+ st.session_state.messages.append(AIMessage(content=response))
422
+ st.rerun()
423
+
424
+ st.markdown("</div>", unsafe_allow_html=True)
425
+
426
+ # Input Handling
427
+ if user_input := st.chat_input("Ask your legal question..."):
428
+ st.session_state.messages.append(HumanMessage(content=user_input))
429
+ with st.chat_message("user"):
430
+ st.write(user_input)
431
+
432
+ with st.chat_message("assistant", avatar="βš–οΈ"):
433
+ with st.spinner("Researching..."):
434
+ response = get_answer(user_input)
435
+ st.write(response, unsafe_allow_html=True)
436
+
437
+ st.session_state.messages.append(AIMessage(content=response))
438
+ st.rerun()
439
+
440
+ if __name__ == "__main__":
441
+ main()