| """ |
| Main entry point for the RAG Gradio application. |
| |
| Loads environment variables, sets up context directory and model parameters, |
| initializes retrieval and generation functions, and launches the interactive chat UI. |
| Handles file uploads, user queries, and streaming LLM responses. |
| """ |
|
|
|
|
| import os |
| import queue |
| from threading import Thread |
| from dotenv import load_dotenv |
| load_dotenv() |
| |
| dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'RAG-LangChain', '.env') |
| print(f"Start loading .env from {dotenv_path}") |
| load_dotenv(dotenv_path=dotenv_path) |
| print(f"Finish loading .env") |
| from langchain.callbacks.base import BaseCallbackHandler |
| print(f"Start importing from rag_func") |
| from prepare import prepare_RAG |
| from retrieve import retrieve_RAG |
| from generate import generate_RAG |
| from prepare import build_knowledge_graph |
|
|
|
|
| |
| |
| import json, math, random |
| import networkx as nx |
| import numpy as np |
| import plotly.graph_objects as go |
| import plotly.express as px |
| try: |
| from scipy.spatial import ConvexHull |
| SCIPY_AVAILABLE = True |
| except Exception: |
| SCIPY_AVAILABLE = False |
|
|
| |
|
|
|
|
| print(f"Finish importing from rag_func") |
| import gradio as gr |
|
|
| |
|
|
| user_dir = "context" |
| |
| print(f"[Info] Using context directory: {user_dir}") |
|
|
| pinecone_API = os.getenv("PINECONE_API") |
| index_name = os.getenv("INDEX_NAME") |
| llm_model = os.getenv("MODELNAME") |
|
|
| |
| index, pc, llm, kg_index = None, None, None, None |
|
|
| |
| def add_user_message(message, history): |
| """ |
| Adds a new user message to the chat history. |
| |
| Ensures the message is appended in the correct format for downstream processing. |
| Returns updated history for use in the chat UI. |
| """ |
|
|
| history = history or [] |
| history.append({"role": "user", "content": message}) |
| return "", history, history |
|
|
| import time |
|
|
| |
| class StreamHandler(BaseCallbackHandler): |
| """ |
| Callback handler for streaming LLM tokens to the UI. |
| |
| Tracks timing for first token and total response, buffers tokens, |
| and manages the flow of streamed content for real-time display. |
| """ |
|
|
| def __init__(self, q: queue.Queue): |
| self.q = q |
| self.first_token_received = False |
| self.ttft = None |
| self.total_time = None |
| self.start_time = None |
| self.buffer = [] |
|
|
| def on_llm_new_token(self, token: str, **kwargs): |
| if not self.first_token_received: |
| self.ttft = time.time() - self.start_time |
| self.first_token_received = True |
| self.buffer.append(token) |
| self.q.put(token) |
|
|
| def on_llm_end(self, *args, **kwargs): |
| |
| |
| self.total_time = time.time() - self.start_time |
| |
|
|
| |
|
|
|
|
|
|
| def generate_bot_response(history): |
| """ |
| Streams the first pass from the LLM to the UI and updates a styled progress box above the chat. |
| """ |
| global index, pc, llm, kg_index |
|
|
| if not history or history[-1]["role"] != "user": |
| yield history, history, "<div style='background:#f5f5f5;padding:10px;border-radius:8px;'>Ready</div>" |
| return |
|
|
| user_msg = history[-1]["content"] |
| documents = None |
|
|
| |
| yield history, history, "<div style='background:#f5f5f5;padding:10px;border-radius:8px;'>Initializing LLM and infrastructure...</div>" |
| if not index or not pc or not llm: |
| from langchain_mistralai import ChatMistralAI |
| from langchain_openai import ChatOpenAI |
|
|
| llm = ChatOpenAI(model=llm_model) if "gpt" in llm_model else ChatMistralAI(model=llm_model) |
| index, pc, llm, documents = prepare_RAG( |
| pinecone_API, |
| index_name, |
| llm_model=llm_model, |
| dir_name=user_dir, |
| info=True |
| ) |
|
|
| |
| yield history, history, "<div style='background:#f5f5f5;padding:10px;border-radius:8px;'>Deciding Graph RAG usage...</div>" |
| def decide_graph_rag_usage(llm_, user_text: str) -> bool: |
| prompt = ( |
| "Given the following user prompt, determine whether graph RAG should be used (True or False):\n" |
| f"{user_text}\n" |
| "Use 'False' only if the prompt is focused on retrieving a single fact.\n" |
| "Use 'True' if the prompt suggests reasoning over a large portion or the entirety of a dataset or corpus." |
| ) |
| resp = llm_.invoke(prompt) |
| decision = (getattr(resp, "content", str(resp)) or "").strip() |
| print("[Debug] Graph RAG decision response:", decision) |
| return decision == "True" |
|
|
| graph_rag_flag = decide_graph_rag_usage(llm, user_msg) |
| print(f"[Info] Graph RAG usage decision: {graph_rag_flag}") |
|
|
| if graph_rag_flag and not documents: |
| _, _, _, documents = prepare_RAG( |
| pinecone_API, |
| index_name, |
| llm_model=llm_model, |
| dir_name=user_dir, |
| info=True |
| ) |
|
|
| if graph_rag_flag: |
| kg_index = build_knowledge_graph(documents, llm, pc, index, info=True) |
|
|
| |
| yield history, history, "<div style='background:#f5f5f5;padding:10px;border-radius:8px;'>Retrieving context...</div>" |
| retrieved_chunks, graph_context = retrieve_RAG( |
| user_msg, |
| pc, |
| index, |
| kg_index, |
| top_k=5, |
| use_query_reformulation=True, |
| llm=llm, |
| graphRAG=graph_rag_flag |
| ) |
|
|
| |
| yield history, history, "<div style='background:#f5f5f5;padding:10px;border-radius:8px;'>Generating response...</div>" |
| FINAL_PREFIX = "[[FINAL]]" |
| q = queue.Queue() |
| handler = StreamHandler(q) |
| handler.start_time = time.time() |
|
|
| model_name = getattr(llm, "model_name", getattr(llm, "model", None)) |
| streaming_llm = llm.__class__(model=model_name, streaming=True, callbacks=[handler]) |
|
|
| def run_llm(): |
| try: |
| resp = generate_RAG( |
| user_msg, |
| streaming_llm, |
| retrieved_chunks, |
| graph_context, |
| graphRAG=graph_rag_flag |
| ) |
| final_text = (getattr(resp, "content", str(resp)) or "").strip() |
| if final_text: |
| q.put(FINAL_PREFIX + final_text) |
| finally: |
| q.put("[[END]]") |
|
|
| Thread(target=run_llm, daemon=True).start() |
|
|
| partial = "" |
| history.append({"role": "assistant", "content": ""}) |
|
|
| while True: |
| token = q.get() |
| if token == "[[END]]": |
| yield history, history, "<div style='background:#d4edda;padding:10px;border-radius:8px;'>Completed!</div>" |
| print(f"[Timing] TTFT: {handler.ttft:.3f} s, Total: {handler.total_time:.3f} s") |
| break |
|
|
| if token.startswith(FINAL_PREFIX): |
| final = token[len(FINAL_PREFIX):] |
| history[-1]["content"] = final |
| yield history, history, "<div style='background:#f5f5f5;padding:10px;border-radius:8px;'>Generating response...</div>" |
| partial = final |
| continue |
|
|
| partial += token |
| history[-1]["content"] = partial |
| yield history, history, "<div style='background:#f5f5f5;padding:10px;border-radius:8px;'>Generating response...</div>" |
|
|
|
|
|
|
| |
| from pathlib import Path |
| import gradio as gr |
|
|
| |
| custom_css = Path("app.css").read_text(encoding="utf-8") |
| js_force_light = """ function refresh() { |
| const url = new URL(window.location); |
| if (url.searchParams.get('__theme') !== 'light') { |
| url.searchParams.set('__theme', 'light'); |
| window.location.replace(url); |
| } |
| } """ |
|
|
| |
| import os |
| import shutil |
| MAX_TOTAL_SIZE_MB = 5 |
| CONTEXT_DIR = "context" |
|
|
| def handle_file_upload(uploaded_files): |
| """ |
| Validates and saves uploaded files to the context directory for RAG processing. |
| |
| Checks file extensions and total upload size against allowed limits. |
| Returns a status message indicating success or failure for each upload attempt. |
| """ |
|
|
| context_dir = "context" |
| os.makedirs(context_dir, exist_ok=True) |
| saved_files = [] |
| total_size_mb = 0 |
|
|
| |
| allowed_extensions = {".txt", ".json", ".md", ".csv", ".pdf", ".docx", ".pptx", ".py"} |
|
|
| for file_obj in uploaded_files: |
| |
| ext = os.path.splitext(file_obj.name)[1].lower() |
| if ext not in allowed_extensions: |
| return f"❌ Unsupported file type: {ext}. Allowed types are: {', '.join(sorted(allowed_extensions))}" |
| |
| file_size_mb = os.path.getsize(file_obj.name) / (1024 * 1024) |
| total_size_mb += file_size_mb |
| if total_size_mb > MAX_TOTAL_SIZE_MB: |
| return f"❌ Total upload size exceeds the limit of {MAX_TOTAL_SIZE_MB}MB." |
| |
| filename = os.path.basename(file_obj.name) |
| dest_path = os.path.join(context_dir, filename) |
| with open(file_obj.name, "rb") as src, open(dest_path, "wb") as dst: |
| dst.write(src.read()) |
| saved_files.append(dest_path) |
|
|
| return f"✅ Uploaded {len(saved_files)} file(s) to '{context_dir}': {', '.join(os.path.basename(f) for f in saved_files)}" |
|
|
|
|
|
|
| |
| |
| GRAPH_JSON_PATH = "knowledge_graph.json" |
| COMMUNITY_MIN_SIZE = 3 |
| MERGE_SMALLS_POLICY = "bucket" |
| LAYOUT_SEED = 42 |
| LAYOUT_ITERS = 30 |
|
|
| |
| _g_G = None |
| _g_pos3d = None |
| _g_node2comm = None |
| _g_comm2nodes = None |
| _g_edges = None |
| _g_node_names = None |
|
|
| def load_graph_from_json(path=GRAPH_JSON_PATH): |
| """Read {source: [[rel, target], ...], ...} and return a DiGraph.""" |
| try: |
| with open(path, "r", encoding="utf-8") as f: |
| graph_dict = json.load(f) |
| except Exception: |
| graph_dict = {} |
| G = nx.DiGraph() |
| for source, edges_list in graph_dict.items(): |
| for relation, target in edges_list: |
| G.add_edge(source, target, label=relation) |
| if G.number_of_nodes() == 0: |
| G.add_node("(empty)") |
| return G |
|
|
| def precompute_layout_and_communities(G: nx.DiGraph): |
| """Compute 3D spring layout and top-level modularity communities.""" |
| pos3d = nx.spring_layout(G, dim=3, seed=LAYOUT_SEED, iterations=LAYOUT_ITERS) |
| node_names = list(G.nodes()) |
| edges = list(G.edges()) |
|
|
| |
| from networkx.algorithms.community import greedy_modularity_communities |
| UG = nx.Graph() |
| UG.add_edges_from(G.to_undirected().edges()) |
| communities = list(greedy_modularity_communities(UG)) |
| large = [set(c) for c in communities if len(c) >= COMMUNITY_MIN_SIZE] |
| small = [set(c) for c in communities if len(c) < COMMUNITY_MIN_SIZE] |
|
|
| if MERGE_SMALLS_POLICY == "bucket" and small: |
| other = set().union(*small) if small else set() |
| if other: |
| large.append(other) |
| comm_ids = [f"C{i}" for i in range(len(large) - (1 if other else 0))] |
| if other: |
| comm_ids.append("C_other") |
| elif MERGE_SMALLS_POLICY == "attach" and small and large: |
| for s in small: |
| |
| best_i, best_links = None, -1 |
| for i, L in enumerate(large): |
| links = sum(1 for u in s for v in L if UG.has_edge(u, v)) |
| if links > best_links: |
| best_i, best_links = i, links |
| if best_i is None: |
| best_i = max(range(len(large)), key=lambda i: len(large[i])) |
| large[best_i].update(s) |
| comm_ids = [f"C{i}" for i in range(len(large))] |
| else: |
| comm_ids = [f"C{i}" for i in range(len(large))] |
|
|
| node2comm, comm2nodes = {}, {} |
| for cid, nodeset in zip(comm_ids, large): |
| comm2nodes[cid] = set(nodeset) |
| for n in nodeset: |
| node2comm[n] = cid |
| for n in G.nodes(): |
| if n not in node2comm: |
| node2comm[n] = "C_isolated" |
| comm2nodes.setdefault("C_isolated", set()).add(n) |
|
|
| return pos3d, node2comm, comm2nodes, edges, node_names |
|
|
| def _make_comm_colors(comm2nodes_dict): |
| palette = (px.colors.qualitative.Alphabet + |
| px.colors.qualitative.Set3 + |
| px.colors.qualitative.Bold + |
| px.colors.qualitative.Dark24 + |
| px.colors.qualitative.Light24) |
| cids = sorted(comm2nodes_dict.keys()) |
| return {cid: palette[i % len(palette)] for i, cid in enumerate(cids)} |
|
|
| def _community_hulls_traces(pos3d, comm2nodes, comm_colors, opacity=0.12): |
| if not SCIPY_AVAILABLE: |
| return [] |
| hull_traces = [] |
| for cid, nodeset in comm2nodes.items(): |
| pts = np.array([pos3d[n] for n in nodeset if n in pos3d]) |
| if pts.shape[0] < 4: |
| continue |
| try: |
| hull = ConvexHull(pts) |
| simplices = hull.simplices |
| hull_traces.append(go.Mesh3d( |
| x=pts[:, 0], y=pts[:, 1], z=pts[:, 2], |
| i=simplices[:, 0], j=simplices[:, 1], k=simplices[:, 2], |
| color=_make_comm_colors(comm2nodes).get(cid, "#cccccc"), |
| opacity=opacity, name=f"{cid} region", |
| hoverinfo="skip", showlegend=False |
| )) |
| except Exception: |
| pass |
| return hull_traces |
| |
| def build_plotly_figure(mode="community", highlight_node=None, |
| highlight_comm_id=None, dim_inter_edges=True, |
| show_hulls=False): |
| global _g_G, _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names |
|
|
| |
| if _g_G is None: |
| _g_G = load_graph_from_json() |
| _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names = \ |
| precompute_layout_and_communities(_g_G) |
|
|
| G = _g_G |
| pos3d = _g_pos3d |
| node2comm = _g_node2comm |
| comm2nodes = _g_comm2nodes |
| edges = _g_edges |
|
|
| |
| edge_x_intra, edge_y_intra, edge_z_intra = [], [], [] |
| edge_x_inter, edge_y_inter, edge_z_inter = [], [], [] |
| for (u, v) in edges: |
| x0, y0, z0 = pos3d[u] |
| x1, y1, z1 = pos3d[v] |
| if node2comm.get(u) == node2comm.get(v): |
| edge_x_intra += [x0, x1, None]; edge_y_intra += [y0, y1, None]; edge_z_intra += [z0, z1, None] |
| else: |
| edge_x_inter += [x0, x1, None]; edge_y_inter += [y0, y1, None]; edge_z_inter += [z0, z1, None] |
|
|
| edge_traces = [] |
| if edge_x_inter: |
| edge_traces.append(go.Scatter3d( |
| x=edge_x_inter, y=edge_y_inter, z=edge_z_inter, |
| mode="lines", |
| line=dict(width=1, color="rgba(180,180,180,0.30)" if dim_inter_edges else "#BBBBBB"), |
| hoverinfo="none", showlegend=False, name="Inter-community" |
| )) |
| if edge_x_intra: |
| edge_traces.append(go.Scatter3d( |
| x=edge_x_intra, y=edge_y_intra, z=edge_z_intra, |
| mode="lines", |
| line=dict(width=2, color="rgba(120,120,120,0.55)"), |
| hoverinfo="none", showlegend=False, name="Intra-community" |
| )) |
|
|
| comm_colors = _make_comm_colors(comm2nodes) |
| hull_traces = _community_hulls_traces(pos3d, comm2nodes, comm_colors) if show_hulls else [] |
|
|
| |
| nbr_succ, nbr_pred = set(), set() |
| if mode == "neighbors" and highlight_node and highlight_node in G: |
| nbr_succ = set(G.neighbors(highlight_node)) |
| nbr_pred = set(G.predecessors(highlight_node)) |
|
|
| node_traces = [] |
| for cid, nodeset in sorted(comm2nodes.items(), key=lambda kv: kv[0]): |
| xs, ys, zs, texts, colors, sizes = [], [], [], [], [], [] |
| base_color = comm_colors.get(cid, "#66c2a5") |
| for n in nodeset: |
| x, y, z = pos3d[n] |
| xs.append(x); ys.append(y); zs.append(z); texts.append(n) |
| if mode == "neighbors": |
| if highlight_node == n: |
| colors.append("red"); sizes.append(8.0) |
| elif n in nbr_succ or n in nbr_pred: |
| colors.append("orange"); sizes.append(6.5) |
| elif highlight_node and node2comm.get(n) == node2comm.get(highlight_node): |
| colors.append(base_color); sizes.append(5.5) |
| else: |
| colors.append("lightblue"); sizes.append(5.0) |
| else: |
| if highlight_comm_id and node2comm.get(n) == highlight_comm_id: |
| colors.append(base_color); sizes.append(6.5) |
| else: |
| colors.append(base_color); sizes.append(5.0) |
| if xs: |
| node_traces.append(go.Scatter3d( |
| x=xs, y=ys, z=zs, mode="markers", |
| hovertext=texts, hoverinfo="text", |
| marker=dict(size=sizes, color=colors, opacity=0.95), |
| name=cid, showlegend=True |
| )) |
|
|
| fig = go.Figure(data=hull_traces + edge_traces + node_traces) |
| fig.update_layout( |
| title="3D Knowledge Graph — Communities & Neighbors", |
| showlegend=True if mode == "community" else False, |
| height=800, |
| margin=dict(l=0, r=0, t=40, b=0), |
| scene=dict( |
| xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), |
| aspectmode="data" |
| ), |
| scene_camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)), |
| uirevision=True, |
| ) |
| return fig |
|
|
| def reload_graph_cache(): |
| """Force re-read knowledge_graph.json and recompute layout/communities.""" |
| global _g_G, _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names |
| _g_G = load_graph_from_json() |
| _g_pos3d, _g_node2comm, _g_comm2nodes, _g_edges, _g_node_names = \ |
| precompute_layout_and_communities(_g_G) |
| |
| return build_plotly_figure(mode="community", highlight_comm_id=None, dim_inter_edges=True, show_hulls=False) |
|
|
| |
|
|
|
|
|
|
| with gr.Blocks(css=custom_css, fill_height=True, js=js_force_light) as demo: |
| with gr.Row(): |
| |
| with gr.Column(scale=1, elem_id="left-column"): |
| |
| with gr.Row(elem_id="branding-row"): |
| import base64 |
| from pathlib import Path |
|
|
| HERE = Path(__file__).resolve().parent |
| logo_path = HERE / "logo_mono.png" |
|
|
| with open(logo_path, "rb") as f: |
| encoded = base64.b64encode(f.read()).decode() |
|
|
| gr.HTML(f""" |
| <div id="branding"> |
| <img id="company-logo" src="data:image/png;base64,{encoded}" alt="Logo" /> |
| <span id="brand-text">mosaiicRAG</span> |
| </div> |
| """) |
|
|
| |
|
|
| gr.Markdown( |
| "<p>Daten verstehen. Wissen vernetzen. Entscheidungen stärken.</p>", |
| elem_id="header2" |
| ) |
|
|
| |
| |
| with gr.Column(scale=4, elem_id="right-column"): |
| with gr.Tabs(): |
| |
| with gr.Tab("Chat"): |
| with gr.Column(elem_id="chat-area"): |
| progress_box = gr.HTML("<div style='background:#f5f5f5;padding:10px;border-radius:8px;margin-bottom:10px;'>Ready</div>") |
| chatbot = gr.Chatbot(type="messages", label="Conversation", elem_id="chatbot") |
| with gr.Row(elem_id="input-row"): |
| msg = gr.Textbox(placeholder="Type your question here...", lines=1) |
| send_btn = gr.Button("➤", elem_id="send-button", size="sm") |
| state = gr.State([]) |
| msg.submit(add_user_message, inputs=[msg, state], outputs=[msg, chatbot, state])\ |
| .then(generate_bot_response, inputs=[state], outputs=[chatbot, state, progress_box]) |
| send_btn.click(add_user_message, inputs=[msg, state], outputs=[msg, chatbot, state])\ |
| .then(generate_bot_response, inputs=[state], outputs=[chatbot, state, progress_box]) |
|
|
| |
| with gr.Tab("Knowledge Graph"): |
| with gr.Row(): |
| color_mode = gr.Radio( |
| ["community"], |
| value="community", |
| label="Color mode" |
| ) |
| community_select = gr.Dropdown( |
| label="Highlight community (optional)", |
| choices=[], |
| value=None |
| ) |
| view_opts = gr.CheckboxGroup( |
| choices=[ |
| "Dim inter-community edges", |
| f"Show 3D community hulls{' (requires scipy)' if not SCIPY_AVAILABLE else ''}" |
| ], |
| value=["Dim inter-community edges"], |
| label="View options" |
| ) |
| reload_btn = gr.Button("Reload graph") |
|
|
| graph_plot = gr.Plot(label="3D Knowledge Graph") |
| node_info = gr.Markdown("") |
|
|
| |
| def _init_graph(): |
| |
| fig = reload_graph_cache() |
| cids = sorted(list(_g_comm2nodes.keys())) if _g_comm2nodes else [] |
| info = "Select a community or click a node to highlight its community." |
| |
| return fig, gr.update(choices=cids, value=None), info |
|
|
| def _refresh(mode, selected_cid, opts): |
| dim_edges = isinstance(opts, list) and ("Dim inter-community edges" in opts) |
| show_hulls = isinstance(opts, list) and any("Show 3D community hulls" in s for s in opts) |
|
|
| fig = build_plotly_figure( |
| mode="community" if mode == "community" else "neighbors", |
| highlight_comm_id=(selected_cid if mode == "community" else None), |
| dim_inter_edges=dim_edges, |
| show_hulls=(show_hulls if mode == "community" else False) |
| ) |
| info = ( |
| "Select a community or click a node to highlight its community." |
| if mode == "community" |
| else "Click a node to see its neighbors (community tint applied)." |
| ) |
| return fig, info |
|
|
| def _reload(mode, selected_cid, opts): |
| |
| _ = reload_graph_cache() |
| cids = sorted(list(_g_comm2nodes.keys())) if _g_comm2nodes else [] |
| |
| fig, info = _refresh(mode, selected_cid, opts) |
| return fig, gr.update(choices=cids, value=selected_cid), info |
|
|
| |
| color_mode.change(_refresh, inputs=[color_mode, community_select, view_opts], |
| outputs=[graph_plot, node_info]) |
| community_select.change(_refresh, inputs=[color_mode, community_select, view_opts], |
| outputs=[graph_plot, node_info]) |
| view_opts.change(_refresh, inputs=[color_mode, community_select, view_opts], |
| outputs=[graph_plot, node_info]) |
|
|
| reload_btn.click(_reload, inputs=[color_mode, community_select, view_opts], |
| outputs=[graph_plot, community_select, node_info]) |
|
|
| |
| |
| demo.load(_init_graph, inputs=[], outputs=[graph_plot, community_select, node_info]) |
| |
| if __name__ == "__main__": |
| demo.launch(inbrowser=True) |