File size: 7,193 Bytes
883d885
3ced916
883d885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ced916
883d885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ced916
883d885
 
 
ddfc14c
 
 
 
 
883d885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ced916
883d885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ced916
883d885
 
 
 
3ced916
 
883d885
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import logging
import streamlit as st
from dotenv import load_dotenv
import pickle

from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever, RecursiveRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.tools import QueryEngineTool
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import get_response_synthesizer
from llama_index.core.agent import ReActAgent
from chromadb import PersistentClient

logger = logging.getLogger(__name__)

@st.cache_resource
def setup_rag_system(debug=False):
    load_dotenv()
    
    groq_api_key = os.getenv("GROQ_API_KEY") or st.secrets.get("groq", {}).get("api_key")
    if not groq_api_key:
        st.error("GROQ API key not found. Please check your environment variables or secrets.")
        st.stop()

    # LLM
    llm = Groq(
        model="llama-3.1-8b-instant",
        api_key=groq_api_key,
        max_input_tokens=1200,
        max_output_tokens=1200
    )

    # Embeddings
    embedding_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")

    # Persisted vector DBs
    persist_dirs = [
        "./vectordb/case_2021",
        "./vectordb/case_2022",
        "./vectordb/case_2023",
        "./vectordb/case_2024",
        "./vectordb/case_2025"
    ]
    for persist_dir in persist_dirs:
        if not os.path.exists(persist_dir):
            st.error(f"Vector database directory {persist_dir} not found.")
            st.stop()

    # Build hybrid retrievers
    hybrid_retrievers = []
    for persist_dir in persist_dirs:
        # Load pickled nodes
        nodes_path = os.path.join(persist_dir, "nodes.pkl")
        if not os.path.exists(nodes_path):
            st.error(f"Pickle file {nodes_path} not found.")
            st.stop()

        with open(nodes_path, "rb") as f:
            nodes = pickle.load(f)

        # Vector store
        client = PersistentClient(path=persist_dir)
        collection = client.get_collection("case_collection")
        vector_store = ChromaVectorStore(chroma_collection=collection)
        index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embedding_model)

        # Retrievers
        vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=2, retriever_mode="mmr")
        bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)

        hybrid_retriever = RecursiveRetriever(
            "vector",
            retriever_dict={"vector": vector_retriever, "bm25": bm25_retriever},
            verbose=True
        )
        hybrid_retrievers.append(hybrid_retriever)

    # Case metadata
    documents_info = [
        {
            "name": "Quezada2021_Retriever",
            "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Quezada (21-0089-MC), issued on December 20, 2021."
        },
        {
            "name": "Thompson2022_Retriever",
            "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Thompson (22-0098-AF), issued on November 21, 2022."
        },
        {
            "name": "Brown2023_Retriever",
            "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Brown (22-0249-CG), issued on October 23, 2023."
        },
        {
            "name": "Smith2024_Retriever",
            "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Smith (23-0207-AF), issued on November 26, 2024."
        },
        {
            "name": "Lopez2025_Retriever",
            "description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Lopez (24-0226-CG), issued on September 2, 2025."
        },
    ]


    # Create retriever β†’ tool
    def create_retriever_tool(retriever, llm, name, description):
        response_synthesizer = get_response_synthesizer(
            llm=llm, response_mode="compact", use_async=False
        )
        query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
        return QueryEngineTool.from_defaults(query_engine=query_engine, name=name, description=description)

    retriever_tools = [
        create_retriever_tool(hybrid_retrievers[i], llm, info["name"], info["description"])
        for i, info in enumerate(documents_info)
    ]

    # System prompt
    system_prompt = """
        You are a highly specialized legal research assistant. 
        You may ONLY answer questions that are legal in nature. 
        This includes both:
        - Specific case law queries from the provided case documents (2021–2025).
        - General legal concepts, doctrines, or terminology.

        Before answering, always perform this intermediate reasoning step:

        1. Classify the user query:
        - If the query relates to law, legal concepts, legal systems, court rulings, rights, duties, contracts, procedures, or legal doctrines β†’ classify as: LEGAL_QUERY.
        - If the query is casual conversation, mathematics, trivia, technical programming, or anything outside the legal domain β†’ classify as: NON_LEGAL_QUERY.

        2. Response rules:
        - If LEGAL_QUERY:
            a) If the query references specific cases between 2021–2025, use the provided case documents to retrieve and answer. Cite the case name and year.
            b) If the query is a general legal question, answer concisely and professionally, using legal reasoning. Do NOT speculate beyond standard legal knowledge.
        - If NON_LEGAL_QUERY:
            Respond ONLY with: "I can only answer questions about legal cases (2021–2025) or general law queries."

        3. Examples:
        - LEGAL_QUERY (answer these):
            β€’ "What is the difference between civil and criminal law?"
            β€’ "Explain the principle of judicial review."
            β€’ "Summarize the ruling in United States v. Lopez (2025)."
            β€’ "What is mens rea in criminal law?"
        - NON_LEGAL_QUERY (reject these):
            β€’ "What is 2+2?"
            β€’ "Who won the FIFA World Cup in 2022?"
            β€’ "Write me a Python script."
            β€’ "Tell me a joke."

        4. Style & tone:
        - Be concise, professional, and clear.
        - Use citations ONLY when referring to case documents (case name + year).
        - Never provide speculative or non-legal answers.
        """


    # ReActAgent
    agent = ReActAgent(
        tools=retriever_tools,
        llm=llm,
        verbose=True,
        max_iterations=20,
        system_prompt=system_prompt
    )

    logger.info("RAG system setup complete.")

    if debug:
        return agent, llm, hybrid_retrievers

    return agent, llm