jayasrees commited on
Commit
59a54c2
·
1 Parent(s): ccecc05

Change model to TinyLlama from Hub and remove local-only loading

Browse files
Files changed (2) hide show
  1. analysis/llama_legal_verifier.py +6 -6
  2. app.py +107 -35
analysis/llama_legal_verifier.py CHANGED
@@ -13,17 +13,13 @@ class LlamaLegalVerifier:
13
  """
14
 
15
  def __init__(self, model_path: str):
16
- if not os.path.isdir(model_path):
17
- raise FileNotFoundError(f"Model path not found: {model_path}")
18
-
19
  self.model_path = model_path
20
  self.device = 0 if torch.cuda.is_available() else -1
21
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
 
23
- tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_path,
26
- local_files_only=True,
27
  torch_dtype=dtype,
28
  )
29
  if tokenizer.pad_token_id is None:
@@ -41,7 +37,11 @@ class LlamaLegalVerifier:
41
  lowered = text.lower()
42
  if "contradiction" in lowered:
43
  return "Contradiction"
44
- if "entailment" in lowered or "duplicate" in lowered or "same meaning" in lowered:
 
 
 
 
45
  return "Entailment"
46
  if "neutral" in lowered:
47
  return "Neutral"
 
13
  """
14
 
15
  def __init__(self, model_path: str):
 
 
 
16
  self.model_path = model_path
17
  self.device = 0 if torch.cuda.is_available() else -1
18
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
19
 
20
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_path,
 
23
  torch_dtype=dtype,
24
  )
25
  if tokenizer.pad_token_id is None:
 
37
  lowered = text.lower()
38
  if "contradiction" in lowered:
39
  return "Contradiction"
40
+ if (
41
+ "entailment" in lowered
42
+ or "duplicate" in lowered
43
+ or "same meaning" in lowered
44
+ ):
45
  return "Entailment"
46
  if "neutral" in lowered:
47
  return "Neutral"
app.py CHANGED
@@ -3,9 +3,6 @@ import sys
3
  from pathlib import Path
4
 
5
 
6
-
7
-
8
-
9
  import importlib
10
  import json
11
  import base64
@@ -14,9 +11,10 @@ import re
14
  import pandas as pd
15
  import plotly.express as px
16
  import streamlit as st
 
17
  sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
18
 
19
- #sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
20
 
21
  from preprocessing.text_extractor import extract_text_from_file
22
  from preprocessing.clause_extraction import extract_clauses
@@ -25,6 +23,7 @@ from storage.faiss_index import create_faiss_index
25
  from analysis.similarity_search import get_similar
26
 
27
  import analysis.common_analyzer
 
28
  importlib.reload(analysis.common_analyzer)
29
  from analysis.common_analyzer import analyze_pair
30
 
@@ -35,7 +34,7 @@ from auth.user_store import authenticate_user, create_user
35
 
36
 
37
  APP_TITLE = "Legal Semantic Integrity"
38
- DEFAULT_MODEL_PATH = "merged_tinyllama_instruction"
39
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
40
 
41
 
@@ -73,7 +72,9 @@ def _extract_party_name(text: str, role: str) -> str:
73
  if m:
74
  name = " ".join(m.group(1).split())
75
  # Filter generic captures like "hereinafter called"
76
- if name and not re.search(r"hereinafter|called|referred|party|agreement", name, re.IGNORECASE):
 
 
77
  return name[:80]
78
 
79
  if re.search(rf"\b{role_l}\b", t, flags=re.IGNORECASE):
@@ -121,7 +122,9 @@ def _extract_document_parties(text_data):
121
  parties[role] = cleaned
122
  break
123
  # Secondary fallback: explicit role in text without name
124
- if parties[role] == "Not found" and re.search(rf"\b{role.lower()}\b", compact, flags=re.IGNORECASE):
 
 
125
  parties[role] = f"{role} mentioned (name not parsed)"
126
 
127
  return parties
@@ -137,9 +140,15 @@ def _extract_parties(text1: str, text2: str, doc_parties=None):
137
  vendee = _extract_party_name(text2, "vendee")
138
 
139
  if doc_parties:
140
- if vendor in ["Not found", "Vendor mentioned (name not parsed)"] and doc_parties.get("Vendor"):
 
 
 
141
  vendor = doc_parties.get("Vendor")
142
- if vendee in ["Not found", "Vendee mentioned (name not parsed)"] and doc_parties.get("Vendee"):
 
 
 
143
  vendee = doc_parties.get("Vendee")
144
 
145
  return vendor, vendee
@@ -299,7 +308,9 @@ def login_page():
299
  )
300
 
