File size: 4,885 Bytes
11e272f d6fab9c 11e272f d6fab9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import os
from datasets import load_dataset
import torch
from langdetect import detect
from deep_translator import GoogleTranslator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr
# -----------------------------
# Load PDFs from Hugging Face dataset
# -----------------------------
dataset = load_dataset("Brian269/Kenyan_Judgements", split="train") # Replace with your dataset
documents = []
for item in dataset:
pdf_text = item["text"] # Assuming your dataset has a "text" field
doc = Document(page_content=pdf_text, metadata={"source": item["file_name"], "page": 1})
documents.append(doc)
# -----------------------------
# Split text into chunks
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200)
chunks = []
for doc in documents:
for chunk in text_splitter.split_text(doc.page_content):
chunks.append(Document(page_content=chunk, metadata=doc.metadata))
# -----------------------------
# Embeddings + FAISS index
# -----------------------------
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vectorstore = FAISS.from_documents(chunks, embedding_model)
# -----------------------------
# Load LLM
# -----------------------------
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.2)
# -----------------------------
# Helpers for multilingual queries
# -----------------------------
def detect_language(query):
try:
return detect(query)
except:
return "en"
def translate_text(text, target_lang):
if target_lang == "sw":
return GoogleTranslator(source='auto', target='sw').translate(text)
elif target_lang == "en":
return GoogleTranslator(source='auto', target='en').translate(text)
return text
# -----------------------------
# Build prompts
# -----------------------------
DISCLAIMER_TEXT = """
β οΈ DISCLAIMER:
This AI assistant provides legal information derived from publicly
available Kenyan court judgments for educational purposes only.
It does NOT provide legal advice.
For professional legal assistance, consult a qualified advocate.
"""
def build_prompt(question, context):
instruction = """
You are a Kenyan legal assistant.
Answer concisely using ONLY the provided context.
Include proper case citation (case name and page).
Do not fabricate information.
"""
return f"{instruction}\n\nContext:\n{context}\n\nQuestion:\n{question}\n\nProvide a clear structured answer."
# -----------------------------
# Query system
# -----------------------------
def ask_kenya_law(question, k=4):
language = detect_language(question)
translated_question = translate_text(question, "en") if language == "sw" else question
retrieved_docs = vectorstore.similarity_search(translated_question, k=k)
context = "\n\n".join([doc.page_content for doc in retrieved_docs])
prompt = build_prompt(translated_question, context)
response = pipe(prompt)[0]["generated_text"]
if language == "sw":
response = translate_text(response, "sw")
sources = [f'{doc.metadata["source"]} - Page {doc.metadata["page"]}' for doc in retrieved_docs]
return response, "\n".join(sources)
# -----------------------------
# Gradio Interface
# -----------------------------
def query_system(user_input):
answer, sources = ask_kenya_law(user_input)
return answer + "\n\nπ SOURCES:\n" + sources + DISCLAIMER_TEXT
import streamlit as st
st.set_page_config(page_title="Kenya Legal Assistant")
st.title("π°πͺ Kenya Legal Assistant")
st.write(
"Ask questions about Kenyan court judgments in English or Swahili."
)
user_input = st.text_input("Enter your legal question:")
if st.button("Ask"):
if user_input.strip():
with st.spinner("Analyzing legal documents..."):
answer, sources = ask_kenya_law(user_input)
st.markdown("### π Answer")
st.write(answer)
st.markdown("### π Sources")
st.text(sources)
st.warning("""
β οΈ DISCLAIMER:
This AI assistant provides legal information derived from publicly
available Kenyan court judgments for educational purposes only.
It does NOT provide legal advice.
For professional legal assistance, consult a qualified advocate.
""")
|