Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from langchain_community.graphs import Neo4jGraph | |
| import pandas as pd | |
| import json | |
| import time | |
| from ki_gen.planner import build_planner_graph | |
| from ki_gen.utils import init_app, memory | |
| from ki_gen.prompts import get_initial_prompt | |
| from neo4j import GraphDatabase | |
| # Set page config | |
| st.set_page_config(page_title="Key Issue Generator", layout="wide") | |
| # Neo4j Database Configuration | |
| NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io" | |
| NEO4J_USERNAME = "neo4j" | |
| NEO4J_PASSWORD = os.getenv("neo4j_password") | |
| # API Keys for LLM services | |
| OPENAI_API_KEY = os.getenv("openai_api_key") | |
| GROQ_API_KEY = os.getenv("groq_api_key") | |
| LANGSMITH_API_KEY = os.getenv("langsmith_api_key") | |
| def verify_neo4j_connectivity(): | |
| """Verify connection to Neo4j database""" | |
| try: | |
| with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: | |
| return driver.verify_connectivity() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def load_config(): | |
| """Load configuration with custom parameters""" | |
| # Custom configuration based on provided parameters | |
| custom_config = { | |
| "main_llm": "deepseek-r1-distill-llama-70b", | |
| "plan_method": "generation", | |
| "use_detailed_query": False, | |
| "cypher_gen_method": "guided", | |
| "validate_cypher": False, | |
| "summarize_model": "deepseek-r1-distill-llama-70b", | |
| "eval_method": "binary", | |
| "eval_threshold": 0.7, | |
| "max_docs": 15, | |
| "compression_method": "llm_lingua", | |
| "compress_rate": 0.33, | |
| "force_tokens": ["."], # Converting to list format as expected by the application | |
| "eval_model": "deepseek-r1-distill-llama-70b", | |
| "thread_id": "3" | |
| } | |
| # Add Neo4j graph object to config | |
| try: | |
| neo_graph = Neo4jGraph( | |
| url=NEO4J_URI, | |
| username=NEO4J_USERNAME, | |
| password=NEO4J_PASSWORD | |
| ) | |
| custom_config["graph"] = neo_graph | |
| except Exception as e: | |
| st.error(f"Error connecting to Neo4j: {e}") | |
| return None | |
| return {"configurable": custom_config} | |
| def generate_key_issues(user_query): | |
| """Main function to generate key issues from Neo4j data""" | |
| # Initialize application with API keys | |
| init_app( | |
| openai_key=OPENAI_API_KEY, | |
| groq_key=GROQ_API_KEY, | |
| langsmith_key=LANGSMITH_API_KEY | |
| ) | |
| # Load configuration with custom parameters | |
| config = load_config() | |
| if not config: | |
| return None | |
| # Create status containers | |
| plan_status = st.empty() | |
| plan_display = st.empty() | |
| retrieval_status = st.empty() | |
| processing_status = st.empty() | |
| # Build planner graph | |
| plan_status.info("Building planner graph...") | |
| graph = build_planner_graph(memory, config["configurable"]) | |
| # Execute initial prompt generation | |
| plan_status.info(f"Generating plan for query: {user_query}") | |
| messages_content = [] | |
| for event in graph.stream(get_initial_prompt(config, user_query), config, stream_mode="values"): | |
| if "messages" in event: | |
| event["messages"][-1].pretty_print() | |
| messages_content.append(event["messages"][-1].content) | |
| # Get the state with the generated plan | |
| state = graph.get_state(config) | |
| steps = [i for i in range(1, len(state.values['store_plan'])+1)] | |
| plan_df = pd.DataFrame({'Plan steps': steps, 'Description': state.values['store_plan']}) | |
| # Display the plan | |
| plan_status.success("Plan generation complete!") | |
| plan_display.dataframe(plan_df, use_container_width=True) | |
| # Continue with plan execution for document retrieval | |
| retrieval_status.info("Retrieving documents...") | |
| for event in graph.stream(None, config, stream_mode="values"): | |
| if "messages" in event: | |
| event["messages"][-1].pretty_print() | |
| messages_content.append(event["messages"][-1].content) | |
| # Get updated state after document retrieval | |
| snapshot = graph.get_state(config) | |
| doc_count = len(snapshot.values.get('valid_docs', [])) | |
| retrieval_status.success(f"Retrieved {doc_count} documents") | |
| # Proceed to document processing | |
| processing_status.info("Processing documents...") | |
| process_steps = ["summarize"] # Using summarize as default processing step | |
| # Update state to indicate human validation is complete and specify processing steps | |
| graph.update_state(config, {'human_validated': True, 'process_steps': process_steps}, as_node="human_validation") | |
| # Continue execution with document processing | |
| for event in graph.stream(None, config, stream_mode="values"): | |
| if "messages" in event: | |
| event["messages"][-1].pretty_print() | |
| messages_content.append(event["messages"][-1].content) | |
| # Get final state after processing | |
| final_snapshot = graph.get_state(config) | |
| processing_status.success("Document processing complete!") | |
| if "messages" in final_snapshot.values: | |
| final_result = final_snapshot.values["messages"][-1].content | |
| return final_result, final_snapshot.values.get('valid_docs', []) | |
| return None, [] | |
| # App header | |
| st.title("Key Issue Generator") | |
| st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.") | |
| # Check database connectivity | |
| connectivity_status = verify_neo4j_connectivity() | |
| st.sidebar.header("Database Status") | |
| if "Error" not in str(connectivity_status): | |
| st.sidebar.success("Connected to Neo4j database") | |
| else: | |
| st.sidebar.error(f"Database connection issue: {connectivity_status}") | |
| # User input section | |
| st.header("Enter Your Query") | |
| user_query = st.text_area("What would you like to explore?", | |
| "What are the main challenges in AI adoption for healthcare systems?", | |
| height=100) | |
| # Process button | |
| if st.button("Generate Key Issues", type="primary"): | |
| if not OPENAI_API_KEY or not GROQ_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD: | |
| st.error("Required API keys or database credentials are missing. Please check your environment variables.") | |
| else: | |
| with st.spinner("Processing your query..."): | |
| start_time = time.time() | |
| final_result, valid_docs = generate_key_issues(user_query) | |
| end_time = time.time() | |
| if final_result: | |
| # Display execution time | |
| st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds") | |
| # Display final result | |
| st.header("Generated Key Issues") | |
| st.markdown(final_result) | |
| # Option to download results | |
| st.download_button( | |
| label="Download Results", | |
| data=final_result, | |
| file_name="key_issues_results.txt", | |
| mime="text/plain" | |
| ) | |
| # Display retrieved documents in expandable section | |
| if valid_docs: | |
| with st.expander("View Retrieved Documents"): | |
| for i, doc in enumerate(valid_docs): | |
| st.markdown(f"### Document {i+1}") | |
| for key in doc: | |
| st.markdown(f"**{key}**: {doc[key]}") | |
| st.divider() | |
| else: | |
| st.error("An error occurred during processing. Please check the logs for details.") | |
| # Help information in sidebar | |
| with st.sidebar: | |
| st.header("About") | |
| st.info(""" | |
| This application uses advanced language models to analyze a Neo4j knowledge graph and generate key issues | |
| based on your query. The process involves: | |
| 1. Creating a plan based on your query | |
| 2. Retrieving relevant documents from the database | |
| 3. Processing and summarizing the information | |
| 4. Generating a comprehensive response | |
| """) |