André Oliveira
Initial MCP Space push
59e6760
raw
history blame
11.5 kB
from __future__ import annotations
import os
import json
import logging
import time
from models import OptimizeRequest, QARequest, AutotuneRequest
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
try:
from ragmint.autotuner import AutoRAGTuner
from ragmint.qa_generator import generate_validation_qa
from ragmint.explainer import explain_results
from ragmint.leaderboard import Leaderboard
from ragmint.tuner import RAGMint
except Exception as e:
AutoRAGTuner = None
generate_validation_qa = None
explain_results = None
Leaderboard = None
RAGMint = None
_import_error = e
else:
_import_error = None
from dotenv import load_dotenv
load_dotenv()
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ragmint_mcp_server")
# FastAPI
app = FastAPI(title="Ragmint MCP Server", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEFAULT_DATA_DIR = "../data/docs"
LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl"
os.makedirs("../experiments", exist_ok=True)
@app.get("/health")
def health():
return {
"status": "ok",
"ragmint_imported": _import_error is None,
"import_error": str(_import_error) if _import_error else None,
}
@app.post("/optimize_rag")
def optimize_rag(req: OptimizeRequest):
logger.info("Received optimize_rag request: %s", req.json())
if RAGMint is None:
raise HTTPException(
status_code=500,
detail=f"Ragmint imports failed or RAGMint unavailable: {_import_error}"
)
docs_path = req.docs_path or DEFAULT_DATA_DIR
if not os.path.isdir(docs_path):
raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")
try:
# Build RAGMint exactly from request
rag = RAGMint(
docs_path=docs_path,
retrievers=req.retriever,
embeddings=req.embedding_model,
rerankers=(req.rerankers or ["mmr"]),
chunk_sizes=req.chunk_sizes,
overlaps=req.overlaps,
strategies=req.strategy,
)
# Validation selection
validation_set = None
validation_choice = (req.validation_choice or "").strip()
default_val_path = os.path.join(docs_path, "validation_qa.json")
# Auto
if not validation_choice:
if os.path.exists(default_val_path):
validation_set = default_val_path
logger.info("Using default validation set: %s", validation_set)
else:
logger.warning("No validation_choice provided and no default found.")
validation_set = None
# Remote HF dataset
elif "/" in validation_choice and not os.path.exists(validation_choice):
validation_set = validation_choice
logger.info("Using Hugging Face validation dataset: %s", validation_set)
# Local file
elif os.path.exists(validation_choice):
validation_set = validation_choice
logger.info("Using local validation dataset: %s", validation_set)
# Generate
elif validation_choice.lower() == "generate":
try:
gen_path = os.path.join(docs_path, "validation_qa.json")
generate_validation_qa(
docs_path=docs_path,
output_path=gen_path,
llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite"
)
validation_set = gen_path
logger.info("Generated new validation QA set at: %s", validation_set)
except Exception as e:
logger.exception("Failed to generate validation QA dataset: %s", e)
raise HTTPException(status_code=500, detail=f"Failed to generate validation QA dataset: {e}")
# Optimize
start_time = time.time()
best, results = rag.optimize(
validation_set=validation_set,
metric=req.metric,
trials=req.trials,
search_type=req.search_type
)
elapsed = time.time() - start_time
run_id = f"opt_{int(time.time())}"
# Corpus stats
try:
corpus_stats = {
"num_docs": len(rag.documents),
"avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
"corpus_size": sum(len(d) for d in rag.documents),
}
except Exception:
corpus_stats = None
# Leaderboard
try:
if Leaderboard:
lb = Leaderboard()
lb.upload(
run_id=run_id,
best_config=best,
best_score=best.get("faithfulness", best.get("score", 0.0)),
all_results=results,
documents=os.listdir(docs_path),
model=best.get("embedding_model", req.embedding_model),
corpus_stats=corpus_stats,
)
except Exception:
logger.exception("Leaderboard persistence failed for optimize_rag")
return {
"status": "finished",
"run_id": run_id,
"elapsed_seconds": elapsed,
"best_config": best,
"results": results,
"corpus_stats": corpus_stats,
}
except Exception as exc:
logger.exception("optimize_rag failed")
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/autotune_rag")
def autotune_rag(req: AutotuneRequest):
logger.info("Received autotune_rag request: %s", req.json())
if AutoRAGTuner is None or RAGMint is None:
raise HTTPException(
status_code=500,
detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}"
)
docs_path = req.docs_path or DEFAULT_DATA_DIR
if not os.path.isdir(docs_path):
raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")
try:
start_time = time.time()
tuner = AutoRAGTuner(docs_path=docs_path)
rec = tuner.recommend(
embedding_model=req.embedding_model,
num_chunk_pairs=req.num_chunk_pairs
)
chunk_candidates = tuner.suggest_chunk_sizes(
model_name=rec.get("embedding_model"),
num_pairs=int(req.num_chunk_pairs),
step=20
)
chunk_sizes = sorted({c for c, _ in chunk_candidates})
overlaps = sorted({o for _, o in chunk_candidates})
rag = RAGMint(
docs_path=docs_path,
retrievers=[rec["retriever"]],
embeddings=[rec["embedding_model"]],
rerankers=["mmr"],
chunk_sizes=chunk_sizes,
overlaps=overlaps,
strategies=[rec["strategy"]],
)
# Validation selection
validation_set = None
validation_choice = (req.validation_choice or "").strip()
default_val_path = os.path.join(docs_path, "validation_qa.jsonl")
if not validation_choice:
if os.path.exists(default_val_path):
validation_set = default_val_path
logger.info("Using default validation set: %s", validation_set)
else:
logger.warning("No validation_choice provided and no default found.")
validation_set = None
elif "/" in validation_choice and not os.path.exists(validation_choice):
validation_set = validation_choice
elif os.path.exists(validation_choice):
validation_set = validation_choice
elif validation_choice.lower() == "generate":
try:
gen_path = os.path.join(docs_path, "validation_qa.json")
generate_validation_qa(
docs_path=docs_path,
output_path=gen_path,
llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite",
)
validation_set = gen_path
except Exception as e:
logger.exception("Failed to generate validation QA dataset: %s", e)
raise HTTPException(status_code=500, detail=f"Failed to generate validation QA dataset: {e}")
# Full optimize
best, results = rag.optimize(
validation_set=validation_set,
metric=req.metric,
search_type=req.search_type,
trials=req.trials,
)
elapsed = time.time() - start_time
run_id = f"autotune_{int(time.time())}"
# Corpus stats
try:
corpus_stats = {
"num_docs": len(rag.documents),
"avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
"corpus_size": sum(len(d) for d in rag.documents),
}
except Exception:
corpus_stats = None
# Leaderboard
try:
if Leaderboard:
lb = Leaderboard()
lb.upload(
run_id=run_id,
best_config=best,
best_score=best.get("faithfulness", best.get("score", 0.0)),
all_results=results,
documents=os.listdir(docs_path),
model=best.get("embedding_model", rec.get("embedding_model")),
corpus_stats=corpus_stats,
)
except Exception:
logger.exception("Leaderboard persistence failed for autotune_rag")
return {
"status": "finished",
"run_id": run_id,
"elapsed_seconds": elapsed,
"recommendation": rec,
"chunk_candidates": chunk_candidates,
"best_config": best,
"results": results,
"corpus_stats": corpus_stats,
}
except Exception as exc:
logger.exception("autotune_rag failed")
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/generate_validation_qa")
def generate_qa(req: QARequest):
logger.info("Received generate_validation_qa request: %s", req.json())
if generate_validation_qa is None:
raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}")
try:
out_path = f"data/docs/validation_qa.json"
os.makedirs(os.path.dirname(out_path), exist_ok=True)
generate_validation_qa(
docs_path=req.docs_path,
output_path=out_path,
llm_model=req.llm_model,
batch_size=req.batch_size,
min_q=req.min_q,
max_q=req.max_q,
)
with open(out_path, "r", encoding="utf-8") as f:
data = json.load(f)
return {
"status": "finished",
"output_path": out_path,
"preview_count": len(data),
"sample": data[:5],
}
except Exception as exc:
logger.exception("generate_validation_qa failed")
raise HTTPException(status_code=500, detail=str(exc))
# -----------------------
# FastAPI launch
# -----------------------
def main():
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
if __name__ == "__main__":
main()