"""api/routes/experiments.py — Experiment tracking dashboard (Feature 1).""" from fastapi import APIRouter, Query from typing import Optional import json import math from pydantic import BaseModel from infra.database import get_db, ExperimentRun, ModelRegistryEntry, TeamNote from infra.result_contract import sanitize_for_json from services.studio_service import experiment_diff router = APIRouter(prefix="/api", tags=["experiments"]) @router.get("/experiments") def list_experiments(limit: int = 50, task_type: Optional[str] = None): """Return all experiment runs, most recent first.""" with get_db() as db: q = db.query(ExperimentRun).order_by(ExperimentRun.created_at.desc()) if task_type: q = q.filter(ExperimentRun.task_type == task_type) runs = q.limit(limit).all() result = [] for r in runs: try: hyperparams = json.loads(r.hyperparams_json) if r.hyperparams_json else {} except Exception: hyperparams = {} try: metrics = json.loads(r.metrics_json) if r.metrics_json else {} except Exception: metrics = {} try: leaderboard = json.loads(r.leaderboard_json) if r.leaderboard_json else [] except Exception: leaderboard = [] result.append({ "id": r.id, "job_id": r.job_id, "dataset_id": r.dataset_id, "model_name": r.model_name, "metric_name": r.metric_name, "score": _safe_float(r.score), "task_type": r.task_type, "mode": r.mode, "goal": r.goal, "feature_count": r.feature_count, "row_count": r.row_count, "created_at": r.created_at.isoformat() if r.created_at else None, "hyperparams": sanitize_for_json(hyperparams), "metrics": sanitize_for_json(metrics), "leaderboard": sanitize_for_json(leaderboard), }) return result @router.get("/experiments/compare") def compare_experiments(ids: str = Query(..., description="Comma-separated experiment run IDs")): """Side-by-side comparison of multiple experiment runs.""" id_list = [i.strip() for i in ids.split(",") if i.strip()] with get_db() as db: runs = db.query(ExperimentRun).filter(ExperimentRun.id.in_(id_list)).all() comparison = [] for r in runs: try: hyperparams = json.loads(r.hyperparams_json) if r.hyperparams_json else {} except Exception: hyperparams = {} try: metrics = json.loads(r.metrics_json) if r.metrics_json else {} except Exception: metrics = {} comparison.append({ "id": r.id, "job_id": r.job_id, "model_name": r.model_name, "metric_name": r.metric_name, "score": _safe_float(r.score), "task_type": r.task_type, "mode": r.mode, "goal": r.goal, "feature_count": r.feature_count, "row_count": r.row_count, "created_at": r.created_at.isoformat() if r.created_at else None, "hyperparams": sanitize_for_json(hyperparams), "metrics": sanitize_for_json(metrics), }) return {"comparison": comparison, "count": len(comparison)} @router.get("/experiments/diff") def diff_experiments(run_a: str = Query(...), run_b: str = Query(...)): return experiment_diff(run_a, run_b) def _safe_float(value): try: numeric = float(value) return numeric if math.isfinite(numeric) else None except (TypeError, ValueError): return None @router.get("/experiments/{run_id}") def get_experiment(run_id: str): """Get a single experiment run by ID.""" with get_db() as db: r = db.query(ExperimentRun).filter(ExperimentRun.id == run_id).first() if not r: return {"error": "Experiment not found"} try: hyperparams = json.loads(r.hyperparams_json) if r.hyperparams_json else {} except Exception: hyperparams = {} try: metrics = json.loads(r.metrics_json) if r.metrics_json else {} except Exception: metrics = {} try: leaderboard = json.loads(r.leaderboard_json) if r.leaderboard_json else [] except Exception: leaderboard = [] return { "id": r.id, "job_id": r.job_id, "model_name": r.model_name, "metric_name": r.metric_name, "score": _safe_float(r.score), "task_type": r.task_type, "mode": r.mode, "goal": r.goal, "feature_count": r.feature_count, "row_count": r.row_count, "created_at": r.created_at.isoformat() if r.created_at else None, "hyperparams": sanitize_for_json(hyperparams), "metrics": sanitize_for_json(metrics), "leaderboard": sanitize_for_json(leaderboard), } @router.get("/experiments/{run_id}/registry") def get_registry(run_id: str): with get_db() as db: row = db.query(ModelRegistryEntry).filter(ModelRegistryEntry.run_id == run_id).first() if not row: return {"run_id": run_id, "label": None, "note": None} return {"run_id": run_id, "label": row.label, "note": row.note} class RegistryRequest(BaseModel): label: str note: Optional[str] = None @router.post("/experiments/{run_id}/registry") def save_registry(run_id: str, req: RegistryRequest): with get_db() as db: row = db.query(ModelRegistryEntry).filter(ModelRegistryEntry.run_id == run_id).first() if row: row.label = req.label row.note = req.note else: db.add(ModelRegistryEntry(run_id=run_id, label=req.label, note=req.note)) return {"run_id": run_id, "label": req.label, "note": req.note} @router.get("/notes/{entity_type}/{entity_id}") def get_notes(entity_type: str, entity_id: str): with get_db() as db: rows = ( db.query(TeamNote) .filter(TeamNote.entity_type == entity_type, TeamNote.entity_id == entity_id) .order_by(TeamNote.created_at.desc()) .all() ) return { "notes": [ { "id": row.id, "note": row.note, "created_at": row.created_at.isoformat() if row.created_at else None, } for row in rows ] } class TeamNoteRequest(BaseModel): note: str @router.post("/notes/{entity_type}/{entity_id}") def add_note(entity_type: str, entity_id: str, req: TeamNoteRequest): text = (req.note or "").strip() if not text: return {"error": "Note cannot be empty."} with get_db() as db: db.add(TeamNote(entity_type=entity_type, entity_id=entity_id, note=text)) return {"ok": True}