André Oliveira commited on
Commit
282d875
·
1 Parent(s): 188a5d8

refactored api

Browse files
Files changed (1) hide show
  1. api.py +102 -124
api.py CHANGED
@@ -5,35 +5,22 @@ import json
5
  import logging
6
  import time
7
  import shutil
 
8
 
9
- from models import OptimizeRequest, QARequest, AutotuneRequest
10
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
11
  from fastapi.middleware.cors import CORSMiddleware
 
12
 
13
- try:
14
- from ragmint.autotuner import AutoRAGTuner
15
- from ragmint.qa_generator import generate_validation_qa
16
- from ragmint.explainer import explain_results
17
- from ragmint.leaderboard import Leaderboard
18
- from ragmint.tuner import RAGMint
19
- except Exception as e:
20
- AutoRAGTuner = None
21
- generate_validation_qa = None
22
- explain_results = None
23
- Leaderboard = None
24
- RAGMint = None
25
- _import_error = e
26
- else:
27
- _import_error = None
28
 
29
- from dotenv import load_dotenv
30
  load_dotenv()
31
 
32
  # Logging
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger("ragmint_mcp_server")
35
 
36
- # FastAPI app (exported for mounting)
37
  app = FastAPI(title="Ragmint MCP Server", version="0.1.0")
38
  app.add_middleware(
39
  CORSMiddleware,
@@ -43,14 +30,30 @@ app.add_middleware(
43
  allow_headers=["*"],
44
  )
45
 
46
- # Use repo-local data folder (not parent dirs)
47
  DEFAULT_DATA_DIR = "data/docs"
48
  LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl"
49
-
50
- # ensure folders exist
51
  os.makedirs(DEFAULT_DATA_DIR, exist_ok=True)
52
  os.makedirs("experiments", exist_ok=True)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @app.get("/health")
55
  def health():
56
  return {
@@ -59,10 +62,11 @@ def health():
59
  "import_error": str(_import_error) if _import_error else None,
60
  }
61
 
 
62
  @app.post("/upload_docs")
63
  async def upload_docs(
64
- docs_path: str = Form(...),
65
- files: list[UploadFile] = File(...)
66
  ):
67
  os.makedirs(docs_path, exist_ok=True)
68
  saved_files = []
@@ -74,6 +78,34 @@ async def upload_docs(
74
  return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path}
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @app.post("/optimize_rag")
78
  def optimize_rag(req: OptimizeRequest):
79
  logger.info("Received optimize_rag request: %s", req.json())
@@ -89,40 +121,14 @@ def optimize_rag(req: OptimizeRequest):
89
  docs_path=docs_path,
90
  retrievers=req.retriever,
91
  embeddings=req.embedding_model,
92
- rerankers=(req.rerankers or ["mmr"]),
93
  chunk_sizes=req.chunk_sizes,
94
  overlaps=req.overlaps,
95
  strategies=req.strategy,
96
  )
97
 
98
- # validation set handling
99
- validation_set = None
100
- validation_choice = (req.validation_choice or "").strip()
101
- default_val_path = os.path.join(docs_path, "validation_qa.json")
102
-
103
- if not validation_choice:
104
- if os.path.exists(default_val_path):
105
- validation_set = default_val_path
106
- logger.info("Using default validation set: %s", validation_set)
107
- else:
108
- logger.warning("No validation_choice provided and no default found.")
109
- validation_set = None
110
- elif "/" in validation_choice and not os.path.exists(validation_choice):
111
- validation_set = validation_choice
112
- logger.info("Using HF dataset as validation: %s", validation_set)
113
- elif os.path.exists(validation_choice):
114
- validation_set = validation_choice
115
- logger.info("Using local validation dataset: %s", validation_set)
116
- elif validation_choice.lower() == "generate":
117
- gen_path = os.path.join(docs_path, "validation_qa.json")
118
- generate_validation_qa(
119
- docs_path=docs_path,
120
- output_path=gen_path,
121
- llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite"
122
- )
123
- validation_set = gen_path
124
- logger.info("Generated validation QA at: %s", validation_set)
125
-
126
  start_time = time.time()
