bhoomi19 commited on
Commit
0d9ba1e
·
verified ·
1 Parent(s): d335737

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +405 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,407 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import re
5
+ import io
6
+ import json
7
+ from typing import List, Dict, Tuple, Any, Optional
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from pypdf import PdfReader
11
+ import docx
12
+ import spacy
13
+ import math
14
 
15
+ # -------------------------
16
+ # PAGE CONFIG
17
+ # -------------------------
18
+ st.set_page_config(page_title="ClauseWise Granite 3.2 (2B) Legal Assistant", page_icon="⚖️", layout="wide")
19
+
20
+ # -------------------------
21
+ # MODEL SETUP
22
+ # -------------------------
23
+ MODEL_ID = "ibm-granite/granite-3.2-2b-instruct"
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
26
+
27
+ @st.cache_resource
28
+ def load_llm_model():
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_ID,
32
+ torch_dtype=DTYPE,
33
+ device_map="auto" if DEVICE == "cuda" else None
34
+ )
35
+ if DEVICE != "cuda":
36
+ model.to(DEVICE)
37
+ return tokenizer, model
38
+
39
+ tokenizer, model = load_llm_model()
40
+
41
+ nlp = spacy.load("en_core_web_sm")
42
+
43
+ # -------------------------
44
+ # HELPER FUNCTIONS
45
+ # -------------------------
46
+ def build_chat_prompt(system_prompt: str, user_prompt: str) -> str:
47
+ messages = []
48
+ if system_prompt:
49
+ messages.append({"role": "system", "content": system_prompt})
50
+ messages.append({"role": "user", "content": user_prompt})
51
+ try:
52
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
+ except Exception:
54
+ sys = f"[SYSTEM]\n{system_prompt}\n" if system_prompt else ""
55
+ usr = f"[USER]\n{user_prompt}\n[ASSISTANT]\n"
56
+ return sys + usr
57
+
58
+ def llm_generate(system_prompt: str, user_prompt: str, max_new_tokens=512, temperature=0.3, top_p=0.9) -> str:
59
+ prompt = build_chat_prompt(system_prompt, user_prompt)
60
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
61
+ with torch.inference_mode():
62
+ output_ids = model.generate(
63
+ **inputs,
64
+ max_new_tokens=max_new_tokens,
65
+ temperature=temperature,
66
+ top_p=top_p,
67
+ do_sample=True,
68
+ pad_token_id=tokenizer.eos_token_id
69
+ )
70
+ full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
71
+ if "[ASSISTANT]" in full_text:
72
+ return full_text.split("[ASSISTANT]")[-1].strip()
73
+ if full_text.startswith(prompt):
74
+ return full_text[len(prompt):].strip()
75
+ return full_text.strip()
76
+
77
+ # -------------------------
78
+ # DOCUMENT LOADING
79
+ # -------------------------
80
+ def load_text_from_pdf(file_obj) -> str:
81
+ reader = PdfReader(file_obj)
82
+ pages = []
83
+ for page in reader.pages:
84
+ try:
85
+ pages.append(page.extract_text() or "")
86
+ except Exception:
87
+ pages.append("")
88
+ return "\n".join(pages).strip()
89
+
90
+ def load_text_from_docx(file_obj) -> str:
91
+ data = file_obj.read()
92
+ file_obj.seek(0)
93
+ f = io.BytesIO(data)
94
+ doc = docx.Document(f)
95
+ paras = [p.text for p in doc.paragraphs]
96
+ return "\n".join(paras).strip()
97
+
98
+ def load_text_from_txt(file_obj) -> str:
99
+ data = file_obj.read()
100
+ if isinstance(data, bytes):
101
+ try:
102
+ data = data.decode("utf-8", errors="ignore")
103
+ except:
104
+ data = data.decode("latin-1", errors="ignore")
105
+ return str(data).strip()
106
+
107
+ def load_document(file) -> str:
108
+ if not file:
109
+ return ""
110
+ name = (file.name or "").lower()
111
+ if name.endswith(".pdf"):
112
+ return load_text_from_pdf(file)
113
+ elif name.endswith(".docx"):
114
+ return load_text_from_docx(file)
115
+ elif name.endswith(".txt"):
116
+ return load_text_from_txt(file)
117
+ else:
118
+ try:
119
+ return load_text_from_pdf(file)
120
+ except:
121
+ pass
122
+ try:
123
+ return load_text_from_docx(file)
124
+ except:
125
+ pass
126
+ try:
127
+ return load_text_from_txt(file)
128
+ except:
129
+ pass
130
+ return ""
131
+
132
+ def get_text_from_inputs(file, text):
133
+ file_text = load_document(file) if file else ""
134
+ final = (text or "").strip()
135
+ return file_text if len(file_text) > len(final) else final
136
+
137
+ # -------------------------
138
+ # CLAUSE PROCESSING
139
+ # -------------------------
140
+ CLAUSE_SPLIT_REGEX = re.compile(r"(?:(?:^\s*\d+(?:\.\d+)[.)]\s+)|(?:^\s[A-Z]\s*[.)]\s+)|(?:;?\s*\n))", re.MULTILINE)
141
+
142
+ def split_into_clauses(text: str, min_len: int = 40) -> List[str]:
143
+ if not text:
144
+ return []
145
+ parts = re.split(CLAUSE_SPLIT_REGEX, text)
146
+ if len(parts) < 2:
147
+ parts = re.split(r"(?<=[.;])\s+\n?\s*", text)
148
+ clauses = [p.strip() for p in parts if len(p.strip()) >= min_len]
149
+ seen = set()
150
+ unique = []
151
+ for c in clauses:
152
+ key = re.sub(r"\s+", " ", c.lower())
153
+ if key not in seen:
154
+ seen.add(key)
155
+ unique.append(c)
156
+ return unique
157
+
158
+ def simplify_clause(clause: str) -> str:
159
+ system = "You are a legal assistant that rewrites clauses into plain English while preserving meaning."
160
+ user = f"Rewrite the following clause in plain English with bullet points for risks.\n\nClause:\n{clause}"
161
+ return llm_generate(system, user, max_new_tokens=400)
162
+
163
+ def ner_entities(text: str) -> Dict[str, List[str]]:
164
+ if not text:
165
+ return {}
166
+ doc = nlp(text)
167
+ out: Dict[str, List[str]] = {}
168
+ for ent in doc.ents:
169
+ out.setdefault(ent.label_, []).append(ent.text)
170
+ out = {k: sorted(set(v)) for k, v in out.items()}
171
+ return out
172
+
173
+ def extract_clauses(text: str) -> List[str]:
174
+ return split_into_clauses(text)
175
+
176
+ # -------------------------
177
+ # DOCUMENT CLASSIFICATION
178
+ # -------------------------
179
+ DOC_TYPES = [
180
+ "Non-Disclosure Agreement (NDA)",
181
+ "Lease Agreement",
182
+ "Employment Contract",
183
+ "Service Agreement",
184
+ "Sales Agreement",
185
+ "Consulting Agreement",
186
+ "End User License Agreement (EULA)",
187
+ "Terms of Service",
188
+ ]
189
+
190
+ def classify_document(text: str) -> str:
191
+ system = "You are a legal document classifier. Choose the best matching document type."
192
+ labels = "\n".join(f"- {t}" for t in DOC_TYPES)
193
+ user = f"Classify the following document:\n{labels}\n\n{text[:5000]}"
194
+ resp = llm_generate(system, user, max_new_tokens=200)
195
+ scores = {t: (1.0 if t.lower() in resp.lower() else 0.0) for t in DOC_TYPES}
196
+ best = max(scores.items(), key=lambda kv: kv[1])[0]
197
+ if scores[best] == 0.0:
198
+ lower = text.lower()
199
+ if "confidential" in lower or "non-disclosure" in lower or "nda" in lower:
200
+ best = "Non-Disclosure Agreement (NDA)"
201
+ elif "lease" in lower or "tenant" in lower or "landlord" in lower:
202
+ best = "Lease Agreement"
203
+ elif "employment" in lower or "employee" in lower or "employer" in lower:
204
+ best = "Employment Contract"
205
+ elif "services" in lower or "service" in lower or "statement of work" in lower:
206
+ best = "Service Agreement"
207
+ return best
208
+
209
+ # -------------------------
210
+ # Negotiation Coach
211
+ # -------------------------
212
+ def negotiation_coach(clause: str) -> Tuple[str, List[Dict[str, Any]]]:
213
+ system = "You are an AI negotiation coach."
214
+ user = (
215
+ "Propose 3 alternative versions ranked by acceptance rate in JSON.\n"
216
+ f"Clause:\n{clause}"
217
+ )
218
+ resp = llm_generate(system, user, max_new_tokens=700)
219
+ data = None
220
+ try:
221
+ json_str = re.search(r"\{[\s\S]*\}", resp).group(0)
222
+ data = json.loads(json_str)
223
+ except:
224
+ data = {"alternatives": []}
225
+ alts = re.split(r"\n\s*\d+[.)]\s*", resp)
226
+ for i, chunk in enumerate(alts[1:4], start=1):
227
+ data["alternatives"].append({
228
+ "rank": i,
229
+ "acceptance_rate_percent": max(50, 90 - (i-1)*10),
230
+ "title": f"Alternative {i}",
231
+ "clause_text": chunk.strip()[:800],
232
+ "rationale": "Heuristic parse from model response."
233
+ })
234
+ return json.dumps(data, indent=2), data.get("alternatives", [])
235
+
236
+ # -------------------------
237
+ # Future Risk Predictor
238
+ # -------------------------
239
+ def future_risk_predictor(clause: str) -> Tuple[str, List[Dict[str, Any]]]:
240
+ system = "Forecast future risks over 1–5 years."
241
+ user = f"Analyze clause and return JSON timeline.\nClause:\n{clause}"
242
+ resp = llm_generate(system, user, max_new_tokens=700)
243
+ data = None
244
+ try:
245
+ json_str = re.search(r"\{[\s\S]*\}", resp).group(0)
246
+ data = json.loads(json_str)
247
+ except:
248
+ data = {"timeline": []}
249
+ for y in range(1,6):
250
+ data["timeline"].append({
251
+ "year": y,
252
+ "risk_score_0_100": min(95, 40 + y*8),
253
+ "key_risks": ["Heuristic timeline due to JSON parse fallback."],
254
+ "mitigation": ["Seek legal review", "Adjust clause terms"]
255
+ })
256
+ return json.dumps(data, indent=2), data["timeline"]
257
+
258
+ # -------------------------
259
+ # Fairness Balance Meter
260
+ # -------------------------
261
+ def fairness_balance_meter(clause: str) -> Tuple[str,int,str]:
262
+ system = "Evaluate clause fairness (0=Party A,50=balanced,100=Party B)."
263
+ user = f"Return JSON: score_0_100 and rationale.\nClause:\n{clause}"
264
+ resp = llm_generate(system, user, max_new_tokens=400)
265
+ try:
266
+ data = json.loads(re.search(r"\{[\s\S]*\}", resp).group(0))
267
+ score = int(data.get("score_0_100", 50))
268
+ rationale = data.get("rationale","")
269
+ except:
270
+ score,rationale=50,"Fallback balanced score."
271
+ return json.dumps({"score_0_100":score,"rationale":rationale,"notes":[]}, indent=2), score, rationale
272
+
273
+ # -------------------------
274
+ # Clause Battle Arena
275
+ # -------------------------
276
+ def clause_battle_arena(text_a: str, text_b: str) -> Tuple[str,str]:
277
+ system="Compare 2 contract drafts across categories."
278
+ user=f"Compare Document A vs Document B and return JSON.\nA:\n{text_a[:4000]}\nB:\n{text_b[:4000]}"
279
+ resp = llm_generate(system,user,max_new_tokens=900)
280
+ try:
281
+ data=json.loads(re.search(r"\{[\s\S]*\}", resp).group(0))
282
+ except:
283
+ data={"rounds":[{"category":c,"winner":"Draw","rationale":"Fallback"} for c in ["Liability","Termination","IP","Payment","Confidentiality","Governing Law"]],
284
+ "overall_winner":"Draw","summary":"Fallback"}
285
+ pretty=json.dumps(data, indent=2)
286
+ rounds_md="\n".join([f"- {r['category']}: {r['winner']} — {r.get('rationale','')}" for r in data.get("rounds",[])])
287
+ md=f"Overall Winner: {data.get('overall_winner','Draw')}\n\nRounds:\n{rounds_md}\n\nSummary:\n{data.get('summary','')}"
288
+ return pretty,md
289
+
290
+ # -------------------------
291
+ # Sensitive Data Sniffer
292
+ # -------------------------
293
+ PII_REGEXES = {
294
+ "Email": r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}",
295
+ "Phone": r"\+?\d[\d\-\s]{7,}\d",
296
+ "SSN (US)": r"\b\d{3}-\d{2}-\d{4}\b",
297
+ "Credit Card": r"\b(?:\d[ -]*?){13,16}\b",
298
+ }
299
+
300
+ def sensitive_data_sniffer(text: str) -> Tuple[str, Dict[str,List[str]]]:
301
+ system="Find hidden privacy traps and personal data."
302
+ user=f"Return JSON.\nText:\n{text[:6000]}"
303
+ resp=llm_generate(system,user,max_new_tokens=700)
304
+ try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0))
305
+ except:data={"data_categories":["Name","Email"],"sharing_parties":["Service Provider"],"processing_purposes":["Service delivery"],"risks":["Potential over-collection"],"recommendations":["Narrow purpose","Limit retention"]}
306
+ regex_hits={}
307
+ for label, pattern in PII_REGEXES.items():
308
+ hits=re.findall(pattern,text or "",flags=re.IGNORECASE)
309
+ if hits: regex_hits[label]=sorted(set([h.strip() for h in hits]))
310
+ pretty=json.dumps({"llm":data,"regex_hits":regex_hits},indent=2)
311
+ return pretty, regex_hits
312
+
313
+ # -------------------------
314
+ # Litigation Risk Radar
315
+ # -------------------------
316
+ def litigation_risk_radar(text:str)->Tuple[str,str]:
317
+ clauses=split_into_clauses(text)
318
+ sample="\n\n".join(clauses[:8]) if clauses else text[:4000]
319
+ system="Identify clauses likely to trigger disputes."
320
+ user=f"Return JSON of hotspots.\nClauses:\n{sample}"
321
+ resp=llm_generate(system,user,max_new_tokens=900)
322
+ try:data=json.loads(re.search(r"\{[\s\S]*\}",resp).group(0))
323
+ except:data={"hotspots":[{"clause_excerpt":(clauses[0][:280] if clauses else text[:280]),"risk_level":"Medium","why":"Ambiguous obligations.","sample_dispute_scenario":"Party A alleges non-performance due to unclear milestones."}]}
324
+ pretty=json.dumps(data,indent=2)
325
+ md="\n".join([f"- [{h.get('risk_level','Medium')}] {h.get('clause_excerpt','')}\n Why: {h.get('why','')}\n Scenario: {h.get('sample_dispute_scenario','')}" for h in data.get("hotspots",[])])
326
+ return pretty, md
327
+
328
+ # -------------------------
329
+ # STREAMLIT UI
330
+ # -------------------------
331
+ st.title("ClauseWise – Granite 3.2 (2B) Legal Assistant")
332
+ st.markdown("Upload a PDF/DOCX/TXT or paste text below. Tabs provide different legal analysis tools.")
333
+
334
+ with st.sidebar:
335
+ uploaded_file = st.file_uploader("Upload PDF/DOCX/TXT (optional)", type=["pdf","docx","txt"])
336
+ pasted_text = st.text_area("Or paste text here", height=200)
337
+
338
+ text_data = get_text_from_inputs(uploaded_file, pasted_text)
339
+
340
+ tabs = st.tabs([
341
+ "Clause Simplification","Named Entity Recognition","Clause Extraction",
342
+ "Document Classification","Negotiation Coach","Future Risk Predictor",
343
+ "Fairness Balance Meter","Clause Battle Arena","Sensitive Data Sniffer","Litigation Risk Radar"
344
+ ])
345
+
346
+ with tabs[0]:
347
+ clause_input = st.text_area("Clause (optional)", height=150)
348
+ if st.button("Simplify Clause", key="simplify"):
349
+ target = clause_input.strip() or text_data
350
+ st.text_area("Plain English Output", simplify_clause(target), height=250)
351
+
352
+ with tabs[1]:
353
+ if st.button("Run NER", key="ner"):
354
+ st.json(ner_entities(text_data[:12000]))
355
+
356
+ with tabs[2]:
357
+ if st.button("Extract Clauses", key="extract"):
358
+ clauses = extract_clauses(text_data)
359
+ st.dataframe([[c] for c in clauses], columns=["Clause"])
360
+
361
+ with tabs[3]:
362
+ if st.button("Classify Document", key="classify"):
363
+ st.text_area("Predicted Type", classify_document(text_data))
364
+
365
+ with tabs[4]:
366
+ negotiation_clause = st.text_area("Clause to Optimize", height=150)
367
+ if st.button("Suggest Alternatives", key="negotiation"):
368
+ pretty, alts = negotiation_coach(negotiation_clause.strip() or text_data)
369
+ st.json(json.loads(pretty))
370
+ table=[[a.get("rank",""),a.get("acceptance_rate_percent",""),a.get("title",""),a.get("clause_text",""),a.get("rationale","")] for a in alts]
371
+ st.dataframe(table, columns=["Rank","Acceptance %","Title","Clause Text","Rationale"])
372
+
373
+ with tabs[5]:
374
+ risk_clause = st.text_area("Clause for Risk Prediction", height=150)
375
+ if st.button("Predict 1–5 Year Risks", key="risk"):
376
+ pretty, timeline = future_risk_predictor(risk_clause.strip() or text_data)
377
+ st.json(json.loads(pretty))
378
+ table=[[t.get("year",""),t.get("risk_score_0_100",""),"; ".join(t.get("key_risks",[])),"; ".join(t.get("mitigation",[]))] for t in timeline]
379
+ st.dataframe(table, columns=["Year","Risk Score (0–100)","Key Risks","Mitigation"])
380
+
381
+ with tabs[6]:
382
+ fairness_clause = st.text_area("Clause", height=150)
383
+ if st.button("Compute Fairness", key="fairness"):
384
+ pretty, score, rationale = fairness_balance_meter(fairness_clause.strip() or text_data)
385
+ st.json(json.loads(pretty))
386
+ st.slider("Balance Score", min_value=0,max_value=100,value=score)
387
+ st.text_area("Rationale / Notes", rationale, height=100)
388
+
389
+ with tabs[7]:
390
+ clause_a = st.text_area("Document A", height=150)
391
+ clause_b = st.text_area("Document B", height=150)
392
+ if st.button("Compare Clauses", key="battle"):
393
+ pretty, md = clause_battle_arena(clause_a.strip() or text_data, clause_b.strip() or text_data)
394
+ st.text_area("Battle JSON", pretty, height=300)
395
+ st.markdown(md)
396
+
397
+ with tabs[8]:
398
+ if st.button("Scan for Sensitive Data", key="sensitive"):
399
+ pretty, hits = sensitive_data_sniffer(text_data)
400
+ st.text_area("Sensitive Data JSON", pretty, height=300)
401
+ st.json(hits)
402
+
403
+ with tabs[9]:
404
+ if st.button("Identify Litigation Risk Hotspots", key="litigation"):
405
+ pretty, md = litigation_risk_radar(text_data)
406
+ st.text_area("Litigation JSON", pretty, height=300)
407
+ st.markdown(md)