legalAssistant / app.py
fizzarif7's picture
Update app.py
786a3da verified
import gradio as gr
import os
import datetime
import numpy as np
import pandas as pd
import json
import faiss
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
# --- Load Hugging Face API Key ---
hf_token = os.environ.get("API_KEY")
if not hf_token:
raise ValueError("Hugging Face API key not found. Please set 'API_KEY' in your environment.")
login(token=hf_token)
# --- Model Setup ---
qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
summarizer_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
summarizer_pipeline = pipeline("summarization", model=summarizer_model, tokenizer=summarizer_tokenizer)
embed_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
# --- Load Corpus ---
df_parquet = pd.read_parquet("ibtehaj dataset.parquet")
corpus_parquet = df_parquet["text"].dropna().tolist()
with open("pdf_data.json", "r", encoding="utf-8") as f:
json_data = json.load(f)
corpus_json = [entry["text"].strip() for entry in json_data if isinstance(entry, dict) and "text" in entry]
corpus = corpus_parquet + corpus_json
# --- Build FAISS Index ---
embeddings = embed_model.encode(corpus, show_progress_bar=True, batch_size=16)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(np.array(embeddings))
# --- Global Variables ---
last_question = ""
last_response = ""
# --- Core Function ---
def rag_answer(question: str, k: int = 3) -> str:
q_emb = embed_model.encode([question])
D, I = index.search(q_emb, k)
context = "\n\n".join(corpus[i] for i in I[0] if 0 <= i < len(corpus))
if not context.strip():
return "Context is empty. Try rephrasing the question."
result = qa_pipeline(question=question, context=context)
raw_answer = result.get("answer", "No answer found.")
if len(raw_answer.split()) > 40 or len(raw_answer) > 300:
summary = summarizer_pipeline(raw_answer, max_length=50, min_length=15, do_sample=False)
summarized = summary[0]['summary_text']
else:
summarized = raw_answer
return f"{summarized}\n\n[Context Used]:\n{context[:400]}..."
# --- Answer Handler ---
def generate_answer(user_input):
global last_question, last_response
if not user_input.strip():
return "Please enter a question."
answer = rag_answer(user_input)
last_question = user_input
last_response = answer
return answer
# --- Save Logic ---
def save_to_file(q, a):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = f"legal_qna_{timestamp}.txt"
content = f"Q: {q}\n\nA: {a}"
with open(filename, "w", encoding="utf-8") as f:
f.write(content)
return gr.Textbox.update(value=f"Saved to {filename}"), gr.update(visible=False)
def cancel_save():
return gr.update(visible=False)
# --- Styling ---
custom_css = """
html, body {
height: 100%;
margin: 0;
}
.gr-block.gr-row {
height: 100vh !important;
}
#left-panel, #center-panel, #right-panel {
height: 100%;
padding: 10px;
border-radius: 10px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.05);
}
#left-panel { background-color: #e3f2fd; }
#center-panel { background-color: #e1f5fe; }
#right-panel { background-color: #e0f7fa; }
#right-panel .gr-chatbot {
max-height: 100%;
overflow-y: auto;
background-color: #ffffff;
border: 1px solid #dcdcdc;
border-radius: 10px;
padding: 8px;
}
textarea, input[type="text"] {
border-radius: 10px;
border: 1px solid #bbb;
background-color: #f0f8ff;
padding: 12px;
font-size: 16px;
color: #333;
height: 120px !important;
}
button {
border-radius: 20px !important;
color: #333;
border: none;
transition: background-color 0.3s ease;
padding: 12px 24px;
font-weight: bold;
}
button:nth-child(1) { background-color: #bbdefb !important; }
button:nth-child(2) { background-color: #c8e6c9 !important; }
button:nth-child(3) { background-color: #ffe0b2 !important; }
button:nth-child(4) { background-color: #d1c4e9 !important; }
button:nth-child(5) { background-color: #f8bbd0 !important; }
button:nth-child(6) { background-color: #b2ebf2 !important; }
button:hover { filter: brightness(95%); }
#center-panel h1 {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
color: #2c3e50;
text-align: center;
}
#center-panel i {
display: block;
margin-top: 15px;
text-align: center;
font-style: italic;
color: #777;
}
"""
# --- Gradio UI ---
with gr.Blocks(css=custom_css) as demo:
with gr.Row():
# Left panel
with gr.Column(scale=1, elem_id="left-panel"):
gr.Image(label="Legal Reference Image", show_label=False)
# Center panel
with gr.Column(scale=2, elem_id="center-panel"):
gr.Markdown("<h1>Legal Assistant</h1>")
with gr.Row():
input_box = gr.Textbox(label="Input Box", lines=6, placeholder="Ask your legal question here...")
answer_box = gr.Textbox(label="Answer Box", lines=6, interactive=False)
with gr.Row():
dictate_btn = gr.Button(" Dictate ")
generate_btn = gr.Button(" Generate Response")
upload_btn = gr.Button(" Upload MP3")
with gr.Row():
read_btn = gr.Button(" Read Aloud")
save_btn = gr.Button(" Save")
reset_btn = gr.Button(" Reset")
gr.Markdown("<i>Consult here for legal statements</i>")
# Save confirmation dialog
with gr.Column(visible=False, elem_id="confirm-save") as confirm_save_block:
gr.Markdown("Are you sure you want to save this Q&A?")
with gr.Row():
yes_btn = gr.Button("Yes, Save")
no_btn = gr.Button("No, Cancel")
status_box = gr.Textbox(label="Status", interactive=False)
# Right panel
with gr.Column(scale=1, elem_id="right-panel"):
gr.Chatbot(label="Conversation History")
# --- Button Actions ---
generate_btn.click(
fn=generate_answer,
inputs=[input_box],
outputs=[answer_box]
)
save_btn.click(
lambda: gr.update(visible=True),
outputs=[confirm_save_block]
)
yes_btn.click(
fn=save_to_file,
inputs=[input_box, answer_box],
outputs=[status_box, confirm_save_block]
)
no_btn.click(
fn=cancel_save,
outputs=[confirm_save_block]
)
# --- Launch App ---
demo.launch()