127
  best, results = rag.optimize(
128
  validation_set=validation_set,
@@ -133,29 +139,23 @@ def optimize_rag(req: OptimizeRequest):
133
  elapsed = time.time() - start_time
134
  run_id = f"opt_{int(time.time())}"
135
 
136
- try:
137
- corpus_stats = {
138
- "num_docs": len(rag.documents),
139
- "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
140
- "corpus_size": sum(len(d) for d in rag.documents),
141
- }
142
- except Exception:
143
- corpus_stats = None
144
-
145
- try:
146
- if Leaderboard:
147
- lb = Leaderboard()
148
- lb.upload(
149
- run_id=run_id,
150
- best_config=best,
151
- best_score=best.get("faithfulness", best.get("score", 0.0)),
152
- all_results=results,
153
- documents=os.listdir(docs_path),
154
- model=best.get("embedding_model", req.embedding_model),
155
- corpus_stats=corpus_stats,
156
- )
157
- except Exception:
158
- logger.exception("Leaderboard persistence failed for optimize_rag")
159
 
160
  return {
161
  "status": "finished",
@@ -165,7 +165,6 @@ def optimize_rag(req: OptimizeRequest):
165
  "results": results,
166
  "corpus_stats": corpus_stats,
167
  }
168
-
169
  except Exception as exc:
170
  logger.exception("optimize_rag failed")
171
  raise HTTPException(status_code=500, detail=str(exc))
@@ -191,7 +190,6 @@ def autotune_rag(req: AutotuneRequest):
191
  num_pairs=int(req.num_chunk_pairs),
192
  step=20
193
  )
194
-
195
  chunk_sizes = sorted({c for c, _ in chunk_candidates})
196
  overlaps = sorted({o for _, o in chunk_candidates})
197
 
@@ -205,27 +203,8 @@ def autotune_rag(req: AutotuneRequest):
205
  strategies=[rec["strategy"]],
206
  )
207
 
