File size: 10,720 Bytes
434392c
59e6760
 
 
 
 
434392c
282d875
59e6760
9d761b8
59e6760
282d875
59e6760
282d875
59e6760
282d875
59e6760
 
 
 
 
 
282d875
59e6760
 
 
 
 
 
 
 
 
282d875
434392c
59e6760
434392c
 
59e6760
282d875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59e6760
 
 
 
 
 
 
 
282d875
9d761b8
 
282d875
 
9d761b8
434392c
9d761b8
 
 
 
 
 
 
 
188a5d8
282d875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59e6760
 
 
 
434392c
59e6760
 
 
 
 
 
 
 
 
 
282d875
59e6760
 
 
 
 
282d875
 
59e6760
 
 
 
 
 
 
 
 
 
282d875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59e6760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434392c
59e6760
 
 
 
 
 
 
 
434392c
59e6760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282d875
 
59e6760
 
 
 
 
 
 
 
 
282d875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59e6760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282d875
59e6760
 
 
 
 
282d875
59e6760
 
 
 
 
 
 
 
 
 
 
 
 
 
282d875
 
 
 
 
 
59e6760
 
 
 
 
 
30720a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7d462d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# api.py
from __future__ import annotations
import os
import json
import logging
import time
import shutil
from typing import List, Optional

from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv

from models import OptimizeRequest, QARequest, AutotuneRequest

# Load environment
load_dotenv()

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ragmint_mcp_server")

# FastAPI app
app = FastAPI(title="Ragmint MCP Server", version="0.1.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Directories
DEFAULT_DATA_DIR = "data/docs"
LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl"
os.makedirs(DEFAULT_DATA_DIR, exist_ok=True)
os.makedirs("experiments", exist_ok=True)

# Try importing ragmint modules
try:
    from ragmint.autotuner import AutoRAGTuner
    from ragmint.qa_generator import generate_validation_qa
    from ragmint.explainer import explain_results
    from ragmint.leaderboard import Leaderboard
    from ragmint.tuner import RAGMint
except Exception as e:
    AutoRAGTuner = None
    generate_validation_qa = None
    explain_results = None
    Leaderboard = None
    RAGMint = None
    _import_error = e
else:
    _import_error = None


@app.get("/health")
def health():
    return {
        "status": "ok",
        "ragmint_imported": _import_error is None,
        "import_error": str(_import_error) if _import_error else None,
    }


@app.post("/upload_docs")
async def upload_docs(
        docs_path: str = Form(...),
        files: List[UploadFile] = File(...)
):
    os.makedirs(docs_path, exist_ok=True)
    saved_files = []
    for file in files:
        file_path = os.path.join(docs_path, file.filename)
        with open(file_path, "wb") as f:
            shutil.copyfileobj(file.file, f)
        saved_files.append(file.filename)
    return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path}


def handle_validation_choice(docs_path: str, validation_choice: Optional[str], llm_model: str) -> Optional[str]:
    """Determine which validation QA set to use or generate one."""
    validation_choice = (validation_choice or "").strip()
    default_path = os.path.join(docs_path, "validation_qa.json")

    if not validation_choice:
        if os.path.exists(default_path):
            logger.info("Using default validation QA: %s", default_path)
            return default_path
        return None

    if validation_choice.lower() == "generate":
        generate_validation_qa(
            docs_path=docs_path,
            output_path=default_path,
            llm_model=llm_model
        )
        logger.info("Generated validation QA at: %s", default_path)
        return default_path

    if os.path.exists(validation_choice) or "/" in validation_choice:
        logger.info("Using specified validation dataset: %s", validation_choice)
        return validation_choice

    logger.warning("Validation choice provided but not found: %s", validation_choice)
    return None


@app.post("/optimize_rag")
def optimize_rag(req: OptimizeRequest):
    logger.info("Received optimize_rag request: %s", req.json())
    if RAGMint is None:
        raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}")

    docs_path = req.docs_path or DEFAULT_DATA_DIR
    if not os.path.isdir(docs_path):
        raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")

    try:
        rag = RAGMint(
            docs_path=docs_path,
            retrievers=req.retriever,
            embeddings=req.embedding_model,
            rerankers=req.rerankers or ["mmr"],
            chunk_sizes=req.chunk_sizes,
            overlaps=req.overlaps,
            strategies=req.strategy,
        )

        validation_set = handle_validation_choice(docs_path, req.validation_choice,
                                                  getattr(req, "llm_model", "gemini-2.5-flash-lite"))
        start_time = time.time()
        best, results = rag.optimize(
            validation_set=validation_set,
            metric=req.metric,
            trials=req.trials,
            search_type=req.search_type
        )
        elapsed = time.time() - start_time
        run_id = f"opt_{int(time.time())}"

        corpus_stats = {
            "num_docs": len(rag.documents),
            "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
            "corpus_size": sum(len(d) for d in rag.documents),
        }

        if Leaderboard:
            lb = Leaderboard()
            lb.upload(
                run_id=run_id,
                best_config=best,
                best_score=best.get("faithfulness", best.get("score", 0.0)),
                all_results=results,
                documents=os.listdir(docs_path),
                model=best.get("embedding_model", req.embedding_model),
                corpus_stats=corpus_stats,
            )

        return {
            "status": "finished",
            "run_id": run_id,
            "elapsed_seconds": elapsed,
            "best_config": best,
            "results": results,
            "corpus_stats": corpus_stats,
        }
    except Exception as exc:
        logger.exception("optimize_rag failed")
        raise HTTPException(status_code=500, detail=str(exc))


