Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| import networkx as nx | |
| from pyvis.network import Network | |
| import tempfile | |
| import openai | |
| import requests | |
| import xml.etree.ElementTree as ET | |
| import pandas as pd | |
| from io import StringIO | |
| import asyncio | |
| import base64 | |
| # --------------------------- | |
| # Model Loading & Caching | |
| # --------------------------- | |
| def load_summarizer(): | |
| # Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn) | |
| summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
| return summarizer | |
| def load_text_generator(): | |
| # For demonstration, we load a text-generation model such as GPT-2. | |
| generator = pipeline("text-generation", model="gpt2") | |
| return generator | |
| summarizer = load_summarizer() | |
| generator = load_text_generator() | |
| # --------------------------- | |
| # Idea Generation Functions | |
| # --------------------------- | |
| def generate_ideas_with_hf(prompt): | |
| # Generate ideas using a Hugging Face model; new tokens beyond the prompt. | |
| results = generator(prompt, max_new_tokens=50, num_return_sequences=1) | |
| idea_text = results[0]['generated_text'] | |
| return idea_text | |
| def generate_ideas_with_openai(prompt, api_key): | |
| """ | |
| Generates research ideas using OpenAI's GPT-3.5 (Streaming). | |
| """ | |
| openai.api_key = api_key | |
| output_text = "" | |
| async def stream_chat(): | |
| nonlocal output_text | |
| response = await openai.ChatCompletion.acreate( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| stream=True, | |
| ) | |
| st_text = st.empty() # Placeholder for streaming output | |
| async for chunk in response: | |
| delta = chunk["choices"][0].get("delta", {}) | |
| text_piece = delta.get("content", "") | |
| output_text += text_piece | |
| st_text.text(output_text) | |
| asyncio.run(stream_chat()) | |
| return output_text | |
| # --------------------------- | |
| # arXiv API Integration using xml.etree.ElementTree | |
| # --------------------------- | |
| def fetch_arxiv_results(query, max_results=5): | |
| """ | |
| Queries arXiv's free API and parses the result using ElementTree. | |
| """ | |
| base_url = "http://export.arxiv.org/api/query?" | |
| search_query = "search_query=all:" + query | |
| start = "0" | |
| max_results_str = str(max_results) | |
| query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}" | |
| response = requests.get(query_url) | |
| results = [] | |
| if response.status_code == 200: | |
| root = ET.fromstring(response.content) | |
| ns = {"atom": "http://www.w3.org/2005/Atom"} | |
| for entry in root.findall("atom:entry", ns): | |
| title_elem = entry.find("atom:title", ns) | |
| title = title_elem.text.strip() if title_elem is not None else "" | |
| summary_elem = entry.find("atom:summary", ns) | |
| summary = summary_elem.text.strip() if summary_elem is not None else "" | |
| published_elem = entry.find("atom:published", ns) | |
| published = published_elem.text.strip() if published_elem is not None else "" | |
| link_elem = entry.find("atom:id", ns) | |
| link = link_elem.text.strip() if link_elem is not None else "" | |
| authors = [author.find("atom:name", ns).text.strip() | |
| for author in entry.findall("atom:author", ns) | |
| if author.find("atom:name", ns) is not None] | |
| results.append({ | |
| "title": title, | |
| "summary": summary, | |
| "published": published, | |
| "link": link, | |
| "authors": ", ".join(authors) | |
| }) | |
| return results | |
| else: | |
| return [] | |
| # --------------------------- | |
| # Utility Function: Graph Download Link | |
| # --------------------------- | |
| def get_download_link(file_path, filename="graph.html"): | |
| """Converts the HTML file to a downloadable link.""" | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| html_data = f.read() | |
| b64 = base64.b64encode(html_data.encode()).decode() | |
| href = f'<a href="data:text/html;base64,{b64}" download="{filename}">Download Graph as HTML</a>' | |
| return href | |
| # --------------------------- | |
| # Streamlit Application Layout | |
| # --------------------------- | |
| st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0") | |
| # Sidebar: Configuration and Layout Options | |
| st.sidebar.header("Configuration") | |
| generation_mode = st.sidebar.selectbox("Select Idea Generation Mode", | |
| ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"]) | |
| openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password") | |
| layout_option = st.sidebar.selectbox("Select Graph Layout", ["Default", "Force Atlas 2"]) | |
| # --- Section 1: arXiv Paper Search --- | |
| st.header("arXiv Paper Search") | |
| arxiv_query = st.text_input("Enter a search query for arXiv papers:") | |
| if st.button("Search arXiv"): | |
| if arxiv_query.strip(): | |
| with st.spinner("Searching arXiv..."): | |
| results = fetch_arxiv_results(arxiv_query, max_results=5) | |
| if results: | |
| st.subheader("arXiv Search Results:") | |
| for idx, paper in enumerate(results): | |
| st.markdown(f"**{idx+1}. {paper['title']}**") | |
| st.markdown(f"*Authors:* {paper['authors']}") | |
| st.markdown(f"*Published:* {paper['published']}") | |
| st.markdown(f"*Summary:* {paper['summary']}") | |
| st.markdown(f"[Read more]({paper['link']})") | |
| st.markdown("---") | |
| else: | |
| st.error("No results found or an error occurred with the arXiv API.") | |
| else: | |
| st.error("Please enter a valid query for the arXiv search.") | |
| # --- Section 2: Research Paper Input and Idea Generation --- | |
| st.header("Research Paper Input") | |
| paper_abstract = st.text_area("Enter the research paper abstract:", height=200) | |
| if st.button("Generate Ideas"): | |
| if paper_abstract.strip(): | |
| st.subheader("Summarized Abstract") | |
| summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False) | |
| summary_text = summary[0]['summary_text'] | |
| st.write(summary_text) | |
| st.subheader("Generated Research Ideas") | |
| prompt = ( | |
| f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n" | |
| f"Paper Abstract:\n{paper_abstract}\n\n" | |
| f"Summary:\n{summary_text}\n\n" | |
| f"Research Ideas:" | |
| ) | |
| if generation_mode == "OpenAI GPT-3.5 (Streaming)": | |
| if not openai_api_key.strip(): | |
| st.error("Please provide your OpenAI API Key in the sidebar.") | |
| else: | |
| with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."): | |
| ideas = generate_ideas_with_openai(prompt, openai_api_key) | |
| st.write(ideas) | |
| else: | |
| with st.spinner("Generating ideas using Hugging Face open source model..."): | |
| ideas = generate_ideas_with_hf(prompt) | |
| st.write(ideas) | |
| else: | |
| st.error("Please enter a research paper abstract.") | |
| # --- Section 3: Knowledge Graph Visualization with Additional Features --- | |
| st.header("Knowledge Graph Visualization") | |
| st.markdown( | |
| "Enter paper details and citation relationships in CSV format:\n\n" | |
| "**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n" | |
| "Example:\n\n```\n1,Graph of AI Ideas: Leveraging Knowledge Graphs and LLMs for AI Research Idea Generation,2;3\n2,Fundamental Approaches in AI Literature,\n3,Applications of LLMs in Research Idea Generation,2\n```" | |
| ) | |
| # Optional filter input for node titles. | |
| filter_text = st.text_input("Optional: Enter keyword to filter nodes in the graph:") | |
| papers_csv = st.text_area("Enter paper details in CSV format:", height=150) | |
| if st.button("Generate Knowledge Graph"): | |
| if papers_csv.strip(): | |
| data = [] | |
| for line in papers_csv.splitlines(): | |
| parts = line.split(',') | |
| if len(parts) >= 3: | |
| paper_id = parts[0].strip() | |
| title = parts[1].strip() | |
| cited = parts[2].strip() | |
| cited_list = [c.strip() for c in cited.split(';') if c.strip()] | |
| data.append({"paper_id": paper_id, "title": title, "cited": cited_list}) | |
| if data: | |
| # Build the full graph. | |
| G = nx.DiGraph() | |
| for paper in data: | |
| G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"]))) | |
| for cited in paper["cited"]: | |
| G.add_edge(paper["paper_id"], cited) | |
| # Filter nodes if a keyword is provided. | |
| if filter_text.strip(): | |
| filtered_nodes = [n for n, d in G.nodes(data=True) if filter_text.lower() in d.get("title", "").lower()] | |
| if filtered_nodes: | |
| H = G.subgraph(filtered_nodes).copy() | |
| else: | |
| H = nx.DiGraph() | |
| else: | |
| H = G | |
| st.subheader("Knowledge Graph") | |
| # Create the Pyvis network. | |
| net = Network(height="500px", width="100%", directed=True) | |
| # Add nodes with tooltips (show title on hover). | |
| for node, node_data in H.nodes(data=True): | |
| net.add_node(node, label=node_data.get("title", str(node)), title=node_data.get("title", "No Title")) | |
| for source, target in H.edges(): | |
| net.add_edge(source, target) | |
| # Apply layout based on the user's selection. | |
| if layout_option == "Force Atlas 2": | |
| net.force_atlas_2based() | |
| # Write graph to temporary HTML file. | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") | |
| net.write_html(temp_file.name) | |
| # Show the graph. | |
| with open(temp_file.name, 'r', encoding='utf-8') as f: | |
| html_content = f.read() | |
| st.components.v1.html(html_content, height=500) | |
| # Provide a download link for the graph. | |
| st.markdown(get_download_link(temp_file.name), unsafe_allow_html=True) | |
| else: | |
| st.error("Please enter paper details for the knowledge graph.") | |