208
- validation_set = None
209
- validation_choice = (req.validation_choice or "").strip()
210
- default_val_path = os.path.join(docs_path, "validation_qa.jsonl")
211
- if not validation_choice:
212
- if os.path.exists(default_val_path):
213
- validation_set = default_val_path
214
- else:
215
- validation_set = None
216
- elif "/" in validation_choice and not os.path.exists(validation_choice):
217
- validation_set = validation_choice
218
- elif os.path.exists(validation_choice):
219
- validation_set = validation_choice
220
- elif validation_choice.lower() == "generate":
221
- gen_path = os.path.join(docs_path, "validation_qa.json")
222
- generate_validation_qa(
223
- docs_path=docs_path,
224
- output_path=gen_path,
225
- llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite",
226
- )
227
- validation_set = gen_path
228
-
229
  best, results = rag.optimize(
230
  validation_set=validation_set,
231
  metric=req.metric,
@@ -235,29 +214,23 @@ def autotune_rag(req: AutotuneRequest):
235
  elapsed = time.time() - start_time
236
  run_id = f"autotune_{int(time.time())}"
237
 
238
- try:
239
- corpus_stats = {
240
- "num_docs": len(rag.documents),
241
- "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
242
- "corpus_size": sum(len(d) for d in rag.documents),
243
- }
244
- except Exception:
245
- corpus_stats = None
246
-
247
- try:
248
- if Leaderboard:
249
- lb = Leaderboard()
250
- lb.upload(
251
- run_id=run_id,
252
- best_config=best,
253
- best_score=best.get("faithfulness", best.get("score", 0.0)),
254
- all_results=results,
255
- documents=os.listdir(docs_path),
256
- model=best.get("embedding_model", rec.get("embedding_model")),
257
- corpus_stats=corpus_stats,
258
- )
259
- except Exception:
260
- logger.exception("Leaderboard persistence failed for autotune_rag")
261
 
262
  return {
263
  "status": "finished",
@@ -276,13 +249,13 @@ def autotune_rag(req: AutotuneRequest):
276
 
277
 
278
  @app.post("/generate_validation_qa")
279
- def generate_qa(req: QARequest):
280
  logger.info("Received generate_validation_qa request: %s", req.json())
281
  if generate_validation_qa is None:
282
  raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}")
283
 
284
  try:
285
- out_path = os.path.join("data", "docs", "validation_qa.json")
286
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
287
 
288
  generate_validation_qa(
@@ -297,7 +270,12 @@ def generate_qa(req: QARequest):
297
  with open(out_path, "r", encoding="utf-8") as f:
298
  data = json.load(f)
299
 
300
- return {"status": "finished", "output_path": out_path, "preview_count": len(data), "sample": data[:5]}
 
 
 
 
 
301
 
302
  except Exception as exc:
303
  logger.exception("generate_validation_qa failed")
 
5
  import logging
6
  import time
7
  import shutil
8
+ from typing import List, Optional
9
 
 
10
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
11
  from fastapi.middleware.cors import CORSMiddleware
12
+ from dotenv import load_dotenv
13
 
14
+ from models import OptimizeRequest, QARequest, AutotuneRequest
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Load environment
17
  load_dotenv()
18
 
19
  # Logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger("ragmint_mcp_server")
22
 
23
+ # FastAPI app
24
  app = FastAPI(title="Ragmint MCP Server", version="0.1.0")
25
  app.add_middleware(
26
  CORSMiddleware,
 
30
  allow_headers=["*"],
31
  )
32
 
33
+ # Directories
34
  DEFAULT_DATA_DIR = "data/docs"
35
  LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl"
 
 
36
  os.makedirs(DEFAULT_DATA_DIR, exist_ok=True)
37
  os.makedirs("experiments", exist_ok=True)
38
 
39
+ # Try importing ragmint modules
40
+ try:
41
+ from ragmint.autotuner import AutoRAGTuner
42
+ from ragmint.qa_generator import generate_validation_qa
43
+ from ragmint.explainer import explain_results
44
+ from ragmint.leaderboard import Leaderboard
45
+ from ragmint.tuner import RAGMint
46
+ except Exception as e:
47
+ AutoRAGTuner = None
48
+ generate_validation_qa = None
49
+ explain_results = None
50
+ Leaderboard = None
51
+ RAGMint = None
52
+ _import_error = e
53
+ else:
54
+ _import_error = None
55
+
56
+
57
  @app.get("/health")
58
  def health():
59
  return {
 
62
  "import_error": str(_import_error) if _import_error else None,
63
  }
64
 
65
+
66
  @app.post("/upload_docs")
67
  async def upload_docs(
68
+ docs_path: str = Form(...),
69
+ files: List[UploadFile] = File(...)
70
  ):
71
  os.makedirs(docs_path, exist_ok=True)
72
  saved_files = []
 
78
  return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path}
79
 
80
 
81
+ def handle_validation_choice(docs_path: str, validation_choice: Optional[str], llm_model: str) -> Optional[str]:
82
+ """Determine which validation QA set to use or generate one."""
83
+ validation_choice = (validation_choice or "").strip()
84
+ default_path = os.path.join(docs_path, "validation_qa.json")
85
+
86
+ if not validation_choice:
87
+ if os.path.exists(default_path):
88
+ logger.info("Using default validation QA: %s", default_path)
89
+ return default_path
90
+ return None
91
+
92
+ if validation_choice.lower() == "generate":
93
+ generate_validation_qa(
94
+ docs_path=docs_path,
95
+ output_path=default_path,
96
+ llm_model=llm_model
97
+ )
98
+ logger.info("Generated validation QA at: %s", default_path)
99
+ return default_path
100
+
101
+ if os.path.exists(validation_choice) or "/" in validation_choice:
102
+ logger.info("Using specified validation dataset: %s", validation_choice)
103
+ return validation_choice
104
+
105
+ logger.warning("Validation choice provided but not found: %s", validation_choice)
106
+ return None
107
+
108
+
109
  @app.post("/optimize_rag")
110
  def optimize_rag(req: OptimizeRequest):
111
  logger.info("Received optimize_rag request: %s", req.json())
 
121
  docs_path=docs_path,
122
  retrievers=req.retriever,
123
  embeddings=req.embedding_model,
124
+ rerankers=req.rerankers or ["mmr"],
125
  chunk_sizes=req.chunk_sizes,
126
  overlaps=req.overlaps,
127
  strategies=req.strategy,
128
  )
