Knowledge_Graphs / Graphs.py
Phani-ISB's picture
Initial commit - clean setup for HF Space
2f32550
# -*- coding: utf-8 -*-
"""Learn with Knowledge Graphs.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/16UX6wbUmaLG6YBJKzH5YouNYYnw2mL8H
"""
# app.py
import streamlit as st
import wikipediaapi
import requests, json
import networkx as nx
import matplotlib.pyplot as plt
from neo4j import GraphDatabase
# ---------------------------
# CONFIGURATION
# ---------------------------
# API Key for Perplexity
PPLX_API_KEY = "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"
# Optional Neo4j credentials (leave empty if not using Neo4j)
NEO4J_URI = "neo4j+s://1a780c1e.databases.neo4j.io"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "Xaabk9z1r5J-DPK6JPOH5QuOHL_MrTeFytx2c4sxjN4"
driver = None
if NEO4J_URI and NEO4J_USER and NEO4J_PASSWORD:
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
# ---------------------------
# FUNCTIONS
# ---------------------------
def perplexity_chat(prompt, model="sonar-medium-online"):
url = "https://api.perplexity.ai/chat/completions"
headers = {
"Authorization": f"Bearer {PPLX_API_KEY}",
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
}
resp = requests.post(url, headers=headers, data=json.dumps(data))
if resp.status_code != 200:
return f"❌ Error {resp.status_code}: {resp.text}"
return resp.json()["choices"][0]["message"]["content"]
def extract_triples_from_chunk(text, max_triples=5):
prompt = f"""Extract up to {max_triples} subject-predicate-object triples
from the text below. Return only triples in the format (subject, predicate, object).
Text: {text}"""
content = perplexity_chat(prompt)
triples = []
for line in content.splitlines():
line = line.strip(" ()[]{}")
if not line:
continue
parts = [p.strip() for p in line.split(",")]
if len(parts) == 3:
triples.append(tuple(parts))
return triples
def build_kg_from_wiki_title(title, lang="en", chunk_chars=800, max_triples_per_chunk=5):
wiki = wikipediaapi.Wikipedia(lang, user_agent = "MyKGApp/1.0 (contact: your_email@example.com)")
page = wiki.page(title)
if not page.exists():
return []
text = page.text
chunks = [text[i:i+chunk_chars] for i in range(0, len(text), chunk_chars)]
triples = []
for chunk in chunks:
chunk_triples = extract_triples_from_chunk(chunk, max_triples=max_triples_per_chunk)
triples.extend(chunk_triples)
return triples
def insert_triple(tx, subject, predicate, obj):
tx.run(
"""
MERGE (s:Entity {name: $subject})
MERGE (o:Entity {name: $object})
MERGE (s)-[:RELATION {type: $predicate}]->(o)
""",
subject=subject, predicate=predicate, object=obj
)
def insert_triples(triples):
if not driver:
return
with driver.session() as session:
for s, p, o in triples:
session.execute_write(insert_triple, s, p, o)
def answer_with_kg(question, triples, top_k=10, model="sonar-medium-online"):
context_triples = triples[:top_k]
context_str = "\n".join([f"({s}, {p}, {o})" for s, p, o in context_triples])
prompt = f"""
You are a QA assistant.
Use the following knowledge graph triples as context to answer the question.
Knowledge Graph Triples:
{context_str}
Question: {question}
Answer in a clear, concise way. If you don't find enough info in triples,
say 'Not found in knowledge graph'.
"""
return perplexity_chat(prompt, model=model)
# ---------------------------
# STREAMLIT APP
# ---------------------------
st.title("πŸ“š Knowledge Graph Chatbot (Wikipedia + Perplexity)")
# Input for Wikipedia Title
title = st.text_input("Enter a Wikipedia Title (e.g., Harry Potter):")
if title:
st.write(f"πŸ” Building Knowledge Graph for: **{title}** ...")
triples = build_kg_from_wiki_title(title)
if triples:
st.success(f"Extracted {len(triples)} triples βœ…")
# Save in Neo4j if configured
if driver:
insert_triples(triples)
st.info("πŸ“‘ Triples also stored in Neo4j.")
# Show sample triples
st.subheader("Sample Triples")
st.json(triples[:10])
# Visualization inside Streamlit
st.subheader("Graph Visualization")
G = nx.DiGraph()
for s, p, o in triples[:30]:
G.add_edge(s, o, label=p)
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=0.5)
nx.draw(G, pos, with_labels=True, node_size=2500, node_color="lightblue",
font_size=10, font_weight="bold", arrows=True)
edge_labels = nx.get_edge_attributes(G, 'label')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
st.pyplot(plt)
# Chat interface
st.subheader("πŸ’¬ Ask Questions")
user_question = st.text_input("Your question:")
if user_question:
answer = answer_with_kg(user_question, triples)
st.write("πŸ€–", answer)
else:
st.error("Page not found or no triples extracted.")