Update src/streamlit_app.py

#1
by omarkashif - opened
Files changed (1) hide show
  1. src/streamlit_app.py +194 -55
src/streamlit_app.py CHANGED
@@ -7,19 +7,139 @@ os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models"
7
 
8
  import streamlit as st
9
  import openai
 
10
  from collections import deque
11
  from sentence_transformers import SentenceTransformer
12
- from pinecone import Pinecone
13
  import re
14
 
15
- # Setup (exact hardcoded keys you provided)
16
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
17
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
18
- index = pc.Index("legal-ai")
19
- model = SentenceTransformer('all-mpnet-base-v2')
20
- chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
21
  ll_model = 'gpt-4o-mini'
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  st.title("AI Legal Assistant βš–οΈ")
24
 
25
  if "history" not in st.session_state:
@@ -30,7 +150,7 @@ def get_rewritten_query(user_query):
30
  hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist)
31
  messages = [
32
  {"role": "system", "content":
33
- "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewite"},
34
  {"role": "user", "content":
35
  f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n"
36
  "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."}
@@ -46,76 +166,103 @@ def get_rewritten_query(user_query):
46
  except Exception as e:
47
  st.error(f"Rewrite error: {e}")
48
  rewritten = user_query
49
- # st.session_state.history.append({"role": "assistant", "content": f"πŸ” Rewritten query: {rewritten}"})
50
  return rewritten
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def retrieve_documents(query, top_k=10):
53
- emb = model.encode(query).tolist()
54
- try:
55
- return index.query(vector=emb, top_k=top_k, include_metadata=True)['matches']
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception as e:
57
  st.error(f"Retrieve error: {e}")
58
  return []
59
 
60
-
61
  def clean_chunk_id(cid: str) -> str:
62
- """Beautify chunk_id by replacing underscores/dashes with spaces and capitalizing words."""
63
- # Remove any trailing '_chunk_xxx' stuff
64
  cid = re.sub(r'_chunk.*$', '', cid)
65
- # Replace _ and - with spaces
66
  cid = cid.replace("_", " ").replace("-", " ")
67
- # Capitalize each word
68
  cid = " ".join(word.capitalize() for word in cid.split())
69
  return cid
70
 
71
-
72
-
73
  def generate_response(user_query, docs):
74
- # --- Collect context ---
75
- context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
76
 
77
- # --- Build human-friendly sources + mapping ---
78
  source_links = {}
79
  for d in docs:
80
- meta = d['metadata']
81
- src = meta.get("source", "unknown").lower()
82
- cid = meta.get("chunk_id", "")
83
- text_preview = " ".join(meta.get("text", "").split()[:30])
84
-
85
- if src in ["constitution"]:
86
- display_name = f"Constitution ({clean_chunk_id(cid)})"
87
-
88
- elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
89
- display_name = f"Tax Ordinance ({clean_chunk_id(cid)})"
90
-
91
- elif src in ["case_law", "case", "tax_case"]:
92
- display_name = f"Case Law: {text_preview}..."
93
 
 
 
94
  else:
95
- display_name = f"{src.title()} ({clean_chunk_id(cid)})"
96
 
97
- source_links[display_name] = meta.get("text", "")
98
 
99
- # Deduplicate
100
  source_links = dict(sorted(source_links.items()))
101
 
102
- # --- System prompt ---
103
  messages = [
104
  {"role": "system", "content":
105
  "You are a helpful legal assistant. Use the provided context from documents to answer the user's question. "
106
  "At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. "
107
  "Formatting rules:\n"
108
- "- For Constitution / Ordinances: show the clean chunk id, no underscores/dashes, capitalized words.\n"
109
- "- For Case law: ignore chunk id, instead show first ~30 words of the case text.\n"
110
  "- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n"
111
  "If multiple are used, separate them with commas."}
112
  ]
113
 
114
- messages.extend(st.session_state.history)
115
-
116
  messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
117
- f"Sources:\n{', '.join(source_links.keys())}\n\n"
118
- f"Question:\n{user_query}"})
 
