Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import streamlit as st | |
| import pandas as pd | |
| import psycopg2 | |
| from openai import OpenAI | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # ========================= | |
| # Database connection | |
| # ========================= | |
| def get_connection(): | |
| return psycopg2.connect( | |
| host=os.getenv("RDS_ENDPOINT"), | |
| dbname="postgres", | |
| user="postgres", | |
| password=os.getenv("YOUR_RDS_PASSWORD"), | |
| port=5432 | |
| ) | |
| # ========================= | |
| # Search function | |
| # ========================= | |
| def search_cases(query, limit=200): | |
| conn = get_connection() | |
| cur = conn.cursor() | |
| sql = """ | |
| SELECT case_id, citation_name, court, case_title, case_text, | |
| ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank | |
| FROM cases | |
| WHERE search_vector @@ websearch_to_tsquery('english', %s) | |
| ORDER BY rank DESC LIMIT %s | |
| """ | |
| cur.execute(sql, (query, query, limit)) | |
| rows = cur.fetchall() | |
| df = pd.DataFrame(rows, columns=["case_id", "citation_name", "court", "case_title", "case_text", "rank"]) | |
| cur.close() | |
| conn.close() | |
| return df | |
| def execute_sql_search(query, limit=200): | |
| """Executes the raw SQL search for a single query string.""" | |
| conn = get_connection() | |
| cur = conn.cursor() | |
| # Using websearch_to_tsquery for better handling of "OR", "AND" and plain text | |
| sql = """ | |
| SELECT case_id, citation_name, court, case_title, case_year, case_text, | |
| ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank | |
| FROM cases | |
| WHERE search_vector @@ websearch_to_tsquery('english', %s) | |
| ORDER BY case_year DESC | |
| LIMIT %s | |
| """ | |
| # sql = """ | |
| # SELECT case_id, citation_name, court, case_title, case_text, | |
| # ts_rank(search_vector, websearch_to_tsquery('english', %s)) AS rank, | |
| # ( | |
| # ts_rank(search_vector, websearch_to_tsquery('english', %s)) * 0.7 + | |
| # COALESCE(c.case_year::float / 2025, 0) * 0.3 | |
| # ) AS combined_score | |
| # FROM cases c | |
| # WHERE search_vector @@ websearch_to_tsquery('english', %s) | |
| # ORDER BY combined_score DESC | |
| # LIMIT %s | |
| # """ | |
| cur.execute(sql, (query, query, limit)) | |
| rows = cur.fetchall() | |
| df = pd.DataFrame(rows, columns=["case_id", "citation_name", "court", "case_title","case_year" ,"case_text", "rank"]) | |
| cur.close() | |
| conn.close() | |
| return df | |
| def smart_search_cases(user_input, limit=30): | |
| """ | |
| 1. Uses GPT-4o-mini to generate optimized search terms. | |
| 2. Runs SQL search for EACH term. | |
| 3. Combines results and removes duplicates. | |
| """ | |
| # 1. Ask GPT to optimize the query | |
| system_prompt = """ | |
| You are a legal search engine optimizer. | |
| Convert the user's natural language request into a JSON list of 1 to 4 distinct, optimized keyword search strings for a Postgres full-text search. | |
| Rules: | |
| - Remove filler words (e.g., "find me cases about", "caselaw on"). | |
| - Focus on specific citations (e.g., "Section 158"), acts, and legal concepts. | |
| - Generate variations (e.g., ["Section 158 Ordinance", "Section 158", "Ordinance 158"]). | |
| - Output strictly a JSON list of strings. No markdown formatting. | |
| """ | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_input} | |
| ], | |
| temperature=0 | |
| ) | |
| # Parse the list from GPT | |
| search_terms = json.loads(response.choices[0].message.content) | |
| except Exception as e: | |
| # Fallback if GPT fails: just use the raw user input as a single list item | |
| print(f"GPT Error: {e}") | |
| search_terms = [user_input] | |
| # 2. Run Search for each term | |
| all_dfs = [] | |
| # Also include the original raw query just in case | |
| if user_input not in search_terms: | |
| search_terms.append(user_input) | |
| for term in search_terms: | |
| df = execute_sql_search(term, limit) | |
| if not df.empty: | |
| df['search_term_used'] = term # Optional: track which term found it | |
| all_dfs.append(df) | |
| # 3. Combine and Deduplicate | |
| if not all_dfs: | |
| return pd.DataFrame() # Empty result | |
| # Concatenate all results | |
| final_df = pd.concat(all_dfs, ignore_index=True) | |
| # Drop duplicates based on 'case_id' (keep the one with the highest rank if possible, | |
| # but here we just keep first occurrence for speed) | |
| final_df = final_df.drop_duplicates(subset=['case_id']) | |
| return final_df | |
| # ========================= | |
| # RAG Chat Logic | |
| # ========================= | |
| def ask_gpt(context, user_query): | |
| system_prompt = f"""You are a professional legal assistant. | |
| Use the provided legal case text to answer questions accurately. | |
| - Do not hallucinate or make up facts. | |
| - Only state facts found in the text. | |
| - If the information is not in the text, say you don't know. | |
| - Always include references or direct quotes from the text as evidence. | |
| CASE CONTEXT: | |
| {context[:50000]} # Limiting context to stay within token bounds | |
| """ | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_query} | |
| ], | |
| temperature=0 | |
| ) | |
| return response.choices[0].message.content | |
| # ========================= | |
| # Streamlit UI & Routing | |
| # ========================= | |
| # Navigation State | |
| if "view" not in st.session_state: | |
| st.session_state.view = "search" | |
| if "selected_case" not in st.session_state: | |
| st.session_state.selected_case = None | |
| # --- Page 1: Search Results --- | |
| if st.session_state.view == "search": | |
| st.title("⚖️ Legal Cases Search Tool") | |
| search_query = st.text_input("Enter search term or phrase:") | |
| if search_query: | |
| # results = search_cases(search_query) | |
| results = smart_search_cases(search_query) | |
| if results.empty: | |
| st.warning("No results found.") | |
| else: | |
| for idx, row in results.iterrows(): | |
| # Title Truncation Logic | |
| raw_title = row['case_title'] if row['case_title'] else "Unknown Title" | |
| display_title = (raw_title[:200] + '...') if len(raw_title) > 50 else raw_title | |
| header = f"{display_title} | {row['citation_name']} | {row['court']}" | |
| with st.expander(header): | |
| # Display first 250 words | |
| words = row['case_text'].split() | |
| st.write(" ".join(words[:250]) + "...") | |
| # Button to "Read More" | |
| if st.button("READ MORE", key=f"btn_{row['case_id']}"): | |
| st.session_state.selected_case = row | |
| st.session_state.view = "detail" | |
| st.rerun() | |
| # --- Page 2: Case Detail & Chat --- | |
| elif st.session_state.view == "detail": | |
| case = st.session_state.selected_case | |
| # Initialize chat history and counter if they don't exist | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "chat_counter" not in st.session_state: | |
| st.session_state.chat_counter = 0 | |
| if st.button("← Back to Search"): | |
| # Reset chat when going back to search | |
| st.session_state.chat_history = [] | |
| st.session_state.chat_counter = 0 | |
| st.session_state.view = "search" | |
| st.rerun() | |
| st.title(case['case_title']) | |
| st.caption(f"{case['citation_name']} | {case['court']}") | |
| tab1, tab2 = st.tabs(["📄 Full Case Text", "💬 Chat with Case"]) | |
| with tab1: | |
| st.write(case['case_text']) | |
| with tab2: | |
| st.subheader("Legal AI Assistant") | |
| # Check if user reached follow-up limit | |
| if st.session_state.chat_counter >= 8: | |
| st.warning("⚠️ You have reached the maximum of 8 follow-up questions for this case.") | |
| else: | |
| # Display chat history (scrolling) | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| # Chat Input | |
| if prompt := st.chat_input("Ask a follow-up question..."): | |
| # Append user message to history | |
| st.session_state.chat_history.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Analyzing case history..."): | |
| # Prepare the sliding window (Last 4 pairs = 8 messages) | |
| recent_history = st.session_state.chat_history[-8:] | |
| # Prepare System Prompt with Case Context | |
| messages = [ | |
| {"role": "system", "content": f"You are a professional legal assistant. Use this case context to answer: {case['case_text'][:15000]}. Be factual and cite the text."} | |
| ] | |
| # Add recent history to the API call | |
| messages.extend(recent_history) | |
| # Call OpenAI | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| full_response = response.choices[0].message.content | |
| st.write(full_response) | |
| # Update history and counter | |
| st.session_state.chat_history.append({"role": "assistant", "content": full_response}) | |
| st.session_state.chat_counter += 1 | |
| # Use a small delay/rerun to ensure the counter updates the UI if needed | |
| st.rerun() | |
| # Display remaining questions | |
| st.info(f"Questions asked: {st.session_state.chat_counter}/8") |