| import os |
| import streamlit as st |
| import torch |
| from datasets import load_dataset |
| from langdetect import detect |
| from deep_translator import GoogleTranslator |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain.vectorstores import FAISS |
| from langchain.docstore.document import Document |
| from langchain.embeddings import HuggingFaceEmbeddings |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
| |
| |
| |
| st.set_page_config( |
| page_title="Kenya Legal Assistant", |
| layout="wide" |
| ) |
|
|
| st.title("π°πͺ Kenya Legal Assistant") |
| st.caption("Ask questions about Kenyan court judgments (English or Swahili)") |
|
|
| |
| |
| |
| @st.cache_resource(show_spinner=True) |
| def load_vectorstore(): |
|
|
| st.write("π Loading legal knowledge base...") |
|
|
| dataset = load_dataset( |
| "Brian269/Kenyan_Judgements", |
| split="train", |
| streaming=True |
| ) |
|
|
| documents = [] |
| for i, item in enumerate(dataset): |
| if i > 200: |
| break |
|
|
| documents.append( |
| Document( |
| page_content=item["text"], |
| metadata={ |
| "source": item["file_name"], |
| "page": 1 |
| }, |
| ) |
| ) |
|
|
| splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1200, |
| chunk_overlap=200 |
| ) |
|
|
| chunks = [] |
| for doc in documents: |
| for chunk in splitter.split_text(doc.page_content): |
| chunks.append( |
| Document(page_content=chunk, metadata=doc.metadata) |
| ) |
|
|
| embeddings = HuggingFaceEmbeddings( |
| model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
| ) |
|
|
| INDEX_PATH = "faiss_index" |
|
|
| |
| if os.path.exists(INDEX_PATH): |
| st.write("β
Loading FAISS index...") |
| vectorstore = FAISS.load_local( |
| INDEX_PATH, |
| embeddings, |
| allow_dangerous_deserialization=True |
| ) |
| else: |
| st.warning("β οΈ FAISS index not found β building (first run only)...") |
| vectorstore = FAISS.from_documents(chunks, embeddings) |
| vectorstore.save_local(INDEX_PATH) |
|
|
| return vectorstore |
|
|
|
|
| |
| |
| |
| @st.cache_resource(show_spinner=True) |
| def load_llm(): |
|
|
| st.write("π§ Loading language model...") |
|
|
| model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| low_cpu_mem_usage=True |
| ) |
|
|
| pipe = pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| max_new_tokens=512, |
| temperature=0.2 |
| ) |
|
|
| return pipe |
|
|
|
|
| |
| vectorstore = load_vectorstore() |
| pipe = load_llm() |
|
|
| |
| |
| |
| def detect_language(text): |
| try: |
| return detect(text) |
| except: |
| return "en" |
|
|
|
|
| def translate(text, target_lang): |
| return GoogleTranslator(source="auto", target=target_lang).translate(text) |
|
|
|
|
| def build_prompt(question, context): |
| return f""" |
| You are a Kenyan legal assistant. |
| |
| Answer ONLY using the provided context. |
| Include proper case citations. |
| Do not fabricate information. |
| |
| Context: |
| {context} |
| |
| Question: |
| {question} |
| |
| Structured Answer: |
| """ |
|
|
|
|
| def ask_kenya_law(question): |
|
|
| language = detect_language(question) |
|
|
| question_en = ( |
| translate(question, "en") if language == "sw" else question |
| ) |
|
|
| retrieved_docs = vectorstore.similarity_search(question_en, k=4) |
|
|
| context = "\n\n".join([doc.page_content for doc in retrieved_docs]) |
|
|
| prompt = build_prompt(question_en, context) |
|
|
| result = pipe(prompt)[0]["generated_text"] |
|
|
| if language == "sw": |
| result = translate(result, "sw") |
|
|
| sources = "\n".join( |
| [f'{doc.metadata["source"]} - Page {doc.metadata["page"]}' |
| for doc in retrieved_docs] |
| ) |
|
|
| return result, sources |
|
|
|
|
| |
| |
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
|
|
| |
| for msg in st.session_state.messages: |
| with st.chat_message(msg["role"]): |
| st.markdown(msg["content"]) |
|
|
| prompt = st.chat_input("Ask a legal question...") |
|
|
| if prompt: |
|
|
| st.session_state.messages.append( |
| {"role": "user", "content": prompt} |
| ) |
|
|
| with st.chat_message("user"): |
| st.markdown(prompt) |
|
|
| with st.chat_message("assistant"): |
| with st.spinner("Analyzing Kenyan case law..."): |
| answer, sources = ask_kenya_law(prompt) |
|
|
| response = f""" |
| {answer} |
| |
| --- |
| |
| π **Sources** |
| {sources} |
| |
| β οΈ DISCLAIMER: |
| This AI provides legal information for educational purposes only. |
| It does NOT constitute legal advice. |
| """ |
|
|
| st.markdown(response) |
|
|
| st.session_state.messages.append( |
| {"role": "assistant", "content": response} |
| ) |
|
|