wuhp commited on
Commit
9ef7e16
·
verified ·
1 Parent(s): b4b271e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -171
app.py CHANGED
@@ -15,27 +15,14 @@ DetectorFactory.seed = 0
15
  import gradio as gr
16
  from tqdm import tqdm
17
 
18
- from sentence_transformers import SentenceTransformer
19
- import faiss
20
-
21
- try:
22
- import hdbscan
23
- HDBSCAN_AVAILABLE = True
24
- except Exception:
25
- HDBSCAN_AVAILABLE = False
26
 
27
  # ------------------- Helpers -------------------
28
  URL = re.compile(r"https?://\S+", re.I)
29
- SKIP_LANGDETECT = True # CPU-friendly default; can be toggled in the UI
30
-
31
-
32
- def torch_cuda_available():
33
- try:
34
- import torch
35
- return torch.cuda.is_available()
36
- except Exception:
37
- return False
38
-
39
 
40
  def html_to_text(html: str) -> str:
41
  if not html:
@@ -45,7 +32,6 @@ def html_to_text(html: str) -> str:
45
  tag.decompose()
46
  return soup.get_text(separator="\n")
47
 
48
-
49
  def strip_quotes_and_sigs(text: str) -> str:
50
  if not text:
51
  return ""
@@ -54,7 +40,6 @@ def strip_quotes_and_sigs(text: str) -> str:
54
  res = re.split(r"\nSent from my ", res)[0]
55
  return res.strip()
56
 
57
-
58
  def parse_name_email(s: str) -> Tuple[str, str]:
59
  if not s:
60
  return "", ""
@@ -64,7 +49,6 @@ def parse_name_email(s: str) -> Tuple[str, str]:
64
  return "", s.strip()
65
 
66
  # ------------------- Normalization -------------------
67
-
68
  def normalize_email_record(raw: Dict[str, Any]) -> Dict[str, Any]:
69
  subject = raw.get("subject") or raw.get("Subject") or ""
70
  body_html = raw.get("body_html") or raw.get("html") or ""
@@ -128,134 +112,46 @@ def normalize_email_record(raw: Dict[str, Any]) -> Dict[str, Any]:
128
  "text_hash": text_hash,
129
  }
130
 
