import streamlit as st import tempfile import os import re import io import json from typing import List, Dict, Tuple, Any, Optional import torch from transformers import AutoTokenizer, AutoModelForCausalLM from pypdf import PdfReader import docx import spacy import math # ------------------------- # PAGE CONFIG # ------------------------- st.set_page_config(page_title="ClauseWise – Granite 3.2 (2B) Legal Assistant", page_icon="⚖️", layout="wide") # ------------------------- # MODEL SETUP # ------------------------- MODEL_ID = "ibm-granite/granite-3.2-2b-instruct" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 @st.cache_resource def load_llm_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, device_map="auto" if DEVICE == "cuda" else None ) if DEVICE != "cuda": model.to(DEVICE) return tokenizer, model tokenizer, model = load_llm_model() nlp = spacy.load("en_core_web_sm") # ------------------------- # HELPER FUNCTIONS # ------------------------- def build_chat_prompt(system_prompt: str, user_prompt: str) -> str: messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_prompt}) try: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception: sys = f"[SYSTEM]\n{system_prompt}\n" if system_prompt else "" usr = f"[USER]\n{user_prompt}\n[ASSISTANT]\n" return sys + usr def llm_generate(system_prompt: str, user_prompt: str, max_new_tokens=512, temperature=0.3, top_p=0.9) -> str: prompt = build_chat_prompt(system_prompt, user_prompt) inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) with torch.inference_mode(): output_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id ) full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) if "[ASSISTANT]" in full_text: return full_text.split("[ASSISTANT]")[-1].strip() if full_text.startswith(prompt): return full_text[len(prompt):].strip() return full_text.strip() # ------------------------- # DOCUMENT LOADING # ------------------------- def load_text_from_pdf(file_obj) -> str: reader = PdfReader(file_obj) pages = [] for page in reader.pages: try: pages.append(page.extract_text() or "") except Exception: pages.append("") return "\n".join(pages).strip() def load_text_from_docx(file_obj) -> str: data = file_obj.read() file_obj.seek(0) f = io.BytesIO(data) doc = docx.Document(f) paras = [p.text for p in doc.paragraphs] return "\n".join(paras).strip() def load_text_from_txt(file_obj) -> str: data = file_obj.read() if isinstance(data, bytes): try: data = data.decode("utf-8", errors="ignore") except: data = data.decode("latin-1", errors="ignore") return str(data).strip() def load_document(file) -> str: if not file: return "" name = (file.name or "").lower() if name.endswith(".pdf"): return load_text_from_pdf(file) elif name.endswith(".docx"): return load_text_from_docx(file) elif name.endswith(".txt"): return load_text_from_txt(file) else: try: return load_text_from_pdf(file) except: pass try: return load_text_from_docx(file) except: pass try: return load_text_from_txt(file) except: pass return "" def get_text_from_inputs(file, text): file_text = load_document(file) if file else "" final = (text or "").strip() return file_text if len(file_text) > len(final) else final # ------------------------- # CLAUSE PROCESSING # ------------------------- CLAUSE_SPLIT_REGEX = re.compile(r"(?:(?:^\s*\d+(?:\.\d+)[.)]\s+)|(?:^\s[A-Z]\s*[.)]\s+)|(?:;?\s*\n))", re.MULTILINE) def split_into_clauses(text: str, min_len: int = 40) -> List[str]: if not text: return [] parts = re.split(CLAUSE_SPLIT_REGEX, text) if len(parts) < 2: parts = re.split(r"(?<=[.;])\s+\n?\s*", text) clauses = [p.strip() for p in parts if len(p.strip()) >= min_len] seen = set() unique = [] for c in clauses: key = re.sub(r"\s+", " ", c.lower()) if key not in seen: seen.add(key) unique.append(c) return unique def simplify_clause(clause: str) -> str: system = "You are a legal assistant that rewrites clauses into plain English while preserving meaning." user = f"Rewrite the following clause in plain English with bullet points for risks.\n\nClause:\n{clause}" return llm_generate(system, user, max_new_tokens=400) def ner_entities(text: str) -> Dict[str, List[str]]: if not text: return {} doc = nlp(text) out: Dict[str, List[str]] = {} for ent in doc.ents: out.setdefault(ent.label_, []).append(ent.text) out = {k: sorted(set(v)) for k, v in out.items()} return out def extract_clauses(text: str) -> List[str]: return split_into_clauses(text) # ------------------------- # DOCUMENT CLASSIFICATION # ------------------------- DOC_TYPES = [ "Non-Disclosure Agreement (NDA)", "Lease Agreement", "Employment Contract", "Service Agreement", "Sales Agreement", "Consulting Agreement", "End User License Agreement (EULA)", "Terms of Service", ] def classify_document(text: str) -> str: system = "You are a legal document classifier. Choose the best matching document type." labels = "\n".join(f"- {t}" for t in DOC_TYPES) user = f"Classify the following document:\n{labels}\n\n{text[:5000]}" resp = llm_generate(system, user, max_new_tokens=200) scores = {t: (1.0 if t.lower() in resp.lower() else 0.0) for t in DOC_TYPES} best = max(scores.items(), key=lambda kv: kv[1])[0] if scores[best] == 0.0: lower = text.lower() if "confidential" in lower or "non-disclosure" in lower or "nda" in lower: best = "Non-Disclosure Agreement (NDA)" elif "lease" in lower or "tenant" in lower or "landlord" in lower: best = "Lease Agreement" elif "employment" in lower or "employee" in lower or "employer" in lower: best = "Employment Contract" elif "services" in lower or "service" in lower or "statement of work" in lower: best = "Service Agreement" return best # ------------------------- # Negotiation Coach # ------------------------- def negotiation_coach(clause: str) -> Tuple[str, List[Dict[str, Any]]]: system = "You are an AI negotiation coach." user = ( "Propose 3 alternative versions ranked by acceptance rate in JSON.\n" f"Clause:\n{clause}" ) resp = llm_generate(system, user, max_new_tokens=700) data = None try: json_str = re.search(r"\{[\s\S]*\}", resp).group(0) data = json.loads(json_str) except: data = {"alternatives": []} alts = re.split(r"\n\s*\d+[.)]\s*", resp) for i, chunk in enumerate(alts[1:4], start=1): data["alternatives"].append({ "rank": i, "acceptance_rate_percent": max(50, 90 - (i-1)*10), "title": f"Alternative {i}", "clause_text": chunk.strip()[:800], "rationale": "Heuristic parse from model response." }) return json.dumps(data, indent=2), data.get("alternatives", []) # ------------------------- # Future Risk Predictor # ------------------------- def future_risk_predictor(clause: str) -> Tuple[str, List[Dict[str, Any]]]: system = "Forecast future risks over 1–5 years." user = f"Analyze clause and return JSON timeline.\nClause:\n{clause}" resp = llm_generate(system, user, max_new_tokens=700) data = None try: json_str = re.search(r"\{[\s\S]*\}", resp).group(0) data = json.loads(json_str) except: data = {"timeline": []} for y in range(1,6): data["timeline"].append({ "year": y, "risk_score_0_100": min(95, 40 + y*8), "key_risks": ["Heuristic timeline due to JSON parse fallback."], "mitigation": ["Seek legal review", "Adjust clause terms"] }) return json.dumps(data, indent=2), data["timeline"] # ------------------------- # Fairness Balance Meter # ------------------------- def fairness_balance_meter(clause: str) -> Tuple[str,int,str]: system = "Evaluate clause fairness (0=Party A,50=balanced,100=Party B)." user = f"Return JSON: score_0_100 and rationale.\nClause:\n{clause}" resp = llm_generate(system, user, max_new_tokens=400) try: data = json.loads(re.search(r"\{[\s\S]*\}", resp).group(0)) score = int(data.get("score_0_100", 50)) rationale = data.get("rationale","") except: score,rationale=50,"Fallback balanced score." return json.dumps({"score_0_100":score,"rationale":rationale,"notes":[]}, indent=2), score, rationale # ------------------------- # Clause Battle Arena # ------------------------- def clause_battle_arena(text_a: str, text_b: str) -> Tuple[str,str]: system="Compare 2 contract drafts across categories." user=f"Compare Document A vs Document B and return JSON.\nA:\n{text_a[:4000]}\nB:\n{text_b[:4000]}" resp = llm_generate(system,user,max_new_tokens=900) try: data=json.loads(re.search(r"\{[\s\S]*\}", resp).group(0)) except: data={"rounds":[{"category":c,"winner":"Draw","rationale":"Fallback"} for c in ["Liability","Termination","IP","Payment","Confidentiality","Governing Law"]], "overall_winner":"Draw","summary":"Fallback"} pretty=json.dumps(data, indent=2) rounds_md="\n".join([f"- {r['category']}: {r['winner']} — {r.get('rationale','')}" for r in data.get("rounds",[])]) md=f"Overall Winner: {data.get('overall_winner','Draw')}\n\nRounds:\n{rounds_md}\n\nSummary:\n{data.get('summary','')}" return pretty,md # -------------------------- # Sensitive Data Sniffer # ------------------------- PII_REGEXES = { "Email": r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}", "Phone": r"\+?\d[\d\-\s]{7,}\d", "SSN (US)": r"\b\d{3}-\d{2}-\d{4}\b", "Credit Card": r"\b(?:\d[ -]*?){13,16}\b", } def sensitive_data_sniffer(text: str) -> Tuple[str, Dict[str,List[str]]]: system="Find hidden privacy traps and personal data." user=f"Return JSON.\nText:\n{text[:6000]}" resp=llm_generate(system,user,max_new_tokens=700) try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0)) except:data={"data_categories":["Name","Email"],"sharing_parties":["Service Provider"],"processing_purposes":["Service delivery"],"risks":["Potential over-collection"],"recommendations":["Narrow purpose","Limit retention"]} regex_hits={} for label, pattern in PII_REGEXES.items(): hits=re.findall(pattern,text or "",flags=re.IGNORECASE) if hits: regex_hits[label]=sorted(set([h.strip() for h in hits])) pretty=json.dumps({"llm":data,"regex_hits":regex_hits},indent=2) return pretty, regex_hits # ------------------------- # Litigation Risk Radar # ------------------------- def litigation_risk_radar(text:str)->Tuple[str,str]: clauses=split_into_clauses(text) sample="\n\n".join(clauses[:8]) if clauses else text[:4000] system="Identify clauses likely to trigger disputes." user=f"Return JSON of hotspots.\nClauses:\n{sample}" resp=llm_generate(system,user,max_new_tokens=900) try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0)) except:data={"hotspots":[{"clause_excerpt":(clauses[0][:280] if clauses else text[:280]),"risk_level":"Medium","why":"Ambiguous obligations.","sample_dispute_scenario":"Party A alleges non-performance due to unclear milestones."}]} pretty=json.dumps(data,indent=2) md="\n".join([f"- [{h.get('risk_level','Medium')}] {h.get('clause_excerpt','')}\n Why: {h.get('why','')}\n Scenario: {h.get('sample_dispute_scenario','')}" for h in data.get("hotspots",[])]) return pretty, md # ------------------------- # STREAMLIT UI # ------------------------- st.title("ClauseWise – Granite 3.2 (2B) Legal Assistant") st.markdown("Upload a PDF/DOCX/TXT or paste text below. Tabs provide different legal analysis tools.") with st.sidebar: uploaded_file = st.file_uploader("Upload PDF/DOCX/TXT (optional)", type=["pdf","docx","txt"]) pasted_text = st.text_area("Or paste text here", height=200) text_data = get_text_from_inputs(uploaded_file, pasted_text) tabs = st.tabs([ "Clause Simplification","Named Entity Recognition","Clause Extraction", "Document Classification","Negotiation Coach","Future Risk Predictor", "Fairness Balance Meter","Clause Battle Arena","Sensitive Data Sniffer","Litigation Risk Radar" ]) with tabs[0]: clause_input = st.text_area("Clause (optional)", height=150) if st.button("Simplify Clause", key="simplify"): target = clause_input.strip() or text_data st.text_area("Plain English Output", simplify_clause(target), height=250) with tabs[1]: if st.button("Run NER", key="ner"): st.json(ner_entities(text_data[:12000])) with tabs[2]: if st.button("Extract Clauses", key="extract"): clauses = extract_clauses(text_data) st.dataframe([[c] for c in clauses], columns=["Clause"]) with tabs[3]: if st.button("Classify Document", key="classify"): st.text_area("Predicted Type", classify_document(text_data)) with tabs[4]: negotiation_clause = st.text_area("Clause to Optimize", height=150) if st.button("Suggest Alternatives", key="negotiation"): pretty, alts = negotiation_coach(negotiation_clause.strip() or text_data) st.json(json.loads(pretty)) table=[[a.get("rank",""),a.get("acceptance_rate_percent",""),a.get("title",""),a.get("clause_text",""),a.get("rationale","")] for a in alts] st.dataframe(table, columns=["Rank","Acceptance %","Title","Clause Text","Rationale"]) with tabs[5]: risk_clause = st.text_area("Clause for Risk Prediction", height=150) if st.button("Predict 1–5 Year Risks", key="risk"): pretty, timeline = future_risk_predictor(risk_clause.strip() or text_data) st.json(json.loads(pretty)) table=[[t.get("year",""),t.get("risk_score_0_100",""),"; ".join(t.get("key_risks",[])),"; ".join(t.get("mitigation",[]))] for t in timeline] st.dataframe(table, columns=["Year","Risk Score (0–100)","Key Risks","Mitigation"]) with tabs[6]: fairness_clause = st.text_area("Clause", height=150) if st.button("Compute Fairness", key="fairness"): pretty, score, rationale = fairness_balance_meter(fairness_clause.strip() or text_data) st.json(json.loads(pretty)) st.slider("Balance Score", min_value=0,max_value=100,value=score) st.text_area("Rationale / Notes", rationale, height=100) with tabs[7]: clause_a = st.text_area("Document A", height=150) clause_b = st.text_area("Document B", height=150) if st.button("Compare Clauses", key="battle"): pretty, md = clause_battle_arena(clause_a.strip() or text_data, clause_b.strip() or text_data) st.text_area("Battle JSON", pretty, height=300) st.markdown(md) with tabs[8]: if st.button("Scan for Sensitive Data", key="sensitive"): pretty, hits = sensitive_data_sniffer(text_data) st.text_area("Sensitive Data JSON", pretty, height=300) st.json(hits) with tabs[9]: if st.button("Identify Litigation Risk Hotspots", key="litigation"): pretty, md = litigation_risk_radar(text_data) st.text_area("Litigation JSON", pretty, height=300) st.markdown(md)