Spaces:
Runtime error
Runtime error
| from typing import List, Dict, Any | |
| from langchain_core.tools import BaseTool | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import json | |
| import os | |
| from pydantic import Field, PrivateAttr | |
| class CitationAnalyzerTool(BaseTool): | |
| name: str = "citation_analyzer" | |
| description: str = """Use this tool to analyze citation networks and research impact. | |
| Input should be a list of papers or a specific paper to analyze. | |
| The tool will: | |
| - Build a citation network | |
| - Calculate impact metrics | |
| - Visualize connections | |
| - Identify key papers in the field""" | |
| graph_path: str = Field(default="data/citation_graph.json") | |
| _graph: PrivateAttr = PrivateAttr() | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self._graph = nx.DiGraph() | |
| # Load existing graph if available | |
| if os.path.exists(self.graph_path): | |
| with open(self.graph_path, 'r') as f: | |
| graph_data = json.load(f) | |
| self._graph = nx.node_link_graph(graph_data) | |
| def _run(self, papers: str) -> str: | |
| try: | |
| # Parse input papers | |
| papers_list = json.loads(papers) | |
| # Add papers to graph | |
| for paper in papers_list: | |
| self._add_paper_to_graph(paper) | |
| # Analyze the graph | |
| analysis = { | |
| "network_stats": self._calculate_network_stats(), | |
| "key_papers": self._identify_key_papers(), | |
| "communities": self._identify_communities(), | |
| "visualization_path": self._generate_visualization() | |
| } | |
| # Save the updated graph | |
| self._save_graph() | |
| return json.dumps(analysis, indent=2) | |
| except Exception as e: | |
| return f"Error analyzing citations: {str(e)}" | |
| def _add_paper_to_graph(self, paper: Dict[str, Any]) -> None: | |
| paper_id = paper.get("id", paper.get("title", "unknown")) | |
| # Add paper node | |
| self._graph.add_node(paper_id, **paper) | |
| # Add citation edges | |
| for citation in paper.get("citations", []): | |
| citation_id = citation.get("id", citation.get("title", "unknown")) | |
| self._graph.add_edge(paper_id, citation_id) | |
| def _calculate_network_stats(self) -> Dict[str, Any]: | |
| return { | |
| "number_of_papers": self._graph.number_of_nodes(), | |
| "number_of_citations": self._graph.number_of_edges(), | |
| "average_citations": sum(dict(self._graph.degree()).values()) / self._graph.number_of_nodes(), | |
| "density": nx.density(self._graph), | |
| "diameter": nx.diameter(self._graph) if nx.is_strongly_connected(self._graph) else "Not connected" | |
| } | |
| def _identify_key_papers(self) -> List[Dict[str, Any]]: | |
| # Calculate centrality measures | |
| degree_centrality = nx.degree_centrality(self._graph) | |
| betweenness_centrality = nx.betweenness_centrality(self._graph) | |
| pagerank = nx.pagerank(self._graph) | |
| # Combine metrics | |
| key_papers = [] | |
| for paper_id in self._graph.nodes(): | |
| key_papers.append({ | |
| "paper_id": paper_id, | |
| "title": self._graph.nodes[paper_id].get("title", "Unknown"), | |
| "degree_centrality": degree_centrality[paper_id], | |
| "betweenness_centrality": betweenness_centrality[paper_id], | |
| "pagerank": pagerank[paper_id] | |
| }) | |
| # Sort by combined importance | |
| key_papers.sort(key=lambda x: x["degree_centrality"] + x["betweenness_centrality"] + x["pagerank"], reverse=True) | |
| return key_papers[:10] # Return top 10 key papers | |
| def _identify_communities(self) -> List[List[str]]: | |
| # Convert to undirected graph for community detection | |
| undirected_graph = self._graph.to_undirected() | |
| communities = list(nx.algorithms.community.greedy_modularity_communities(undirected_graph)) | |
| return [[self._graph.nodes[node_id].get("title", "Unknown") for node_id in community] for community in communities] | |
| def _generate_visualization(self) -> str: | |
| # Create visualization | |
| plt.figure(figsize=(12, 8)) | |
| pos = nx.spring_layout(self._graph) | |
| nx.draw(self._graph, pos, with_labels=True, node_size=500, font_size=8) | |
| # Save the visualization | |
| viz_path = "data/citation_network.png" | |
| plt.savefig(viz_path) | |
| plt.close() | |
| return viz_path | |
| def _save_graph(self) -> None: | |
| graph_data = nx.node_link_data(self._graph) | |
| with open(self.graph_path, 'w') as f: | |
| json.dump(graph_data, f) |