301
  with col_auth:
302
- st.markdown('<div class="step">Step 1 of 3: Login</div>', unsafe_allow_html=True)
 
 
303
  tab_login, tab_signup = st.tabs(["Sign In", "Create Account"])
304
 
305
  with tab_login:
@@ -338,7 +349,9 @@ def login_page():
338
  st.caption("Local accounts are saved in data/users.db")
339
 
340
 
341
- def run_analysis(uploaded_file, sensitivity: float, backend: str, llama_model_path: str):
 
 
342
  file_ext = uploaded_file.name.split(".")[-1].lower()
343
 
344
  with st.spinner("Extracting text..."):
@@ -412,9 +425,13 @@ def run_analysis(uploaded_file, sensitivity: float, backend: str, llama_model_pa
412
  result["Vendee"] = vendee_name
413
 
414
  if backend == "llama":
415
- _, llm_conf, llm_label, llm_reason = verifier.predict(result["Clause 1"], result["Clause 2"])
 
 
416
  else:
417
- _, llm_conf, llm_label = verifier.predict(result["Clause 1"], result["Clause 2"])
 
 
418
  llm_reason = f"NLI label: {llm_label}"
419
 
420
  if llm_label == "Neutral":
@@ -483,7 +500,9 @@ def upload_page():
483
  """,
484
  unsafe_allow_html=True,
485
  )
486
- st.markdown('<div class="step">Step 2 of 3: Upload Document</div>', unsafe_allow_html=True)
 
 
487
 
488
  with st.sidebar:
489
  st.header("Scan Settings")
@@ -514,7 +533,7 @@ def upload_page():
514
  f"""
515
  <div class="mini-card">
516
  <div class="mini-label">Active Mode</div>
517
- <div class="mini-value">{scan_mode.split('(')[0].strip()}</div>
518
  <div class="mono">Sensitivity: {sensitivity} | Backend: {model_backend}</div>
519
  </div>
520
  """,
@@ -578,7 +597,9 @@ def dashboard_page():
578
  """,
579
  unsafe_allow_html=True,
580
  )
581
- st.markdown('<div class="step">Step 3 of 3: Dashboard</div>', unsafe_allow_html=True)
 
 
582
 
583
  results = st.session_state.results
584
  line_issues = st.session_state.line_issues
@@ -653,10 +674,16 @@ def dashboard_page():
653
  st.caption(f"Single issue page: {page_min}")
654
  page_sel = (page_min, page_max)
655
  else:
656
- page_sel = st.slider("Page Range (analytics)", page_min, page_max, (page_min, page_max))
 
 
657
  with filter_col3:
658
- vendors = ["All"] + sorted(line_df["Vendor"].dropna().astype(str).unique().tolist())
659
- vendees = ["All"] + sorted(line_df["Vendee"].dropna().astype(str).unique().tolist())
 
 
 
 
660
  vendor_sel = st.selectbox("Vendor", vendors, index=0)
661
  vendee_sel = st.selectbox("Vendee", vendees, index=0)
662
 
@@ -664,7 +691,9 @@ def dashboard_page():
664
  if issue_sel:
665
  filtered = filtered[filtered["Issue Type"].isin(issue_sel)]
666
  filtered = filtered[filtered["Confidence"] >= conf_min]
667
- filtered = filtered[(filtered["Page"] >= page_sel[0]) & (filtered["Page"] <= page_sel[1])]
 
 
668
  if vendor_sel != "All":
669
  filtered = filtered[filtered["Vendor"] == vendor_sel]
670
  if vendee_sel != "All":
@@ -672,9 +701,13 @@ def dashboard_page():
672
 
673
  total_issues = len(filtered)
674
  conflict_rate = (len(issues_df) / len(df) * 100.0) if len(df) else 0.0
675
- top_issue = filtered["Issue Type"].mode().iloc[0] if not filtered.empty else "N/A"
 
 
676
  highest_risk_page = (
677
- int(filtered.groupby("Page")["Confidence"].mean().idxmax()) if not filtered.empty else "N/A"
 
 
678
  )
679
  k1, k2, k3, k4 = st.columns(4)
680
  k1.metric("Filtered Issues", total_issues)
@@ -697,10 +730,23 @@ def dashboard_page():
697
  pie_fig.update_layout(margin=dict(l=10, r=10, t=50, b=10))
698
  st.plotly_chart(pie_fig, use_container_width=True)
699
 
700
- top_lines = filtered.sort_values(by=["Confidence"], ascending=False).head(10)
 
 
701
  st.markdown("**Top 10 High-Risk Lines**")
