Jasonkim8652 commited on
Commit
6205b94
·
verified ·
1 Parent(s): eecaec9

feat: add submission & scoring infrastructure (eval_scorer, dispatcher, boltz, queue, tasks) + fix gradio 5.x for Python 3.13

Browse files
Files changed (9) hide show
  1. app.py +407 -1
  2. eval_boltz.py +272 -0
  3. eval_dispatcher.py +361 -0
  4. eval_queue.py +312 -0
  5. eval_scorer.py +1643 -0
  6. eval_tasks.py +236 -0
  7. example_server.py +205 -0
  8. mcp_tool_schemas.json +468 -0
  9. requirements.txt +6 -1
app.py CHANGED
@@ -2,14 +2,26 @@
2
 
3
  Evaluating LLM Agents on Protein Design via MCP Tools
4
  Romero Lab, Duke University
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  import json
 
8
  from pathlib import Path
9
 
10
  import gradio as gr
11
  import plotly.graph_objects as go
12
 
 
 
13
 
14
  # ═══════════════════════════════════════════════════════════════════
15
  # Configuration — change these when deploying
@@ -916,7 +928,401 @@ def create_app() -> gr.Blocks:
916
  gr.Plot(chart_mode_comparison(entries))
917
  gr.HTML(build_mode_cards(entries))
918
 
919
- # ════════ Tab 5: About ════════
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  with gr.Tab("\u2139\ufe0f About"):
921
  gr.HTML(build_about())
922
 
 
2
 
3
  Evaluating LLM Agents on Protein Design via MCP Tools
4
  Romero Lab, Duke University
