File size: 7,295 Bytes
4b22ce1 be89e03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | from flask import Flask, render_template, request, redirect, url_for, session
import networkx as nx
from pyvis.network import Network
import os, re, pickle
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from docx import Document
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import csv
from flask import Response
import io
app = Flask(__name__)
app.secret_key = "secret_key_for_session"
model_name = "Babelscape/rebel-large"
device = "cuda" if torch.cuda.is_available() else "cpu"
load_dotenv() # This loads the variables from .env
HF_TOKEN = os.getenv("HF_TOKEN")
rebel_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
#rebel_tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
rebel_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN, low_cpu_mem_usage=True).to(device)
DB_FILE = "graph_database.pkl"
def save_db(graph):
with open(DB_FILE, "wb") as f:
pickle.dump(graph, f)
def load_db():
if os.path.exists(DB_FILE):
try:
with open(DB_FILE, "rb") as f:
return pickle.load(f)
except: return nx.DiGraph()
return nx.DiGraph()
G = load_db()
def extract_triples(text):
inputs = rebel_tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
gen_kwargs = {"max_length": 128, "length_penalty": 0, "num_beams": 1, "num_return_sequences": 1}
generated_tokens = rebel_model.generate(**inputs, **gen_kwargs)
decoded = rebel_tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)[0]
triples = []
current_subject, current_relation, current_object = "", "", ""
current_state = ""
# ADD THESE TWO LINES TO FIX THE "FIRST WORD" PROBLEM
clean_decoded = decoded.replace("<s>", "").replace("</s>", "")
clean_decoded = clean_decoded.replace("<triplet>", " <triplet> ").replace("<subj>", " <subj> ").replace("<obj>", " <obj> ")
# CHANGE THIS LOOP TO USE clean_decoded
for token in clean_decoded.split():
if token == "<triplet>":
current_state = "s"
if current_subject and current_relation and current_object:
triples.append((current_subject.strip(), current_relation.strip(), current_object.strip()))
current_subject, current_relation, current_object = "", "", ""
elif token == "<subj>": current_state = "o"
elif token == "<obj>": current_state = "r"
else:
if current_state == "s": current_subject += " " + token
elif current_state == "o": current_object += " " + token
elif current_state == "r": current_relation += " " + token
if current_subject and current_relation and current_object:
triples.append((current_subject.strip(), current_relation.strip(), current_object.strip()))
return triples
def visualize_graph():
# Use absolute paths for the cloud environment
base_path = os.path.dirname(os.path.abspath(__file__))
static_path = os.path.join(base_path, 'static')
if not os.path.exists(static_path):
os.makedirs(static_path)
net = Network(height="600px", width="100%", directed=True, bgcolor="#ffffff", font_color="black", cdn_resources='remote')
net.force_atlas_2based(gravity=-50, central_gravity=0.01, spring_length=150, damping=0.4)
for node in G.nodes():
net.add_node(node, label=node, color="#00d2ff", size=25, shadow={'enabled': True, 'color': 'rgba(0,210,255,0.6)', 'size': 10})
for source, target, data in G.edges(data=True):
net.add_edge(source, target, label=data.get("label", ""), color="#a29bfe")
# Save using the absolute path
save_path = os.path.join(static_path, "graph.html")
net.save_graph(save_path)
@app.route("/", methods=["GET", "POST"])
def index():
global G
answer = None
user_query = ""
text = session.get('user_text', "")
if request.method == "POST":
# 1. HANDLE FILE UPLOAD OR TEXT BOX
if "file" in request.files and request.files["file"].filename != "":
file = request.files["file"]
ext = file.filename.split('.')[-1].lower()
if ext == "pdf":
reader = PdfReader(file)
text = " ".join([page.extract_text() for page in reader.pages])
elif ext == "docx":
text = " ".join([p.text for p in Document(file).paragraphs])
elif ext == "txt":
text = file.read().decode("utf-8")
elif "text" in request.form and request.form["text"].strip():
text = request.form["text"]
# 2. PROCESS DATA (Only if we have new text)
if text and "query" not in request.form:
session['user_text'] = text
sentences = [s.strip() for s in re.split(r'[\n.!?]', text) if len(s.strip()) > 10]
print(f"--- 🚀 AI is extracting from {len(sentences)} sentences ---")
for i, sent in enumerate(sentences):
print(f"📄 Processing {i+1}/{len(sentences)}...")
for s, r, o in extract_triples(sent):
G.add_edge(s.title().strip(), o.title().strip(), label=r.strip())
save_db(G)
visualize_graph()
# 3. HANDLE SEARCH QUERY
if "query" in request.form:
user_query = request.form["query"].strip()
keywords = [w.lower() for w in user_query.split() if len(w) > 3]
results = []
for node in G.nodes():
if any(k in node.lower() for k in keywords):
for n in G.successors(node):
results.append(f"<b>{node}</b> {G[node][n]['label']} <b>{n}</b>")
for p in G.predecessors(node):
results.append(f"<b>{p}</b> {G[p][node]['label']} <b>{node}</b>")
answer = " • " + "<br> • ".join(list(set(results))[:8]) if results else f"Nothing found for '{user_query}'."
db_triples = [{"s": s, "r": d['label'], "o": t} for s, t, d in G.edges(data=True)]
return render_template("index.html", answer=answer, graph=os.path.exists("static/graph.html"), user_query=user_query, user_text=text, db_triples=db_triples)
@app.route("/export_csv")
def export_csv():
# 1. Create a string buffer to hold CSV data
output = io.StringIO()
writer = csv.writer(output)
# 2. Write the Header
writer.writerow(['Subject', 'Relationship', 'Object'])
# 3. Write the Data from the Graph G
for s, t, d in G.edges(data=True):
writer.writerow([s, d.get('label', ''), t])
# 4. Prepare the response for download
output.seek(0)
return Response(
output,
mimetype="text/csv",
headers={"Content-disposition": "attachment; filename=knowledge_graph.csv"}
)
@app.route("/clear")
def clear_db():
global G
G = nx.DiGraph()
session.clear()
if os.path.exists(DB_FILE): os.remove(DB_FILE)
if os.path.exists("static/graph.html"): os.remove("static/graph.html")
return redirect(url_for('index'))
#if __name__ == "__main__":
# app.run(debug=True)
if __name__ == "__main__":
# 0.0.0.0 makes it accessible to the internet
app.run(host="0.0.0.0", port=7860) |