702
  st.dataframe(
703
- top_lines[["Issue Type", "Confidence", "Page", "Line", "Vendor", "Vendee", "Snippet", "Reason"]],
 
 
 
 
 
 
 
 
 
 
 
704
  use_container_width=True,
705
  )
706
  else:
@@ -757,19 +803,35 @@ def dashboard_page():
757
  st.caption(f"Only one page with issues: Page {page_min}")
758
  page_range = (page_min, page_max)
759
  else:
760
- page_range = st.slider("Page range", page_min, page_max, (page_min, page_max))
 
 
761
 
762
  if selected:
763
  line_df = line_df[line_df["Issue Type"].isin(selected)]
764
- line_df = line_df[(line_df["Page"] >= page_range[0]) & (line_df["Page"] <= page_range[1])]
 
 
765
 
766
  st.dataframe(line_df, use_container_width=True)
767
 
768
  st.markdown("**Issue Occurrence By Line With Parties**")
769
  by_line = line_df.copy()
770
- by_line = by_line.sort_values(by=["Page", "Line", "Confidence"], ascending=[True, True, False])
 
 
771
  st.dataframe(
772
- by_line[["Issue Type", "Page", "Line", "Vendor", "Vendee", "Confidence", "Reason"]],
 
 
 
 
 
 
 
 
 
 
773
  use_container_width=True,
774
  )
775
 
@@ -778,10 +840,14 @@ def dashboard_page():
778
  line_df = line_df.reset_index(drop=True)
779
  line_df.insert(0, "Item", range(1, len(line_df) + 1))
780
  line_df["Jump"] = line_df.apply(
781
- lambda r: f"#{r['Item']} | Pg {int(r['Page'])}, Ln {int(r['Line'])} | {r['Issue Type']}",
 
 
782
  axis=1,
783
  )
784
- selected_jump = st.selectbox("Select issue line", line_df["Jump"].tolist())
 
 
785
  chosen = line_df[line_df["Jump"] == selected_jump].iloc[0]
786
 
787
  c1, c2 = st.columns([1.1, 1], gap="large")
@@ -790,8 +856,8 @@ def dashboard_page():
790
  f"""
791
  <div class="mini-card">
792
  <div class="mini-label">Selected Line</div>
793
- <div class="mini-value">Pg {int(chosen['Page'])} · Ln {int(chosen['Line'])}</div>
794
- <div class="mono">{chosen['Issue Type']} | Confidence: {float(chosen['Confidence']):.2f}</div>
795
  </div>
796
  """,
797
  unsafe_allow_html=True,
@@ -806,7 +872,9 @@ def dashboard_page():
806
  if is_pdf and st.session_state.uploaded_bytes:
807
  st.caption("PDF Preview (jumped to selected page)")
808
  page_number = int(chosen["Page"])
809
- pdf_b64 = base64.b64encode(st.session_state.uploaded_bytes).decode("utf-8")
 
 
810
  pdf_html = f"""
811
  <iframe
812
  src="data:application/pdf;base64,{pdf_b64}#page={page_number}&zoom=110"
@@ -817,7 +885,9 @@ def dashboard_page():
817
  """
818
  st.markdown(pdf_html, unsafe_allow_html=True)
819
  else:
820
- st.info("Inline PDF preview is available for PDF uploads. Current file is not PDF.")
 
 
821
  else:
822
  st.info("No line-level issues to display.")
823
 
@@ -830,7 +900,9 @@ def dashboard_page():
830
  file_name="semantic_integrity_report.json",
831
  mime="application/json",
832
  )
