shahidshaikh's picture
Upload 4 files
ca12042 verified
"""
tools.py — 4 tools only. Pure data fetch + persist.
ALL logic, filtering, ranking, labelling lives in agent prompts.
Fixes applied:
- Sr No off-by-one bug corrected (counts data rows, not raw lines)
- fetch_papers exception now logs reason before arXiv fallback
- read_output path traversal vulnerability patched (realpath check)
- _embed() is no longer dead code — exported cleanly for app.py
- Consistent error returns across all tools (always a string)
"""
import csv, json, os, urllib.request, urllib.parse
from xml.etree import ElementTree as ET
from langchain_core.tools import tool
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import plotly.express as px
import plotly.figure_factory as ff
from scipy.spatial.distance import pdist
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
# ── Hugging Face Persistent Storage Support ─────────────────────
# Detect if running on HF Space with /data persistent storage mounted.
PERSISTENT_ROOT = "/data"
if os.path.exists(PERSISTENT_ROOT) and os.access(PERSISTENT_ROOT, os.W_OK):
print(f">>> tools.py: Detected HF persistent storage at {PERSISTENT_ROOT}")
CSV_PATH = os.path.join(PERSISTENT_ROOT, "papers.csv")
OUT_DIR = os.path.join(PERSISTENT_ROOT, "outputs")
else:
CSV_PATH = os.path.join(LOCAL_PATH, "papers.csv")
OUT_DIR = os.path.join(LOCAL_PATH, "outputs")
os.makedirs(OUT_DIR, exist_ok=True)
print(f">>> tools.py: storage ready -> {CSV_PATH}")
HEADERS = ["Sr No", "Title", "Authors", "Year", "Journal",
"Link", "DOI", "Abstract", "Citations", "Keywords"]
def _out(filename: str) -> str:
return os.path.join(OUT_DIR, filename)
# ── Embedding model — lazy-loaded, exported for app.py ──────────
_MODEL = None
def _embed():
"""Lazy-load the sentence-transformer model. Called once at startup."""
global _MODEL
if _MODEL is None:
from sentence_transformers import SentenceTransformer
_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return _MODEL
# ════════════════════════════════════════════════════════════════
# TOOL 1 — Fetch papers (Semantic Scholar primary, arXiv fallback)
# Returns raw JSON lines. LLM decides what to keep.
# ════════════════════════════════════════════════════════════════
@tool
def fetch_papers(query: str) -> str:
"""Fetch raw papers from Semantic Scholar (falls back to arXiv on failure).
Returns raw JSON lines — one paper per line. No filtering applied."""
_ss_error = None # capture primary failure reason for transparency
# ── Primary: Semantic Scholar ────────────────────────────────
try:
q = urllib.parse.quote(query)
url = (
f"https://api.semanticscholar.org/graph/v1/paper/search"
f"?query={q}&limit=15"
f"&fields=title,authors,year,venue,externalIds,abstract,citationCount"
)
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
with urllib.request.urlopen(req, timeout=12) as r:
data = json.loads(r.read()).get("data", [])
if data:
lines = []
for p in data:
authors = ", ".join(a.get("name", "") for a in p.get("authors", [])[:3])
pid = p.get("paperId", "")
lines.append(json.dumps({
"Title": p.get("title", ""),
"Authors": authors,
"Year": str(p.get("year", "")),
"Journal": p.get("venue", ""),
"Link": f"https://www.semanticscholar.org/paper/{pid}",
"DOI": p.get("externalIds", {}).get("DOI", ""),
"Abstract": (p.get("abstract") or "")[:500],
"Citations": str(p.get("citationCount", "")),
"Keywords": "",
}))
return "\n".join(lines)
except Exception as e:
_ss_error = str(e) # log reason; fall through to arXiv
# ── Fallback: arXiv ─────────────────────────────────────────
try:
q = urllib.parse.quote(query)
url = f"https://export.arxiv.org/api/query?search_query=all:{q}&max_results=15"
with urllib.request.urlopen(url, timeout=12) as r:
tree = ET.fromstring(r.read())
ns = {"a": "http://www.w3.org/2005/Atom"}
lines = []
for e in tree.findall("a:entry", ns):
aid = (e.findtext("a:id", "", ns) or "").strip()
authors = ", ".join(
n.findtext("a:name", "", ns) for n in e.findall("a:author", ns)[:3]
)
lines.append(json.dumps({
"Title": (e.findtext("a:title", "", ns) or "").replace("\n", " ").strip(),
"Authors": authors,
"Year": (e.findtext("a:published", "", ns) or "")[:4],
"Journal": "arXiv",
"Link": aid,
"DOI": "",
"Abstract": (e.findtext("a:summary", "", ns) or "").replace("\n", " ")[:500],
"Citations": "",
"Keywords": "",
}))
note = f"[Note: Semantic Scholar failed ({_ss_error}), results from arXiv]\n" if _ss_error else ""
return note + ("\n".join(lines) if lines else "No results.")
except Exception as e:
return f"Fetch error — Semantic Scholar: {_ss_error} | arXiv: {e}"
# ════════════════════════════════════════════════════════════════
# TOOL 2 — Save papers the LLM approved
# LLM passes only the papers it judged relevant.
# Fix: sr now correctly counts data rows (not raw lines).
# ════════════════════════════════════════════════════════════════
@tool
def save_papers(papers_json: str) -> str:
"""Persist LLM-approved papers to papers.csv. Pass a JSON array of paper objects.
Required keys: Title, Authors, Year, Journal, Link, DOI, Abstract, Citations, Keywords."""
try:
papers = json.loads(papers_json)
if not isinstance(papers, list):
papers = [papers]
# ── Deduplicate against existing titles ─────────────────
existing_titles: set[str] = set()
existing_count = 0
if os.path.isfile(CSV_PATH):
with open(CSV_PATH, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
existing_titles.add(row.get("Title", "").strip().lower())
existing_count += 1
# existing_count = number of DATA rows (DictReader skips header)
new_papers = [
p for p in papers
if p.get("Title", "").strip()
and p.get("Title", "").strip().lower() not in existing_titles
]
if not new_papers:
return f"0 new papers saved (all {len(papers)} were duplicates)."
needs_header = not os.path.isfile(CSV_PATH)
next_sr = existing_count + 1 # correct serial number start
with open(CSV_PATH, "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=HEADERS, extrasaction="ignore")
if needs_header:
writer.writeheader()
for i, p in enumerate(new_papers):
p["Sr No"] = str(next_sr + i)
writer.writerow(p)
total = existing_count + len(new_papers)
return f"Saved {len(new_papers)} new papers. Total: {total}/200."
except json.JSONDecodeError as e:
return f"Save error — invalid JSON: {e}"
except Exception as e:
return f"Save error: {e}"
# ════════════════════════════════════════════════════════════════
# TOOL 3 — Save any LLM output (labels, themes, taxonomy, narrative)
# ════════════════════════════════════════════════════════════════
@tool
def save_output(filename: str, content: str) -> str:
"""Save any LLM-generated content to outputs/<filename>.
Use for: labels.json, themes.json, taxonomy_map.json, narrative.txt, etc.
content: raw string (JSON or plain text) to write."""
try:
# Safety: ensure filename has no path separators
safe_name = os.path.basename(filename)
if not safe_name:
return "Save error: filename cannot be empty."
path = _out(safe_name)
with open(path, "w", encoding="utf-8") as f:
f.write(content)
kb = os.path.getsize(path) // 1024
return f"Saved {safe_name} ({kb}KB)."
except Exception as e:
return f"Save error: {e}"
# ════════════════════════════════════════════════════════════════
# TOOL 4 — Read any saved file (LLM loads its own state)
# Fix: path traversal vulnerability patched with realpath check.
# ════════════════════════════════════════════════════════════════
@tool
def read_output(filename: str) -> str:
"""Read a previously saved file from outputs/. Pass filename only (no paths).
Use 'list' to see all available files.
Use 'papers.csv' to get current paper count."""
filename = filename.strip()
if filename.lower() == "list":
if not os.path.isdir(OUT_DIR):
return "No outputs yet."
files = sorted(os.listdir(OUT_DIR))
if not files:
return "No output files yet."
return "Files:\n" + "\n".join(
f" {fn} ({os.path.getsize(_out(fn)) // 1024}KB)" for fn in files
)
if filename == "papers.csv":
if not os.path.isfile(CSV_PATH):
return "No papers saved yet."
with open(CSV_PATH, encoding="utf-8") as f:
# subtract 1 for header row
n = sum(1 for _ in f) - 1
return f"{max(0, n)} papers in CSV."
# ── Path traversal guard ─────────────────────────────────────
safe_name = os.path.basename(filename) # strips any ../
resolved = os.path.realpath(_out(safe_name))
allowed_root = os.path.realpath(OUT_DIR)
if not resolved.startswith(allowed_root + os.sep) and resolved != allowed_root:
return f"Access denied: '{filename}' is outside the outputs directory."
if not os.path.isfile(resolved):
return f"'{safe_name}' not found. Call read_output('list') to see available files."
with open(resolved, encoding="utf-8", errors="replace") as f:
content = f.read()
if len(content) > 8000:
content = content[:8000] + f"\n...[truncated — {len(content)} chars total]"
return content
# ════════════════════════════════════════════════════════════════
# NLP CLUSTERING
# ════════════════════════════════════════════════════════════════
@tool
def run_clustering(csv_path: str, mode: str, max_clusters: int=15) -> str:
"""
Reads the CSV, extracts either 'Abstract' or 'Title', embeds them,
clusters them using AgglomerativeClustering, and extracts top representative
sentences per cluster. Also generates Plotly HTML visualizations.
Returns a dictionary of raw topic data for the LLM.
"""
if mode not in ["abstract", "title"]:
mode = "abstract"
target_column = "Abstract" if mode == "abstract" else "Title"
texts = []
paper_ids = []
import re
with open(csv_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
text = row.get(target_column, "").strip()
# Systematic scrub: URLs, emails, and short boilerplate
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = re.sub(r'\S+@\S+', '', text)
text = text.strip()
if len(text) > 40: # Stronger length noise filter
texts.append(text)
paper_ids.append(row.get("Sr No", str(i)))
if len(texts) < 3:
raise ValueError("Not enough papers to cluster.")
print(f"[{mode}] Embedding {len(texts)} texts...")
model = _embed()
embeddings = model.encode(texts, show_progress_bar=False)
# Determine number of clusters (heuristic)
n_clusters = min(max_clusters, max(3, len(texts) // 10))
print(f"[{mode}] Clustering into {n_clusters} clusters...")
clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='cosine', linkage='average')
labels = clusterer.fit_predict(embeddings)
# Calculate centroids and find representative sentences
topic_map = {}
for cluster_id in range(n_clusters):
idx = np.where(labels == cluster_id)[0]
if len(idx) == 0:
continue
cluster_embeddings = embeddings[idx]
centroid = np.mean(cluster_embeddings, axis=0).reshape(1, -1)
sim_scores = cosine_similarity(centroid, cluster_embeddings)[0]
# We need representative sentences for the LLM (filter duplicates, crop length)
top_k = min(10, len(idx))
top_indices = np.argsort(sim_scores)[::-1][:top_k]
top_sentences = []
seen_signatures = set()
for i in top_indices:
text = texts[idx[i]].strip()
# Generate a signature to catch near-duplicates (first 60 chars)
sig = text.lower()[:60]
if sig not in seen_signatures:
seen_signatures.add(sig)
# Truncate text to 250 chars to save tokens while keeping signal high
trunc = text[:250].replace('\n', ' ') + ("..." if len(text) > 250 else "")
top_sentences.append(trunc)
if len(top_sentences) == 3: # Keep strictly max 3 highly distinct sentences
break
topic_map[str(cluster_id)] = {
"size": len(idx),
"top_sentences": top_sentences,
"paper_ids": [str(paper_ids[i]) for i in idx]
}
# Persist core matrices to allow decoupling visualization from clustering
np.save(os.path.join(OUT_DIR, f"{mode}_emb.npy"), embeddings)
np.save(os.path.join(OUT_DIR, f"{mode}_cluster_labels.npy"), labels)
# Generate Initial Visualizations (Unlabeled)
_generate_visuals(mode, embeddings, labels, n_clusters, topic_map)
# Save raw clusters for LLM context
raw_path = os.path.join(OUT_DIR, f"{mode}_raw_clusters.json")
with open(raw_path, "w", encoding="utf-8") as f:
json.dump(topic_map, f, indent=2)
# NEW: Save summaries explicitly for the UI table evidence
summ_path = os.path.join(OUT_DIR, f"{mode}_summaries.json")
summaries = {
tid: v["top_sentences"][0] if v["top_sentences"] else ""
for tid, v in topic_map.items()
}
with open(summ_path, "w", encoding="utf-8") as f:
json.dump(summaries, f, indent=2)
return json.dumps(topic_map)
def update_charts_with_labels(mode: str):
"""Regenerates the Plotly charts systematically after the LLM has labeled them."""
emb_path = os.path.join(OUT_DIR, f"{mode}_emb.npy")
lbl_path = os.path.join(OUT_DIR, f"{mode}_cluster_labels.npy")
raw_path = os.path.join(OUT_DIR, f"{mode}_raw_clusters.json")
names_path = os.path.join(OUT_DIR, f"{mode}_labels.json")
if not all(os.path.exists(p) for p in [emb_path, lbl_path, raw_path, names_path]):
return
embeddings = np.load(emb_path)
labels = np.load(lbl_path)
with open(raw_path, "r", encoding="utf-8") as f: topic_map = json.load(f)
with open(names_path, "r", encoding="utf-8") as f:
names_data = json.load(f)
if isinstance(names_data, dict) and "labels" in names_data:
names_data = names_data["labels"]
custom_names = {}
for k, v in (names_data.items() if isinstance(names_data, dict) else {}):
if isinstance(v, dict) and "label" in v:
custom_names[k] = v["label"]
# Figure out n_clusters safely
n_clusters = max([int(k) for k in topic_map.keys()]) + 1 if topic_map else 0
_generate_visuals(mode, embeddings, labels, n_clusters, topic_map, custom_names)
def _generate_visuals(mode: str, embeddings, labels, n_clusters, topic_map, custom_names=None):
print(f"[{mode}] Generating Plotly visuals (Dark theme -> White template applied)...")
if custom_names is None: custom_names = {}
def get_name(tid_str):
return custom_names.get(tid_str, f"Topic {tid_str}")
# 1. Intertopic Distance (2D PCA Scatter)
pca = PCA(n_components=2)
coords = pca.fit_transform(embeddings)
# Map raw integer labels to custom string names for the scatter plot legend
named_labels = [get_name(str(l)) for l in labels]
fig_scatter = px.scatter(
x=coords[:, 0], y=coords[:, 1], color=named_labels,
title=f"Intertopic Distance ({mode.capitalize()})",
labels={'color': 'Topic'}
)
fig_scatter.update_layout(template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
fig_scatter.write_html(os.path.join(OUT_DIR, f"{mode}_intertopic.html"), include_plotlyjs="cdn")
# 2. Topic Frequency (Bar chart)
valid_ids = [str(i) for i in range(n_clusters) if str(i) in topic_map]
sizes = [topic_map[vid]['size'] for vid in valid_ids]
bar_names = [get_name(vid) for vid in valid_ids]
fig_bar = px.bar(
x=bar_names, y=sizes,
title=f"Topic Frequency ({mode.capitalize()})",
labels={'x': 'Topic', 'y': 'Number of Papers'}
)
fig_bar.update_layout(template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
fig_bar.write_html(os.path.join(OUT_DIR, f"{mode}_bars.html"), include_plotlyjs="cdn")
# 3. Hierarchy (Dendrogram of centroids)
centroids = []
valid_clusters = []
for vid in valid_ids:
idx = np.where(labels == int(vid))[0]
centroids.append(np.mean(embeddings[idx], axis=0))
valid_clusters.append(get_name(vid))
if len(centroids) > 1:
fig_dendro = ff.create_dendrogram(np.array(centroids), labels=valid_clusters)
fig_dendro.update_layout(title=f"Topic Hierarchy ({mode.capitalize()})", template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
fig_dendro.write_html(os.path.join(OUT_DIR, f"{mode}_hierarchy.html"), include_plotlyjs="cdn")
# 4. Similarity Heatmap (Cosine similarity between centroids)
if len(centroids) > 1:
sim_matrix = cosine_similarity(np.array(centroids))
fig_heat = px.imshow(
sim_matrix,
x=valid_clusters, y=valid_clusters,
title=f"Topic Similarity Heatmap ({mode.capitalize()})",
color_continuous_scale="Viridis"
)
fig_heat.update_layout(template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
fig_heat.write_html(os.path.join(OUT_DIR, f"{mode}_heatmap.html"), include_plotlyjs="cdn")