01 / app.py
Brian269's picture
Update app.py
7643630 verified
import os
from datasets import load_dataset
import torch
from langdetect import detect
from deep_translator import GoogleTranslator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr
# -----------------------------
# Load PDFs from Hugging Face dataset
# -----------------------------
dataset = load_dataset("Brian269/Kenyan_Judgements", split="train") # Replace with your dataset
documents = []
for item in dataset:
pdf_text = item["text"] # Assuming your dataset has a "text" field
doc = Document(page_content=pdf_text, metadata={"source": item["file_name"], "page": 1})
documents.append(doc)
# -----------------------------
# Split text into chunks
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200)
chunks = []
for doc in documents:
for chunk in text_splitter.split_text(doc.page_content):
chunks.append(Document(page_content=chunk, metadata=doc.metadata))
# -----------------------------
# Embeddings + FAISS index
# -----------------------------
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vectorstore = FAISS.from_documents(chunks, embedding_model)
# -----------------------------
# Load LLM
# -----------------------------
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.2)
# -----------------------------
# Helpers for multilingual queries
# -----------------------------
def detect_language(query):
try:
return detect(query)
except:
return "en"
def translate_text(text, target_lang):
if target_lang == "sw":
return GoogleTranslator(source='auto', target='sw').translate(text)
elif target_lang == "en":
return GoogleTranslator(source='auto', target='en').translate(text)
return text
# -----------------------------
# Build prompts
# -----------------------------
DISCLAIMER_TEXT = """
⚠️ DISCLAIMER:
This AI assistant provides legal information derived from publicly
available Kenyan court judgments for educational purposes only.
It does NOT provide legal advice.
For professional legal assistance, consult a qualified advocate.
"""
def build_prompt(question, context):
instruction = """
You are a Kenyan legal assistant.
Answer concisely using ONLY the provided context.
Include proper case citation (case name and page).
Do not fabricate information.
"""
return f"{instruction}\n\nContext:\n{context}\n\nQuestion:\n{question}\n\nProvide a clear structured answer."
# -----------------------------
# Query system
# -----------------------------
def ask_kenya_law(question, k=4):
language = detect_language(question)
translated_question = translate_text(question, "en") if language == "sw" else question
retrieved_docs = vectorstore.similarity_search(translated_question, k=k)
context = "\n\n".join([doc.page_content for doc in retrieved_docs])
prompt = build_prompt(translated_question, context)
response = pipe(prompt)[0]["generated_text"]
if language == "sw":
response = translate_text(response, "sw")
sources = [f'{doc.metadata["source"]} - Page {doc.metadata["page"]}' for doc in retrieved_docs]
return response, "\n".join(sources)
# -----------------------------
# Gradio Interface
# -----------------------------
def query_system(user_input):
answer, sources = ask_kenya_law(user_input)
return answer + "\n\n📚 SOURCES:\n" + sources + DISCLAIMER_TEXT
iface = gr.Interface(
fn=query_system,
inputs="text",
outputs="text",
title="Kenya Legal Assistant",
description="Ask questions about Kenyan court judgments in English or Swahili."
)
iface.launch()