import os import requests from dotenv import load_dotenv from langchain_community.graphs import Neo4jGraph import streamlit as st from streamlit.logger import get_logger from chains import load_embedding_model from utils import create_constraints, create_vector_index from PIL import Image load_dotenv(".env") url = os.getenv("NEO4J_URI") username = os.getenv("NEO4J_USERNAME") password = os.getenv("NEO4J_PASSWORD") ollama_base_url = os.getenv("OLLAMA_BASE_URL") embedding_model_name = os.getenv("EMBEDDING_MODEL") # Remapping for Langchain Neo4j integration os.environ["NEO4J_URL"] = url logger = get_logger(__name__) so_api_base_url = "https://api.stackexchange.com/2.3/search/advanced" embeddings, dimension = load_embedding_model( embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger ) # if Neo4j is local, you can go to http://localhost:7474/ to browse the database neo4j_graph = Neo4jGraph( url=url, username=username, password=password, refresh_schema=False ) create_constraints(neo4j_graph) create_vector_index(neo4j_graph) def load_so_data(tag: str = "neo4j", page: int = 1) -> None: parameters = ( f"?pagesize=100&page={page}&order=desc&sort=creation&answers=1&tagged={tag}" "&site=stackoverflow&filter=!*236eb_eL9rai)MOSNZ-6D3Q6ZKb0buI*IVotWaTb" ) data = requests.get(so_api_base_url + parameters).json() insert_so_data(data) def load_high_score_so_data() -> None: parameters = ( f"?fromdate=1664150400&order=desc&sort=votes&site=stackoverflow&" "filter=!.DK56VBPooplF.)bWW5iOX32Fh1lcCkw1b_Y6Zkb7YD8.ZMhrR5.FRRsR6Z1uK8*Z5wPaONvyII" ) data = requests.get(so_api_base_url + parameters).json() insert_so_data(data) def insert_so_data(data: dict) -> None: # Calculate embedding values for questions and answers for q in data["items"]: question_text = q["title"] + "\n" + q["body_markdown"] q["embedding"] = embeddings.embed_query(question_text) for a in q["answers"]: a["embedding"] = embeddings.embed_query( question_text + "\n" + a["body_markdown"] ) # Cypher, the query language of Neo4j, is used to import the data # https://neo4j.com/docs/getting-started/cypher-intro/ # https://neo4j.com/docs/cypher-cheat-sheet/5/auradb-enterprise/ import_query = """ UNWIND $data AS q MERGE (question:Question {id:q.question_id}) ON CREATE SET question.title = q.title, question.link = q.link, question.score = q.score, question.favorite_count = q.favorite_count, question.creation_date = datetime({epochSeconds: q.creation_date}), question.body = q.body_markdown, question.embedding = q.embedding FOREACH (tagName IN q.tags | MERGE (tag:Tag {name:tagName}) MERGE (question)-[:TAGGED]->(tag) ) FOREACH (a IN q.answers | MERGE (question)<-[:ANSWERS]-(answer:Answer {id:a.answer_id}) SET answer.is_accepted = a.is_accepted, answer.score = a.score, answer.creation_date = datetime({epochSeconds:a.creation_date}), answer.body = a.body_markdown, answer.embedding = a.embedding MERGE (answerer:User {id:coalesce(a.owner.user_id, "deleted")}) ON CREATE SET answerer.display_name = a.owner.display_name, answerer.reputation= a.owner.reputation MERGE (answer)<-[:PROVIDED]-(answerer) ) WITH * WHERE NOT q.owner.user_id IS NULL MERGE (owner:User {id:q.owner.user_id}) ON CREATE SET owner.display_name = q.owner.display_name, owner.reputation = q.owner.reputation MERGE (owner)-[:ASKED]->(question) """ neo4j_graph.query(import_query, {"data": data["items"]}) # Streamlit def get_tag() -> str: input_text = st.text_input( "Which tag questions do you want to import?", value="neo4j" ) return input_text def get_pages(): col1, col2 = st.columns(2) with col1: num_pages = st.number_input( "Number of pages (100 questions per page)", step=1, min_value=1 ) with col2: start_page = st.number_input("Start page", step=1, min_value=1) st.caption("Only questions with answers will be imported.") return (int(num_pages), int(start_page)) def render_page(): datamodel_image = Image.open("./images/datamodel.png") st.header("StackOverflow Loader") st.subheader("Choose StackOverflow tags to load into Neo4j") st.caption("Go to http://localhost:7474/ to explore the graph.") user_input = get_tag() num_pages, start_page = get_pages() if st.button("Import", type="primary"): with st.spinner("Loading... This might take a minute or two."): try: for page in range(1, num_pages + 1): load_so_data(user_input, start_page + (page - 1)) st.success("Import successful", icon="✅") st.caption("Data model") st.image(datamodel_image) st.caption("Go to http://localhost:7474/ to interact with the database") except Exception as e: st.error(f"Error: {e}", icon="🚨") with st.expander("Highly ranked questions rather than tags?"): if st.button("Import highly ranked questions"): with st.spinner("Loading... This might take a minute or two."): try: load_high_score_so_data() st.success("Import successful", icon="✅") except Exception as e: st.error(f"Error: {e}", icon="🚨") render_page()