soft.engineer
add setting tab
58259d1
import gradio as gr
import os
import logging
from dotenv import load_dotenv
import tempfile
import pandas as pd
from typing import List, Dict, Any, Tuple
import shutil
import json
from core.ingest import DocumentChunker, HierarchyManager
from core.index import VectorStore
from core.retrieval import RAGManager
from core.eval import RAGEvaluator
from core.utils import generate_id
import os as _os
_OPENAI_ON = False
try:
from openai import OpenAI as _OpenAI
_OPENAI_ON = True if _os.getenv("OPENAI_API_KEY") else False
except Exception:
_OPENAI_ON = False
def _auto_detect_query_filters(query: str) -> Dict[str, Any]:
"""Infer hierarchy filters (level1/2/3, doc_type) from a query.
Prefers OpenAI; falls back to heuristic scan of hierarchy keywords.
"""
# Try OpenAI if available
if os.getenv("OPENAI_API_KEY") and _OpenAI is not None:
try:
logger.debug("Calling OpenAI for query filter detection.")
client = _OpenAI()
prompt = (
"Given a user query, infer optional filters for a hierarchical RAG system."
" Return JSON with any of: level1, level2, level3, doc_type."
" Use concise strings; omit fields you cannot infer. Query: " + query
)
resp = client.chat.completions.create(
model=_os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
)
import json as _json
content = resp.choices[0].message.content
data = _json.loads(content)
logger.debug(f"OpenAI filters inferred: {data}")
if isinstance(data, dict) and any(data.get(k) for k in ("level1","level2","level3","doc_type")):
return data
except Exception:
logger.exception("OpenAI query filter detection failed; will try heuristic.")
# Heuristic fallback
try:
hm = HierarchyManager()
ql = query.lower()
best = {"level1": None, "level2": None, "level3": None, "doc_type": None}
best_score = -1
for hname, hdef in hm.hierarchies.items():
for l1 in hdef['levels']['level1']['values']:
score = 0
if l1.lower() in ql:
score += 2
for l2 in hdef['levels']['level2']['values'].get(l1, []):
if l2.lower() in ql:
score += 2
for l3 in hdef['levels']['level3']['values'].get(l2, []):
if l3.lower() in ql:
score += 1
if score > best_score:
best_score = score
best.update({"level1": l1})
for dt in ["Policy","Manual","FAQ","Report","Note","Guideline"]:
if dt.lower() in ql:
best["doc_type"] = dt
break
logger.debug(f"Heuristic filters inferred: {best if best_score>0 or best.get('doc_type') else {}}")
return best if best_score > 0 or best.get("doc_type") else {}
except Exception:
logger.debug("Heuristic detection failed; returning no filters.")
return {}
# Load environment variables from .env if present, then configure logging
load_dotenv()
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
logger = logging.getLogger("rag_app")
if os.getenv("OPENAI_API_KEY"):
logger.info("OpenAI API key detected. OpenAI-powered auto-detection is ENABLED.")
if os.getenv("OPENAI_MODEL"):
logger.info(f"OpenAI model: {os.getenv('OPENAI_MODEL')}")
else:
logger.info("OpenAI API key not set. Falling back to heuristic auto-detection.")
# Global variables
rag_manager = None
evaluator = None
current_collection = "documents"
persist_directory = None
def initialize_system():
"""Initialize the RAG system"""
global rag_manager, evaluator, persist_directory
# Try /data/chroma first (for HF Spaces), fallback to ./chroma_data
persist_dir = "/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data"
# Create directory with proper permissions, and check if we can write to it
try:
os.makedirs(persist_dir, exist_ok=True, mode=0o755)
# Test write permissions
test_file = os.path.join(persist_dir, ".test_write")
try:
with open(test_file, 'w') as f:
f.write("test")
os.remove(test_file)
except (PermissionError, OSError):
# If can't write to /data/chroma, use ./chroma_data
persist_dir = "./chroma_data"
os.makedirs(persist_dir, exist_ok=True, mode=0o755)
except (PermissionError, OSError) as e:
# If even ./chroma_data fails, try current directory
persist_dir = "./chroma_data"
os.makedirs(persist_dir, exist_ok=True, mode=0o755)
persist_directory = persist_dir
rag_manager = RAGManager(persist_directory=persist_dir)
evaluator = RAGEvaluator(rag_manager)
return f"System initialized successfully! Using persist directory: {persist_dir}"
def reset_index() -> str:
"""Clear Chroma persistence and reinitialize the vector store."""
global rag_manager, evaluator, persist_directory
try:
dir_path = persist_directory or ("/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data")
if os.path.exists(dir_path):
shutil.rmtree(dir_path, ignore_errors=True)
os.makedirs(dir_path, exist_ok=True, mode=0o755)
persist_directory = dir_path
rag_manager = RAGManager(persist_directory=dir_path)
evaluator = RAGEvaluator(rag_manager)
return f"Index reset complete. Using fresh directory: {dir_path}"
except Exception as ex:
return f"Failed to reset index: {ex}"
def upload_documents(files: List[str], hierarchy: str, doc_type: str, language: str, progress: Any = None) -> Tuple[str, List[Dict[str, Any]], Dict[str, Any], List[Dict[str, Any]]]:
"""Upload and process documents.
Returns: (status_text, per_file_summaries, collection_stats)
per_file_summaries: [{filename, chunks, language, doc_type, hierarchy}]
"""
global rag_manager, persist_directory
if not files:
return "No files provided!"
# Ensure system is initialized
if not rag_manager:
initialize_system()
chunker = DocumentChunker()
all_chunks = []
per_file_summaries: List[Dict[str, Any]] = []
chunk_rows: List[Dict[str, Any]] = []
processed_count = 0
errors: List[str] = []
total = len(files)
for idx, file_path in enumerate(files, start=1):
if progress:
try:
progress(idx, total=total, desc=f"Processing {idx}/{total}: {os.path.basename(file_path)}")
except Exception:
pass
try:
chunks = chunker.chunk_document(file_path, hierarchy, doc_type, language)
if chunks:
# Aggregate per-file metadata from chunks (majority vote)
from collections import Counter
langs = Counter([c.metadata.get('lang') for c in chunks if c.metadata.get('lang')])
docts = Counter([c.metadata.get('doc_type') for c in chunks if c.metadata.get('doc_type')])
# prefer explicit 'hierarchy' if present, else most common level1's hierarchy name is unknown
hier_names = Counter([c.metadata.get('hierarchy') for c in chunks if c.metadata.get('hierarchy')])
per_file_summaries.append({
'Filename': os.path.basename(file_path),
'Chunks': len(chunks),
'Language': (langs.most_common(1)[0][0] if langs else None),
'Doc Type': (docts.most_common(1)[0][0] if docts else None),
'Hierarchy': (hier_names.most_common(1)[0][0] if hier_names else None)
})
# Prepare per-chunk preview rows
for c in chunks:
md = c.metadata or {}
chunk_rows.append({
'Filename': os.path.basename(file_path),
'Level1': md.get('level1'),
'Level2': md.get('level2'),
'Level3': md.get('level3'),
'Doc Type': md.get('doc_type'),
'Language': md.get('lang'),
'Preview': (c.content[:160] + '...') if c.content else ''
})
# Log indexing summary per file (without AI percentage)
try:
logger.info("Indexed %s: chunks=%d, lang=%s, doc_type=%s, hierarchy=%s",
os.path.basename(file_path), len(chunks),
(langs.most_common(1)[0][0] if langs else None),
(docts.most_common(1)[0][0] if docts else None),
(hier_names.most_common(1)[0][0] if hier_names else None))
except Exception:
pass
all_chunks.extend(chunks)
processed_count += 1
else:
errors.append(f"Warning: {os.path.basename(file_path)} produced no chunks")
except Exception as e:
error_msg = f"{os.path.basename(file_path)}: {str(e)}"
errors.append(error_msg)
# Continue processing other files instead of stopping
# Index if any chunk present
vector_store = rag_manager.vector_store
stats: Dict[str, Any] = {"document_count": 0, "collection_name": current_collection}
if all_chunks:
vector_store.add_documents(current_collection, all_chunks)
stats = vector_store.get_collection_stats(current_collection)
# Build result message
status_lines = [
f"Processed {processed_count}/{total} files",
f"Indexed chunks: {len(all_chunks)}"
]
if errors:
status_lines.append("\nErrors/Warnings:\n" + "\n".join(f"- {e}" for e in errors))
if not all_chunks:
return "\n".join(status_lines), per_file_summaries, stats, chunk_rows
return "\n".join(status_lines), per_file_summaries, stats, chunk_rows
def build_rag_index(files: List[Any], hierarchy: str, doc_type: str, language: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]:
"""Build RAG index from uploaded files"""
if not files:
return "No files provided!", None
# Gradio file objects already contain the full path in .name property
# No need to prepend /tmp/ - just use the path directly
file_paths = []
for file in files:
# Get the file path - Gradio provides it as .name or as a string
if isinstance(file, str):
file_path = file
elif hasattr(file, 'name') and file.name:
# file.name already contains the full path (e.g., /tmp/gradio/.../filename.txt)
file_path = file.name
else:
# Fallback for edge cases
return f"Error: Unable to get file path from uploaded file", None
# Normalize the path to handle any double slashes
file_path = os.path.normpath(file_path)
# Ensure the file exists
if not os.path.exists(file_path):
return f"Error: File not found at {file_path}", None
file_paths.append(file_path)
# Normalize "Auto" to None for auto-detection downstream
norm_hierarchy = None if not hierarchy or str(hierarchy).lower() == 'auto' else hierarchy
norm_doc_type = None if not doc_type or str(doc_type).lower() == 'auto' else doc_type
norm_language = None if not language or str(language).lower() == 'auto' else language
# Process documents (progress provided by gradio)
status_text, file_summaries, stats, chunk_rows = upload_documents(file_paths, norm_hierarchy, norm_doc_type, norm_language, progress=None)
# Build per-file dataframe (no totals row)
per_file_df = pd.DataFrame(file_summaries) if file_summaries else pd.DataFrame(columns=['Filename','Chunks','Language','Doc Type','Hierarchy'])
chunks_df = pd.DataFrame(chunk_rows) if chunk_rows else pd.DataFrame(columns=['Filename','Level1','Level2','Level3','Doc Type','Language','Preview'])
return status_text, per_file_df, chunks_df
def search_documents(query: str, k: int, level1: str, level2: str, level3: str, doc_type: str) -> Tuple[str, str, str, pd.DataFrame]:
"""Search documents using both RAG pipelines"""
if not rag_manager:
return "System not initialized!", "", "", None
# Convert empty strings to None
level1 = level1 if level1 else None
level2 = level2 if level2 else None
level3 = level3 if level3 else None
doc_type = doc_type if doc_type else None
# Auto-detect filters from query if none provided
if not any([level1, level2, level3, doc_type]):
inferred = _auto_detect_query_filters(query)
level1 = level1 or inferred.get('level1')
level2 = level2 or inferred.get('level2')
level3 = level3 or inferred.get('level3')
doc_type = doc_type or inferred.get('doc_type')
base_result, hier_result = rag_manager.compare_retrieval(query, k, level1, level2, level3, doc_type)
# Prepare results for display
results_data = []
for i, source in enumerate(base_result.sources):
results_data.append({
'Pipeline': 'Base-RAG',
'Rank': i + 1,
'Content': source['content'][:100] + '...',
'Domain': source['metadata'].get('level1', 'N/A'),
'Section': source['metadata'].get('level2', 'N/A'),
'Topic': source['metadata'].get('level3', 'N/A'),
'Score': f"{source['score']:.3f}"
})
for i, source in enumerate(hier_result.sources):
results_data.append({
'Pipeline': 'Hier-RAG',
'Rank': i + 1,
'Content': source['content'][:100] + '...',
'Domain': source['metadata'].get('level1', 'N/A'),
'Section': source['metadata'].get('level2', 'N/A'),
'Topic': source['metadata'].get('level3', 'N/A'),
'Score': f"{source['score']:.3f}"
})
results_df = pd.DataFrame(results_data)
comparison_text = f"""
## Retrieval Comparison
### Base-RAG
- Latency: {base_result.latency:.3f}s
- Retrieved: {len(base_result.sources)} documents
### Hier-RAG
- Latency: {hier_result.latency:.3f}s
- Retrieved: {len(hier_result.sources)} documents
- Filters: level1={level1 or 'None'}, level2={level2 or 'None'}, level3={level3 or 'None'}, doc_type={doc_type or 'None'}
"""
return base_result.content, hier_result.content, comparison_text, results_df
def search_documents_auto(query: str) -> Tuple[str, str, str, pd.DataFrame]:
"""Search with only a query; k and filters auto-handled.
- k defaults to DEFAULT_SEARCH_K env or 5
- filters inferred from query when possible
"""
default_k = int(os.getenv("DEFAULT_SEARCH_K", 5))
return search_documents(query, default_k, None, None, None, None)
def search_documents_unified(query: str, manual: bool, k: int = None,
level1: str = None, level2: str = None,
level3: str = None, doc_type: str = None) -> Tuple[str, str, str, pd.DataFrame]:
"""Single entry search. If manual is True, use provided controls; else auto."""
if manual:
# normalize empty/"Auto" strings to None (auto-detect)
def _norm(v):
return None if (v is None or (isinstance(v, str) and v.strip().lower() == 'auto') or v == '') else v
level1 = _norm(level1)
level2 = _norm(level2)
level3 = _norm(level3)
doc_type = _norm(doc_type)
k = k or int(os.getenv("DEFAULT_SEARCH_K", 5))
return search_documents(query, k, level1, level2, level3, doc_type)
return search_documents_auto(query)
def _toggle_manual_controls(manual: bool):
"""Update manual controls interactive state, accordion, and default values."""
auto_update = gr.update(value="Auto", interactive=manual)
return (
gr.update(open=manual),
gr.update(interactive=manual), # k slider
auto_update, # level1
auto_update, # level2
auto_update, # level3
auto_update, # doc_type
)
def _llm_answer(user_message: str, contexts: List[Dict[str, Any]]) -> str:
"""Generate a natural, human-like answer grounded in contexts.
Uses OpenAI if configured; otherwise produces a conversational fallback.
"""
# Build a compact context string
ctx_blocks = []
for i, c in enumerate(contexts, 1):
src = c.get('metadata', {}).get('source_name', 'unknown')
snippet = (c.get('content', '') or '')[:400]
ctx_blocks.append(f"[{i}] ({src}) {snippet}")
ctx_text = "\n\n".join(ctx_blocks)
# Prefer OpenAI
if os.getenv("OPENAI_API_KEY") and '_OpenAI' in globals() and _OpenAI is not None:
try:
client = _OpenAI()
system_prompt = (
"You are a helpful, professional assistant. Answer in a warm, natural, and concise tone. "
"ALWAYS ground the answer ONLY in the provided contexts. If information is missing, say so. "
"Style: Start with a clear 1-2 sentence answer. Then, if helpful, add 2-5 short bullet points with key facts. "
"Avoid hedging, avoid citations inline, avoid repeating the question."
)
content = (
f"User question:\n{user_message}\n\n"
f"Contexts (each begins with [n]):\n{ctx_text}\n\n"
"Write a natural, human-like response as specified."
)
resp = client.chat.completions.create(
model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": content},
],
temperature=0.2,
)
return resp.choices[0].message.content.strip()
except Exception:
pass
# Fallback heuristic: conversational synthesis from top snippets
if contexts:
snippets = []
for src in contexts[:3]:
txt = (src.get('content') or '').strip()
if txt:
snippets.append(txt)
joined = " ".join(snippets)
# Naive sentence split
sentences = [s.strip() for s in joined.replace('\n', ' ').split('.') if s.strip()]
bullets = sentences[:4]
lead = "Here’s what the documents say:"
if bullets:
bullets_text = "\n".join([f"- {b}." for b in bullets])
return f"{lead}\n\n{bullets_text}"
return "I don't have enough information to answer that yet."
def chat_with_rag(message: str, history: List[Dict[str, str]], pipeline: str, k: int) -> Tuple[str, List[Dict[str, str]]]:
"""Chat interface with RAG: choose one pipeline (base_rag or hier_rag),
retrieve, then generate an LLM answer grounded in retrieved contexts and
show sources used.
"""
if not rag_manager:
return "System not initialized! Please build the RAG index first.", history
# Retrieve with chosen pipeline
filters_note = ""
if pipeline == "hier_rag":
inferred = _auto_detect_query_filters(message)
h_l1 = inferred.get('level1')
h_l2 = inferred.get('level2')
h_l3 = inferred.get('level3')
h_dt = inferred.get('doc_type')
result = rag_manager.hier_rag.retrieve(message, k, h_l1, h_l2, h_l3, h_dt)
# Show filters only when detected
found = []
if h_l1:
found.append(f"level1={h_l1}")
if h_l2:
found.append(f"level2={h_l2}")
if h_l3:
found.append(f"level3={h_l3}")
if h_dt:
found.append(f"doc_type={h_dt}")
if found:
filters_note = " (filters: " + ", ".join(found) + ")"
else:
result = rag_manager.base_rag.retrieve(message, k)
# Generate grounded answer
answer = _llm_answer(message, result.sources)
# Sources list
src_lines = []
for s in result.sources[:k]:
src_name = s.get('metadata', {}).get('source_name', 'unknown')
src_lines.append(f"• {src_name}")
response = (
f"{answer}\n\n"
f"─────────────────────────────────────\n"
f"📎 Sources ({'Hier-RAG' if pipeline=='hier_rag' else 'Base-RAG'}{filters_note}):\n" + ("\n".join(src_lines) or "(none)")
)
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return "", history
def run_evaluation(queries_json: str, output_filename: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]:
"""Run quantitative evaluation"""
if not evaluator:
return "System not initialized!", None, None
try:
queries = json.loads(queries_json)
except json.JSONDecodeError as e:
return f"Invalid JSON format: {str(e)}", None, None
df, results = evaluator.batch_evaluate(queries, output_filename)
# Generate summary statistics
summary = df.groupby(['pipeline', 'k']).agg({
'hit_at_k': 'mean',
'mrr': 'mean',
'semantic_similarity': 'mean',
'latency': 'mean',
'retrieved_count': 'mean'
}).reset_index()
# Create comparison plot data
plot_data = pd.DataFrame({
'k': summary[summary['pipeline'] == 'base_rag']['k'],
'base_rag_hit@k': summary[summary['pipeline'] == 'base_rag']['hit_at_k'],
'hier_rag_hit@k': summary[summary['pipeline'] == 'hier_rag']['hit_at_k'],
'base_rag_mrr': summary[summary['pipeline'] == 'base_rag']['mrr'],
'hier_rag_mrr': summary[summary['pipeline'] == 'hier_rag']['mrr']
})
return f"Evaluation completed! Processed {len(queries)} queries. Results saved to {output_filename}", df, plot_data
# Diagnostics: simple OpenAI connectivity test
## (removed) test_openai_connectivity helper
# Initialize system
initialize_system()
# Create Gradio interface
# Minimal CSS to keep layout stable when vertical scrollbar appears and improve mobile spacing
APP_CSS = """
html, body { scrollbar-gutter: stable both-edges; }
body { overflow-y: scroll; }
* { box-sizing: border-box; }
@media (max-width: 768px) {
.gradio-container { padding-left: 8px; padding-right: 8px; }
}
"""
with gr.Blocks(title="RAG Evaluation System", css=APP_CSS) as demo:
gr.Markdown("# RAG Evaluation System: Hierarchical vs Standard RAG")
with gr.Tab("Upload Documents"):
gr.Markdown("## Upload and Process Documents")
with gr.Row():
with gr.Column():
file_upload = gr.File(
label="Upload PDF/TXT Files",
file_count="multiple",
file_types=[".pdf", ".txt"]
)
hierarchy_dropdown = gr.Dropdown(
choices=["Auto", "hospital", "bank", "fluid_simulation"],
label="Select Hierarchy",
value="Auto"
)
doc_type_dropdown = gr.Dropdown(
choices=["Auto", "Policy", "Manual", "FAQ", "Report", "Note", "Guideline"],
label="Document Type",
value="Auto"
)
language_dropdown = gr.Dropdown(
choices=["Auto", "en", "ja"],
label="Language",
value="Auto"
)
build_btn = gr.Button("Build RAG Index")
with gr.Column():
build_output = gr.Textbox(label="Build Status", lines=4)
stats_table = gr.DataFrame(label="File Summary")
chunks_table = gr.DataFrame(label="Indexed Chunks (preview)")
reset_btn = gr.Button("Reset Index (Clear chroma_data)", variant="secondary")
with gr.Tab("Search"):
gr.Markdown("## Compare Retrieval Pipelines")
with gr.Row():
with gr.Column():
manual_toggle = gr.Checkbox(label="Manual controls", value=False)
search_query = gr.Textbox(
label="Search Query",
placeholder="Enter your query...",
lines=2
)
# Single Search button
search_btn = gr.Button("Search", variant="primary")
with gr.Accordion("Manual (optional)", open=False) as manual_section:
k_slider = gr.Slider(minimum=1, maximum=20, value=int(os.getenv("DEFAULT_SEARCH_K", 5)), step=1, label="Number of results (k)", interactive=False)
with gr.Row():
level1_filter = gr.Dropdown(label="Domain (Level1)", choices=["Auto"], value="Auto", allow_custom_value=True, interactive=False)
level2_filter = gr.Dropdown(label="Section (Level2)", choices=["Auto"], value="Auto", allow_custom_value=True, interactive=False)
with gr.Row():
level3_filter = gr.Dropdown(label="Topic (Level3)", choices=["Auto"], value="Auto", allow_custom_value=True, interactive=False)
doc_type_filter = gr.Dropdown(label="Document Type", choices=["Auto", "Policy", "Manual", "FAQ", "Report", "Note", "Guideline"], value="Auto", allow_custom_value=True, interactive=False)
with gr.Column():
base_results = gr.Textbox(
label="Base-RAG Results",
lines=8,
max_lines=12
)
hier_results = gr.Textbox(
label="Hier-RAG Results",
lines=8,
max_lines=12
)
comparison_text = gr.Markdown()
results_table = gr.DataFrame(label="Detailed Results")
with gr.Tab("Chat"):
gr.Markdown("## Chat with RAG System")
with gr.Row():
with gr.Column(scale=1):
pipeline_radio = gr.Radio(
choices=["base_rag", "hier_rag"],
label="RAG Pipeline",
value="hier_rag"
)
chat_k_slider = gr.Slider(
minimum=1, maximum=10, value=3, step=1,
label="Number of results"
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="RAG Chat", type="messages")
chat_input = gr.Textbox(
label="Message",
placeholder="Ask a question...",
lines=2
)
chat_btn = gr.Button("Send")
with gr.Tab("Evaluation"):
gr.Markdown("## Quantitative Evaluation")
with gr.Row():
with gr.Column():
eval_queries = gr.Textbox(
label="Evaluation Queries (JSON)",
lines=10,
placeholder='''[
{
"query": "question 1",
"ground_truth": ["expected answer 1", "expected answer 2"],
"k_values": [1, 3, 5],
"level1": "Clinical",
"level2": "Emergency",
"level3": "Triage",
"doc_type": "Report"
}
]''',
value='''[
{
"query": "What are the emergency procedures?",
"ground_truth": ["Emergency protocols for triage", "Patient assessment guidelines"],
"k_values": [1, 3, 5]
}
]'''
)
eval_output_name = gr.Textbox(
label="Output Filename",
value="evaluation_results.csv"
)
eval_btn = gr.Button("Run Evaluation", variant="primary")
# Diagnostics removed
with gr.Column():
eval_output = gr.Textbox(label="Evaluation Status", lines=3)
eval_results_table = gr.DataFrame(label="Evaluation Results")
eval_plot = gr.LinePlot(
label="Performance Comparison",
x="k",
y=["base_rag_hit@k", "hier_rag_hit@k"],
title="Hit@k Comparison",
width=600,
height=400
)
with gr.Tab("Settings"):
gr.Markdown("## Settings")
gr.Markdown("Configure embedding models and system preferences.")
with gr.Accordion("Embedding Configuration", open=True):
gr.Markdown("**Select the embedding provider and model.** Switching providers requires re-indexing your documents.")
with gr.Row():
with gr.Column():
emb_provider = gr.Radio(
choices=["SentenceTransformers", "OpenAI"],
value="SentenceTransformers",
label="Embeddings Provider",
info="Choose between local SentenceTransformers models or OpenAI embeddings (requires API key)"
)
with gr.Row():
apply_embed_btn = gr.Button("Apply Embedding Settings", variant="primary")
with gr.Row():
with gr.Column():
st_model_in = gr.Textbox(
label="SentenceTransformers Model",
value=os.getenv("ST_EMBED_MODEL", "all-MiniLM-L6-v2"),
interactive=False,
info="Local embedding model (384 dimensions)"
)
with gr.Column():
oai_model_in = gr.Textbox(
label="OpenAI Embedding Model",
value=os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small"),
interactive=False,
info="OpenAI embedding model (1536 dimensions for small, 3072 for large)"
)
embed_status = gr.Textbox(
label="Status",
lines=3,
interactive=False,
placeholder="Embedding configuration status will appear here..."
)
# Define handler before wiring it
def _apply_embeddings(provider, st_model, oai_model):
try:
use_oai = (provider == "OpenAI")
rag_manager.vector_store.configure_embeddings(use_oai, openai_model=oai_model, st_model_name=st_model)
status_msg = f"✅ Embeddings successfully configured!\n\n"
status_msg += f"Provider: {provider}\n"
if use_oai:
status_msg += f"Model: {oai_model} (OpenAI)\n"
status_msg += f"Dimensions: {3072 if 'large' in oai_model.lower() else 1536}\n"
else:
status_msg += f"Model: {st_model} (SentenceTransformers)\n"
status_msg += f"Dimensions: ~384\n"
status_msg += f"\n⚠️ Note: If switching providers, reset and rebuild your index in the Upload tab."
return status_msg
except Exception as ex:
return f"❌ Failed to set embeddings: {ex}\n\nPlease check your configuration and try again."
apply_embed_btn.click(
fn=_apply_embeddings,
inputs=[emb_provider, st_model_in, oai_model_in],
outputs=embed_status
)
# Event handlers
build_btn.click(
fn=build_rag_index,
inputs=[file_upload, hierarchy_dropdown, doc_type_dropdown, language_dropdown],
outputs=[build_output, stats_table, chunks_table],
api_name="build_rag"
)
reset_btn.click(
fn=reset_index,
inputs=None,
outputs=build_output
)
# One button: uses manual if checked; otherwise auto
search_btn.click(
fn=search_documents_unified,
inputs=[search_query, manual_toggle, k_slider, level1_filter, level2_filter, level3_filter, doc_type_filter],
outputs=[base_results, hier_results, comparison_text, results_table],
api_name="search"
)
# Toggle manual controls interactive and accordion state
manual_toggle.change(
fn=_toggle_manual_controls,
inputs=[manual_toggle],
outputs=[manual_section, k_slider, level1_filter, level2_filter, level3_filter, doc_type_filter]
)
chat_btn.click(
fn=chat_with_rag,
inputs=[chat_input, chatbot, pipeline_radio, chat_k_slider],
outputs=[chat_input, chatbot],
api_name="chat"
).then(
lambda: None,
None,
chat_input,
queue=False
)
eval_btn.click(
fn=run_evaluation,
inputs=[eval_queries, eval_output_name],
outputs=[eval_output, eval_results_table, eval_plot],
api_name="evaluate"
)
# Diagnostics trigger removed
# MCP Server Implementation
import asyncio
import sys
from typing import Any, List, Optional
try:
from mcp.server import Server
from mcp.server.models import InitializationOptions
from mcp.types import Tool, TextContent
MCP_AVAILABLE = True
except ImportError:
MCP_AVAILABLE = False
# Fallback for when MCP is not installed
Server = None
Tool = None
TextContent = None
class RAGMCPServer:
"""MCP server for RAG system"""
def __init__(self):
persist_dir = "/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data"
self.rag_manager = RAGManager(persist_directory=persist_dir)
self.evaluator = RAGEvaluator(self.rag_manager)
async def list_tools(self) -> List[Tool]:
"""List available MCP tools"""
return [
Tool(
name="search_documents",
description="Search documents using RAG system (Base-RAG or Hier-RAG)",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"k": {"type": "integer", "description": "Number of results", "default": 5},
"pipeline": {"type": "string", "enum": ["base_rag", "hier_rag"], "default": "base_rag"},
"level1": {"type": "string", "description": "Level1 filter (domain)"},
"level2": {"type": "string", "description": "Level2 filter (section)"},
"level3": {"type": "string", "description": "Level3 filter (topic)"},
"doc_type": {"type": "string", "description": "Document type filter"}
},
"required": ["query"]
}
),
Tool(
name="evaluate_retrieval",
description="Evaluate RAG performance with batch queries",
inputSchema={
"type": "object",
"properties": {
"queries": {
"type": "array",
"description": "List of query objects with query, ground_truth, k_values, and optional filters",
"items": {"type": "object"}
},
"output_file": {"type": "string", "description": "Output filename for results"}
},
"required": ["queries"]
}
)
]
async def call_tool(self, name: str, arguments: dict) -> List[TextContent]:
"""Call an MCP tool by name"""
if name == "search_documents":
query = arguments.get("query")
k = arguments.get("k", 5)
pipeline = arguments.get("pipeline", "base_rag")
level1 = arguments.get("level1")
level2 = arguments.get("level2")
level3 = arguments.get("level3")
doc_type = arguments.get("doc_type")
if pipeline == "base_rag":
result = self.rag_manager.base_rag.retrieve(query, k)
else:
result = self.rag_manager.hier_rag.retrieve(query, k, level1, level2, level3, doc_type)
response = {
"content": result.content,
"sources": [
{
"content": source['content'][:200],
"metadata": source['metadata'],
"score": source['score']
} for source in result.sources
],
"latency": result.latency,
"strategy": pipeline
}
return [TextContent(type="text", text=json.dumps(response, indent=2))]
elif name == "evaluate_retrieval":
queries = arguments.get("queries", [])
output_file = arguments.get("output_file")
df, results = self.evaluator.batch_evaluate(queries, output_file)
summary = df.groupby('pipeline').agg({
'hit_at_k': 'mean',
'mrr': 'mean',
'semantic_similarity': 'mean',
'latency': 'mean'
}).reset_index()
response = {
"summary": summary.to_dict('records'),
"total_queries": len(queries),
"output_file": output_file
}
return [TextContent(type="text", text=json.dumps(response, indent=2))]
else:
raise ValueError(f"Unknown tool: {name}")
# Export for Gradio Client
if __name__ == "__main__":
# If run as CLI, prefer plain Gradio serving. Spaces will import demo directly.
# Respect common hosting env vars.
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", 7860)))
# Avoid SSR and API schema on Spaces to prevent response length errors
demo.launch(server_name=host, server_port=port, share=False, ssl_verify=False, ssr_mode=False)