Agent_Workout_531 / app /agent /tools /citation_analyzer.py
T-K-O-H
Update LangChain imports and dependencies for compatibility
37856ac
from typing import List, Dict, Any
from langchain_core.tools import BaseTool
import networkx as nx
import matplotlib.pyplot as plt
import json
import os
from pydantic import Field, PrivateAttr
class CitationAnalyzerTool(BaseTool):
name: str = "citation_analyzer"
description: str = """Use this tool to analyze citation networks and research impact.
Input should be a list of papers or a specific paper to analyze.
The tool will:
- Build a citation network
- Calculate impact metrics
- Visualize connections
- Identify key papers in the field"""
graph_path: str = Field(default="data/citation_graph.json")
_graph: PrivateAttr = PrivateAttr()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._graph = nx.DiGraph()
# Load existing graph if available
if os.path.exists(self.graph_path):
with open(self.graph_path, 'r') as f:
graph_data = json.load(f)
self._graph = nx.node_link_graph(graph_data)
def _run(self, papers: str) -> str:
try:
# Parse input papers
papers_list = json.loads(papers)
# Add papers to graph
for paper in papers_list:
self._add_paper_to_graph(paper)
# Analyze the graph
analysis = {
"network_stats": self._calculate_network_stats(),
"key_papers": self._identify_key_papers(),
"communities": self._identify_communities(),
"visualization_path": self._generate_visualization()
}
# Save the updated graph
self._save_graph()
return json.dumps(analysis, indent=2)
except Exception as e:
return f"Error analyzing citations: {str(e)}"
def _add_paper_to_graph(self, paper: Dict[str, Any]) -> None:
paper_id = paper.get("id", paper.get("title", "unknown"))
# Add paper node
self._graph.add_node(paper_id, **paper)
# Add citation edges
for citation in paper.get("citations", []):
citation_id = citation.get("id", citation.get("title", "unknown"))
self._graph.add_edge(paper_id, citation_id)
def _calculate_network_stats(self) -> Dict[str, Any]:
return {
"number_of_papers": self._graph.number_of_nodes(),
"number_of_citations": self._graph.number_of_edges(),
"average_citations": sum(dict(self._graph.degree()).values()) / self._graph.number_of_nodes(),
"density": nx.density(self._graph),
"diameter": nx.diameter(self._graph) if nx.is_strongly_connected(self._graph) else "Not connected"
}
def _identify_key_papers(self) -> List[Dict[str, Any]]:
# Calculate centrality measures
degree_centrality = nx.degree_centrality(self._graph)
betweenness_centrality = nx.betweenness_centrality(self._graph)
pagerank = nx.pagerank(self._graph)
# Combine metrics
key_papers = []
for paper_id in self._graph.nodes():
key_papers.append({
"paper_id": paper_id,
"title": self._graph.nodes[paper_id].get("title", "Unknown"),
"degree_centrality": degree_centrality[paper_id],
"betweenness_centrality": betweenness_centrality[paper_id],
"pagerank": pagerank[paper_id]
})
# Sort by combined importance
key_papers.sort(key=lambda x: x["degree_centrality"] + x["betweenness_centrality"] + x["pagerank"], reverse=True)
return key_papers[:10] # Return top 10 key papers
def _identify_communities(self) -> List[List[str]]:
# Convert to undirected graph for community detection
undirected_graph = self._graph.to_undirected()
communities = list(nx.algorithms.community.greedy_modularity_communities(undirected_graph))
return [[self._graph.nodes[node_id].get("title", "Unknown") for node_id in community] for community in communities]
def _generate_visualization(self) -> str:
# Create visualization
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(self._graph)
nx.draw(self._graph, pos, with_labels=True, node_size=500, font_size=8)
# Save the visualization
viz_path = "data/citation_network.png"
plt.savefig(viz_path)
plt.close()
return viz_path
def _save_graph(self) -> None:
graph_data = nx.node_link_data(self._graph)
with open(self.graph_path, 'w') as f:
json.dump(graph_data, f)