File size: 10,894 Bytes
6205b94
 
 
 
 
af5defe
 
6205b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5defe
6205b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8476db7
 
 
6205b94
8476db7
 
 
6205b94
8476db7
 
 
 
 
 
6205b94
 
 
 
 
8476db7
 
 
 
 
 
6205b94
8476db7
 
6205b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8476db7
 
 
6205b94
 
 
8476db7
 
 
6205b94
8476db7
 
6205b94
 
 
 
 
 
8476db7
 
 
 
 
 
 
 
 
6205b94
 
 
 
8476db7
6205b94
 
 
 
8476db7
 
 
6205b94
8476db7
 
 
6205b94
 
 
 
 
 
 
 
 
8476db7
 
6205b94
 
 
 
8476db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6205b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
"""Submission queue management using HuggingFace Datasets.

Manages the lifecycle of benchmark submissions:
  pending β†’ approved β†’ dispatching β†’ boltz β†’ scoring β†’ complete / failed

Rate limiting: 1 submission per calendar month per organization.
LLM-judge API costs are paid by Romero Lab, so the limit is intentionally low.

HF Dataset: RomeroLab-Duke/biodesignbench-submissions (private)
Schema: Each row is a submission with per-task results stored as JSON.
"""

from __future__ import annotations

import json
import logging
import os
import uuid
from datetime import datetime, timezone
from typing import Any

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
#  Constants
# ---------------------------------------------------------------------------

SUBMISSIONS_DATASET = os.environ.get(
    "BDB_SUBMISSIONS_DATASET",
    "RomeroLab-Duke/biodesignbench-submissions",
)
HF_TOKEN = os.environ.get("HF_TOKEN")
MAX_SUBMISSIONS_PER_MONTH = 1

# Submission status progression
VALID_STATUSES = {
    "pending",
    "approved",
    "dispatching",
    "boltz",
    "scoring",
    "complete",
    "failed",
    "rejected",
}


# ---------------------------------------------------------------------------
#  Data model
# ---------------------------------------------------------------------------


def _make_submission_row(
    agent_name: str,
    organization: str,
    provider: str,
    model_name: str,
    api_key: str,
    description: str = "",
    custom_mcp_url: str = "",
    custom_mcp_token: str = "",
    canary_token: str = "",
) -> dict[str, Any]:
    """Create a new submission row.

    The submitter's `api_key` is stored on the row only between
    submission and dispatch; `scrub_credentials()` removes it
    immediately after the agent loop completes (or fails).
    """
    now = datetime.now(timezone.utc).isoformat()
    return {
        "submission_id": str(uuid.uuid4())[:12],
        "agent_name": agent_name,
        "organization": organization,
        "provider": provider,
        "model_name": model_name,
        # Transient credentials -- scrubbed after dispatch
        "api_key": api_key,
        "custom_mcp_url": custom_mcp_url,
        "custom_mcp_token": custom_mcp_token,
        "description": description,
        "mcp_custom": bool(custom_mcp_url),
        "canary_token": canary_token,
        "status": "pending",
        "created_at": now,
        "updated_at": now,
        "tasks_dispatched": 0,
        "tasks_total": 76,
        "tasks_boltz_done": 0,
        "overall_score": None,
        "component_scores": None,
        "taxonomy_scores": None,
        "per_task_results": "{}",  # JSON string of task_id β†’ result
        "error_message": None,
    }


# ---------------------------------------------------------------------------
#  Queue operations (HF Datasets API)
# ---------------------------------------------------------------------------


def _get_dataset():
    """Load the submissions dataset from HF Hub."""
    try:
        from datasets import load_dataset

        ds = load_dataset(
            SUBMISSIONS_DATASET,
            split="train",
            token=HF_TOKEN,
        )
        return ds
    except Exception as e:
        logger.warning(f"Could not load submissions dataset: {e}")
        return None


def _save_rows(rows: list[dict[str, Any]]) -> bool:
    """Save rows back to HF Dataset."""
    try:
        from datasets import Dataset
        from huggingface_hub import HfApi

        ds = Dataset.from_list(rows)
        ds.push_to_hub(
            SUBMISSIONS_DATASET,
            token=HF_TOKEN,
            private=True,
        )
        return True
    except Exception as e:
        logger.error(f"Failed to save submissions: {e}")
        return False


def _load_all_rows() -> list[dict[str, Any]]:
    """Load all submission rows as a list of dicts."""
    ds = _get_dataset()
    if ds is None:
        return []
    return [dict(row) for row in ds]


SUPPORTED_PROVIDERS = {"anthropic", "openai", "deepseek", "google"}


def submit(
    agent_name: str,
    organization: str,
    provider: str,
    model_name: str,
    api_key: str,
    description: str = "",
    custom_mcp_url: str = "",
    custom_mcp_token: str = "",
) -> dict[str, Any]:
    """Create a new submission.

    Returns:
        Dict with submission_id and status, or error message.
    """
    if not agent_name or not organization or not model_name or not api_key:
        return {"error": "agent_name, organization, model_name, and api_key are required"}

    if provider not in SUPPORTED_PROVIDERS:
        return {"error": f"provider must be one of {sorted(SUPPORTED_PROVIDERS)}"}

    if custom_mcp_url and not custom_mcp_url.startswith(("http://", "https://")):
        return {"error": "custom_mcp_url must start with http:// or https://"}

    error = check_rate_limit(organization)
    if error:
        return {"error": error}

    canary = uuid.uuid4().hex[:16]

    row = _make_submission_row(
        agent_name=agent_name,
        organization=organization,
        provider=provider,
        model_name=model_name,
        api_key=api_key,
        description=description,
        custom_mcp_url=custom_mcp_url,
        custom_mcp_token=custom_mcp_token,
        canary_token=canary,
    )

    rows = _load_all_rows()
    rows.append(row)

    if _save_rows(rows):
        return {
            "submission_id": row["submission_id"],
            "status": "pending",
            "canary_token": canary,
            "message": "Submission created. Awaiting admin approval.",
        }
    return {"error": "Failed to save submission. Please try again."}


