import os from datasets import load_dataset import torch from langdetect import detect from deep_translator import GoogleTranslator from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from langchain.docstore.document import Document from langchain.embeddings import HuggingFaceEmbeddings from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import gradio as gr # ----------------------------- # Load PDFs from Hugging Face dataset # ----------------------------- dataset = load_dataset("Brian269/Kenyan_Judgements", split="train") # Replace with your dataset documents = [] for item in dataset: pdf_text = item["text"] # Assuming your dataset has a "text" field doc = Document(page_content=pdf_text, metadata={"source": item["file_name"], "page": 1}) documents.append(doc) # ----------------------------- # Split text into chunks # ----------------------------- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=200) chunks = [] for doc in documents: for chunk in text_splitter.split_text(doc.page_content): chunks.append(Document(page_content=chunk, metadata=doc.metadata)) # ----------------------------- # Embeddings + FAISS index # ----------------------------- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") vectorstore = FAISS.from_documents(chunks, embedding_model) # ----------------------------- # Load LLM # ----------------------------- model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.2) # ----------------------------- # Helpers for multilingual queries # ----------------------------- def detect_language(query): try: return detect(query) except: return "en" def translate_text(text, target_lang): if target_lang == "sw": return GoogleTranslator(source='auto', target='sw').translate(text) elif target_lang == "en": return GoogleTranslator(source='auto', target='en').translate(text) return text # ----------------------------- # Build prompts # ----------------------------- DISCLAIMER_TEXT = """ ⚠️ DISCLAIMER: This AI assistant provides legal information derived from publicly available Kenyan court judgments for educational purposes only. It does NOT provide legal advice. For professional legal assistance, consult a qualified advocate. """ def build_prompt(question, context): instruction = """ You are a Kenyan legal assistant. Answer concisely using ONLY the provided context. Include proper case citation (case name and page). Do not fabricate information. """ return f"{instruction}\n\nContext:\n{context}\n\nQuestion:\n{question}\n\nProvide a clear structured answer." # ----------------------------- # Query system # ----------------------------- def ask_kenya_law(question, k=4): language = detect_language(question) translated_question = translate_text(question, "en") if language == "sw" else question retrieved_docs = vectorstore.similarity_search(translated_question, k=k) context = "\n\n".join([doc.page_content for doc in retrieved_docs]) prompt = build_prompt(translated_question, context) response = pipe(prompt)[0]["generated_text"] if language == "sw": response = translate_text(response, "sw") sources = [f'{doc.metadata["source"]} - Page {doc.metadata["page"]}' for doc in retrieved_docs] return response, "\n".join(sources) # ----------------------------- # Gradio Interface # ----------------------------- def query_system(user_input): answer, sources = ask_kenya_law(user_input) return answer + "\n\n📚 SOURCES:\n" + sources + DISCLAIMER_TEXT iface = gr.Interface( fn=query_system, inputs="text", outputs="text", title="Kenya Legal Assistant", description="Ask questions about Kenyan court judgments in English or Swahili." ) iface.launch()