MajorProjectRAG / app.py
Prasanga73's picture
Update app.py
cc90595 verified
import gradio as gr
from huggingface_hub import InferenceClient
import os
from src.data_processor import LegalDocProcessor
from src.hybrid_retriever import HybridRetriever
# --- Configuration & Initialization ---
INDEX_DIR = "index_storage"
PARENT_DATA = "data/parent_docs.json"
CHILD_DATA = "data/child_docs.json"
def initialize_retriever():
try:
if os.path.exists(INDEX_DIR):
print("[*] Loading existing index...")
return HybridRetriever(index_dir=INDEX_DIR)
else:
print("[*] Building new index...")
processor = LegalDocProcessor(PARENT_DATA, CHILD_DATA)
docs = processor.load_and_clean()
if not docs:
return None
ret = HybridRetriever(documents=docs, index_dir=INDEX_DIR)
ret.save_index()
return ret
except Exception as e:
print(f"Error initializing retriever: {e}")
return None
# Global retriever instance
retriever = initialize_retriever()
def respond(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken | None = None,
):
# 1. RETRIEVAL STEP
context = ""
if retriever:
try:
search_results = retriever.hybrid_search(message, top_k=3)
context = "\n\nRELEVANT NEPALESE LAW CONTEXT:\n"
if not search_results:
context += "No specific legal clauses found for this query."
for res in search_results:
src = res.get('legal_document_source', 'Unknown')
cid = res.get('parent_clause_id', 'N/A')
txt = res.get('parent_clause_text', 'Text not found')
context += f"--- Source: {src} ---\nClause: {cid}\nText: {txt}\n\n"
except Exception as e:
print(f"Retrieval Error: {e}")
context = "\n(Error retrieving specific law context.)"
# 2. PROMPT ENGINEERING
augmented_system_message = (
f"{system_message}\n\n"
"You are a legal assistant specializing in Nepalese Law. "
"Use the legal context provided below to answer. Cite the Source and Clause ID.\n"
f"{context}"
)
# 3. TOKEN SETUP
raw_token = hf_token.token if (hf_token and hasattr(hf_token, 'token')) else os.getenv("HF_TOKEN", "")
token = raw_token.strip() if raw_token else None
if not token:
yield "⚠️ Error: Please sign in with Hugging Face (see sidebar) or set HF_TOKEN secret."
return
client = InferenceClient(token=token, model="meta-llama/Llama-3.1-70B-Instruct")
# 4. UNIVERSAL HISTORY PARSER (Handles both List of Lists and List of Dicts)
messages = [{"role": "system", "content": augmented_system_message}]
for item in history:
if isinstance(item, dict):
# Gradio 5 or Newer Gradio 4 format: {"role": "user", "content": "..."}
messages.append({"role": item["role"], "content": item["content"]})
elif isinstance(item, (list, tuple)):
# Traditional Gradio 4 format: [user_msg, bot_msg]
u, a = item
if u: messages.append({"role": "user", "content": u})
if a: messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": message})
# 5. GENERATION
response = ""
try:
for msg in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token_text = msg.choices[0].delta.content
if token_text:
response += token_text
yield response
except Exception as e:
yield f"AI Error: {str(e)}"
# --- Gradio UI Setup ---
chatbot = gr.ChatInterface(
respond,
# REMOVED type="messages" to fix TypeError
additional_inputs=[
gr.Textbox(value="You are a helpful Nepalese Legal Advisor.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
],
title="Nepal Law Search AI",
description="Ask questions about Nepalese Constitution and Acts.",
examples=[
["What are the punishments for cybercrime?"],
["What does the constitution say about the right to equality?"],
["Is witchcraft accusation a crime in Nepal?"]
],
cache_examples=False,
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.Markdown("### Authentication")
gr.LoginButton()
gr.Markdown("---")
gr.Markdown("**Status:** Database Ready ✅")
chatbot.render()
if __name__ == "__main__":
# Disable SSR to avoid Python 3.13 asyncio noise
demo.launch(ssr_mode=False, show_error=True)