from __future__ import annotations import os import json import logging import time from models import OptimizeRequest, QARequest, AutotuneRequest from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import uvicorn try: from ragmint.autotuner import AutoRAGTuner from ragmint.qa_generator import generate_validation_qa from ragmint.explainer import explain_results from ragmint.leaderboard import Leaderboard from ragmint.tuner import RAGMint except Exception as e: AutoRAGTuner = None generate_validation_qa = None explain_results = None Leaderboard = None RAGMint = None _import_error = e else: _import_error = None from dotenv import load_dotenv load_dotenv() # Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ragmint_mcp_server") # FastAPI app = FastAPI(title="Ragmint MCP Server", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) DEFAULT_DATA_DIR = "../data/docs" LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl" os.makedirs("../experiments", exist_ok=True) @app.get("/health") def health(): return { "status": "ok", "ragmint_imported": _import_error is None, "import_error": str(_import_error) if _import_error else None, } @app.post("/optimize_rag") def optimize_rag(req: OptimizeRequest): logger.info("Received optimize_rag request: %s", req.json()) if RAGMint is None: raise HTTPException( status_code=500, detail=f"Ragmint imports failed or RAGMint unavailable: {_import_error}" ) docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: # Build RAGMint exactly from request rag = RAGMint( docs_path=docs_path, retrievers=req.retriever, embeddings=req.embedding_model, rerankers=(req.rerankers or ["mmr"]), chunk_sizes=req.chunk_sizes, overlaps=req.overlaps, strategies=req.strategy, ) # Validation selection validation_set = None validation_choice = (req.validation_choice or "").strip() default_val_path = os.path.join(docs_path, "validation_qa.json") # Auto if not validation_choice: if os.path.exists(default_val_path): validation_set = default_val_path logger.info("Using default validation set: %s", validation_set) else: logger.warning("No validation_choice provided and no default found.") validation_set = None # Remote HF dataset elif "/" in validation_choice and not os.path.exists(validation_choice): validation_set = validation_choice logger.info("Using Hugging Face validation dataset: %s", validation_set) # Local file elif os.path.exists(validation_choice): validation_set = validation_choice logger.info("Using local validation dataset: %s", validation_set) # Generate elif validation_choice.lower() == "generate": try: gen_path = os.path.join(docs_path, "validation_qa.json") generate_validation_qa( docs_path=docs_path, output_path=gen_path, llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite" ) validation_set = gen_path logger.info("Generated new validation QA set at: %s", validation_set) except Exception as e: logger.exception("Failed to generate validation QA dataset: %s", e) raise HTTPException(status_code=500, detail=f"Failed to generate validation QA dataset: {e}") # Optimize start_time = time.time() best, results = rag.optimize( validation_set=validation_set, metric=req.metric, trials=req.trials, search_type=req.search_type ) elapsed = time.time() - start_time run_id = f"opt_{int(time.time())}" # Corpus stats try: corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } except Exception: corpus_stats = None # Leaderboard try: if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", req.embedding_model), corpus_stats=corpus_stats, ) except Exception: logger.exception("Leaderboard persistence failed for optimize_rag") return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("optimize_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/autotune_rag") def autotune_rag(req: AutotuneRequest): logger.info("Received autotune_rag request: %s", req.json()) if AutoRAGTuner is None or RAGMint is None: raise HTTPException( status_code=500, detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}" ) docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: start_time = time.time() tuner = AutoRAGTuner(docs_path=docs_path) rec = tuner.recommend( embedding_model=req.embedding_model, num_chunk_pairs=req.num_chunk_pairs ) chunk_candidates = tuner.suggest_chunk_sizes( model_name=rec.get("embedding_model"), num_pairs=int(req.num_chunk_pairs), step=20 ) chunk_sizes = sorted({c for c, _ in chunk_candidates}) overlaps = sorted({o for _, o in chunk_candidates}) rag = RAGMint( docs_path=docs_path, retrievers=[rec["retriever"]], embeddings=[rec["embedding_model"]], rerankers=["mmr"], chunk_sizes=chunk_sizes, overlaps=overlaps, strategies=[rec["strategy"]], ) # Validation selection validation_set = None validation_choice = (req.validation_choice or "").strip() default_val_path = os.path.join(docs_path, "validation_qa.jsonl") if not validation_choice: if os.path.exists(default_val_path): validation_set = default_val_path logger.info("Using default validation set: %s", validation_set) else: logger.warning("No validation_choice provided and no default found.") validation_set = None elif "/" in validation_choice and not os.path.exists(validation_choice): validation_set = validation_choice elif os.path.exists(validation_choice): validation_set = validation_choice elif validation_choice.lower() == "generate": try: gen_path = os.path.join(docs_path, "validation_qa.json") generate_validation_qa( docs_path=docs_path, output_path=gen_path, llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite", ) validation_set = gen_path except Exception as e: logger.exception("Failed to generate validation QA dataset: %s", e) raise HTTPException(status_code=500, detail=f"Failed to generate validation QA dataset: {e}") # Full optimize best, results = rag.optimize( validation_set=validation_set, metric=req.metric, search_type=req.search_type, trials=req.trials, ) elapsed = time.time() - start_time run_id = f"autotune_{int(time.time())}" # Corpus stats try: corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } except Exception: corpus_stats = None # Leaderboard try: if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", rec.get("embedding_model")), corpus_stats=corpus_stats, ) except Exception: logger.exception("Leaderboard persistence failed for autotune_rag") return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "recommendation": rec, "chunk_candidates": chunk_candidates, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("autotune_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/generate_validation_qa") def generate_qa(req: QARequest): logger.info("Received generate_validation_qa request: %s", req.json()) if generate_validation_qa is None: raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") try: out_path = f"data/docs/validation_qa.json" os.makedirs(os.path.dirname(out_path), exist_ok=True) generate_validation_qa( docs_path=req.docs_path, output_path=out_path, llm_model=req.llm_model, batch_size=req.batch_size, min_q=req.min_q, max_q=req.max_q, ) with open(out_path, "r", encoding="utf-8") as f: data = json.load(f) return { "status": "finished", "output_path": out_path, "preview_count": len(data), "sample": data[:5], } except Exception as exc: logger.exception("generate_validation_qa failed") raise HTTPException(status_code=500, detail=str(exc)) # ----------------------- # FastAPI launch # ----------------------- def main(): uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") if __name__ == "__main__": main()