File size: 5,290 Bytes
e76c3e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a18334
e76c3e2
 
5a18334
 
 
e76c3e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f32550
e76c3e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# -*- coding: utf-8 -*-
"""Learn with Knowledge Graphs.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/16UX6wbUmaLG6YBJKzH5YouNYYnw2mL8H
"""

# app.py
import streamlit as st
import wikipediaapi
import requests, json
import networkx as nx
import matplotlib.pyplot as plt
from neo4j import GraphDatabase

# ---------------------------
# CONFIGURATION
# ---------------------------
# API Key for Perplexity
PPLX_API_KEY = "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"

# Optional Neo4j credentials (leave empty if not using Neo4j)
NEO4J_URI = "neo4j+s://1a780c1e.databases.neo4j.io"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "Xaabk9z1r5J-DPK6JPOH5QuOHL_MrTeFytx2c4sxjN4"

driver = None
if NEO4J_URI and NEO4J_USER and NEO4J_PASSWORD:
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))


# ---------------------------
# FUNCTIONS
# ---------------------------
def perplexity_chat(prompt, model="sonar-medium-online"):
    url = "https://api.perplexity.ai/chat/completions"
    headers = {
        "Authorization": f"Bearer {PPLX_API_KEY}",
        "Content-Type": "application/json",
    }
    data = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "temperature": 0,
    }
    resp = requests.post(url, headers=headers, data=json.dumps(data))
    if resp.status_code != 200:
        return f"❌ Error {resp.status_code}: {resp.text}"
    return resp.json()["choices"][0]["message"]["content"]


def extract_triples_from_chunk(text, max_triples=5):
    prompt = f"""Extract up to {max_triples} subject-predicate-object triples
from the text below. Return only triples in the format (subject, predicate, object).

Text: {text}"""

    content = perplexity_chat(prompt)
    triples = []
    for line in content.splitlines():
        line = line.strip(" ()[]{}")
        if not line:
            continue
        parts = [p.strip() for p in line.split(",")]
        if len(parts) == 3:
            triples.append(tuple(parts))
    return triples


def build_kg_from_wiki_title(title, lang="en", chunk_chars=800, max_triples_per_chunk=5):
    wiki = wikipediaapi.Wikipedia(lang, user_agent = "MyKGApp/1.0 (contact: your_email@example.com)")
    page = wiki.page(title)
    if not page.exists():
        return []

    text = page.text
    chunks = [text[i:i+chunk_chars] for i in range(0, len(text), chunk_chars)]

    triples = []
    for chunk in chunks:
        chunk_triples = extract_triples_from_chunk(chunk, max_triples=max_triples_per_chunk)
        triples.extend(chunk_triples)

    return triples


def insert_triple(tx, subject, predicate, obj):
    tx.run(
        """
        MERGE (s:Entity {name: $subject})
        MERGE (o:Entity {name: $object})
        MERGE (s)-[:RELATION {type: $predicate}]->(o)
        """,
        subject=subject, predicate=predicate, object=obj
    )

def insert_triples(triples):
    if not driver:
        return
    with driver.session() as session:
        for s, p, o in triples:
            session.execute_write(insert_triple, s, p, o)


def answer_with_kg(question, triples, top_k=10, model="sonar-medium-online"):
    context_triples = triples[:top_k]
    context_str = "\n".join([f"({s}, {p}, {o})" for s, p, o in context_triples])

    prompt = f"""
    You are a QA assistant.
    Use the following knowledge graph triples as context to answer the question.

    Knowledge Graph Triples:
    {context_str}

    Question: {question}

    Answer in a clear, concise way. If you don't find enough info in triples,
    say 'Not found in knowledge graph'.
    """
    return perplexity_chat(prompt, model=model)


# ---------------------------
# STREAMLIT APP
# ---------------------------
st.title("πŸ“š Knowledge Graph Chatbot (Wikipedia + Perplexity)")

# Input for Wikipedia Title
title = st.text_input("Enter a Wikipedia Title (e.g., Harry Potter):")

if title:
    st.write(f"πŸ” Building Knowledge Graph for: **{title}** ...")
    triples = build_kg_from_wiki_title(title)

    if triples:
        st.success(f"Extracted {len(triples)} triples βœ…")

        # Save in Neo4j if configured
        if driver:
            insert_triples(triples)
            st.info("πŸ“‘ Triples also stored in Neo4j.")

        # Show sample triples
        st.subheader("Sample Triples")
        st.json(triples[:10])

        # Visualization inside Streamlit
        st.subheader("Graph Visualization")
        G = nx.DiGraph()
        for s, p, o in triples[:30]:
            G.add_edge(s, o, label=p)

        plt.figure(figsize=(12, 8))
        pos = nx.spring_layout(G, k=0.5)
        nx.draw(G, pos, with_labels=True, node_size=2500, node_color="lightblue",
                font_size=10, font_weight="bold", arrows=True)
        edge_labels = nx.get_edge_attributes(G, 'label')
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
        st.pyplot(plt)

        # Chat interface
        st.subheader("πŸ’¬ Ask Questions")
        user_question = st.text_input("Your question:")
        if user_question:
            answer = answer_with_kg(user_question, triples)
            st.write("πŸ€–", answer)
    else:
        st.error("Page not found or no triples extracted.")