""" Main entry point for the RAG Gradio application. Loads environment variables, sets up context directory and model parameters, initializes retrieval and generation functions, and launches the interactive chat UI. Handles file uploads, user queries, and streaming LLM responses. """ import os import queue from threading import Thread from dotenv import load_dotenv load_dotenv() # Construct the path to the .env file relative to this script's location dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'RAG-LangChain', '.env') print(f"Start loading .env from {dotenv_path}") load_dotenv(dotenv_path=dotenv_path) print(f"Finish loading .env") from langchain.callbacks.base import BaseCallbackHandler print(f"Start importing from rag_func") from prepare import prepare_RAG from retrieve import retrieve_RAG from generate import generate_RAG from prepare import build_knowledge_graph ###### # --- Graph viz imports (Plotly + NetworkX) --- import json, math, random import networkx as nx import numpy as np import plotly.graph_objects as go import plotly.express as px try: from scipy.spatial import ConvexHull SCIPY_AVAILABLE = True except Exception: SCIPY_AVAILABLE = False ###### print(f"Finish importing from rag_func") import gradio as gr # -------------------- Context Setup -------------------- user_dir = "context" #print default print(f"[Info] Using context directory: {user_dir}") pinecone_API = os.getenv("PINECONE_API") index_name = os.getenv("INDEX_NAME") llm_model = os.getenv("MODELNAME") #index, pc, llm, kg_index = prepare_RAG(pinecone_API, index_name, llm_model=llm_model, dir_name=user_dir, graph_rag=(graph_rag=="True")) index, pc, llm, kg_index = None, None, None, None # -------------------- Chat Functions -------------------- def add_user_message(message, history): """ Adds a new user message to the chat history. Ensures the message is appended in the correct format for downstream processing. Returns updated history for use in the chat UI. """ history = history or [] history.append({"role": "user", "content": message}) return "", history, history import time # -------------------- Streaming Handler -------------------- class StreamHandler(BaseCallbackHandler): """ Callback handler for streaming LLM tokens to the UI. Tracks timing for first token and total response, buffers tokens, and manages the flow of streamed content for real-time display. """ def __init__(self, q: queue.Queue): self.q = q self.first_token_received = False self.ttft = None # time to first token self.total_time = None self.start_time = None self.buffer = [] # optional: accumulate tokens def on_llm_new_token(self, token: str, **kwargs): if not self.first_token_received: self.ttft = time.time() - self.start_time self.first_token_received = True self.buffer.append(token) self.q.put(token) def on_llm_end(self, *args, **kwargs): # IMPORTANT: do NOT end the consumer here. # Let the worker thread send [[FINAL]] (if any) and then [[END]]. self.total_time = time.time() - self.start_time # self.q.put("[[END]]") # <-- REMOVED (this was breaking before we could send [[FINAL]]) # -------------------- Chat Functions with timing -------------------- def generate_bot_response(history): """ Streams the first pass from the LLM to the UI and updates a styled progress box above the chat. """ global index, pc, llm, kg_index if not history or history[-1]["role"] != "user": yield history, history, "
Ready
" return user_msg = history[-1]["content"] documents = None # --- Stage 1: Initialize LLM / vector infra --- yield history, history, "
Initializing LLM and infrastructure...
" if not index or not pc or not llm: from langchain_mistralai import ChatMistralAI from langchain_openai import ChatOpenAI llm = ChatOpenAI(model=llm_model) if "gpt" in llm_model else ChatMistralAI(model=llm_model) index, pc, llm, documents = prepare_RAG( pinecone_API, index_name, llm_model=llm_model, dir_name=user_dir, info=True ) # --- Stage 2: Decide Graph RAG usage --- yield history, history, "
Deciding Graph RAG usage...
" def decide_graph_rag_usage(llm_, user_text: str) -> bool: prompt = ( "Given the following user prompt, determine whether graph RAG should be used (True or False):\n" f"{user_text}\n" "Use 'False' only if the prompt is focused on retrieving a single fact.\n" "Use 'True' if the prompt suggests reasoning over a large portion or the entirety of a dataset or corpus." ) resp = llm_.invoke(prompt) decision = (getattr(resp, "content", str(resp)) or "").strip() print("[Debug] Graph RAG decision response:", decision) return decision == "True" graph_rag_flag = decide_graph_rag_usage(llm, user_msg) print(f"[Info] Graph RAG usage decision: {graph_rag_flag}") if graph_rag_flag and not documents: _, _, _, documents = prepare_RAG( pinecone_API, index_name, llm_model=llm_model, dir_name=user_dir, info=True ) if graph_rag_flag: kg_index = build_knowledge_graph(documents, llm, pc, index, info=True) # --- Stage 3: Retrieve context --- yield history, history, "
Retrieving context...
" retrieved_chunks, graph_context = retrieve_RAG( user_msg, pc, index, kg_index, top_k=5, use_query_reformulation=True, llm=llm, graphRAG=graph_rag_flag ) # --- Stage 4: Generating response --- yield history, history, "
Generating response...
" FINAL_PREFIX = "[[FINAL]]" q = queue.Queue() handler = StreamHandler(q) handler.start_time = time.time() model_name = getattr(llm, "model_name", getattr(llm, "model", None)) streaming_llm = llm.__class__(model=model_name, streaming=True, callbacks=[handler]) def run_llm(): try: resp = generate_RAG( user_msg, streaming_llm, retrieved_chunks, graph_context, graphRAG=graph_rag_flag ) final_text = (getattr(resp, "content", str(resp)) or "").strip() if final_text: q.put(FINAL_PREFIX + final_text) finally: q.put("[[END]]") Thread(target=run_llm, daemon=True).start() partial = "" history.append({"role": "assistant", "content": ""}) while True: token = q.get() if token == "[[END]]": yield history, history, "
Completed!
" print(f"[Timing] TTFT: {handler.ttft:.3f} s, Total: {handler.total_time:.3f} s") break if token.startswith(FINAL_PREFIX): final = token[len(FINAL_PREFIX):] history[-1]["content"] = final yield history, history, "
Generating response...
" partial = final continue partial += token history[-1]["content"] = partial yield history, history, "
Generating response...
" # -------------------- Simplified CSS for Default Gradio Font -------------------- from pathlib import Path import gradio as gr # Load external assets custom_css = Path("app.css").read_text(encoding="utf-8") js_force_light = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.replace(url); } } """ # -------------------- Gradio App -------------------- import os import shutil MAX_TOTAL_SIZE_MB = 5 CONTEXT_DIR = "context" def handle_file_upload(uploaded_files): """ Validates and saves uploaded files to the context directory for RAG processing. Checks file extensions and total upload size against allowed limits. Returns a status message indicating success or failure for each upload attempt. """ context_dir = "context" os.makedirs(context_dir, exist_ok=True) saved_files = [] total_size_mb = 0 # Allowed extensions allowed_extensions = {".txt", ".json", ".md", ".csv", ".pdf", ".docx", ".pptx", ".py"} for file_obj in uploaded_files: # Check file extension ext = os.path.splitext(file_obj.name)[1].lower() if ext not in allowed_extensions: return f"❌ Unsupported file type: {ext}. Allowed types are: {', '.join(sorted(allowed_extensions))}" # Check size file_size_mb = os.path.getsize(file_obj.name) / (1024 * 1024) total_size_mb += file_size_mb if total_size_mb > MAX_TOTAL_SIZE_MB: return f"❌ Total upload size exceeds the limit of {MAX_TOTAL_SIZE_MB}MB." # Save file filename = os.path.basename(file_obj.name) dest_path = os.path.join(context_dir, filename) with open(file_obj.name, "rb") as src, open(dest_path, "wb") as dst: dst.write(src.read()) saved_files.append(dest_path) return f"✅ Uploaded {len(saved_files)} file(s) to '{context_dir}': {', '.join(os.path.basename(f) for f in saved_files)}" ######### # ---------- Graph viz core ---------- GRAPH_JSON_PATH = "knowledge_graph.json" COMMUNITY_MIN_SIZE = 3 MERGE_SMALLS_POLICY = "bucket" # or 'attach' LAYOUT_SEED = 42 LAYOUT_ITERS = 30 # Cached state (simple globals for now) _g_G = None _g_pos3d = None _g_node2comm = None _g_comm2nodes = None _g_edges = None _g_node_names = None def load_graph_from_json(path=GRAPH_JSON_PATH): """Read {source: [[rel, target], ...], ...} and return a DiGraph.""" try: with open(path, "r", encoding="utf-8") as f: graph_dict = json.load(f) except Exception: graph_dict = {} G = nx.DiGraph() for source, edges_list in graph_dict.items(): for relation, target in edges_list: G.add_edge(source, target, label=relation) if G.number_of_nodes() == 0: G.add_node("(empty)") return G def precompute_layout_and_communities(G: nx.DiGraph): """Compute 3D spring layout and top-level modularity communities.""" pos3d = nx.spring_layout(G, dim=3, seed=LAYOUT_SEED, iterations=LAYOUT_ITERS) node_names = list(G.nodes()) edges = list(G.edges()) # Greedy modularity communities (on undirected projection) from networkx.algorithms.community import greedy_modularity_communities UG = nx.Graph() UG.add_edges_from(G.to_undirected().edges()) communities = list(greedy_modularity_communities(UG)) large = [set(c) for c in communities if len(c) >= COMMUNITY_MIN_SIZE] small = [set(c) for c in communities if len(c) < COMMUNITY_MIN_SIZE] if MERGE_SMALLS_POLICY == "bucket" and small: other = set().union(*small) if small else set() if other: large.append(other) comm_ids = [f"C{i}" for i in range(len(large) - (1 if other else 0))] if other: comm_ids.append("C_other") elif MERGE_SMALLS_POLICY == "attach" and small and large: for s in small: # attach to the large community with the most cross-edges best_i, best_links = None, -1 for i, L in enumerate(large): links = sum(1 for u in s for v in L if UG.has_edge(u, v)) if links > best_links: best_i, best_links = i, links if best_i is None: best_i = max(range(len(large)), key=lambda i: len(large[i])) large[best_i].update(s) comm_ids = [f"C{i}" for i in range(len(large))] else: comm_ids = [f"C{i}" for i in range(len(large))] node2comm, comm2nodes = {}, {} for cid, nodeset in zip(comm_ids, large): comm2nodes[cid] = set(nodeset) for n in nodeset: node2comm[n] = cid for n in G.nodes(): if n not in node2comm: node2comm[n] = "C_isolated" comm2nodes.setdefault("C_isolated", set()).add(n) return pos3d, node2comm, comm2nodes, edges, node_names def _make_comm_colors(comm2nodes_dict): palette = (px.colors.qualitative.Alphabet + px.colors.qualitative.Set3 + px.colors.qualitative.Bold + px.colors.qualitative.Dark24 + px.colors.qualitative.Light24) cids = sorted(comm2nodes_dict.keys()) return {cid: palette[i % len(palette)] for i, cid in enumerate(cids)} def _community_hulls_traces(pos3d, comm2nodes, comm_colors, opacity=0.12): if not SCIPY_AVAILABLE: return [] hull_traces = [] for cid, nodeset in comm2nodes.items(): pts = np.array([pos3d[n] for n in nodeset if n in pos3d]) if pts.shape[0] < 4: continue try: hull = ConvexHull(pts) simplices = hull.simplices hull_traces.append(go.Mesh3d( x=pts[:, 0], y=pts[:, 1], z=pts[:, 2], i=simplices[:, 0], j=simplices[:, 1], k=simplices[:, 2], color=_make_comm_colors(comm2nodes).get(cid, "#cccccc"), opacity=opacity, name=f"{cid} region", hoverinfo="skip", showlegend=False )) except Exception: pass return hull_traces ######### def build_plotly_figure(mode="community", highlight_node=None, highlight_comm_id=None, dim_inter_edges=True, show_hulls=False): global _g_G, _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names # Load & cache if not present if _g_G is None: _g_G = load_graph_from_json() _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names = \ precompute_layout_and_communities(_g_G) G = _g_G pos3d = _g_pos3d node2comm = _g_node2comm comm2nodes = _g_comm2nodes edges = _g_edges # split intra/inter edges edge_x_intra, edge_y_intra, edge_z_intra = [], [], [] edge_x_inter, edge_y_inter, edge_z_inter = [], [], [] for (u, v) in edges: x0, y0, z0 = pos3d[u] x1, y1, z1 = pos3d[v] if node2comm.get(u) == node2comm.get(v): edge_x_intra += [x0, x1, None]; edge_y_intra += [y0, y1, None]; edge_z_intra += [z0, z1, None] else: edge_x_inter += [x0, x1, None]; edge_y_inter += [y0, y1, None]; edge_z_inter += [z0, z1, None] edge_traces = [] if edge_x_inter: edge_traces.append(go.Scatter3d( x=edge_x_inter, y=edge_y_inter, z=edge_z_inter, mode="lines", line=dict(width=1, color="rgba(180,180,180,0.30)" if dim_inter_edges else "#BBBBBB"), hoverinfo="none", showlegend=False, name="Inter-community" )) if edge_x_intra: edge_traces.append(go.Scatter3d( x=edge_x_intra, y=edge_y_intra, z=edge_z_intra, mode="lines", line=dict(width=2, color="rgba(120,120,120,0.55)"), hoverinfo="none", showlegend=False, name="Intra-community" )) comm_colors = _make_comm_colors(comm2nodes) hull_traces = _community_hulls_traces(pos3d, comm2nodes, comm_colors) if show_hulls else [] # neighbor sets (if needed) nbr_succ, nbr_pred = set(), set() if mode == "neighbors" and highlight_node and highlight_node in G: nbr_succ = set(G.neighbors(highlight_node)) nbr_pred = set(G.predecessors(highlight_node)) node_traces = [] for cid, nodeset in sorted(comm2nodes.items(), key=lambda kv: kv[0]): xs, ys, zs, texts, colors, sizes = [], [], [], [], [], [] base_color = comm_colors.get(cid, "#66c2a5") for n in nodeset: x, y, z = pos3d[n] xs.append(x); ys.append(y); zs.append(z); texts.append(n) if mode == "neighbors": if highlight_node == n: colors.append("red"); sizes.append(8.0) elif n in nbr_succ or n in nbr_pred: colors.append("orange"); sizes.append(6.5) elif highlight_node and node2comm.get(n) == node2comm.get(highlight_node): colors.append(base_color); sizes.append(5.5) else: colors.append("lightblue"); sizes.append(5.0) else: if highlight_comm_id and node2comm.get(n) == highlight_comm_id: colors.append(base_color); sizes.append(6.5) else: colors.append(base_color); sizes.append(5.0) if xs: node_traces.append(go.Scatter3d( x=xs, y=ys, z=zs, mode="markers", hovertext=texts, hoverinfo="text", marker=dict(size=sizes, color=colors, opacity=0.95), name=cid, showlegend=True )) fig = go.Figure(data=hull_traces + edge_traces + node_traces) fig.update_layout( title="3D Knowledge Graph — Communities & Neighbors", showlegend=True if mode == "community" else False, height=800, margin=dict(l=0, r=0, t=40, b=0), scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), aspectmode="data" ), scene_camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)), uirevision=True, ) return fig def reload_graph_cache(): """Force re-read knowledge_graph.json and recompute layout/communities.""" global _g_G, _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names _g_G = load_graph_from_json() _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names = \ precompute_layout_and_communities(_g_G) # Return a default figure return build_plotly_figure(mode="community", highlight_comm_id=None, dim_inter_edges=True, show_hulls=False) ######### with gr.Blocks(css=custom_css, fill_height=True, js=js_force_light) as demo: with gr.Row(): # LEFT SIDE: Branding + Upload with gr.Column(scale=1, elem_id="left-column"): # Branding row: logo and title side by side with gr.Row(elem_id="branding-row"): import base64 from pathlib import Path HERE = Path(__file__).resolve().parent logo_path = HERE / "logo_mono.png" with open(logo_path, "rb") as f: encoded = base64.b64encode(f.read()).decode() gr.HTML(f"""
mosaiicRAG
""") gr.Markdown( "

Daten verstehen. Wissen vernetzen. Entscheidungen stärken.

", elem_id="header2" ) with gr.Column(scale=4, elem_id="right-column"): with gr.Tabs(): # ------------------------- Chat tab (unchanged) ------------------------- with gr.Tab("Chat"): with gr.Column(elem_id="chat-area"): progress_box = gr.HTML("
Ready
") chatbot = gr.Chatbot(type="messages", label="Conversation", elem_id="chatbot") with gr.Row(elem_id="input-row"): msg = gr.Textbox(placeholder="Type your question here...", lines=1) send_btn = gr.Button("➤", elem_id="send-button", size="sm") state = gr.State([]) msg.submit(add_user_message, inputs=[msg, state], outputs=[msg, chatbot, state])\ .then(generate_bot_response, inputs=[state], outputs=[chatbot, state, progress_box]) send_btn.click(add_user_message, inputs=[msg, state], outputs=[msg, chatbot, state])\ .then(generate_bot_response, inputs=[state], outputs=[chatbot, state, progress_box]) # --------------------- Knowledge Graph tab (updated) --------------------- with gr.Tab("Knowledge Graph"): with gr.Row(): color_mode = gr.Radio( ["community"], value="community", label="Color mode" ) community_select = gr.Dropdown( label="Highlight community (optional)", choices=[], value=None ) view_opts = gr.CheckboxGroup( choices=[ "Dim inter-community edges", f"Show 3D community hulls{' (requires scipy)' if not SCIPY_AVAILABLE else ''}" ], value=["Dim inter-community edges"], label="View options" ) reload_btn = gr.Button("Reload graph") graph_plot = gr.Plot(label="3D Knowledge Graph") node_info = gr.Markdown("") # ---- functions bound to UI (defined above or inline) ---- def _init_graph(): # Rebuild cache from knowledge_graph.json and return default figure fig = reload_graph_cache() cids = sorted(list(_g_comm2nodes.keys())) if _g_comm2nodes else [] info = "Select a community or click a node to highlight its community." # Use gr.update to set dropdown choices return fig, gr.update(choices=cids, value=None), info def _refresh(mode, selected_cid, opts): dim_edges = isinstance(opts, list) and ("Dim inter-community edges" in opts) show_hulls = isinstance(opts, list) and any("Show 3D community hulls" in s for s in opts) fig = build_plotly_figure( mode="community" if mode == "community" else "neighbors", highlight_comm_id=(selected_cid if mode == "community" else None), dim_inter_edges=dim_edges, show_hulls=(show_hulls if mode == "community" else False) ) info = ( "Select a community or click a node to highlight its community." if mode == "community" else "Click a node to see its neighbors (community tint applied)." ) return fig, info def _reload(mode, selected_cid, opts): # Reload data and recompute communities/layout _ = reload_graph_cache() cids = sorted(list(_g_comm2nodes.keys())) if _g_comm2nodes else [] # Immediately apply current UI options on the new graph state fig, info = _refresh(mode, selected_cid, opts) return fig, gr.update(choices=cids, value=selected_cid), info # wire controls color_mode.change(_refresh, inputs=[color_mode, community_select, view_opts], outputs=[graph_plot, node_info]) community_select.change(_refresh, inputs=[color_mode, community_select, view_opts], outputs=[graph_plot, node_info]) view_opts.change(_refresh, inputs=[color_mode, community_select, view_opts], outputs=[graph_plot, node_info]) reload_btn.click(_reload, inputs=[color_mode, community_select, view_opts], outputs=[graph_plot, community_select, node_info]) # ------------------------ IMPORTANT: INSIDE THE BLOCKS ------------------------ # Initialize the graph once when the app loads (now inside the Blocks context) demo.load(_init_graph, inputs=[], outputs=[graph_plot, community_select, node_info]) # -------------------- Launch App -------------------- if __name__ == "__main__": demo.launch(inbrowser=True)