Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import datetime | |
| import numpy as np | |
| import pandas as pd | |
| import json | |
| import faiss | |
| from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import login | |
| # --- Load Hugging Face API Key --- | |
| hf_token = os.environ.get("API_KEY") | |
| if not hf_token: | |
| raise ValueError("Hugging Face API key not found. Please set 'API_KEY' in your environment.") | |
| login(token=hf_token) | |
| # --- Model Setup --- | |
| qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") | |
| qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") | |
| qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer) | |
| summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") | |
| summarizer_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
| summarizer_pipeline = pipeline("summarization", model=summarizer_model, tokenizer=summarizer_tokenizer) | |
| embed_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2") | |
| # --- Load Corpus --- | |
| df_parquet = pd.read_parquet("ibtehaj dataset.parquet") | |
| corpus_parquet = df_parquet["text"].dropna().tolist() | |
| with open("pdf_data.json", "r", encoding="utf-8") as f: | |
| json_data = json.load(f) | |
| corpus_json = [entry["text"].strip() for entry in json_data if isinstance(entry, dict) and "text" in entry] | |
| corpus = corpus_parquet + corpus_json | |
| # --- Build FAISS Index --- | |
| embeddings = embed_model.encode(corpus, show_progress_bar=True, batch_size=16) | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(np.array(embeddings)) | |
| # --- Global Variables --- | |
| last_question = "" | |
| last_response = "" | |
| # --- Core Function --- | |
| def rag_answer(question: str, k: int = 3) -> str: | |
| q_emb = embed_model.encode([question]) | |
| D, I = index.search(q_emb, k) | |
| context = "\n\n".join(corpus[i] for i in I[0] if 0 <= i < len(corpus)) | |
| if not context.strip(): | |
| return "Context is empty. Try rephrasing the question." | |
| result = qa_pipeline(question=question, context=context) | |
| raw_answer = result.get("answer", "No answer found.") | |
| if len(raw_answer.split()) > 40 or len(raw_answer) > 300: | |
| summary = summarizer_pipeline(raw_answer, max_length=50, min_length=15, do_sample=False) | |
| summarized = summary[0]['summary_text'] | |
| else: | |
| summarized = raw_answer | |
| return f"{summarized}\n\n[Context Used]:\n{context[:400]}..." | |
| # --- Answer Handler --- | |
| def generate_answer(user_input): | |
| global last_question, last_response | |
| if not user_input.strip(): | |
| return "Please enter a question." | |
| answer = rag_answer(user_input) | |
| last_question = user_input | |
| last_response = answer | |
| return answer | |
| # --- Save Logic --- | |
| def save_to_file(q, a): | |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| filename = f"legal_qna_{timestamp}.txt" | |
| content = f"Q: {q}\n\nA: {a}" | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return gr.Textbox.update(value=f"Saved to {filename}"), gr.update(visible=False) | |
| def cancel_save(): | |
| return gr.update(visible=False) | |
| # --- Styling --- | |
| custom_css = """ | |
| html, body { | |
| height: 100%; | |
| margin: 0; | |
| } | |
| .gr-block.gr-row { | |
| height: 100vh !important; | |
| } | |
| #left-panel, #center-panel, #right-panel { | |
| height: 100%; | |
| padding: 10px; | |
| border-radius: 10px; | |
| box-shadow: 0 0 10px rgba(0, 0, 0, 0.05); | |
| } | |
| #left-panel { background-color: #e3f2fd; } | |
| #center-panel { background-color: #e1f5fe; } | |
| #right-panel { background-color: #e0f7fa; } | |
| #right-panel .gr-chatbot { | |
| max-height: 100%; | |
| overflow-y: auto; | |
| background-color: #ffffff; | |
| border: 1px solid #dcdcdc; | |
| border-radius: 10px; | |
| padding: 8px; | |
| } | |
| textarea, input[type="text"] { | |
| border-radius: 10px; | |
| border: 1px solid #bbb; | |
| background-color: #f0f8ff; | |
| padding: 12px; | |
| font-size: 16px; | |
| color: #333; | |
| height: 120px !important; | |
| } | |
| button { | |
| border-radius: 20px !important; | |
| color: #333; | |
| border: none; | |
| transition: background-color 0.3s ease; | |
| padding: 12px 24px; | |
| font-weight: bold; | |
| } | |
| button:nth-child(1) { background-color: #bbdefb !important; } | |
| button:nth-child(2) { background-color: #c8e6c9 !important; } | |
| button:nth-child(3) { background-color: #ffe0b2 !important; } | |
| button:nth-child(4) { background-color: #d1c4e9 !important; } | |
| button:nth-child(5) { background-color: #f8bbd0 !important; } | |
| button:nth-child(6) { background-color: #b2ebf2 !important; } | |
| button:hover { filter: brightness(95%); } | |
| #center-panel h1 { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| color: #2c3e50; | |
| text-align: center; | |
| } | |
| #center-panel i { | |
| display: block; | |
| margin-top: 15px; | |
| text-align: center; | |
| font-style: italic; | |
| color: #777; | |
| } | |
| """ | |
| # --- Gradio UI --- | |
| with gr.Blocks(css=custom_css) as demo: | |
| with gr.Row(): | |
| # Left panel | |
| with gr.Column(scale=1, elem_id="left-panel"): | |
| gr.Image(label="Legal Reference Image", show_label=False) | |
| # Center panel | |
| with gr.Column(scale=2, elem_id="center-panel"): | |
| gr.Markdown("<h1>Legal Assistant</h1>") | |
| with gr.Row(): | |
| input_box = gr.Textbox(label="Input Box", lines=6, placeholder="Ask your legal question here...") | |
| answer_box = gr.Textbox(label="Answer Box", lines=6, interactive=False) | |
| with gr.Row(): | |
| dictate_btn = gr.Button(" Dictate ") | |
| generate_btn = gr.Button(" Generate Response") | |
| upload_btn = gr.Button(" Upload MP3") | |
| with gr.Row(): | |
| read_btn = gr.Button(" Read Aloud") | |
| save_btn = gr.Button(" Save") | |
| reset_btn = gr.Button(" Reset") | |
| gr.Markdown("<i>Consult here for legal statements</i>") | |
| # Save confirmation dialog | |
| with gr.Column(visible=False, elem_id="confirm-save") as confirm_save_block: | |
| gr.Markdown("Are you sure you want to save this Q&A?") | |
| with gr.Row(): | |
| yes_btn = gr.Button("Yes, Save") | |
| no_btn = gr.Button("No, Cancel") | |
| status_box = gr.Textbox(label="Status", interactive=False) | |
| # Right panel | |
| with gr.Column(scale=1, elem_id="right-panel"): | |
| gr.Chatbot(label="Conversation History") | |
| # --- Button Actions --- | |
| generate_btn.click( | |
| fn=generate_answer, | |
| inputs=[input_box], | |
| outputs=[answer_box] | |
| ) | |
| save_btn.click( | |
| lambda: gr.update(visible=True), | |
| outputs=[confirm_save_block] | |
| ) | |
| yes_btn.click( | |
| fn=save_to_file, | |
| inputs=[input_box, answer_box], | |
| outputs=[status_box, confirm_save_block] | |
| ) | |
| no_btn.click( | |
| fn=cancel_save, | |
| outputs=[confirm_save_block] | |
| ) | |
| # --- Launch App --- | |
| demo.launch() | |