def scrub_credentials(submission_id: str) -> bool:
    """Remove the submitter's api_key (and custom MCP token) from a row.

    Called immediately after the dispatch phase, regardless of whether
    the agent loop succeeded. The api_key is forwarded directly from the
    submission form to the agent loop and is never needed again after
    that single use.
    """
    rows = _load_all_rows()
    found = False
    for row in rows:
        if row.get("submission_id") == submission_id:
            row["api_key"] = ""
            row["custom_mcp_token"] = ""
            row["updated_at"] = datetime.now(timezone.utc).isoformat()
            found = True
            break
    if not found:
        logger.error(f"scrub_credentials: submission {submission_id} not found")
        return False
    return _save_rows(rows)


def check_rate_limit(organization: str) -> str | None:
    """Check if an organization has exceeded the monthly submission limit.

    Returns:
        Error message string if rate limited, None if OK.
    """
    rows = _load_all_rows()
    now = datetime.now(timezone.utc)
    current_month = now.strftime("%Y-%m")

    monthly_count = 0
    for row in rows:
        if row.get("organization", "").lower() != organization.lower():
            continue
        if row.get("status") in ("rejected", "failed"):
            continue
        created = row.get("created_at", "")
        if created.startswith(current_month):
            monthly_count += 1

    if monthly_count >= MAX_SUBMISSIONS_PER_MONTH:
        return (
            f"Organization '{organization}' has reached the limit of "
            f"{MAX_SUBMISSIONS_PER_MONTH} submissions for {current_month}."
        )
    return None


def update_status(
    submission_id: str,
    status: str,
    **extra_fields: Any,
) -> bool:
    """Update a submission's status and optional extra fields.

    Args:
        submission_id: The submission to update.
        status: New status (must be in VALID_STATUSES).
        **extra_fields: Additional fields to update (e.g., tasks_dispatched=10).

    Returns:
        True if updated successfully.
    """
    if status not in VALID_STATUSES:
        logger.error(f"Invalid status: {status}")
        return False

    rows = _load_all_rows()
    found = False
    for row in rows:
        if row.get("submission_id") == submission_id:
            row["status"] = status
            row["updated_at"] = datetime.now(timezone.utc).isoformat()
            for k, v in extra_fields.items():
                if k in row:
                    row[k] = v
            found = True
            break

    if not found:
        logger.error(f"Submission {submission_id} not found")
        return False

    return _save_rows(rows)


def save_task_result(
    submission_id: str,
    task_id: str,
    result: dict[str, Any],
) -> bool:
    """Save a per-task result to the submission.

    Args:
        submission_id: The submission to update.
        task_id: Task identifier.
        result: Score result dict from eval_scorer.score_submission_task().

    Returns:
        True if saved successfully.
    """
    rows = _load_all_rows()
    for row in rows:
        if row.get("submission_id") == submission_id:
            per_task = json.loads(row.get("per_task_results", "{}"))
            per_task[task_id] = result
            row["per_task_results"] = json.dumps(per_task)
            row["tasks_dispatched"] = len(per_task)
            row["updated_at"] = datetime.now(timezone.utc).isoformat()
            return _save_rows(rows)

    logger.error(f"Submission {submission_id} not found")
    return False


def get_submission(submission_id: str) -> dict[str, Any] | None:
    """Get a single submission by ID."""
    rows = _load_all_rows()
    for row in rows:
        if row.get("submission_id") == submission_id:
            return row
    return None


def get_pending_submissions() -> list[dict[str, Any]]:
    """Get all submissions awaiting admin approval."""
    return [r for r in _load_all_rows() if r.get("status") == "pending"]


def get_approved_submissions() -> list[dict[str, Any]]:
    """Get all approved submissions ready for dispatch."""
    return [r for r in _load_all_rows() if r.get("status") == "approved"]


def get_all_submissions() -> list[dict[str, Any]]:
    """Get all submissions for the admin panel."""
    return _load_all_rows()


def finalize_submission(
    submission_id: str,
    overall_score: float,
    component_scores: dict[str, float],
    taxonomy_scores: dict[str, dict[str, float]],
) -> bool:
    """Finalize a submission with aggregated scores.

    Args:
        submission_id: The submission to finalize.
        overall_score: Overall score (0-100).
        component_scores: Dict of component β†’ averaged score.
        taxonomy_scores: Nested dict of task_type β†’ context β†’ avg score.

    Returns:
        True if finalized successfully.
    """
    return update_status(
        submission_id,
        status="complete",
        overall_score=overall_score,
        component_scores=json.dumps(component_scores),
        taxonomy_scores=json.dumps(taxonomy_scores),
    )