# api.py from __future__ import annotations import os import json import logging import time import shutil from typing import List, Optional from fastapi import FastAPI, HTTPException, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware from dotenv import load_dotenv from models import OptimizeRequest, QARequest, AutotuneRequest # Load environment load_dotenv() # Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ragmint_mcp_server") # FastAPI app app = FastAPI(title="Ragmint MCP Server", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Directories DEFAULT_DATA_DIR = "data/docs" LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl" os.makedirs(DEFAULT_DATA_DIR, exist_ok=True) os.makedirs("experiments", exist_ok=True) # Try importing ragmint modules try: from ragmint.autotuner import AutoRAGTuner from ragmint.qa_generator import generate_validation_qa from ragmint.explainer import explain_results from ragmint.leaderboard import Leaderboard from ragmint.tuner import RAGMint except Exception as e: AutoRAGTuner = None generate_validation_qa = None explain_results = None Leaderboard = None RAGMint = None _import_error = e else: _import_error = None @app.get("/health") def health(): return { "status": "ok", "ragmint_imported": _import_error is None, "import_error": str(_import_error) if _import_error else None, } @app.post("/upload_docs") async def upload_docs( docs_path: str = Form(...), files: List[UploadFile] = File(...) ): os.makedirs(docs_path, exist_ok=True) saved_files = [] for file in files: file_path = os.path.join(docs_path, file.filename) with open(file_path, "wb") as f: shutil.copyfileobj(file.file, f) saved_files.append(file.filename) return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path} def handle_validation_choice(docs_path: str, validation_choice: Optional[str], llm_model: str) -> Optional[str]: """Determine which validation QA set to use or generate one.""" validation_choice = (validation_choice or "").strip() default_path = os.path.join(docs_path, "validation_qa.json") if not validation_choice: if os.path.exists(default_path): logger.info("Using default validation QA: %s", default_path) return default_path return None if validation_choice.lower() == "generate": generate_validation_qa( docs_path=docs_path, output_path=default_path, llm_model=llm_model ) logger.info("Generated validation QA at: %s", default_path) return default_path if os.path.exists(validation_choice) or "/" in validation_choice: logger.info("Using specified validation dataset: %s", validation_choice) return validation_choice logger.warning("Validation choice provided but not found: %s", validation_choice) return None @app.post("/optimize_rag") def optimize_rag(req: OptimizeRequest): logger.info("Received optimize_rag request: %s", req.json()) if RAGMint is None: raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: rag = RAGMint( docs_path=docs_path, retrievers=req.retriever, embeddings=req.embedding_model, rerankers=req.rerankers or ["mmr"], chunk_sizes=req.chunk_sizes, overlaps=req.overlaps, strategies=req.strategy, ) validation_set = handle_validation_choice(docs_path, req.validation_choice, getattr(req, "llm_model", "gemini-2.5-flash-lite")) start_time = time.time() best, results = rag.optimize( validation_set=validation_set, metric=req.metric, trials=req.trials, search_type=req.search_type ) elapsed = time.time() - start_time run_id = f"opt_{int(time.time())}" corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", req.embedding_model), corpus_stats=corpus_stats, ) return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("optimize_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/autotune_rag") def autotune_rag(req: AutotuneRequest): logger.info("Received autotune_rag request: %s", req.json()) if AutoRAGTuner is None or RAGMint is None: raise HTTPException(status_code=500, detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}") docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: start_time = time.time() tuner = AutoRAGTuner(docs_path=docs_path) rec = tuner.recommend(embedding_model=req.embedding_model, num_chunk_pairs=req.num_chunk_pairs) chunk_candidates = tuner.suggest_chunk_sizes( model_name=rec.get("embedding_model"), num_pairs=int(req.num_chunk_pairs), step=20 ) chunk_sizes = sorted({c for c, _ in chunk_candidates}) overlaps = sorted({o for _, o in chunk_candidates}) rag = RAGMint( docs_path=docs_path, retrievers=[rec["retriever"]], embeddings=[rec["embedding_model"]], rerankers=["mmr"], chunk_sizes=chunk_sizes, overlaps=overlaps, strategies=[rec["strategy"]], ) validation_set = handle_validation_choice(docs_path, req.validation_choice, getattr(req, "llm_model", "gemini-2.5-flash-lite")) best, results = rag.optimize( validation_set=validation_set, metric=req.metric, search_type=req.search_type, trials=req.trials, ) elapsed = time.time() - start_time run_id = f"autotune_{int(time.time())}" corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", rec.get("embedding_model")), corpus_stats=corpus_stats, ) return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "recommendation": rec, "chunk_candidates": chunk_candidates, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("autotune_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/generate_validation_qa") def generate_validation_qa_endpoint(req: QARequest): logger.info("Received generate_validation_qa request: %s", req.json()) if generate_validation_qa is None: raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") try: out_path = os.path.join(req.docs_path or DEFAULT_DATA_DIR, "validation_qa.json") os.makedirs(os.path.dirname(out_path), exist_ok=True) generate_validation_qa( docs_path=req.docs_path, output_path=out_path, llm_model=req.llm_model, batch_size=req.batch_size, min_q=req.min_q, max_q=req.max_q, ) with open(out_path, "r", encoding="utf-8") as f: data = json.load(f) return { "status": "finished", "output_path": out_path, "preview_count": len(data), "sample": data[:5] } except Exception as exc: logger.exception("generate_validation_qa failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/clear_cache") async def clear_cache(docs_path: str = Form(DEFAULT_DATA_DIR)): """ Delete all files inside docs_path but keep the directory. Useful to reset uploaded documents for RAG runs. """ if not os.path.exists(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") removed = [] for root, dirs, files in os.walk(docs_path, topdown=False): for name in files: file_path = os.path.join(root, name) try: os.remove(file_path) removed.append(name) except Exception as e: logger.error(f"Failed to remove {file_path}: {e}") for name in dirs: dir_path = os.path.join(root, name) try: shutil.rmtree(dir_path) removed.append(f"{name}/") except Exception as e: logger.error(f"Failed to remove {dir_path}: {e}") return { "status": "cleared", "docs_path": docs_path, "removed_items": removed, "total_removed": len(removed), } def start_api(): import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")