SheriaAit / app.py
PeterMaks's picture
Update that oly allows the RAG model to get context from the provided database (#1)
8be2711 verified
import streamlit as st
import ollama
import chromadb
from chromadb.utils import embedding_functions
import os
import time
# --- SETUP ---
DB_PATH = "./legal_db"
COLLECTION_NAME = "legal_docs"
EMBEDDING_MODEL = 'hf.co/CompendiumLabs/bge-base-en-v1.5-gguf'
LANGUAGE_MODEL = 'llama3.2:3b'
st.set_page_config(page_title="Legal Assistant AI")
st.title("⚖️ Kenya Law RAG Bot")
# --- INITIALIZE DATABASE ---
@st.cache_resource
def get_collection():
client = chromadb.PersistentClient(path=DB_PATH)
ollama_ef = embedding_functions.OllamaEmbeddingFunction(
model_name=EMBEDDING_MODEL,
url="http://localhost:11434/api/embeddings"
)
return client.get_or_create_collection(name=COLLECTION_NAME, embedding_function=ollama_ef)
try:
collection = get_collection()
except Exception as e:
st.error(f"Could not connect to database: {e}")
st.stop()
# --- CHAT INTERFACE ---
if "messages" not in st.session_state:
st.session_state.messages = []
# Display history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle Input
if prompt := st.chat_input("Ask a legal question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# RAG Logic
with st.chat_message("assistant"):
with st.spinner("Searching legal documents..."):
results = collection.query(query_texts=[prompt], n_results=3)
context_str = ""
if results['documents'] and results['documents'][0]:
for i, doc in enumerate(results['documents'][0]):
meta = results['metadatas'][0][i]
context_str += f"[Source: {meta.get('source', 'unknown')}]\n{doc}\n\n"
# --- STRICT PROMPT LOGIC ---
system_msg = f"""
You are a strict specialized assistant. You verify facts against the provided database extracts.
DATABASE EXTRACTS:
{context_str}
RULES:
1. You must ONLY answer using the information in the 'DATABASE EXTRACTS' above.
2. If the answer is not explicitly in the extracts, you MUST say: "The provided documents do not contain information about this."
3. Do not use outside knowledge. Do not make up laws or facts.
4. Cite the source file names provided in the extracts.
"""
if not context_str:
context_str = "No relevant documents found."
# Override system msg to force a "not found" response if no context exists
system_msg = "The database contained no relevant information. Inform the user you cannot answer based on the available documents."
# Streaming Response
stream = ollama.chat(
model=LANGUAGE_MODEL,
messages=[{'role': 'system', 'content': system_msg}, {'role': 'user', 'content': prompt}],
stream=True
)
# Uses st.write_stream as requested in your snippet
response = st.write_stream(chunk['message']['content'] for chunk in stream)
st.session_state.messages.append({"role": "assistant", "content": response})