diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b708933f5bdc3a5132e24431f1fd2d721c7c1baf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +vectordb/**/*.sqlite3 filter=lfs diff=lfs merge=lfs -text diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7482727440b284f3c1b318059c9dd9dcb1249cdb --- /dev/null +++ b/app.py @@ -0,0 +1,53 @@ +# app.py +import streamlit as st +from db import init_connection +from ui_pages import login_page, create_account_page, main_page +from chat import load_user_sessions # import it here + +from dotenv import load_dotenv +load_dotenv() + +def app(): + # Initialize session state + default_keys = { + 'logged_in': False, + 'username': "", + 'show_create_account': False, + 'messages': [], + 'current_chat_session': None, + 'chat_sessions': [], + 'last_session_id': None + } + for key, val in default_keys.items(): + if key not in st.session_state: + st.session_state[key] = val + + # Initialize database connection (returns dict of collections) + db_conn = init_connection() + if db_conn is None: + return + + # Route to appropriate page + if st.session_state.logged_in: + # load_user_sessions returns (sessions, current_session, messages_stub) + sessions, current, messages = load_user_sessions( + st.session_state.username, + db_conn["sessions"], + st.session_state.get("last_session_id") + ) + st.session_state.chat_sessions = sessions + st.session_state.current_chat_session = current + # load chat messages if we have a current session + if current: + from chat import load_chat_history + st.session_state.messages = load_chat_history(str(current["_id"]), db_conn["messages"]) + else: + st.session_state.messages = [] + main_page() + elif st.session_state.show_create_account: + create_account_page() + else: + login_page() + +if __name__ == "__main__": + app() diff --git a/auth.py b/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..542a16c71f12137aad6c1cedf2b9f1b169fea727 --- /dev/null +++ b/auth.py @@ -0,0 +1,26 @@ +import bcrypt +import logging + +logger = logging.getLogger(__name__) + +def check_login(username: str, password: str, users_collection) -> bool: + """Checks if the provided username and password are valid against MongoDB.""" + logger.info(f"Login attempt for user: {username}") + user = users_collection.find_one({"username": username}) + if user: + stored_hash = user["password"] + + # Ensure we always have bytes for bcrypt.checkpw + stored_hash_bytes = stored_hash.encode('utf-8') if isinstance(stored_hash, str) else stored_hash + + try: + if bcrypt.checkpw(password.encode('utf-8'), stored_hash_bytes): + logger.info(f"User '{username}' logged in successfully.") + return True + else: + logger.warning(f"Invalid password attempt for user: {username}") + except Exception as e: + logger.error(f"Error checking password for user {username}: {e}") + else: + logger.warning(f"Login failed, user not found: {username}") + return False \ No newline at end of file diff --git a/chat.py b/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..707f84dfe5b0fd6a04106c381753066147d3442c --- /dev/null +++ b/chat.py @@ -0,0 +1,61 @@ +from datetime import datetime +from bson.objectid import ObjectId + +def load_user_sessions(username, sessions_collection, last_session_id=None): + """ + Load sessions for a user. Restore last active session if possible. + Returns: (sessions_list, current_session_or_None, messages_list) + """ + if sessions_collection is None: + return [], None, [] + + sessions = list(sessions_collection.find({"username": username}).sort("timestamp", -1)) + current_session = None + messages = [] + + if sessions: + if last_session_id: + try: + last = sessions_collection.find_one({"_id": ObjectId(last_session_id)}) + except Exception: + last = None + if last: + current_session = last + # messages will be loaded by caller or by calling load_chat_history + if not current_session: + current_session = sessions[0] + + # Note: do NOT load messages here unless you also have messages_collection. + # Return sessions and current_session; caller can call load_chat_history with messages_collection. + return sessions, current_session, messages + + +def load_chat_history(session_id, messages_collection): + """ + Loads messages for a given chat session from messages_collection. + """ + if messages_collection is None: + return [] + try: + msgs = list(messages_collection.find({"session_id": session_id}).sort("timestamp", 1)) + return [{"role": m.get("role", "assistant"), "content": m.get("content", "")} for m in msgs] + except Exception: + return [] + + +def save_message(session_id, role, content, messages_collection): + """ + Save a message to the chat history in messages_collection. + """ + if messages_collection is None: + return None + try: + doc = { + "session_id": session_id, + "role": role, + "content": content, + "timestamp": datetime.utcnow(), + } + return messages_collection.insert_one(doc).inserted_id + except Exception: + return None diff --git a/db.py b/db.py new file mode 100644 index 0000000000000000000000000000000000000000..c8aba6720b5b9d472637943f14bc24369e12e651 --- /dev/null +++ b/db.py @@ -0,0 +1,42 @@ +from pymongo import MongoClient +import os + +from dotenv import load_dotenv +load_dotenv() + + +client = None +db = None +users_collection = None +sessions_collection = None +messages_collection = None + + +def init_connection(): + """Initialize MongoDB connection and collections.""" + global client, db, users_collection, sessions_collection, messages_collection + + mongo_uri = os.getenv("MONGO_URI") + if not mongo_uri: + raise ValueError("❌ MONGO_URI not found in environment variables.") + + client = MongoClient(mongo_uri) + db = client.get_database("law_cases_db") + + users_collection = db.get_collection("users") + sessions_collection = db.get_collection("chat_sessions") + messages_collection = db.get_collection("chat_messages") + + # ✅ create unique index (username + normalized chat name) + sessions_collection.create_index( + [("username", 1), ("session_name_normalized", 1)], + unique=True + ) + + return { + "client": client, + "db": db, + "users": users_collection, + "sessions": sessions_collection, + "messages": messages_collection + } diff --git a/rag.py b/rag.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5e071c14f4f29d9230573843071a937c2c4d12 --- /dev/null +++ b/rag.py @@ -0,0 +1,175 @@ +import os +import logging +import streamlit as st +from dotenv import load_dotenv +import pickle + +from llama_index.llms.groq import Groq +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.vector_stores.chroma import ChromaVectorStore +from llama_index.core import VectorStoreIndex +from llama_index.core.retrievers import VectorIndexRetriever, RecursiveRetriever +from llama_index.retrievers.bm25 import BM25Retriever +from llama_index.core.tools import QueryEngineTool +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core import get_response_synthesizer +from llama_index.core.agent import ReActAgent +from chromadb import PersistentClient + +logger = logging.getLogger(__name__) + +@st.cache_resource +def setup_rag_system(debug=False): + load_dotenv() + + groq_api_key = os.getenv("GROQ_API_KEY") or st.secrets.get("groq", {}).get("api_key") + if not groq_api_key: + st.error("GROQ API key not found. Please check your environment variables or secrets.") + st.stop() + + # LLM + llm = Groq( + model="llama-3.1-8b-instant", + api_key=groq_api_key, + max_input_tokens=1200, + max_output_tokens=1200 + ) + + # Embeddings + embedding_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") + + # Persisted vector DBs + persist_dirs = [ + "vectordb/case_2021", + "vectordb/case_2022", + "vectordb/case_2023", + "vectordb/case_2024", + "vectordb/case_2025" + ] + for persist_dir in persist_dirs: + if not os.path.exists(persist_dir): + st.error(f"Vector database directory {persist_dir} not found.") + st.stop() + + # Build hybrid retrievers + hybrid_retrievers = [] + for persist_dir in persist_dirs: + # Load pickled nodes + nodes_path = os.path.join(persist_dir, "nodes.pkl") + if not os.path.exists(nodes_path): + st.error(f"Pickle file {nodes_path} not found.") + st.stop() + + with open(nodes_path, "rb") as f: + nodes = pickle.load(f) + + # Vector store + client = PersistentClient(path=persist_dir) + collection = client.get_collection("case_collection") + vector_store = ChromaVectorStore(chroma_collection=collection) + index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embedding_model) + + # Retrievers + vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=2, retriever_mode="mmr") + bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2) + + hybrid_retriever = RecursiveRetriever( + "vector", + retriever_dict={"vector": vector_retriever, "bm25": bm25_retriever}, + verbose=True + ) + hybrid_retrievers.append(hybrid_retriever) + + # Case metadata + documents_info = [ + { + "name": "Quezada2021_Retriever", + "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Quezada (21-0089-MC), issued on December 20, 2021." + }, + { + "name": "Thompson2022_Retriever", + "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Thompson (22-0098-AF), issued on November 21, 2022." + }, + { + "name": "Brown2023_Retriever", + "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Brown (22-0249-CG), issued on October 23, 2023." + }, + { + "name": "Smith2024_Retriever", + "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Smith (23-0207-AF), issued on November 26, 2024." + }, + { + "name": "Lopez2025_Retriever", + "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Lopez (24-0226-CG), issued on September 2, 2025." + }, + ] + + + # Create retriever → tool + def create_retriever_tool(retriever, llm, name, description): + response_synthesizer = get_response_synthesizer( + llm=llm, response_mode="compact", use_async=False + ) + query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer) + return QueryEngineTool.from_defaults(query_engine=query_engine, name=name, description=description) + + retriever_tools = [ + create_retriever_tool(hybrid_retrievers[i], llm, info["name"], info["description"]) + for i, info in enumerate(documents_info) + ] + + # System prompt + system_prompt = """ + You are a highly specialized legal research assistant. + You may ONLY answer questions that are legal in nature. + This includes both: + - Specific case law queries from the provided case documents (2021–2025). + - General legal concepts, doctrines, or terminology. + + Before answering, always perform this intermediate reasoning step: + + 1. Classify the user query: + - If the query relates to law, legal concepts, legal systems, court rulings, rights, duties, contracts, procedures, or legal doctrines → classify as: LEGAL_QUERY. + - If the query is casual conversation, mathematics, trivia, technical programming, or anything outside the legal domain → classify as: NON_LEGAL_QUERY. + + 2. Response rules: + - If LEGAL_QUERY: + a) If the query references specific cases between 2021–2025, use the provided case documents to retrieve and answer. Cite the case name and year. + b) If the query is a general legal question, answer concisely and professionally, using legal reasoning. Do NOT speculate beyond standard legal knowledge. + - If NON_LEGAL_QUERY: + Respond ONLY with: "I can only answer questions about legal cases (2021–2025) or general law queries." + + 3. Examples: + - LEGAL_QUERY (answer these): + • "What is the difference between civil and criminal law?" + • "Explain the principle of judicial review." + • "Summarize the ruling in United States v. Lopez (2025)." + • "What is mens rea in criminal law?" + - NON_LEGAL_QUERY (reject these): + • "What is 2+2?" + • "Who won the FIFA World Cup in 2022?" + • "Write me a Python script." + • "Tell me a joke." + + 4. Style & tone: + - Be concise, professional, and clear. + - Use citations ONLY when referring to case documents (case name + year). + - Never provide speculative or non-legal answers. + """ + + + # ReActAgent + agent = ReActAgent( + tools=retriever_tools, + llm=llm, + verbose=True, + max_iterations=20, + system_prompt=system_prompt + ) + + logger.info("RAG system setup complete.") + + if debug: + return agent, llm, hybrid_retrievers + + return agent, llm diff --git a/requirements.txt b/requirements.txt index 28d994e22f8dd432b51df193562052e315ad95f7..1119022e72ee95c1df77ed95dc9f63454194720c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,17 @@ -altair -pandas -streamlit \ No newline at end of file +requests +PyMuPDF +chromadb +sentence-transformers + +llama-index +llama-index-vector-stores-chroma +llama-index-embeddings-huggingface +llama-index-llms-groq +llama-index +llama-index-retrievers-bm25 +llama-index-storage-chat-store-mongo + +pymongo + +streamlit +nest_asyncio diff --git a/ui_pages.py b/ui_pages.py new file mode 100644 index 0000000000000000000000000000000000000000..3f12a6d8138240862fb52dc13041d5f6b858fa11 --- /dev/null +++ b/ui_pages.py @@ -0,0 +1,577 @@ +import time +import streamlit as st +import re +import asyncio +import nest_asyncio + +from auth import check_login +from chat import load_user_sessions, load_chat_history, save_message +from rag import setup_rag_system +from db import init_connection + +nest_asyncio.apply() + + +def login_page(): + # Custom CSS for login page + st.markdown(""" + + """, unsafe_allow_html=True) + + st.markdown(""" +
Ask me about legal cases (2021–2025). I'll retrieve documents and give citations.
+Set up your account to start chatting with the legal RAG system.
+