omarkashif commited on
Commit
f6d9e3b
·
verified ·
1 Parent(s): b832b0a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +32 -17
src/streamlit_app.py CHANGED
@@ -18,6 +18,7 @@ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
18
  index = pc.Index("legal-ai")
19
  model = SentenceTransformer('all-mpnet-base-v2')
20
  chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
 
21
 
22
  st.title("AI Legal Assistant ⚖️")
23
 
@@ -36,7 +37,7 @@ def get_rewritten_query(user_query):
36
  ]
37
  try:
38
  resp = client.chat.completions.create(
39
- model="gpt-4.1-mini",
40
  messages=messages,
41
  temperature=0.1,
42
  max_tokens=400
@@ -70,10 +71,11 @@ def clean_chunk_id(cid: str) -> str:
70
 
71
 
72
  def generate_response(user_query, docs):
 
73
  context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
74
- # sources = sorted({d['metadata']['chunk_id'] for d in docs if 'source' in d['metadata']})
75
- # --- Build human-friendly sources ---
76
- readable_sources = []
77
  for d in docs:
78
  meta = d['metadata']
79
  src = meta.get("source", "unknown").lower()
@@ -81,20 +83,21 @@ def generate_response(user_query, docs):
81
  text_preview = " ".join(meta.get("text", "").split()[:30])
82
 
83
  if src in ["constitution"]:
84
- readable_sources.append(f"Constitution ({clean_chunk_id(cid)})")
85
 
86
  elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
87
- readable_sources.append(f"Tax Ordinance ({clean_chunk_id(cid)})")
88
 
89
  elif src in ["case_law", "case", "tax_case"]:
90
- # Use first ~30 words of the actual text
91
- readable_sources.append(f"Case Law: {text_preview}...")
92
 
93
  else:
94
- readable_sources.append(f"{src.title()} ({clean_chunk_id(cid)})")
 
 
95
 
96
- # Deduplicate and sort
97
- readable_sources = sorted(set(readable_sources))
98
 
99
  # --- System prompt ---
100
  messages = [
@@ -108,15 +111,14 @@ def generate_response(user_query, docs):
108
  "If multiple are used, separate them with commas."}
109
  ]
110
 
111
-
112
  messages.extend(st.session_state.history)
113
-
114
  messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
115
- f"Sources:\n{', '.join(readable_sources)}\n\n"
116
  f"Question:\n{user_query}"})
117
  try:
118
  resp = client.chat.completions.create(
119
- model="gpt-4.1-mini",
120
  messages=messages,
121
  temperature=0.1,
122
  max_tokens=900
@@ -127,16 +129,29 @@ def generate_response(user_query, docs):
127
  reply = "Sorry, I encountered an error generating the answer."
128
 
129
  # Optional: force clean source line if LLM misses it
130
- if readable_sources:
131
- clean_sources = ", ".join(readable_sources)
132
  if "Source:" not in reply:
133
  reply += f"\n\nSource: {clean_sources}"
134
 
 
135
  st.session_state.history.append({"role": "assistant", "content": reply})
 
 
 
 
 
 
 
 
 
 
 
136
  return reply
137
 
138
 
139
 
 
140
  # Chat UI
141
  with st.form("chat_input", clear_on_submit=True):
142
  user_input = st.text_input("You:", "")
 
18
  index = pc.Index("legal-ai")
19
  model = SentenceTransformer('all-mpnet-base-v2')
20
  chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
21
+ ll_model = 'gpt-4o-mini'
22
 
23
  st.title("AI Legal Assistant ⚖️")
24
 
 
37
  ]
38
  try:
39
  resp = client.chat.completions.create(
40
+ model=ll_model,
41
  messages=messages,
42
  temperature=0.1,
43
  max_tokens=400
 
71
 
72
 
73
  def generate_response(user_query, docs):
74
+ # --- Collect context ---
75
  context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
76
+
77
+ # --- Build human-friendly sources + mapping ---
78
+ source_links = {}
79
  for d in docs:
80
  meta = d['metadata']
81
  src = meta.get("source", "unknown").lower()
 
83
  text_preview = " ".join(meta.get("text", "").split()[:30])
84
 
85
  if src in ["constitution"]:
86
+ display_name = f"Constitution ({clean_chunk_id(cid)})"
87
 
88
  elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
89
+ display_name = f"Tax Ordinance ({clean_chunk_id(cid)})"
90
 
91
  elif src in ["case_law", "case", "tax_case"]:
92
+ display_name = f"Case Law: {text_preview}..."
 
93
 
94
  else:
95
+ display_name = f"{src.title()} ({clean_chunk_id(cid)})"
96
+
97
+ source_links[display_name] = meta.get("text", "")
98
 
99
+ # Deduplicate
100
+ source_links = dict(sorted(source_links.items()))
101
 
102
  # --- System prompt ---
103
  messages = [
 
111
  "If multiple are used, separate them with commas."}
112
  ]
113
 
 
114
  messages.extend(st.session_state.history)
115
+
116
  messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
117
+ f"Sources:\n{', '.join(source_links.keys())}\n\n"
118
  f"Question:\n{user_query}"})
119
  try:
120
  resp = client.chat.completions.create(
121
+ model=ll_model,
122
  messages=messages,
123
  temperature=0.1,
124
  max_tokens=900
 
129
  reply = "Sorry, I encountered an error generating the answer."
130
 
131
  # Optional: force clean source line if LLM misses it
132
+ if source_links:
133
+ clean_sources = ", ".join(source_links.keys())
134
  if "Source:" not in reply:
135
  reply += f"\n\nSource: {clean_sources}"
136
 
137
+ # Save reply into history
138
  st.session_state.history.append({"role": "assistant", "content": reply})
139
+
140
+ # --- Render in Streamlit ---
141
+ st.markdown(reply)
142
+
143
+ # Add expandable sources
144
+ if source_links:
145
+ st.write("### Sources")
146
+ for name, text in source_links.items():
147
+ with st.expander(name):
148
+ st.write(text)
149
+
150
  return reply
151
 
152
 
153
 
154
+
155
  # Chat UI
156
  with st.form("chat_input", clear_on_submit=True):
157
  user_input = st.text_input("You:", "")