833
- pdf_bytes = generate_pdf_report([r for r in results if r["Label"] != "NO_CONFLICT"])
 
 
834
  st.download_button(
835
  label="Download PDF Report",
836
  data=pdf_bytes,
 
3
  from pathlib import Path
4
 
5
 
 
 
 
6
  import importlib
7
  import json
8
  import base64
 
11
  import pandas as pd
12
  import plotly.express as px
13
  import streamlit as st
14
+
15
  sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
16
 
17
+ # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
18
 
19
  from preprocessing.text_extractor import extract_text_from_file
20
  from preprocessing.clause_extraction import extract_clauses
 
23
  from analysis.similarity_search import get_similar
24
 
25
  import analysis.common_analyzer
26
+
27
  importlib.reload(analysis.common_analyzer)
28
  from analysis.common_analyzer import analyze_pair
29
 
 
34
 
35
 
36
  APP_TITLE = "Legal Semantic Integrity"
37
+ DEFAULT_MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
38
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
39
 
40
 
 
72
  if m:
73
  name = " ".join(m.group(1).split())
74
  # Filter generic captures like "hereinafter called"
75
+ if name and not re.search(
76
+ r"hereinafter|called|referred|party|agreement", name, re.IGNORECASE
77
+ ):
78
  return name[:80]
79
 
80
  if re.search(rf"\b{role_l}\b", t, flags=re.IGNORECASE):
 
122
  parties[role] = cleaned
123
  break
124
  # Secondary fallback: explicit role in text without name
125
+ if parties[role] == "Not found" and re.search(
126
+ rf"\b{role.lower()}\b", compact, flags=re.IGNORECASE
127
+ ):
128
  parties[role] = f"{role} mentioned (name not parsed)"
129
 
130
  return parties
 
140
  vendee = _extract_party_name(text2, "vendee")
141
 
142
  if doc_parties:
143
+ if vendor in [
144
+ "Not found",
145
+ "Vendor mentioned (name not parsed)",
146
+ ] and doc_parties.get("Vendor"):
147
  vendor = doc_parties.get("Vendor")
148
+ if vendee in [
149
+ "Not found",
150
+ "Vendee mentioned (name not parsed)",
151
+ ] and doc_parties.get("Vendee"):
152
  vendee = doc_parties.get("Vendee")
153
 
154
  return vendor, vendee
 
308
  )
309
 
310
  with col_auth:
311
+ st.markdown(
312
+ '<div class="step">Step 1 of 3: Login</div>', unsafe_allow_html=True
313
+ )
314
  tab_login, tab_signup = st.tabs(["Sign In", "Create Account"])
315
 
316
  with tab_login:
 
349
  st.caption("Local accounts are saved in data/users.db")
350
 
351
 
352
+ def run_analysis(
353
+ uploaded_file, sensitivity: float, backend: str, llama_model_path: str
354
+ ):
355
  file_ext = uploaded_file.name.split(".")[-1].lower()
356
 
357
  with st.spinner("Extracting text..."):
 
425
  result["Vendee"] = vendee_name
426
 
427
  if backend == "llama":
428
+ _, llm_conf, llm_label, llm_reason = verifier.predict(
429
+ result["Clause 1"], result["Clause 2"]
430
+ )
431
  else:
432
+ _, llm_conf, llm_label = verifier.predict(
433
+ result["Clause 1"], result["Clause 2"]
434
+ )
435
  llm_reason = f"NLI label: {llm_label}"
436
 
437
  if llm_label == "Neutral":
 
500
  """,
501
  unsafe_allow_html=True,
502
  )
503
+ st.markdown(
504
+ '<div class="step">Step 2 of 3: Upload Document</div>', unsafe_allow_html=True
505
+ )
506
 
507
  with st.sidebar:
508
  st.header("Scan Settings")
 
533
  f"""
534
  <div class="mini-card">
535
  <div class="mini-label">Active Mode</div>
536
+ <div class="mini-value">{scan_mode.split("(")[0].strip()}</div>
537
  <div class="mono">Sensitivity: {sensitivity} | Backend: {model_backend}</div>
538
  </div>
539
  """,
 
597
  """,
598
  unsafe_allow_html=True,
599
  )
600
+ st.markdown(
601
+ '<div class="step">Step 3 of 3: Dashboard</div>', unsafe_allow_html=True
602
+ )
603
 
604
  results = st.session_state.results
605
  line_issues = st.session_state.line_issues
 
674
  st.caption(f"Single issue page: {page_min}")
675
  page_sel = (page_min, page_max)
676
  else:
677
+ page_sel = st.slider(
678
+ "Page Range (analytics)", page_min, page_max, (page_min, page_max)
679
+ )
680
  with filter_col3:
681
+ vendors = ["All"] + sorted(
682
+ line_df["Vendor"].dropna().astype(str).unique().tolist()
683
+ )
684
+ vendees = ["All"] + sorted(
685
+ line_df["Vendee"].dropna().astype(str).unique().tolist()
686
+ )
687
  vendor_sel = st.selectbox("Vendor", vendors, index=0)
688
  vendee_sel = st.selectbox("Vendee", vendees, index=0)
689
 
 
691
  if issue_sel:
692
  filtered = filtered[filtered["Issue Type"].isin(issue_sel)]
693
  filtered = filtered[filtered["Confidence"] >= conf_min]
694
+ filtered = filtered[
695
+ (filtered["Page"] >= page_sel[0]) & (filtered["Page"] <= page_sel[1])
696
+ ]
697
  if vendor_sel != "All":
698
  filtered = filtered[filtered["Vendor"] == vendor_sel]
699
  if vendee_sel != "All":
 
701
 
702
  total_issues = len(filtered)
703
  conflict_rate = (len(issues_df) / len(df) * 100.0) if len(df) else 0.0