5
+
6
+ Tabs:
7
+ 1. Overall Leaderboard
8
+ 2. Taxonomy Breakdown
9
+ 3. Component Analysis
10
+ 4. Benchmark vs User
11
+ 5. Submit (new submission form)
12
+ 6. Status & Admin (password-protected pipeline control)
13
+ 7. About
14
  """
15
 
16
  import json
17
+ import os
18
  from pathlib import Path
19
 
20
  import gradio as gr
21
  import plotly.graph_objects as go
22
 
23
+ ADMIN_PASSWORD = os.environ.get("BDB_ADMIN_PASSWORD", "biodesignbench2026")
24
+
25
 
26
  # ═══════════════════════════════════════════════════════════════════
27
  # Configuration — change these when deploying
 
928
  gr.Plot(chart_mode_comparison(entries))
929
  gr.HTML(build_mode_cards(entries))
930
 
931
+ # ══════ Tab 5: Submit ══════
932
+ with gr.Tab("\U0001f4e4 Submit"):
933
+ gr.HTML("""
934
+ <div style="max-width:700px;margin:0 auto;padding:1rem">
935
+ <h2 style="color:#1a365d;margin:0 0 0.5rem">
936
+ Submit Your Agent</h2>
937
+ <p style="color:#4a5568;margin-bottom:1rem;line-height:1.5">
938
+ Submit your protein design agent for benchmarking.
939
+ Your agent must be hosted as a POST endpoint that accepts
940
+ task descriptions and returns designed sequences.
941
+ <strong>You bear all LLM and MCP tool costs</strong>;
942
+ we only run Boltz structure prediction on our end.</p>
943
+ <div style="background:#fefcbf;border-left:4px solid #d69e2e;
944
+ padding:0.8rem;border-radius:4px;margin-bottom:1rem;
945
+ font-size:0.85rem;color:#744210">
946
+ <strong>Rate limit:</strong> 2 submissions per calendar
947
+ month per organization.</div>
948
+ </div>""")
949
+
950
+ with gr.Column(scale=1):
951
+ sub_agent = gr.Textbox(
952
+ label="Agent Name",
953
+ placeholder="e.g., GPT-5 + Custom MCP Tools",
954
+ )
955
+ sub_org = gr.Textbox(
956
+ label="Organization",
957
+ placeholder="e.g., OpenAI",
958
+ )
959
+ sub_url = gr.Textbox(
960
+ label="Endpoint URL",
961
+ placeholder="https://your-server.com/api/run",
962
+ )
963
+ sub_desc = gr.Textbox(
964
+ label="Description (optional)",
965
+ placeholder="Brief description of your agent...",
966
+ lines=3,
967
+ )
968
+ sub_mcp = gr.Checkbox(
969
+ label="Uses custom MCP tools (not reference)",
970
+ value=False,
971
+ )
972
+ sub_btn = gr.Button(
973
+ "Submit for Review",
974
+ variant="primary",
975
+ )
976
+ sub_result = gr.HTML()
977
+
978
+ def _handle_submit(name, org, url, desc, mcp):
979
+ if not name or not org or not url:
980
+ return ('<div style="color:#e53e3e;padding:0.5rem">'
981
+ "Please fill in all required fields.</div>")
982
+ if not url.startswith(("http://", "https://")):
983
+ return ('<div style="color:#e53e3e;padding:0.5rem">'
984
+ "URL must start with http:// or https://</div>")
985
+ try:
986
+ from eval_queue import submit
987
+ result = submit(
988
+ agent_name=name,
989
+ organization=org,
990
+ endpoint_url=url,
991
+ description=desc,
992
+ mcp_custom=mcp,
993
+ )
994
+ if "error" in result:
995
+ return (f'<div style="color:#e53e3e;padding:0.5rem">'
996
+ f'{result["error"]}</div>')
997
+ return (
998
+ f'<div style="background:#c6f6d5;padding:1rem;'
999
+ f'border-radius:8px;margin-top:0.5rem">'
1000
+ f'<strong>Submitted!</strong> '
1001
+ f'ID: <code>{result["submission_id"]}</code><br>'
1002
+ f'Status: {result["status"]}<br>'
1003
+ f'{result.get("message", "")}</div>'
1004
+ )
1005
+ except Exception as e:
1006
+ return (f'<div style="color:#e53e3e;padding:0.5rem">'
1007
+ f"Error: {str(e)[:200]}</div>")
1008
+
1009
+ sub_btn.click(
1010
+ _handle_submit,
1011
+ [sub_agent, sub_org, sub_url, sub_desc, sub_mcp],
1012
+ sub_result,
1013
+ )
1014
+
1015
+ # ══════ Tab 6: Status & Admin ══════
1016
+ with gr.Tab("\U0001f6e0 Status"):
1017
+ gr.HTML("""
1018
+ <div style="max-width:800px;margin:0 auto;padding:1rem">
1019
+ <h2 style="color:#1a365d;margin:0 0 0.5rem">
1020
+ Submission Status & Admin</h2>
1021
+ <p style="color:#4a5568;margin-bottom:0.5rem">
1022
+ Check your submission status or manage the pipeline
1023
+ (admin only).</p>
1024
+ </div>""")
1025
+
1026
+ # --- Public status check ---
1027
+ with gr.Accordion("Check Submission Status", open=True):
1028
+ status_id = gr.Textbox(
1029
+ label="Submission ID",
1030
+ placeholder="Enter your submission ID...",
1031
+ )
1032
+ status_btn = gr.Button("Check Status")
1033
+ status_out = gr.HTML()
1034
+
1035
+ def _check_status(sid):
1036
+ if not sid:
1037
+ return '<div style="color:#718096">Enter an ID above.</div>'
1038
+ try:
1039
+ from eval_queue import get_submission
1040
+ sub = get_submission(sid.strip())
1041
+ if sub is None:
1042
+ return ('<div style="color:#e53e3e">'
1043
+ "Submission not found.</div>")
1044
+ status_color = {
1045
+ "pending": "#d69e2e", "approved": "#38a169",
1046
+ "dispatching": "#3182ce", "boltz": "#805ad5",
1047
+ "scoring": "#805ad5", "complete": "#38a169",
1048
+ "failed": "#e53e3e", "rejected": "#e53e3e",
1049
+ }.get(sub["status"], "#718096")
1050
+ score_html = ""
1051
+ if sub.get("overall_score") is not None:
1052
+ score_html = (
1053
+ f'<div style="font-size:1.2rem;'
1054
+ f'font-weight:700;color:#1a365d;'
1055
+ f'margin-top:0.5rem">'
1056
+ f'Score: {sub["overall_score"]:.1f}/100'
1057
+ f'</div>'
1058
+ )
1059
+ return (
1060
+ f'<div style="background:white;padding:1rem;'
1061
+ f'border-radius:8px;border:1px solid #e2e8f0">'
1062
+ f'<strong>{sub["agent_name"]}</strong> '
1063
+ f'({sub["organization"]})<br>'
1064
+ f'Status: <span style="color:{status_color};'
1065
+ f'font-weight:700">{sub["status"]}</span><br>'
1066
+ f'Tasks: {sub.get("tasks_dispatched", 0)}'
1067
+ f'/{sub.get("tasks_total", 76)}<br>'
1068
+ f'Created: {sub.get("created_at", "")[:10]}'
1069
+ f'{score_html}</div>'
1070
+ )
1071
+ except Exception as e:
1072
+ return f'<div style="color:#e53e3e">{e}</div>'
1073
+
1074
+ status_btn.click(_check_status, [status_id], status_out)
1075
+
1076
+ # --- Admin panel (password-protected) ---
1077
+ with gr.Accordion("Admin Panel", open=False):
1078
+ admin_pw = gr.Textbox(
1079
+ label="Admin Password", type="password",
1080
+ )
1081
+ admin_auth_btn = gr.Button("Authenticate")
1082
+ admin_panel = gr.Column(visible=False)
1083
+ admin_msg = gr.HTML()
1084
+
1085
+ with admin_panel:
1086
+ gr.HTML('<h3 style="color:#1a365d">'
1087
+ 'Pending Submissions</h3>')
1088
+ pending_html = gr.HTML()
1089
+ refresh_btn = gr.Button("Refresh List")
1090
+
1091
+ with gr.Row():
1092
+ approve_id = gr.Textbox(
1093
+ label="Submission ID to Approve/Reject",
1094
+ scale=2,
1095
+ )
1096
+ approve_btn = gr.Button(
1097
+ "Approve", variant="primary", scale=1,
1098
+ )
1099
+ reject_btn = gr.Button(
1100
+ "Reject", variant="stop", scale=1,
1101
+ )
1102
+ approve_msg = gr.HTML()
1103
+
1104
+ gr.HTML('<h3 style="color:#1a365d;margin-top:1rem">'
1105
+ 'Pipeline Control</h3>')
1106
+ with gr.Row():
1107
+ dispatch_id = gr.Textbox(
1108
+ label="Submission ID", scale=2,
1109
+ )
1110
+ dispatch_btn = gr.Button(
1111
+ "Phase A: Dispatch Tasks", scale=1,
1112
+ )
1113
+ with gr.Row():
1114
+ boltz_id = gr.Textbox(
1115
+ label="Submission ID", scale=2,
1116
+ )
1117
+ boltz_btn = gr.Button(
1118
+ "Phase B: Run Boltz", scale=1,
1119
+ )
1120
+ with gr.Row():
1121
+ final_id = gr.Textbox(
1122
+ label="Submission ID", scale=2,
1123
+ )
1124
+ final_btn = gr.Button(
1125
+ "Phase C: Finalize & Publish", scale=1,
1126
+ )
1127
+ pipeline_out = gr.HTML()
1128
+
1129
+ def _admin_auth(pw):
1130
+ if pw == ADMIN_PASSWORD:
1131
+ return (
1132
+ gr.update(visible=True),
1133
+ '<div style="color:#38a169">'
1134
+ 'Authenticated.</div>',
1135
+ )
1136
+ return (
1137
+ gr.update(visible=False),
1138
+ '<div style="color:#e53e3e">'
1139
+ 'Wrong password.</div>',
1140
+ )
1141
+
1142
+ admin_auth_btn.click(
1143
+ _admin_auth, [admin_pw],
1144
+ [admin_panel, admin_msg],
1145
+ )
1146
+
1147
+ def _refresh_pending():
1148
+ try:
1149
+ from eval_queue import get_pending_submissions
1150
+ pending = get_pending_submissions()
1151
+ if not pending:
1152
+ return "<p>No pending submissions.</p>"
1153
+ rows = []
1154
+ for s in pending:
1155
+ rows.append(
1156
+ f'<tr><td>{s["submission_id"]}</td>'
1157
+ f'<td>{s["agent_name"]}</td>'
1158
+ f'<td>{s["organization"]}</td>'
1159
+ f'<td>{s.get("endpoint_url","")[:40]}'
1160
+ f'...</td>'
1161
+ f'<td>{s.get("created_at","")[:10]}'
1162
+ f'</td></tr>'
1163
+ )
1164
+ return (
1165
+ '<table style="width:100%;font-size:0.85rem;'
1166
+ 'border-collapse:collapse">'
1167
+ "<tr><th>ID</th><th>Agent</th><th>Org</th>"
1168
+ "<th>URL</th><th>Date</th></tr>"
1169
+ + "".join(rows) + "</table>"
1170
+ )
1171
+ except Exception as e:
1172
+ return f"<p>Error: {e}</p>"
1173
+
1174
+ refresh_btn.click(
1175
+ _refresh_pending, [], pending_html,
1176
+ )
1177
+
1178
+ def _approve_sub(sid):
1179
+ try:
1180
+ from eval_queue import update_status
1181
+ ok = update_status(sid.strip(), "approved")
1182
+ if ok:
1183
+ return (
1184
+ f'<div style="color:#38a169">'
1185
+ f'Approved: {sid}</div>'
1186
+ )
1187
+ return (
1188
+ f'<div style="color:#e53e3e">'
1189
+ f'Failed to approve {sid}</div>'
1190
+ )
1191
+ except Exception as e:
1192
+ return f'<div style="color:#e53e3e">{e}</div>'
1193
+
1194
+ def _reject_sub(sid):
1195
+ try:
1196
+ from eval_queue import update_status
1197
+ ok = update_status(sid.strip(), "rejected")
1198
+ if ok:
1199
+ return (
1200
+ f'<div style="color:#d69e2e">'
1201
+ f'Rejected: {sid}</div>'
1202
+ )
1203
+ return (
1204
+ f'<div style="color:#e53e3e">'
1205
+ f'Failed to reject {sid}</div>'
1206
+ )
1207
+ except Exception as e:
1208
+ return f'<div style="color:#e53e3e">{e}</div>'
1209
+
1210
+ approve_btn.click(
1211
+ _approve_sub, [approve_id], approve_msg,
1212
+ )
1213
+ reject_btn.click(
1214
+ _reject_sub, [approve_id], approve_msg,
1215
+ )
1216
+
1217
+ def _run_dispatch(sid):
1218
+ try:
1219
+ import asyncio as _aio
1220
+ from eval_queue import get_submission
1221
+ from eval_dispatcher import dispatch_all_tasks
1222
+
1223
+ sub = get_submission(sid.strip())
1224
+ if sub is None:
1225
+ return (
1226
+ '<div style="color:#e53e3e">'
1227
+ 'Not found</div>'
1228
+ )
1229
+ if sub["status"] not in (
1230
+ "approved", "dispatching"
1231
+ ):
1232
+ return (
1233
+ f'<div style="color:#e53e3e">'
1234
+ f'Cannot dispatch: status='
1235
+ f'{sub["status"]}</div>'
1236
+ )
1237
+ loop = _aio.new_event_loop()
1238
+ results = loop.run_until_complete(
1239
+ dispatch_all_tasks(
1240
+ sid.strip(),
1241
+ sub["endpoint_url"],
1242
+ )
1243
+ )
1244
+ loop.close()
1245
+ ok = sum(
1246
+ 1 for r in results if r.get("success")
1247
+ )
1248
+ return (
1249
+ f'<div style="color:#38a169">'
1250
+ f'Dispatched: {ok}/{len(results)} '
1251
+ f'tasks succeeded.</div>'
1252
+ )
1253
+ except Exception as e:
1254
+ return f'<div style="color:#e53e3e">{e}</div>'
1255
+
1256
+ def _run_boltz(sid):
1257
+ try:
1258
+ from eval_queue import get_submission
1259
+ from eval_boltz import run_boltz_posteval
1260
+
1261
+ sub = get_submission(sid.strip())
1262
+ if sub is None:
1263
+ return (
1264
+ '<div style="color:#e53e3e">'
1265
+ 'Not found</div>'
1266
+ )
1267
+ per_task = json.loads(
1268
+ sub.get("per_task_results", "{}")
1269
+ )
1270
+ if not per_task:
1271
+ return (
1272
+ '<div style="color:#e53e3e">'
1273
+ "No task results to process.</div>"
1274
+ )
1275
+ run_boltz_posteval(per_task)
1276
+ return (
1277
+ '<div style="color:#38a169">'
1278
+ "Boltz post-assessment complete.</div>"
1279
+ )
1280
+ except Exception as e:
1281
+ return f'<div style="color:#e53e3e">{e}</div>'
1282
+
1283
+ def _run_finalize(sid):
1284
+ try:
1285
+ from eval_queue import (
1286
+ finalize_submission,
1287
+ get_submission,
1288
+ )
1289
+ from eval_scorer import aggregate_scores
1290
+
1291
+ sub = get_submission(sid.strip())
1292
+ if sub is None:
1293
+ return (
1294
+ '<div style="color:#e53e3e">'
1295
+ 'Not found</div>'
1296
+ )
1297
+ per_task = json.loads(
1298
+ sub.get("per_task_results", "{}")
1299
+ )
1300
+ agg = aggregate_scores(per_task)
1301
+ finalize_submission(
1302
+ sid.strip(),
1303
+ overall_score=agg["overall_score"],
1304
+ component_scores=agg["component_scores"],
1305
+ taxonomy_scores=agg["taxonomy_scores"],
1306
+ )
1307
+ return (
1308
+ f'<div style="color:#38a169">'
1309
+ f'Finalized! Score: '
1310
+ f'{agg["overall_score"]:.1f}</div>'
1311
+ )
1312
+ except Exception as e:
1313
+ return f'<div style="color:#e53e3e">{e}</div>'
1314
+
1315
+ dispatch_btn.click(
1316
+ _run_dispatch, [dispatch_id], pipeline_out,
1317
+ )
1318
+ boltz_btn.click(
1319
+ _run_boltz, [boltz_id], pipeline_out,
1320
+ )
1321
+ final_btn.click(
1322
+ _run_finalize, [final_id], pipeline_out,
1323
+ )
1324
+
1325
+ # ══════ Tab 7: About ══════
1326
  with gr.Tab("\u2139\ufe0f About"):
1327
  gr.HTML(build_about())
1328
 
eval_boltz.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Boltz structure prediction for post-assessment scoring.
2
+
3
+ Uses @spaces.GPU decorator for ZeroGPU on HuggingFace Spaces.
4
+
5
+ Two prediction modes:
6
+ - Monomer: Non-binding tasks -> pLDDT, pTM
7
+ - Complex: Binding tasks (binder + target) -> ipTM, i_pAE
8
+
9
+ Batch chunking respects ZeroGPU time limits (~180-240s per burst).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import time
16
+ from typing import Any
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Chunking limits for ZeroGPU (free tier: ~300s max per burst)
21
+ MONOMER_CHUNK_SIZE = 5 # ~30-60s per monomer
22
+ COMPLEX_CHUNK_SIZE = 2 # ~60-120s per complex
23
+ MAX_GPU_TIME = 240 # safety margin under 300s ZeroGPU limit
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Boltz prediction (GPU-accelerated)
28
+ # ---------------------------------------------------------------------------
29
+
30
+
31
+ def _predict_monomer(sequence: str) -> dict[str, float]:
32
+ """Predict structure of a single protein sequence using Boltz.
33
+
34
+ Returns:
35
+ Dict with: pLDDT, pTM (or error).
36
+ """
37
+ try:
38
+ import torch
39
+ from boltz import Boltz
40
+
41
+ model = Boltz.from_pretrained("boltz2")
42
+ result = model.predict(sequence)
43
+
44
+ plddt = float(result.confidence.plddt.mean())
45
+ ptm = float(result.confidence.ptm)
46
+
47
+ return {
48
+ "pLDDT": round(plddt, 2),
49
+ "pTM": round(ptm, 4),
50
+ "success": True,
51
+ }
52
+ except Exception as e:
53
+ logger.error(f"Boltz monomer prediction failed: {e}")
54
+ return {"pLDDT": 0.0, "pTM": 0.0, "success": False, "error": str(e)}
55
+
56
+
57
+ def _predict_complex(
58
+ binder_seq: str,
59
+ target_seq: str,
60
+ ) -> dict[str, float]:
61
+ """Predict complex structure and binding metrics using Boltz.
62
+
63
+ Returns:
64
+ Dict with: ipTM, i_pAE, pLDDT, pTM (or error).
65
+ """
66
+ try:
67
+ import torch
68
+ from boltz import Boltz
69
+
70
+ model = Boltz.from_pretrained("boltz2")
71
+ result = model.predict([binder_seq, target_seq])
72
+
73
+ plddt = float(result.confidence.plddt.mean())
74
+ ptm = float(result.confidence.ptm)
75
+ iptm = float(result.confidence.iptm) if hasattr(result.confidence, "iptm") else 0.0
76
+ ipae = float(result.confidence.ipae) if hasattr(result.confidence, "ipae") else 0.0
77
+
78
+ return {
79
+ "pLDDT": round(plddt, 2),
80
+ "pTM": round(ptm, 4),
81
+ "ipTM": round(iptm, 4),
82
+ "i_pAE": round(ipae, 2),
83
+ "success": True,
84
+ }
85
+ except Exception as e:
86
+ logger.error(f"Boltz complex prediction failed: {e}")
87
+ return {
88
+ "pLDDT": 0.0, "pTM": 0.0, "ipTM": 0.0, "i_pAE": 0.0,
89
+ "success": False, "error": str(e),
90
+ }
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # GPU-decorated entry points (for HF Spaces with ZeroGPU)
95
+ # ---------------------------------------------------------------------------
96
+
97
+ try:
98
+ import spaces
99
+
100
+ @spaces.GPU(duration=MAX_GPU_TIME)
101
+ def predict_monomer_batch(sequences: list[str]) -> list[dict[str, float]]:
102
+ """Predict structures for a batch of monomer sequences.
103
+
104
+ Decorated with @spaces.GPU for ZeroGPU allocation.
105
+
106
+ Args:
107
+ sequences: List of amino acid sequences (max MONOMER_CHUNK_SIZE).
108
+
109
+ Returns:
110
+ List of prediction result dicts with pLDDT, pTM.
111
+ """
112
+ results = []
113
+ for seq in sequences[:MONOMER_CHUNK_SIZE]:
114
+ results.append(_predict_monomer(seq))
115
+ return results
116
+
117
+ @spaces.GPU(duration=MAX_GPU_TIME)
118
+ def predict_complex_batch(
119
+ pairs: list[tuple[str, str]],
120
+ ) -> list[dict[str, float]]:
121
+ """Predict structures for a batch of binder-target pairs.
122
+
123
+ Args:
124
+ pairs: List of (binder_seq, target_seq) tuples.
125
+
126
+ Returns:
127
+ List of prediction result dicts with ipTM, i_pAE, pLDDT, pTM.
128
+ """
129
+ results = []
130
+ for binder, target in pairs[:COMPLEX_CHUNK_SIZE]:
131
+ results.append(_predict_complex(binder, target))
132
+ return results
133
+
134
+ except ImportError:
135
+ # Not running on HF Spaces -- provide un-decorated versions
136
+ def predict_monomer_batch(sequences: list[str]) -> list[dict[str, float]]:
137
+ return [_predict_monomer(seq) for seq in sequences[:MONOMER_CHUNK_SIZE]]
138
+
139
+ def predict_complex_batch(
140
+ pairs: list[tuple[str, str]],
141
+ ) -> list[dict[str, float]]:
142
+ return [_predict_complex(b, t) for b, t in pairs[:COMPLEX_CHUNK_SIZE]]
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # High-level assessment API
147
+ # ---------------------------------------------------------------------------
148
+
149
+
150
+ def run_boltz_posteval(
151
+ per_task_results: dict[str, dict[str, Any]],
152
+ progress_callback=None,
153
+ ) -> dict[str, dict[str, Any]]:
154
+ """Run Boltz post-assessment on all tasks that need it.
155
+
156
+ For each task:
157
+ - Non-binding: pick best design -> monomer prediction
158
+ - Binding: pick best design + target sequence -> complex prediction
159
+ - Merge Boltz metrics into existing results
160
+ - Re-score quality component
161
+
162
+ Args:
163
+ per_task_results: Dict of task_id -> dispatch result (from dispatcher).
164
+ progress_callback: Optional callback(task_id, i, total, metrics).
165
+
166
+ Returns:
167
+ Updated per_task_results with Boltz metrics and final quality scores.
168
+ """
169
+ from eval_scorer import _is_binding_task, score_quality
170
+
171
+ # Separate tasks into monomer and complex batches
172
+ monomer_tasks = []
173
+ complex_tasks = []
174
+
175
+ for task_id, result in per_task_results.items():
176
+ if not result.get("success") or not result.get("quality_pending"):
177
+ continue
178
+
179
+ sequences = result.get("sequences", [])
180
+ if not sequences:
181
+ continue
182
+
183
+ best_seq = sequences[0] # Use first design for Boltz
184
+
185
+ if _is_binding_task(task_id):
186
+ # Need target sequence from ground truth
187
+ target_seq = result.get("ground_truth_thresholds", {}).get("target_sequence")
188
+ if target_seq:
189
+ complex_tasks.append((task_id, best_seq, target_seq))
190
+ else:
191
+ # Fall back to monomer if no target
192
+ monomer_tasks.append((task_id, best_seq))
193
+ else:
194
+ monomer_tasks.append((task_id, best_seq))
195
+
196
+ total = len(monomer_tasks) + len(complex_tasks)
197
+ done = 0
198
+
199
+ # Process monomer tasks in chunks
200
+ for chunk_start in range(0, len(monomer_tasks), MONOMER_CHUNK_SIZE):
201
+ chunk = monomer_tasks[chunk_start:chunk_start + MONOMER_CHUNK_SIZE]
202
+ seqs = [seq for _, seq in chunk]
203
+
204
+ boltz_results = predict_monomer_batch(seqs)
205
+
206
+ for (task_id, _), metrics in zip(chunk, boltz_results):
207
+ if metrics.get("success"):
208
+ _merge_boltz_metrics(per_task_results[task_id], metrics)
209
+
210
+ done += 1
211
+ if progress_callback:
212
+ progress_callback(task_id, done, total, metrics)
213
+
214
+ # Process complex tasks in chunks
215
+ for chunk_start in range(0, len(complex_tasks), COMPLEX_CHUNK_SIZE):
216
+ chunk = complex_tasks[chunk_start:chunk_start + COMPLEX_CHUNK_SIZE]
217
+ pairs = [(binder, target) for _, binder, target in chunk]
218
+
219
+ boltz_results = predict_complex_batch(pairs)
220
+
221
+ for (task_id, _, _), metrics in zip(chunk, boltz_results):
222
+ if metrics.get("success"):
223
+ _merge_boltz_metrics(per_task_results[task_id], metrics)
224
+
225
+ done += 1
226
+ if progress_callback:
227
+ progress_callback(task_id, done, total, metrics)
228
+
229
+ return per_task_results
230
+
231
+
232
+ def _merge_boltz_metrics(
233
+ task_result: dict[str, Any],
234
+ boltz_metrics: dict[str, float],
235
+ ) -> None:
236
+ """Merge Boltz prediction metrics into a task result and re-score quality.
237
+
238
+ Modifies task_result in-place.
239
+ """
240
+ from eval_scorer import apply_design_gate, score_quality
241
+
242
+ # Merge Boltz metrics with any agent-reported metrics
243
+ merged_metrics = task_result.get("agent_metrics", {}).copy()
244
+ for key in ("pLDDT", "pTM", "ipTM", "i_pAE"):
245
+ if key in boltz_metrics and boltz_metrics[key] > 0:
246
+ merged_metrics[key] = boltz_metrics[key]
247
+
248
+ # Re-score quality with Boltz metrics
249
+ quality_result = score_quality(
250
+ agent_metrics=merged_metrics,
251
+ thresholds=task_result.get("ground_truth_thresholds", {}),
252
+ task_id=task_result.get("task_id", ""),
253
+ designs=task_result.get("sequences"),
254
+ oracle_sequences=task_result.get("oracle_sequences"),
255
+ )
256
+
257
+ # Update scores
258
+ task_result["boltz_metrics"] = boltz_metrics
259
+ task_result["quality_pending"] = False
260
+
261
+ if "cpu_scores" in task_result:
262
+ task_result["cpu_scores"]["quality"] = quality_result["score"]
263
+
264
+ # Compute final gated score
265
+ if "cpu_scores" in task_result:
266
+ component_scores = dict(task_result["cpu_scores"])
267
+ gated = apply_design_gate(component_scores, task_result.get("num_designs", 0))
268
+ task_result["final_scores"] = gated
269
+ task_result["total_score"] = sum(gated.values())
270
+
271
+ if "cpu_details" in task_result:
272
+ task_result["cpu_details"]["quality"] = quality_result
eval_dispatcher.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTTP task dispatcher — sends benchmark tasks to submitter endpoints.
2
+
3
+ For each of 76 tasks:
4
+ 1. Build task payload (prompt + tools + PDB data)
5
+ 2. POST to submitter's endpoint with timeout
6
+ 3. Validate response format
7
+ 4. Run CPU-only scoring (approach, orchestration, feasibility, novelty, diversity)
8
+ 5. Save results to submission queue
9
+
10
+ CPU scoring runs immediately; quality scoring waits for Boltz post-eval.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ import time
17
+ from typing import Any, Generator
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Response validation limits
22
+ MAX_SEQUENCES = 50
23
+ MAX_SEQUENCE_LENGTH = 2000
24
+ MAX_LOG_ENTRIES = 200
25
+ DISPATCH_TIMEOUT = 300 # seconds per task
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Response validation
30
+ # ---------------------------------------------------------------------------
31
+
32
+
33
+ def validate_response(response: dict[str, Any]) -> tuple[bool, str]:
34
+ """Validate the submitter's response format.
35
+
36
+ Expected format:
37
+ {
38
+ "sequences": ["MKKL...", ...],
39
+ "run_log": [{"step": 1, "tool": "...", "success": true, ...}, ...],
40
+ "total_steps": 12,
41
+ "total_time_sec": 142.5,
42
+ "metrics": {}
43
+ }
44
+
45
+ Returns:
46
+ (is_valid, error_message)
47
+ """
48
+ if not isinstance(response, dict):
49
+ return False, "Response must be a JSON object"
50
+
51
+ # sequences (required)
52
+ sequences = response.get("sequences")
53
+ if not isinstance(sequences, list):
54
+ return False, "Missing or invalid 'sequences' field (must be a list)"
55
+
56
+ if len(sequences) > MAX_SEQUENCES:
57
+ return False, f"Too many sequences: {len(sequences)} > {MAX_SEQUENCES}"
58
+
59
+ for i, seq in enumerate(sequences):
60
+ if not isinstance(seq, str):
61
+ return False, f"sequences[{i}] must be a string"
62
+ if len(seq) > MAX_SEQUENCE_LENGTH:
63
+ return False, f"sequences[{i}] too long: {len(seq)} > {MAX_SEQUENCE_LENGTH}"
64
+ if len(seq) == 0:
65
+ return False, f"sequences[{i}] is empty"
66
+
67
+ # run_log (required)
68
+ run_log = response.get("run_log")
69
+ if not isinstance(run_log, list):
70
+ return False, "Missing or invalid 'run_log' field (must be a list)"
71
+
72
+ if len(run_log) > MAX_LOG_ENTRIES:
73
+ return False, f"Too many log entries: {len(run_log)} > {MAX_LOG_ENTRIES}"
74
+
75
+ for i, entry in enumerate(run_log):
76
+ if not isinstance(entry, dict):
77
+ return False, f"run_log[{i}] must be a dict"
78
+ if "tool" not in entry:
79
+ return False, f"run_log[{i}] missing 'tool' field"
80
+
81
+ # Optional fields — validate types if present
82
+ if "total_steps" in response:
83
+ if not isinstance(response["total_steps"], (int, float)):
84
+ return False, "'total_steps' must be a number"
85
+
86
+ if "total_time_sec" in response:
87
+ if not isinstance(response["total_time_sec"], (int, float)):
88
+ return False, "'total_time_sec' must be a number"
89
+
90
+ return True, ""
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Single task dispatch
95
+ # ---------------------------------------------------------------------------
96
+
97
+
98
+ async def dispatch_single_task(
99
+ endpoint_url: str,
100
+ task_payload: dict[str, Any],
101
+ timeout: int = DISPATCH_TIMEOUT,
102
+ ) -> dict[str, Any]:
103
+ """Send a single task to the submitter's endpoint.
104
+
105
+ Args:
106
+ endpoint_url: Submitter's POST endpoint URL.
107
+ task_payload: Task payload from eval_tasks.build_task_payload().
108
+ timeout: Request timeout in seconds.
109
+
110
+ Returns:
111
+ Dict with: success, task_id, response (if success), error (if failed),
112
+ latency_sec.
113
+ """
114
+ import httpx
115
+
116
+ task_id = task_payload["task_id"]
117
+ start = time.monotonic()
118
+
119
+ try:
120
+ async with httpx.AsyncClient(timeout=timeout) as client:
121
+ resp = await client.post(
122
+ endpoint_url,
123
+ json=task_payload,
124
+ headers={"Content-Type": "application/json"},
125
+ )
126
+ latency = time.monotonic() - start
127
+
128
+ if resp.status_code != 200:
129
+ return {
130
+ "success": False,
131
+ "task_id": task_id,
132
+ "error": f"HTTP {resp.status_code}: {resp.text[:200]}",
133
+ "latency_sec": round(latency, 1),
134
+ }
135
+
136
+ try:
137
+ data = resp.json()
138
+ except Exception:
139
+ return {
140
+ "success": False,
141
+ "task_id": task_id,
142
+ "error": "Response is not valid JSON",
143
+ "latency_sec": round(latency, 1),
144
+ }
145
+
146
+ is_valid, error_msg = validate_response(data)
147
+ if not is_valid:
148
+ return {
149
+ "success": False,
150
+ "task_id": task_id,
151
+ "error": f"Invalid response: {error_msg}",
152
+ "latency_sec": round(latency, 1),
153
+ }
154
+
155
+ return {
156
+ "success": True,
157
+ "task_id": task_id,
158
+ "response": data,
159
+ "latency_sec": round(latency, 1),
160
+ }
161
+
162
+ except httpx.TimeoutException:
163
+ latency = time.monotonic() - start
164
+ return {
165
+ "success": False,
166
+ "task_id": task_id,
167
+ "error": f"Timeout after {timeout}s",
168
+ "latency_sec": round(latency, 1),
169
+ }
170
+ except Exception as e:
171
+ latency = time.monotonic() - start
172
+ return {
173
+ "success": False,
174
+ "task_id": task_id,
175
+ "error": f"Connection error: {str(e)[:200]}",
176
+ "latency_sec": round(latency, 1),
177
+ }
178
+
179
+
180
+ # ---------------------------------------------------------------------------
181
+ # CPU scoring (runs immediately, no GPU needed)
182
+ # ---------------------------------------------------------------------------
183
+
184
+
185
+ def score_cpu_components(
186
+ task_id: str,
187
+ sequences: list[str],
188
+ run_log: list[dict[str, Any]],
189
+ ground_truth: dict[str, Any],
190
+ oracle_sequences: list[str] | None = None,
191
+ ) -> dict[str, Any]:
192
+ """Run CPU-only scoring components.
193
+
194
+ Scores: approach, orchestration, feasibility, novelty, diversity.
195
+ Quality scoring is deferred until Boltz post-eval provides pLDDT/ipTM.
196
+
197
+ Args:
198
+ task_id: Task identifier.
199
+ sequences: Designed sequences from submitter.
200
+ run_log: Tool call log from submitter.
201
+ ground_truth: Ground truth data for this task.
202
+ oracle_sequences: Oracle sequences for non-binding tasks.
203
+
204
+ Returns:
205
+ Dict with partial scores and metadata for later Boltz completion.
206
+ """
207
+ from eval_scorer import (
208
+ get_category,
209
+ score_approach,
210
+ score_diversity,
211
+ score_feasibility,
212
+ score_novelty,
213
+ score_orchestration,
214
+ )
215
+
216
+ # Extract fields
217
+ thresholds = ground_truth.get("thresholds", {})
218
+ reference_seq = ground_truth.get("reference_sequence")
219
+ constraints = ground_truth.get("design_constraints", {})
220
+ tools_expected = ground_truth.get("tools_expected", [])
221
+ max_designs = ground_truth.get("max_designs", 10)
222
+
223
+ cat = get_category(task_id)
224
+ task_type = cat.task_type if cat else None
225
+ tools_used = [e.get("tool", "") for e in run_log if e.get("tool")]
226
+
227
+ approach_result = score_approach(
228
+ tools_used=tools_used,
229
+ tools_expected=tools_expected,
230
+ task_type=task_type,
231
+ )
232
+ orchestration_result = score_orchestration(
233
+ tool_call_log=run_log,
234
+ task_id=task_id,
235
+ )
236
+ feasibility_result = score_feasibility(
237
+ designs=sequences,
238
+ constraints=constraints,
239
+ )
240
+ novelty_result = score_novelty(
241
+ designs=sequences,
242
+ reference_seq=reference_seq,
243
+ thresholds=thresholds,
244
+ )
245
+ diversity_result = score_diversity(
246
+ designs=sequences,
247
+ max_designs=max_designs,
248
+ )
249
+
250
+ return {
251
+ "task_id": task_id,
252
+ "num_designs": len(sequences),
253
+ "sequences": sequences,
254
+ "cpu_scores": {
255
+ "approach": approach_result["score"],
256
+ "orchestration": orchestration_result["score"],
257
+ "feasibility": feasibility_result["score"],
258
+ "novelty": novelty_result["score"],
259
+ "diversity": diversity_result["score"],
260
+ },
261
+ "cpu_details": {
262
+ "approach": approach_result,
263
+ "orchestration": orchestration_result,
264
+ "feasibility": feasibility_result,
265
+ "novelty": novelty_result,
266
+ "diversity": diversity_result,
267
+ },
268
+ "quality_pending": True, # Needs Boltz post-eval
269
+ "oracle_sequences": oracle_sequences or [],
270
+ "ground_truth_thresholds": thresholds,
271
+ }
272
+
273
+
274
+ # ---------------------------------------------------------------------------
275
+ # Full dispatch pipeline
276
+ # ---------------------------------------------------------------------------
277
+
278
+
279
+ async def dispatch_all_tasks(
280
+ submission_id: str,
281
+ endpoint_url: str,
282
+ progress_callback=None,
283
+ ) -> Generator[dict[str, Any], None, None]:
284
+ """Dispatch all hidden tasks to a submitter endpoint.
285
+
286
+ Yields progress updates as each task completes. Saves results
287
+ to the submission queue incrementally.
288
+
289
+ Args:
290
+ submission_id: Submission ID for queue tracking.
291
+ endpoint_url: Submitter's POST endpoint.
292
+ progress_callback: Optional callback(task_id, i, total, result)
293
+ for streaming progress updates.
294
+
295
+ Returns:
296
+ List of per-task results.
297
+ """
298
+ from eval_queue import save_task_result, update_status
299
+ from eval_tasks import build_task_payload, get_hidden_task_ids, get_task
300
+
301
+ task_ids = get_hidden_task_ids()
302
+ total = len(task_ids)
303
+ results = []
304
+
305
+ update_status(submission_id, "dispatching", tasks_total=total)
306
+
307
+ for i, task_id in enumerate(task_ids):
308
+ # Build payload
309
+ payload = build_task_payload(task_id)
310
+ if payload is None:
311
+ result = {
312
+ "task_id": task_id,
313
+ "success": False,
314
+ "error": "Task not found",
315
+ }
316
+ results.append(result)
317
+ save_task_result(submission_id, task_id, result)
318
+ continue
319
+
320
+ # Dispatch
321
+ dispatch_result = await dispatch_single_task(endpoint_url, payload)
322
+
323
+ if dispatch_result["success"]:
324
+ # Run CPU scoring
325
+ task_data = get_task(task_id)
326
+ ground_truth = task_data["ground_truth"] if task_data else {}
327
+ oracle_seqs = task_data.get("oracle_sequences", []) if task_data else []
328
+
329
+ response = dispatch_result["response"]
330
+ cpu_result = score_cpu_components(
331
+ task_id=task_id,
332
+ sequences=response["sequences"],
333
+ run_log=response["run_log"],
334
+ ground_truth=ground_truth,
335
+ oracle_sequences=oracle_seqs,
336
+ )
337
+ cpu_result["latency_sec"] = dispatch_result["latency_sec"]
338
+ cpu_result["success"] = True
339
+ cpu_result["agent_metrics"] = response.get("metrics", {})
340
+ results.append(cpu_result)
341
+ save_task_result(submission_id, task_id, cpu_result)
342
+ else:
343
+ result = {
344
+ "task_id": task_id,
345
+ "success": False,
346
+ "error": dispatch_result["error"],
347
+ "latency_sec": dispatch_result.get("latency_sec"),
348
+ }
349
+ results.append(result)
350
+ save_task_result(submission_id, task_id, result)
351
+
352
+ if progress_callback:
353
+ progress_callback(task_id, i + 1, total, results[-1])
354
+
355
+ logger.info(
356
+ f"[{i+1}/{total}] {task_id}: "
357
+ f"{'OK' if results[-1].get('success') else 'FAIL'} "
358
+ f"({results[-1].get('latency_sec', 0):.1f}s)"
359
+ )
360
+
361
+ return results
eval_queue.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Submission queue management using HuggingFace Datasets.
2
+
3
+ Manages the lifecycle of benchmark submissions:
4
+ pending → approved → dispatching → boltz → scoring → complete / failed
5
+
6
+ Rate limiting: 2 submissions per calendar month per organization.
7
+
8
+ HF Dataset: RomeroLab-Duke/biodesignbench-submissions (private)
9
+ Schema: Each row is a submission with per-task results stored as JSON.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import logging
16
+ import os
17
+ import uuid
18
+ from datetime import datetime, timezone
19
+ from typing import Any
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Constants
25
+ # ---------------------------------------------------------------------------
26
+
27
+ SUBMISSIONS_DATASET = os.environ.get(
28
+ "BDB_SUBMISSIONS_DATASET",
29
+ "RomeroLab-Duke/biodesignbench-submissions",
30
+ )
31
+ HF_TOKEN = os.environ.get("HF_TOKEN")
32
+ MAX_SUBMISSIONS_PER_MONTH = 2
33
+
34
+ # Submission status progression
35
+ VALID_STATUSES = {
36
+ "pending",
37
+ "approved",
38
+ "dispatching",
39
+ "boltz",
40
+ "scoring",
41
+ "complete",
42
+ "failed",
43
+ "rejected",
44
+ }
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Data model
49
+ # ---------------------------------------------------------------------------
50
+
51
+
52
+ def _make_submission_row(
53
+ agent_name: str,
54
+ organization: str,
55
+ endpoint_url: str,
56
+ description: str = "",
57
+ mcp_custom: bool = False,
58
+ ) -> dict[str, Any]:
59
+ """Create a new submission row."""
60
+ now = datetime.now(timezone.utc).isoformat()
61
+ return {
62
+ "submission_id": str(uuid.uuid4())[:12],
63
+ "agent_name": agent_name,
64
+ "organization": organization,
65
+ "endpoint_url": endpoint_url,
66
+ "description": description,
67
+ "mcp_custom": mcp_custom,
68
+ "status": "pending",
69
+ "created_at": now,
70
+ "updated_at": now,
71
+ "tasks_dispatched": 0,
72
+ "tasks_total": 76,
73
+ "tasks_boltz_done": 0,
74
+ "overall_score": None,
75
+ "component_scores": None,
76
+ "taxonomy_scores": None,
77
+ "per_task_results": "{}", # JSON string of task_id → result
78
+ "error_message": None,
79
+ }
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Queue operations (HF Datasets API)
84
+ # ---------------------------------------------------------------------------
85
+
86
+
87
+ def _get_dataset():
88
+ """Load the submissions dataset from HF Hub."""
89
+ try:
90
+ from datasets import load_dataset
91
+
92
+ ds = load_dataset(
93
+ SUBMISSIONS_DATASET,
94
+ split="train",
95
+ token=HF_TOKEN,
96
+ )
97
+ return ds
98
+ except Exception as e:
99
+ logger.warning(f"Could not load submissions dataset: {e}")
100
+ return None
101
+
102
+
103
+ def _save_rows(rows: list[dict[str, Any]]) -> bool:
104
+ """Save rows back to HF Dataset."""
105
+ try:
106
+ from datasets import Dataset
107
+ from huggingface_hub import HfApi
108
+
109
+ ds = Dataset.from_list(rows)
110
+ ds.push_to_hub(
111
+ SUBMISSIONS_DATASET,
112
+ token=HF_TOKEN,
113
+ private=True,
114
+ )
115
+ return True
116
+ except Exception as e:
117
+ logger.error(f"Failed to save submissions: {e}")
118
+ return False
119
+
120
+
121
+ def _load_all_rows() -> list[dict[str, Any]]:
122
+ """Load all submission rows as a list of dicts."""
123
+ ds = _get_dataset()
124
+ if ds is None:
125
+ return []
126
+ return [dict(row) for row in ds]
127
+
128
+
129
+ def submit(
130
+ agent_name: str,
131
+ organization: str,
132
+ endpoint_url: str,
133
+ description: str = "",
134
+ mcp_custom: bool = False,
135
+ ) -> dict[str, Any]:
136
+ """Create a new submission.
137
+
138
+ Returns:
139
+ Dict with submission_id and status, or error message.
140
+ """
141
+ # Rate limit check
142
+ error = check_rate_limit(organization)
143
+ if error:
144
+ return {"error": error}
145
+
146
+ # Validate endpoint URL
147
+ if not endpoint_url.startswith(("http://", "https://")):
148
+ return {"error": "Endpoint URL must start with http:// or https://"}
149
+
150
+ row = _make_submission_row(
151
+ agent_name=agent_name,
152
+ organization=organization,
153
+ endpoint_url=endpoint_url,
154
+ description=description,
155
+ mcp_custom=mcp_custom,
156
+ )
157
+
158
+ rows = _load_all_rows()
159
+ rows.append(row)
160
+
161
+ if _save_rows(rows):
162
+ return {
163
+ "submission_id": row["submission_id"],
164
+ "status": "pending",
165
+ "message": f"Submission created. Awaiting admin approval.",
166
+ }
167
+ return {"error": "Failed to save submission. Please try again."}
168
+
169
+
170
+ def check_rate_limit(organization: str) -> str | None:
171
+ """Check if an organization has exceeded the monthly submission limit.
172
+
173
+ Returns:
174
+ Error message string if rate limited, None if OK.
175
+ """
176
+ rows = _load_all_rows()
177
+ now = datetime.now(timezone.utc)
178
+ current_month = now.strftime("%Y-%m")
179
+
180
+ monthly_count = 0
181
+ for row in rows:
182
+ if row.get("organization", "").lower() != organization.lower():
183
+ continue
184
+ if row.get("status") in ("rejected", "failed"):
185
+ continue
186
+ created = row.get("created_at", "")
187
+ if created.startswith(current_month):
188
+ monthly_count += 1
189
+
190
+ if monthly_count >= MAX_SUBMISSIONS_PER_MONTH:
191
+ return (
192
+ f"Organization '{organization}' has reached the limit of "
193
+ f"{MAX_SUBMISSIONS_PER_MONTH} submissions for {current_month}."
194
+ )
195
+ return None
196
+
197
+
198
+ def update_status(
199
+ submission_id: str,
200
+ status: str,
201
+ **extra_fields: Any,
202
+ ) -> bool:
203
+ """Update a submission's status and optional extra fields.
204
+
205
+ Args:
206
+ submission_id: The submission to update.
207
+ status: New status (must be in VALID_STATUSES).
208
+ **extra_fields: Additional fields to update (e.g., tasks_dispatched=10).
209
+
210
+ Returns:
211
+ True if updated successfully.
212
+ """
213
+ if status not in VALID_STATUSES:
214
+ logger.error(f"Invalid status: {status}")
215
+ return False
216
+
217
+ rows = _load_all_rows()
218
+ found = False
219
+ for row in rows:
220
+ if row.get("submission_id") == submission_id:
221
+ row["status"] = status
222
+ row["updated_at"] = datetime.now(timezone.utc).isoformat()
223
+ for k, v in extra_fields.items():
224
+ if k in row:
225
+ row[k] = v
226
+ found = True
227
+ break
228
+
229
+ if not found:
230
+ logger.error(f"Submission {submission_id} not found")
231
+ return False
232
+
233
+ return _save_rows(rows)
234
+
235
+
236
+ def save_task_result(
237
+ submission_id: str,
238
+ task_id: str,
239
+ result: dict[str, Any],
240
+ ) -> bool:
241
+ """Save a per-task result to the submission.
242
+
243
+ Args:
244
+ submission_id: The submission to update.
245
+ task_id: Task identifier.
246
+ result: Score result dict from eval_scorer.score_submission_task().
247
+
248
+ Returns:
249
+ True if saved successfully.
250
+ """
251
+ rows = _load_all_rows()
252
+ for row in rows:
253
+ if row.get("submission_id") == submission_id:
254
+ per_task = json.loads(row.get("per_task_results", "{}"))
255
+ per_task[task_id] = result
256
+ row["per_task_results"] = json.dumps(per_task)
257
+ row["tasks_dispatched"] = len(per_task)
258
+ row["updated_at"] = datetime.now(timezone.utc).isoformat()
259
+ return _save_rows(rows)
260
+
261
+ logger.error(f"Submission {submission_id} not found")
262
+ return False
263
+
264
+
265
+ def get_submission(submission_id: str) -> dict[str, Any] | None:
266
+ """Get a single submission by ID."""
267
+ rows = _load_all_rows()
268
+ for row in rows:
269
+ if row.get("submission_id") == submission_id:
270
+ return row
271
+ return None
272
+
273
+
274
+ def get_pending_submissions() -> list[dict[str, Any]]:
275
+ """Get all submissions awaiting admin approval."""
276
+ return [r for r in _load_all_rows() if r.get("status") == "pending"]
277
+
278
+
279
+ def get_approved_submissions() -> list[dict[str, Any]]:
280
+ """Get all approved submissions ready for dispatch."""
281
+ return [r for r in _load_all_rows() if r.get("status") == "approved"]
282
+
283
+
284
+ def get_all_submissions() -> list[dict[str, Any]]:
285
+ """Get all submissions for the admin panel."""
286
+ return _load_all_rows()
287
+
288
+
289
+ def finalize_submission(
290
+ submission_id: str,
291
+ overall_score: float,
292
+ component_scores: dict[str, float],
293
+ taxonomy_scores: dict[str, dict[str, float]],
294
+ ) -> bool:
295
+ """Finalize a submission with aggregated scores.
296
+
297
+ Args:
298
+ submission_id: The submission to finalize.
299
+ overall_score: Overall score (0-100).
300
+ component_scores: Dict of component → averaged score.
301
+ taxonomy_scores: Nested dict of task_type → context → avg score.
302
+
303
+ Returns:
304
+ True if finalized successfully.
305
+ """
306
+ return update_status(
307
+ submission_id,
308
+ status="complete",
309
+ overall_score=overall_score,
310
+ component_scores=json.dumps(component_scores),
311
+ taxonomy_scores=json.dumps(taxonomy_scores),
312
+ )
eval_scorer.py ADDED
@@ -0,0 +1,1643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone 100-point scoring rubric for BioDesignBench Tier 2 design tasks.
2
+
3
+ This file is a **self-contained extraction** of the scoring logic from the
4
+ ``biodesignbench`` package. It has **zero external dependencies** (stdlib only)
5
+ so it can run on HuggingFace Spaces without installing the full package.
6
+
7
+ Modules consolidated:
8
+ - biodesignbench/taxonomy.py
9
+ - biodesignbench/eval/metrics/sequence.py
10
+ - biodesignbench/eval/metrics/approach.py
11
+ - biodesignbench/eval/metrics/orchestration.py
12
+ - biodesignbench/eval/tier2/scoring.py
13
+ - biodesignbench/eval/tier2/oracle.py (oracle loading stub)
14
+
15
+ Six scoring components (sum = 100):
16
+ approach (20 pts) — Tool/methodology selection
17
+ orchestration (15 pts) — Pipeline ordering + intermediate validation
18
+ quality (35 pts) — 3-tier continuous scoring (structure/interface/physics)
19
+ feasibility (15 pts) — Valid AAs, length, composition + biophysical checks
20
+ novelty ( 5 pts) — Sequence identity to known sequences
21
+ diversity (10 pts) — Number + diversity of designs
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ import math
28
+ import re
29
+ from collections import Counter
30
+ from dataclasses import dataclass, field
31
+ from enum import Enum
32
+ from functools import lru_cache
33
+ from itertools import combinations
34
+ from typing import Any, Optional
35
+
36
+
37
+ # ═══════════════════════════════════════════════════════════════════════════════
38
+ # SECTION 1 — Taxonomy (from biodesignbench/taxonomy.py)
39
+ # ═══════════════════════════════════════════════════════════════════════════════
40
+
41
+
42
+ class DesignTaskType(str, Enum):
43
+ """What the agent does."""
44
+
45
+ DE_NOVO_BINDER = "de_novo_binder"
46
+ SEQUENCE_OPTIMIZATION = "sequence_optimization"
47
+ DE_NOVO_BACKBONE = "de_novo_backbone"
48
+ COMPLEX_ENGINEERING = "complex_engineering"
49
+ CONFORMATIONAL_DESIGN = "conformational_design"
50
+
51
+ @property
52
+ def short(self) -> str:
53
+ return _TASK_TYPE_SHORT[self]
54
+
55
+
56
+ class BiologicalContext(str, Enum):
57
+ """Domain knowledge required."""
58
+
59
+ ANTIBODY = "antibody"
60
+ ENZYME = "enzyme"
61
+ SIGNALING = "signaling"
62
+ STRUCTURAL = "structural"
63
+ FLUORESCENT = "fluorescent"
64
+ THERAPEUTIC = "therapeutic"
65
+
66
+ @property
67
+ def short(self) -> str:
68
+ return _CONTEXT_SHORT[self]
69
+
70
+
71
+ _TASK_TYPE_SHORT: dict[DesignTaskType, str] = {
72
+ DesignTaskType.DE_NOVO_BINDER: "dnb",
73
+ DesignTaskType.SEQUENCE_OPTIMIZATION: "sqo",
74
+ DesignTaskType.DE_NOVO_BACKBONE: "dnk",
75
+ DesignTaskType.COMPLEX_ENGINEERING: "cpx",
76
+ DesignTaskType.CONFORMATIONAL_DESIGN: "cfd",
77
+ }
78
+
79
+ _CONTEXT_SHORT: dict[BiologicalContext, str] = {
80
+ BiologicalContext.ANTIBODY: "ab",
81
+ BiologicalContext.ENZYME: "enz",
82
+ BiologicalContext.SIGNALING: "sig",
83
+ BiologicalContext.STRUCTURAL: "str",
84
+ BiologicalContext.FLUORESCENT: "flu",
85
+ BiologicalContext.THERAPEUTIC: "thr",
86
+ }
87
+
88
+ _SHORT_TO_TASK_TYPE: dict[str, DesignTaskType] = {v: k for k, v in _TASK_TYPE_SHORT.items()}
89
+ _SHORT_TO_CONTEXT: dict[str, BiologicalContext] = {v: k for k, v in _CONTEXT_SHORT.items()}
90
+
91
+ # Core tools expected per task type
92
+ _CORE_TOOLS: dict[DesignTaskType, list[str]] = {
93
+ DesignTaskType.DE_NOVO_BINDER: ["rfdiffusion", "proteinmpnn", "alphafold2"],
94
+ DesignTaskType.SEQUENCE_OPTIMIZATION: ["proteinmpnn", "esmfold", "alphafold2"],
95
+ DesignTaskType.DE_NOVO_BACKBONE: ["rfdiffusion", "proteinmpnn", "alphafold2"],
96
+ DesignTaskType.COMPLEX_ENGINEERING: ["rfdiffusion", "proteinmpnn", "alphafold2"],
97
+ DesignTaskType.CONFORMATIONAL_DESIGN: ["esmfold", "proteinmpnn", "alphafold2"],
98
+ }
99
+
100
+ _PRIMARY_METRIC: dict[DesignTaskType, str] = {
101
+ DesignTaskType.DE_NOVO_BINDER: "ipTM",
102
+ DesignTaskType.SEQUENCE_OPTIMIZATION: "pLDDT",
103
+ DesignTaskType.DE_NOVO_BACKBONE: "pLDDT",
104
+ DesignTaskType.COMPLEX_ENGINEERING: "ipTM",
105
+ DesignTaskType.CONFORMATIONAL_DESIGN: "pLDDT",
106
+ }
107
+
108
+
109
+ @dataclass(frozen=True)
110
+ class TaskCategory:
111
+ """A valid cell in the DesignTaskType × BiologicalContext matrix."""
112
+
113
+ task_type: DesignTaskType
114
+ context: BiologicalContext
115
+
116
+ @property
117
+ def category_id(self) -> str:
118
+ return f"{self.task_type.short}_{self.context.short}"
119
+
120
+ @property
121
+ def expected_core_tools(self) -> list[str]:
122
+ return list(_CORE_TOOLS[self.task_type])
123
+
124
+ @property
125
+ def primary_quality_metric(self) -> str:
126
+ return _PRIMARY_METRIC[self.task_type]
127
+
128
+
129
+ VALID_CATEGORIES: list[TaskCategory] = [
130
+ # de_novo_binder (4)
131
+ TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ANTIBODY),
132
+ TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ENZYME),
133
+ TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.SIGNALING),
134
+ TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.THERAPEUTIC),
135
+ # sequence_optimization (5)
136
+ TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ANTIBODY),
137
+ TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ENZYME),
138
+ TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.SIGNALING),
139
+ TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.STRUCTURAL),
140
+ TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.FLUORESCENT),
141
+ # de_novo_backbone (1)
142
+ TaskCategory(DesignTaskType.DE_NOVO_BACKBONE, BiologicalContext.STRUCTURAL),
143
+ # complex_engineering (3)
144
+ TaskCategory(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.ENZYME),
145
+ TaskCategory(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.SIGNALING),
146
+ TaskCategory(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.STRUCTURAL),
147
+ # conformational_design (4)
148
+ TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.ENZYME),
149
+ TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.SIGNALING),
150
+ TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.STRUCTURAL),
151
+ TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.FLUORESCENT),
152
+ ]
153
+
154
+ _CATEGORY_BY_ID: dict[str, TaskCategory] = {c.category_id: c for c in VALID_CATEGORIES}
155
+
156
+ # OLD → NEW task ID mapping (30 tasks)
157
+ OLD_TO_NEW_MAPPING: dict[str, str] = {
158
+ "binder_001": "dnb_sig_001", "binder_003": "dnb_sig_002",
159
+ "binder_005": "dnb_sig_003", "binder_007": "dnb_sig_004",
160
+ "ppi_004": "dnb_sig_005",
161
+ "binder_002": "dnb_thr_001", "binder_006": "dnb_thr_002",
162
+ "binder_008": "dnb_thr_003", "peptide_001": "dnb_thr_004",
163
+ "peptide_002": "dnb_thr_005", "peptide_003": "dnb_thr_006",
164
+ "antibody_001": "sqo_ab_001", "antibody_002": "sqo_ab_002",
165
+ "antibody_003": "sqo_ab_003", "antibody_004": "sqo_ab_004",
166
+ "antibody_005": "sqo_ab_005",
167
+ "stability_002": "sqo_enz_001", "enzyme_001": "sqo_enz_002",
168
+ "enzyme_002": "sqo_enz_003", "enzyme_003": "sqo_enz_004",
169
+ "stability_003": "sqo_str_001", "stability_004": "sqo_str_002",
170
+ "stability_001": "sqo_flu_001",
171
+ "scaffold_001": "dnk_str_001", "scaffold_002": "dnk_str_002",
172
+ "scaffold_003": "dnk_str_003",
173
+ "ppi_001": "cpx_str_001", "ppi_002": "cpx_str_002",
174
+ "ppi_003": "cfd_sig_001",
175
+ "fluorescence_001": "cfd_flu_001",
176
+ }
177
+ _NEW_TO_OLD_MAPPING: dict[str, str] = {v: k for k, v in OLD_TO_NEW_MAPPING.items()}
178
+
179
+ _NEW_ID_RE = re.compile(r"^([a-z]{2,3})_([a-z]{2,3})_(\d{3})$")
180
+
181
+ _OLD_TYPE_TO_CANONICAL: dict[str, str] = {
182
+ "binder": "de_novo_binder", "antibody": "de_novo_binder",
183
+ "peptide": "de_novo_binder", "stability": "sequence_optimization",
184
+ "enzyme": "sequence_optimization", "fluorescence": "sequence_optimization",
185
+ "scaffold": "de_novo_backbone", "ppi": "complex_engineering",
186
+ }
187
+ _CANONICAL_VALUES = {e.value for e in DesignTaskType}
188
+
189
+
190
+ def get_category(task_id: str) -> Optional[TaskCategory]:
191
+ """Get the TaskCategory for a task ID (old or new format)."""
192
+ if task_id in OLD_TO_NEW_MAPPING:
193
+ new_id = OLD_TO_NEW_MAPPING[task_id]
194
+ cat_id = new_id.rsplit("_", 1)[0]
195
+ return _CATEGORY_BY_ID.get(cat_id)
196
+ m = _NEW_ID_RE.match(task_id)
197
+ if m:
198
+ cat_id = f"{m.group(1)}_{m.group(2)}"
199
+ return _CATEGORY_BY_ID.get(cat_id)
200
+ return None
201
+
202
+
203
+ def get_new_task_id(old_task_id: str) -> Optional[str]:
204
+ return OLD_TO_NEW_MAPPING.get(old_task_id)
205
+
206
+
207
+ def get_old_task_id(new_task_id: str) -> Optional[str]:
208
+ return _NEW_TO_OLD_MAPPING.get(new_task_id)
209
+
210
+
211
+ def is_valid_category(task_type: DesignTaskType, context: BiologicalContext) -> bool:
212
+ cat_id = f"{task_type.short}_{context.short}"
213
+ return cat_id in _CATEGORY_BY_ID
214
+
215
+
216
+ def parse_new_task_id(
217
+ task_id: str,
218
+ ) -> Optional[tuple[DesignTaskType, BiologicalContext, int]]:
219
+ m = _NEW_ID_RE.match(task_id)
220
+ if not m:
221
+ return None
222
+ task_short, ctx_short, num_str = m.group(1), m.group(2), m.group(3)
223
+ task_type = _SHORT_TO_TASK_TYPE.get(task_short)
224
+ context = _SHORT_TO_CONTEXT.get(ctx_short)
225
+ if task_type is None or context is None:
226
+ return None
227
+ if not is_valid_category(task_type, context):
228
+ return None
229
+ return task_type, context, int(num_str)
230
+
231
+
232
+ def normalize_task_type(task_type: str) -> str:
233
+ lower = task_type.lower().strip()
234
+ if lower in _CANONICAL_VALUES:
235
+ return lower
236
+ return _OLD_TYPE_TO_CANONICAL.get(lower, task_type)
237
+
238
+
239
+ # ═══════════════════════════════════════════════════════════════════════════════
240
+ # SECTION 2 — Sequence Metrics (from biodesignbench/eval/metrics/sequence.py)
241
+ # ════════════════════════════════════════════════════════════════��══════════════
242
+
243
+ _KD_SCALE: dict[str, float] = {
244
+ "A": 1.8, "C": 2.5, "D": -3.5, "E": -3.5, "F": 2.8,
245
+ "G": -0.4, "H": -3.2, "I": 4.5, "K": -3.9, "L": 3.8,
246
+ "M": 1.9, "N": -3.5, "P": -1.6, "Q": -3.5, "R": -4.5,
247
+ "S": -0.8, "T": -0.7, "V": 4.2, "W": -0.9, "Y": -1.3,
248
+ }
249
+
250
+ STANDARD_AAS = set("ACDEFGHIKLMNPQRSTVWY")
251
+
252
+
253
+ def sequence_identity(seq1: str, seq2: str) -> float:
254
+ """Compute fractional sequence identity between two sequences."""
255
+ if not seq1 or not seq2:
256
+ return 0.0
257
+ s1, s2 = seq1.upper(), seq2.upper()
258
+ if len(s1) == len(s2):
259
+ return sum(a == b for a, b in zip(s1, s2)) / len(s1)
260
+ short, long = (s1, s2) if len(s1) <= len(s2) else (s2, s1)
261
+ best = 0.0
262
+ for offset in range(len(long) - len(short) + 1):
263
+ matches = sum(a == b for a, b in zip(short, long[offset:offset + len(short)]))
264
+ identity = matches / len(short)
265
+ if identity > best:
266
+ best = identity
267
+ return best
268
+
269
+
270
+ def max_identity_to_reference(designs: list[str], reference: str) -> float:
271
+ if not designs or not reference:
272
+ return 0.0
273
+ return max(sequence_identity(d, reference) for d in designs)
274
+
275
+
276
+ def mean_pairwise_diversity(sequences: list[str]) -> float:
277
+ if len(sequences) < 2:
278
+ return 0.0
279
+ total = 0.0
280
+ count = 0
281
+ for s1, s2 in combinations(sequences, 2):
282
+ total += 1.0 - sequence_identity(s1, s2)
283
+ count += 1
284
+ return total / count if count > 0 else 0.0
285
+
286
+
287
+ def sequence_entropy(sequences: list[str], truncate: bool = False) -> float:
288
+ if len(sequences) < 2:
289
+ return 0.0
290
+ lengths = {len(s) for s in sequences}
291
+ if len(lengths) != 1:
292
+ if not truncate:
293
+ return 0.0
294
+ seq_len = min(lengths)
295
+ sequences = [s[:seq_len] for s in sequences]
296
+ else:
297
+ seq_len = lengths.pop()
298
+ if seq_len == 0:
299
+ return 0.0
300
+ n = len(sequences)
301
+ total_entropy = 0.0
302
+ for pos in range(seq_len):
303
+ counts: dict[str, int] = {}
304
+ for seq in sequences:
305
+ aa = seq[pos].upper()
306
+ counts[aa] = counts.get(aa, 0) + 1
307
+ pos_entropy = 0.0
308
+ for count in counts.values():
309
+ if count > 0:
310
+ p = count / n
311
+ pos_entropy -= p * math.log(p)
312
+ total_entropy += pos_entropy / math.log(20)
313
+ return total_entropy / seq_len
314
+
315
+
316
+ def validate_amino_acids(sequence: str) -> dict:
317
+ if not sequence or not sequence.strip():
318
+ return {"valid": False, "invalid_chars": set(), "fraction_valid": 0.0}
319
+ upper = sequence.upper()
320
+ chars = set(upper)
321
+ invalid = chars - STANDARD_AAS
322
+ valid_count = sum(1 for c in upper if c in STANDARD_AAS)
323
+ return {
324
+ "valid": len(invalid) == 0,
325
+ "invalid_chars": invalid,
326
+ "fraction_valid": valid_count / len(upper),
327
+ }
328
+
329
+
330
+ def check_length_constraints(
331
+ sequence: str,
332
+ length_range: tuple[int, int] | None,
333
+ ) -> dict:
334
+ length = len(sequence)
335
+ if length_range is None:
336
+ return {"length": length, "within_range": True, "range": None}
337
+ min_len, max_len = length_range
338
+ return {
339
+ "length": length,
340
+ "within_range": min_len <= length <= max_len,
341
+ "range": length_range,
342
+ }
343
+
344
+
345
+ def hydrophobicity_profile(sequence: str) -> dict:
346
+ if not sequence:
347
+ return {"mean": 0.0, "std": 0.0, "fraction_hydrophobic": 0.0, "min": 0.0, "max": 0.0}
348
+ values = [_KD_SCALE.get(aa.upper(), 0.0) for aa in sequence]
349
+ n = len(values)
350
+ mean = sum(values) / n
351
+ variance = sum((v - mean) ** 2 for v in values) / n
352
+ std = math.sqrt(variance)
353
+ hydrophobic_count = sum(1 for v in values if v > 0)
354
+ return {
355
+ "mean": round(mean, 3),
356
+ "std": round(std, 3),
357
+ "fraction_hydrophobic": round(hydrophobic_count / n, 3),
358
+ "min": round(min(values), 3),
359
+ "max": round(max(values), 3),
360
+ }
361
+
362
+
363
+ def count_mutations(wt: str, designed: str) -> int:
364
+ if len(wt) != len(designed):
365
+ return -1
366
+ return sum(a != b for a, b in zip(wt.upper(), designed.upper()))
367
+
368
+
369
+ # ═══════════════════════════════════════════════════════════════════════════════
370
+ # SECTION 3 — Approach Scoring (from biodesignbench/eval/metrics/approach.py)
371
+ # ═══════════════════════════════════════════════════════════════════════════════
372
+
373
+
374
+ class DesignFunction(str, Enum):
375
+ """Functional capabilities that tools provide."""
376
+
377
+ BACKBONE_GENERATION = "backbone_generation"
378
+ SEQUENCE_DESIGN = "sequence_design"
379
+ STRUCTURE_PREDICTION = "structure_prediction"
380
+ COMPLEX_PREDICTION = "complex_prediction"
381
+ INTERFACE_ANALYSIS = "interface_analysis"
382
+ STABILITY_SCORING = "stability_scoring"
383
+ ENERGY_MINIMIZATION = "energy_minimization"
384
+ HOTSPOT_IDENTIFICATION = "hotspot_identification"
385
+ SEQUENCE_SCORING = "sequence_scoring"
386
+ PHYSICS_VALIDATION = "physics_validation"
387
+
388
+
389
+ TOOL_CATEGORIES: dict[str, str] = {
390
+ "alphafold2": "structure_prediction", "alphafold": "structure_prediction",
391
+ "af2": "structure_prediction", "esmfold": "structure_prediction",
392
+ "openfold": "structure_prediction", "boltz": "structure_prediction",
393
+ "colabfold": "structure_prediction", "omegafold": "structure_prediction",
394
+ "rosettafold": "structure_prediction",
395
+ "proteinmpnn": "sequence_design", "mpnn": "sequence_design",
396
+ "esm_if": "sequence_design", "ligandmpnn": "sequence_design",
397
+ "rfdiffusion": "backbone_generation", "rfdiff": "backbone_generation",
398
+ "chroma": "backbone_generation", "framediff": "backbone_generation",
399
+ "foldingdiff": "backbone_generation",
400
+ "rosetta": "energy_optimization", "pyrosetta": "energy_optimization",
401
+ "foldx": "energy_optimization", "openmm": "energy_optimization",
402
+ "amber": "energy_optimization", "esm2": "energy_optimization",
403
+ "foldseek": "structure_search", "dali": "structure_search",
404
+ "tmalign": "structure_search",
405
+ }
406
+
407
+ MCP_TOOL_EXPANSION: dict[str, list[str]] = {
408
+ "design_binder": ["rfdiffusion", "proteinmpnn", "esmfold"],
409
+ "validate_design": ["esmfold", "alphafold2"],
410
+ "optimize_sequence": ["proteinmpnn"],
411
+ "predict_complex": ["alphafold2"],
412
+ "analyze_interface": ["pyrosetta"],
413
+ "predict_structure": ["esmfold", "alphafold2"],
414
+ "score_stability": ["esm2"],
415
+ "energy_minimize": ["openmm"],
416
+ "suggest_hotspots": [],
417
+ "get_design_status": [],
418
+ "generate_backbone": ["rfdiffusion"],
419
+ "rosetta_score": ["pyrosetta"],
420
+ "rosetta_relax": ["pyrosetta"],
421
+ "rosetta_interface_score": ["pyrosetta"],
422
+ "rosetta_design": ["pyrosetta"],
423
+ "predict_structure_boltz": ["boltz"],
424
+ "predict_affinity_boltz": ["boltz"],
425
+ }
426
+
427
+ TOOL_TO_FUNCTION: dict[str, set[DesignFunction]] = {
428
+ # MCP wrappers
429
+ "design_binder": {DesignFunction.BACKBONE_GENERATION, DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
430
+ "validate_design": {DesignFunction.STRUCTURE_PREDICTION},
431
+ "optimize_sequence": {DesignFunction.SEQUENCE_DESIGN},
432
+ "predict_complex": {DesignFunction.COMPLEX_PREDICTION, DesignFunction.STRUCTURE_PREDICTION},
433
+ "analyze_interface": {DesignFunction.INTERFACE_ANALYSIS},
434
+ "predict_structure": {DesignFunction.STRUCTURE_PREDICTION},
435
+ "score_stability": {DesignFunction.STABILITY_SCORING},
436
+ "energy_minimize": {DesignFunction.ENERGY_MINIMIZATION},
437
+ "suggest_hotspots": {DesignFunction.HOTSPOT_IDENTIFICATION},
438
+ "get_design_status": set(),
439
+ "generate_backbone": {DesignFunction.BACKBONE_GENERATION},
440
+ "rosetta_score": {DesignFunction.PHYSICS_VALIDATION},
441
+ "rosetta_relax": {DesignFunction.ENERGY_MINIMIZATION},
442
+ "rosetta_interface_score": {DesignFunction.INTERFACE_ANALYSIS},
443
+ "rosetta_design": {DesignFunction.SEQUENCE_DESIGN},
444
+ "predict_structure_boltz": {DesignFunction.STRUCTURE_PREDICTION},
445
+ "predict_affinity_boltz": {DesignFunction.COMPLEX_PREDICTION, DesignFunction.INTERFACE_ANALYSIS},
446
+ # Bio-level tools
447
+ "rfdiffusion": {DesignFunction.BACKBONE_GENERATION},
448
+ "proteinmpnn": {DesignFunction.SEQUENCE_DESIGN},
449
+ "alphafold2": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
450
+ "alphafold": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
451
+ "esmfold": {DesignFunction.STRUCTURE_PREDICTION},
452
+ "esm2": {DesignFunction.STABILITY_SCORING, DesignFunction.SEQUENCE_SCORING},
453
+ "pyrosetta": {DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION, DesignFunction.INTERFACE_ANALYSIS},
454
+ "rosetta": {DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION, DesignFunction.INTERFACE_ANALYSIS},
455
+ "openmm": {DesignFunction.ENERGY_MINIMIZATION},
456
+ "boltz": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
457
+ "foldx": {DesignFunction.STABILITY_SCORING, DesignFunction.PHYSICS_VALIDATION},
458
+ "colabfold": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
459
+ "foldseek": {DesignFunction.STRUCTURE_PREDICTION},
460
+ "chroma": {DesignFunction.BACKBONE_GENERATION},
461
+ "ligandmpnn": {DesignFunction.SEQUENCE_DESIGN},
462
+ "esm_if": {DesignFunction.SEQUENCE_DESIGN},
463
+ "mpnn": {DesignFunction.SEQUENCE_DESIGN},
464
+ }
465
+
466
+
467
+ class _TaskTypeDict(dict):
468
+ """Dict that accepts both DesignTaskType enum and string keys."""
469
+
470
+ def __init__(self, raw: dict[str, set[DesignFunction]]):
471
+ super().__init__()
472
+ self._raw = raw
473
+ for k, v in raw.items():
474
+ super().__setitem__(k, v)
475
+
476
+ def __contains__(self, key):
477
+ k = key.value if hasattr(key, "value") else key
478
+ return super().__contains__(k)
479
+
480
+ def __getitem__(self, key):
481
+ k = key.value if hasattr(key, "value") else key
482
+ return super().__getitem__(k)
483
+
484
+ def get(self, key, default=None):
485
+ k = key.value if hasattr(key, "value") else key
486
+ return super().get(k, default)
487
+
488
+
489
+ REQUIRED_FUNCTIONS = _TaskTypeDict({
490
+ "de_novo_binder": {DesignFunction.BACKBONE_GENERATION, DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
491
+ "sequence_optimization": {DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
492
+ "de_novo_backbone": {DesignFunction.BACKBONE_GENERATION, DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
493
+ "complex_engineering": {DesignFunction.SEQUENCE_DESIGN, DesignFunction.COMPLEX_PREDICTION},
494
+ "conformational_design": {DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
495
+ })
496
+
497
+ BONUS_FUNCTIONS = _TaskTypeDict({
498
+ "de_novo_binder": {DesignFunction.COMPLEX_PREDICTION, DesignFunction.INTERFACE_ANALYSIS, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.HOTSPOT_IDENTIFICATION},
499
+ "sequence_optimization": {DesignFunction.STABILITY_SCORING, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION},
500
+ "de_novo_backbone": {DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION},
501
+ "complex_engineering": {DesignFunction.BACKBONE_GENERATION, DesignFunction.INTERFACE_ANALYSIS, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.STRUCTURE_PREDICTION},
502
+ "conformational_design": {DesignFunction.STABILITY_SCORING, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.COMPLEX_PREDICTION},
503
+ })
504
+
505
+ _GENERATION_TOOLS: set[str] = {
506
+ "rfdiffusion", "proteinmpnn", "design_binder", "optimize_sequence",
507
+ "generate_backbone", "rosetta_design", "chroma", "ligandmpnn",
508
+ "esm_if", "mpnn",
509
+ }
510
+
511
+ _VALIDATION_TOOLS: set[str] = {
512
+ "esmfold", "alphafold2", "validate_design", "predict_structure",
513
+ "predict_complex", "score_stability", "rosetta_score",
514
+ "rosetta_interface_score", "predict_structure_boltz",
515
+ "predict_affinity_boltz", "analyze_interface",
516
+ }
517
+
518
+ _REFINEMENT_TOOLS: set[str] = {
519
+ "energy_minimize", "rosetta_relax", "openmm", "pyrosetta", "rosetta",
520
+ }
521
+
522
+
523
+ def expand_mcp_tools(tools: list[str]) -> list[str]:
524
+ """Expand MCP wrapper tool names to their underlying bio tools."""
525
+ seen: set[str] = set()
526
+ expanded: list[str] = []
527
+ for tool in tools:
528
+ if tool in MCP_TOOL_EXPANSION:
529
+ underlying = MCP_TOOL_EXPANSION[tool]
530
+ if not underlying:
531
+ if tool not in seen:
532
+ expanded.append(tool)
533
+ seen.add(tool)
534
+ else:
535
+ for ut in underlying:
536
+ if ut not in seen:
537
+ expanded.append(ut)
538
+ seen.add(ut)
539
+ else:
540
+ if tool not in seen:
541
+ expanded.append(tool)
542
+ seen.add(tool)
543
+ return expanded
544
+
545
+
546
+ def normalize_tool_name(tool: str) -> str:
547
+ return tool.lower().strip().replace(" ", "").replace("-", "").replace("_", "")
548
+
549
+
550
+ def get_tool_category(tool: str) -> str | None:
551
+ normalized = normalize_tool_name(tool)
552
+ for name, category in TOOL_CATEGORIES.items():
553
+ if normalize_tool_name(name) == normalized:
554
+ return category
555
+ return None
556
+
557
+
558
+ def _extract_functions_from_tools(tools: list[str]) -> set[DesignFunction]:
559
+ functions: set[DesignFunction] = set()
560
+ for tool in tools:
561
+ if tool in TOOL_TO_FUNCTION:
562
+ functions.update(TOOL_TO_FUNCTION[tool])
563
+ else:
564
+ norm = normalize_tool_name(tool)
565
+ for known, funcs in TOOL_TO_FUNCTION.items():
566
+ if normalize_tool_name(known) == norm:
567
+ functions.update(funcs)
568
+ break
569
+ return functions
570
+
571
+
572
+ def _check_validation(tools_used: list[str]) -> float:
573
+ if not tools_used:
574
+ return 0.0
575
+ has_generation = False
576
+ has_validation_after_generation = False
577
+ has_any_validation = False
578
+ for tool in tools_used:
579
+ if tool in _GENERATION_TOOLS:
580
+ has_generation = True
581
+ if tool in _VALIDATION_TOOLS:
582
+ has_any_validation = True
583
+ if has_generation:
584
+ has_validation_after_generation = True
585
+ if has_validation_after_generation:
586
+ return 4.0
587
+ if has_any_validation:
588
+ return 2.0
589
+ return 0.0
590
+
591
+
592
+ def _check_refinement(tools_used: list[str]) -> float:
593
+ if not tools_used:
594
+ return 0.0
595
+ for tool in tools_used:
596
+ if tool in _REFINEMENT_TOOLS:
597
+ return 4.0
598
+ counts = Counter(tools_used)
599
+ for tool, count in counts.items():
600
+ if count >= 2 and (tool in _GENERATION_TOOLS or tool in _VALIDATION_TOOLS):
601
+ return 4.0
602
+ return 0.0
603
+
604
+
605
+ def _score_approach_legacy(
606
+ tools_used: list[str],
607
+ tools_expected: list[str],
608
+ max_points: int = 20,
609
+ ) -> dict:
610
+ if not tools_expected:
611
+ return {
612
+ "score": max_points, "max": max_points,
613
+ "breakdown": [], "tools_matched": [], "tools_missing": [],
614
+ "mode": "legacy",
615
+ }
616
+ expanded_used = expand_mcp_tools(tools_used)
617
+ per_tool = max_points / len(tools_expected)
618
+ used_normalized = [normalize_tool_name(t) for t in expanded_used]
619
+ used_categories = [get_tool_category(t) for t in expanded_used]
620
+ total = 0.0
621
+ breakdown = []
622
+ matched = []
623
+ missing = []
624
+ for expected in tools_expected:
625
+ expected_norm = normalize_tool_name(expected)
626
+ expected_cat = get_tool_category(expected)
627
+ if expected_norm in used_normalized:
628
+ total += per_tool
629
+ breakdown.append({"tool": expected, "match": "exact", "points": per_tool})
630
+ matched.append(expected)
631
+ elif expected_cat and expected_cat in used_categories:
632
+ points = per_tool * 0.7
633
+ total += points
634
+ breakdown.append({"tool": expected, "match": "category", "points": points})
635
+ matched.append(expected)
636
+ else:
637
+ breakdown.append({"tool": expected, "match": "none", "points": 0})
638
+ missing.append(expected)
639
+ return {
640
+ "score": int(round(total)), "max": max_points,
641
+ "breakdown": breakdown, "tools_matched": matched,
642
+ "tools_missing": missing, "mode": "legacy",
643
+ }
644
+
645
+
646
+ def score_approach(
647
+ tools_used: list[str],
648
+ tools_expected: list[str],
649
+ max_points: int = 20,
650
+ task_type: DesignTaskType | str | None = None,
651
+ ) -> dict:
652
+ """Score the agent's tool/methodology selection."""
653
+ if task_type is None:
654
+ return _score_approach_legacy(tools_used, tools_expected, max_points)
655
+
656
+ tt_key = task_type.value if hasattr(task_type, "value") else str(task_type)
657
+ scale = max_points / 20.0
658
+ func_max = 12.0 * scale
659
+
660
+ agent_functions = _extract_functions_from_tools(tools_used)
661
+ required = REQUIRED_FUNCTIONS.get(tt_key, set())
662
+ bonus = BONUS_FUNCTIONS.get(tt_key, set())
663
+
664
+ if required:
665
+ covered_required = agent_functions & required
666
+ required_ratio = len(covered_required) / len(required)
667
+ else:
668
+ required_ratio = 1.0 if agent_functions else 0.0
669
+ covered_required = set()
670
+
671
+ covered_bonus = agent_functions & bonus
672
+ bonus_count = min(len(covered_bonus), 3)
673
+ func_score = (required_ratio * 9.0 + bonus_count * 1.0) * scale
674
+ func_score = min(func_score, func_max)
675
+
676
+ val_score = _check_validation(tools_used) * scale
677
+ ref_score = _check_refinement(tools_used) * scale
678
+
679
+ total = min(func_score + val_score + ref_score, float(max_points))
680
+
681
+ return {
682
+ "score": int(round(total)), "max": max_points, "mode": "function",
683
+ "function_coverage": round(func_score, 1),
684
+ "validation_inclusion": round(val_score, 1),
685
+ "iterative_refinement": round(ref_score, 1),
686
+ "required_functions": sorted(f.value for f in required),
687
+ "covered_required": sorted(f.value for f in covered_required),
688
+ "covered_bonus": sorted(f.value for f in covered_bonus),
689
+ "agent_functions": sorted(f.value for f in agent_functions),
690
+ }
691
+
692
+
693
+ # ═══════════════════════════════════════════════════════════════════════════════
694
+ # SECTION 4 — Orchestration Scoring (from biodesignbench/eval/metrics/orchestration.py)
695
+ # ═══════════════════════════════════════════════════════════════════════════════
696
+
697
+ EXPECTED_PIPELINES: dict[str, list[str]] = {
698
+ "de_novo_binder": ["rfdiffusion", "proteinmpnn", "esmfold"],
699
+ "sequence_optimization": ["proteinmpnn", "esmfold"],
700
+ "de_novo_backbone": ["rfdiffusion", "proteinmpnn", "esmfold"],
701
+ "complex_engineering": ["rfdiffusion", "proteinmpnn", "esmfold"],
702
+ "conformational_design": ["proteinmpnn", "esmfold"],
703
+ # Old category names (backward compat)
704
+ "binder": ["rfdiffusion", "proteinmpnn", "esmfold"],
705
+ "antibody": ["proteinmpnn", "esmfold"],
706
+ "stability": ["proteinmpnn", "esmfold"],
707
+ "enzyme": ["rfdiffusion", "proteinmpnn", "esmfold"],
708
+ }
709
+
710
+ ORCHESTRATION_VALIDATION_TOOLS: set[str] = {
711
+ "validate_design", "predict_complex", "analyze_interface",
712
+ "esmfold", "score_stability", "rosetta_score",
713
+ "rosetta_interface_score", "predict_structure_boltz",
714
+ "predict_affinity_boltz",
715
+ }
716
+
717
+
718
+ def _expand_tool_name(tool: str) -> list[str]:
719
+ if tool in MCP_TOOL_EXPANSION:
720
+ underlying = MCP_TOOL_EXPANSION[tool]
721
+ return underlying if underlying else [tool]
722
+ return [tool]
723
+
724
+
725
+ def _extract_ordered_bio_tools(tool_call_log: list[dict[str, Any]]) -> list[str]:
726
+ utility_tools = {"execute_python", "read_file", "write_file"}
727
+ ordered: list[str] = []
728
+ for entry in tool_call_log:
729
+ tool = entry.get("tool", "")
730
+ if tool in utility_tools:
731
+ continue
732
+ expanded = _expand_tool_name(tool)
733
+ for t in expanded:
734
+ ordered.append(normalize_tool_name(t))
735
+ return ordered
736
+
737
+
738
+ def _longest_ordered_subsequence_length(
739
+ actual: list[str], expected: list[str]
740
+ ) -> int:
741
+ if not expected or not actual:
742
+ return 0
743
+ j = 0
744
+ matched = 0
745
+ for tool in actual:
746
+ k = j
747
+ while k < len(expected):
748
+ if tool == normalize_tool_name(expected[k]):
749
+ matched += 1
750
+ j = k + 1
751
+ break
752
+ k += 1
753
+ return matched
754
+
755
+
756
+ def _count_validation_steps(tool_call_log: list[dict[str, Any]]) -> int:
757
+ count = 0
758
+ for entry in tool_call_log:
759
+ tool = entry.get("tool", "")
760
+ if tool in ORCHESTRATION_VALIDATION_TOOLS:
761
+ count += 1
762
+ expanded = _expand_tool_name(tool)
763
+ for t in expanded:
764
+ if t in ORCHESTRATION_VALIDATION_TOOLS and tool not in ORCHESTRATION_VALIDATION_TOOLS:
765
+ count += 1
766
+ return count
767
+
768
+
769
+ def _has_adaptive_behavior(tool_call_log: list[dict[str, Any]]) -> bool:
770
+ tool_args: dict[str, list[dict]] = {}
771
+ for entry in tool_call_log:
772
+ tool = entry.get("tool", "")
773
+ args = entry.get("args_summary", {})
774
+ if tool not in tool_args:
775
+ tool_args[tool] = []
776
+ tool_args[tool].append(args)
777
+ for tool, args_list in tool_args.items():
778
+ if len(args_list) >= 2:
779
+ for i in range(1, len(args_list)):
780
+ if args_list[i] != args_list[i - 1]:
781
+ return True
782
+ return False
783
+
784
+
785
+ def _get_task_category_for_orchestration(task_id: str) -> str | None:
786
+ """Extract category from task_id using taxonomy, with legacy fallback."""
787
+ category = get_category(task_id)
788
+ if category is not None:
789
+ return category.task_type.value
790
+ for cat in ("binder", "antibody", "stability", "enzyme"):
791
+ if task_id.startswith(cat):
792
+ return cat
793
+ return None
794
+
795
+
796
+ def score_orchestration(
797
+ tool_call_log: list[dict[str, Any]],
798
+ task_id: str,
799
+ max_points: int = 15,
800
+ ) -> dict[str, Any]:
801
+ """Score the agent's multi-step pipeline orchestration."""
802
+ if not tool_call_log:
803
+ return {
804
+ "score": 0, "max": max_points,
805
+ "pipeline_order_score": 0.0, "validation_score": 0.0,
806
+ "adaptive_score": 0.0, "details": "No tool calls recorded",
807
+ }
808
+
809
+ category = _get_task_category_for_orchestration(task_id)
810
+ expected_pipeline = EXPECTED_PIPELINES.get(category, [])
811
+
812
+ ordered_tools = _extract_ordered_bio_tools(tool_call_log)
813
+ if expected_pipeline:
814
+ matched = _longest_ordered_subsequence_length(ordered_tools, expected_pipeline)
815
+ order_ratio = matched / len(expected_pipeline)
816
+ else:
817
+ order_ratio = 1.0 if ordered_tools else 0.0
818
+
819
+ pipeline_points = order_ratio * max_points * 0.5
820
+
821
+ validation_count = _count_validation_steps(tool_call_log)
822
+ if validation_count >= 2:
823
+ validation_ratio = 1.0
824
+ elif validation_count == 1:
825
+ validation_ratio = 0.6
826
+ else:
827
+ validation_ratio = 0.0
828
+ validation_points = validation_ratio * max_points * 0.3
829
+
830
+ adaptive = _has_adaptive_behavior(tool_call_log)
831
+ adaptive_points = max_points * 0.2 if adaptive else 0.0
832
+
833
+ total = int(round(pipeline_points + validation_points + adaptive_points))
834
+
835
+ return {
836
+ "score": min(total, max_points), "max": max_points,
837
+ "pipeline_order_score": round(pipeline_points, 1),
838
+ "validation_score": round(validation_points, 1),
839
+ "adaptive_score": round(adaptive_points, 1),
840
+ "expected_pipeline": expected_pipeline,
841
+ "actual_tool_order": ordered_tools,
842
+ "validation_steps": validation_count,
843
+ "adaptive_behavior": adaptive,
844
+ }
845
+
846
+
847
+ # ═══════════════════════════════════════════════════════════════════════════════
848
+ # SECTION 5 — Quality + Scoring (from biodesignbench/eval/tier2/scoring.py)
849
+ # ═══════════════════════════════════════════════════════════════════════════════
850
+
851
+ DEFAULT_DESIGN_RUBRIC = {
852
+ "approach": 20, "orchestration": 15, "quality": 35,
853
+ "feasibility": 15, "novelty": 5, "diversity": 10,
854
+ }
855
+
856
+ METRIC_RANGES: dict[str, tuple[float, float]] = {
857
+ "pLDDT": (0, 100), "pTM": (0, 1), "ipTM": (0, 1),
858
+ "i_pAE": (0, 50), "predicted_kd": (0, 1e6),
859
+ "predicted_ddG": (-100, 100), "active_site_rmsd": (0, 50),
860
+ "max_sequence_identity": (0, 1), "TM_score": (0, 1),
861
+ }
862
+
863
+ THRESHOLD_TO_METRIC: dict[str, tuple[str, str]] = {
864
+ "pLDDT_good": ("pLDDT", "higher_is_better"),
865
+ "ipTM_good": ("ipTM", "higher_is_better"),
866
+ "kd_nM_good": ("predicted_kd", "lower_is_better"),
867
+ "predicted_ddG_good": ("predicted_ddG", "lower_is_better"),
868
+ "active_site_rmsd_good": ("active_site_rmsd", "lower_is_better"),
869
+ }
870
+
871
+ # Tier A: Structure Confidence
872
+ _TIER_A_THRESHOLDS: dict[str, dict[str, float]] = {
873
+ "pLDDT": {"pass": 65, "good": 80, "excellent": 90},
874
+ "pTM": {"pass": 0.45, "good": 0.65, "excellent": 0.80},
875
+ }
876
+
877
+ # Tier B: Interface Confidence (binding only)
878
+ _TIER_B_THRESHOLDS: dict[str, dict[str, float]] = {
879
+ "ipTM": {"pass": 0.15, "good": 0.40, "excellent": 0.70},
880
+ "i_pAE": {"pass": 25.0, "good": 15.0, "excellent": 8.0},
881
+ }
882
+ _TIER_B_DIRECTIONS: dict[str, str] = {"i_pAE": "lower_is_better"}
883
+
884
+ # Tier C: Interface Physics
885
+ _TIER_C_METRICS: dict[str, tuple[str, str]] = {
886
+ "kd_nM_good": ("predicted_kd", "lower_is_better"),
887
+ "predicted_ddG_good": ("predicted_ddG", "lower_is_better"),
888
+ "active_site_rmsd_good": ("active_site_rmsd", "lower_is_better"),
889
+ }
890
+ _TIER_C_PHYSICS: dict[str, dict[str, float]] = {
891
+ "buried_surface_area": {"pass": 800, "good": 1500, "excellent": 2500},
892
+ "hydrogen_bonds": {"pass": 5, "good": 15, "excellent": 30},
893
+ }
894
+
895
+ _TIER_A_BASE = 15
896
+ _TIER_B_BASE = 10
897
+ _TIER_C_BASE = 10
898
+ _QUALITY_BASE = _TIER_A_BASE + _TIER_B_BASE + _TIER_C_BASE # 35
899
+
900
+ _BINDING_TASK_TYPES: set[DesignTaskType] = {
901
+ DesignTaskType.DE_NOVO_BINDER,
902
+ DesignTaskType.COMPLEX_ENGINEERING,
903
+ }
904
+ _BINDING_OLD_PREFIXES: set[str] = {"binder", "antibody", "ppi", "peptide"}
905
+
906
+
907
+ def _is_binding_task(task_id: str | None) -> bool:
908
+ if not task_id:
909
+ return False
910
+ cat = get_category(task_id)
911
+ if cat is not None:
912
+ return cat.task_type in _BINDING_TASK_TYPES
913
+ prefix = task_id.split("_")[0]
914
+ return prefix in _BINDING_OLD_PREFIXES
915
+
916
+
917
+ def _get_tier_weights(
918
+ task_id: str | None = None,
919
+ max_points: int = 35,
920
+ ) -> tuple[int, int, int]:
921
+ if not task_id:
922
+ scale = max_points / _QUALITY_BASE if _QUALITY_BASE > 0 else 0
923
+ return (
924
+ int(round(_TIER_A_BASE * scale)),
925
+ int(round(_TIER_B_BASE * scale)),
926
+ int(round(_TIER_C_BASE * scale)),
927
+ )
928
+ is_binding = _is_binding_task(task_id)
929
+ cat = get_category(task_id)
930
+ if cat is None and not is_binding:
931
+ scale = max_points / _QUALITY_BASE if _QUALITY_BASE > 0 else 0
932
+ return (
933
+ int(round(_TIER_A_BASE * scale)),
934
+ int(round(_TIER_B_BASE * scale)),
935
+ int(round(_TIER_C_BASE * scale)),
936
+ )
937
+ if is_binding:
938
+ ratio_a = 12 / 35
939
+ ratio_b = 18 / 35
940
+ a = int(round(max_points * ratio_a))
941
+ b = int(round(max_points * ratio_b))
942
+ c = max_points - a - b
943
+ return (a, b, c)
944
+ else:
945
+ ratio_a = 25 / 35
946
+ ratio_b = 10 / 35
947
+ a = int(round(max_points * ratio_a))
948
+ b = int(round(max_points * ratio_b))
949
+ c = max_points - a - b
950
+ return (a, b, c)
951
+
952
+
953
+ def _continuous_score(
954
+ value: float,
955
+ thresholds: dict[str, float],
956
+ direction: str = "higher_is_better",
957
+ ) -> float:
958
+ """Return continuous fraction [0.0, 1.0] via linear interpolation."""
959
+ p, g, e = thresholds["pass"], thresholds["good"], thresholds["excellent"]
960
+
961
+ if direction == "lower_is_better":
962
+ floor = p + abs(p) * 0.3 if p != 0 else 0.3
963
+ if value <= e:
964
+ return 1.0
965
+ if value >= floor:
966
+ return 0.0
967
+ if value <= g:
968
+ span = g - e
969
+ if span == 0:
970
+ return 1.0
971
+ return 0.66 + (g - value) / span * 0.34
972
+ if value <= p:
973
+ span = p - g
974
+ if span == 0:
975
+ return 0.66
976
+ return 0.33 + (p - value) / span * 0.33
977
+ span = floor - p
978
+ if span == 0:
979
+ return 0.0
980
+ return 0.33 * (floor - value) / span
981
+
982
+ # higher_is_better
983
+ floor = p * 0.7
984
+ if value >= e:
985
+ return 1.0
986
+ if value <= floor:
987
+ return 0.0
988
+ if value >= g:
989
+ span = e - g
990
+ if span == 0:
991
+ return 1.0
992
+ return 0.66 + (value - g) / span * 0.34
993
+ if value >= p:
994
+ span = g - p
995
+ if span == 0:
996
+ return 0.66
997
+ return 0.33 + (value - p) / span * 0.33
998
+ span = p - floor
999
+ if span == 0:
1000
+ return 0.0
1001
+ return 0.33 * (value - floor) / span
1002
+
1003
+
1004
+ # Category-specific quality metrics (17 valid taxonomy cells)
1005
+ QUALITY_METRICS: dict[tuple[DesignTaskType, BiologicalContext], dict[str, Any]] = {
1006
+ # de_novo_binder (4 cells)
1007
+ (DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ANTIBODY): {
1008
+ "primary_metric": "ipTM",
1009
+ "thresholds": {"excellent": 0.75, "good": 0.50, "pass": 0.20},
1010
+ "secondary_metrics": ["pLDDT", "predicted_kd"],
1011
+ },
1012
+ (DesignTaskType.DE_NOVO_BINDER, BiologicalContext.SIGNALING): {
1013
+ "primary_metric": "ipTM",
1014
+ "thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
1015
+ "secondary_metrics": ["pLDDT", "predicted_kd"],
1016
+ },
1017
+ (DesignTaskType.DE_NOVO_BINDER, BiologicalContext.THERAPEUTIC): {
1018
+ "primary_metric": "ipTM",
1019
+ "thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
1020
+ "secondary_metrics": ["pLDDT", "predicted_kd"],
1021
+ },
1022
+ (DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ENZYME): {
1023
+ "primary_metric": "ipTM",
1024
+ "thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
1025
+ "secondary_metrics": ["pLDDT", "predicted_kd", "active_site_rmsd"],
1026
+ },
1027
+ # sequence_optimization (5 cells)
1028
+ (DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ANTIBODY): {
1029
+ "primary_metric": "pLDDT",
1030
+ "thresholds": {"excellent": 90, "good": 80, "pass": 65},
1031
+ "secondary_metrics": ["ipTM", "max_sequence_identity"],
1032
+ },
1033
+ (DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ENZYME): {
1034
+ "primary_metric": "pLDDT",
1035
+ "thresholds": {"excellent": 90, "good": 80, "pass": 65},
1036
+ "secondary_metrics": ["predicted_ddG", "active_site_rmsd"],
1037
+ },
1038
+ (DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.STRUCTURAL): {
1039
+ "primary_metric": "pLDDT",
1040
+ "thresholds": {"excellent": 92, "good": 82, "pass": 68},
1041
+ "secondary_metrics": ["TM_score", "predicted_ddG"],
1042
+ },
1043
+ (DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.FLUORESCENT): {
1044
+ "primary_metric": "pLDDT",
1045
+ "thresholds": {"excellent": 88, "good": 78, "pass": 62},
1046
+ "secondary_metrics": ["predicted_ddG", "max_sequence_identity"],
1047
+ },
1048
+ (DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.SIGNALING): {
1049
+ "primary_metric": "pLDDT",
1050
+ "thresholds": {"excellent": 90, "good": 80, "pass": 65},
1051
+ "secondary_metrics": ["ipTM", "predicted_ddG"],
1052
+ },
1053
+ # de_novo_backbone (1 cell)
1054
+ (DesignTaskType.DE_NOVO_BACKBONE, BiologicalContext.STRUCTURAL): {
1055
+ "primary_metric": "pLDDT",
1056
+ "thresholds": {"excellent": 88, "good": 78, "pass": 60},
1057
+ "secondary_metrics": ["TM_score", "predicted_ddG"],
1058
+ },
1059
+ # complex_engineering (3 cells)
1060
+ (DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.SIGNALING): {
1061
+ "primary_metric": "ipTM",
1062
+ "thresholds": {"excellent": 0.72, "good": 0.48, "pass": 0.20},
1063
+ "secondary_metrics": ["pLDDT", "predicted_kd"],
1064
+ },
1065
+ (DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.STRUCTURAL): {
1066
+ "primary_metric": "ipTM",
1067
+ "thresholds": {"excellent": 0.72, "good": 0.48, "pass": 0.20},
1068
+ "secondary_metrics": ["pLDDT", "TM_score"],
1069
+ },
1070
+ (DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.ENZYME): {
1071
+ "primary_metric": "ipTM",
1072
+ "thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
1073
+ "secondary_metrics": ["pLDDT", "predicted_kd", "active_site_rmsd"],
1074
+ },
1075
+ # conformational_design (4 cells)
1076
+ (DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.ENZYME): {
1077
+ "primary_metric": "pLDDT",
1078
+ "thresholds": {"excellent": 88, "good": 78, "pass": 62},
1079
+ "secondary_metrics": ["predicted_ddG", "active_site_rmsd"],
1080
+ },
1081
+ (DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.SIGNALING): {
1082
+ "primary_metric": "pLDDT",
1083
+ "thresholds": {"excellent": 85, "good": 75, "pass": 60},
1084
+ "secondary_metrics": ["ipTM", "predicted_kd"],
1085
+ },
1086
+ (DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.FLUORESCENT): {
1087
+ "primary_metric": "pLDDT",
1088
+ "thresholds": {"excellent": 85, "good": 75, "pass": 60},
1089
+ "secondary_metrics": ["predicted_ddG", "max_sequence_identity"],
1090
+ },
1091
+ (DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.STRUCTURAL): {
1092
+ "primary_metric": "pLDDT",
1093
+ "thresholds": {"excellent": 88, "good": 78, "pass": 62},
1094
+ "secondary_metrics": ["TM_score", "predicted_ddG"],
1095
+ },
1096
+ }
1097
+
1098
+
1099
+ def get_quality_config(task_id: str) -> dict[str, Any] | None:
1100
+ category = get_category(task_id)
1101
+ if category is None:
1102
+ return None
1103
+ key = (category.task_type, category.context)
1104
+ return QUALITY_METRICS.get(key)
1105
+
1106
+
1107
+ @dataclass
1108
+ class DesignScoringRubric:
1109
+ components: dict[str, int] = field(default_factory=lambda: dict(DEFAULT_DESIGN_RUBRIC))
1110
+
1111
+ @property
1112
+ def max_score(self) -> int:
1113
+ return sum(self.components.values())
1114
+
1115
+ def validate(self) -> None:
1116
+ total = sum(self.components.values())
1117
+ if total != 100:
1118
+ raise ValueError(f"Rubric total must be 100, got {total}")
1119
+
1120
+
1121
+ def _has_reasonable_composition(seq: str, min_length: int = 20) -> bool:
1122
+ upper = seq.upper()
1123
+ if len(upper) < min_length:
1124
+ return False
1125
+ unique_aas = len(set(upper))
1126
+ if unique_aas < 5:
1127
+ return False
1128
+ counts = Counter(upper)
1129
+ max_fraction = max(counts.values()) / len(upper)
1130
+ if max_fraction > 0.5:
1131
+ return False
1132
+ ala_fraction = counts.get("A", 0) / len(upper)
1133
+ if ala_fraction > 0.3:
1134
+ return False
1135
+ hp = hydrophobicity_profile(upper)
1136
+ if hp["mean"] > 2.0:
1137
+ return False
1138
+ return True
1139
+
1140
+
1141
+ def validate_metric_range(name: str, value: float) -> bool:
1142
+ if name not in METRIC_RANGES:
1143
+ return True
1144
+ low, high = METRIC_RANGES[name]
1145
+ return low <= value <= high
1146
+
1147
+
1148
+ # Functional Similarity thresholds for non-binding Tier B
1149
+ _FUNCTIONAL_SIM_DEFAULTS: dict[DesignTaskType, dict[str, float]] = {
1150
+ DesignTaskType.SEQUENCE_OPTIMIZATION: {"pass": 0.40, "good": 0.60, "excellent": 0.85},
1151
+ DesignTaskType.CONFORMATIONAL_DESIGN: {"pass": 0.15, "good": 0.30, "excellent": 0.50},
1152
+ DesignTaskType.DE_NOVO_BACKBONE: {"pass": 0.10, "good": 0.20, "excellent": 0.40},
1153
+ }
1154
+
1155
+
1156
+ def _derive_functional_sim_thresholds(value: float) -> dict[str, float]:
1157
+ return {
1158
+ "pass": value * 0.5,
1159
+ "good": value,
1160
+ "excellent": min(value * 2, 1.0),
1161
+ }
1162
+
1163
+
1164
+ def _get_functional_sim_thresholds(
1165
+ thresholds: dict[str, float],
1166
+ task_id: str,
1167
+ ) -> dict[str, float] | None:
1168
+ if _is_binding_task(task_id):
1169
+ return None
1170
+ gt_value = thresholds.get("max_seq_identity_good")
1171
+ if gt_value is not None:
1172
+ return _derive_functional_sim_thresholds(gt_value)
1173
+ cat = get_category(task_id)
1174
+ if cat is None:
1175
+ return None
1176
+ return _FUNCTIONAL_SIM_DEFAULTS.get(cat.task_type)
1177
+
1178
+
1179
+ def _score_functional_similarity(
1180
+ designs: list[str],
1181
+ oracle_sequences: list[str],
1182
+ thresholds: dict[str, float],
1183
+ ) -> float | None:
1184
+ if not designs or not oracle_sequences:
1185
+ return None
1186
+ best_identity = 0.0
1187
+ for design in designs:
1188
+ for oracle in oracle_sequences:
1189
+ ident = sequence_identity(design, oracle)
1190
+ if ident > best_identity:
1191
+ best_identity = ident
1192
+ return _continuous_score(best_identity, thresholds, "higher_is_better")
1193
+
1194
+
1195
+ def score_quality(
1196
+ agent_metrics: dict[str, float],
1197
+ thresholds: dict[str, float],
1198
+ max_points: int = 35,
1199
+ task_id: str | None = None,
1200
+ designs: list[str] | None = None,
1201
+ oracle_sequences: list[str] | None = None,
1202
+ ) -> dict[str, Any]:
1203
+ """Score quality using 3-tier continuous system."""
1204
+ valid_metrics = {
1205
+ k: v for k, v in agent_metrics.items() if validate_metric_range(k, v)
1206
+ }
1207
+ for extra_key in ("buried_surface_area", "hydrogen_bonds"):
1208
+ if extra_key in agent_metrics and extra_key not in valid_metrics:
1209
+ val = agent_metrics[extra_key]
1210
+ if isinstance(val, (int, float)) and val >= 0:
1211
+ valid_metrics[extra_key] = float(val)
1212
+
1213
+ tier_a_max, tier_b_max, tier_c_max = _get_tier_weights(task_id, max_points)
1214
+ is_binding = _is_binding_task(task_id)
1215
+
1216
+ overrides: dict[str, dict[str, float]] = {}
1217
+ if task_id:
1218
+ config = get_quality_config(task_id)
1219
+ if config and "thresholds" in config:
1220
+ primary = config["primary_metric"]
1221
+ overrides[primary] = config["thresholds"]
1222
+
1223
+ # Tier A: Structure Confidence
1224
+ tier_a_scores: dict[str, float] = {}
1225
+ for metric, default_thresh in _TIER_A_THRESHOLDS.items():
1226
+ if metric in valid_metrics:
1227
+ thresh = overrides.get(metric, default_thresh)
1228
+ tier_a_scores[metric] = _continuous_score(
1229
+ valid_metrics[metric], thresh, "higher_is_better"
1230
+ )
1231
+ tier_a_pts = (sum(tier_a_scores.values()) / len(tier_a_scores)) * tier_a_max if tier_a_scores else 0.0
1232
+
1233
+ # Tier B: Interface or Functional Similarity
1234
+ tier_b_scores: dict[str, float] = {}
1235
+ tier_b_pts = 0.0
1236
+ _use_functional_sim = (
1237
+ tier_b_max > 0
1238
+ and task_id is not None
1239
+ and not is_binding
1240
+ and get_category(task_id) is not None
1241
+ )
1242
+
1243
+ if tier_b_max > 0:
1244
+ if _use_functional_sim:
1245
+ if designs and oracle_sequences:
1246
+ func_thresh = _get_functional_sim_thresholds(thresholds, task_id)
1247
+ if func_thresh is not None:
1248
+ frac = _score_functional_similarity(designs, oracle_sequences, func_thresh)
1249
+ if frac is not None:
1250
+ tier_b_pts = frac * tier_b_max
1251
+ tier_b_scores["oracle_identity"] = frac
1252
+ else:
1253
+ for metric, default_thresh in _TIER_B_THRESHOLDS.items():
1254
+ if metric in valid_metrics:
1255
+ thresh = overrides.get(metric, default_thresh)
1256
+ direction = _TIER_B_DIRECTIONS.get(metric, "higher_is_better")
1257
+ tier_b_scores[metric] = _continuous_score(
1258
+ valid_metrics[metric], thresh, direction
1259
+ )
1260
+ if tier_b_scores:
1261
+ tier_b_pts = (sum(tier_b_scores.values()) / len(tier_b_scores)) * tier_b_max
1262
+
1263
+ # Tier C: Interface Physics
1264
+ tier_c_fractions: list[float] = []
1265
+ tier_c_breakdown: list[dict] = []
1266
+
1267
+ if tier_c_max > 0:
1268
+ if is_binding:
1269
+ for metric_key, phys_thresh in _TIER_C_PHYSICS.items():
1270
+ if metric_key in valid_metrics:
1271
+ frac = _continuous_score(valid_metrics[metric_key], phys_thresh, "higher_is_better")
1272
+ tier_c_fractions.append(frac)
1273
+ tier_c_breakdown.append({
1274
+ "threshold": metric_key, "metric": metric_key,
1275
+ "value": valid_metrics[metric_key],
1276
+ "threshold_value": phys_thresh, "fraction": round(frac, 3),
1277
+ })
1278
+
1279
+ for thresh_key, (metric_key, direction) in _TIER_C_METRICS.items():
1280
+ if thresh_key in thresholds and metric_key in valid_metrics:
1281
+ threshold_val = thresholds[thresh_key]
1282
+ agent_val = valid_metrics[metric_key]
1283
+ margin = abs(threshold_val) * 0.5 if threshold_val != 0 else 1.0
1284
+ if direction == "lower_is_better":
1285
+ gt_thresh = {
1286
+ "pass": threshold_val + margin,
1287
+ "good": threshold_val,
1288
+ "excellent": threshold_val - margin,
1289
+ }
1290
+ else:
1291
+ gt_thresh = {
1292
+ "pass": threshold_val - margin,
1293
+ "good": threshold_val,
1294
+ "excellent": threshold_val + margin,
1295
+ }
1296
+ frac = _continuous_score(agent_val, gt_thresh, direction)
1297
+ tier_c_fractions.append(frac)
1298
+ tier_c_breakdown.append({
1299
+ "threshold": thresh_key, "metric": metric_key,
1300
+ "value": agent_val, "threshold_value": threshold_val,
1301
+ "fraction": round(frac, 3),
1302
+ })
1303
+
1304
+ tier_c_pts = (sum(tier_c_fractions) / len(tier_c_fractions)) * tier_c_max if tier_c_fractions else 0.0
1305
+
1306
+ total = min(tier_a_pts + tier_b_pts + tier_c_pts, max_points)
1307
+ metrics_evaluated = len(tier_a_scores) + len(tier_b_scores) + len(tier_c_fractions)
1308
+
1309
+ return {
1310
+ "score": int(round(total)), "max": max_points,
1311
+ "tier_a": round(tier_a_pts, 1), "tier_b": round(tier_b_pts, 1),
1312
+ "tier_c": round(tier_c_pts, 1),
1313
+ "metrics_evaluated": metrics_evaluated,
1314
+ "breakdown": {
1315
+ "structure": tier_a_scores, "interface": tier_b_scores,
1316
+ "physics": tier_c_breakdown,
1317
+ },
1318
+ }
1319
+
1320
+
1321
+ def score_novelty(
1322
+ designs: list[str],
1323
+ reference_seq: str | None,
1324
+ thresholds: dict[str, float],
1325
+ max_points: int = 5,
1326
+ ) -> dict[str, Any]:
1327
+ """Score novelty by computing sequence identity to reference."""
1328
+ if not designs:
1329
+ return {"score": 0, "max": max_points, "max_identity": 0.0, "identity_threshold": None}
1330
+
1331
+ identity_threshold = thresholds.get("max_seq_identity_good")
1332
+ max_id = max_identity_to_reference(designs, reference_seq) if reference_seq else 0.0
1333
+
1334
+ if identity_threshold is None:
1335
+ if reference_seq:
1336
+ novelty_ratio = 1.0 - max_id
1337
+ score = int(round(max_points * min(novelty_ratio * 2, 1.0)))
1338
+ else:
1339
+ score = max_points
1340
+ elif identity_threshold >= 0.9:
1341
+ if max_id >= identity_threshold:
1342
+ score = max_points
1343
+ elif max_id >= identity_threshold * 0.9:
1344
+ score = int(round(max_points * 0.7))
1345
+ else:
1346
+ score = int(round(max_points * 0.3))
1347
+ else:
1348
+ if max_id <= identity_threshold:
1349
+ score = max_points
1350
+ elif max_id <= identity_threshold * 1.5:
1351
+ score = int(round(max_points * 0.5))
1352
+ else:
1353
+ score = int(round(max_points * 0.2))
1354
+
1355
+ return {
1356
+ "score": min(score, max_points), "max": max_points,
1357
+ "max_identity": round(max_id, 3), "identity_threshold": identity_threshold,
1358
+ }
1359
+
1360
+
1361
+ def score_diversity(
1362
+ designs: list[str],
1363
+ max_designs: int = 10,
1364
+ max_points: int = 5,
1365
+ ) -> dict[str, Any]:
1366
+ """Score diversity of designs."""
1367
+ if not designs:
1368
+ return {"score": 0, "max": max_points, "num_designs": 0, "pairwise_diversity": 0.0, "entropy": 0.0}
1369
+
1370
+ num = len(designs)
1371
+ count_fraction = min(num / max_designs, 1.0) if max_designs > 0 else 1.0
1372
+ diversity = mean_pairwise_diversity(designs)
1373
+ entropy = sequence_entropy(designs)
1374
+
1375
+ count_score = count_fraction * max_points * 0.4
1376
+ diversity_score = diversity * max_points * 0.4
1377
+ entropy_score = entropy * max_points * 0.2
1378
+ total = int(round(count_score + diversity_score + entropy_score))
1379
+
1380
+ return {
1381
+ "score": min(total, max_points), "max": max_points,
1382
+ "num_designs": num, "pairwise_diversity": round(diversity, 3),
1383
+ "entropy": round(entropy, 3),
1384
+ }
1385
+
1386
+
1387
+ def score_feasibility(
1388
+ designs: list[str],
1389
+ constraints: dict[str, Any],
1390
+ max_points: int = 25,
1391
+ ) -> dict[str, Any]:
1392
+ """Score feasibility of designed sequences."""
1393
+ if not designs:
1394
+ return {"score": 0, "max": max_points, "aa_validity": 0.0, "length_validity": 0.0, "composition_check": 0.0}
1395
+
1396
+ per_check = max_points / 3
1397
+ length_range = constraints.get("length_range")
1398
+ if isinstance(length_range, list):
1399
+ length_range = tuple(length_range)
1400
+
1401
+ comp_min_length = 20
1402
+ if length_range and length_range[1] < 20:
1403
+ comp_min_length = max(length_range[0], 5)
1404
+
1405
+ aa_valid_count = sum(1 for seq in designs if validate_amino_acids(seq)["valid"])
1406
+ aa_fraction = aa_valid_count / len(designs)
1407
+
1408
+ length_valid_count = sum(1 for seq in designs if check_length_constraints(seq, length_range)["within_range"])
1409
+ length_fraction = length_valid_count / len(designs)
1410
+
1411
+ composition_ok = sum(1 for seq in designs if _has_reasonable_composition(seq, min_length=comp_min_length))
1412
+ composition_fraction = composition_ok / len(designs)
1413
+
1414
+ aa_score = aa_fraction * per_check
1415
+ length_score = length_fraction * per_check
1416
+ comp_score = composition_fraction * per_check
1417
+ total = int(round(aa_score + length_score + comp_score))
1418
+
1419
+ return {
1420
+ "score": min(total, max_points), "max": max_points,
1421
+ "aa_validity": round(aa_fraction, 3),
1422
+ "length_validity": round(length_fraction, 3),
1423
+ "composition_check": round(composition_fraction, 3),
1424
+ }
1425
+
1426
+
1427
+ # ═══════════════════════════════════════════════════════════════════════════════
1428
+ # SECTION 6 — Design Gate + Final Score
1429
+ # ═══════════════════════════════════════════════════════════════════════════════
1430
+
1431
+ _DESIGN_GATE_ZEROED = {"quality", "novelty", "diversity", "feasibility"}
1432
+ _DESIGN_GATE_CAP = 30
1433
+
1434
+
1435
+ def apply_design_gate(
1436
+ component_scores: dict[str, int],
1437
+ num_designs: int,
1438
+ ) -> dict[str, int]:
1439
+ """If no designs produced, cap total at 30."""
1440
+ if num_designs >= 1:
1441
+ return dict(component_scores)
1442
+ gated = dict(component_scores)
1443
+ for key in _DESIGN_GATE_ZEROED:
1444
+ gated[key] = 0
1445
+ remaining_sum = sum(v for k, v in gated.items() if k not in _DESIGN_GATE_ZEROED)
1446
+ if remaining_sum > _DESIGN_GATE_CAP:
1447
+ scale = _DESIGN_GATE_CAP / remaining_sum
1448
+ for key in gated:
1449
+ if key not in _DESIGN_GATE_ZEROED:
1450
+ gated[key] = int(round(gated[key] * scale))
1451
+ return gated
1452
+
1453
+
1454
+ def calculate_design_score(
1455
+ rubric: DesignScoringRubric,
1456
+ results: dict[str, int],
1457
+ ) -> dict[str, Any]:
1458
+ """Calculate final design task score from component results."""
1459
+ breakdown = {}
1460
+ for component, max_pts in rubric.components.items():
1461
+ actual = min(results.get(component, 0), max_pts)
1462
+ breakdown[component] = {"score": actual, "max": max_pts}
1463
+ total = sum(v["score"] for v in breakdown.values())
1464
+ max_possible = rubric.max_score
1465
+ return {
1466
+ "breakdown": breakdown,
1467
+ "total": total,
1468
+ "max_possible": max_possible,
1469
+ "percentage": round(total / max_possible * 100, 1) if max_possible > 0 else 0,
1470
+ }
1471
+
1472
+
1473
+ # ═══════════════════════════════════════════════════════════════════════════════
1474
+ # SECTION 7 — Full Task Scorer (high-level API for eval pipeline)
1475
+ # ═══════════════════════════════════════════════════════════════════════════════
1476
+
1477
+
1478
+ def score_submission_task(
1479
+ task_id: str,
1480
+ sequences: list[str],
1481
+ run_log: list[dict[str, Any]],
1482
+ ground_truth: dict[str, Any],
1483
+ agent_metrics: dict[str, float] | None = None,
1484
+ oracle_sequences: list[str] | None = None,
1485
+ ) -> dict[str, Any]:
1486
+ """Score a single task submission end-to-end.
1487
+
1488
+ This is the main entry point for the evaluation pipeline.
1489
+
1490
+ Args:
1491
+ task_id: Task identifier (e.g., "dnb_sig_001").
1492
+ sequences: Designed amino acid sequences from the agent.
1493
+ run_log: Tool call log from the agent.
1494
+ ground_truth: Ground truth dict with thresholds, reference_sequence,
1495
+ design_constraints, tools_expected, max_designs.
1496
+ agent_metrics: Optional metrics reported by the agent or from Boltz
1497
+ (e.g., {"pLDDT": 85.0, "ipTM": 0.35}).
1498
+ oracle_sequences: Optional oracle sequences for functional similarity.
1499
+
1500
+ Returns:
1501
+ Dict with: total_score, component_scores, details, num_designs.
1502
+ """
1503
+ if agent_metrics is None:
1504
+ agent_metrics = {}
1505
+
1506
+ # Extract fields from ground truth
1507
+ thresholds = ground_truth.get("thresholds", {})
1508
+ reference_seq = ground_truth.get("reference_sequence")
1509
+ constraints = ground_truth.get("design_constraints", {})
1510
+ tools_expected = ground_truth.get("tools_expected", [])
1511
+ max_designs = ground_truth.get("max_designs", 10)
1512
+
1513
+ # Get task category for function-based scoring
1514
+ cat = get_category(task_id)
1515
+ task_type = cat.task_type if cat else None
1516
+
1517
+ # Extract tools used from run_log
1518
+ tools_used = [entry.get("tool", "") for entry in run_log if entry.get("tool")]
1519
+
1520
+ # Score all 6 components
1521
+ approach_result = score_approach(
1522
+ tools_used=tools_used,
1523
+ tools_expected=tools_expected,
1524
+ task_type=task_type,
1525
+ )
1526
+ orchestration_result = score_orchestration(
1527
+ tool_call_log=run_log,
1528
+ task_id=task_id,
1529
+ )
1530
+ quality_result = score_quality(
1531
+ agent_metrics=agent_metrics,
1532
+ thresholds=thresholds,
1533
+ task_id=task_id,
1534
+ designs=sequences,
1535
+ oracle_sequences=oracle_sequences,
1536
+ )
1537
+ feasibility_result = score_feasibility(
1538
+ designs=sequences,
1539
+ constraints=constraints,
1540
+ )
1541
+ novelty_result = score_novelty(
1542
+ designs=sequences,
1543
+ reference_seq=reference_seq,
1544
+ thresholds=thresholds,
1545
+ )
1546
+ diversity_result = score_diversity(
1547
+ designs=sequences,
1548
+ max_designs=max_designs,
1549
+ )
1550
+
1551
+ # Build component scores dict
1552
+ component_scores = {
1553
+ "approach": approach_result["score"],
1554
+ "orchestration": orchestration_result["score"],
1555
+ "quality": quality_result["score"],
1556
+ "feasibility": feasibility_result["score"],
1557
+ "novelty": novelty_result["score"],
1558
+ "diversity": diversity_result["score"],
1559
+ }
1560
+
1561
+ # Apply design gate
1562
+ num_designs = len(sequences)
1563
+ gated = apply_design_gate(component_scores, num_designs)
1564
+ total = sum(gated.values())
1565
+
1566
+ return {
1567
+ "total_score": total,
1568
+ "component_scores": gated,
1569
+ "num_designs": num_designs,
1570
+ "details": {
1571
+ "approach": approach_result,
1572
+ "orchestration": orchestration_result,
1573
+ "quality": quality_result,
1574
+ "feasibility": feasibility_result,
1575
+ "novelty": novelty_result,
1576
+ "diversity": diversity_result,
1577
+ },
1578
+ }
1579
+
1580
+
1581
+ def aggregate_scores(
1582
+ per_task_scores: dict[str, dict[str, Any]],
1583
+ ) -> dict[str, Any]:
1584
+ """Aggregate per-task scores into an overall submission result.
1585
+
1586
+ Args:
1587
+ per_task_scores: Dict mapping task_id → score_submission_task() result.
1588
+
1589
+ Returns:
1590
+ Dict with: overall_score, component_scores (averaged), taxonomy_scores,
1591
+ tasks_completed, tasks_with_zero.
1592
+ """
1593
+ if not per_task_scores:
1594
+ return {
1595
+ "overall_score": 0.0,
1596
+ "component_scores": {c: 0.0 for c in DEFAULT_DESIGN_RUBRIC},
1597
+ "taxonomy_scores": {},
1598
+ "tasks_completed": 0,
1599
+ "tasks_total": 0,
1600
+ "tasks_with_zero": 0,
1601
+ }
1602
+
1603
+ totals = {c: 0.0 for c in DEFAULT_DESIGN_RUBRIC}
1604
+ n = len(per_task_scores)
1605
+ tasks_with_zero = 0
1606
+
1607
+ # Taxonomy breakdown
1608
+ taxonomy_scores: dict[str, dict[str, list[float]]] = {}
1609
+
1610
+ for task_id, result in per_task_scores.items():
1611
+ total_score = result["total_score"]
1612
+ if total_score == 0:
1613
+ tasks_with_zero += 1
1614
+
1615
+ for comp, val in result["component_scores"].items():
1616
+ totals[comp] += val
1617
+
1618
+ # Taxonomy mapping
1619
+ cat = get_category(task_id)
1620
+ if cat:
1621
+ tt = cat.task_type.value
1622
+ ctx = cat.context.short
1623
+ taxonomy_scores.setdefault(tt, {}).setdefault(ctx, []).append(total_score)
1624
+
1625
+ # Average components
1626
+ avg_components = {c: round(v / n, 1) for c, v in totals.items()}
1627
+ overall = round(sum(avg_components.values()), 1)
1628
+
1629
+ # Average taxonomy scores
1630
+ taxonomy_avg: dict[str, dict[str, float]] = {}
1631
+ for tt, contexts in taxonomy_scores.items():
1632
+ taxonomy_avg[tt] = {}
1633
+ for ctx, scores in contexts.items():
1634
+ taxonomy_avg[tt][ctx] = round(sum(scores) / len(scores), 1)
1635
+
1636
+ return {
1637
+ "overall_score": overall,
1638
+ "component_scores": avg_components,
1639
+ "taxonomy_scores": taxonomy_avg,
1640
+ "tasks_completed": n,
1641
+ "tasks_total": n,
1642
+ "tasks_with_zero": tasks_with_zero,
1643
+ }
eval_tasks.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load hidden benchmark tasks from a private HuggingFace Dataset.
2
+
3
+ Each task row contains:
4
+ - task_id: e.g., "dnb_sig_001"
5
+ - task_json: Full task definition (JSON string)
6
+ - ground_truth: Ground truth thresholds + reference (JSON string)
7
+ - prompt_md: Task prompt in Markdown
8
+ - pdb_data: Base64-encoded PDB file (if needed)
9
+ - pdb_filename: Original PDB filename (e.g., "7n1j.pdb")
10
+ - oracle_sequences: JSON list of oracle sequences (for non-binding tasks)
11
+
12
+ Falls back to local files in development (when BDB_USE_LOCAL=1).
13
+
14
+ HF Dataset: RomeroLab-Duke/biodesignbench-hidden-tasks (private)
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import base64
20
+ import json
21
+ import logging
22
+ import os
23
+ from functools import lru_cache
24
+ from pathlib import Path
25
+ from typing import Any
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Configuration
31
+ # ---------------------------------------------------------------------------
32
+
33
+ TASKS_DATASET = os.environ.get(
34
+ "BDB_TASKS_DATASET",
35
+ "RomeroLab-Duke/biodesignbench-hidden-tasks",
36
+ )
37
+ HF_TOKEN = os.environ.get("HF_TOKEN")
38
+ USE_LOCAL = os.environ.get("BDB_USE_LOCAL", "0") == "1"
39
+
40
+ # Local paths (for development)
41
+ _PROJECT_ROOT = Path(__file__).resolve().parents[1]
42
+ _TASKS_DIR = _PROJECT_ROOT / "tasks" / "tier2"
43
+ _GT_DIR = _PROJECT_ROOT / "data" / "tier2" / "ground_truth"
44
+ _PROMPTS_DIR = _PROJECT_ROOT / "data" / "tier2" / "prompts"
45
+ _INPUT_DIR = _PROJECT_ROOT / "data" / "tier2" / "input"
46
+ _ORACLE_PATH = _PROJECT_ROOT / "data" / "oracle" / "sequences.json"
47
+ _TOOL_SCHEMAS_PATH = Path(__file__).parent / "mcp_tool_schemas.json"
48
+
49
+ # Public task IDs (for development/testing — not hidden)
50
+ PUBLIC_TASK_IDS = {"dnb_sig_001", "sqo_enz_005", "cpx_sig_001"}
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # HF Dataset loading
55
+ # ---------------------------------------------------------------------------
56
+
57
+
58
+ @lru_cache(maxsize=1)
59
+ def _load_from_hf() -> dict[str, dict[str, Any]]:
60
+ """Load all tasks from the private HF Dataset."""
61
+ try:
62
+ from datasets import load_dataset
63
+
64
+ ds = load_dataset(
65
+ TASKS_DATASET,
66
+ split="train",
67
+ token=HF_TOKEN,
68
+ )
69
+ tasks = {}
70
+ for row in ds:
71
+ task_id = row["task_id"]
72
+ tasks[task_id] = {
73
+ "task_id": task_id,
74
+ "task_json": json.loads(row["task_json"]),
75
+ "ground_truth": json.loads(row["ground_truth"]),
76
+ "prompt_md": row["prompt_md"],
77
+ "pdb_data": row.get("pdb_data"),
78
+ "pdb_filename": row.get("pdb_filename"),
79
+ "oracle_sequences": json.loads(row.get("oracle_sequences", "[]")),
80
+ }
81
+ logger.info(f"Loaded {len(tasks)} tasks from HF Dataset")
82
+ return tasks
83
+ except Exception as e:
84
+ logger.error(f"Failed to load tasks from HF: {e}")
85
+ return {}
86
+
87
+
88
+ @lru_cache(maxsize=1)
89
+ def _load_from_local() -> dict[str, dict[str, Any]]:
90
+ """Load tasks from local project files (development mode)."""
91
+ tasks = {}
92
+
93
+ # Load oracle data
94
+ oracle_data = {}
95
+ if _ORACLE_PATH.exists():
96
+ with open(_ORACLE_PATH) as f:
97
+ oracle_data = json.load(f)
98
+
99
+ # Enumerate task files
100
+ if not _TASKS_DIR.exists():
101
+ logger.warning(f"Tasks directory not found: {_TASKS_DIR}")
102
+ return tasks
103
+
104
+ for task_path in sorted(_TASKS_DIR.glob("*.json")):
105
+ task_id = task_path.stem
106
+ try:
107
+ with open(task_path) as f:
108
+ task_json = json.load(f)
109
+
110
+ # Ground truth
111
+ gt_path = _GT_DIR / f"{task_id}.json"
112
+ ground_truth = {}
113
+ if gt_path.exists():
114
+ with open(gt_path) as f:
115
+ ground_truth = json.load(f)
116
+
117
+ # Prompt
118
+ prompt_path = _PROMPTS_DIR / f"{task_id}.md"
119
+ prompt_md = ""
120
+ if prompt_path.exists():
121
+ prompt_md = prompt_path.read_text()
122
+
123
+ # PDB data
124
+ pdb_data = None
125
+ pdb_filename = None
126
+ input_pdb = task_json.get("input_pdb") or task_json.get("pdb_file")
127
+ if input_pdb:
128
+ pdb_path = _INPUT_DIR / input_pdb
129
+ if pdb_path.exists():
130
+ pdb_data = base64.b64encode(pdb_path.read_bytes()).decode()
131
+ pdb_filename = input_pdb
132
+
133
+ # Oracle sequences
134
+ oracle_entry = oracle_data.get(task_id, {})
135
+ oracle_seqs = oracle_entry.get("sequences", []) if isinstance(oracle_entry, dict) else []
136
+
137
+ tasks[task_id] = {
138
+ "task_id": task_id,
139
+ "task_json": task_json,
140
+ "ground_truth": ground_truth,
141
+ "prompt_md": prompt_md,
142
+ "pdb_data": pdb_data,
143
+ "pdb_filename": pdb_filename,
144
+ "oracle_sequences": oracle_seqs,
145
+ }
146
+ except Exception as e:
147
+ logger.warning(f"Failed to load task {task_id}: {e}")
148
+
149
+ logger.info(f"Loaded {len(tasks)} tasks from local files")
150
+ return tasks
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Public API
155
+ # ---------------------------------------------------------------------------
156
+
157
+
158
+ def load_all_tasks() -> dict[str, dict[str, Any]]:
159
+ """Load all benchmark tasks.
160
+
161
+ Returns:
162
+ Dict mapping task_id → task data dict.
163
+ """
164
+ if USE_LOCAL:
165
+ return _load_from_local()
166
+ return _load_from_hf()
167
+
168
+
169
+ def get_task(task_id: str) -> dict[str, Any] | None:
170
+ """Load a single task by ID."""
171
+ tasks = load_all_tasks()
172
+ return tasks.get(task_id)
173
+
174
+
175
+ def get_hidden_task_ids() -> list[str]:
176
+ """Get the list of hidden (non-public) task IDs."""
177
+ tasks = load_all_tasks()
178
+ return sorted(tid for tid in tasks if tid not in PUBLIC_TASK_IDS)
179
+
180
+
181
+ def get_all_task_ids() -> list[str]:
182
+ """Get all task IDs (public + hidden)."""
183
+ return sorted(load_all_tasks().keys())
184
+
185
+
186
+ def get_public_task_ids() -> list[str]:
187
+ """Get the 3 public task IDs for development."""
188
+ tasks = load_all_tasks()
189
+ return sorted(tid for tid in tasks if tid in PUBLIC_TASK_IDS)
190
+
191
+
192
+ @lru_cache(maxsize=1)
193
+ def load_tool_schemas() -> list[dict[str, Any]]:
194
+ """Load the 17 MCP tool schemas for task payloads."""
195
+ if _TOOL_SCHEMAS_PATH.exists():
196
+ with open(_TOOL_SCHEMAS_PATH) as f:
197
+ return json.load(f)
198
+ return []
199
+
200
+
201
+ def build_task_payload(task_id: str) -> dict[str, Any] | None:
202
+ """Build the payload to send to a submitter's endpoint.
203
+
204
+ Returns:
205
+ Dict with: task_id, task_description, available_tools,
206
+ input_files, design_constraints, max_steps, timeout_sec.
207
+ Returns None if task not found.
208
+ """
209
+ task = get_task(task_id)
210
+ if task is None:
211
+ return None
212
+
213
+ task_json = task["task_json"]
214
+ prompt = task["prompt_md"]
215
+
216
+ # Build input files (base64-encoded PDBs)
217
+ input_files = {}
218
+ if task.get("pdb_data") and task.get("pdb_filename"):
219
+ input_files[task["pdb_filename"]] = task["pdb_data"]
220
+
221
+ # Extract constraints from task JSON
222
+ constraints = task_json.get("design_constraints", {})
223
+ max_designs = task_json.get("max_designs", 10)
224
+
225
+ return {
226
+ "task_id": task_id,
227
+ "task_description": prompt,
228
+ "available_tools": load_tool_schemas(),
229
+ "input_files": input_files,
230
+ "design_constraints": {
231
+ **constraints,
232
+ "max_designs": max_designs,
233
+ },
234
+ "max_steps": 50,
235
+ "timeout_sec": 300,
236
+ }
example_server.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reference FastAPI server for BioDesignBench submitters.
2
+
3
+ This example shows how to implement the API endpoint that BioDesignBench
4
+ will call during benchmarking. Replace the mock agent logic with your
5
+ actual LLM agent + MCP tool pipeline.
6
+
7
+ Usage:
8
+ pip install fastapi uvicorn
9
+ python example_server.py
10
+
11
+ # Or with uvicorn directly:
12
+ uvicorn example_server:app --host 0.0.0.0 --port 8000
13
+
14
+ Your endpoint will receive POST requests at /api/run with the task payload.
15
+
16
+ Task Payload Format:
17
+ {
18
+ "task_id": "dnb_sig_001",
19
+ "task_description": "Design a de novo binder for...",
20
+ "available_tools": [... 17 tool schemas ...],
21
+ "input_files": {"7n1j.pdb": "<base64>"},
22
+ "design_constraints": {"length_range": [80, 150], "max_designs": 10},
23
+ "max_steps": 50,
24
+ "timeout_sec": 300
25
+ }
26
+
27
+ Expected Response Format:
28
+ {
29
+ "sequences": ["MKKL...", "MFQR..."],
30
+ "run_log": [{"step": 1, "tool": "suggest_hotspots", "success": true}, ...],
31
+ "total_steps": 12,
32
+ "total_time_sec": 142.5,
33
+ "metrics": {}
34
+ }
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import base64
40
+ import random
41
+ import time
42
+ from pathlib import Path
43
+ from typing import Any
44
+
45
+ from fastapi import FastAPI
46
+ from pydantic import BaseModel
47
+
48
+ app = FastAPI(
49
+ title="BioDesignBench Example Agent",
50
+ description="Reference implementation for benchmark submission",
51
+ version="0.1.0",
52
+ )
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Request/Response models
57
+ # ---------------------------------------------------------------------------
58
+
59
+
60
+ class TaskPayload(BaseModel):
61
+ task_id: str
62
+ task_description: str
63
+ available_tools: list[dict[str, Any]] = []
64
+ input_files: dict[str, str] = {} # filename -> base64 data
65
+ design_constraints: dict[str, Any] = {}
66
+ max_steps: int = 50
67
+ timeout_sec: int = 300
68
+
69
+
70
+ class AgentResponse(BaseModel):
71
+ sequences: list[str]
72
+ run_log: list[dict[str, Any]]
73
+ total_steps: int
74
+ total_time_sec: float
75
+ metrics: dict[str, Any] = {}
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # Mock agent (replace with your real agent)
80
+ # ---------------------------------------------------------------------------
81
+
82
+ # Standard amino acids for mock sequence generation
83
+ _AAS = "ACDEFGHIKLMNPQRSTVWY"
84
+
85
+
86
+ def _generate_mock_sequence(length: int) -> str:
87
+ """Generate a random protein sequence with reasonable composition."""
88
+ # Weight towards common amino acids
89
+ weights = [
90
+ 7, 2, 5, 6, 4, 7, 2, 5, 6, 9, # A C D E F G H I K L
91
+ 2, 4, 5, 4, 5, 7, 6, 7, 1, 3, # M N P Q R S T V W Y
92
+ ]
93
+ return "".join(random.choices(_AAS, weights=weights, k=length))
94
+
95
+
96
+ def mock_agent(payload: TaskPayload) -> AgentResponse:
97
+ """Mock agent that generates random but valid designs.
98
+
99
+ Replace this with your actual LLM agent + MCP tool pipeline.
100
+ This mock demonstrates the expected response format.
101
+ """
102
+ start = time.monotonic()
103
+
104
+ # Determine design parameters
105
+ constraints = payload.design_constraints
106
+ length_range = constraints.get("length_range", [80, 150])
107
+ max_designs = constraints.get("max_designs", 10)
108
+ num_designs = min(max_designs, 5) # Generate 5 for this mock
109
+
110
+ # "Decode" input PDB files (in a real agent, you'd use these)
111
+ for filename, b64_data in payload.input_files.items():
112
+ pdb_bytes = base64.b64decode(b64_data)
113
+ # In a real agent: save to temp file and pass to MCP tools
114
+
115
+ # Simulate a multi-step design pipeline
116
+ run_log = [
117
+ {
118
+ "step": 1,
119
+ "tool": "suggest_hotspots",
120
+ "success": True,
121
+ "args_summary": {"target": "from_pdb"},
122
+ "output_summary": "Found 5 hotspot residues",
123
+ },
124
+ {
125
+ "step": 2,
126
+ "tool": "generate_backbone",
127
+ "success": True,
128
+ "args_summary": {"length": length_range[0]},
129
+ "output_summary": f"Generated {num_designs} backbones",
130
+ },
131
+ {
132
+ "step": 3,
133
+ "tool": "optimize_sequence",
134
+ "success": True,
135
+ "args_summary": {"optimization_target": "both"},
136
+ "output_summary": f"Optimized {num_designs} sequences",
137
+ },
138
+ {
139
+ "step": 4,
140
+ "tool": "predict_structure",
141
+ "success": True,
142
+ "args_summary": {"predictor": "esmfold"},
143
+ "output_summary": "Predicted structures for all designs",
144
+ },
145
+ {
146
+ "step": 5,
147
+ "tool": "validate_design",
148
+ "success": True,
149
+ "args_summary": {},
150
+ "output_summary": "Validated all designs",
151
+ },
152
+ ]
153
+
154
+ # Generate mock sequences
155
+ min_len, max_len = length_range
156
+ sequences = [
157
+ _generate_mock_sequence(random.randint(min_len, max_len))
158
+ for _ in range(num_designs)
159
+ ]
160
+
161
+ elapsed = time.monotonic() - start
162
+
163
+ return AgentResponse(
164
+ sequences=sequences,
165
+ run_log=run_log,
166
+ total_steps=len(run_log),
167
+ total_time_sec=round(elapsed, 2),
168
+ metrics={}, # Agent-reported metrics (optional)
169
+ )
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # API endpoint
174
+ # ---------------------------------------------------------------------------
175
+
176
+
177
+ @app.post("/api/run", response_model=AgentResponse)
178
+ async def run_task(payload: TaskPayload) -> AgentResponse:
179
+ """Run a single benchmark task.
180
+
181
+ This is the endpoint that BioDesignBench will POST to during benchmarking.
182
+ Replace mock_agent() with your actual agent logic.
183
+ """
184
+ return mock_agent(payload)
185
+
186
+
187
+ @app.get("/health")
188
+ async def health():
189
+ """Health check endpoint."""
190
+ return {"status": "ok", "agent": "example-mock"}
191
+
192
+
193
+ # ---------------------------------------------------------------------------
194
+ # Entry point
195
+ # ---------------------------------------------------------------------------
196
+
197
+ if __name__ == "__main__":
198
+ import uvicorn
199
+
200
+ print("Starting BioDesignBench example server...")
201
+ print("POST endpoint: http://localhost:8000/api/run")
202
+ print("Health check: http://localhost:8000/health")
203
+ print()
204
+ print("Replace mock_agent() with your real agent logic.")
205
+ uvicorn.run(app, host="0.0.0.0", port=8000)
mcp_tool_schemas.json ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "design_binder",
4
+ "description": "Design protein binders for a target protein. Runs RFdiffusion -> ProteinMPNN -> ESMFold pipeline.",
5
+ "parameters": {
6
+ "type": "object",
7
+ "properties": {
8
+ "target_pdb": {
9
+ "type": "string",
10
+ "description": "Path to target protein PDB file"
11
+ },
12
+ "hotspot_residues": {
13
+ "type": "array",
14
+ "items": {
15
+ "type": "string"
16
+ },
17
+ "description": "Target residues for binder interface, e.g. ['A45', 'A46']"
18
+ },
19
+ "num_designs": {
20
+ "type": "integer",
21
+ "description": "Number of designs to generate (default: 10)",
22
+ "default": 10
23
+ },
24
+ "binder_length": {
25
+ "type": "integer",
26
+ "description": "Binder length in residues (default: 80)",
27
+ "default": 80
28
+ }
29
+ },
30
+ "required": [
31
+ "target_pdb",
32
+ "hotspot_residues"
33
+ ]
34
+ }
35
+ },
36
+ {
37
+ "name": "analyze_interface",
38
+ "description": "Analyze protein-protein interface: buried surface area, H-bonds, salt bridges, hydrophobic contacts.",
39
+ "parameters": {
40
+ "type": "object",
41
+ "properties": {
42
+ "complex_pdb": {
43
+ "type": "string",
44
+ "description": "Path to complex PDB file"
45
+ },
46
+ "chain_a": {
47
+ "type": "string",
48
+ "description": "Chain ID of first protein"
49
+ },
50
+ "chain_b": {
51
+ "type": "string",
52
+ "description": "Chain ID of second protein"
53
+ }
54
+ },
55
+ "required": [
56
+ "complex_pdb",
57
+ "chain_a",
58
+ "chain_b"
59
+ ]
60
+ }
61
+ },
62
+ {
63
+ "name": "validate_design",
64
+ "description": "Validate a designed sequence by predicting its structure (ESMFold/AlphaFold2) and computing pLDDT, pTM.",
65
+ "parameters": {
66
+ "type": "object",
67
+ "properties": {
68
+ "sequence": {
69
+ "type": "string",
70
+ "description": "Amino acid sequence to validate"
71
+ },
72
+ "expected_structure": {
73
+ "type": "string",
74
+ "description": "Optional PDB path for RMSD comparison"
75
+ },
76
+ "predictor": {
77
+ "type": "string",
78
+ "enum": [
79
+ "esmfold",
80
+ "alphafold2"
81
+ ],
82
+ "default": "esmfold",
83
+ "description": "Structure predictor to use"
84
+ }
85
+ },
86
+ "required": [
87
+ "sequence"
88
+ ]
89
+ }
90
+ },
91
+ {
92
+ "name": "optimize_sequence",
93
+ "description": "Optimize binder sequence for improved stability and/or binding affinity.",
94
+ "parameters": {
95
+ "type": "object",
96
+ "properties": {
97
+ "current_sequence": {
98
+ "type": "string",
99
+ "description": "Starting amino acid sequence"
100
+ },
101
+ "target_pdb": {
102
+ "type": "string",
103
+ "description": "Path to target protein PDB"
104
+ },
105
+ "optimization_target": {
106
+ "type": "string",
107
+ "enum": [
108
+ "stability",
109
+ "affinity",
110
+ "both"
111
+ ],
112
+ "default": "both"
113
+ },
114
+ "fixed_positions": {
115
+ "type": "array",
116
+ "items": {
117
+ "type": "integer"
118
+ },
119
+ "description": "Positions to keep fixed (1-indexed)"
120
+ }
121
+ },
122
+ "required": [
123
+ "current_sequence",
124
+ "target_pdb"
125
+ ]
126
+ }
127
+ },
128
+ {
129
+ "name": "suggest_hotspots",
130
+ "description": "Analyze target protein and suggest binding hotspots using structure, conservation, and literature.",
131
+ "parameters": {
132
+ "type": "object",
133
+ "properties": {
134
+ "target": {
135
+ "type": "string",
136
+ "description": "Protein name, UniProt ID, PDB ID, or local PDB path"
137
+ },
138
+ "chain_id": {
139
+ "type": "string",
140
+ "description": "Chain to analyze (default: first)"
141
+ },
142
+ "criteria": {
143
+ "type": "string",
144
+ "enum": [
145
+ "druggable",
146
+ "exposed",
147
+ "conserved"
148
+ ],
149
+ "default": "exposed"
150
+ },
151
+ "include_literature": {
152
+ "type": "boolean",
153
+ "default": false,
154
+ "description": "Search PubMed for known binders"
155
+ }
156
+ },
157
+ "required": [
158
+ "target"
159
+ ]
160
+ }
161
+ },
162
+ {
163
+ "name": "get_design_status",
164
+ "description": "Check status of running design jobs.",
165
+ "parameters": {
166
+ "type": "object",
167
+ "properties": {
168
+ "job_id": {
169
+ "type": "string",
170
+ "description": "Job ID from design_binder call"
171
+ }
172
+ },
173
+ "required": [
174
+ "job_id"
175
+ ]
176
+ }
177
+ },
178
+ {
179
+ "name": "predict_complex",
180
+ "description": "Predict protein complex structure using AlphaFold2-Multimer.",
181
+ "parameters": {
182
+ "type": "object",
183
+ "properties": {
184
+ "sequences": {
185
+ "type": "array",
186
+ "items": {
187
+ "type": "string"
188
+ },
189
+ "description": "List of sequences, one per chain"
190
+ },
191
+ "chain_names": {
192
+ "type": "array",
193
+ "items": {
194
+ "type": "string"
195
+ },
196
+ "description": "Optional chain identifiers"
197
+ }
198
+ },
199
+ "required": [
200
+ "sequences"
201
+ ]
202
+ }
203
+ },
204
+ {
205
+ "name": "predict_structure",
206
+ "description": "Predict the 3D structure of a single protein chain using ESMFold or AlphaFold2. Returns predicted PDB, pLDDT, and pTM scores.",
207
+ "parameters": {
208
+ "type": "object",
209
+ "properties": {
210
+ "sequence": {
211
+ "type": "string",
212
+ "description": "Amino acid sequence to predict structure for"
213
+ },
214
+ "predictor": {
215
+ "type": "string",
216
+ "enum": [
217
+ "esmfold",
218
+ "alphafold2"
219
+ ],
220
+ "default": "esmfold",
221
+ "description": "Structure predictor to use"
222
+ }
223
+ },
224
+ "required": [
225
+ "sequence"
226
+ ]
227
+ }
228
+ },
229
+ {
230
+ "name": "score_stability",
231
+ "description": "Score protein stability using ESM2 pseudo-log-likelihood. Optionally compute per-mutation effects (delta log-likelihood).",
232
+ "parameters": {
233
+ "type": "object",
234
+ "properties": {
235
+ "sequence": {
236
+ "type": "string",
237
+ "description": "Amino acid sequence to score"
238
+ },
239
+ "mutations": {
240
+ "type": "array",
241
+ "items": {
242
+ "type": "string"
243
+ },
244
+ "description": "Optional mutations in 'X42Y' format for delta scoring"
245
+ },
246
+ "reference_sequence": {
247
+ "type": "string",
248
+ "description": "Optional wild-type sequence for mutation scoring"
249
+ }
250
+ },
251
+ "required": [
252
+ "sequence"
253
+ ]
254
+ }
255
+ },
256
+ {
257
+ "name": "energy_minimize",
258
+ "description": "Energy-minimize a protein structure using OpenMM with AMBER14 force field. Returns minimized PDB, energy change, and RMSD from initial structure.",
259
+ "parameters": {
260
+ "type": "object",
261
+ "properties": {
262
+ "pdb_path": {
263
+ "type": "string",
264
+ "description": "Path to input PDB file"
265
+ },
266
+ "force_field": {
267
+ "type": "string",
268
+ "default": "amber14-all.xml",
269
+ "description": "OpenMM force field XML"
270
+ },
271
+ "num_steps": {
272
+ "type": "integer",
273
+ "default": 500,
274
+ "description": "Maximum minimization iterations"
275
+ },
276
+ "solvent": {
277
+ "type": "string",
278
+ "enum": [
279
+ "implicit",
280
+ "none"
281
+ ],
282
+ "default": "implicit",
283
+ "description": "Solvent model: implicit (GBn2) or none (vacuum)"
284
+ }
285
+ },
286
+ "required": [
287
+ "pdb_path"
288
+ ]
289
+ }
290
+ },
291
+ {
292
+ "name": "generate_backbone",
293
+ "description": "Generate de novo protein backbones using RFdiffusion unconditional generation. No target protein required.",
294
+ "parameters": {
295
+ "type": "object",
296
+ "properties": {
297
+ "length": {
298
+ "type": "integer",
299
+ "description": "Backbone length in residues"
300
+ },
301
+ "num_designs": {
302
+ "type": "integer",
303
+ "default": 10,
304
+ "description": "Number of designs to generate"
305
+ }
306
+ },
307
+ "required": [
308
+ "length"
309
+ ]
310
+ }
311
+ },
312
+ {
313
+ "name": "rosetta_score",
314
+ "description": "Score a protein structure using Rosetta energy function (ref2015). Returns total score, per-residue energies, and energy breakdown.",
315
+ "parameters": {
316
+ "type": "object",
317
+ "properties": {
318
+ "pdb_path": {
319
+ "type": "string",
320
+ "description": "Path to input PDB file"
321
+ },
322
+ "score_function": {
323
+ "type": "string",
324
+ "default": "ref2015",
325
+ "description": "Rosetta score function name"
326
+ }
327
+ },
328
+ "required": [
329
+ "pdb_path"
330
+ ]
331
+ }
332
+ },
333
+ {
334
+ "name": "rosetta_relax",
335
+ "description": "Relax a protein structure using Rosetta FastRelax. Returns relaxed PDB, energy change, and CA-RMSD.",
336
+ "parameters": {
337
+ "type": "object",
338
+ "properties": {
339
+ "pdb_path": {
340
+ "type": "string",
341
+ "description": "Path to input PDB file"
342
+ },
343
+ "nstruct": {
344
+ "type": "integer",
345
+ "default": 1,
346
+ "description": "Number of relaxation trajectories"
347
+ },
348
+ "score_function": {
349
+ "type": "string",
350
+ "default": "ref2015",
351
+ "description": "Rosetta score function name"
352
+ }
353
+ },
354
+ "required": [
355
+ "pdb_path"
356
+ ]
357
+ }
358
+ },
359
+ {
360
+ "name": "rosetta_interface_score",
361
+ "description": "Compute interface energy metrics for a protein complex using Rosetta. Returns dG_separated, dSASA, interface hbonds, and packing stats.",
362
+ "parameters": {
363
+ "type": "object",
364
+ "properties": {
365
+ "pdb_path": {
366
+ "type": "string",
367
+ "description": "Path to complex PDB file"
368
+ },
369
+ "chains": {
370
+ "type": "string",
371
+ "default": "A_B",
372
+ "description": "Chain grouping, e.g. 'A_B' or 'AB_C'"
373
+ },
374
+ "score_function": {
375
+ "type": "string",
376
+ "default": "ref2015",
377
+ "description": "Rosetta score function name"
378
+ }
379
+ },
380
+ "required": [
381
+ "pdb_path"
382
+ ]
383
+ }
384
+ },
385
+ {
386
+ "name": "rosetta_design",
387
+ "description": "Fixed-backbone sequence design using Rosetta PackRotamers + MinMover. Composite convenience tool (hidden in benchmark mode).",
388
+ "parameters": {
389
+ "type": "object",
390
+ "properties": {
391
+ "pdb_path": {
392
+ "type": "string",
393
+ "description": "Path to input PDB file"
394
+ },
395
+ "chains": {
396
+ "type": "string",
397
+ "default": "A_B",
398
+ "description": "Chain grouping for interface detection"
399
+ },
400
+ "fixed_positions": {
401
+ "type": "array",
402
+ "items": {
403
+ "type": "integer"
404
+ },
405
+ "description": "1-indexed positions to keep fixed"
406
+ },
407
+ "score_function": {
408
+ "type": "string",
409
+ "default": "ref2015",
410
+ "description": "Rosetta score function name"
411
+ }
412
+ },
413
+ "required": [
414
+ "pdb_path"
415
+ ]
416
+ }
417
+ },
418
+ {
419
+ "name": "predict_structure_boltz",
420
+ "description": "Predict protein structure using Boltz (fast alternative to AF2/ESMFold). Returns predicted PDB, pLDDT, and pTM scores.",
421
+ "parameters": {
422
+ "type": "object",
423
+ "properties": {
424
+ "sequence": {
425
+ "type": "string",
426
+ "description": "Amino acid sequence to predict structure for"
427
+ },
428
+ "model": {
429
+ "type": "string",
430
+ "default": "boltz2",
431
+ "description": "Model name (default: boltz2)"
432
+ },
433
+ "num_samples": {
434
+ "type": "integer",
435
+ "default": 1,
436
+ "description": "Number of structure samples"
437
+ }
438
+ },
439
+ "required": [
440
+ "sequence"
441
+ ]
442
+ }
443
+ },
444
+ {
445
+ "name": "predict_affinity_boltz",
446
+ "description": "Predict binding affinity for a protein complex using Boltz. Returns affinity score, predicted structure, and confidence metrics.",
447
+ "parameters": {
448
+ "type": "object",
449
+ "properties": {
450
+ "sequences": {
451
+ "type": "array",
452
+ "items": {
453
+ "type": "string"
454
+ },
455
+ "description": "List of amino acid sequences, one per chain"
456
+ },
457
+ "model": {
458
+ "type": "string",
459
+ "default": "boltz2",
460
+ "description": "Model name (default: boltz2)"
461
+ }
462
+ },
463
+ "required": [
464
+ "sequences"
465
+ ]
466
+ }
467
+ }
468
+ ]
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- gradio>=4.0
2
  pandas
3
  plotly
 
 
 
 
 
 
1
+ gradio>=5.6
2
  pandas
3
  plotly
4
+ httpx>=0.25
5
+ huggingface_hub>=0.20
6
+ datasets>=2.16
7
+ boltz>=0.4
8
+ pyaudioop