131
- # ------------------- Embeddings & Clustering -------------------
132
-
133
- def embed_texts(
134
- model: SentenceTransformer,
135
- texts: List[str],
136
- batch_size: int,
137
- use_gpu: bool,
138
- use_multiprocess: bool = True
139
- ) -> np.ndarray:
140
- """
141
- Faster CPU path: try multi-process first; fall back to single-process batching.
142
- """
143
- if not use_gpu and use_multiprocess and (os.cpu_count() or 1) >= 2:
144
- try:
145
- pool = model.start_multi_process_pool()
146
- arr = model.encode_multi_process(texts, pool, normalize_embeddings=True)
147
- model.stop_multi_process_pool(pool)
148
- return np.asarray(arr, dtype=np.float32)
149
- except Exception:
150
- pass # fallback below
151
-
152
- embs = []
153
- for i in tqdm(range(0, len(texts), batch_size), desc="Embedding", leave=False):
154
- chunk = texts[i:i + batch_size]
155
- embs.append(model.encode(
156
- chunk,
157
- batch_size=min(batch_size, len(chunk)),
158
- show_progress_bar=False,
159
- normalize_embeddings=True,
160
- convert_to_numpy=True,
161
- device="cuda" if use_gpu else "cpu",
162
- ))
163
- return np.vstack(embs).astype(np.float32)
164
-
165
-
166
- def cluster_embeddings(embs: np.ndarray, method: str, min_cluster_size: int, k_hint: int, use_gpu: bool) -> np.ndarray:
167
- if method == "HDBSCAN" and HDBSCAN_AVAILABLE:
168
- clust = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=max(5, min_cluster_size // 5), metric='euclidean')
169
- return clust.fit_predict(embs)
170
- k = max(10, k_hint or int(max(20, math.sqrt(len(embs) / 50))))
171
- kmeans = faiss.Kmeans(d=embs.shape[1], k=k, niter=25, verbose=False, gpu=use_gpu)
172
- kmeans.train(embs)
173
- _, labels = kmeans.index.search(embs, 1)
174
- return labels.reshape(-1)
175
-
176
-
177
- def zero_shot_embed_sim(embs: np.ndarray, model: SentenceTransformer, label_texts: List[str], use_gpu: bool) -> Tuple[np.ndarray, np.ndarray]:
178
- prompts = [f"This email is about: {t}" for t in label_texts]
179
- label_embs = model.encode(prompts, normalize_embeddings=True, convert_to_numpy=True, device="cuda" if use_gpu else "cpu").astype(np.float32)
180
- sims = embs @ label_embs.T
181
- top_idx = sims.argmax(axis=1)
182
- top_score = sims[np.arange(len(embs)), top_idx]
183
- return top_idx, top_score
184
-
185
- # ------------------- Defaults -------------------
186
- DEFAULT_LABELS = [
187
- "Newsletters/Subscriptions",
188
- "Receipts & Billing",
189
- "Personal/Family",
190
- "Work/Colleagues",
191
- "Meetings & Calendars",
192
- "Travel/Itineraries",
193
- "Legal/Contracts",
194
- "System Notifications",
195
- "Security/2FA",
196
- "Hiring/Recruiting",
197
- "Support Tickets",
198
- "Politics/Government",
199
- "Media/Press",
200
- "Unknown"
201
- ]
202
-
203
- # ------------------- Search -------------------
204
- class EmailSearch:
205
- def __init__(self, df, embs, model):
206
- self.df = df
207
- self.embs = embs
208
- self.model = model
209
- self.index = faiss.IndexFlatIP(embs.shape[1])
210
- self.index.add(embs)
211
-
212
- def query(self, q: str, top_k=20):
213
- q_emb = self.model.encode([q], normalize_embeddings=True, convert_to_numpy=True)
214
- scores, idx = self.index.search(q_emb.astype(np.float32), top_k)
215
- results = self.df.iloc[idx[0]].copy()
216
- results["score"] = scores[0]
217
- return results
218
-
219
  # ------------------- Gradio UI -------------------
220
- with gr.Blocks(title="Email Organizer & Browser") as demo:
221
  gr.Markdown("""
222
  # Email Organizer & Browser (No-Redaction)
223
- Upload a **.jsonl** or **.json** of emails. The app normalizes, deduplicates, embeds, clusters, labels, and lets you **search** your inbox semantically.
224
-
225
- **CPU mode defaults**: smaller model, CPU multiprocessing, and skipped language detection for speed. You can change these below.
226
  """)
227
 
228
  with gr.Row():
229
  inbox_file = gr.File(label="Upload emails (.jsonl or .json)", file_types=[".jsonl", ".json"])
230
 
231
  with gr.Row():
232
- model_choice = gr.Dropdown(
233
- label="Embedding model",
234
- choices=[
235
- "sentence-transformers/paraphrase-MiniLM-L3-v2", # fast 384-dim (default)
236
- "sentence-transformers/all-MiniLM-L6-v2", # slower 768-dim
237
- ],
238
- value="sentence-transformers/paraphrase-MiniLM-L3-v2"
239
- )
240
- batch_size_in = gr.Number(label="Batch size (CPU)", value=128, precision=0)
241
- mp_cpu = gr.Checkbox(label="Use CPU multiprocessing", value=True)
242
  skip_lang = gr.Checkbox(label="Skip language detection (faster)", value=True)
243
 
 
 
 
 
 
244
  run_btn = gr.Button("Process", variant="primary")
245
  status = gr.Textbox(label="Status", interactive=False)
246
- label_counts_df = gr.Dataframe(label="Label counts (by sender domain)", interactive=False)
247
- html_samples = gr.HTML(label="Samples")
 
248
 
249
  with gr.Row():
250
  search_query = gr.Textbox(label="Search emails (keywords, names, etc.)")
251
  search_btn = gr.Button("Search")
252
  search_results = gr.Dataframe(label="Search results", interactive=False)
253
 
 
254
  state_df = gr.State()
255
- state_embs = gr.State()
256
- state_model = gr.State()
257
- state_search = gr.State()
258
 
 
259
  def _load_json_records(local_path: str) -> List[Dict[str, Any]]:
260
  recs: List[Dict[str, Any]] = []
261
  if local_path.endswith(".jsonl"):
@@ -277,75 +173,136 @@ with gr.Blocks(title="Email Organizer & Browser") as demo:
277
  recs = [obj]
278
  return recs
279
 
280
- def process_file(inbox_file, model_choice, batch_size_in, mp_cpu, skip_lang):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  if inbox_file is None:
282
- return "Please upload a file", None, None, None, None, None, None
283
 
284
- # apply fast flags
285
  global SKIP_LANGDETECT
286
  SKIP_LANGDETECT = bool(skip_lang)
287
 
288
- local_path = inbox_file.name
289
- recs = _load_json_records(local_path)
290
  if not recs:
291
- return "No valid records found.", None, None, None, None, None, None
292
 
293
- # Normalize
294
  normd = [normalize_email_record(r) for r in recs]
295
  df = pd.DataFrame(normd)
296
-
297
- # Deduplicate
298
  df = df.drop_duplicates(subset=["message_id", "subject", "text_hash"]).reset_index(drop=True)
299
 
300
- # Build texts WITHOUT cap (as requested)
301
  texts = (df["subject"].fillna("") + "\n\n" + df["body_text"].fillna("")).tolist()
302
 
303
- # Model (CPU only for free tier)
304
- model = SentenceTransformer(str(model_choice))
305
-
306
- # Embeddings (CPU multiprocessing optional)
307
- embs = embed_texts(
308
- model=model,
309
- texts=texts,
310
- batch_size=int(batch_size_in) if batch_size_in else 128,
311
- use_gpu=False,
312
- use_multiprocess=bool(mp_cpu),
313
  )
 
314
 
315
- # Build simple domain label counts as a quick organizer view
316
- label_counts = df.groupby("from_domain").size().reset_index(name="count").sort_values("count", ascending=False)
317
-
318
- # Build searcher
319
- searcher = EmailSearch(df, embs, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- # Show a small HTML preview of the first 20
322
- sample_html = df.head(20)[["date", "from_email", "subject", "body_text"]].to_html(escape=False)
 
 
323
 
324
- return (
325
- f"Processed {len(df)} emails with model {model_choice} (dim={embs.shape[1]}).",
326
- label_counts,
327
- sample_html,
328
- df,
329
- embs,
330
- model,
331
- searcher
332
- )
333
 
334
  run_btn.click(
335
  process_file,
336
- inputs=[inbox_file, model_choice, batch_size_in, mp_cpu, skip_lang],
337
- outputs=[status, label_counts_df, html_samples, state_df, state_embs, state_model, state_search]
338
  )
339
 
340
- def search_fn(q, df, embs, model, searcher):
341
- if searcher is None or not q:
 
342
  return pd.DataFrame()
343
- results = searcher.query(q, top_k=20)
344
- return results[["date","from_email","subject","body_text","score"]]
 
 
 
 
 
345
 
346
  search_btn.click(
347
  search_fn,
348
- inputs=[search_query, state_df, state_embs, state_model, state_search],
349
  outputs=[search_results]
350
  )
351
 
 
15
  import gradio as gr
16
  from tqdm import tqdm
17
 
18
+ # ---- NEW: classic ML (CPU-fast) ----
19
+ from sklearn.feature_extraction.text import TfidfVectorizer
20
+ from sklearn.cluster import MiniBatchKMeans
21
+ from sklearn.neighbors import NearestNeighbors
 
 
 
 
22
 
23
  # ------------------- Helpers -------------------
24
  URL = re.compile(r"https?://\S+", re.I)
25
+ SKIP_LANGDETECT = True # can be toggled in UI
 
 
 
 
 
 
 
 
 
26
 
27
  def html_to_text(html: str) -> str:
28
  if not html:
 
32
  tag.decompose()
33
  return soup.get_text(separator="\n")
34
 
 
35
  def strip_quotes_and_sigs(text: str) -> str:
36
  if not text:
37
  return ""
 
40
  res = re.split(r"\nSent from my ", res)[0]
41
  return res.strip()
42
 
 
43
  def parse_name_email(s: str) -> Tuple[str, str]:
44
  if not s:
45
  return "", ""
 
49
  return "", s.strip()
50
 
51
  # ------------------- Normalization -------------------
 
52
  def normalize_email_record(raw: Dict[str, Any]) -> Dict[str, Any]:
53
  subject = raw.get("subject") or raw.get("Subject") or ""
54
  body_html = raw.get("body_html") or raw.get("html") or ""
 
112
  "text_hash": text_hash,
113
  }
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # ------------------- Gradio UI -------------------
116
+ with gr.Blocks(title="Email Organizer & Browser (TF-IDF + MiniBatchKMeans)") as demo:
117
  gr.Markdown("""
118
  # Email Organizer & Browser (No-Redaction)
119
+ **Engine:** TF-IDF (sparse) + MiniBatchKMeans clustering + cosine NearestNeighbors search.
120
+ CPU-fast, no text cap. Upload **.jsonl** or **.json**.
 
121
  """)
122
 
123
  with gr.Row():
124
  inbox_file = gr.File(label="Upload emails (.jsonl or .json)", file_types=[".jsonl", ".json"])
125
 
126
  with gr.Row():
127
+ max_features = gr.Number(label="TF-IDF max_features", value=100_000, precision=0)
128
+ min_df = gr.Number(label="min_df (doc freq ≥)", value=3, precision=0)
129
+ max_df = gr.Slider(label="max_df (fraction ≤)", minimum=0.1, maximum=0.95, value=0.6, step=0.05)
130
+ use_bigrams = gr.Checkbox(label="Use bigrams (1–2)", value=True)
 
 
 
 
 
 
131
  skip_lang = gr.Checkbox(label="Skip language detection (faster)", value=True)
132
 
133
+ with gr.Row():
134
+ auto_k = gr.Checkbox(label="Auto choose k", value=True)
135
+ k_clusters = gr.Number(label="k (MiniBatchKMeans)", value=300, precision=0)
136
+ mb_batch = gr.Number(label="KMeans batch_size", value=4096, precision=0)
137
+
138
  run_btn = gr.Button("Process", variant="primary")
139
  status = gr.Textbox(label="Status", interactive=False)
140
+
141
+ cluster_counts_df = gr.Dataframe(label="Cluster summary", interactive=False)
142
+ html_samples = gr.HTML(label="Preview & Domain counts")
143
 
144
  with gr.Row():
145
  search_query = gr.Textbox(label="Search emails (keywords, names, etc.)")
146
  search_btn = gr.Button("Search")
147
  search_results = gr.Dataframe(label="Search results", interactive=False)
148
 
149
+ # States we need for search:
150
  state_df = gr.State()
151
+ state_vectorizer = gr.State()
152
+ state_nn = gr.State()
 
153
 
154
+ # -------- IO helpers --------
155
  def _load_json_records(local_path: str) -> List[Dict[str, Any]]:
156
  recs: List[Dict[str, Any]] = []
157
  if local_path.endswith(".jsonl"):
 
173
  recs = [obj]
174
  return recs
175
 
176
+ # -------- Cluster naming --------
177
+ def top_terms_per_cluster(X, labels, vectorizer, topn=5):
178
+ """Return dict: cluster_id -> 'term1, term2, ...' using mean TF-IDF weights."""
179
+ names = vectorizer.get_feature_names_out()
180
+ out = {}
181
+ # iterate only over present labels
182
+ for c in np.unique(labels):
183
+ mask = (labels == c)
184
+ if mask.sum() == 0:
185
+ out[int(c)] = f"cluster_{c}"
186
+ continue
187
+ # mean tfidf for cluster c
188
+ mean_vec = X[mask].mean(axis=0).A1 # to 1D array
189
+ idx = np.argpartition(mean_vec, -topn)[-topn:]
190
+ idx = idx[np.argsort(-mean_vec[idx])]
191
+ terms = [names[i] for i in idx if mean_vec[i] > 0]
192
+ out[int(c)] = ", ".join(terms) if terms else f"cluster_{c}"
193
+ return out
194
+
195
+ def auto_k_rule(n_docs: int) -> int:
196
+ # heuristic: between 150 and 600 depending on corpus size
197
+ base = int(max(100, min(600, math.sqrt(max(n_docs, 1) / 50.0) * 100)))
198
+ return base
199
+
200
+ # -------- Main pipeline --------
201
+ def process_file(inbox_file, max_features, min_df, max_df, use_bigrams, skip_lang, auto_k, k_clusters, mb_batch):
202
  if inbox_file is None:
203
+ return "Please upload a file", None, None, None, None, None
204
 
205
+ # fast flags
206
  global SKIP_LANGDETECT
207
  SKIP_LANGDETECT = bool(skip_lang)
208
 
209
+ # Load -> normalize
210
+ recs = _load_json_records(inbox_file.name)
211
  if not recs:
212
+ return "No valid records found.", None, None, None, None, None
213
 
 
214
  normd = [normalize_email_record(r) for r in recs]
215
  df = pd.DataFrame(normd)
216
+ # Deduplicate conservatively
 
217
  df = df.drop_duplicates(subset=["message_id", "subject", "text_hash"]).reset_index(drop=True)
218
 
219
+ # Build texts (no cap)
220
  texts = (df["subject"].fillna("") + "\n\n" + df["body_text"].fillna("")).tolist()
221
 
222
+ # TF-IDF
223
+ ngram_range = (1, 2) if use_bigrams else (1, 1)
224
+ vec = TfidfVectorizer(
225
+ analyzer="word",
226
+ ngram_range=ngram_range,
227
+ max_features=int(max_features) if max_features else None,
228
+ min_df=int(min_df) if min_df else 1,
229
+ max_df=float(max_df) if max_df else 1.0,
230
+ dtype=np.float32
 
231
  )
232
+ X = vec.fit_transform(texts)
233
 
234
+ # KMeans
235
+ if bool(auto_k):
236
+ k = auto_k_rule(X.shape[0])
237
+ else:
238
+ k = max(10, int(k_clusters or 300))
239
+ kmeans = MiniBatchKMeans(
240
+ n_clusters=k,
241
+ batch_size=int(mb_batch or 4096),
242
+ random_state=0,
243
+ n_init="auto"
244
+ )
245
+ labels = kmeans.fit_predict(X)
246
+ df["cluster_id"] = labels
247
+
248
+ # Auto-name clusters from top terms (and optionally dominant domain)
249
+ term_names = top_terms_per_cluster(X, labels, vec, topn=5)
250
+
251
+ # Optionally fold dominant sender domain into name when strongly dominant
252
+ cluster_names = []
253
+ for c in labels:
254
+ name = term_names[int(c)]
255
+ cluster_names.append(name)
256
+ df["cluster_name"] = cluster_names
257
+
258
+ # Fit cosine NN over TF-IDF for fast search
259
+ nn = NearestNeighbors(metric="cosine", algorithm="brute")
260
+ nn.fit(X)
261
+
262
+ # Summaries
263
+ cluster_counts = (
264
+ df.groupby(["cluster_id", "cluster_name"]).size()
265
+ .reset_index(name="count")
266
+ .sort_values("count", ascending=False)
267
+ ).head(500)
268
+
269
+ domain_counts = (
270
+ df.groupby("from_domain").size()
271
+ .reset_index(name="count")
272
+ .sort_values("count", ascending=False)
273
+ .head(50)
274
+ )
275
 
276
+ # HTML preview: first 20 with cluster tags + domain stats
277
+ sample_cols = ["date", "from_email", "subject", "cluster_name", "body_text"]
278
+ preview_html = "<h3>Sample (first 20)</h3>" + df.head(20)[sample_cols].to_html(escape=False, index=False)
279
+ preview_html += "<br/><h3>Top sender domains</h3>" + domain_counts.to_html(escape=False, index=False)
280
 
281
+ status = f"Processed {len(df)} emails | TF-IDF shape={X.shape} | k={k}"
282
+ # Keep only what search needs: df, vectorizer, nn (X stays inside nn)
283
+ return status, cluster_counts, preview_html, df, vec, nn
 
 
 
 
 
 
284
 
285
  run_btn.click(
286
  process_file,
287
+ inputs=[inbox_file, max_features, min_df, max_df, use_bigrams, skip_lang, auto_k, k_clusters, mb_batch],
288
+ outputs=[status, cluster_counts_df, html_samples, state_df, state_vectorizer, state_nn]
289
  )
290
 
291
+ # -------- Search: cosine NN over TF-IDF --------
292
+ def search_fn(q, df, vectorizer, nn):
293
+ if (not q) or (df is None) or (vectorizer is None) or (nn is None):
294
  return pd.DataFrame()
295
+ q_vec = vectorizer.transform([q])
296
+ distances, indices = nn.kneighbors(q_vec, n_neighbors=min(20, len(df)))
297
+ inds = indices[0]
298
+ sims = 1.0 - distances[0] # cosine similarity
299
+ results = df.iloc[inds].copy()
300
+ results["score"] = sims
301
+ return results[["date","from_email","subject","cluster_name","body_text","score"]]
302
 
303
  search_btn.click(
304
  search_fn,
305
+ inputs=[search_query, state_df, state_vectorizer, state_nn],
306
  outputs=[search_results]
307
  )
308