André Oliveira
added clear cache tool
30720a5
raw
history blame
10.7 kB
# api.py
from __future__ import annotations
import os
import json
import logging
import time
import shutil
from typing import List, Optional
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from models import OptimizeRequest, QARequest, AutotuneRequest
# Load environment
load_dotenv()
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ragmint_mcp_server")
# FastAPI app
app = FastAPI(title="Ragmint MCP Server", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Directories
DEFAULT_DATA_DIR = "data/docs"
LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl"
os.makedirs(DEFAULT_DATA_DIR, exist_ok=True)
os.makedirs("experiments", exist_ok=True)
# Try importing ragmint modules
try:
from ragmint.autotuner import AutoRAGTuner
from ragmint.qa_generator import generate_validation_qa
from ragmint.explainer import explain_results
from ragmint.leaderboard import Leaderboard
from ragmint.tuner import RAGMint
except Exception as e:
AutoRAGTuner = None
generate_validation_qa = None
explain_results = None
Leaderboard = None
RAGMint = None
_import_error = e
else:
_import_error = None
@app.get("/health")
def health():
return {
"status": "ok",
"ragmint_imported": _import_error is None,
"import_error": str(_import_error) if _import_error else None,
}
@app.post("/upload_docs")
async def upload_docs(
docs_path: str = Form(...),
files: List[UploadFile] = File(...)
):
os.makedirs(docs_path, exist_ok=True)
saved_files = []
for file in files:
file_path = os.path.join(docs_path, file.filename)
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
saved_files.append(file.filename)
return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path}
def handle_validation_choice(docs_path: str, validation_choice: Optional[str], llm_model: str) -> Optional[str]:
"""Determine which validation QA set to use or generate one."""
validation_choice = (validation_choice or "").strip()
default_path = os.path.join(docs_path, "validation_qa.json")
if not validation_choice:
if os.path.exists(default_path):
logger.info("Using default validation QA: %s", default_path)
return default_path
return None
if validation_choice.lower() == "generate":
generate_validation_qa(
docs_path=docs_path,
output_path=default_path,
llm_model=llm_model
)
logger.info("Generated validation QA at: %s", default_path)
return default_path
if os.path.exists(validation_choice) or "/" in validation_choice:
logger.info("Using specified validation dataset: %s", validation_choice)
return validation_choice
logger.warning("Validation choice provided but not found: %s", validation_choice)
return None
@app.post("/optimize_rag")
def optimize_rag(req: OptimizeRequest):
logger.info("Received optimize_rag request: %s", req.json())
if RAGMint is None:
raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}")
docs_path = req.docs_path or DEFAULT_DATA_DIR
if not os.path.isdir(docs_path):
raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")
try:
rag = RAGMint(
docs_path=docs_path,
retrievers=req.retriever,
embeddings=req.embedding_model,
rerankers=req.rerankers or ["mmr"],
chunk_sizes=req.chunk_sizes,
overlaps=req.overlaps,
strategies=req.strategy,
)
validation_set = handle_validation_choice(docs_path, req.validation_choice,
getattr(req, "llm_model", "gemini-2.5-flash-lite"))
start_time = time.time()
best, results = rag.optimize(
validation_set=validation_set,
metric=req.metric,
trials=req.trials,
search_type=req.search_type
)
elapsed = time.time() - start_time
run_id = f"opt_{int(time.time())}"
corpus_stats = {
"num_docs": len(rag.documents),
"avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
"corpus_size": sum(len(d) for d in rag.documents),
}
if Leaderboard:
lb = Leaderboard()
lb.upload(
run_id=run_id,
best_config=best,
best_score=best.get("faithfulness", best.get("score", 0.0)),
all_results=results,
documents=os.listdir(docs_path),
model=best.get("embedding_model", req.embedding_model),
corpus_stats=corpus_stats,
)
return {
"status": "finished",
"run_id": run_id,
"elapsed_seconds": elapsed,
"best_config": best,
"results": results,
"corpus_stats": corpus_stats,
}
except Exception as exc:
logger.exception("optimize_rag failed")
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/autotune_rag")
def autotune_rag(req: AutotuneRequest):
logger.info("Received autotune_rag request: %s", req.json())
if AutoRAGTuner is None or RAGMint is None:
raise HTTPException(status_code=500, detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}")
docs_path = req.docs_path or DEFAULT_DATA_DIR
if not os.path.isdir(docs_path):
raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")
try:
start_time = time.time()
tuner = AutoRAGTuner(docs_path=docs_path)
rec = tuner.recommend(embedding_model=req.embedding_model, num_chunk_pairs=req.num_chunk_pairs)
chunk_candidates = tuner.suggest_chunk_sizes(
model_name=rec.get("embedding_model"),
num_pairs=int(req.num_chunk_pairs),
step=20
)
chunk_sizes = sorted({c for c, _ in chunk_candidates})
overlaps = sorted({o for _, o in chunk_candidates})
rag = RAGMint(
docs_path=docs_path,
retrievers=[rec["retriever"]],
embeddings=[rec["embedding_model"]],
rerankers=["mmr"],
chunk_sizes=chunk_sizes,
overlaps=overlaps,
strategies=[rec["strategy"]],
)
validation_set = handle_validation_choice(docs_path, req.validation_choice,
getattr(req, "llm_model", "gemini-2.5-flash-lite"))
best, results = rag.optimize(
validation_set=validation_set,
metric=req.metric,
search_type=req.search_type,
trials=req.trials,
)
elapsed = time.time() - start_time
run_id = f"autotune_{int(time.time())}"
corpus_stats = {
"num_docs": len(rag.documents),
"avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
"corpus_size": sum(len(d) for d in rag.documents),
}
if Leaderboard:
lb = Leaderboard()
lb.upload(
run_id=run_id,
best_config=best,
best_score=best.get("faithfulness", best.get("score", 0.0)),
all_results=results,
documents=os.listdir(docs_path),
model=best.get("embedding_model", rec.get("embedding_model")),
corpus_stats=corpus_stats,
)
return {
"status": "finished",
"run_id": run_id,
"elapsed_seconds": elapsed,
"recommendation": rec,
"chunk_candidates": chunk_candidates,
"best_config": best,
"results": results,
"corpus_stats": corpus_stats,
}
except Exception as exc:
logger.exception("autotune_rag failed")
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/generate_validation_qa")
def generate_validation_qa_endpoint(req: QARequest):
logger.info("Received generate_validation_qa request: %s", req.json())
if generate_validation_qa is None:
raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}")
try:
out_path = os.path.join(req.docs_path or DEFAULT_DATA_DIR, "validation_qa.json")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
generate_validation_qa(
docs_path=req.docs_path,
output_path=out_path,
llm_model=req.llm_model,
batch_size=req.batch_size,
min_q=req.min_q,
max_q=req.max_q,
)
with open(out_path, "r", encoding="utf-8") as f:
data = json.load(f)
return {
"status": "finished",
"output_path": out_path,
"preview_count": len(data),
"sample": data[:5]
}
except Exception as exc:
logger.exception("generate_validation_qa failed")
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/clear_cache")
async def clear_cache(docs_path: str = Form(DEFAULT_DATA_DIR)):
"""
Delete all files inside docs_path but keep the directory.
Useful to reset uploaded documents for RAG runs.
"""
if not os.path.exists(docs_path):
raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}")
removed = []
for root, dirs, files in os.walk(docs_path, topdown=False):
for name in files:
file_path = os.path.join(root, name)
try:
os.remove(file_path)
removed.append(name)
except Exception as e:
logger.error(f"Failed to remove {file_path}: {e}")
for name in dirs:
dir_path = os.path.join(root, name)
try:
shutil.rmtree(dir_path)
removed.append(f"{name}/")
except Exception as e:
logger.error(f"Failed to remove {dir_path}: {e}")
return {
"status": "cleared",
"docs_path": docs_path,
"removed_items": removed,
"total_removed": len(removed),
}
def start_api():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")