File size: 4,885 Bytes
11e272f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6fab9c
 
 
 
 
 
 
11e272f
 
d6fab9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
from datasets import load_dataset
import torch
from langdetect import detect
from deep_translator import GoogleTranslator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr

# -----------------------------
# Load PDFs from Hugging Face dataset
# -----------------------------
dataset = load_dataset("Brian269/Kenyan_Judgements", split="train")  # Replace with your dataset

documents = []
for item in dataset:
    pdf_text = item["text"]  # Assuming your dataset has a "text" field
    doc = Document(page_content=pdf_text, metadata={"source": item["file_name"], "page": 1})
    documents.append(doc)

# -----------------------------
# Split text into chunks
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200)
chunks = []
for doc in documents:
    for chunk in text_splitter.split_text(doc.page_content):
        chunks.append(Document(page_content=chunk, metadata=doc.metadata))

# -----------------------------
# Embeddings + FAISS index
# -----------------------------
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vectorstore = FAISS.from_documents(chunks, embedding_model)

# -----------------------------
# Load LLM
# -----------------------------
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.2)

# -----------------------------
# Helpers for multilingual queries
# -----------------------------
def detect_language(query):
    try:
        return detect(query)
    except:
        return "en"

def translate_text(text, target_lang):
    if target_lang == "sw":
        return GoogleTranslator(source='auto', target='sw').translate(text)
    elif target_lang == "en":
        return GoogleTranslator(source='auto', target='en').translate(text)
    return text

# -----------------------------
# Build prompts
# -----------------------------
DISCLAIMER_TEXT = """
⚠️ DISCLAIMER:
This AI assistant provides legal information derived from publicly
available Kenyan court judgments for educational purposes only.
It does NOT provide legal advice.
For professional legal assistance, consult a qualified advocate.
"""

def build_prompt(question, context):
    instruction = """
You are a Kenyan legal assistant.
Answer concisely using ONLY the provided context.
Include proper case citation (case name and page).
Do not fabricate information.
"""
    return f"{instruction}\n\nContext:\n{context}\n\nQuestion:\n{question}\n\nProvide a clear structured answer."

# -----------------------------
# Query system
# -----------------------------
def ask_kenya_law(question, k=4):
    language = detect_language(question)
    translated_question = translate_text(question, "en") if language == "sw" else question
    retrieved_docs = vectorstore.similarity_search(translated_question, k=k)
    context = "\n\n".join([doc.page_content for doc in retrieved_docs])
    prompt = build_prompt(translated_question, context)
    response = pipe(prompt)[0]["generated_text"]
    if language == "sw":
        response = translate_text(response, "sw")
    sources = [f'{doc.metadata["source"]} - Page {doc.metadata["page"]}' for doc in retrieved_docs]
    return response, "\n".join(sources)

# -----------------------------
# Gradio Interface
# -----------------------------
def query_system(user_input):
    answer, sources = ask_kenya_law(user_input)
    return answer + "\n\nπŸ“š SOURCES:\n" + sources + DISCLAIMER_TEXT

import streamlit as st

st.set_page_config(page_title="Kenya Legal Assistant")

st.title("πŸ‡°πŸ‡ͺ Kenya Legal Assistant")
st.write(
    "Ask questions about Kenyan court judgments in English or Swahili."
)

user_input = st.text_input("Enter your legal question:")

if st.button("Ask"):
    if user_input.strip():
        with st.spinner("Analyzing legal documents..."):
            answer, sources = ask_kenya_law(user_input)

            st.markdown("### πŸ“– Answer")
            st.write(answer)

            st.markdown("### πŸ“š Sources")
            st.text(sources)

            st.warning("""
⚠️ DISCLAIMER:
This AI assistant provides legal information derived from publicly
available Kenyan court judgments for educational purposes only.
It does NOT provide legal advice.
For professional legal assistance, consult a qualified advocate.
""")