import os import streamlit as st import torch from datasets import load_dataset from langdetect import detect from deep_translator import GoogleTranslator from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from langchain.docstore.document import Document from langchain.embeddings import HuggingFaceEmbeddings from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # =================================== # PAGE CONFIG # =================================== st.set_page_config( page_title="Kenya Legal Assistant", layout="wide" ) st.title("🇰🇪 Kenya Legal Assistant") st.caption("Ask questions about Kenyan court judgments (English or Swahili)") # =================================== # LOAD VECTOR DATABASE (CACHED) # =================================== @st.cache_resource(show_spinner=True) def load_vectorstore(): st.write("🔎 Loading legal knowledge base...") dataset = load_dataset( "Brian269/Kenyan_Judgements", split="train", streaming=True ) documents = [] for i, item in enumerate(dataset): if i > 200: # prevents HF startup timeout break documents.append( Document( page_content=item["text"], metadata={ "source": item["file_name"], "page": 1 }, ) ) splitter = RecursiveCharacterTextSplitter( chunk_size=1200, chunk_overlap=200 ) chunks = [] for doc in documents: for chunk in splitter.split_text(doc.page_content): chunks.append( Document(page_content=chunk, metadata=doc.metadata) ) embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" ) INDEX_PATH = "faiss_index" # ✅ Load prebuilt FAISS index if uploaded if os.path.exists(INDEX_PATH): st.write("✅ Loading FAISS index...") vectorstore = FAISS.load_local( INDEX_PATH, embeddings, allow_dangerous_deserialization=True ) else: st.warning("⚠️ FAISS index not found — building (first run only)...") vectorstore = FAISS.from_documents(chunks, embeddings) vectorstore.save_local(INDEX_PATH) return vectorstore # =================================== # LOAD LANGUAGE MODEL (CACHED) # =================================== @st.cache_resource(show_spinner=True) def load_llm(): st.write("🧠 Loading language model...") model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, low_cpu_mem_usage=True ) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.2 ) return pipe # Load once vectorstore = load_vectorstore() pipe = load_llm() # =================================== # HELPERS # =================================== def detect_language(text): try: return detect(text) except: return "en" def translate(text, target_lang): return GoogleTranslator(source="auto", target=target_lang).translate(text) def build_prompt(question, context): return f""" You are a Kenyan legal assistant. Answer ONLY using the provided context. Include proper case citations. Do not fabricate information. Context: {context} Question: {question} Structured Answer: """ def ask_kenya_law(question): language = detect_language(question) question_en = ( translate(question, "en") if language == "sw" else question ) retrieved_docs = vectorstore.similarity_search(question_en, k=4) context = "\n\n".join([doc.page_content for doc in retrieved_docs]) prompt = build_prompt(question_en, context) result = pipe(prompt)[0]["generated_text"] if language == "sw": result = translate(result, "sw") sources = "\n".join( [f'{doc.metadata["source"]} - Page {doc.metadata["page"]}' for doc in retrieved_docs] ) return result, sources # =================================== # STREAMLIT CHAT UI # =================================== if "messages" not in st.session_state: st.session_state.messages = [] # Display history for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) prompt = st.chat_input("Ask a legal question...") if prompt: st.session_state.messages.append( {"role": "user", "content": prompt} ) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Analyzing Kenyan case law..."): answer, sources = ask_kenya_law(prompt) response = f""" {answer} --- 📚 **Sources** {sources} ⚠️ DISCLAIMER: This AI provides legal information for educational purposes only. It does NOT constitute legal advice. """ st.markdown(response) st.session_state.messages.append( {"role": "assistant", "content": response} )