case-search / src /streamlit_app.py
omarkashif's picture
Update src/streamlit_app.py
7834345 verified
import os
import json
import streamlit as st
import pandas as pd
import psycopg2
from openai import OpenAI
# Initialize OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# =========================
# Database connection
# =========================
def get_connection():
return psycopg2.connect(
host=os.getenv("RDS_ENDPOINT"),
dbname="postgres",
user="postgres",
password=os.getenv("YOUR_RDS_PASSWORD"),
port=5432
)
# =========================
# Search function
# =========================
def search_cases(query, limit=200):
conn = get_connection()
cur = conn.cursor()
sql = """
SELECT case_id, citation_name, court, case_title, case_text,
ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank
FROM cases
WHERE search_vector @@ websearch_to_tsquery('english', %s)
ORDER BY rank DESC LIMIT %s
"""
cur.execute(sql, (query, query, limit))
rows = cur.fetchall()
df = pd.DataFrame(rows, columns=["case_id", "citation_name", "court", "case_title", "case_text", "rank"])
cur.close()
conn.close()
return df
def execute_sql_search(query, limit=200):
"""Executes the raw SQL search for a single query string."""
conn = get_connection()
cur = conn.cursor()
# Using websearch_to_tsquery for better handling of "OR", "AND" and plain text
sql = """
SELECT case_id, citation_name, court, case_title, case_year, case_text,
ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank
FROM cases
WHERE search_vector @@ websearch_to_tsquery('english', %s)
ORDER BY case_year DESC
LIMIT %s
"""
# sql = """
# SELECT case_id, citation_name, court, case_title, case_text,
# ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank,
# (
# ts_rank(search_vector, websearch_to_tsquery('english', %s)) * 0.7 +
# COALESCE(c.case_year::float / 2025, 0) * 0.3
# ) AS combined_score
# FROM cases c
# WHERE search_vector @@ websearch_to_tsquery('english', %s)
# ORDER BY combined_score DESC
# LIMIT %s
# """
cur.execute(sql, (query, query, limit))
rows = cur.fetchall()
df = pd.DataFrame(rows, columns=["case_id", "citation_name", "court", "case_title","case_year" ,"case_text", "rank"])
cur.close()
conn.close()
return df
def smart_search_cases(user_input, limit=30):
"""
1. Uses GPT-4o-mini to generate optimized search terms.
2. Runs SQL search for EACH term.
3. Combines results and removes duplicates.
"""
# 1. Ask GPT to optimize the query
system_prompt = """
You are a legal search engine optimizer.
Convert the user's natural language request into a JSON list of 1 to 4 distinct, optimized keyword search strings for a Postgres full-text search.
Rules:
- Remove filler words (e.g., "find me cases about", "caselaw on").
- Focus on specific citations (e.g., "Section 158"), acts, and legal concepts.
- Generate variations (e.g., ["Section 158 Ordinance", "Section 158", "Ordinance 158"]).
- Output strictly a JSON list of strings. No markdown formatting.
"""
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
],
temperature=0
)
# Parse the list from GPT
search_terms = json.loads(response.choices[0].message.content)
except Exception as e:
# Fallback if GPT fails: just use the raw user input as a single list item
print(f"GPT Error: {e}")
search_terms = [user_input]
# 2. Run Search for each term
all_dfs = []
# Also include the original raw query just in case
if user_input not in search_terms:
search_terms.append(user_input)
for term in search_terms:
df = execute_sql_search(term, limit)
if not df.empty:
df['search_term_used'] = term # Optional: track which term found it
all_dfs.append(df)
# 3. Combine and Deduplicate
if not all_dfs:
return pd.DataFrame() # Empty result
# Concatenate all results
final_df = pd.concat(all_dfs, ignore_index=True)
# Drop duplicates based on 'case_id' (keep the one with the highest rank if possible,
# but here we just keep first occurrence for speed)
final_df = final_df.drop_duplicates(subset=['case_id'])
return final_df
# =========================
# RAG Chat Logic
# =========================
def ask_gpt(context, user_query):
system_prompt = f"""You are a professional legal assistant.
Use the provided legal case text to answer questions accurately.
- Do not hallucinate or make up facts.
- Only state facts found in the text.
- If the information is not in the text, say you don't know.
- Always include references or direct quotes from the text as evidence.
CASE CONTEXT:
{context[:50000]} # Limiting context to stay within token bounds
"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_query}
],
temperature=0
)
return response.choices[0].message.content
# =========================
# Streamlit UI & Routing
# =========================
# Navigation State
if "view" not in st.session_state:
st.session_state.view = "search"
if "selected_case" not in st.session_state:
st.session_state.selected_case = None
# --- Page 1: Search Results ---
if st.session_state.view == "search":
st.title("⚖️ Legal Cases Search Tool")
search_query = st.text_input("Enter search term or phrase:")
if search_query:
# results = search_cases(search_query)
results = smart_search_cases(search_query)
if results.empty:
st.warning("No results found.")
else:
for idx, row in results.iterrows():
# Title Truncation Logic
raw_title = row['case_title'] if row['case_title'] else "Unknown Title"
display_title = (raw_title[:200] + '...') if len(raw_title) > 50 else raw_title
header = f"{display_title} | {row['citation_name']} | {row['court']}"
with st.expander(header):
# Display first 250 words
words = row['case_text'].split()
st.write(" ".join(words[:250]) + "...")
# Button to "Read More"
if st.button("READ MORE", key=f"btn_{row['case_id']}"):
st.session_state.selected_case = row
st.session_state.view = "detail"
st.rerun()
# --- Page 2: Case Detail & Chat ---
elif st.session_state.view == "detail":
case = st.session_state.selected_case
# Initialize chat history and counter if they don't exist
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "chat_counter" not in st.session_state:
st.session_state.chat_counter = 0
if st.button("← Back to Search"):
# Reset chat when going back to search
st.session_state.chat_history = []
st.session_state.chat_counter = 0
st.session_state.view = "search"
st.rerun()
st.title(case['case_title'])
st.caption(f"{case['citation_name']} | {case['court']}")
tab1, tab2 = st.tabs(["📄 Full Case Text", "💬 Chat with Case"])
with tab1:
st.write(case['case_text'])
with tab2:
st.subheader("Legal AI Assistant")
# Check if user reached follow-up limit
if st.session_state.chat_counter >= 8:
st.warning("⚠️ You have reached the maximum of 8 follow-up questions for this case.")
else:
# Display chat history (scrolling)
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
# Chat Input
if prompt := st.chat_input("Ask a follow-up question..."):
# Append user message to history
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("Analyzing case history..."):
# Prepare the sliding window (Last 4 pairs = 8 messages)
recent_history = st.session_state.chat_history[-8:]
# Prepare System Prompt with Case Context
messages = [
{"role": "system", "content": f"You are a professional legal assistant. Use this case context to answer: {case['case_text'][:15000]}. Be factual and cite the text."}
]
# Add recent history to the API call
messages.extend(recent_history)
# Call OpenAI
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
temperature=0
)
full_response = response.choices[0].message.content
st.write(full_response)
# Update history and counter
st.session_state.chat_history.append({"role": "assistant", "content": full_response})
st.session_state.chat_counter += 1
# Use a small delay/rerun to ensure the counter updates the UI if needed
st.rerun()
# Display remaining questions
st.info(f"Questions asked: {st.session_state.chat_counter}/8")