import os import json import streamlit as st import pandas as pd import psycopg2 from openai import OpenAI # Initialize OpenAI client client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # ========================= # Database connection # ========================= def get_connection(): return psycopg2.connect( host=os.getenv("RDS_ENDPOINT"), dbname="postgres", user="postgres", password=os.getenv("YOUR_RDS_PASSWORD"), port=5432 ) # ========================= # Search function # ========================= def search_cases(query, limit=200): conn = get_connection() cur = conn.cursor() sql = """ SELECT case_id, citation_name, court, case_title, case_text, ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank FROM cases WHERE search_vector @@ websearch_to_tsquery('english', %s) ORDER BY rank DESC LIMIT %s """ cur.execute(sql, (query, query, limit)) rows = cur.fetchall() df = pd.DataFrame(rows, columns=["case_id", "citation_name", "court", "case_title", "case_text", "rank"]) cur.close() conn.close() return df def execute_sql_search(query, limit=200): """Executes the raw SQL search for a single query string.""" conn = get_connection() cur = conn.cursor() # Using websearch_to_tsquery for better handling of "OR", "AND" and plain text sql = """ SELECT case_id, citation_name, court, case_title, case_year, case_text, ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank FROM cases WHERE search_vector @@ websearch_to_tsquery('english', %s) ORDER BY case_year DESC LIMIT %s """ # sql = """ # SELECT case_id, citation_name, court, case_title, case_text, # ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank, # ( # ts_rank(search_vector, websearch_to_tsquery('english', %s)) * 0.7 + # COALESCE(c.case_year::float / 2025, 0) * 0.3 # ) AS combined_score # FROM cases c # WHERE search_vector @@ websearch_to_tsquery('english', %s) # ORDER BY combined_score DESC # LIMIT %s # """ cur.execute(sql, (query, query, limit)) rows = cur.fetchall() df = pd.DataFrame(rows, columns=["case_id", "citation_name", "court", "case_title","case_year" ,"case_text", "rank"]) cur.close() conn.close() return df def smart_search_cases(user_input, limit=30): """ 1. Uses GPT-4o-mini to generate optimized search terms. 2. Runs SQL search for EACH term. 3. Combines results and removes duplicates. """ # 1. Ask GPT to optimize the query system_prompt = """ You are a legal search engine optimizer. Convert the user's natural language request into a JSON list of 1 to 4 distinct, optimized keyword search strings for a Postgres full-text search. Rules: - Remove filler words (e.g., "find me cases about", "caselaw on"). - Focus on specific citations (e.g., "Section 158"), acts, and legal concepts. - Generate variations (e.g., ["Section 158 Ordinance", "Section 158", "Ordinance 158"]). - Output strictly a JSON list of strings. No markdown formatting. """ try: response = client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_input} ], temperature=0 ) # Parse the list from GPT search_terms = json.loads(response.choices[0].message.content) except Exception as e: # Fallback if GPT fails: just use the raw user input as a single list item print(f"GPT Error: {e}") search_terms = [user_input] # 2. Run Search for each term all_dfs = [] # Also include the original raw query just in case if user_input not in search_terms: search_terms.append(user_input) for term in search_terms: df = execute_sql_search(term, limit) if not df.empty: df['search_term_used'] = term # Optional: track which term found it all_dfs.append(df) # 3. Combine and Deduplicate if not all_dfs: return pd.DataFrame() # Empty result # Concatenate all results final_df = pd.concat(all_dfs, ignore_index=True) # Drop duplicates based on 'case_id' (keep the one with the highest rank if possible, # but here we just keep first occurrence for speed) final_df = final_df.drop_duplicates(subset=['case_id']) return final_df # ========================= # RAG Chat Logic # ========================= def ask_gpt(context, user_query): system_prompt = f"""You are a professional legal assistant. Use the provided legal case text to answer questions accurately. - Do not hallucinate or make up facts. - Only state facts found in the text. - If the information is not in the text, say you don't know. - Always include references or direct quotes from the text as evidence. CASE CONTEXT: {context[:50000]} # Limiting context to stay within token bounds """ response = client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_query} ], temperature=0 ) return response.choices[0].message.content # ========================= # Streamlit UI & Routing # ========================= # Navigation State if "view" not in st.session_state: st.session_state.view = "search" if "selected_case" not in st.session_state: st.session_state.selected_case = None # --- Page 1: Search Results --- if st.session_state.view == "search": st.title("⚖️ Legal Cases Search Tool") search_query = st.text_input("Enter search term or phrase:") if search_query: # results = search_cases(search_query) results = smart_search_cases(search_query) if results.empty: st.warning("No results found.") else: for idx, row in results.iterrows(): # Title Truncation Logic raw_title = row['case_title'] if row['case_title'] else "Unknown Title" display_title = (raw_title[:200] + '...') if len(raw_title) > 50 else raw_title header = f"{display_title} | {row['citation_name']} | {row['court']}" with st.expander(header): # Display first 250 words words = row['case_text'].split() st.write(" ".join(words[:250]) + "...") # Button to "Read More" if st.button("READ MORE", key=f"btn_{row['case_id']}"): st.session_state.selected_case = row st.session_state.view = "detail" st.rerun() # --- Page 2: Case Detail & Chat --- elif st.session_state.view == "detail": case = st.session_state.selected_case # Initialize chat history and counter if they don't exist if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "chat_counter" not in st.session_state: st.session_state.chat_counter = 0 if st.button("← Back to Search"): # Reset chat when going back to search st.session_state.chat_history = [] st.session_state.chat_counter = 0 st.session_state.view = "search" st.rerun() st.title(case['case_title']) st.caption(f"{case['citation_name']} | {case['court']}") tab1, tab2 = st.tabs(["📄 Full Case Text", "💬 Chat with Case"]) with tab1: st.write(case['case_text']) with tab2: st.subheader("Legal AI Assistant") # Check if user reached follow-up limit if st.session_state.chat_counter >= 8: st.warning("⚠️ You have reached the maximum of 8 follow-up questions for this case.") else: # Display chat history (scrolling) for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.write(message["content"]) # Chat Input if prompt := st.chat_input("Ask a follow-up question..."): # Append user message to history st.session_state.chat_history.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.write(prompt) with st.chat_message("assistant"): with st.spinner("Analyzing case history..."): # Prepare the sliding window (Last 4 pairs = 8 messages) recent_history = st.session_state.chat_history[-8:] # Prepare System Prompt with Case Context messages = [ {"role": "system", "content": f"You are a professional legal assistant. Use this case context to answer: {case['case_text'][:15000]}. Be factual and cite the text."} ] # Add recent history to the API call messages.extend(recent_history) # Call OpenAI response = client.chat.completions.create( model="gpt-4o-mini", messages=messages, temperature=0 ) full_response = response.choices[0].message.content st.write(full_response) # Update history and counter st.session_state.chat_history.append({"role": "assistant", "content": full_response}) st.session_state.chat_counter += 1 # Use a small delay/rerun to ensure the counter updates the UI if needed st.rerun() # Display remaining questions st.info(f"Questions asked: {st.session_state.chat_counter}/8")