| import streamlit as st |
| import shelve |
| import docx2txt |
| import PyPDF2 |
| import time |
| import nltk |
|
|
| import re |
| import os |
| import requests |
| from dotenv import load_dotenv |
|
|
|
|
| import torch |
| from sentence_transformers import SentenceTransformer, util |
| from transformers import pipeline |
| import nltk |
|
|
| nltk.download('punkt') |
| import hashlib |
| from nltk import sent_tokenize |
| nltk.download('punkt_tab') |
| from transformers import LEDTokenizer, LEDForConditionalGeneration |
| import torch |
|
|
| st.set_page_config(page_title="Legal Document Summarizer", layout="wide") |
|
|
| st.title("π Legal Document Summarizer (zero shot)") |
|
|
| USER_AVATAR = "π€" |
| BOT_AVATAR = "π€" |
|
|
| |
| def load_chat_history(): |
| with shelve.open("chat_history") as db: |
| return db.get("messages", []) |
|
|
| |
| def save_chat_history(messages): |
| with shelve.open("chat_history") as db: |
| db["messages"] = messages |
|
|
| |
| def limit_text(text, word_limit=500): |
| words = text.split() |
| return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "") |
|
|
|
|
| |
|
|
|
|
| def clean_text(text): |
| |
| text = text.replace('\r\n', ' ').replace('\n', ' ') |
| text = re.sub(r'\s+', ' ', text) |
| |
| |
| text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE) |
|
|
| |
| text = re.sub(r'[_]{5,}', '', text) |
| text = re.sub(r'[-]{5,}', '', text) |
| |
| |
| text = re.sub(r'[.]{4,}', '', text) |
| |
| |
| text = text.strip() |
|
|
| return text |
|
|
|
|
| |
|
|
|
|
| |
|
|
| |
| load_dotenv() |
| HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
|
|
| |
| |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
| |
| @st.cache_resource |
| def load_local_zero_shot_classifier(): |
| return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli") |
|
|
| local_classifier = load_local_zero_shot_classifier() |
|
|
|
|
| SECTION_LABELS = ["Facts", "Arguments", "Judgment", "Other"] |
|
|
| def classify_chunk(text): |
| result = local_classifier(text, candidate_labels=SECTION_LABELS) |
| return result["labels"][0] |
|
|
|
|
| |
| def section_by_zero_shot(text): |
| sections = {"Facts": "", "Arguments": "", "Judgment": "", "Other": ""} |
| sentences = sent_tokenize(text) |
| chunk = "" |
|
|
| for i, sent in enumerate(sentences): |
| chunk += sent + " " |
| if (i + 1) % 3 == 0 or i == len(sentences) - 1: |
| label = classify_chunk(chunk.strip()) |
| print(f"π Chunk: {chunk[:60]}...\nπ Predicted Label: {label}") |
| |
| label = label.capitalize() |
| if label not in sections: |
| label = "Other" |
| sections[label] += chunk + "\n" |
| chunk = "" |
|
|
| return sections |
|
|
|
|
|
|
|
|
| |
|
|
|
|
|
|
| |
|
|
| |
| def extract_text(file): |
| if file.name.endswith(".pdf"): |
| reader = PyPDF2.PdfReader(file) |
| full_text = "\n".join(page.extract_text() or "" for page in reader.pages) |
| elif file.name.endswith(".docx"): |
| full_text = docx2txt.process(file) |
| elif file.name.endswith(".txt"): |
| full_text = file.read().decode("utf-8") |
| else: |
| return "Unsupported file type." |
| |
| return full_text |
|
|
|
|
| |
|
|
| |
|
|
|
|
| @st.cache_resource |
| def load_legalbert(): |
| return SentenceTransformer("nlpaueb/legal-bert-base-uncased") |
|
|
|
|
| legalbert_model = load_legalbert() |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| @st.cache_resource |
| def load_fast_bart(): |
| device = 0 if torch.cuda.is_available() else -1 |
| return pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device) |
|
|
| bart_summarizer = load_fast_bart() |
|
|
| def legalbert_extractive_summary(text, top_ratio=0.2): |
| sentences = sent_tokenize(text) |
| top_k = max(3, int(len(sentences) * top_ratio)) |
|
|
| if len(sentences) <= top_k: |
| return text |
|
|
| |
| sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True) |
| doc_embedding = torch.mean(sentence_embeddings, dim=0) |
| cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0] |
| top_results = torch.topk(cosine_scores, k=top_k) |
|
|
| |
| selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())] |
| return " ".join(selected_sentences) |
|
|
|
|
|
|
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| def bart_abstractive_summary_chunked(text, max_chunk_words=700, max_length=256, min_length=60): |
| words = text.split() |
| summaries = [] |
|
|
| for i in range(0, len(words), max_chunk_words): |
| chunk = " ".join(words[i:i+max_chunk_words]) |
| summary = bart_summarizer( |
| chunk, max_length=max_length, min_length=min_length, do_sample=False |
| )[0]['summary_text'] |
| summaries.append(summary) |
|
|
| return " ".join(summaries) |
|
|
|
|
|
|
| def hybrid_summary_by_section(text, top_ratio=0.8): |
| cleaned_text = clean_text(text) |
| sections = section_by_zero_shot(cleaned_text) |
|
|
| summary_parts = [] |
| for name, content in sections.items(): |
| if content.strip(): |
| |
| sentences = sent_tokenize(content) |
| top_k = max(3, int(len(sentences) * top_ratio)) |
|
|
| |
| extractive = legalbert_extractive_summary(content, 0.8) |
|
|
| |
| abstractive = bart_abstractive_summary_chunked(extractive) |
|
|
| |
| hybrid = f"π **Extractive Summary:**\n{extractive}\n\nπ **Abstractive Summary:**\n{abstractive}" |
| summary_parts.append(f"### π {name} Section:\n{clean_text(hybrid)}") |
|
|
| |
| return sections |
|
|
|
|
| |
|
|
|
|
| |
|
|
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = load_chat_history() |
|
|
| |
| if "last_uploaded" not in st.session_state: |
| st.session_state.last_uploaded = None |
|
|
| |
| with st.sidebar: |
| st.subheader("βοΈ Options") |
| if st.button("Delete Chat History"): |
| st.session_state.messages = [] |
| st.session_state.last_uploaded = None |
| save_chat_history([]) |
|
|
| |
| def display_with_typing_effect(text, speed=0.005): |
| placeholder = st.empty() |
| displayed_text = "" |
| for char in text: |
| displayed_text += char |
| placeholder.markdown(displayed_text) |
| time.sleep(speed) |
| return displayed_text |
|
|
| |
| for message in st.session_state.messages: |
| avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR |
| with st.chat_message(message["role"], avatar=avatar): |
| st.markdown(message["content"]) |
|
|
|
|
| |
| prompt = st.chat_input("Type a message...") |
|
|
| |
| |
|
|
| |
| with st.container(): |
| st.subheader("π Upload a Legal Document") |
| uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"]) |
| reprocess_btn = st.button("π Reprocess Last Uploaded File") |
|
|
|
|
| |
| def get_file_hash(file): |
| file.seek(0) |
| content = file.read() |
| file.seek(0) |
| return hashlib.md5(content).hexdigest() |
|
|
|
|
|
|
| if uploaded_file: |
| file_hash = get_file_hash(uploaded_file) |
| |
| |
| if file_hash != st.session_state.get("last_uploaded_hash") or reprocess_btn: |
| raw_text = extract_text(uploaded_file) |
| summary_text = hybrid_summary_by_section(raw_text) |
|
|
| st.session_state.messages.append({ |
| "role": "user", |
| "content": f"π€ Uploaded **{uploaded_file.name}**" |
| }) |
|
|
| with st.chat_message("assistant", avatar=BOT_AVATAR): |
| preview_text = f"π§Ύ **Hybrid Summary of {uploaded_file.name}:**\n\n{summary_text}" |
| display_with_typing_effect(clean_text(preview_text), speed=0) |
|
|
| st.session_state.messages.append({ |
| "role": "assistant", |
| "content": preview_text |
| }) |
|
|
| |
| if not reprocess_btn: |
| st.session_state.last_uploaded_hash = file_hash |
|
|
| save_chat_history(st.session_state.messages) |
| st.rerun() |
|
|
|
|
| |
| if prompt: |
| raw_text = prompt |
| summary_text = hybrid_summary_by_section(raw_text) |
| |
| st.session_state.messages.append({ |
| "role": "user", |
| "content": prompt |
| }) |
|
|
| with st.chat_message("assistant", avatar=BOT_AVATAR): |
| bot_response = f"π **Hybrid Summary of your text:**\n\n{summary_text}" |
| display_with_typing_effect(clean_text(bot_response), speed=0) |
|
|
| st.session_state.messages.append({ |
| "role": "assistant", |
| "content": bot_response |
| }) |
|
|
| save_chat_history(st.session_state.messages) |
| st.rerun() |
|
|