704
+ top_issue = (
705
+ filtered["Issue Type"].mode().iloc[0] if not filtered.empty else "N/A"
706
+ )
707
  highest_risk_page = (
708
+ int(filtered.groupby("Page")["Confidence"].mean().idxmax())
709
+ if not filtered.empty
710
+ else "N/A"
711
  )
712
  k1, k2, k3, k4 = st.columns(4)
713
  k1.metric("Filtered Issues", total_issues)
 
730
  pie_fig.update_layout(margin=dict(l=10, r=10, t=50, b=10))
731
  st.plotly_chart(pie_fig, use_container_width=True)
732
 
733
+ top_lines = filtered.sort_values(by=["Confidence"], ascending=False).head(
734
+ 10
735
+ )
736
  st.markdown("**Top 10 High-Risk Lines**")
737
  st.dataframe(
738
+ top_lines[
739
+ [
740
+ "Issue Type",
741
+ "Confidence",
742
+ "Page",
743
+ "Line",
744
+ "Vendor",
745
+ "Vendee",
746
+ "Snippet",
747
+ "Reason",
748
+ ]
749
+ ],
750
  use_container_width=True,
751
  )
752
  else:
 
803
  st.caption(f"Only one page with issues: Page {page_min}")
804
  page_range = (page_min, page_max)
805
  else:
806
+ page_range = st.slider(
807
+ "Page range", page_min, page_max, (page_min, page_max)
808
+ )
809
 
810
  if selected:
811
  line_df = line_df[line_df["Issue Type"].isin(selected)]
812
+ line_df = line_df[
813
+ (line_df["Page"] >= page_range[0]) & (line_df["Page"] <= page_range[1])
814
+ ]
815
 
816
  st.dataframe(line_df, use_container_width=True)
817
 
818
  st.markdown("**Issue Occurrence By Line With Parties**")
819
  by_line = line_df.copy()
820
+ by_line = by_line.sort_values(
821
+ by=["Page", "Line", "Confidence"], ascending=[True, True, False]
822
+ )
823
  st.dataframe(
824
+ by_line[
825
+ [
826
+ "Issue Type",
827
+ "Page",
828
+ "Line",
829
+ "Vendor",
830
+ "Vendee",
831
+ "Confidence",
832
+ "Reason",
833
+ ]
834
+ ],
835
  use_container_width=True,
836
  )
837
 
 
840
  line_df = line_df.reset_index(drop=True)
841
  line_df.insert(0, "Item", range(1, len(line_df) + 1))
842
  line_df["Jump"] = line_df.apply(
843
+ lambda r: (
844
+ f"#{r['Item']} | Pg {int(r['Page'])}, Ln {int(r['Line'])} | {r['Issue Type']}"
845
+ ),
846
  axis=1,
847
  )
848
+ selected_jump = st.selectbox(
849
+ "Select issue line", line_df["Jump"].tolist()
850
+ )
851
  chosen = line_df[line_df["Jump"] == selected_jump].iloc[0]
852
 
853
  c1, c2 = st.columns([1.1, 1], gap="large")
 
856
  f"""
857
  <div class="mini-card">
858
  <div class="mini-label">Selected Line</div>
859
+ <div class="mini-value">Pg {int(chosen["Page"])} · Ln {int(chosen["Line"])}</div>
860
+ <div class="mono">{chosen["Issue Type"]} | Confidence: {float(chosen["Confidence"]):.2f}</div>
861
  </div>
862
  """,
863
  unsafe_allow_html=True,
 
872
  if is_pdf and st.session_state.uploaded_bytes:
873
  st.caption("PDF Preview (jumped to selected page)")
874
  page_number = int(chosen["Page"])
875
+ pdf_b64 = base64.b64encode(
876
+ st.session_state.uploaded_bytes
877
+ ).decode("utf-8")
878
  pdf_html = f"""
879
  <iframe
880
  src="data:application/pdf;base64,{pdf_b64}#page={page_number}&zoom=110"
 
885
  """
886
  st.markdown(pdf_html, unsafe_allow_html=True)
887
  else:
888
+ st.info(
889
+ "Inline PDF preview is available for PDF uploads. Current file is not PDF."
890
+ )
891
  else:
892
  st.info("No line-level issues to display.")
893
 
 
900
  file_name="semantic_integrity_report.json",
901
  mime="application/json",
902
  )
903
+ pdf_bytes = generate_pdf_report(
904
+ [r for r in results if r["Label"] != "NO_CONFLICT"]
905
+ )
906
  st.download_button(
907
  label="Download PDF Report",
908
  data=pdf_bytes,