@app.post("/autotune_rag")
def autotune_rag(req: AutotuneRequest):
    logger.info("Received autotune_rag request: %s", req.json())
    if AutoRAGTuner is None or RAGMint is None:
        raise HTTPException(status_code=500, detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}")

    docs_path = req.docs_path or DEFAULT_DATA_DIR
    if not os.path.isdir(docs_path):
        raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")

    try:
        start_time = time.time()
        tuner = AutoRAGTuner(docs_path=docs_path)
        rec = tuner.recommend(embedding_model=req.embedding_model, num_chunk_pairs=req.num_chunk_pairs)

        chunk_candidates = tuner.suggest_chunk_sizes(
            model_name=rec.get("embedding_model"),
            num_pairs=int(req.num_chunk_pairs),
            step=20
        )
        chunk_sizes = sorted({c for c, _ in chunk_candidates})
        overlaps = sorted({o for _, o in chunk_candidates})

        rag = RAGMint(
            docs_path=docs_path,
            retrievers=[rec["retriever"]],
            embeddings=[rec["embedding_model"]],
            rerankers=["mmr"],
            chunk_sizes=chunk_sizes,
            overlaps=overlaps,
            strategies=[rec["strategy"]],
        )

        validation_set = handle_validation_choice(docs_path, req.validation_choice,
                                                  getattr(req, "llm_model", "gemini-2.5-flash-lite"))
        best, results = rag.optimize(
            validation_set=validation_set,
            metric=req.metric,
            search_type=req.search_type,
            trials=req.trials,
        )
        elapsed = time.time() - start_time
        run_id = f"autotune_{int(time.time())}"

        corpus_stats = {
            "num_docs": len(rag.documents),
            "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
            "corpus_size": sum(len(d) for d in rag.documents),
        }

        if Leaderboard:
            lb = Leaderboard()
            lb.upload(
                run_id=run_id,
                best_config=best,
                best_score=best.get("faithfulness", best.get("score", 0.0)),
                all_results=results,
                documents=os.listdir(docs_path),
                model=best.get("embedding_model", rec.get("embedding_model")),
                corpus_stats=corpus_stats,
            )

        return {
            "status": "finished",
            "run_id": run_id,
            "elapsed_seconds": elapsed,
            "recommendation": rec,
            "chunk_candidates": chunk_candidates,
            "best_config": best,
            "results": results,
            "corpus_stats": corpus_stats,
        }

    except Exception as exc:
        logger.exception("autotune_rag failed")
        raise HTTPException(status_code=500, detail=str(exc))


@app.post("/generate_validation_qa")
def generate_validation_qa_endpoint(req: QARequest):
    logger.info("Received generate_validation_qa request: %s", req.json())
    if generate_validation_qa is None:
        raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}")

    try:
        out_path = os.path.join(req.docs_path or DEFAULT_DATA_DIR, "validation_qa.json")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)

        generate_validation_qa(
            docs_path=req.docs_path,
            output_path=out_path,
            llm_model=req.llm_model,
            batch_size=req.batch_size,
            min_q=req.min_q,
            max_q=req.max_q,
        )

        with open(out_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        return {
            "status": "finished",
            "output_path": out_path,
            "preview_count": len(data),
            "sample": data[:5]
        }

    except Exception as exc:
        logger.exception("generate_validation_qa failed")
        raise HTTPException(status_code=500, detail=str(exc))


@app.post("/clear_cache")
async def clear_cache(docs_path: str = Form(DEFAULT_DATA_DIR)):
    """
    Delete all files inside docs_path but keep the directory.
    Useful to reset uploaded documents for RAG runs.
    """
    if not os.path.exists(docs_path):
        raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")

    removed = []
    for root, dirs, files in os.walk(docs_path, topdown=False):
        for name in files:
            file_path = os.path.join(root, name)
            try:
                os.remove(file_path)
                removed.append(name)
            except Exception as e:
                logger.error(f"Failed to remove {file_path}: {e}")

        for name in dirs:
            dir_path = os.path.join(root, name)
            try:
                shutil.rmtree(dir_path)
                removed.append(f"{name}/")
            except Exception as e:
                logger.error(f"Failed to remove {dir_path}: {e}")

    return {
        "status": "cleared",
        "docs_path": docs_path,
        "removed_items": removed,
        "total_removed": len(removed),
    }



def start_api():
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")