119
  try:
120
  resp = client.chat.completions.create(
121
  model=ll_model,
@@ -128,19 +275,14 @@ def generate_response(user_query, docs):
128
  st.error(f"Response error: {e}")
129
  reply = "Sorry, I encountered an error generating the answer."
130
 
131
- # Optional: force clean source line if LLM misses it
132
  if source_links:
133
  clean_sources = ", ".join(source_links.keys())
134
  if "Source:" not in reply:
135
  reply += f"\n\nSource: {clean_sources}"
136
 
137
- # Save reply into history
138
  st.session_state.history.append({"role": "assistant", "content": reply})
139
-
140
- # --- Render in Streamlit ---
141
  st.markdown(reply)
142
 
143
- # Add expandable sources
144
  if source_links:
145
  st.write("### Sources")
146
  for name, text in source_links.items():
@@ -149,9 +291,6 @@ def generate_response(user_query, docs):
149
 
150
  return reply
151
 
152
-
153
-
154
-
155
  # Chat UI
156
  with st.form("chat_input", clear_on_submit=True):
157
  user_input = st.text_input("You:", "")
@@ -164,12 +303,12 @@ if submit and user_input:
164
  assistant_reply = generate_response(rewritten, docs)
165
 
166
  c = 0
167
- # Display history
168
  st.markdown("---")
169
  for msg in reversed(st.session_state.history):
170
- c+=1
171
  if msg["role"] == "user":
172
  st.markdown(f"**You:** {msg['content']}")
173
  else:
174
  st.markdown(f"**Legal Assistant:** {msg['content']}")
175
- if c ^ 1: st.markdown("---")
 
 
7
 
8
  import streamlit as st
9
  import openai
10
+ import psycopg2
11
  from collections import deque
12
  from sentence_transformers import SentenceTransformer
 
13
  import re
14
 
15
+ # Setup
16
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
 
17
  ll_model = 'gpt-4o-mini'
18
 
19
+ # ── NEW: PostgreSQL connection ──────────────────────────────
20
+ def get_db_connection():
21
+ return psycopg2.connect(
22
+ host=os.getenv("RDS_HOST"),
23
+ port=os.getenv("RDS_PORT", 5432),
24
+ dbname=os.getenv("RDS_DB"),
25
+ user=os.getenv("RDS_USER"),
26
+ password=os.getenv("RDS_PASS")
27
+ )
28
+
29
+ # ── NEW: BGE model ──────────────────────────────────────────
30
+ model = SentenceTransformer('BAAI/bge-small-en-v1.5')
31
+
32
+ def retrieve_summaries(query, top_k=40):
33
+ try:
34
+ embedding = get_embedding(query)
35
+ conn = get_db_connection()
36
+ cur = conn.cursor()
37
+ cur.execute("""
38
+ SELECT
39
+ id,
40
+ case_id,
41
+ chunk_index,
42
+ chunk_summary,
43
+ 1 - (embedding <=> %s::vector) AS similarity
44
+ FROM public.case_chunks
45
+ ORDER BY embedding <=> %s::vector
46
+ LIMIT %s;
47
+ """, [embedding, embedding, top_k])
48
+ rows = cur.fetchall()
49
+ cur.close()
50
+ conn.close()
51
+
52
+ return [
53
+ {
54
+ "id": row[0],
55
+ "case_id": row[1],
56
+ "chunk_index": row[2],
57
+ "chunk_summary": row[3],
58
+ "similarity": row[4]
59
+ }
60
+ for row in rows
61
+ ]
62
+ except Exception as e:
63
+ st.error(f"Retrieve error: {e}")
64
+ return []
65
+
66
+
67
+ # ── STEP 2: LLM picks best chunks based on summaries ───────
68
+ def rerank_with_llm(query, candidates, final_k=10):
69
+ summary_list = "\n".join([
70
+ f"[ID: {c['id']}] Case: {c['case_id']} | Summary: {c['chunk_summary']}"
71
+ for c in candidates
72
+ ])
73
+
74
+ messages = [
75
+ {"role": "system", "content":
76
+ "You are a legal research assistant. Given a user query and a list of document chunk summaries, "
77
+ "select the most relevant chunk IDs that would best answer the query. "
78
+ "Return ONLY a comma-separated list of IDs, nothing else. Example: 12,45,67,23"
79
+ },
80
+ {"role": "user", "content":
81
+ f"Query: {query}\n\n"
82
+ f"Chunks:\n{summary_list}\n\n"
83
+ f"Select the {final_k} most relevant chunk IDs."
84
+ }
85
+ ]
86
+
87
+ try:
88
+ resp = client.chat.completions.create(
89
+ model=ll_model,
90
+ messages=messages,
91
+ temperature=0.0,
92
+ max_tokens=200
93
+ )
94
+ raw = resp.choices[0].message.content.strip()
95
+ selected_ids = [int(i.strip()) for i in raw.split(",") if i.strip().isdigit()]
96
+ return selected_ids[:final_k]
97
+ except Exception as e:
98
+ st.error(f"Rerank error: {e}")
99
+ # Fallback: just return top final_k by similarity
100
+ return [c["id"] for c in candidates[:final_k]]
101
+
102
+
103
+ # ── STEP 3: fetch full chunk_text for selected IDs only ────
104
+ def fetch_chunks_by_ids(selected_ids):
105
+ try:
106
+ conn = get_db_connection()
107
+ cur = conn.cursor()
108
+ cur.execute("""
109
+ SELECT
110
+ id,
111
+ case_id,
112
+ chunk_index,
113
+ chunk_text,
114
+ chunk_summary
115
+ FROM public.case_chunks
116
+ WHERE id = ANY(%s);
117
+ """, [selected_ids])
118
+ rows = cur.fetchall()
119
+ cur.close()
120
+ conn.close()
121
+
122
+ return [
123
+ {
124
+ "id": row[0],
125
+ "case_id": row[1],
126
+ "chunk_index": row[2],
127
+ "chunk_text": row[3],
128
+ "chunk_summary": row[4]
129
+ }
130
+ for row in rows
131
+ ]
132
+ except Exception as e:
133
+ st.error(f"Fetch error: {e}")
134
+ return []
135
+
136
+
137
+
138
+ def get_embedding(text):
139
+ # BGE requires this prefix for queries
140
+ prefixed = f"Represent this sentence for searching relevant passages: {text}"
141
+ return model.encode(prefixed).tolist()
142
+
143
  st.title("AI Legal Assistant βš–οΈ")
144
 
145
  if "history" not in st.session_state:
 
150
  hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist)
