uni_qa / app.py
avojarot's picture
Upload app.py
f0754b4 verified
#!/usr/bin/env python3
"""GraphRAG v4 - Gradio Chat Demo (Pre-built graph + index support)
Startup behavior:
1. Check PREBUILT_DIR (default ./data/prebuilt) for pre-built artifacts
2. If found β†’ load graph + index instantly (no ML models needed for search)
3. If not found β†’ show upload panel for on-the-fly building
4. Users can always upload additional PDFs to rebuild
The pre-built artifacts are created by prebuild.py running offline.
"""
import json
import os
import shutil
import time
import logging
from pathlib import Path
from typing import List, Tuple, Optional
import gradio as gr
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("graphrag")
# ── Configuration ────────────────────────────────────────────────────
PREBUILT_DIR = Path(os.environ.get("PREBUILT_DIR", "./data/prebuilt"))
CORPUS_DIR = Path("/tmp/graphrag_corpus")
INDEX_DIR = Path("/tmp/graphrag_index")
# ── Global state ─────────────────────────────────────────────────────
_builder = None
_retriever = None
_qa = None
_system_ready = False
_load_mode = "none" # "prebuilt", "uploaded", "none"
# ── Pre-built loading (fast path β€” no ML models) ────────────────────
def try_load_prebuilt() -> Optional[str]:
"""Attempt to load pre-built graph and index. Returns status message or None."""
global _builder, _retriever, _qa, _system_ready, _load_mode
graph_dir = PREBUILT_DIR / "graph"
index_dir = PREBUILT_DIR / "index"
corpus_dir = PREBUILT_DIR / "corpus"
# Check if pre-built artifacts exist
if not (graph_dir / "graph_nodes.jsonl").exists():
logger.info(f"No prebuilt graph at {graph_dir}")
return None
if not (index_dir / "dense_index.faiss").exists():
logger.info(f"No prebuilt index at {index_dir}")
return None
start = time.time()
logger.info("Loading pre-built graph and index...")
try:
from graphrag_v4.graph_builder import KnowledgeGraphBuilder
from graphrag_v4.retriever import HybridRetriever
from graphrag_v4.qa import GraphRAGQA, LLMClient
# Load graph (fast: just JSON parsing, no PageRank/Leiden/embedding)
_builder = KnowledgeGraphBuilder()
success = _builder.load_graph(graph_dir)
if not success:
logger.error("Failed to load pre-built graph")
return None
stats = _builder.get_stats()
logger.info(f" Graph: {stats['total_nodes']} nodes, {stats['total_edges']} edges, "
f"{stats['communities']} communities")
# Load index (fast: FAISS deserialize + pickle)
_retriever = HybridRetriever(graph=_builder.graph, communities=_builder.communities)
_retriever.load(index_dir)
# Rebuild node→community mapping from loaded communities
for comm in _builder.communities.values():
for nid in comm.node_ids:
_retriever.node_to_community[nid] = comm.community_id
logger.info(f" Index: {len(_retriever.dense_index.doc_ids)} vectors, "
f"{len(_retriever.sparse_index.sparse_vectors)} sparse docs")
# Initialize QA (optional LLM)
llm = None
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
llm = LLMClient(api_key=api_key)
if not llm.available:
llm = None
_qa = GraphRAGQA(retriever=_retriever, llm_client=llm)
_system_ready = True
_load_mode = "prebuilt"
elapsed = time.time() - start
status = (
f"βœ… Pre-built graph loaded in {elapsed:.1f}s\n"
f"πŸ“Š Nodes: {stats['total_nodes']} | Edges: {stats['total_edges']} | "
f"Communities: {stats['communities']} | Cross-doc: {stats.get('cross_doc_entities', 0)}\n"
f"πŸ” Index: {len(_retriever.dense_index.doc_ids)} chunks indexed\n"
f"πŸ€– LLM: {'connected' if llm and llm.available else 'off (set OPENAI_API_KEY)'}"
)
logger.info(f"Pre-built system ready in {elapsed:.1f}s")
return status
except Exception as e:
logger.error(f"Failed to load prebuilt: {e}")
return None
# ── Upload-based building (slow path β€” needs ML models) ─────────────
def process_uploads(files, progress=gr.Progress()) -> str:
"""Process uploaded PDFs β†’ build corpus β†’ build KG β†’ index."""
global _builder, _retriever, _qa, _system_ready, _load_mode
if not files:
return "❌ No files uploaded."
_system_ready = False
start = time.time()
log_lines = []
def log(msg):
log_lines.append(msg)
try:
if CORPUS_DIR.exists():
shutil.rmtree(CORPUS_DIR)
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
pdf_dir = CORPUS_DIR / "pdfs"
pdf_dir.mkdir(parents=True, exist_ok=True)
pdf_count = 0
for f in files:
src = Path(f.name) if hasattr(f, 'name') else Path(f)
if src.suffix.lower() == ".pdf":
dst = pdf_dir / src.name
shutil.copy2(str(src), str(dst))
pdf_count += 1
log(f"πŸ“„ {src.name}")
if pdf_count == 0:
return "❌ No PDF files found in upload."
log(f"\nπŸ”§ Processing {pdf_count} PDF(s)...")
progress(0.1, desc="Building corpus...")
from graphrag_v4.corpus_builder import build_corpus
corpus_out = CORPUS_DIR / "output"
build_corpus(input_path=pdf_dir, output_dir=corpus_out, max_chunk_tokens=384)
stats_file = corpus_out / "stats.json"
if stats_file.exists():
stats = json.loads(stats_file.read_text())
log(f" Chunks: {stats['total_chunks']}, Entities: {stats['total_entities']}, Relations: {stats['total_relations']}")
else:
return "❌ Corpus building failed β€” no output produced."
progress(0.4, desc="Building knowledge graph...")
from graphrag_v4.graph_builder import KnowledgeGraphBuilder
_builder = KnowledgeGraphBuilder()
_builder.load_corpus(corpus_out)
_builder.build_cooccurrence_edges()
cross_doc = _builder.build_cross_document_edges()
_builder.compute_pagerank()
_builder.detect_communities(n_levels=2)
_builder.generate_community_summaries()
kg_stats = _builder.get_stats()
log(f"\nπŸ•ΈοΈ Knowledge Graph:")
log(f" Nodes: {kg_stats['total_nodes']}, Edges: {kg_stats['total_edges']}")
log(f" Communities: {kg_stats['communities']}, Cross-doc edges: {cross_doc}")
progress(0.7, desc="Building search index...")
from graphrag_v4.retriever import HybridRetriever
_retriever = HybridRetriever(graph=_builder.graph, communities=_builder.communities)
chunks = []
chunks_file = corpus_out / "chunks.jsonl"
if chunks_file.exists():
with open(chunks_file, "r", encoding="utf-8") as fh:
for line in fh:
chunks.append(json.loads(line))
_retriever.index_chunks(chunks)
_retriever.index_communities(_builder.communities)
progress(0.9, desc="Initializing QA...")
from graphrag_v4.qa import GraphRAGQA, LLMClient
llm = None
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
llm = LLMClient(api_key=api_key)
if llm.available:
log(f"\nπŸ€– LLM: {llm.model} connected")
else:
llm = None
_qa = GraphRAGQA(retriever=_retriever, llm_client=llm)
_system_ready = True
_load_mode = "uploaded"
elapsed = time.time() - start
log(f"\nβœ… Ready in {elapsed:.1f}s")
progress(1.0, desc="Done!")
return "\n".join(log_lines)
except Exception as e:
import traceback
return f"❌ Error: {str(e)}\n{traceback.format_exc()}"
# ── Chat ─────────────────────────────────────────────────────────────
def chat_respond(message: str, history: List) -> Tuple[List, str]:
"""
Handles the chat logic using the modern Gradio 'messages' format (list of dicts).
"""
if not _system_ready or _qa is None:
history.append({"role": "user", "content": message})
history.append({
"role": "assistant",
"content": "⚠️ System not ready. Please upload PDF files or wait for the pre-built graph to finish loading."
})
return history, ""
try:
start = time.time()
# Perform GraphRAG QA
result = _qa.answer(message, top_k=8, use_communities=True)
elapsed = time.time() - start
# Construct the response text
response_parts = [result.answer]
if result.sources:
response_parts.append("\n\n---\n**πŸ“š Sources:**")
for i, src in enumerate(result.sources[:5], 1):
scores = []
if src.dense_score > 0: scores.append(f"dense={src.dense_score:.2f}")
if src.sparse_score > 0: scores.append(f"sparse={src.sparse_score:.2f}")
if src.graph_score > 0: scores.append(f"graph={src.graph_score:.2f}")
score_str = f" ({', '.join(scores)})" if scores else ""
response_parts.append(
f"{i}. **{src.title[:60]}** β€” p.{src.page} β€” score: {src.score:.4f}{score_str}"
)
response_parts.append(
f"\n*Confidence: {result.confidence:.0%} | Time: {elapsed:.2f}s | {', '.join(result.reasoning)}*"
)
full_response = "\n".join(response_parts)
# Append to history using dictionary format
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": full_response})
except Exception as e:
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
# Return updated history and an empty string to clear the input textbox
return history, ""
def get_graph_info() -> str:
if not _system_ready or _builder is None:
return "No graph loaded."
stats = _builder.get_stats()
lines = [
f"## πŸ“Š Knowledge Graph Statistics\n",
f"**Load mode:** {_load_mode}",
f"**Nodes:** {stats['total_nodes']}",
]
if stats.get('node_types'):
for nt, count in stats['node_types'].items():
lines.append(f" - {nt}: {count}")
lines.append(f"\n**Edges:** {stats['total_edges']}")
if stats.get('edge_types'):
for et, count in stats['edge_types'].items():
lines.append(f" - {et}: {count}")
lines.append(f"\n**Communities:** {stats['communities']}")
lines.append(f"**Cross-doc entities:** {stats.get('cross_doc_entities', 0)}")
if _builder.communities:
lines.append("\n## 🏘️ Top Communities\n")
sorted_comms = sorted(_builder.communities.values(), key=lambda c: c.size, reverse=True)
for comm in sorted_comms[:5]:
entities = ", ".join(comm.key_entities[:3]) if comm.key_entities else "β€”"
lines.append(f"**{comm.community_id}** ({comm.size} nodes): {entities}")
return "\n".join(lines)
# ── Gradio UI ────────────────────────────────────────────────────────
def build_ui():
# Try loading prebuilt on startup
prebuilt_status = try_load_prebuilt()
with gr.Blocks(
title="GraphRAG v4 β€” Cross-Document Knowledge Graph QA",
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate"),
css="""
.main-header { text-align: center; margin-bottom: 1rem; }
.status-ready { background: #e8f5e9; border-radius: 8px; padding: 0.75rem; font-family: monospace; font-size: 0.85rem; white-space: pre-wrap; }
.status-waiting { background: #fff3e0; border-radius: 8px; padding: 0.75rem; }
footer { display: none !important; }
"""
) as demo:
gr.Markdown(
"# πŸ•ΈοΈ GraphRAG v4 β€” Cross-Document Knowledge Graph QA\n"
"**Pipeline:** PDF β†’ Chunks β†’ GLiNER Entities β†’ Proximity Relations β†’ Cross-Doc Linking β†’ Leiden Communities β†’ BGE-M3 Hybrid Search β†’ Answer",
elem_classes="main-header",
)
with gr.Row():
# Left panel
with gr.Column(scale=1):
# System status
if prebuilt_status:
gr.Markdown("### βœ… System Status")
status_box = gr.Textbox(
value=prebuilt_status,
label="",
lines=5,
interactive=False,
elem_classes="status-ready",
)
else:
gr.Markdown("### ⏳ System Status")
status_box = gr.Textbox(
value="No pre-built data found. Upload PDFs to get started.\n\n"
f"Looking for: {PREBUILT_DIR}/graph/\n"
"Tip: Run prebuild.py offline to create instant-load artifacts.",
label="",
lines=5,
interactive=False,
elem_classes="status-waiting",
)
with gr.Accordion("πŸ“ Upload New Documents", open=not bool(prebuilt_status)):
gr.Markdown(
"*Upload PDFs to build a new knowledge graph. "
"This replaces any pre-loaded data.*" if prebuilt_status
else "*Upload PDFs to build the knowledge graph.*"
)
upload = gr.File(file_count="multiple", file_types=[".pdf"], label="Upload PDFs")
build_btn = gr.Button("πŸ”§ Build Knowledge Graph", variant="primary", size="lg")
build_log = gr.Textbox(label="Build Log", lines=10, interactive=False)
with gr.Accordion("πŸ“Š Graph Info", open=bool(prebuilt_status)):
graph_info = gr.Markdown(get_graph_info() if prebuilt_status else "No graph loaded.")
refresh_btn = gr.Button("πŸ”„ Refresh Stats", size="sm")
# Right panel: Chat
with gr.Column(scale=2):
gr.Markdown("### πŸ’¬ Ask Questions")
chatbot = gr.Chatbot(
label="GraphRAG Chat",
height=520,
)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask a question about the documents..."
if prebuilt_status else
"Upload PDFs first, then ask questions...",
label="", scale=5, show_label=False,
)
send_btn = gr.Button("Send", variant="primary", scale=1)
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
gr.Markdown("### πŸ’‘ Example Questions")
gr.Examples(
examples=[
"Π―ΠΊΡ– Π΄ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚ΠΈ ΠΏΠΎΡ‚Ρ€Ρ–Π±Π½Ρ– для вступу?",
"Π₯Ρ‚ΠΎ Ρ” Π³Π°Ρ€Π°Π½Ρ‚ΠΎΠΌ ΠΎΡΠ²Ρ–Ρ‚Π½ΡŒΠΎΡ— ΠΏΡ€ΠΎΠ³Ρ€Π°ΠΌΠΈ?",
"Π’ΠΈΠΌΠΎΠ³ΠΈ Π΄ΠΎ Π±Π°ΠΊΠ°Π»Π°Π²Ρ€Π° Π· ΠΊΡ–Π±Π΅Ρ€Π±Π΅Π·ΠΏΠ΅ΠΊΠΈ",
"What are the admission requirements?",
"Summarize the main topics across all documents",
],
inputs=msg,
)
# Events
build_btn.click(fn=process_uploads, inputs=[upload], outputs=[build_log])
send_btn.click(fn=chat_respond, inputs=[msg, chatbot], outputs=[chatbot, msg])
msg.submit(fn=chat_respond, inputs=[msg, chatbot], outputs=[chatbot, msg])
clear_btn.click(fn=lambda: [], outputs=[chatbot])
refresh_btn.click(fn=get_graph_info, outputs=[graph_info])
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)