| """
|
| tools.py — 4 tools only. Pure data fetch + persist.
|
| ALL logic, filtering, ranking, labelling lives in agent prompts.
|
|
|
| Fixes applied:
|
| - Sr No off-by-one bug corrected (counts data rows, not raw lines)
|
| - fetch_papers exception now logs reason before arXiv fallback
|
| - read_output path traversal vulnerability patched (realpath check)
|
| - _embed() is no longer dead code — exported cleanly for app.py
|
| - Consistent error returns across all tools (always a string)
|
| """
|
| import csv, json, os, urllib.request, urllib.parse
|
| from xml.etree import ElementTree as ET
|
| from langchain_core.tools import tool
|
|
|
| import numpy as np
|
| from sklearn.cluster import AgglomerativeClustering
|
| from sklearn.metrics.pairwise import cosine_similarity
|
| from sklearn.decomposition import PCA
|
| import plotly.graph_objects as go
|
| import plotly.express as px
|
| import plotly.figure_factory as ff
|
| from scipy.spatial.distance import pdist
|
|
|
| LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
|
| PERSISTENT_ROOT = "/data"
|
| if os.path.exists(PERSISTENT_ROOT) and os.access(PERSISTENT_ROOT, os.W_OK):
|
| print(f">>> tools.py: Detected HF persistent storage at {PERSISTENT_ROOT}")
|
| CSV_PATH = os.path.join(PERSISTENT_ROOT, "papers.csv")
|
| OUT_DIR = os.path.join(PERSISTENT_ROOT, "outputs")
|
| else:
|
| CSV_PATH = os.path.join(LOCAL_PATH, "papers.csv")
|
| OUT_DIR = os.path.join(LOCAL_PATH, "outputs")
|
|
|
| os.makedirs(OUT_DIR, exist_ok=True)
|
| print(f">>> tools.py: storage ready -> {CSV_PATH}")
|
|
|
| HEADERS = ["Sr No", "Title", "Authors", "Year", "Journal",
|
| "Link", "DOI", "Abstract", "Citations", "Keywords"]
|
|
|
| def _out(filename: str) -> str:
|
| return os.path.join(OUT_DIR, filename)
|
|
|
|
|
|
|
| _MODEL = None
|
|
|
| def _embed():
|
| """Lazy-load the sentence-transformer model. Called once at startup."""
|
| global _MODEL
|
| if _MODEL is None:
|
| from sentence_transformers import SentenceTransformer
|
| _MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
| return _MODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
| @tool
|
| def fetch_papers(query: str) -> str:
|
| """Fetch raw papers from Semantic Scholar (falls back to arXiv on failure).
|
| Returns raw JSON lines — one paper per line. No filtering applied."""
|
|
|
| _ss_error = None
|
|
|
|
|
| try:
|
| q = urllib.parse.quote(query)
|
| url = (
|
| f"https://api.semanticscholar.org/graph/v1/paper/search"
|
| f"?query={q}&limit=15"
|
| f"&fields=title,authors,year,venue,externalIds,abstract,citationCount"
|
| )
|
| req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
| with urllib.request.urlopen(req, timeout=12) as r:
|
| data = json.loads(r.read()).get("data", [])
|
| if data:
|
| lines = []
|
| for p in data:
|
| authors = ", ".join(a.get("name", "") for a in p.get("authors", [])[:3])
|
| pid = p.get("paperId", "")
|
| lines.append(json.dumps({
|
| "Title": p.get("title", ""),
|
| "Authors": authors,
|
| "Year": str(p.get("year", "")),
|
| "Journal": p.get("venue", ""),
|
| "Link": f"https://www.semanticscholar.org/paper/{pid}",
|
| "DOI": p.get("externalIds", {}).get("DOI", ""),
|
| "Abstract": (p.get("abstract") or "")[:500],
|
| "Citations": str(p.get("citationCount", "")),
|
| "Keywords": "",
|
| }))
|
| return "\n".join(lines)
|
| except Exception as e:
|
| _ss_error = str(e)
|
|
|
|
|
| try:
|
| q = urllib.parse.quote(query)
|
| url = f"https://export.arxiv.org/api/query?search_query=all:{q}&max_results=15"
|
| with urllib.request.urlopen(url, timeout=12) as r:
|
| tree = ET.fromstring(r.read())
|
| ns = {"a": "http://www.w3.org/2005/Atom"}
|
| lines = []
|
| for e in tree.findall("a:entry", ns):
|
| aid = (e.findtext("a:id", "", ns) or "").strip()
|
| authors = ", ".join(
|
| n.findtext("a:name", "", ns) for n in e.findall("a:author", ns)[:3]
|
| )
|
| lines.append(json.dumps({
|
| "Title": (e.findtext("a:title", "", ns) or "").replace("\n", " ").strip(),
|
| "Authors": authors,
|
| "Year": (e.findtext("a:published", "", ns) or "")[:4],
|
| "Journal": "arXiv",
|
| "Link": aid,
|
| "DOI": "",
|
| "Abstract": (e.findtext("a:summary", "", ns) or "").replace("\n", " ")[:500],
|
| "Citations": "",
|
| "Keywords": "",
|
| }))
|
| note = f"[Note: Semantic Scholar failed ({_ss_error}), results from arXiv]\n" if _ss_error else ""
|
| return note + ("\n".join(lines) if lines else "No results.")
|
| except Exception as e:
|
| return f"Fetch error — Semantic Scholar: {_ss_error} | arXiv: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| @tool
|
| def save_papers(papers_json: str) -> str:
|
| """Persist LLM-approved papers to papers.csv. Pass a JSON array of paper objects.
|
| Required keys: Title, Authors, Year, Journal, Link, DOI, Abstract, Citations, Keywords."""
|
| try:
|
| papers = json.loads(papers_json)
|
| if not isinstance(papers, list):
|
| papers = [papers]
|
|
|
|
|
| existing_titles: set[str] = set()
|
| existing_count = 0
|
| if os.path.isfile(CSV_PATH):
|
| with open(CSV_PATH, encoding="utf-8") as f:
|
| reader = csv.DictReader(f)
|
| for row in reader:
|
| existing_titles.add(row.get("Title", "").strip().lower())
|
| existing_count += 1
|
|
|
|
|
| new_papers = [
|
| p for p in papers
|
| if p.get("Title", "").strip()
|
| and p.get("Title", "").strip().lower() not in existing_titles
|
| ]
|
|
|
| if not new_papers:
|
| return f"0 new papers saved (all {len(papers)} were duplicates)."
|
|
|
| needs_header = not os.path.isfile(CSV_PATH)
|
| next_sr = existing_count + 1
|
|
|
| with open(CSV_PATH, "a", newline="", encoding="utf-8") as f:
|
| writer = csv.DictWriter(f, fieldnames=HEADERS, extrasaction="ignore")
|
| if needs_header:
|
| writer.writeheader()
|
| for i, p in enumerate(new_papers):
|
| p["Sr No"] = str(next_sr + i)
|
| writer.writerow(p)
|
|
|
| total = existing_count + len(new_papers)
|
| return f"Saved {len(new_papers)} new papers. Total: {total}/200."
|
|
|
| except json.JSONDecodeError as e:
|
| return f"Save error — invalid JSON: {e}"
|
| except Exception as e:
|
| return f"Save error: {e}"
|
|
|
|
|
|
|
|
|
|
|
| @tool
|
| def save_output(filename: str, content: str) -> str:
|
| """Save any LLM-generated content to outputs/<filename>.
|
| Use for: labels.json, themes.json, taxonomy_map.json, narrative.txt, etc.
|
| content: raw string (JSON or plain text) to write."""
|
| try:
|
|
|
| safe_name = os.path.basename(filename)
|
| if not safe_name:
|
| return "Save error: filename cannot be empty."
|
| path = _out(safe_name)
|
| with open(path, "w", encoding="utf-8") as f:
|
| f.write(content)
|
| kb = os.path.getsize(path) // 1024
|
| return f"Saved {safe_name} ({kb}KB)."
|
| except Exception as e:
|
| return f"Save error: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| @tool
|
| def read_output(filename: str) -> str:
|
| """Read a previously saved file from outputs/. Pass filename only (no paths).
|
| Use 'list' to see all available files.
|
| Use 'papers.csv' to get current paper count."""
|
|
|
| filename = filename.strip()
|
|
|
| if filename.lower() == "list":
|
| if not os.path.isdir(OUT_DIR):
|
| return "No outputs yet."
|
| files = sorted(os.listdir(OUT_DIR))
|
| if not files:
|
| return "No output files yet."
|
| return "Files:\n" + "\n".join(
|
| f" {fn} ({os.path.getsize(_out(fn)) // 1024}KB)" for fn in files
|
| )
|
|
|
| if filename == "papers.csv":
|
| if not os.path.isfile(CSV_PATH):
|
| return "No papers saved yet."
|
| with open(CSV_PATH, encoding="utf-8") as f:
|
|
|
| n = sum(1 for _ in f) - 1
|
| return f"{max(0, n)} papers in CSV."
|
|
|
|
|
| safe_name = os.path.basename(filename)
|
| resolved = os.path.realpath(_out(safe_name))
|
| allowed_root = os.path.realpath(OUT_DIR)
|
| if not resolved.startswith(allowed_root + os.sep) and resolved != allowed_root:
|
| return f"Access denied: '{filename}' is outside the outputs directory."
|
|
|
| if not os.path.isfile(resolved):
|
| return f"'{safe_name}' not found. Call read_output('list') to see available files."
|
|
|
| with open(resolved, encoding="utf-8", errors="replace") as f:
|
| content = f.read()
|
|
|
| if len(content) > 8000:
|
| content = content[:8000] + f"\n...[truncated — {len(content)} chars total]"
|
|
|
| return content
|
|
|
|
|
|
|
|
|
|
|
|
|
| @tool
|
| def run_clustering(csv_path: str, mode: str, max_clusters: int=15) -> str:
|
| """
|
| Reads the CSV, extracts either 'Abstract' or 'Title', embeds them,
|
| clusters them using AgglomerativeClustering, and extracts top representative
|
| sentences per cluster. Also generates Plotly HTML visualizations.
|
|
|
| Returns a dictionary of raw topic data for the LLM.
|
| """
|
| if mode not in ["abstract", "title"]:
|
| mode = "abstract"
|
|
|
| target_column = "Abstract" if mode == "abstract" else "Title"
|
|
|
| texts = []
|
| paper_ids = []
|
|
|
| import re
|
|
|
| with open(csv_path, encoding="utf-8") as f:
|
| reader = csv.DictReader(f)
|
| for i, row in enumerate(reader):
|
| text = row.get(target_column, "").strip()
|
|
|
| text = re.sub(r'https?://\S+|www\.\S+', '', text)
|
| text = re.sub(r'\S+@\S+', '', text)
|
| text = text.strip()
|
|
|
| if len(text) > 40:
|
| texts.append(text)
|
| paper_ids.append(row.get("Sr No", str(i)))
|
|
|
| if len(texts) < 3:
|
| raise ValueError("Not enough papers to cluster.")
|
|
|
| print(f"[{mode}] Embedding {len(texts)} texts...")
|
| model = _embed()
|
| embeddings = model.encode(texts, show_progress_bar=False)
|
|
|
|
|
| n_clusters = min(max_clusters, max(3, len(texts) // 10))
|
| print(f"[{mode}] Clustering into {n_clusters} clusters...")
|
|
|
| clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='cosine', linkage='average')
|
| labels = clusterer.fit_predict(embeddings)
|
|
|
|
|
| topic_map = {}
|
| for cluster_id in range(n_clusters):
|
| idx = np.where(labels == cluster_id)[0]
|
| if len(idx) == 0:
|
| continue
|
|
|
| cluster_embeddings = embeddings[idx]
|
| centroid = np.mean(cluster_embeddings, axis=0).reshape(1, -1)
|
| sim_scores = cosine_similarity(centroid, cluster_embeddings)[0]
|
|
|
|
|
| top_k = min(10, len(idx))
|
| top_indices = np.argsort(sim_scores)[::-1][:top_k]
|
|
|
| top_sentences = []
|
| seen_signatures = set()
|
|
|
| for i in top_indices:
|
| text = texts[idx[i]].strip()
|
|
|
| sig = text.lower()[:60]
|
| if sig not in seen_signatures:
|
| seen_signatures.add(sig)
|
|
|
| trunc = text[:250].replace('\n', ' ') + ("..." if len(text) > 250 else "")
|
| top_sentences.append(trunc)
|
|
|
| if len(top_sentences) == 3:
|
| break
|
|
|
| topic_map[str(cluster_id)] = {
|
| "size": len(idx),
|
| "top_sentences": top_sentences,
|
| "paper_ids": [str(paper_ids[i]) for i in idx]
|
| }
|
|
|
|
|
| np.save(os.path.join(OUT_DIR, f"{mode}_emb.npy"), embeddings)
|
| np.save(os.path.join(OUT_DIR, f"{mode}_cluster_labels.npy"), labels)
|
|
|
|
|
| _generate_visuals(mode, embeddings, labels, n_clusters, topic_map)
|
|
|
|
|
| raw_path = os.path.join(OUT_DIR, f"{mode}_raw_clusters.json")
|
| with open(raw_path, "w", encoding="utf-8") as f:
|
| json.dump(topic_map, f, indent=2)
|
|
|
|
|
| summ_path = os.path.join(OUT_DIR, f"{mode}_summaries.json")
|
| summaries = {
|
| tid: v["top_sentences"][0] if v["top_sentences"] else ""
|
| for tid, v in topic_map.items()
|
| }
|
| with open(summ_path, "w", encoding="utf-8") as f:
|
| json.dump(summaries, f, indent=2)
|
|
|
| return json.dumps(topic_map)
|
|
|
| def update_charts_with_labels(mode: str):
|
| """Regenerates the Plotly charts systematically after the LLM has labeled them."""
|
| emb_path = os.path.join(OUT_DIR, f"{mode}_emb.npy")
|
| lbl_path = os.path.join(OUT_DIR, f"{mode}_cluster_labels.npy")
|
| raw_path = os.path.join(OUT_DIR, f"{mode}_raw_clusters.json")
|
| names_path = os.path.join(OUT_DIR, f"{mode}_labels.json")
|
|
|
| if not all(os.path.exists(p) for p in [emb_path, lbl_path, raw_path, names_path]):
|
| return
|
|
|
| embeddings = np.load(emb_path)
|
| labels = np.load(lbl_path)
|
|
|
| with open(raw_path, "r", encoding="utf-8") as f: topic_map = json.load(f)
|
| with open(names_path, "r", encoding="utf-8") as f:
|
| names_data = json.load(f)
|
| if isinstance(names_data, dict) and "labels" in names_data:
|
| names_data = names_data["labels"]
|
|
|
| custom_names = {}
|
| for k, v in (names_data.items() if isinstance(names_data, dict) else {}):
|
| if isinstance(v, dict) and "label" in v:
|
| custom_names[k] = v["label"]
|
|
|
|
|
| n_clusters = max([int(k) for k in topic_map.keys()]) + 1 if topic_map else 0
|
| _generate_visuals(mode, embeddings, labels, n_clusters, topic_map, custom_names)
|
|
|
| def _generate_visuals(mode: str, embeddings, labels, n_clusters, topic_map, custom_names=None):
|
| print(f"[{mode}] Generating Plotly visuals (Dark theme -> White template applied)...")
|
| if custom_names is None: custom_names = {}
|
|
|
| def get_name(tid_str):
|
| return custom_names.get(tid_str, f"Topic {tid_str}")
|
|
|
|
|
| pca = PCA(n_components=2)
|
| coords = pca.fit_transform(embeddings)
|
|
|
|
|
| named_labels = [get_name(str(l)) for l in labels]
|
|
|
| fig_scatter = px.scatter(
|
| x=coords[:, 0], y=coords[:, 1], color=named_labels,
|
| title=f"Intertopic Distance ({mode.capitalize()})",
|
| labels={'color': 'Topic'}
|
| )
|
| fig_scatter.update_layout(template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
|
| fig_scatter.write_html(os.path.join(OUT_DIR, f"{mode}_intertopic.html"), include_plotlyjs="cdn")
|
|
|
|
|
| valid_ids = [str(i) for i in range(n_clusters) if str(i) in topic_map]
|
| sizes = [topic_map[vid]['size'] for vid in valid_ids]
|
| bar_names = [get_name(vid) for vid in valid_ids]
|
|
|
| fig_bar = px.bar(
|
| x=bar_names, y=sizes,
|
| title=f"Topic Frequency ({mode.capitalize()})",
|
| labels={'x': 'Topic', 'y': 'Number of Papers'}
|
| )
|
| fig_bar.update_layout(template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
|
| fig_bar.write_html(os.path.join(OUT_DIR, f"{mode}_bars.html"), include_plotlyjs="cdn")
|
|
|
|
|
|
|
| centroids = []
|
| valid_clusters = []
|
| for vid in valid_ids:
|
| idx = np.where(labels == int(vid))[0]
|
| centroids.append(np.mean(embeddings[idx], axis=0))
|
| valid_clusters.append(get_name(vid))
|
|
|
| if len(centroids) > 1:
|
| fig_dendro = ff.create_dendrogram(np.array(centroids), labels=valid_clusters)
|
| fig_dendro.update_layout(title=f"Topic Hierarchy ({mode.capitalize()})", template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
|
| fig_dendro.write_html(os.path.join(OUT_DIR, f"{mode}_hierarchy.html"), include_plotlyjs="cdn")
|
|
|
|
|
| if len(centroids) > 1:
|
| sim_matrix = cosine_similarity(np.array(centroids))
|
| fig_heat = px.imshow(
|
| sim_matrix,
|
| x=valid_clusters, y=valid_clusters,
|
| title=f"Topic Similarity Heatmap ({mode.capitalize()})",
|
| color_continuous_scale="Viridis"
|
| )
|
| fig_heat.update_layout(template="plotly_white", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')
|
| fig_heat.write_html(os.path.join(OUT_DIR, f"{mode}_heatmap.html"), include_plotlyjs="cdn") |