Spaces:
Sleeping
Sleeping
File size: 7,193 Bytes
883d885 3ced916 883d885 3ced916 883d885 3ced916 883d885 ddfc14c 883d885 3ced916 883d885 3ced916 883d885 3ced916 883d885 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import logging
import streamlit as st
from dotenv import load_dotenv
import pickle
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever, RecursiveRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.tools import QueryEngineTool
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import get_response_synthesizer
from llama_index.core.agent import ReActAgent
from chromadb import PersistentClient
logger = logging.getLogger(__name__)
@st.cache_resource
def setup_rag_system(debug=False):
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY") or st.secrets.get("groq", {}).get("api_key")
if not groq_api_key:
st.error("GROQ API key not found. Please check your environment variables or secrets.")
st.stop()
# LLM
llm = Groq(
model="llama-3.1-8b-instant",
api_key=groq_api_key,
max_input_tokens=1200,
max_output_tokens=1200
)
# Embeddings
embedding_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Persisted vector DBs
persist_dirs = [
"./vectordb/case_2021",
"./vectordb/case_2022",
"./vectordb/case_2023",
"./vectordb/case_2024",
"./vectordb/case_2025"
]
for persist_dir in persist_dirs:
if not os.path.exists(persist_dir):
st.error(f"Vector database directory {persist_dir} not found.")
st.stop()
# Build hybrid retrievers
hybrid_retrievers = []
for persist_dir in persist_dirs:
# Load pickled nodes
nodes_path = os.path.join(persist_dir, "nodes.pkl")
if not os.path.exists(nodes_path):
st.error(f"Pickle file {nodes_path} not found.")
st.stop()
with open(nodes_path, "rb") as f:
nodes = pickle.load(f)
# Vector store
client = PersistentClient(path=persist_dir)
collection = client.get_collection("case_collection")
vector_store = ChromaVectorStore(chroma_collection=collection)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embedding_model)
# Retrievers
vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=2, retriever_mode="mmr")
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
hybrid_retriever = RecursiveRetriever(
"vector",
retriever_dict={"vector": vector_retriever, "bm25": bm25_retriever},
verbose=True
)
hybrid_retrievers.append(hybrid_retriever)
# Case metadata
documents_info = [
{
"name": "Quezada2021_Retriever",
"description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Quezada (21-0089-MC), issued on December 20, 2021."
},
{
"name": "Thompson2022_Retriever",
"description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Thompson (22-0098-AF), issued on November 21, 2022."
},
{
"name": "Brown2023_Retriever",
"description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Brown (22-0249-CG), issued on October 23, 2023."
},
{
"name": "Smith2024_Retriever",
"description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Smith (23-0207-AF), issued on November 26, 2024."
},
{
"name": "Lopez2025_Retriever",
"description": "Retrieves information from the United States Court of Appeals for the Armed Forces decision in United States v. Lopez (24-0226-CG), issued on September 2, 2025."
},
]
# Create retriever β tool
def create_retriever_tool(retriever, llm, name, description):
response_synthesizer = get_response_synthesizer(
llm=llm, response_mode="compact", use_async=False
)
query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
return QueryEngineTool.from_defaults(query_engine=query_engine, name=name, description=description)
retriever_tools = [
create_retriever_tool(hybrid_retrievers[i], llm, info["name"], info["description"])
for i, info in enumerate(documents_info)
]
# System prompt
system_prompt = """
You are a highly specialized legal research assistant.
You may ONLY answer questions that are legal in nature.
This includes both:
- Specific case law queries from the provided case documents (2021β2025).
- General legal concepts, doctrines, or terminology.
Before answering, always perform this intermediate reasoning step:
1. Classify the user query:
- If the query relates to law, legal concepts, legal systems, court rulings, rights, duties, contracts, procedures, or legal doctrines β classify as: LEGAL_QUERY.
- If the query is casual conversation, mathematics, trivia, technical programming, or anything outside the legal domain β classify as: NON_LEGAL_QUERY.
2. Response rules:
- If LEGAL_QUERY:
a) If the query references specific cases between 2021β2025, use the provided case documents to retrieve and answer. Cite the case name and year.
b) If the query is a general legal question, answer concisely and professionally, using legal reasoning. Do NOT speculate beyond standard legal knowledge.
- If NON_LEGAL_QUERY:
Respond ONLY with: "I can only answer questions about legal cases (2021β2025) or general law queries."
3. Examples:
- LEGAL_QUERY (answer these):
β’ "What is the difference between civil and criminal law?"
β’ "Explain the principle of judicial review."
β’ "Summarize the ruling in United States v. Lopez (2025)."
β’ "What is mens rea in criminal law?"
- NON_LEGAL_QUERY (reject these):
β’ "What is 2+2?"
β’ "Who won the FIFA World Cup in 2022?"
β’ "Write me a Python script."
β’ "Tell me a joke."
4. Style & tone:
- Be concise, professional, and clear.
- Use citations ONLY when referring to case documents (case name + year).
- Never provide speculative or non-legal answers.
"""
# ReActAgent
agent = ReActAgent(
tools=retriever_tools,
llm=llm,
verbose=True,
max_iterations=20,
system_prompt=system_prompt
)
logger.info("RAG system setup complete.")
if debug:
return agent, llm, hybrid_retrievers
return agent, llm
|