wuhp commited on
Commit
d5cce46
·
verified ·
1 Parent(s): 7a250cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -111
app.py CHANGED
@@ -17,11 +17,15 @@ import gradio as gr
17
  from tqdm import tqdm
18
 
19
  # sklearn (CPU-friendly)
20
- from sklearn.feature_extraction.text import TfidfVectorizer
21
  from sklearn.cluster import MiniBatchKMeans
22
  from sklearn.neighbors import NearestNeighbors
23
  from sklearn.decomposition import TruncatedSVD
24
  from sklearn.preprocessing import Normalizer
 
 
 
 
25
 
26
  # Optional fast ANN (CPU)
27
  try:
@@ -71,6 +75,20 @@ SUSPECT_PHRASES = [
71
  "contract splitting", "grease payment", "unreported", "unrecorded",
72
  ]
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # =================== Label cleanup helpers ===================
75
  EN_STOP = {
76
  "the","of","and","to","in","is","for","on","at","with","from","by","or","as",
@@ -78,7 +96,7 @@ EN_STOP = {
78
  "re","fwd","fw","hi","hello","thanks","thank","regards","best","please","dear","mr","mrs",
79
  "message","original","forwarded","attached","attachment","confidential","notice","disclaimer",
80
  "herein","thereof","hereby","therein","regarding","subject","url","via","kind","regard",
81
- "ny" # short common noise in your set
82
  }
83
  HE_STOP = {
84
  "של","על","זה","גם","אם","לא","את","אתה","אני","הוא","היא","הם","הן","כי","מה",
@@ -109,7 +127,6 @@ def _is_junk_term(t: str) -> bool:
109
  return False
110
 
111
  def _sanitize_top_terms(names: np.ndarray, idxs: np.ndarray, mean_vec: np.ndarray, want:int) -> list:
112
- # Keep order by descending weight in idxs
113
  ordered = idxs[np.argsort(-mean_vec[idxs])]
114
  cleaned = []
115
  for i in ordered:
@@ -119,7 +136,6 @@ def _sanitize_top_terms(names: np.ndarray, idxs: np.ndarray, mean_vec: np.ndarra
119
  cleaned.append(term)
120
  if len(cleaned) >= want:
121
  break
122
- # If we filtered too hard, allow some not-too-bad tokens (but still avoid email-like)
123
  if len(cleaned) < max(2, want//2):
124
  for i in ordered:
125
  term = names[i]
@@ -141,19 +157,14 @@ def html_to_text(html: str) -> str:
141
  return soup.get_text(separator="\n")
142
 
143
  def strip_quotes_and_sigs(text: str) -> str:
144
- """Drop quoted lines, signatures, device footers, forwarded chains."""
145
  if not text:
146
  return ""
147
- # remove > quoted lines
148
  text = QUOTE_LINE_RE.sub("", text)
149
- # cut everything after signature separator
150
  parts = SIG_RE.split(text)
151
  if parts:
152
  text = parts[0]
153
- # remove device footers
154
  text = SENT_FROM_RE.sub("", text)
155
  text = HEBREW_SENT_FROM_RE.sub("", text)
156
- # trim forwarded/quoted chains
157
  cut = None
158
  for pat in (FWD_BEGIN_RE, FWD_MSG_RE, ON_WROTE_RE):
159
  m = pat.search(text)
@@ -165,7 +176,6 @@ def strip_quotes_and_sigs(text: str) -> str:
165
  return text.strip()
166
 
167
  def parse_name_email(s: str) -> Tuple[str, str]:
168
- """Split 'Name <email>' into (name, email)."""
169
  if not s:
170
  return "", ""
171
  m = re.match(r'(?:"?([^"]*)"?\s)?<?([^<>]+@[^<>]+)>?', s)
@@ -174,16 +184,11 @@ def parse_name_email(s: str) -> Tuple[str, str]:
174
  return "", s.strip()
175
 
176
  def parse_email_headers(text: str) -> Tuple[Dict[str, str], str]:
177
- """
178
- Extract inline headers (From, To, CC, Date, Subject) from the text blob.
179
- Returns (headers_dict, remaining_body_text).
180
- """
181
  headers: Dict[str, str] = {}
182
  lines = (text or "").splitlines()
183
  header_pat = re.compile(r'^(From|To|Cc|CC|Bcc|Date|Subject):')
184
  i = 0
185
  saw_header = False
186
-
187
  while i < len(lines):
188
  line = lines[i].rstrip("\r")
189
  stripped = line.strip()
@@ -221,13 +226,11 @@ def parse_email_headers(text: str) -> Tuple[Dict[str, str], str]:
221
  break
222
  else:
223
  break
224
-
225
  body_text = "\n".join(lines[i:]) if i < len(lines) else ""
226
  return headers, body_text
227
 
228
  # =================== Normalization & Utilities ===================
229
  def normalize_email_record(raw: Dict[str, Any], use_langdetect: bool) -> Dict[str, Any]:
230
- """Normalize a single raw record into a structured row."""
231
  if str(raw.get("type", "")).lower() == "meta":
232
  return {}
233
 
@@ -248,7 +251,6 @@ def normalize_email_record(raw: Dict[str, Any], use_langdetect: bool) -> Dict[st
248
  sender = headers.get("From", "") or raw.get("from") or raw.get("From") or ""
249
  date_val = headers.get("Date", "") or date_val
250
 
251
- # Clean body
252
  body_clean = strip_quotes_and_sigs(ftfy.fix_text(body_only or ""))
253
  body_clean = URL_RE.sub(" URL ", body_clean)
254
  body_clean = re.sub(r"\s+", " ", body_clean).strip()
@@ -307,7 +309,6 @@ def normalize_email_record(raw: Dict[str, Any], use_langdetect: bool) -> Dict[st
307
  }
308
 
309
  def has_suspect_tag(text: str) -> List[str]:
310
- """Return list of corruption/suspicion tags present in text."""
311
  tags = []
312
  if not text:
313
  return tags
@@ -330,7 +331,6 @@ def compute_sentiment_column(df: pd.DataFrame) -> pd.DataFrame:
330
  return df
331
  analyzer = SentimentIntensityAnalyzer()
332
  scores = df["body_text"].fillna("").map(lambda t: analyzer.polarity_scores(t)["compound"])
333
- # VADER thresholds: [-1,-0.05), (-0.05,0.05), (0.05,1]
334
  bins = [-1.01, -0.05, 0.05, 1.01]
335
  labels = ["negative", "neutral", "positive"]
336
  df["sentiment_score"] = scores
@@ -338,7 +338,6 @@ def compute_sentiment_column(df: pd.DataFrame) -> pd.DataFrame:
338
  return df
339
 
340
  def build_highlighted_html(row: pd.Series, query_terms: Optional[List[str]] = None, cluster_label: Optional[str] = None) -> str:
341
- """Email reader HTML with highlighted query terms and visible tags."""
342
  subject = (row.get("subject") or "").strip()
343
  body = (row.get("body_text") or "").strip()
344
  from_email = row.get("from_email") or ""
@@ -363,7 +362,6 @@ def build_highlighted_html(row: pd.Series, query_terms: Optional[List[str]] = No
363
  subject_h = hi(subject)
364
  body_h = hi(body)
365
 
366
- # Basic RTL detection for Hebrew/Arabic chars → add dir="rtl"
367
  rtl = bool(re.search(r"[\u0590-\u08FF]", body_h))
368
  dir_attr = ' dir="rtl"' if rtl else ""
369
  body_html = body_h.replace("\n", "<br/>")
@@ -394,31 +392,185 @@ def build_highlighted_html(row: pd.Series, query_terms: Optional[List[str]] = No
394
  )
395
  return html
396
 
397
- def top_terms_per_cluster(X, labels, vectorizer, topn=6):
398
- names = vectorizer.get_feature_names_out()
399
- out = {}
400
- uniq = np.unique(labels)
401
- for c in uniq:
402
- mask = (labels == c)
403
- if mask.sum() == 0:
404
- out[int(c)] = f"cluster_{c}"
405
- continue
406
- # mean TF-IDF per feature inside cluster
407
- mean_vec = X[mask].mean(axis=0).A1
408
- if mean_vec.size == 0:
409
- out[int(c)] = f"cluster_{c}"
410
- continue
411
- # oversample candidates, then filter junk
412
- take = max(topn * 4, topn)
413
- idx = np.argpartition(mean_vec, -take)[-take:]
414
- terms = _sanitize_top_terms(names, idx, mean_vec, want=topn)
415
- out[int(c)] = ", ".join(terms) if terms else f"cluster_{c}"
416
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  def auto_k_rule(n_docs: int) -> int:
419
  # Sublinear scaling; keeps clusters between ~120 and 600 for big corpora
420
  return int(max(120, min(600, math.sqrt(max(n_docs, 1) / 50.0) * 110)))
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  # =================== Gradio UI ===================
423
  CSS = """
424
  :root { --pill:#eef2ff; --pill-text:#1f2937; --tag:#e5e7eb; --tag-text:#111827; }
@@ -439,8 +591,8 @@ hr.sep { border:none; border-top:1px solid #e5e7eb; margin:10px 0; }
439
 
440
  with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="soft") as demo:
441
  gr.Markdown("""
442
- # Email Investigator — TF-IDF + LSA + MiniBatchKMeans
443
- **Goal:** quickly surface potentially corruption-related emails via topic clusters, tags, and sentiment.
444
  """)
445
 
446
  with gr.Row():
@@ -448,7 +600,7 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
448
 
449
  with gr.Accordion("Vectorization & Clustering", open=True):
450
  with gr.Row():
451
- max_features = gr.Number(label="TF-IDF max_features", value=120_000, precision=0)
452
  min_df = gr.Number(label="min_df (doc freq ≥)", value=2, precision=0)
453
  max_df = gr.Slider(label="max_df (fraction ≤)", minimum=0.1, maximum=0.95, value=0.7, step=0.05)
454
  use_bigrams = gr.Checkbox(label="Use bigrams (1–2)", value=True)
@@ -456,11 +608,11 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
456
  with gr.Row():
457
  use_lsa = gr.Checkbox(label="Use LSA (TruncatedSVD) before KMeans", value=True)
458
  lsa_dim = gr.Number(label="LSA components", value=150, precision=0)
459
- auto_k = gr.Checkbox(label="Auto choose k", value=True)
460
  k_clusters = gr.Number(label="k (MiniBatchKMeans)", value=350, precision=0)
461
  mb_batch = gr.Number(label="KMeans batch_size", value=4096, precision=0)
462
  with gr.Row():
463
- use_faiss = gr.Checkbox(label="Use Faiss ANN for search (if available)", value=True)
464
 
465
  with gr.Accordion("Filters", open=True):
466
  with gr.Row():
@@ -495,7 +647,7 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
495
 
496
  # State
497
  state_df = gr.State() # full dataframe
498
- state_vec = gr.State() # TfidfVectorizer
499
  state_X_reduced = gr.State() # np.ndarray (LSA normalized) or None
500
  state_index = gr.State() # Faiss index or sklearn NN
501
  state_term_names = gr.State() # dict cluster_id -> label
@@ -504,6 +656,7 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
504
  state_use_faiss = gr.State()
505
  state_svd = gr.State()
506
  state_norm = gr.State()
 
507
 
508
  # -------- IO helpers --------
509
  def _load_json_records(local_path: str) -> List[Dict[str, Any]]:
@@ -545,7 +698,6 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
545
  ) -> pd.DataFrame:
546
  out = df
547
  if cluster and cluster != "(any)":
548
- # cluster values like "12 — payment, contract (534)"
549
  m = re.match(r"^(\d+)\s+—", cluster)
550
  if m:
551
  cid = int(m.group(1))
@@ -555,9 +707,7 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
555
  if sentiment and sentiment != "(any)" and "sentiment" in out.columns:
556
  out = out[out["sentiment"].astype(str) == sentiment]
557
  if tag_value and tag_value != "(any)":
558
- # tags is a list; check membership robustly
559
  out = out[out["tags"].apply(lambda ts: isinstance(ts, list) and (tag_value in ts))]
560
- # date bounds
561
  if start:
562
  try:
563
  dt = pd.to_datetime(start, utc=True, errors="coerce")
@@ -577,14 +727,14 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
577
  use_lsa, lsa_dim, auto_k, k_clusters, mb_batch, use_faiss):
578
  if inbox_file is None:
579
  return ("**Please upload a file.**",
580
- None, None, None, None, None, None, None, None, None, None, None, None, None, None)
581
 
582
  use_lang = not bool(skip_lang)
583
 
584
  recs = _load_json_records(inbox_file.name)
585
  if not recs:
586
  return ("**No valid records found.**",
587
- None, None, None, None, None, None, None, None, None, None, None, None, None, None)
588
 
589
  # Normalize
590
  normd = []
@@ -595,7 +745,7 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
595
  df = pd.DataFrame(normd)
596
  if df.empty:
597
  return ("**No usable email records after normalization.**",
598
- None, None, None, None, None, None, None, None, None, None, None, None, None, None)
599
 
600
  # Deduplicate conservatively
601
  df = df.drop_duplicates(subset=["message_id", "subject", "text_hash"]).reset_index(drop=True)
@@ -604,12 +754,12 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
604
  df["tags"] = df["body_text"].fillna("").map(has_suspect_tag)
605
  df = compute_sentiment_column(df)
606
 
607
- # Texts for modeling
608
- texts = (df["subject"].fillna("") + "\n\n" + df["body_text"].fillna("")).tolist()
609
 
610
- # TF-IDF (sparse CSR float32)
611
  ngram_range = (1, 2) if use_bigrams else (1, 1)
612
- vec = TfidfVectorizer(
613
  analyzer="word",
614
  ngram_range=ngram_range,
615
  max_features=int(max_features) if max_features else None,
@@ -617,10 +767,22 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
617
  max_df=float(max_df) if max_df else 0.7,
618
  token_pattern=TOKEN_PATTERN,
619
  lowercase=True,
620
- sublinear_tf=True,
621
  dtype=np.float32,
622
  )
623
- X = vec.fit_transform(texts) # CSR float32
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
  # LSA (TruncatedSVD + Normalizer) for stability/quality
626
  use_lsa = bool(use_lsa)
@@ -629,43 +791,70 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
629
  norm_obj = None
630
  if use_lsa:
631
  svd_obj = TruncatedSVD(n_components=int(lsa_dim or 150), random_state=0)
632
- X_reduced_tmp = svd_obj.fit_transform(X) # dense (n_docs x lsa_dim)
633
  norm_obj = Normalizer(copy=False)
634
  X_reduced = norm_obj.fit_transform(X_reduced_tmp).astype(np.float32)
635
  del X_reduced_tmp
636
  gc.collect()
637
 
638
- # KMeans clustering
639
  if bool(auto_k):
640
- k = auto_k_rule(X.shape[0])
 
 
 
 
641
  else:
642
  k = max(10, int(k_clusters or 350))
643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  kmeans = MiniBatchKMeans(
645
  n_clusters=k,
646
  batch_size=int(mb_batch or 4096),
647
  random_state=0,
648
- n_init="auto",
 
649
  )
650
- labels = kmeans.fit_predict(X_reduced if use_lsa else X)
 
 
 
 
 
651
  df["cluster_id"] = labels
652
 
653
- # Name clusters by top terms (use original TF-IDF for interpretability)
654
- term_names = top_terms_per_cluster(X, labels, vec, topn=6)
655
- df["cluster_name"] = [term_names[int(c)] for c in labels]
 
 
 
656
 
657
  # Build search index
658
- use_faiss = bool(use_faiss) and FAISS_OK
659
  index_obj = None
660
- if use_faiss and use_lsa:
661
- # cosine ≈ inner product on normalized vectors
662
  d = X_reduced.shape[1]
663
- index_obj = faiss.IndexFlatIP(d)
664
  index_obj.add(X_reduced)
665
  else:
666
- # fallback to brute-force cosine on TF-IDF or reduced vectors
667
  nn = NearestNeighbors(metric="cosine", algorithm="brute")
668
- nn.fit(X_reduced if use_lsa else X)
669
  index_obj = nn
670
 
671
  # Summaries
@@ -675,7 +864,6 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
675
  .sort_values("count", ascending=False)
676
  .head(500)
677
  )
678
- # For dropdown labels: "id — label (count)"
679
  cluster_counts["label"] = cluster_counts.apply(
680
  lambda r: f'{int(r["cluster_id"])} — {r["cluster_name"]} ({int(r["count"])})', axis=1
681
  )
@@ -689,28 +877,28 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
689
  )
690
  domain_choices = ["(any)"] + domain_counts["from_domain"].tolist()
691
 
692
- # Results preview default (latest 500 by date if available)
693
- if "date" in df.columns and df["date"].notna().any():
694
- show_df = df.copy()
695
- # coerce to datetime for sort
696
  show_df["_dt"] = pd.to_datetime(show_df["date"], utc=True, errors="coerce")
697
- show_df = show_df.sort_values("_dt", ascending=False).drop(columns=["_dt"])
698
  else:
699
- show_df = df.copy()
 
700
 
701
- cols_out = ["date", "from_email", "from_domain", "subject", "cluster_name", "tags", "sentiment"]
702
  out_table = show_df[cols_out].head(500)
703
 
 
 
704
  status_md = (
705
  f"**Processed {len(df):,} emails** \n"
706
- f"TF-IDF shape = {X.shape[0]:,} × {X.shape[1]:,} | "
707
  f"{'LSA: ' + str(X_reduced.shape[1]) + ' dims | ' if use_lsa else ''}"
708
- f"k = {k} | Search = {'Faiss (IP on LSA)' if (use_faiss and use_lsa and FAISS_OK) else 'cosine brute-force'}"
709
  )
710
 
711
  gc.collect()
712
 
713
- # Use gr.update to set dropdown choices + default values safely
714
  cluster_update = gr.update(choices=cluster_choices, value="(any)")
715
  domain_update = gr.update(choices=domain_choices, value="(any)")
716
 
@@ -718,10 +906,11 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
718
  status_md,
719
  cluster_counts, domain_counts,
720
  out_table,
721
- df, vec, (X_reduced if use_lsa else None), index_obj, term_names,
722
- use_lsa, (use_faiss and use_lsa and FAISS_OK),
723
  cluster_update, domain_update,
724
- svd_obj, norm_obj
 
725
  )
726
 
727
  (run_btn.click)(
@@ -734,7 +923,8 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
734
  state_df, state_vec, state_X_reduced, state_index, state_term_names,
735
  state_use_lsa, state_use_faiss,
736
  cluster_drop, domain_drop,
737
- state_svd, state_norm]
 
738
  )
739
 
740
  # -------- Filtering & Search --------
@@ -742,14 +932,13 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
742
  if df is None or len(df) == 0:
743
  return pd.DataFrame()
744
  filt = _apply_filters(df, cluster_choice, domain_choice, sentiment_choice, tag_choice, start, end)
745
- cols_out = ["date", "from_email", "from_domain", "subject", "cluster_name", "tags", "sentiment"]
746
- # default: sort by date desc if possible
747
  if "date" in filt.columns and filt["date"].notna().any():
748
  tmp = filt.copy()
749
  tmp["_dt"] = pd.to_datetime(tmp["date"], utc=True, errors="coerce")
750
- tmp = tmp.sort_values("_dt", ascending=False).drop(columns=["_dt"])
751
  return tmp[cols_out].head(500)
752
- return filt[cols_out].head(500)
753
 
754
  for ctrl in [cluster_drop, domain_drop, sentiment_drop, tag_drop, date_start, date_end]:
755
  ctrl.change(
@@ -758,7 +947,6 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
758
  outputs=[results_df]
759
  )
760
 
761
- # Safer reset: set dropdowns to None (always valid), others to defaults
762
  reset_btn.click(
763
  lambda: [None, None, "(any)", "(any)", "", ""],
764
  inputs=[],
@@ -772,38 +960,43 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
772
  def _tokenize_query(q: str) -> List[str]:
773
  if not q:
774
  return []
775
- # split on spaces, keep simple tokens; dedupe while preserving order
776
  parts = [p.strip() for p in re.split(r"\s+", q) if p.strip()]
777
  seen, out = set(), []
778
  for p in parts:
779
  if p.lower() not in seen:
780
  out.append(p)
781
  seen.add(p.lower())
782
- return out[:8] # limit highlights for performance
783
 
784
  def _project_query_to_lsa(q_vec, svd_obj, norm_obj) -> Optional[np.ndarray]:
785
  try:
786
- q_red = svd_obj.transform(q_vec) # (1, lsa_dim)
787
- q_red = norm_obj.transform(q_red) # normalize
788
  return q_red.astype(np.float32)
789
  except Exception:
790
  return None
791
 
792
- def search_fn(q, df, vec, X_reduced, index_obj, use_lsa_flag, use_faiss_flag, svd_obj, norm_obj):
793
- if (not q) or (df is None) or (vec is None) or (index_obj is None):
 
 
 
 
 
 
 
 
 
 
794
  return pd.DataFrame(), []
795
  q_terms = _tokenize_query(q)
796
-
797
- # Vectorize the query
798
- q_vec = vec.transform([q])
799
-
800
- # Decide which space the index uses and project accordingly
801
  if use_lsa_flag and (X_reduced is not None):
802
- q_emb = _project_query_to_lsa(q_vec, svd_obj, norm_obj)
803
  if q_emb is None:
804
  return pd.DataFrame(), q_terms
805
  else:
806
- q_emb = q_vec
807
 
808
  if isinstance(index_obj, NearestNeighbors):
809
  distances, indices = index_obj.kneighbors(q_emb, n_neighbors=min(50, len(df)))
@@ -811,7 +1004,7 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
811
  sims = 1.0 - distances[0]
812
  results = df.iloc[inds].copy()
813
  results["score"] = sims
814
- elif FAISS_OK and isinstance(index_obj, faiss.Index):
815
  D, I = index_obj.search(q_emb.astype(np.float32), min(50, len(df)))
816
  inds = I[0]
817
  sims = D[0]
@@ -820,7 +1013,12 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
820
  else:
821
  return pd.DataFrame(), q_terms
822
 
823
- cols = ["date", "from_email", "from_domain", "subject", "cluster_name", "tags", "sentiment", "score"]
 
 
 
 
 
824
  return results[cols].head(50), q_terms
825
 
826
  search_btn.click(
@@ -836,12 +1034,10 @@ with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="so
836
  row_idx = evt.index if hasattr(evt, "index") else None
837
  if row_idx is None or table is None or len(table) == 0 or df is None or len(df) == 0:
838
  return ""
839
- # Get identifying columns from the table row to map back to original df row
840
  sel = table.iloc[row_idx]
841
  subj = sel.get("subject", None)
842
  frm = sel.get("from_email", None)
843
  dstr = sel.get("date", None)
844
- # match in original df
845
  cand = df
846
  if subj is not None:
847
  cand = cand[cand["subject"] == subj]
 
17
  from tqdm import tqdm
18
 
19
  # sklearn (CPU-friendly)
20
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer as CharTfidf
21
  from sklearn.cluster import MiniBatchKMeans
22
  from sklearn.neighbors import NearestNeighbors
23
  from sklearn.decomposition import TruncatedSVD
24
  from sklearn.preprocessing import Normalizer
25
+ from sklearn.preprocessing import normalize as sk_normalize
26
+ from sklearn.metrics.pairwise import cosine_similarity
27
+
28
+ from scipy.sparse import hstack
29
 
30
  # Optional fast ANN (CPU)
31
  try:
 
75
  "contract splitting", "grease payment", "unreported", "unrecorded",
76
  ]
77
 
78
+ # Entity regexes for enrichment/scoring
79
+ MONEY_RE = re.compile(r'(\$|USD|EUR|ILS|NIS)\s?\d[\d,.\s]*', re.I)
80
+ PHONE_RE = re.compile(r'(\+?\d{1,3}[-\s.]?)?(\(?\d{2,4}\)?[-\s.]?)?\d{3,4}[-\s.]?\d{4}')
81
+ INVOICE_RE = re.compile(r'\b(invoice|inv\.\s?\d+|po\s?#?\d+|purchase order)\b', re.I)
82
+ COMPANY_RE = re.compile(r'\b(LLC|Ltd|Limited|Inc|GmbH|S\.A\.|S\.p\.A\.)\b')
83
+
84
+ # Optional seeded themes for semi-supervised init (used only when LSA is ON)
85
+ CORR_LEX = {
86
+ "kickback" : ["kickback","bribe","under the table","gift","cash"],
87
+ "invoice_fraud" : ["false invoice","ghost employee","contract splitting","slush fund","shell company","front company"],
88
+ "procurement" : ["bid rigging","tender","vendor","sole source","rfp","rfq","purchase order","po"],
89
+ "money_flow" : ["wire transfer","transfer","swift","iban","routing number","account number","cash"]
90
+ }
91
+
92
  # =================== Label cleanup helpers ===================
93
  EN_STOP = {
94
  "the","of","and","to","in","is","for","on","at","with","from","by","or","as",
 
96
  "re","fwd","fw","hi","hello","thanks","thank","regards","best","please","dear","mr","mrs",
97
  "message","original","forwarded","attached","attachment","confidential","notice","disclaimer",
98
  "herein","thereof","hereby","therein","regarding","subject","url","via","kind","regard",
99
+ "ny"
100
  }
101
  HE_STOP = {
102
  "של","על","זה","גם","אם","לא","את","אתה","אני","הוא","היא","הם","הן","כי","מה",
 
127
  return False
128
 
129
  def _sanitize_top_terms(names: np.ndarray, idxs: np.ndarray, mean_vec: np.ndarray, want:int) -> list:
 
130
  ordered = idxs[np.argsort(-mean_vec[idxs])]
131
  cleaned = []
132
  for i in ordered:
 
136
  cleaned.append(term)
137
  if len(cleaned) >= want:
138
  break
 
139
  if len(cleaned) < max(2, want//2):
140
  for i in ordered:
141
  term = names[i]
 
157
  return soup.get_text(separator="\n")
158
 
159
  def strip_quotes_and_sigs(text: str) -> str:
 
160
  if not text:
161
  return ""
 
162
  text = QUOTE_LINE_RE.sub("", text)
 
163
  parts = SIG_RE.split(text)
164
  if parts:
165
  text = parts[0]
 
166
  text = SENT_FROM_RE.sub("", text)
167
  text = HEBREW_SENT_FROM_RE.sub("", text)
 
168
  cut = None
169
  for pat in (FWD_BEGIN_RE, FWD_MSG_RE, ON_WROTE_RE):
170
  m = pat.search(text)
 
176
  return text.strip()
177
 
178
  def parse_name_email(s: str) -> Tuple[str, str]:
 
179
  if not s:
180
  return "", ""
181
  m = re.match(r'(?:"?([^"]*)"?\s)?<?([^<>]+@[^<>]+)>?', s)
 
184
  return "", s.strip()
185
 
186
  def parse_email_headers(text: str) -> Tuple[Dict[str, str], str]:
 
 
 
 
187
  headers: Dict[str, str] = {}
188
  lines = (text or "").splitlines()
189
  header_pat = re.compile(r'^(From|To|Cc|CC|Bcc|Date|Subject):')
190
  i = 0
191
  saw_header = False
 
192
  while i < len(lines):
193
  line = lines[i].rstrip("\r")
194
  stripped = line.strip()
 
226
  break
227
  else:
228
  break
 
229
  body_text = "\n".join(lines[i:]) if i < len(lines) else ""
230
  return headers, body_text
231
 
232
  # =================== Normalization & Utilities ===================
233
  def normalize_email_record(raw: Dict[str, Any], use_langdetect: bool) -> Dict[str, Any]:
 
234
  if str(raw.get("type", "")).lower() == "meta":
235
  return {}
236
 
 
251
  sender = headers.get("From", "") or raw.get("from") or raw.get("From") or ""
252
  date_val = headers.get("Date", "") or date_val
253
 
 
254
  body_clean = strip_quotes_and_sigs(ftfy.fix_text(body_only or ""))
255
  body_clean = URL_RE.sub(" URL ", body_clean)
256
  body_clean = re.sub(r"\s+", " ", body_clean).strip()
 
309
  }
310
 
311
  def has_suspect_tag(text: str) -> List[str]:
 
312
  tags = []
313
  if not text:
314
  return tags
 
331
  return df
332
  analyzer = SentimentIntensityAnalyzer()
333
  scores = df["body_text"].fillna("").map(lambda t: analyzer.polarity_scores(t)["compound"])
 
334
  bins = [-1.01, -0.05, 0.05, 1.01]
335
  labels = ["negative", "neutral", "positive"]
336
  df["sentiment_score"] = scores
 
338
  return df
339
 
340
  def build_highlighted_html(row: pd.Series, query_terms: Optional[List[str]] = None, cluster_label: Optional[str] = None) -> str:
 
341
  subject = (row.get("subject") or "").strip()
342
  body = (row.get("body_text") or "").strip()
343
  from_email = row.get("from_email") or ""
 
362
  subject_h = hi(subject)
363
  body_h = hi(body)
364
 
 
365
  rtl = bool(re.search(r"[\u0590-\u08FF]", body_h))
366
  dir_attr = ' dir="rtl"' if rtl else ""
367
  body_html = body_h.replace("\n", "<br/>")
 
392
  )
393
  return html
394
 
395
+ # =================== Feature engineering (BM25 + char) ===================
396
+ class BM25Transformer:
397
+ def __init__(self, k1=1.2, b=0.75):
398
+ self.k1 = k1
399
+ self.b = b
400
+ self.idf_ = None
401
+ self.avgdl_ = None
402
+
403
+ def fit(self, X):
404
+ # X is term-frequency (CountVectorizer)
405
+ N = X.shape[0]
406
+ # document frequency per term
407
+ df = np.bincount(X.tocsc().indices, minlength=X.shape[1]).astype(np.float64)
408
+ self.idf_ = np.log((N - df + 0.5) / (df + 0.5 + 1e-12))
409
+ dl = np.asarray(X.sum(axis=1)).ravel()
410
+ self.avgdl_ = float(dl.mean() if dl.size else 1.0)
411
+ return self
412
+
413
+ def transform(self, X):
414
+ X = X.tocsr(copy=True).astype(np.float32)
415
+ dl = np.asarray(X.sum(axis=1)).ravel()
416
+ k1, b, avgdl = self.k1, self.b, self.avgdl_
417
+ rows, cols = X.nonzero()
418
+ data = X.data
419
+ for i in range(len(data)):
420
+ tf = data[i]
421
+ d = rows[i]
422
+ denom = tf + k1 * (1 - b + b * (dl[d] / (avgdl + 1e-12)))
423
+ data[i] = (self.idf_[cols[i]] * (tf * (k1 + 1))) / (denom + 1e-12)
424
+ return X
425
+
426
+ # Add enrichment tokens to help the model lock onto key signals
427
+ def enrich_text(row: pd.Series) -> str:
428
+ subj = row.get("subject","") or ""
429
+ body = row.get("body_text","") or ""
430
+ t = subj + "\n\n" + body
431
+ tokens = []
432
+ if MONEY_RE.search(t): tokens.append("__HAS_MONEY__")
433
+ if PHONE_RE.search(t): tokens.append("__HAS_PHONE__")
434
+ if INVOICE_RE.search(t): tokens.append("__HAS_INVOICE__")
435
+ if COMPANY_RE.search(t): tokens.append("__HAS_COMPANY__")
436
+ return (t + " " + " ".join(tokens)).strip()
437
+
438
+ # =================== Cluster labeling: PMI bigrams ===================
439
+ def cluster_labels_pmi_bigram(texts, labels, topn=6):
440
+ def bigrams(t):
441
+ toks = re.findall(TOKEN_PATTERN, t.lower())
442
+ return [" ".join(p) for p in zip(toks, toks[1:])]
443
+ N = len(texts)
444
+ from collections import Counter
445
+ import math as _math
446
+ glob_bg = Counter()
447
+ per_c = {int(c): Counter() for c in np.unique(labels)}
448
+ for t, c in zip(texts, labels):
449
+ bgs = set(bigrams(t))
450
+ glob_bg.update(bgs)
451
+ per_c[int(c)].update(bgs)
452
+ labels_out = {}
453
+ total_bg = sum(glob_bg.values()) + 1e-9
454
+ for c in np.unique(labels):
455
+ c = int(c)
456
+ scores = []
457
+ total_c = sum(per_c[c].values()) + 1e-9
458
+ for bg, cnt in per_c[c].most_common(1000):
459
+ p_bg_c = cnt / total_c
460
+ p_bg = (glob_bg[bg] / total_bg)
461
+ if p_bg > 0 and p_bg_c > 0:
462
+ score = _math.log(p_bg_c) - _math.log(p_bg)
463
+ scores.append((score, bg))
464
+ scores.sort(reverse=True)
465
+ top = [bg for _, bg in scores[:topn]]
466
+ labels_out[c] = ", ".join(top) if top else f"cluster_{c}"
467
+ return labels_out
468
+
469
+ # =================== Auto-k (Kneedle on inertia) ===================
470
+ def choose_k_by_kneedle(X, ks=(50,100,150,200,300,400,500)):
471
+ n = X.shape[0]
472
+ if n > 40000:
473
+ rs = np.random.RandomState(0)
474
+ idx = rs.choice(n, size=40000, replace=False)
475
+ Xs = X[idx]
476
+ else:
477
+ Xs = X
478
+ inertias = []
479
+ for k in ks:
480
+ km = MiniBatchKMeans(n_clusters=k, batch_size=4096, random_state=0, n_init="auto")
481
+ km.fit(Xs)
482
+ inertias.append(km.inertia_)
483
+ x = np.array(list(ks), dtype=float)
484
+ y = np.array(inertias, dtype=float)
485
+ y_norm = (y - y.min()) / (y.max() - y.min() + 1e-9)
486
+ x_norm = (x - x.min()) / (x.max() - x.min() + 1e-9)
487
+ chord = y_norm[0] + (y_norm[-1] - y_norm[0]) * (x_norm - x_norm[0])/(x_norm[-1]-x_norm[0]+1e-9)
488
+ dist = chord - y_norm
489
+ k_best = int(x[np.argmax(dist)])
490
+ return k_best, dict(zip(ks, inertias))
491
 
492
  def auto_k_rule(n_docs: int) -> int:
493
  # Sublinear scaling; keeps clusters between ~120 and 600 for big corpora
494
  return int(max(120, min(600, math.sqrt(max(n_docs, 1) / 50.0) * 110)))
495
 
496
+ # =================== Merge close clusters (LSA space only to save RAM) ===================
497
+ def merge_close_clusters(labels, centers, thresh=0.92):
498
+ centers = sk_normalize(centers)
499
+ sim = cosine_similarity(centers, centers)
500
+ k = centers.shape[0]
501
+ parent = list(range(k))
502
+ def find(a):
503
+ while parent[a]!=a: a=parent[a]
504
+ return a
505
+ for i in range(k):
506
+ for j in range(i+1, k):
507
+ if sim[i,j] >= thresh:
508
+ pi, pj = find(i), find(j)
509
+ if pi!=pj: parent[pj]=pi
510
+ root = {i:find(i) for i in range(k)}
511
+ idmap, new_id = {}, 0
512
+ for i in range(k):
513
+ r = root[i]
514
+ if r not in idmap:
515
+ idmap[r] = new_id
516
+ new_id += 1
517
+ labels2 = np.array([idmap[root[int(c)]] for c in labels], dtype=int)
518
+ return labels2
519
+
520
+ # =================== Seeded centroids (only if LSA enabled) ===================
521
+ def seeded_centroids_in_lsa(lexicons: Dict[str, List[str]], count_vec: CountVectorizer,
522
+ lsa_components: np.ndarray, norm_obj: Normalizer,
523
+ d_word: int, d_full: int, k: int) -> Optional[np.ndarray]:
524
+ # Build a few unit vectors in word-term space based on lexicons
525
+ seeds_word = []
526
+ vocab = count_vec.vocabulary_
527
+ for _, words in lexicons.items():
528
+ idxs = [vocab.get(w.lower()) for w in words if vocab.get(w.lower()) is not None]
529
+ if not idxs:
530
+ continue
531
+ v = np.zeros((d_word,), dtype=np.float32)
532
+ v[idxs] = 1.0
533
+ n = np.linalg.norm(v)
534
+ if n > 0:
535
+ v /= n
536
+ seeds_word.append(v)
537
+ if not seeds_word:
538
+ return None
539
+ # Lift to full feature space (word + char) by padding zeros for char dims
540
+ seeds_full = []
541
+ for v in seeds_word:
542
+ vf = np.zeros((d_full,), dtype=np.float32)
543
+ vf[:d_word] = v
544
+ seeds_full.append(vf)
545
+ seeds_full = np.stack(seeds_full, axis=0) # (s, n_features)
546
+ # Project to LSA space: x @ components_.T then normalize
547
+ seeds_red = seeds_full @ lsa_components.T # (s, lsa_dim)
548
+ seeds_red = norm_obj.transform(seeds_red.astype(np.float32))
549
+ # If fewer than k seeds, KMeans will accept; scikit-learn requires init shape == (k, d)
550
+ # We’ll return only if seeds count >= 2 to be meaningful; otherwise None
551
+ if seeds_red.shape[0] >= 2 and seeds_red.shape[0] <= k:
552
+ return seeds_red
553
+ return None
554
+
555
+ # =================== Corruption scoring ===================
556
+ def corruption_score(row):
557
+ score = 0.0
558
+ txt = f'{row.get("subject","")} {row.get("body_text","")}'.lower()
559
+ for ph in SUSPECT_PHRASES:
560
+ if ph in txt:
561
+ score += 2.0
562
+ break
563
+ if isinstance(row.get("tags"), list) and ("🚩suspect" in row["tags"] or "finance" in row["tags"]):
564
+ score += 1.5
565
+ if MONEY_RE.search(txt): score += 0.7
566
+ if INVOICE_RE.search(txt): score += 0.7
567
+ if str(row.get("sentiment","")) == "negative":
568
+ score += 0.3
569
+ body_len = len(row.get("body_text",""))
570
+ if body_len < 160 and PHONE_RE.search(row.get("body_text","") or ""):
571
+ score += 0.5
572
+ return score
573
+
574
  # =================== Gradio UI ===================
575
  CSS = """
576
  :root { --pill:#eef2ff; --pill-text:#1f2937; --tag:#e5e7eb; --tag-text:#111827; }
 
591
 
592
  with gr.Blocks(title="Email Investigator (Corruption Focus)", css=CSS, theme="soft") as demo:
593
  gr.Markdown("""
594
+ # Email Investigator — BM25 + Char-grams + (optional) LSA MiniBatchKMeans
595
+ **Goal:** quickly surface potentially corruption-related emails via topic clusters, tags, corruption score, and sentiment.
596
  """)
597
 
598
  with gr.Row():
 
600
 
601
  with gr.Accordion("Vectorization & Clustering", open=True):
602
  with gr.Row():
603
+ max_features = gr.Number(label="Word max_features (BM25)", value=120_000, precision=0)
604
  min_df = gr.Number(label="min_df (doc freq ≥)", value=2, precision=0)
605
  max_df = gr.Slider(label="max_df (fraction ≤)", minimum=0.1, maximum=0.95, value=0.7, step=0.05)
606
  use_bigrams = gr.Checkbox(label="Use bigrams (1–2)", value=True)
 
608
  with gr.Row():
609
  use_lsa = gr.Checkbox(label="Use LSA (TruncatedSVD) before KMeans", value=True)
610
  lsa_dim = gr.Number(label="LSA components", value=150, precision=0)
611
+ auto_k = gr.Checkbox(label="Auto choose k (kneedle)", value=True)
612
  k_clusters = gr.Number(label="k (MiniBatchKMeans)", value=350, precision=0)
613
  mb_batch = gr.Number(label="KMeans batch_size", value=4096, precision=0)
614
  with gr.Row():
615
+ use_faiss = gr.Checkbox(label="Use Faiss ANN for search (if available & LSA on)", value=True)
616
 
617
  with gr.Accordion("Filters", open=True):
618
  with gr.Row():
 
647
 
648
  # State
649
  state_df = gr.State() # full dataframe
650
+ state_vec = gr.State() # {"count_vec":..., "char_vec":..., "bm25":...}
651
  state_X_reduced = gr.State() # np.ndarray (LSA normalized) or None
652
  state_index = gr.State() # Faiss index or sklearn NN
653
  state_term_names = gr.State() # dict cluster_id -> label
 
656
  state_use_faiss = gr.State()
657
  state_svd = gr.State()
658
  state_norm = gr.State()
659
+ state_dims = gr.State() # (d_word, d_char)
660
 
661
  # -------- IO helpers --------
662
  def _load_json_records(local_path: str) -> List[Dict[str, Any]]:
 
698
  ) -> pd.DataFrame:
699
  out = df
700
  if cluster and cluster != "(any)":
 
701
  m = re.match(r"^(\d+)\s+—", cluster)
702
  if m:
703
  cid = int(m.group(1))
 
707
  if sentiment and sentiment != "(any)" and "sentiment" in out.columns:
708
  out = out[out["sentiment"].astype(str) == sentiment]
709
  if tag_value and tag_value != "(any)":
 
710
  out = out[out["tags"].apply(lambda ts: isinstance(ts, list) and (tag_value in ts))]
 
711
  if start:
712
  try:
713
  dt = pd.to_datetime(start, utc=True, errors="coerce")
 
727
  use_lsa, lsa_dim, auto_k, k_clusters, mb_batch, use_faiss):
728
  if inbox_file is None:
729
  return ("**Please upload a file.**",
730
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
731
 
732
  use_lang = not bool(skip_lang)
733
 
734
  recs = _load_json_records(inbox_file.name)
735
  if not recs:
736
  return ("**No valid records found.**",
737
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
738
 
739
  # Normalize
740
  normd = []
 
745
  df = pd.DataFrame(normd)
746
  if df.empty:
747
  return ("**No usable email records after normalization.**",
748
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
749
 
750
  # Deduplicate conservatively
751
  df = df.drop_duplicates(subset=["message_id", "subject", "text_hash"]).reset_index(drop=True)
 
754
  df["tags"] = df["body_text"].fillna("").map(has_suspect_tag)
755
  df = compute_sentiment_column(df)
756
 
757
+ # Enriched texts (adds __HAS_*__ flags)
758
+ texts = list(df.apply(enrich_text, axis=1))
759
 
760
+ # === Vectorization: BM25 word + char tf-idf, then optional LSA ===
761
  ngram_range = (1, 2) if use_bigrams else (1, 1)
762
+ count_vec = CountVectorizer(
763
  analyzer="word",
764
  ngram_range=ngram_range,
765
  max_features=int(max_features) if max_features else None,
 
767
  max_df=float(max_df) if max_df else 0.7,
768
  token_pattern=TOKEN_PATTERN,
769
  lowercase=True,
 
770
  dtype=np.float32,
771
  )
772
+ TF = count_vec.fit_transform(texts)
773
+ bm25 = BM25Transformer(k1=1.2, b=0.75).fit(TF)
774
+ X_word = bm25.transform(TF) # sparse BM25 word matrix
775
+
776
+ char_vec = CharTfidf(
777
+ analyzer="char", ngram_range=(3,5), min_df=2, max_features=100_000,
778
+ lowercase=True, dtype=np.float32
779
+ )
780
+ X_char = char_vec.fit_transform(texts)
781
+
782
+ X_full = hstack([X_word, X_char], format="csr")
783
+ d_word = X_word.shape[1]
784
+ d_char = X_char.shape[1]
785
+ d_full = X_full.shape[1]
786
 
787
  # LSA (TruncatedSVD + Normalizer) for stability/quality
788
  use_lsa = bool(use_lsa)
 
791
  norm_obj = None
792
  if use_lsa:
793
  svd_obj = TruncatedSVD(n_components=int(lsa_dim or 150), random_state=0)
794
+ X_reduced_tmp = svd_obj.fit_transform(X_full) # dense (n_docs x lsa_dim)
795
  norm_obj = Normalizer(copy=False)
796
  X_reduced = norm_obj.fit_transform(X_reduced_tmp).astype(np.float32)
797
  del X_reduced_tmp
798
  gc.collect()
799
 
800
+ # K selection
801
  if bool(auto_k):
802
+ if use_lsa:
803
+ k, _ = choose_k_by_kneedle(X_reduced, ks=(50,100,150,200,300,400,500))
804
+ else:
805
+ # fallback: heuristic rule on doc count
806
+ k = auto_k_rule(X_full.shape[0])
807
  else:
808
  k = max(10, int(k_clusters or 350))
809
 
810
+ # Optional seeded init (only in LSA space to keep memory sane)
811
+ init = None
812
+ if use_lsa:
813
+ seeds = seeded_centroids_in_lsa(
814
+ CORR_LEX, count_vec, svd_obj.components_, norm_obj,
815
+ d_word=d_word, d_full=d_full, k=k
816
+ )
817
+ if seeds is not None and seeds.shape[0] <= k:
818
+ # If fewer seeds than k, KMeans will handle by k-means++ for remaining centers internally only for KMeans.
819
+ # For MiniBatchKMeans, we must provide exactly k centers or fall back to k-means++.
820
+ # So use seeds only if seeds.shape[0] == k; otherwise None.
821
+ if seeds.shape[0] == k:
822
+ init = seeds
823
+
824
+ # KMeans clustering (use LSA space if enabled)
825
+ X_space = (X_reduced if use_lsa else X_full)
826
  kmeans = MiniBatchKMeans(
827
  n_clusters=k,
828
  batch_size=int(mb_batch or 4096),
829
  random_state=0,
830
+ n_init="auto" if init is None else 1,
831
+ init="k-means++" if init is None else init
832
  )
833
+ labels = kmeans.fit_predict(X_space)
834
+
835
+ # Optional: merge very-similar clusters (only when LSA enabled)
836
+ if use_lsa:
837
+ labels = merge_close_clusters(labels, kmeans.cluster_centers_, thresh=0.92)
838
+
839
  df["cluster_id"] = labels
840
 
841
+ # Name clusters by PMI bigrams on raw enriched texts
842
+ term_names = cluster_labels_pmi_bigram(texts, labels, topn=6)
843
+ df["cluster_name"] = [term_names.get(int(c), f"cluster_{int(c)}") for c in labels]
844
+
845
+ # CorruptionScore
846
+ df["corruption_score"] = df.apply(corruption_score, axis=1)
847
 
848
  # Build search index
849
+ use_faiss = bool(use_faiss) and FAISS_OK and use_lsa and (X_reduced is not None)
850
  index_obj = None
851
+ if use_faiss:
 
852
  d = X_reduced.shape[1]
853
+ index_obj = faiss.IndexFlatIP(d) # cosine ~ inner product on normalized vectors
854
  index_obj.add(X_reduced)
855
  else:
 
856
  nn = NearestNeighbors(metric="cosine", algorithm="brute")
857
+ nn.fit(X_space)
858
  index_obj = nn
859
 
860
  # Summaries
 
864
  .sort_values("count", ascending=False)
865
  .head(500)
866
  )
 
867
  cluster_counts["label"] = cluster_counts.apply(
868
  lambda r: f'{int(r["cluster_id"])} — {r["cluster_name"]} ({int(r["count"])})', axis=1
869
  )
 
877
  )
878
  domain_choices = ["(any)"] + domain_counts["from_domain"].tolist()
879
 
880
+ # Results preview default: rank by corruption_score then date desc
881
+ show_df = df.copy()
882
+ if "date" in show_df.columns and show_df["date"].notna().any():
 
883
  show_df["_dt"] = pd.to_datetime(show_df["date"], utc=True, errors="coerce")
 
884
  else:
885
+ show_df["_dt"] = pd.NaT
886
+ show_df = show_df.sort_values(["corruption_score","_dt"], ascending=[False, False]).drop(columns=["_dt"])
887
 
888
+ cols_out = ["date", "from_email", "from_domain", "subject", "cluster_name", "tags", "sentiment", "corruption_score"]
889
  out_table = show_df[cols_out].head(500)
890
 
891
+ vec_state = {"count_vec": count_vec, "char_vec": char_vec, "bm25": bm25}
892
+
893
  status_md = (
894
  f"**Processed {len(df):,} emails** \n"
895
+ f"Word feats (BM25): {d_word:,} | Char feats: {d_char:,} | Total: {d_full:,} \n"
896
  f"{'LSA: ' + str(X_reduced.shape[1]) + ' dims | ' if use_lsa else ''}"
897
+ f"k = {k} | Search = {'Faiss (IP on LSA)' if use_faiss else 'cosine brute-force'}"
898
  )
899
 
900
  gc.collect()
901
 
 
902
  cluster_update = gr.update(choices=cluster_choices, value="(any)")
903
  domain_update = gr.update(choices=domain_choices, value="(any)")
904
 
 
906
  status_md,
907
  cluster_counts, domain_counts,
908
  out_table,
909
+ df, vec_state, (X_reduced if use_lsa else None), index_obj, term_names,
910
+ use_lsa, bool(use_faiss),
911
  cluster_update, domain_update,
912
+ svd_obj, norm_obj,
913
+ (d_word, d_char)
914
  )
915
 
916
  (run_btn.click)(
 
923
  state_df, state_vec, state_X_reduced, state_index, state_term_names,
924
  state_use_lsa, state_use_faiss,
925
  cluster_drop, domain_drop,
926
+ state_svd, state_norm,
927
+ state_dims]
928
  )
929
 
930
  # -------- Filtering & Search --------
 
932
  if df is None or len(df) == 0:
933
  return pd.DataFrame()
934
  filt = _apply_filters(df, cluster_choice, domain_choice, sentiment_choice, tag_choice, start, end)
935
+ cols_out = ["date", "from_email", "from_domain", "subject", "cluster_name", "tags", "sentiment", "corruption_score"]
 
936
  if "date" in filt.columns and filt["date"].notna().any():
937
  tmp = filt.copy()
938
  tmp["_dt"] = pd.to_datetime(tmp["date"], utc=True, errors="coerce")
939
+ tmp = tmp.sort_values(["corruption_score","_dt"], ascending=[False, False]).drop(columns=["_dt"])
940
  return tmp[cols_out].head(500)
941
+ return filt.sort_values(["corruption_score"], ascending=False)[cols_out].head(500)
942
 
943
  for ctrl in [cluster_drop, domain_drop, sentiment_drop, tag_drop, date_start, date_end]:
944
  ctrl.change(
 
947
  outputs=[results_df]
948
  )
949
 
 
950
  reset_btn.click(
951
  lambda: [None, None, "(any)", "(any)", "", ""],
952
  inputs=[],
 
960
  def _tokenize_query(q: str) -> List[str]:
961
  if not q:
962
  return []
 
963
  parts = [p.strip() for p in re.split(r"\s+", q) if p.strip()]
964
  seen, out = set(), []
965
  for p in parts:
966
  if p.lower() not in seen:
967
  out.append(p)
968
  seen.add(p.lower())
969
+ return out[:8]
970
 
971
  def _project_query_to_lsa(q_vec, svd_obj, norm_obj) -> Optional[np.ndarray]:
972
  try:
973
+ q_red = svd_obj.transform(q_vec)
974
+ q_red = norm_obj.transform(q_red)
975
  return q_red.astype(np.float32)
976
  except Exception:
977
  return None
978
 
979
+ def _vectorize_query(q: str, vec_state: Dict[str, Any]):
980
+ count_vec = vec_state["count_vec"]
981
+ char_vec = vec_state["char_vec"]
982
+ bm25 = vec_state["bm25"]
983
+ q_word_tf = count_vec.transform([q])
984
+ q_word = bm25.transform(q_word_tf)
985
+ q_char = char_vec.transform([q])
986
+ q_full = hstack([q_word, q_char], format="csr")
987
+ return q_full
988
+
989
+ def search_fn(q, df, vec_state, X_reduced, index_obj, use_lsa_flag, use_faiss_flag, svd_obj, norm_obj):
990
+ if (not q) or (df is None) or (vec_state is None) or (index_obj is None):
991
  return pd.DataFrame(), []
992
  q_terms = _tokenize_query(q)
993
+ q_vec_full = _vectorize_query(q, vec_state)
 
 
 
 
994
  if use_lsa_flag and (X_reduced is not None):
995
+ q_emb = _project_query_to_lsa(q_vec_full, svd_obj, norm_obj)
996
  if q_emb is None:
997
  return pd.DataFrame(), q_terms
998
  else:
999
+ q_emb = q_vec_full
1000
 
1001
  if isinstance(index_obj, NearestNeighbors):
1002
  distances, indices = index_obj.kneighbors(q_emb, n_neighbors=min(50, len(df)))
 
1004
  sims = 1.0 - distances[0]
1005
  results = df.iloc[inds].copy()
1006
  results["score"] = sims
1007
+ elif FAISS_OK and use_faiss_flag and isinstance(index_obj, faiss.Index):
1008
  D, I = index_obj.search(q_emb.astype(np.float32), min(50, len(df)))
1009
  inds = I[0]
1010
  sims = D[0]
 
1013
  else:
1014
  return pd.DataFrame(), q_terms
1015
 
1016
+ cols = ["date", "from_email", "from_domain", "subject", "cluster_name", "tags", "sentiment", "corruption_score", "score"]
1017
+ # Rerank by a blend: 0.7 * ANN score + 0.3 * corruption_score (scaled)
1018
+ cs = results["corruption_score"].fillna(0.0)
1019
+ cs = (cs - cs.min()) / (cs.max() - cs.min() + 1e-9)
1020
+ results["_blend"] = 0.7*results["score"].values + 0.3*cs.values
1021
+ results = results.sort_values("_blend", ascending=False).drop(columns=["_blend"])
1022
  return results[cols].head(50), q_terms
1023
 
1024
  search_btn.click(
 
1034
  row_idx = evt.index if hasattr(evt, "index") else None
1035
  if row_idx is None or table is None or len(table) == 0 or df is None or len(df) == 0:
1036
  return ""
 
1037
  sel = table.iloc[row_idx]
1038
  subj = sel.get("subject", None)
1039
  frm = sel.get("from_email", None)
1040
  dstr = sel.get("date", None)
 
1041
  cand = df
1042
  if subj is not None:
1043
  cand = cand[cand["subject"] == subj]