151
  messages = [
152
  {"role": "system", "content":
153
+ "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewrite"},
154
  {"role": "user", "content":
155
  f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n"
156
  "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."}
 
166
  except Exception as e:
167
  st.error(f"Rewrite error: {e}")
168
  rewritten = user_query
 
169
  return rewritten
170
 
171
+ # ── UPDATED: retrieve from pgvector ────────────────────────
172
+ # def retrieve_documents(query, top_k=10):
173
+ # try:
174
+ # embedding = get_embedding(query)
175
+ # conn = get_db_connection()
176
+ # cur = conn.cursor()
177
+ # cur.execute("""
178
+ # SELECT
179
+ # case_id,
180
+ # chunk_index,
181
+ # chunk_text,
182
+ # chunk_summary,
183
+ # 1 - (embedding <=> %s::vector) AS similarity
184
+ # FROM public.case_chunks
185
+ # ORDER BY embedding <=> %s::vector
186
+ # LIMIT %s;
187
+ # """, [embedding, embedding, top_k])
188
+ # rows = cur.fetchall()
189
+ # cur.close()
190
+ # conn.close()
191
+
192
+ # # Format to match the rest of the app
193
+ # docs = []
194
+ # for row in rows:
195
+ # docs.append({
196
+ # "case_id": row[0],
197
+ # "chunk_index": row[1],
198
+ # "chunk_text": row[2],
199
+ # "chunk_summary": row[3],
200
+ # "similarity": row[4]
201
+ # })
202
+ # return docs
203
+ # ── COMBINED: full retrieval pipeline ──────────────────────
204
  def retrieve_documents(query, top_k=10):
205
+ # 1. Get 4x summaries
206
+ candidates = retrieve_summaries(query, top_k=top_k * 4)
207
+ if not candidates:
208
+ return []
209
+
210
+ # 2. LLM picks best IDs from summaries
211
+ selected_ids = rerank_with_llm(query, candidates, final_k=top_k)
212
+ if not selected_ids:
213
+ return []
214
+
215
+ # 3. Fetch full text for selected chunks only
216
+ docs = fetch_chunks_by_ids(selected_ids)
217
+ return docs
218
+
219
  except Exception as e:
220
  st.error(f"Retrieve error: {e}")
221
  return []
222
 
 
223
  def clean_chunk_id(cid: str) -> str:
 
 
224
  cid = re.sub(r'_chunk.*$', '', cid)
 
225
  cid = cid.replace("_", " ").replace("-", " ")
 
226
  cid = " ".join(word.capitalize() for word in cid.split())
227
  return cid
228
 
229
+ # ── UPDATED: generate response with new doc structure ───────
 
230
  def generate_response(user_query, docs):
231
+ # Collect context from chunk_text
232
+ context = "\n\n---\n\n".join(d['chunk_text'] for d in docs if d['chunk_text'])
233
 
234
+ # Build sources
235
  source_links = {}
236
  for d in docs:
237
+ case_id = d.get("case_id", "unknown")
238
+ chunk_idx = d.get("chunk_index", "")
239
+ text_preview = " ".join((d.get("chunk_text") or "").split()[:30])
 
 
 
 
 
 
 
 
 
 
240
 
241
+ if case_id == "constitution":
242
+ display_name = f"Constitution (Chunk {chunk_idx})"
243
  else:
244
+ display_name = f"Case Law: {text_preview}..."
245
 
246
+ source_links[display_name] = d.get("chunk_text", "")
247
 
 
248
  source_links = dict(sorted(source_links.items()))
249
 
 
250
  messages = [
251
  {"role": "system", "content":
252
  "You are a helpful legal assistant. Use the provided context from documents to answer the user's question. "
253
  "At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. "
254
  "Formatting rules:\n"
255
+ "- For Constitution: show the chunk number.\n"
256
+ "- For Case law: show first ~30 words of the case text.\n"
257
  "- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n"
258
  "If multiple are used, separate them with commas."}
259
  ]
260
 
261
+ messages.extend(list(st.session_state.history))
 
262
  messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
263
+ f"Sources:\n{', '.join(source_links.keys())}\n\n"
264
+ f"Question:\n{user_query}"})
265
+
266
  try:
267
  resp = client.chat.completions.create(
268
  model=ll_model,
 
275
  st.error(f"Response error: {e}")
276
  reply = "Sorry, I encountered an error generating the answer."
277
 
 
278
  if source_links:
279
  clean_sources = ", ".join(source_links.keys())
280
  if "Source:" not in reply:
281
  reply += f"\n\nSource: {clean_sources}"
282
 
 
283
  st.session_state.history.append({"role": "assistant", "content": reply})
 
 
284
  st.markdown(reply)
285
 
 
286
  if source_links:
287
  st.write("### Sources")
288
  for name, text in source_links.items():
 
291
 
292
  return reply
293
 
 
 
 
294
  # Chat UI
295
  with st.form("chat_input", clear_on_submit=True):
296
  user_input = st.text_input("You:", "")
 
303
  assistant_reply = generate_response(rewritten, docs)
304
 
305
  c = 0
 
306
  st.markdown("---")
307
  for msg in reversed(st.session_state.history):
308
+ c += 1
309
  if msg["role"] == "user":
310
  st.markdown(f"**You:** {msg['content']}")
311
  else:
312
  st.markdown(f"**Legal Assistant:** {msg['content']}")
313
+ if c ^ 1:
314
+ st.markdown("---")