bhoomi19 commited on
Commit
1f579c9
·
verified ·
1 Parent(s): 8ad66cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -407
app.py CHANGED
@@ -1,407 +1,413 @@
1
- import streamlit as st
2
- import tempfile
3
- import os
4
- import re
5
- import io
6
- import json
7
- from typing import List, Dict, Tuple, Any, Optional
8
- import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
- from pypdf import PdfReader
11
- import docx
12
- import spacy
13
- import math
14
-
15
- # -------------------------
16
- # PAGE CONFIG
17
- # -------------------------
18
- st.set_page_config(page_title="ClauseWise – Granite 3.2 (2B) Legal Assistant", page_icon="⚖️", layout="wide")
19
-
20
- # -------------------------
21
- # MODEL SETUP
22
- # -------------------------
23
- MODEL_ID = "ibm-granite/granite-3.2-2b-instruct"
24
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
- DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
26
-
27
- @st.cache_resource
28
- def load_llm_model():
29
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- MODEL_ID,
32
- torch_dtype=DTYPE,
33
- device_map="auto" if DEVICE == "cuda" else None
34
- )
35
- if DEVICE != "cuda":
36
- model.to(DEVICE)
37
- return tokenizer, model
38
-
39
- tokenizer, model = load_llm_model()
40
-
41
- nlp = spacy.load("en_core_web_sm")
42
-
43
- # -------------------------
44
- # HELPER FUNCTIONS
45
- # -------------------------
46
- def build_chat_prompt(system_prompt: str, user_prompt: str) -> str:
47
- messages = []
48
- if system_prompt:
49
- messages.append({"role": "system", "content": system_prompt})
50
- messages.append({"role": "user", "content": user_prompt})
51
- try:
52
- return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
- except Exception:
54
- sys = f"[SYSTEM]\n{system_prompt}\n" if system_prompt else ""
55
- usr = f"[USER]\n{user_prompt}\n[ASSISTANT]\n"
56
- return sys + usr
57
-
58
- def llm_generate(system_prompt: str, user_prompt: str, max_new_tokens=512, temperature=0.3, top_p=0.9) -> str:
59
- prompt = build_chat_prompt(system_prompt, user_prompt)
60
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
61
- with torch.inference_mode():
62
- output_ids = model.generate(
63
- **inputs,
64
- max_new_tokens=max_new_tokens,
65
- temperature=temperature,
66
- top_p=top_p,
67
- do_sample=True,
68
- pad_token_id=tokenizer.eos_token_id
69
- )
70
- full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
71
- if "[ASSISTANT]" in full_text:
72
- return full_text.split("[ASSISTANT]")[-1].strip()
73
- if full_text.startswith(prompt):
74
- return full_text[len(prompt):].strip()
75
- return full_text.strip()
76
-
77
- # -------------------------
78
- # DOCUMENT LOADING
79
- # -------------------------
80
- def load_text_from_pdf(file_obj) -> str:
81
- reader = PdfReader(file_obj)
82
- pages = []
83
- for page in reader.pages:
84
- try:
85
- pages.append(page.extract_text() or "")
86
- except Exception:
87
- pages.append("")
88
- return "\n".join(pages).strip()
89
-
90
- def load_text_from_docx(file_obj) -> str:
91
- data = file_obj.read()
92
- file_obj.seek(0)
93
- f = io.BytesIO(data)
94
- doc = docx.Document(f)
95
- paras = [p.text for p in doc.paragraphs]
96
- return "\n".join(paras).strip()
97
-
98
- def load_text_from_txt(file_obj) -> str:
99
- data = file_obj.read()
100
- if isinstance(data, bytes):
101
- try:
102
- data = data.decode("utf-8", errors="ignore")
103
- except:
104
- data = data.decode("latin-1", errors="ignore")
105
- return str(data).strip()
106
-
107
- def load_document(file) -> str:
108
- if not file:
109
- return ""
110
- name = (file.name or "").lower()
111
- if name.endswith(".pdf"):
112
- return load_text_from_pdf(file)
113
- elif name.endswith(".docx"):
114
- return load_text_from_docx(file)
115
- elif name.endswith(".txt"):
116
- return load_text_from_txt(file)
117
- else:
118
- try:
119
- return load_text_from_pdf(file)
120
- except:
121
- pass
122
- try:
123
- return load_text_from_docx(file)
124
- except:
125
- pass
126
- try:
127
- return load_text_from_txt(file)
128
- except:
129
- pass
130
- return ""
131
-
132
- def get_text_from_inputs(file, text):
133
- file_text = load_document(file) if file else ""
134
- final = (text or "").strip()
135
- return file_text if len(file_text) > len(final) else final
136
-
137
- # -------------------------
138
- # CLAUSE PROCESSING
139
- # -------------------------
140
- CLAUSE_SPLIT_REGEX = re.compile(r"(?:(?:^\s*\d+(?:\.\d+)[.)]\s+)|(?:^\s[A-Z]\s*[.)]\s+)|(?:;?\s*\n))", re.MULTILINE)
141
-
142
- def split_into_clauses(text: str, min_len: int = 40) -> List[str]:
143
- if not text:
144
- return []
145
- parts = re.split(CLAUSE_SPLIT_REGEX, text)
146
- if len(parts) < 2:
147
- parts = re.split(r"(?<=[.;])\s+\n?\s*", text)
148
- clauses = [p.strip() for p in parts if len(p.strip()) >= min_len]
149
- seen = set()
150
- unique = []
151
- for c in clauses:
152
- key = re.sub(r"\s+", " ", c.lower())
153
- if key not in seen:
154
- seen.add(key)
155
- unique.append(c)
156
- return unique
157
-
158
- def simplify_clause(clause: str) -> str:
159
- system = "You are a legal assistant that rewrites clauses into plain English while preserving meaning."
160
- user = f"Rewrite the following clause in plain English with bullet points for risks.\n\nClause:\n{clause}"
161
- return llm_generate(system, user, max_new_tokens=400)
162
-
163
- def ner_entities(text: str) -> Dict[str, List[str]]:
164
- if not text:
165
- return {}
166
- doc = nlp(text)
167
- out: Dict[str, List[str]] = {}
168
- for ent in doc.ents:
169
- out.setdefault(ent.label_, []).append(ent.text)
170
- out = {k: sorted(set(v)) for k, v in out.items()}
171
- return out
172
-
173
- def extract_clauses(text: str) -> List[str]:
174
- return split_into_clauses(text)
175
-
176
- # -------------------------
177
- # DOCUMENT CLASSIFICATION
178
- # -------------------------
179
- DOC_TYPES = [
180
- "Non-Disclosure Agreement (NDA)",
181
- "Lease Agreement",
182
- "Employment Contract",
183
- "Service Agreement",
184
- "Sales Agreement",
185
- "Consulting Agreement",
186
- "End User License Agreement (EULA)",
187
- "Terms of Service",
188
- ]
189
-
190
- def classify_document(text: str) -> str:
191
- system = "You are a legal document classifier. Choose the best matching document type."
192
- labels = "\n".join(f"- {t}" for t in DOC_TYPES)
193
- user = f"Classify the following document:\n{labels}\n\n{text[:5000]}"
194
- resp = llm_generate(system, user, max_new_tokens=200)
195
- scores = {t: (1.0 if t.lower() in resp.lower() else 0.0) for t in DOC_TYPES}
196
- best = max(scores.items(), key=lambda kv: kv[1])[0]
197
- if scores[best] == 0.0:
198
- lower = text.lower()
199
- if "confidential" in lower or "non-disclosure" in lower or "nda" in lower:
200
- best = "Non-Disclosure Agreement (NDA)"
201
- elif "lease" in lower or "tenant" in lower or "landlord" in lower:
202
- best = "Lease Agreement"
203
- elif "employment" in lower or "employee" in lower or "employer" in lower:
204
- best = "Employment Contract"
205
- elif "services" in lower or "service" in lower or "statement of work" in lower:
206
- best = "Service Agreement"
207
- return best
208
-
209
- # -------------------------
210
- # Negotiation Coach
211
- # -------------------------
212
- def negotiation_coach(clause: str) -> Tuple[str, List[Dict[str, Any]]]:
213
- system = "You are an AI negotiation coach."
214
- user = (
215
- "Propose 3 alternative versions ranked by acceptance rate in JSON.\n"
216
- f"Clause:\n{clause}"
217
- )
218
- resp = llm_generate(system, user, max_new_tokens=700)
219
- data = None
220
- try:
221
- json_str = re.search(r"\{[\s\S]*\}", resp).group(0)
222
- data = json.loads(json_str)
223
- except:
224
- data = {"alternatives": []}
225
- alts = re.split(r"\n\s*\d+[.)]\s*", resp)
226
- for i, chunk in enumerate(alts[1:4], start=1):
227
- data["alternatives"].append({
228
- "rank": i,
229
- "acceptance_rate_percent": max(50, 90 - (i-1)*10),
230
- "title": f"Alternative {i}",
231
- "clause_text": chunk.strip()[:800],
232
- "rationale": "Heuristic parse from model response."
233
- })
234
- return json.dumps(data, indent=2), data.get("alternatives", [])
235
-
236
- # -------------------------
237
- # Future Risk Predictor
238
- # -------------------------
239
- def future_risk_predictor(clause: str) -> Tuple[str, List[Dict[str, Any]]]:
240
- system = "Forecast future risks over 1–5 years."
241
- user = f"Analyze clause and return JSON timeline.\nClause:\n{clause}"
242
- resp = llm_generate(system, user, max_new_tokens=700)
243
- data = None
244
- try:
245
- json_str = re.search(r"\{[\s\S]*\}", resp).group(0)
246
- data = json.loads(json_str)
247
- except:
248
- data = {"timeline": []}
249
- for y in range(1,6):
250
- data["timeline"].append({
251
- "year": y,
252
- "risk_score_0_100": min(95, 40 + y*8),
253
- "key_risks": ["Heuristic timeline due to JSON parse fallback."],
254
- "mitigation": ["Seek legal review", "Adjust clause terms"]
255
- })
256
- return json.dumps(data, indent=2), data["timeline"]
257
-
258
- # -------------------------
259
- # Fairness Balance Meter
260
- # -------------------------
261
- def fairness_balance_meter(clause: str) -> Tuple[str,int,str]:
262
- system = "Evaluate clause fairness (0=Party A,50=balanced,100=Party B)."
263
- user = f"Return JSON: score_0_100 and rationale.\nClause:\n{clause}"
264
- resp = llm_generate(system, user, max_new_tokens=400)
265
- try:
266
- data = json.loads(re.search(r"\{[\s\S]*\}", resp).group(0))
267
- score = int(data.get("score_0_100", 50))
268
- rationale = data.get("rationale","")
269
- except:
270
- score,rationale=50,"Fallback balanced score."
271
- return json.dumps({"score_0_100":score,"rationale":rationale,"notes":[]}, indent=2), score, rationale
272
-
273
- # -------------------------
274
- # Clause Battle Arena
275
- # -------------------------
276
- def clause_battle_arena(text_a: str, text_b: str) -> Tuple[str,str]:
277
- system="Compare 2 contract drafts across categories."
278
- user=f"Compare Document A vs Document B and return JSON.\nA:\n{text_a[:4000]}\nB:\n{text_b[:4000]}"
279
- resp = llm_generate(system,user,max_new_tokens=900)
280
- try:
281
- data=json.loads(re.search(r"\{[\s\S]*\}", resp).group(0))
282
- except:
283
- data={"rounds":[{"category":c,"winner":"Draw","rationale":"Fallback"} for c in ["Liability","Termination","IP","Payment","Confidentiality","Governing Law"]],
284
- "overall_winner":"Draw","summary":"Fallback"}
285
- pretty=json.dumps(data, indent=2)
286
- rounds_md="\n".join([f"- {r['category']}: {r['winner']} — {r.get('rationale','')}" for r in data.get("rounds",[])])
287
- md=f"Overall Winner: {data.get('overall_winner','Draw')}\n\nRounds:\n{rounds_md}\n\nSummary:\n{data.get('summary','')}"
288
- return pretty,md
289
-
290
- # -------------------------
291
- # Sensitive Data Sniffer
292
- # -------------------------
293
- PII_REGEXES = {
294
- "Email": r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}",
295
- "Phone": r"\+?\d[\d\-\s]{7,}\d",
296
- "SSN (US)": r"\b\d{3}-\d{2}-\d{4}\b",
297
- "Credit Card": r"\b(?:\d[ -]*?){13,16}\b",
298
- }
299
-
300
- def sensitive_data_sniffer(text: str) -> Tuple[str, Dict[str,List[str]]]:
301
- system="Find hidden privacy traps and personal data."
302
- user=f"Return JSON.\nText:\n{text[:6000]}"
303
- resp=llm_generate(system,user,max_new_tokens=700)
304
- try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0))
305
- except:data={"data_categories":["Name","Email"],"sharing_parties":["Service Provider"],"processing_purposes":["Service delivery"],"risks":["Potential over-collection"],"recommendations":["Narrow purpose","Limit retention"]}
306
- regex_hits={}
307
- for label, pattern in PII_REGEXES.items():
308
- hits=re.findall(pattern,text or "",flags=re.IGNORECASE)
309
- if hits: regex_hits[label]=sorted(set([h.strip() for h in hits]))
310
- pretty=json.dumps({"llm":data,"regex_hits":regex_hits},indent=2)
311
- return pretty, regex_hits
312
-
313
- # -------------------------
314
- # Litigation Risk Radar
315
- # -------------------------
316
- def litigation_risk_radar(text:str)->Tuple[str,str]:
317
- clauses=split_into_clauses(text)
318
- sample="\n\n".join(clauses[:8]) if clauses else text[:4000]
319
- system="Identify clauses likely to trigger disputes."
320
- user=f"Return JSON of hotspots.\nClauses:\n{sample}"
321
- resp=llm_generate(system,user,max_new_tokens=900)
322
- try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0))
323
- except:data={"hotspots":[{"clause_excerpt":(clauses[0][:280] if clauses else text[:280]),"risk_level":"Medium","why":"Ambiguous obligations.","sample_dispute_scenario":"Party A alleges non-performance due to unclear milestones."}]}
324
- pretty=json.dumps(data,indent=2)
325
- md="\n".join([f"- [{h.get('risk_level','Medium')}] {h.get('clause_excerpt','')}\n Why: {h.get('why','')}\n Scenario: {h.get('sample_dispute_scenario','')}" for h in data.get("hotspots",[])])
326
- return pretty, md
327
-
328
- # -------------------------
329
- # STREAMLIT UI
330
- # -------------------------
331
- st.title("ClauseWise Granite 3.2 (2B) Legal Assistant")
332
- st.markdown("Upload a PDF/DOCX/TXT or paste text below. Tabs provide different legal analysis tools.")
333
-
334
- with st.sidebar:
335
- uploaded_file = st.file_uploader("Upload PDF/DOCX/TXT (optional)", type=["pdf","docx","txt"])
336
- pasted_text = st.text_area("Or paste text here", height=200)
337
-
338
- text_data = get_text_from_inputs(uploaded_file, pasted_text)
339
-
340
- tabs = st.tabs([
341
- "Clause Simplification","Named Entity Recognition","Clause Extraction",
342
- "Document Classification","Negotiation Coach","Future Risk Predictor",
343
- "Fairness Balance Meter","Clause Battle Arena","Sensitive Data Sniffer","Litigation Risk Radar"
344
- ])
345
-
346
- with tabs[0]:
347
- clause_input = st.text_area("Clause (optional)", height=150)
348
- if st.button("Simplify Clause", key="simplify"):
349
- target = clause_input.strip() or text_data
350
- st.text_area("Plain English Output", simplify_clause(target), height=250)
351
-
352
- with tabs[1]:
353
- if st.button("Run NER", key="ner"):
354
- st.json(ner_entities(text_data[:12000]))
355
-
356
- with tabs[2]:
357
- if st.button("Extract Clauses", key="extract"):
358
- clauses = extract_clauses(text_data)
359
- st.dataframe([[c] for c in clauses], columns=["Clause"])
360
-
361
- with tabs[3]:
362
- if st.button("Classify Document", key="classify"):
363
- st.text_area("Predicted Type", classify_document(text_data))
364
-
365
- with tabs[4]:
366
- negotiation_clause = st.text_area("Clause to Optimize", height=150)
367
- if st.button("Suggest Alternatives", key="negotiation"):
368
- pretty, alts = negotiation_coach(negotiation_clause.strip() or text_data)
369
- st.json(json.loads(pretty))
370
- table=[[a.get("rank",""),a.get("acceptance_rate_percent",""),a.get("title",""),a.get("clause_text",""),a.get("rationale","")] for a in alts]
371
- st.dataframe(table, columns=["Rank","Acceptance %","Title","Clause Text","Rationale"])
372
-
373
- with tabs[5]:
374
- risk_clause = st.text_area("Clause for Risk Prediction", height=150)
375
- if st.button("Predict 1–5 Year Risks", key="risk"):
376
- pretty, timeline = future_risk_predictor(risk_clause.strip() or text_data)
377
- st.json(json.loads(pretty))
378
- table=[[t.get("year",""),t.get("risk_score_0_100",""),"; ".join(t.get("key_risks",[])),"; ".join(t.get("mitigation",[]))] for t in timeline]
379
- st.dataframe(table, columns=["Year","Risk Score (0–100)","Key Risks","Mitigation"])
380
-
381
- with tabs[6]:
382
- fairness_clause = st.text_area("Clause", height=150)
383
- if st.button("Compute Fairness", key="fairness"):
384
- pretty, score, rationale = fairness_balance_meter(fairness_clause.strip() or text_data)
385
- st.json(json.loads(pretty))
386
- st.slider("Balance Score", min_value=0,max_value=100,value=score)
387
- st.text_area("Rationale / Notes", rationale, height=100)
388
-
389
- with tabs[7]:
390
- clause_a = st.text_area("Document A", height=150)
391
- clause_b = st.text_area("Document B", height=150)
392
- if st.button("Compare Clauses", key="battle"):
393
- pretty, md = clause_battle_arena(clause_a.strip() or text_data, clause_b.strip() or text_data)
394
- st.text_area("Battle JSON", pretty, height=300)
395
- st.markdown(md)
396
-
397
- with tabs[8]:
398
- if st.button("Scan for Sensitive Data", key="sensitive"):
399
- pretty, hits = sensitive_data_sniffer(text_data)
400
- st.text_area("Sensitive Data JSON", pretty, height=300)
401
- st.json(hits)
402
-
403
- with tabs[9]:
404
- if st.button("Identify Litigation Risk Hotspots", key="litigation"):
405
- pretty, md = litigation_risk_radar(text_data)
406
- st.text_area("Litigation JSON", pretty, height=300)
407
- st.markdown(md)
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import re
5
+ import io
6
+ import json
7
+ from typing import List, Dict, Tuple, Any, Optional
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from pypdf import PdfReader
11
+ import docx
12
+ import spacy
13
+ import math
14
+
15
+ # -------------------------
16
+ # PAGE CONFIG
17
+ # -------------------------
18
+ st.set_page_config(page_title="ClauseWise – Granite 3.2 (2B) Legal Assistant", page_icon="⚖️", layout="wide")
19
+
20
+ # -------------------------
21
+ # MODEL SETUP
22
+ # -------------------------
23
+ MODEL_ID = "ibm-granite/granite-3.2-2b-instruct"
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
26
+
27
+ @st.cache_resource
28
+ def load_llm_model():
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_ID,
32
+ torch_dtype=DTYPE,
33
+ device_map="auto" if DEVICE == "cuda" else None
34
+ )
35
+ if DEVICE != "cuda":
36
+ model.to(DEVICE)
37
+ return tokenizer, model
38
+
39
+ tokenizer, model = load_llm_model()
40
+
41
+ from spacy.cli import download as spacy_download
42
+
43
+ try:
44
+ nlp = spacy.load("en_core_web_sm")
45
+ except OSError:
46
+ spacy_download("en_core_web_sm")
47
+ nlp = spacy.load("en_core_web_sm")
48
+
49
+ # -------------------------
50
+ # HELPER FUNCTIONS
51
+ # -------------------------
52
+ def build_chat_prompt(system_prompt: str, user_prompt: str) -> str:
53
+ messages = []
54
+ if system_prompt:
55
+ messages.append({"role": "system", "content": system_prompt})
56
+ messages.append({"role": "user", "content": user_prompt})
57
+ try:
58
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
+ except Exception:
60
+ sys = f"[SYSTEM]\n{system_prompt}\n" if system_prompt else ""
61
+ usr = f"[USER]\n{user_prompt}\n[ASSISTANT]\n"
62
+ return sys + usr
63
+
64
+ def llm_generate(system_prompt: str, user_prompt: str, max_new_tokens=512, temperature=0.3, top_p=0.9) -> str:
65
+ prompt = build_chat_prompt(system_prompt, user_prompt)
66
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
67
+ with torch.inference_mode():
68
+ output_ids = model.generate(
69
+ **inputs,
70
+ max_new_tokens=max_new_tokens,
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ do_sample=True,
74
+ pad_token_id=tokenizer.eos_token_id
75
+ )
76
+ full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
77
+ if "[ASSISTANT]" in full_text:
78
+ return full_text.split("[ASSISTANT]")[-1].strip()
79
+ if full_text.startswith(prompt):
80
+ return full_text[len(prompt):].strip()
81
+ return full_text.strip()
82
+
83
+ # -------------------------
84
+ # DOCUMENT LOADING
85
+ # -------------------------
86
+ def load_text_from_pdf(file_obj) -> str:
87
+ reader = PdfReader(file_obj)
88
+ pages = []
89
+ for page in reader.pages:
90
+ try:
91
+ pages.append(page.extract_text() or "")
92
+ except Exception:
93
+ pages.append("")
94
+ return "\n".join(pages).strip()
95
+
96
+ def load_text_from_docx(file_obj) -> str:
97
+ data = file_obj.read()
98
+ file_obj.seek(0)
99
+ f = io.BytesIO(data)
100
+ doc = docx.Document(f)
101
+ paras = [p.text for p in doc.paragraphs]
102
+ return "\n".join(paras).strip()
103
+
104
+ def load_text_from_txt(file_obj) -> str:
105
+ data = file_obj.read()
106
+ if isinstance(data, bytes):
107
+ try:
108
+ data = data.decode("utf-8", errors="ignore")
109
+ except:
110
+ data = data.decode("latin-1", errors="ignore")
111
+ return str(data).strip()
112
+
113
+ def load_document(file) -> str:
114
+ if not file:
115
+ return ""
116
+ name = (file.name or "").lower()
117
+ if name.endswith(".pdf"):
118
+ return load_text_from_pdf(file)
119
+ elif name.endswith(".docx"):
120
+ return load_text_from_docx(file)
121
+ elif name.endswith(".txt"):
122
+ return load_text_from_txt(file)
123
+ else:
124
+ try:
125
+ return load_text_from_pdf(file)
126
+ except:
127
+ pass
128
+ try:
129
+ return load_text_from_docx(file)
130
+ except:
131
+ pass
132
+ try:
133
+ return load_text_from_txt(file)
134
+ except:
135
+ pass
136
+ return ""
137
+
138
+ def get_text_from_inputs(file, text):
139
+ file_text = load_document(file) if file else ""
140
+ final = (text or "").strip()
141
+ return file_text if len(file_text) > len(final) else final
142
+
143
+ # -------------------------
144
+ # CLAUSE PROCESSING
145
+ # -------------------------
146
+ CLAUSE_SPLIT_REGEX = re.compile(r"(?:(?:^\s*\d+(?:\.\d+)[.)]\s+)|(?:^\s[A-Z]\s*[.)]\s+)|(?:;?\s*\n))", re.MULTILINE)
147
+
148
+ def split_into_clauses(text: str, min_len: int = 40) -> List[str]:
149
+ if not text:
150
+ return []
151
+ parts = re.split(CLAUSE_SPLIT_REGEX, text)
152
+ if len(parts) < 2:
153
+ parts = re.split(r"(?<=[.;])\s+\n?\s*", text)
154
+ clauses = [p.strip() for p in parts if len(p.strip()) >= min_len]
155
+ seen = set()
156
+ unique = []
157
+ for c in clauses:
158
+ key = re.sub(r"\s+", " ", c.lower())
159
+ if key not in seen:
160
+ seen.add(key)
161
+ unique.append(c)
162
+ return unique
163
+
164
+ def simplify_clause(clause: str) -> str:
165
+ system = "You are a legal assistant that rewrites clauses into plain English while preserving meaning."
166
+ user = f"Rewrite the following clause in plain English with bullet points for risks.\n\nClause:\n{clause}"
167
+ return llm_generate(system, user, max_new_tokens=400)
168
+
169
+ def ner_entities(text: str) -> Dict[str, List[str]]:
170
+ if not text:
171
+ return {}
172
+ doc = nlp(text)
173
+ out: Dict[str, List[str]] = {}
174
+ for ent in doc.ents:
175
+ out.setdefault(ent.label_, []).append(ent.text)
176
+ out = {k: sorted(set(v)) for k, v in out.items()}
177
+ return out
178
+
179
+ def extract_clauses(text: str) -> List[str]:
180
+ return split_into_clauses(text)
181
+
182
+ # -------------------------
183
+ # DOCUMENT CLASSIFICATION
184
+ # -------------------------
185
+ DOC_TYPES = [
186
+ "Non-Disclosure Agreement (NDA)",
187
+ "Lease Agreement",
188
+ "Employment Contract",
189
+ "Service Agreement",
190
+ "Sales Agreement",
191
+ "Consulting Agreement",
192
+ "End User License Agreement (EULA)",
193
+ "Terms of Service",
194
+ ]
195
+
196
+ def classify_document(text: str) -> str:
197
+ system = "You are a legal document classifier. Choose the best matching document type."
198
+ labels = "\n".join(f"- {t}" for t in DOC_TYPES)
199
+ user = f"Classify the following document:\n{labels}\n\n{text[:5000]}"
200
+ resp = llm_generate(system, user, max_new_tokens=200)
201
+ scores = {t: (1.0 if t.lower() in resp.lower() else 0.0) for t in DOC_TYPES}
202
+ best = max(scores.items(), key=lambda kv: kv[1])[0]
203
+ if scores[best] == 0.0:
204
+ lower = text.lower()
205
+ if "confidential" in lower or "non-disclosure" in lower or "nda" in lower:
206
+ best = "Non-Disclosure Agreement (NDA)"
207
+ elif "lease" in lower or "tenant" in lower or "landlord" in lower:
208
+ best = "Lease Agreement"
209
+ elif "employment" in lower or "employee" in lower or "employer" in lower:
210
+ best = "Employment Contract"
211
+ elif "services" in lower or "service" in lower or "statement of work" in lower:
212
+ best = "Service Agreement"
213
+ return best
214
+
215
+ # -------------------------
216
+ # Negotiation Coach
217
+ # -------------------------
218
+ def negotiation_coach(clause: str) -> Tuple[str, List[Dict[str, Any]]]:
219
+ system = "You are an AI negotiation coach."
220
+ user = (
221
+ "Propose 3 alternative versions ranked by acceptance rate in JSON.\n"
222
+ f"Clause:\n{clause}"
223
+ )
224
+ resp = llm_generate(system, user, max_new_tokens=700)
225
+ data = None
226
+ try:
227
+ json_str = re.search(r"\{[\s\S]*\}", resp).group(0)
228
+ data = json.loads(json_str)
229
+ except:
230
+ data = {"alternatives": []}
231
+ alts = re.split(r"\n\s*\d+[.)]\s*", resp)
232
+ for i, chunk in enumerate(alts[1:4], start=1):
233
+ data["alternatives"].append({
234
+ "rank": i,
235
+ "acceptance_rate_percent": max(50, 90 - (i-1)*10),
236
+ "title": f"Alternative {i}",
237
+ "clause_text": chunk.strip()[:800],
238
+ "rationale": "Heuristic parse from model response."
239
+ })
240
+ return json.dumps(data, indent=2), data.get("alternatives", [])
241
+
242
+ # -------------------------
243
+ # Future Risk Predictor
244
+ # -------------------------
245
+ def future_risk_predictor(clause: str) -> Tuple[str, List[Dict[str, Any]]]:
246
+ system = "Forecast future risks over 1–5 years."
247
+ user = f"Analyze clause and return JSON timeline.\nClause:\n{clause}"
248
+ resp = llm_generate(system, user, max_new_tokens=700)
249
+ data = None
250
+ try:
251
+ json_str = re.search(r"\{[\s\S]*\}", resp).group(0)
252
+ data = json.loads(json_str)
253
+ except:
254
+ data = {"timeline": []}
255
+ for y in range(1,6):
256
+ data["timeline"].append({
257
+ "year": y,
258
+ "risk_score_0_100": min(95, 40 + y*8),
259
+ "key_risks": ["Heuristic timeline due to JSON parse fallback."],
260
+ "mitigation": ["Seek legal review", "Adjust clause terms"]
261
+ })
262
+ return json.dumps(data, indent=2), data["timeline"]
263
+
264
+ # -------------------------
265
+ # Fairness Balance Meter
266
+ # -------------------------
267
+ def fairness_balance_meter(clause: str) -> Tuple[str,int,str]:
268
+ system = "Evaluate clause fairness (0=Party A,50=balanced,100=Party B)."
269
+ user = f"Return JSON: score_0_100 and rationale.\nClause:\n{clause}"
270
+ resp = llm_generate(system, user, max_new_tokens=400)
271
+ try:
272
+ data = json.loads(re.search(r"\{[\s\S]*\}", resp).group(0))
273
+ score = int(data.get("score_0_100", 50))
274
+ rationale = data.get("rationale","")
275
+ except:
276
+ score,rationale=50,"Fallback balanced score."
277
+ return json.dumps({"score_0_100":score,"rationale":rationale,"notes":[]}, indent=2), score, rationale
278
+
279
+ # -------------------------
280
+ # Clause Battle Arena
281
+ # -------------------------
282
+ def clause_battle_arena(text_a: str, text_b: str) -> Tuple[str,str]:
283
+ system="Compare 2 contract drafts across categories."
284
+ user=f"Compare Document A vs Document B and return JSON.\nA:\n{text_a[:4000]}\nB:\n{text_b[:4000]}"
285
+ resp = llm_generate(system,user,max_new_tokens=900)
286
+ try:
287
+ data=json.loads(re.search(r"\{[\s\S]*\}", resp).group(0))
288
+ except:
289
+ data={"rounds":[{"category":c,"winner":"Draw","rationale":"Fallback"} for c in ["Liability","Termination","IP","Payment","Confidentiality","Governing Law"]],
290
+ "overall_winner":"Draw","summary":"Fallback"}
291
+ pretty=json.dumps(data, indent=2)
292
+ rounds_md="\n".join([f"- {r['category']}: {r['winner']} — {r.get('rationale','')}" for r in data.get("rounds",[])])
293
+ md=f"Overall Winner: {data.get('overall_winner','Draw')}\n\nRounds:\n{rounds_md}\n\nSummary:\n{data.get('summary','')}"
294
+ return pretty,md
295
+
296
+ # -------------------------
297
+ # Sensitive Data Sniffer
298
+ # -------------------------
299
+ PII_REGEXES = {
300
+ "Email": r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}",
301
+ "Phone": r"\+?\d[\d\-\s]{7,}\d",
302
+ "SSN (US)": r"\b\d{3}-\d{2}-\d{4}\b",
303
+ "Credit Card": r"\b(?:\d[ -]*?){13,16}\b",
304
+ }
305
+
306
+ def sensitive_data_sniffer(text: str) -> Tuple[str, Dict[str,List[str]]]:
307
+ system="Find hidden privacy traps and personal data."
308
+ user=f"Return JSON.\nText:\n{text[:6000]}"
309
+ resp=llm_generate(system,user,max_new_tokens=700)
310
+ try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0))
311
+ except:data={"data_categories":["Name","Email"],"sharing_parties":["Service Provider"],"processing_purposes":["Service delivery"],"risks":["Potential over-collection"],"recommendations":["Narrow purpose","Limit retention"]}
312
+ regex_hits={}
313
+ for label, pattern in PII_REGEXES.items():
314
+ hits=re.findall(pattern,text or "",flags=re.IGNORECASE)
315
+ if hits: regex_hits[label]=sorted(set([h.strip() for h in hits]))
316
+ pretty=json.dumps({"llm":data,"regex_hits":regex_hits},indent=2)
317
+ return pretty, regex_hits
318
+
319
+ # -------------------------
320
+ # Litigation Risk Radar
321
+ # -------------------------
322
+ def litigation_risk_radar(text:str)->Tuple[str,str]:
323
+ clauses=split_into_clauses(text)
324
+ sample="\n\n".join(clauses[:8]) if clauses else text[:4000]
325
+ system="Identify clauses likely to trigger disputes."
326
+ user=f"Return JSON of hotspots.\nClauses:\n{sample}"
327
+ resp=llm_generate(system,user,max_new_tokens=900)
328
+ try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0))
329
+ except:data={"hotspots":[{"clause_excerpt":(clauses[0][:280] if clauses else text[:280]),"risk_level":"Medium","why":"Ambiguous obligations.","sample_dispute_scenario":"Party A alleges non-performance due to unclear milestones."}]}
330
+ pretty=json.dumps(data,indent=2)
331
+ md="\n".join([f"- [{h.get('risk_level','Medium')}] {h.get('clause_excerpt','')}\n Why: {h.get('why','')}\n Scenario: {h.get('sample_dispute_scenario','')}" for h in data.get("hotspots",[])])
332
+ return pretty, md
333
+
334
+ # -------------------------
335
+ # STREAMLIT UI
336
+ # -------------------------
337
+ st.title("ClauseWise – Granite 3.2 (2B) Legal Assistant")
338
+ st.markdown("Upload a PDF/DOCX/TXT or paste text below. Tabs provide different legal analysis tools.")
339
+
340
+ with st.sidebar:
341
+ uploaded_file = st.file_uploader("Upload PDF/DOCX/TXT (optional)", type=["pdf","docx","txt"])
342
+ pasted_text = st.text_area("Or paste text here", height=200)
343
+
344
+ text_data = get_text_from_inputs(uploaded_file, pasted_text)
345
+
346
+ tabs = st.tabs([
347
+ "Clause Simplification","Named Entity Recognition","Clause Extraction",
348
+ "Document Classification","Negotiation Coach","Future Risk Predictor",
349
+ "Fairness Balance Meter","Clause Battle Arena","Sensitive Data Sniffer","Litigation Risk Radar"
350
+ ])
351
+
352
+ with tabs[0]:
353
+ clause_input = st.text_area("Clause (optional)", height=150)
354
+ if st.button("Simplify Clause", key="simplify"):
355
+ target = clause_input.strip() or text_data
356
+ st.text_area("Plain English Output", simplify_clause(target), height=250)
357
+
358
+ with tabs[1]:
359
+ if st.button("Run NER", key="ner"):
360
+ st.json(ner_entities(text_data[:12000]))
361
+
362
+ with tabs[2]:
363
+ if st.button("Extract Clauses", key="extract"):
364
+ clauses = extract_clauses(text_data)
365
+ st.dataframe([[c] for c in clauses], columns=["Clause"])
366
+
367
+ with tabs[3]:
368
+ if st.button("Classify Document", key="classify"):
369
+ st.text_area("Predicted Type", classify_document(text_data))
370
+
371
+ with tabs[4]:
372
+ negotiation_clause = st.text_area("Clause to Optimize", height=150)
373
+ if st.button("Suggest Alternatives", key="negotiation"):
374
+ pretty, alts = negotiation_coach(negotiation_clause.strip() or text_data)
375
+ st.json(json.loads(pretty))
376
+ table=[[a.get("rank",""),a.get("acceptance_rate_percent",""),a.get("title",""),a.get("clause_text",""),a.get("rationale","")] for a in alts]
377
+ st.dataframe(table, columns=["Rank","Acceptance %","Title","Clause Text","Rationale"])
378
+
379
+ with tabs[5]:
380
+ risk_clause = st.text_area("Clause for Risk Prediction", height=150)
381
+ if st.button("Predict 1–5 Year Risks", key="risk"):
382
+ pretty, timeline = future_risk_predictor(risk_clause.strip() or text_data)
383
+ st.json(json.loads(pretty))
384
+ table=[[t.get("year",""),t.get("risk_score_0_100",""),"; ".join(t.get("key_risks",[])),"; ".join(t.get("mitigation",[]))] for t in timeline]
385
+ st.dataframe(table, columns=["Year","Risk Score (0–100)","Key Risks","Mitigation"])
386
+
387
+ with tabs[6]:
388
+ fairness_clause = st.text_area("Clause", height=150)
389
+ if st.button("Compute Fairness", key="fairness"):
390
+ pretty, score, rationale = fairness_balance_meter(fairness_clause.strip() or text_data)
391
+ st.json(json.loads(pretty))
392
+ st.slider("Balance Score", min_value=0,max_value=100,value=score)
393
+ st.text_area("Rationale / Notes", rationale, height=100)
394
+
395
+ with tabs[7]:
396
+ clause_a = st.text_area("Document A", height=150)
397
+ clause_b = st.text_area("Document B", height=150)
398
+ if st.button("Compare Clauses", key="battle"):
399
+ pretty, md = clause_battle_arena(clause_a.strip() or text_data, clause_b.strip() or text_data)
400
+ st.text_area("Battle JSON", pretty, height=300)
401
+ st.markdown(md)
402
+
403
+ with tabs[8]:
404
+ if st.button("Scan for Sensitive Data", key="sensitive"):
405
+ pretty, hits = sensitive_data_sniffer(text_data)
406
+ st.text_area("Sensitive Data JSON", pretty, height=300)
407
+ st.json(hits)
408
+
409
+ with tabs[9]:
410
+ if st.button("Identify Litigation Risk Hotspots", key="litigation"):
411
+ pretty, md = litigation_risk_radar(text_data)
412
+ st.text_area("Litigation JSON", pretty, height=300)
413
+ st.markdown(md)