129
 
130
+ validation_set = handle_validation_choice(docs_path, req.validation_choice,
131
+ getattr(req, "llm_model", "gemini-2.5-flash-lite"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  start_time = time.time()
133
  best, results = rag.optimize(
134
  validation_set=validation_set,
 
139
  elapsed = time.time() - start_time
140
  run_id = f"opt_{int(time.time())}"
141
 
142
+ corpus_stats = {
143
+ "num_docs": len(rag.documents),
144
+ "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
145
+ "corpus_size": sum(len(d) for d in rag.documents),
146
+ }
147
+
148
+ if Leaderboard:
149
+ lb = Leaderboard()
150
+ lb.upload(
151
+ run_id=run_id,
152
+ best_config=best,
153
+ best_score=best.get("faithfulness", best.get("score", 0.0)),
154
+ all_results=results,
155
+ documents=os.listdir(docs_path),
156
+ model=best.get("embedding_model", req.embedding_model),
157
+ corpus_stats=corpus_stats,
158
+ )
 
 
 
 
 
 
159
 
160
  return {
161
  "status": "finished",
 
165
  "results": results,
166
  "corpus_stats": corpus_stats,
167
  }
 
168
  except Exception as exc:
169
  logger.exception("optimize_rag failed")
170
  raise HTTPException(status_code=500, detail=str(exc))
 
190
  num_pairs=int(req.num_chunk_pairs),
191
  step=20
192
  )
 
193
  chunk_sizes = sorted({c for c, _ in chunk_candidates})
194
  overlaps = sorted({o for _, o in chunk_candidates})
195
 
 
203
  strategies=[rec["strategy"]],
204
  )
205
 
206
+ validation_set = handle_validation_choice(docs_path, req.validation_choice,
207
+ getattr(req, "llm_model", "gemini-2.5-flash-lite"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  best, results = rag.optimize(
209
  validation_set=validation_set,
210
  metric=req.metric,
 
214
  elapsed = time.time() - start_time
215
  run_id = f"autotune_{int(time.time())}"
216
 
217
+ corpus_stats = {
218
+ "num_docs": len(rag.documents),
219
+ "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
220
+ "corpus_size": sum(len(d) for d in rag.documents),
221
+ }
222
+
223
+ if Leaderboard:
224
+ lb = Leaderboard()
225
+ lb.upload(
226
+ run_id=run_id,
227
+ best_config=best,
228
+ best_score=best.get("faithfulness", best.get("score", 0.0)),
229
+ all_results=results,
230
+ documents=os.listdir(docs_path),
231
+ model=best.get("embedding_model", rec.get("embedding_model")),
232
+ corpus_stats=corpus_stats,
233
+ )
 
 
 
 
 
 
234
 
235
  return {
236
  "status": "finished",
 
249
 
250
 
251
  @app.post("/generate_validation_qa")
252
+ def generate_validation_qa_endpoint(req: QARequest):
253
  logger.info("Received generate_validation_qa request: %s", req.json())
254
  if generate_validation_qa is None:
255
  raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}")
256
 
257
  try:
258
+ out_path = os.path.join(req.docs_path or DEFAULT_DATA_DIR, "validation_qa.json")
259
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
260
 
261
  generate_validation_qa(
 
270
  with open(out_path, "r", encoding="utf-8") as f:
271
  data = json.load(f)
272
 
273
+ return {
274
+ "status": "finished",
275
+ "output_path": out_path,
276
+ "preview_count": len(data),
277
+ "sample": data[:5]
278
+ }
279
 
280
  except Exception as exc:
281
  logger.exception("generate_validation_qa failed")