Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tempfile | |
| import os | |
| import re | |
| import io | |
| import json | |
| from typing import List, Dict, Tuple, Any, Optional | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from pypdf import PdfReader | |
| import docx | |
| import spacy | |
| import math | |
| # ------------------------- | |
| # PAGE CONFIG | |
| # ------------------------- | |
| st.set_page_config(page_title="ClauseWise – Granite 3.2 (2B) Legal Assistant", page_icon="⚖️", layout="wide") | |
| # ------------------------- | |
| # MODEL SETUP | |
| # ------------------------- | |
| MODEL_ID = "ibm-granite/granite-3.2-2b-instruct" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| def load_llm_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| device_map="auto" if DEVICE == "cuda" else None | |
| ) | |
| if DEVICE != "cuda": | |
| model.to(DEVICE) | |
| return tokenizer, model | |
| tokenizer, model = load_llm_model() | |
| nlp = spacy.load("en_core_web_sm") | |
| # ------------------------- | |
| # HELPER FUNCTIONS | |
| # ------------------------- | |
| def build_chat_prompt(system_prompt: str, user_prompt: str) -> str: | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": user_prompt}) | |
| try: | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| except Exception: | |
| sys = f"[SYSTEM]\n{system_prompt}\n" if system_prompt else "" | |
| usr = f"[USER]\n{user_prompt}\n[ASSISTANT]\n" | |
| return sys + usr | |
| def llm_generate(system_prompt: str, user_prompt: str, max_new_tokens=512, temperature=0.3, top_p=0.9) -> str: | |
| prompt = build_chat_prompt(system_prompt, user_prompt) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| if "[ASSISTANT]" in full_text: | |
| return full_text.split("[ASSISTANT]")[-1].strip() | |
| if full_text.startswith(prompt): | |
| return full_text[len(prompt):].strip() | |
| return full_text.strip() | |
| # ------------------------- | |
| # DOCUMENT LOADING | |
| # ------------------------- | |
| def load_text_from_pdf(file_obj) -> str: | |
| reader = PdfReader(file_obj) | |
| pages = [] | |
| for page in reader.pages: | |
| try: | |
| pages.append(page.extract_text() or "") | |
| except Exception: | |
| pages.append("") | |
| return "\n".join(pages).strip() | |
| def load_text_from_docx(file_obj) -> str: | |
| data = file_obj.read() | |
| file_obj.seek(0) | |
| f = io.BytesIO(data) | |
| doc = docx.Document(f) | |
| paras = [p.text for p in doc.paragraphs] | |
| return "\n".join(paras).strip() | |
| def load_text_from_txt(file_obj) -> str: | |
| data = file_obj.read() | |
| if isinstance(data, bytes): | |
| try: | |
| data = data.decode("utf-8", errors="ignore") | |
| except: | |
| data = data.decode("latin-1", errors="ignore") | |
| return str(data).strip() | |
| def load_document(file) -> str: | |
| if not file: | |
| return "" | |
| name = (file.name or "").lower() | |
| if name.endswith(".pdf"): | |
| return load_text_from_pdf(file) | |
| elif name.endswith(".docx"): | |
| return load_text_from_docx(file) | |
| elif name.endswith(".txt"): | |
| return load_text_from_txt(file) | |
| else: | |
| try: | |
| return load_text_from_pdf(file) | |
| except: | |
| pass | |
| try: | |
| return load_text_from_docx(file) | |
| except: | |
| pass | |
| try: | |
| return load_text_from_txt(file) | |
| except: | |
| pass | |
| return "" | |
| def get_text_from_inputs(file, text): | |
| file_text = load_document(file) if file else "" | |
| final = (text or "").strip() | |
| return file_text if len(file_text) > len(final) else final | |
| # ------------------------- | |
| # CLAUSE PROCESSING | |
| # ------------------------- | |
| CLAUSE_SPLIT_REGEX = re.compile(r"(?:(?:^\s*\d+(?:\.\d+)[.)]\s+)|(?:^\s[A-Z]\s*[.)]\s+)|(?:;?\s*\n))", re.MULTILINE) | |
| def split_into_clauses(text: str, min_len: int = 40) -> List[str]: | |
| if not text: | |
| return [] | |
| parts = re.split(CLAUSE_SPLIT_REGEX, text) | |
| if len(parts) < 2: | |
| parts = re.split(r"(?<=[.;])\s+\n?\s*", text) | |
| clauses = [p.strip() for p in parts if len(p.strip()) >= min_len] | |
| seen = set() | |
| unique = [] | |
| for c in clauses: | |
| key = re.sub(r"\s+", " ", c.lower()) | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append(c) | |
| return unique | |
| def simplify_clause(clause: str) -> str: | |
| system = "You are a legal assistant that rewrites clauses into plain English while preserving meaning." | |
| user = f"Rewrite the following clause in plain English with bullet points for risks.\n\nClause:\n{clause}" | |
| return llm_generate(system, user, max_new_tokens=400) | |
| def ner_entities(text: str) -> Dict[str, List[str]]: | |
| if not text: | |
| return {} | |
| doc = nlp(text) | |
| out: Dict[str, List[str]] = {} | |
| for ent in doc.ents: | |
| out.setdefault(ent.label_, []).append(ent.text) | |
| out = {k: sorted(set(v)) for k, v in out.items()} | |
| return out | |
| def extract_clauses(text: str) -> List[str]: | |
| return split_into_clauses(text) | |
| # ------------------------- | |
| # DOCUMENT CLASSIFICATION | |
| # ------------------------- | |
| DOC_TYPES = [ | |
| "Non-Disclosure Agreement (NDA)", | |
| "Lease Agreement", | |
| "Employment Contract", | |
| "Service Agreement", | |
| "Sales Agreement", | |
| "Consulting Agreement", | |
| "End User License Agreement (EULA)", | |
| "Terms of Service", | |
| ] | |
| def classify_document(text: str) -> str: | |
| system = "You are a legal document classifier. Choose the best matching document type." | |
| labels = "\n".join(f"- {t}" for t in DOC_TYPES) | |
| user = f"Classify the following document:\n{labels}\n\n{text[:5000]}" | |
| resp = llm_generate(system, user, max_new_tokens=200) | |
| scores = {t: (1.0 if t.lower() in resp.lower() else 0.0) for t in DOC_TYPES} | |
| best = max(scores.items(), key=lambda kv: kv[1])[0] | |
| if scores[best] == 0.0: | |
| lower = text.lower() | |
| if "confidential" in lower or "non-disclosure" in lower or "nda" in lower: | |
| best = "Non-Disclosure Agreement (NDA)" | |
| elif "lease" in lower or "tenant" in lower or "landlord" in lower: | |
| best = "Lease Agreement" | |
| elif "employment" in lower or "employee" in lower or "employer" in lower: | |
| best = "Employment Contract" | |
| elif "services" in lower or "service" in lower or "statement of work" in lower: | |
| best = "Service Agreement" | |
| return best | |
| # ------------------------- | |
| # Negotiation Coach | |
| # ------------------------- | |
| def negotiation_coach(clause: str) -> Tuple[str, List[Dict[str, Any]]]: | |
| system = "You are an AI negotiation coach." | |
| user = ( | |
| "Propose 3 alternative versions ranked by acceptance rate in JSON.\n" | |
| f"Clause:\n{clause}" | |
| ) | |
| resp = llm_generate(system, user, max_new_tokens=700) | |
| data = None | |
| try: | |
| json_str = re.search(r"\{[\s\S]*\}", resp).group(0) | |
| data = json.loads(json_str) | |
| except: | |
| data = {"alternatives": []} | |
| alts = re.split(r"\n\s*\d+[.)]\s*", resp) | |
| for i, chunk in enumerate(alts[1:4], start=1): | |
| data["alternatives"].append({ | |
| "rank": i, | |
| "acceptance_rate_percent": max(50, 90 - (i-1)*10), | |
| "title": f"Alternative {i}", | |
| "clause_text": chunk.strip()[:800], | |
| "rationale": "Heuristic parse from model response." | |
| }) | |
| return json.dumps(data, indent=2), data.get("alternatives", []) | |
| # ------------------------- | |
| # Future Risk Predictor | |
| # ------------------------- | |
| def future_risk_predictor(clause: str) -> Tuple[str, List[Dict[str, Any]]]: | |
| system = "Forecast future risks over 1–5 years." | |
| user = f"Analyze clause and return JSON timeline.\nClause:\n{clause}" | |
| resp = llm_generate(system, user, max_new_tokens=700) | |
| data = None | |
| try: | |
| json_str = re.search(r"\{[\s\S]*\}", resp).group(0) | |
| data = json.loads(json_str) | |
| except: | |
| data = {"timeline": []} | |
| for y in range(1,6): | |
| data["timeline"].append({ | |
| "year": y, | |
| "risk_score_0_100": min(95, 40 + y*8), | |
| "key_risks": ["Heuristic timeline due to JSON parse fallback."], | |
| "mitigation": ["Seek legal review", "Adjust clause terms"] | |
| }) | |
| return json.dumps(data, indent=2), data["timeline"] | |
| # ------------------------- | |
| # Fairness Balance Meter | |
| # ------------------------- | |
| def fairness_balance_meter(clause: str) -> Tuple[str,int,str]: | |
| system = "Evaluate clause fairness (0=Party A,50=balanced,100=Party B)." | |
| user = f"Return JSON: score_0_100 and rationale.\nClause:\n{clause}" | |
| resp = llm_generate(system, user, max_new_tokens=400) | |
| try: | |
| data = json.loads(re.search(r"\{[\s\S]*\}", resp).group(0)) | |
| score = int(data.get("score_0_100", 50)) | |
| rationale = data.get("rationale","") | |
| except: | |
| score,rationale=50,"Fallback balanced score." | |
| return json.dumps({"score_0_100":score,"rationale":rationale,"notes":[]}, indent=2), score, rationale | |
| # ------------------------- | |
| # Clause Battle Arena | |
| # ------------------------- | |
| def clause_battle_arena(text_a: str, text_b: str) -> Tuple[str,str]: | |
| system="Compare 2 contract drafts across categories." | |
| user=f"Compare Document A vs Document B and return JSON.\nA:\n{text_a[:4000]}\nB:\n{text_b[:4000]}" | |
| resp = llm_generate(system,user,max_new_tokens=900) | |
| try: | |
| data=json.loads(re.search(r"\{[\s\S]*\}", resp).group(0)) | |
| except: | |
| data={"rounds":[{"category":c,"winner":"Draw","rationale":"Fallback"} for c in ["Liability","Termination","IP","Payment","Confidentiality","Governing Law"]], | |
| "overall_winner":"Draw","summary":"Fallback"} | |
| pretty=json.dumps(data, indent=2) | |
| rounds_md="\n".join([f"- {r['category']}: {r['winner']} — {r.get('rationale','')}" for r in data.get("rounds",[])]) | |
| md=f"Overall Winner: {data.get('overall_winner','Draw')}\n\nRounds:\n{rounds_md}\n\nSummary:\n{data.get('summary','')}" | |
| return pretty,md | |
| # -------------------------- | |
| # Sensitive Data Sniffer | |
| # ------------------------- | |
| PII_REGEXES = { | |
| "Email": r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}", | |
| "Phone": r"\+?\d[\d\-\s]{7,}\d", | |
| "SSN (US)": r"\b\d{3}-\d{2}-\d{4}\b", | |
| "Credit Card": r"\b(?:\d[ -]*?){13,16}\b", | |
| } | |
| def sensitive_data_sniffer(text: str) -> Tuple[str, Dict[str,List[str]]]: | |
| system="Find hidden privacy traps and personal data." | |
| user=f"Return JSON.\nText:\n{text[:6000]}" | |
| resp=llm_generate(system,user,max_new_tokens=700) | |
| try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0)) | |
| except:data={"data_categories":["Name","Email"],"sharing_parties":["Service Provider"],"processing_purposes":["Service delivery"],"risks":["Potential over-collection"],"recommendations":["Narrow purpose","Limit retention"]} | |
| regex_hits={} | |
| for label, pattern in PII_REGEXES.items(): | |
| hits=re.findall(pattern,text or "",flags=re.IGNORECASE) | |
| if hits: regex_hits[label]=sorted(set([h.strip() for h in hits])) | |
| pretty=json.dumps({"llm":data,"regex_hits":regex_hits},indent=2) | |
| return pretty, regex_hits | |
| # ------------------------- | |
| # Litigation Risk Radar | |
| # ------------------------- | |
| def litigation_risk_radar(text:str)->Tuple[str,str]: | |
| clauses=split_into_clauses(text) | |
| sample="\n\n".join(clauses[:8]) if clauses else text[:4000] | |
| system="Identify clauses likely to trigger disputes." | |
| user=f"Return JSON of hotspots.\nClauses:\n{sample}" | |
| resp=llm_generate(system,user,max_new_tokens=900) | |
| try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0)) | |
| except:data={"hotspots":[{"clause_excerpt":(clauses[0][:280] if clauses else text[:280]),"risk_level":"Medium","why":"Ambiguous obligations.","sample_dispute_scenario":"Party A alleges non-performance due to unclear milestones."}]} | |
| pretty=json.dumps(data,indent=2) | |
| md="\n".join([f"- [{h.get('risk_level','Medium')}] {h.get('clause_excerpt','')}\n Why: {h.get('why','')}\n Scenario: {h.get('sample_dispute_scenario','')}" for h in data.get("hotspots",[])]) | |
| return pretty, md | |
| # ------------------------- | |
| # STREAMLIT UI | |
| # ------------------------- | |
| st.title("ClauseWise – Granite 3.2 (2B) Legal Assistant") | |
| st.markdown("Upload a PDF/DOCX/TXT or paste text below. Tabs provide different legal analysis tools.") | |
| with st.sidebar: | |
| uploaded_file = st.file_uploader("Upload PDF/DOCX/TXT (optional)", type=["pdf","docx","txt"]) | |
| pasted_text = st.text_area("Or paste text here", height=200) | |
| text_data = get_text_from_inputs(uploaded_file, pasted_text) | |
| tabs = st.tabs([ | |
| "Clause Simplification","Named Entity Recognition","Clause Extraction", | |
| "Document Classification","Negotiation Coach","Future Risk Predictor", | |
| "Fairness Balance Meter","Clause Battle Arena","Sensitive Data Sniffer","Litigation Risk Radar" | |
| ]) | |
| with tabs[0]: | |
| clause_input = st.text_area("Clause (optional)", height=150) | |
| if st.button("Simplify Clause", key="simplify"): | |
| target = clause_input.strip() or text_data | |
| st.text_area("Plain English Output", simplify_clause(target), height=250) | |
| with tabs[1]: | |
| if st.button("Run NER", key="ner"): | |
| st.json(ner_entities(text_data[:12000])) | |
| with tabs[2]: | |
| if st.button("Extract Clauses", key="extract"): | |
| clauses = extract_clauses(text_data) | |
| st.dataframe([[c] for c in clauses], columns=["Clause"]) | |
| with tabs[3]: | |
| if st.button("Classify Document", key="classify"): | |
| st.text_area("Predicted Type", classify_document(text_data)) | |
| with tabs[4]: | |
| negotiation_clause = st.text_area("Clause to Optimize", height=150) | |
| if st.button("Suggest Alternatives", key="negotiation"): | |
| pretty, alts = negotiation_coach(negotiation_clause.strip() or text_data) | |
| st.json(json.loads(pretty)) | |
| table=[[a.get("rank",""),a.get("acceptance_rate_percent",""),a.get("title",""),a.get("clause_text",""),a.get("rationale","")] for a in alts] | |
| st.dataframe(table, columns=["Rank","Acceptance %","Title","Clause Text","Rationale"]) | |
| with tabs[5]: | |
| risk_clause = st.text_area("Clause for Risk Prediction", height=150) | |
| if st.button("Predict 1–5 Year Risks", key="risk"): | |
| pretty, timeline = future_risk_predictor(risk_clause.strip() or text_data) | |
| st.json(json.loads(pretty)) | |
| table=[[t.get("year",""),t.get("risk_score_0_100",""),"; ".join(t.get("key_risks",[])),"; ".join(t.get("mitigation",[]))] for t in timeline] | |
| st.dataframe(table, columns=["Year","Risk Score (0–100)","Key Risks","Mitigation"]) | |
| with tabs[6]: | |
| fairness_clause = st.text_area("Clause", height=150) | |
| if st.button("Compute Fairness", key="fairness"): | |
| pretty, score, rationale = fairness_balance_meter(fairness_clause.strip() or text_data) | |
| st.json(json.loads(pretty)) | |
| st.slider("Balance Score", min_value=0,max_value=100,value=score) | |
| st.text_area("Rationale / Notes", rationale, height=100) | |
| with tabs[7]: | |
| clause_a = st.text_area("Document A", height=150) | |
| clause_b = st.text_area("Document B", height=150) | |
| if st.button("Compare Clauses", key="battle"): | |
| pretty, md = clause_battle_arena(clause_a.strip() or text_data, clause_b.strip() or text_data) | |
| st.text_area("Battle JSON", pretty, height=300) | |
| st.markdown(md) | |
| with tabs[8]: | |
| if st.button("Scan for Sensitive Data", key="sensitive"): | |
| pretty, hits = sensitive_data_sniffer(text_data) | |
| st.text_area("Sensitive Data JSON", pretty, height=300) | |
| st.json(hits) | |
| with tabs[9]: | |
| if st.button("Identify Litigation Risk Hotspots", key="litigation"): | |
| pretty, md = litigation_risk_radar(text_data) | |
| st.text_area("Litigation JSON", pretty, height=300) | |
| st.markdown(md) |