# api.py from __future__ import annotations import os import json import logging import time import shutil from models import OptimizeRequest, QARequest, AutotuneRequest from fastapi import FastAPI, HTTPException, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware try: from ragmint.autotuner import AutoRAGTuner from ragmint.qa_generator import generate_validation_qa from ragmint.explainer import explain_results from ragmint.leaderboard import Leaderboard from ragmint.tuner import RAGMint except Exception as e: AutoRAGTuner = None generate_validation_qa = None explain_results = None Leaderboard = None RAGMint = None _import_error = e else: _import_error = None from dotenv import load_dotenv load_dotenv() # Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ragmint_mcp_server") # FastAPI app (exported for mounting) app = FastAPI(title="Ragmint MCP Server", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Use repo-local data folder (not parent dirs) DEFAULT_DATA_DIR = "data/docs" LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl" # ensure folders exist os.makedirs(DEFAULT_DATA_DIR, exist_ok=True) os.makedirs("experiments", exist_ok=True) @app.get("/health") def health(): return { "status": "ok", "ragmint_imported": _import_error is None, "import_error": str(_import_error) if _import_error else None, } @app.post("/upload_docs") async def upload_docs( docs_path: str = Form(...), files: list[UploadFile] = File(...) ): os.makedirs(docs_path, exist_ok=True) saved_files = [] for file in files: file_path = os.path.join(docs_path, file.filename) with open(file_path, "wb") as f: shutil.copyfileobj(file.file, f) saved_files.append(file.filename) return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path} @app.post("/optimize_rag") def optimize_rag(req: OptimizeRequest): logger.info("Received optimize_rag request: %s", req.json()) if RAGMint is None: raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: rag = RAGMint( docs_path=docs_path, retrievers=req.retriever, embeddings=req.embedding_model, rerankers=(req.rerankers or ["mmr"]), chunk_sizes=req.chunk_sizes, overlaps=req.overlaps, strategies=req.strategy, ) # validation set handling validation_set = None validation_choice = (req.validation_choice or "").strip() default_val_path = os.path.join(docs_path, "validation_qa.json") if not validation_choice: if os.path.exists(default_val_path): validation_set = default_val_path logger.info("Using default validation set: %s", validation_set) else: logger.warning("No validation_choice provided and no default found.") validation_set = None elif "/" in validation_choice and not os.path.exists(validation_choice): validation_set = validation_choice logger.info("Using HF dataset as validation: %s", validation_set) elif os.path.exists(validation_choice): validation_set = validation_choice logger.info("Using local validation dataset: %s", validation_set) elif validation_choice.lower() == "generate": gen_path = os.path.join(docs_path, "validation_qa.json") generate_validation_qa( docs_path=docs_path, output_path=gen_path, llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite" ) validation_set = gen_path logger.info("Generated validation QA at: %s", validation_set) start_time = time.time() best, results = rag.optimize( validation_set=validation_set, metric=req.metric, trials=req.trials, search_type=req.search_type ) elapsed = time.time() - start_time run_id = f"opt_{int(time.time())}" try: corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } except Exception: corpus_stats = None try: if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", req.embedding_model), corpus_stats=corpus_stats, ) except Exception: logger.exception("Leaderboard persistence failed for optimize_rag") return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("optimize_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/autotune_rag") def autotune_rag(req: AutotuneRequest): logger.info("Received autotune_rag request: %s", req.json()) if AutoRAGTuner is None or RAGMint is None: raise HTTPException(status_code=500, detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}") docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: start_time = time.time() tuner = AutoRAGTuner(docs_path=docs_path) rec = tuner.recommend(embedding_model=req.embedding_model, num_chunk_pairs=req.num_chunk_pairs) chunk_candidates = tuner.suggest_chunk_sizes( model_name=rec.get("embedding_model"), num_pairs=int(req.num_chunk_pairs), step=20 ) chunk_sizes = sorted({c for c, _ in chunk_candidates}) overlaps = sorted({o for _, o in chunk_candidates}) rag = RAGMint( docs_path=docs_path, retrievers=[rec["retriever"]], embeddings=[rec["embedding_model"]], rerankers=["mmr"], chunk_sizes=chunk_sizes, overlaps=overlaps, strategies=[rec["strategy"]], ) validation_set = None validation_choice = (req.validation_choice or "").strip() default_val_path = os.path.join(docs_path, "validation_qa.jsonl") if not validation_choice: if os.path.exists(default_val_path): validation_set = default_val_path else: validation_set = None elif "/" in validation_choice and not os.path.exists(validation_choice): validation_set = validation_choice elif os.path.exists(validation_choice): validation_set = validation_choice elif validation_choice.lower() == "generate": gen_path = os.path.join(docs_path, "validation_qa.json") generate_validation_qa( docs_path=docs_path, output_path=gen_path, llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite", ) validation_set = gen_path best, results = rag.optimize( validation_set=validation_set, metric=req.metric, search_type=req.search_type, trials=req.trials, ) elapsed = time.time() - start_time run_id = f"autotune_{int(time.time())}" try: corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } except Exception: corpus_stats = None try: if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", rec.get("embedding_model")), corpus_stats=corpus_stats, ) except Exception: logger.exception("Leaderboard persistence failed for autotune_rag") return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "recommendation": rec, "chunk_candidates": chunk_candidates, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("autotune_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/generate_validation_qa") def generate_qa(req: QARequest): logger.info("Received generate_validation_qa request: %s", req.json()) if generate_validation_qa is None: raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") try: out_path = os.path.join("data", "docs", "validation_qa.json") os.makedirs(os.path.dirname(out_path), exist_ok=True) generate_validation_qa( docs_path=req.docs_path, output_path=out_path, llm_model=req.llm_model, batch_size=req.batch_size, min_q=req.min_q, max_q=req.max_q, ) with open(out_path, "r", encoding="utf-8") as f: data = json.load(f) return {"status": "finished", "output_path": out_path, "preview_count": len(data), "sample": data[:5]} except Exception as exc: logger.exception("generate_validation_qa failed") raise HTTPException(status_code=500, detail=str(exc)) def start_api(): import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")