|
|
import gradio as gr |
|
|
import json, re, math, os |
|
|
from collections import Counter, defaultdict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize(text): |
|
|
return re.findall(r"[A-Za-z0-9']+", text.lower()) |
|
|
|
|
|
def text_vector(text): |
|
|
return Counter(tokenize(text)) |
|
|
|
|
|
def centroid(docs): |
|
|
C = Counter() |
|
|
for d in docs: |
|
|
C.update(text_vector(d["text"])) |
|
|
return C |
|
|
|
|
|
def cosine(a, b): |
|
|
num = 0 |
|
|
da = 0 |
|
|
db = 0 |
|
|
for k in set(a.keys()) | set(b.keys()): |
|
|
va = a.get(k,0) |
|
|
vb = b.get(k,0) |
|
|
num += va*vb |
|
|
da += va*va |
|
|
db += vb*vb |
|
|
if da == 0 or db == 0: |
|
|
return 0 |
|
|
return num / math.sqrt(da*db) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_records_from_path(path): |
|
|
if not os.path.exists(path): |
|
|
return None, None, "β JSONL file not found." |
|
|
|
|
|
records = [] |
|
|
with open(path, "r", encoding="utf8") as f: |
|
|
for line in f: |
|
|
try: |
|
|
records.append(json.loads(line)) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return initialize_state(records) |
|
|
|
|
|
|
|
|
def load_jsonl(user_file): |
|
|
if user_file is None: |
|
|
return None, "β No file uploaded." |
|
|
|
|
|
records = [] |
|
|
with open(user_file.name, "r", encoding="utf8") as f: |
|
|
for line in f: |
|
|
try: |
|
|
records.append(json.loads(line)) |
|
|
except: |
|
|
pass |
|
|
|
|
|
state, msg = initialize_state(records) |
|
|
return state, msg |
|
|
|
|
|
|
|
|
def initialize_state(records): |
|
|
|
|
|
for i, r in enumerate(records): |
|
|
if "id" not in r: |
|
|
r["id"] = i |
|
|
|
|
|
cluster_map = defaultdict(list) |
|
|
for r in records: |
|
|
cluster_map[r.get("cluster", -1)].append(r) |
|
|
|
|
|
docs_text = [r["text"] for r in records] |
|
|
tokenized_docs = [tokenize(t) for t in docs_text] |
|
|
|
|
|
doc_freq = Counter() |
|
|
for toks in tokenized_docs: |
|
|
for t in set(toks): |
|
|
doc_freq[t] += 1 |
|
|
|
|
|
Ndocs = len(records) |
|
|
avg_len = sum(len(t) for t in tokenized_docs) / max(Ndocs, 1) |
|
|
|
|
|
centroids = {cid: centroid(docs) for cid, docs in cluster_map.items()} |
|
|
|
|
|
state = { |
|
|
"records": records, |
|
|
"cluster_map": cluster_map, |
|
|
"tokenized_docs": tokenized_docs, |
|
|
"doc_freq": doc_freq, |
|
|
"Ndocs": Ndocs, |
|
|
"avg_len": avg_len, |
|
|
"centroids": centroids, |
|
|
} |
|
|
|
|
|
return state, f"Loaded {len(records)} records." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bm25_score(query, doc_toks, doc_freq, Ndocs, avg_len): |
|
|
k=1.5; b=0.75 |
|
|
score = 0 |
|
|
q_toks = tokenize(query) |
|
|
|
|
|
for q in q_toks: |
|
|
df = doc_freq.get(q, 0) |
|
|
if df == 0: |
|
|
continue |
|
|
idf = math.log((Ndocs - df + 0.5) / (df + 0.5) + 1) |
|
|
tf = doc_toks.count(q) |
|
|
denom = tf + k * (1 - b + b * (len(doc_toks) / avg_len)) |
|
|
score += idf * (tf * (k + 1)) / denom |
|
|
|
|
|
return score |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def do_view_cluster(state, cid): |
|
|
if state is None: |
|
|
return "β No dataset loaded." |
|
|
|
|
|
try: |
|
|
cid = int(cid) |
|
|
except: |
|
|
return "Enter a valid cluster number." |
|
|
|
|
|
cluster_map = state["cluster_map"] |
|
|
|
|
|
if cid not in cluster_map: |
|
|
return "β Cluster not found." |
|
|
|
|
|
out = [f"=== Cluster {cid} ({len(cluster_map[cid])} docs) ===\n"] |
|
|
for d in cluster_map[cid]: |
|
|
out.append(f"\n--- id={d['id']} ---\n{d['text']}\n") |
|
|
|
|
|
return "\n".join(out) |
|
|
|
|
|
|
|
|
def do_search(state, query): |
|
|
if state is None: |
|
|
return "β No dataset loaded." |
|
|
|
|
|
results = [] |
|
|
for r, toks in zip(state["records"], state["tokenized_docs"]): |
|
|
score = bm25_score(query, toks, state["doc_freq"], state["Ndocs"], state["avg_len"]) |
|
|
if score > 0: |
|
|
results.append((score, r)) |
|
|
|
|
|
results.sort(key=lambda x: x[0], reverse=True) |
|
|
|
|
|
out = [f"=== Results for '{query}' ==="] |
|
|
for score, r in results[:30]: |
|
|
out.append(f"\nScore {score:.2f} β Cluster {r['cluster']} β id={r['id']}\n{r['text']}\n") |
|
|
|
|
|
return "\n".join(out) |
|
|
|
|
|
|
|
|
def do_entity_search(state, name): |
|
|
if state is None: |
|
|
return "β No dataset loaded." |
|
|
|
|
|
hits = [] |
|
|
for cid, docs in state["cluster_map"].items(): |
|
|
count = sum(name.lower() in d["text"].lower() for d in docs) |
|
|
if count: |
|
|
hits.append((count, cid)) |
|
|
|
|
|
hits.sort(reverse=True) |
|
|
|
|
|
out = [f"=== Clusters mentioning '{name}' ==="] |
|
|
for count, cid in hits[:30]: |
|
|
out.append(f"Cluster {cid}: {count} hits") |
|
|
|
|
|
return "\n".join(out) |
|
|
|
|
|
|
|
|
def do_show_topics(state): |
|
|
if state is None: |
|
|
return "β No dataset loaded." |
|
|
|
|
|
STOP = set(""" |
|
|
the and to of a in is this that for on with as be or by from at |
|
|
an it are was you your if but have we they his her she their our |
|
|
subject re fw message thereof all may any doc email |
|
|
""".split()) |
|
|
|
|
|
out = ["=== Cluster Topics ==="] |
|
|
|
|
|
for cid, cent in state["centroids"].items(): |
|
|
filtered = {w: c for w, c in cent.items() |
|
|
if w not in STOP and len(w) > 2 and c > 1} |
|
|
top = [w for w, _ in Counter(filtered).most_common(10)] |
|
|
out.append(f"Cluster {cid:<4} | {' '.join(top)}") |
|
|
|
|
|
return "\n".join(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_PATH = "epstein_semantic.jsonl" |
|
|
|
|
|
startup_state = None |
|
|
startup_msg = "β No default dataset found." |
|
|
|
|
|
if os.path.exists(DEFAULT_PATH): |
|
|
startup_state, startup_msg = load_records_from_path(DEFAULT_PATH) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Epstein Semantic Explorer") as demo: |
|
|
|
|
|
gr.Markdown("# Epstein Semantic Explorer") |
|
|
gr.Markdown(startup_msg) |
|
|
|
|
|
state_box = gr.State(startup_state) |
|
|
|
|
|
cluster_input = gr.Number(label="Cluster #", value=0) |
|
|
keyword_input = gr.Textbox(label="Keyword Search") |
|
|
entity_input = gr.Textbox(label="Entity Search (name)") |
|
|
jsonl_file = gr.File(label="Upload different JSONL dataset") |
|
|
|
|
|
out_box = gr.Textbox(label="Output", lines=40) |
|
|
|
|
|
|
|
|
cluster_input.change(do_view_cluster, [state_box, cluster_input], out_box) |
|
|
keyword_input.submit(do_search, [state_box, keyword_input], out_box) |
|
|
entity_input.submit(do_entity_search, [state_box, entity_input], out_box) |
|
|
|
|
|
gr.Button("Show Topics").click(do_show_topics, state_box, out_box) |
|
|
gr.Button("Load Dataset").click(load_jsonl, jsonl_file, [state_box, out_box]) |
|
|
|
|
|
demo.launch() |
|
|
|