03 / app.py
Brian269's picture
Update app.py
551c35e verified
import os
import streamlit as st
import torch
from datasets import load_dataset
from langdetect import detect
from deep_translator import GoogleTranslator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# ===================================
# PAGE CONFIG
# ===================================
st.set_page_config(
page_title="Kenya Legal Assistant",
layout="wide"
)
st.title("πŸ‡°πŸ‡ͺ Kenya Legal Assistant")
st.caption("Ask questions about Kenyan court judgments (English or Swahili)")
# ===================================
# LOAD VECTOR DATABASE (CACHED)
# ===================================
@st.cache_resource(show_spinner=True)
def load_vectorstore():
st.write("πŸ”Ž Loading legal knowledge base...")
dataset = load_dataset(
"Brian269/Kenyan_Judgements",
split="train",
streaming=True
)
documents = []
for i, item in enumerate(dataset):
if i > 200: # prevents HF startup timeout
break
documents.append(
Document(
page_content=item["text"],
metadata={
"source": item["file_name"],
"page": 1
},
)
)
splitter = RecursiveCharacterTextSplitter(
chunk_size=1200,
chunk_overlap=200
)
chunks = []
for doc in documents:
for chunk in splitter.split_text(doc.page_content):
chunks.append(
Document(page_content=chunk, metadata=doc.metadata)
)
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
INDEX_PATH = "faiss_index"
# βœ… Load prebuilt FAISS index if uploaded
if os.path.exists(INDEX_PATH):
st.write("βœ… Loading FAISS index...")
vectorstore = FAISS.load_local(
INDEX_PATH,
embeddings,
allow_dangerous_deserialization=True
)
else:
st.warning("⚠️ FAISS index not found β€” building (first run only)...")
vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local(INDEX_PATH)
return vectorstore
# ===================================
# LOAD LANGUAGE MODEL (CACHED)
# ===================================
@st.cache_resource(show_spinner=True)
def load_llm():
st.write("🧠 Loading language model...")
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.2
)
return pipe
# Load once
vectorstore = load_vectorstore()
pipe = load_llm()
# ===================================
# HELPERS
# ===================================
def detect_language(text):
try:
return detect(text)
except:
return "en"
def translate(text, target_lang):
return GoogleTranslator(source="auto", target=target_lang).translate(text)
def build_prompt(question, context):
return f"""
You are a Kenyan legal assistant.
Answer ONLY using the provided context.
Include proper case citations.
Do not fabricate information.
Context:
{context}
Question:
{question}
Structured Answer:
"""
def ask_kenya_law(question):
language = detect_language(question)
question_en = (
translate(question, "en") if language == "sw" else question
)
retrieved_docs = vectorstore.similarity_search(question_en, k=4)
context = "\n\n".join([doc.page_content for doc in retrieved_docs])
prompt = build_prompt(question_en, context)
result = pipe(prompt)[0]["generated_text"]
if language == "sw":
result = translate(result, "sw")
sources = "\n".join(
[f'{doc.metadata["source"]} - Page {doc.metadata["page"]}'
for doc in retrieved_docs]
)
return result, sources
# ===================================
# STREAMLIT CHAT UI
# ===================================
if "messages" not in st.session_state:
st.session_state.messages = []
# Display history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
prompt = st.chat_input("Ask a legal question...")
if prompt:
st.session_state.messages.append(
{"role": "user", "content": prompt}
)
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Analyzing Kenyan case law..."):
answer, sources = ask_kenya_law(prompt)
response = f"""
{answer}
---
πŸ“š **Sources**
{sources}
⚠️ DISCLAIMER:
This AI provides legal information for educational purposes only.
It does NOT constitute legal advice.
"""
st.markdown(response)
st.session_state.messages.append(
{"role": "assistant", "content": response}
)