feat: add submission & scoring infrastructure (eval_scorer, dispatcher, boltz, queue, tasks) + fix gradio 5.x for Python 3.13
Browse files- app.py +407 -1
- eval_boltz.py +272 -0
- eval_dispatcher.py +361 -0
- eval_queue.py +312 -0
- eval_scorer.py +1643 -0
- eval_tasks.py +236 -0
- example_server.py +205 -0
- mcp_tool_schemas.json +468 -0
- requirements.txt +6 -1
app.py
CHANGED
|
@@ -2,14 +2,26 @@
|
|
| 2 |
|
| 3 |
Evaluating LLM Agents on Protein Design via MCP Tools
|
| 4 |
Romero Lab, Duke University
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import json
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
import plotly.graph_objects as go
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# ═══════════════════════════════════════════════════════════════════
|
| 15 |
# Configuration — change these when deploying
|
|
@@ -916,7 +928,401 @@ def create_app() -> gr.Blocks:
|
|
| 916 |
gr.Plot(chart_mode_comparison(entries))
|
| 917 |
gr.HTML(build_mode_cards(entries))
|
| 918 |
|
| 919 |
-
# ══════
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
with gr.Tab("\u2139\ufe0f About"):
|
| 921 |
gr.HTML(build_about())
|
| 922 |
|
|
|
|
| 2 |
|
| 3 |
Evaluating LLM Agents on Protein Design via MCP Tools
|
| 4 |
Romero Lab, Duke University
|
| 5 |
+
|
| 6 |
+
Tabs:
|
| 7 |
+
1. Overall Leaderboard
|
| 8 |
+
2. Taxonomy Breakdown
|
| 9 |
+
3. Component Analysis
|
| 10 |
+
4. Benchmark vs User
|
| 11 |
+
5. Submit (new submission form)
|
| 12 |
+
6. Status & Admin (password-protected pipeline control)
|
| 13 |
+
7. About
|
| 14 |
"""
|
| 15 |
|
| 16 |
import json
|
| 17 |
+
import os
|
| 18 |
from pathlib import Path
|
| 19 |
|
| 20 |
import gradio as gr
|
| 21 |
import plotly.graph_objects as go
|
| 22 |
|
| 23 |
+
ADMIN_PASSWORD = os.environ.get("BDB_ADMIN_PASSWORD", "biodesignbench2026")
|
| 24 |
+
|
| 25 |
|
| 26 |
# ═══════════════════════════════════════════════════════════════════
|
| 27 |
# Configuration — change these when deploying
|
|
|
|
| 928 |
gr.Plot(chart_mode_comparison(entries))
|
| 929 |
gr.HTML(build_mode_cards(entries))
|
| 930 |
|
| 931 |
+
# ══════ Tab 5: Submit ══════
|
| 932 |
+
with gr.Tab("\U0001f4e4 Submit"):
|
| 933 |
+
gr.HTML("""
|
| 934 |
+
<div style="max-width:700px;margin:0 auto;padding:1rem">
|
| 935 |
+
<h2 style="color:#1a365d;margin:0 0 0.5rem">
|
| 936 |
+
Submit Your Agent</h2>
|
| 937 |
+
<p style="color:#4a5568;margin-bottom:1rem;line-height:1.5">
|
| 938 |
+
Submit your protein design agent for benchmarking.
|
| 939 |
+
Your agent must be hosted as a POST endpoint that accepts
|
| 940 |
+
task descriptions and returns designed sequences.
|
| 941 |
+
<strong>You bear all LLM and MCP tool costs</strong>;
|
| 942 |
+
we only run Boltz structure prediction on our end.</p>
|
| 943 |
+
<div style="background:#fefcbf;border-left:4px solid #d69e2e;
|
| 944 |
+
padding:0.8rem;border-radius:4px;margin-bottom:1rem;
|
| 945 |
+
font-size:0.85rem;color:#744210">
|
| 946 |
+
<strong>Rate limit:</strong> 2 submissions per calendar
|
| 947 |
+
month per organization.</div>
|
| 948 |
+
</div>""")
|
| 949 |
+
|
| 950 |
+
with gr.Column(scale=1):
|
| 951 |
+
sub_agent = gr.Textbox(
|
| 952 |
+
label="Agent Name",
|
| 953 |
+
placeholder="e.g., GPT-5 + Custom MCP Tools",
|
| 954 |
+
)
|
| 955 |
+
sub_org = gr.Textbox(
|
| 956 |
+
label="Organization",
|
| 957 |
+
placeholder="e.g., OpenAI",
|
| 958 |
+
)
|
| 959 |
+
sub_url = gr.Textbox(
|
| 960 |
+
label="Endpoint URL",
|
| 961 |
+
placeholder="https://your-server.com/api/run",
|
| 962 |
+
)
|
| 963 |
+
sub_desc = gr.Textbox(
|
| 964 |
+
label="Description (optional)",
|
| 965 |
+
placeholder="Brief description of your agent...",
|
| 966 |
+
lines=3,
|
| 967 |
+
)
|
| 968 |
+
sub_mcp = gr.Checkbox(
|
| 969 |
+
label="Uses custom MCP tools (not reference)",
|
| 970 |
+
value=False,
|
| 971 |
+
)
|
| 972 |
+
sub_btn = gr.Button(
|
| 973 |
+
"Submit for Review",
|
| 974 |
+
variant="primary",
|
| 975 |
+
)
|
| 976 |
+
sub_result = gr.HTML()
|
| 977 |
+
|
| 978 |
+
def _handle_submit(name, org, url, desc, mcp):
|
| 979 |
+
if not name or not org or not url:
|
| 980 |
+
return ('<div style="color:#e53e3e;padding:0.5rem">'
|
| 981 |
+
"Please fill in all required fields.</div>")
|
| 982 |
+
if not url.startswith(("http://", "https://")):
|
| 983 |
+
return ('<div style="color:#e53e3e;padding:0.5rem">'
|
| 984 |
+
"URL must start with http:// or https://</div>")
|
| 985 |
+
try:
|
| 986 |
+
from eval_queue import submit
|
| 987 |
+
result = submit(
|
| 988 |
+
agent_name=name,
|
| 989 |
+
organization=org,
|
| 990 |
+
endpoint_url=url,
|
| 991 |
+
description=desc,
|
| 992 |
+
mcp_custom=mcp,
|
| 993 |
+
)
|
| 994 |
+
if "error" in result:
|
| 995 |
+
return (f'<div style="color:#e53e3e;padding:0.5rem">'
|
| 996 |
+
f'{result["error"]}</div>')
|
| 997 |
+
return (
|
| 998 |
+
f'<div style="background:#c6f6d5;padding:1rem;'
|
| 999 |
+
f'border-radius:8px;margin-top:0.5rem">'
|
| 1000 |
+
f'<strong>Submitted!</strong> '
|
| 1001 |
+
f'ID: <code>{result["submission_id"]}</code><br>'
|
| 1002 |
+
f'Status: {result["status"]}<br>'
|
| 1003 |
+
f'{result.get("message", "")}</div>'
|
| 1004 |
+
)
|
| 1005 |
+
except Exception as e:
|
| 1006 |
+
return (f'<div style="color:#e53e3e;padding:0.5rem">'
|
| 1007 |
+
f"Error: {str(e)[:200]}</div>")
|
| 1008 |
+
|
| 1009 |
+
sub_btn.click(
|
| 1010 |
+
_handle_submit,
|
| 1011 |
+
[sub_agent, sub_org, sub_url, sub_desc, sub_mcp],
|
| 1012 |
+
sub_result,
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
# ══════ Tab 6: Status & Admin ══════
|
| 1016 |
+
with gr.Tab("\U0001f6e0 Status"):
|
| 1017 |
+
gr.HTML("""
|
| 1018 |
+
<div style="max-width:800px;margin:0 auto;padding:1rem">
|
| 1019 |
+
<h2 style="color:#1a365d;margin:0 0 0.5rem">
|
| 1020 |
+
Submission Status & Admin</h2>
|
| 1021 |
+
<p style="color:#4a5568;margin-bottom:0.5rem">
|
| 1022 |
+
Check your submission status or manage the pipeline
|
| 1023 |
+
(admin only).</p>
|
| 1024 |
+
</div>""")
|
| 1025 |
+
|
| 1026 |
+
# --- Public status check ---
|
| 1027 |
+
with gr.Accordion("Check Submission Status", open=True):
|
| 1028 |
+
status_id = gr.Textbox(
|
| 1029 |
+
label="Submission ID",
|
| 1030 |
+
placeholder="Enter your submission ID...",
|
| 1031 |
+
)
|
| 1032 |
+
status_btn = gr.Button("Check Status")
|
| 1033 |
+
status_out = gr.HTML()
|
| 1034 |
+
|
| 1035 |
+
def _check_status(sid):
|
| 1036 |
+
if not sid:
|
| 1037 |
+
return '<div style="color:#718096">Enter an ID above.</div>'
|
| 1038 |
+
try:
|
| 1039 |
+
from eval_queue import get_submission
|
| 1040 |
+
sub = get_submission(sid.strip())
|
| 1041 |
+
if sub is None:
|
| 1042 |
+
return ('<div style="color:#e53e3e">'
|
| 1043 |
+
"Submission not found.</div>")
|
| 1044 |
+
status_color = {
|
| 1045 |
+
"pending": "#d69e2e", "approved": "#38a169",
|
| 1046 |
+
"dispatching": "#3182ce", "boltz": "#805ad5",
|
| 1047 |
+
"scoring": "#805ad5", "complete": "#38a169",
|
| 1048 |
+
"failed": "#e53e3e", "rejected": "#e53e3e",
|
| 1049 |
+
}.get(sub["status"], "#718096")
|
| 1050 |
+
score_html = ""
|
| 1051 |
+
if sub.get("overall_score") is not None:
|
| 1052 |
+
score_html = (
|
| 1053 |
+
f'<div style="font-size:1.2rem;'
|
| 1054 |
+
f'font-weight:700;color:#1a365d;'
|
| 1055 |
+
f'margin-top:0.5rem">'
|
| 1056 |
+
f'Score: {sub["overall_score"]:.1f}/100'
|
| 1057 |
+
f'</div>'
|
| 1058 |
+
)
|
| 1059 |
+
return (
|
| 1060 |
+
f'<div style="background:white;padding:1rem;'
|
| 1061 |
+
f'border-radius:8px;border:1px solid #e2e8f0">'
|
| 1062 |
+
f'<strong>{sub["agent_name"]}</strong> '
|
| 1063 |
+
f'({sub["organization"]})<br>'
|
| 1064 |
+
f'Status: <span style="color:{status_color};'
|
| 1065 |
+
f'font-weight:700">{sub["status"]}</span><br>'
|
| 1066 |
+
f'Tasks: {sub.get("tasks_dispatched", 0)}'
|
| 1067 |
+
f'/{sub.get("tasks_total", 76)}<br>'
|
| 1068 |
+
f'Created: {sub.get("created_at", "")[:10]}'
|
| 1069 |
+
f'{score_html}</div>'
|
| 1070 |
+
)
|
| 1071 |
+
except Exception as e:
|
| 1072 |
+
return f'<div style="color:#e53e3e">{e}</div>'
|
| 1073 |
+
|
| 1074 |
+
status_btn.click(_check_status, [status_id], status_out)
|
| 1075 |
+
|
| 1076 |
+
# --- Admin panel (password-protected) ---
|
| 1077 |
+
with gr.Accordion("Admin Panel", open=False):
|
| 1078 |
+
admin_pw = gr.Textbox(
|
| 1079 |
+
label="Admin Password", type="password",
|
| 1080 |
+
)
|
| 1081 |
+
admin_auth_btn = gr.Button("Authenticate")
|
| 1082 |
+
admin_panel = gr.Column(visible=False)
|
| 1083 |
+
admin_msg = gr.HTML()
|
| 1084 |
+
|
| 1085 |
+
with admin_panel:
|
| 1086 |
+
gr.HTML('<h3 style="color:#1a365d">'
|
| 1087 |
+
'Pending Submissions</h3>')
|
| 1088 |
+
pending_html = gr.HTML()
|
| 1089 |
+
refresh_btn = gr.Button("Refresh List")
|
| 1090 |
+
|
| 1091 |
+
with gr.Row():
|
| 1092 |
+
approve_id = gr.Textbox(
|
| 1093 |
+
label="Submission ID to Approve/Reject",
|
| 1094 |
+
scale=2,
|
| 1095 |
+
)
|
| 1096 |
+
approve_btn = gr.Button(
|
| 1097 |
+
"Approve", variant="primary", scale=1,
|
| 1098 |
+
)
|
| 1099 |
+
reject_btn = gr.Button(
|
| 1100 |
+
"Reject", variant="stop", scale=1,
|
| 1101 |
+
)
|
| 1102 |
+
approve_msg = gr.HTML()
|
| 1103 |
+
|
| 1104 |
+
gr.HTML('<h3 style="color:#1a365d;margin-top:1rem">'
|
| 1105 |
+
'Pipeline Control</h3>')
|
| 1106 |
+
with gr.Row():
|
| 1107 |
+
dispatch_id = gr.Textbox(
|
| 1108 |
+
label="Submission ID", scale=2,
|
| 1109 |
+
)
|
| 1110 |
+
dispatch_btn = gr.Button(
|
| 1111 |
+
"Phase A: Dispatch Tasks", scale=1,
|
| 1112 |
+
)
|
| 1113 |
+
with gr.Row():
|
| 1114 |
+
boltz_id = gr.Textbox(
|
| 1115 |
+
label="Submission ID", scale=2,
|
| 1116 |
+
)
|
| 1117 |
+
boltz_btn = gr.Button(
|
| 1118 |
+
"Phase B: Run Boltz", scale=1,
|
| 1119 |
+
)
|
| 1120 |
+
with gr.Row():
|
| 1121 |
+
final_id = gr.Textbox(
|
| 1122 |
+
label="Submission ID", scale=2,
|
| 1123 |
+
)
|
| 1124 |
+
final_btn = gr.Button(
|
| 1125 |
+
"Phase C: Finalize & Publish", scale=1,
|
| 1126 |
+
)
|
| 1127 |
+
pipeline_out = gr.HTML()
|
| 1128 |
+
|
| 1129 |
+
def _admin_auth(pw):
|
| 1130 |
+
if pw == ADMIN_PASSWORD:
|
| 1131 |
+
return (
|
| 1132 |
+
gr.update(visible=True),
|
| 1133 |
+
'<div style="color:#38a169">'
|
| 1134 |
+
'Authenticated.</div>',
|
| 1135 |
+
)
|
| 1136 |
+
return (
|
| 1137 |
+
gr.update(visible=False),
|
| 1138 |
+
'<div style="color:#e53e3e">'
|
| 1139 |
+
'Wrong password.</div>',
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
admin_auth_btn.click(
|
| 1143 |
+
_admin_auth, [admin_pw],
|
| 1144 |
+
[admin_panel, admin_msg],
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
def _refresh_pending():
|
| 1148 |
+
try:
|
| 1149 |
+
from eval_queue import get_pending_submissions
|
| 1150 |
+
pending = get_pending_submissions()
|
| 1151 |
+
if not pending:
|
| 1152 |
+
return "<p>No pending submissions.</p>"
|
| 1153 |
+
rows = []
|
| 1154 |
+
for s in pending:
|
| 1155 |
+
rows.append(
|
| 1156 |
+
f'<tr><td>{s["submission_id"]}</td>'
|
| 1157 |
+
f'<td>{s["agent_name"]}</td>'
|
| 1158 |
+
f'<td>{s["organization"]}</td>'
|
| 1159 |
+
f'<td>{s.get("endpoint_url","")[:40]}'
|
| 1160 |
+
f'...</td>'
|
| 1161 |
+
f'<td>{s.get("created_at","")[:10]}'
|
| 1162 |
+
f'</td></tr>'
|
| 1163 |
+
)
|
| 1164 |
+
return (
|
| 1165 |
+
'<table style="width:100%;font-size:0.85rem;'
|
| 1166 |
+
'border-collapse:collapse">'
|
| 1167 |
+
"<tr><th>ID</th><th>Agent</th><th>Org</th>"
|
| 1168 |
+
"<th>URL</th><th>Date</th></tr>"
|
| 1169 |
+
+ "".join(rows) + "</table>"
|
| 1170 |
+
)
|
| 1171 |
+
except Exception as e:
|
| 1172 |
+
return f"<p>Error: {e}</p>"
|
| 1173 |
+
|
| 1174 |
+
refresh_btn.click(
|
| 1175 |
+
_refresh_pending, [], pending_html,
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
def _approve_sub(sid):
|
| 1179 |
+
try:
|
| 1180 |
+
from eval_queue import update_status
|
| 1181 |
+
ok = update_status(sid.strip(), "approved")
|
| 1182 |
+
if ok:
|
| 1183 |
+
return (
|
| 1184 |
+
f'<div style="color:#38a169">'
|
| 1185 |
+
f'Approved: {sid}</div>'
|
| 1186 |
+
)
|
| 1187 |
+
return (
|
| 1188 |
+
f'<div style="color:#e53e3e">'
|
| 1189 |
+
f'Failed to approve {sid}</div>'
|
| 1190 |
+
)
|
| 1191 |
+
except Exception as e:
|
| 1192 |
+
return f'<div style="color:#e53e3e">{e}</div>'
|
| 1193 |
+
|
| 1194 |
+
def _reject_sub(sid):
|
| 1195 |
+
try:
|
| 1196 |
+
from eval_queue import update_status
|
| 1197 |
+
ok = update_status(sid.strip(), "rejected")
|
| 1198 |
+
if ok:
|
| 1199 |
+
return (
|
| 1200 |
+
f'<div style="color:#d69e2e">'
|
| 1201 |
+
f'Rejected: {sid}</div>'
|
| 1202 |
+
)
|
| 1203 |
+
return (
|
| 1204 |
+
f'<div style="color:#e53e3e">'
|
| 1205 |
+
f'Failed to reject {sid}</div>'
|
| 1206 |
+
)
|
| 1207 |
+
except Exception as e:
|
| 1208 |
+
return f'<div style="color:#e53e3e">{e}</div>'
|
| 1209 |
+
|
| 1210 |
+
approve_btn.click(
|
| 1211 |
+
_approve_sub, [approve_id], approve_msg,
|
| 1212 |
+
)
|
| 1213 |
+
reject_btn.click(
|
| 1214 |
+
_reject_sub, [approve_id], approve_msg,
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
def _run_dispatch(sid):
|
| 1218 |
+
try:
|
| 1219 |
+
import asyncio as _aio
|
| 1220 |
+
from eval_queue import get_submission
|
| 1221 |
+
from eval_dispatcher import dispatch_all_tasks
|
| 1222 |
+
|
| 1223 |
+
sub = get_submission(sid.strip())
|
| 1224 |
+
if sub is None:
|
| 1225 |
+
return (
|
| 1226 |
+
'<div style="color:#e53e3e">'
|
| 1227 |
+
'Not found</div>'
|
| 1228 |
+
)
|
| 1229 |
+
if sub["status"] not in (
|
| 1230 |
+
"approved", "dispatching"
|
| 1231 |
+
):
|
| 1232 |
+
return (
|
| 1233 |
+
f'<div style="color:#e53e3e">'
|
| 1234 |
+
f'Cannot dispatch: status='
|
| 1235 |
+
f'{sub["status"]}</div>'
|
| 1236 |
+
)
|
| 1237 |
+
loop = _aio.new_event_loop()
|
| 1238 |
+
results = loop.run_until_complete(
|
| 1239 |
+
dispatch_all_tasks(
|
| 1240 |
+
sid.strip(),
|
| 1241 |
+
sub["endpoint_url"],
|
| 1242 |
+
)
|
| 1243 |
+
)
|
| 1244 |
+
loop.close()
|
| 1245 |
+
ok = sum(
|
| 1246 |
+
1 for r in results if r.get("success")
|
| 1247 |
+
)
|
| 1248 |
+
return (
|
| 1249 |
+
f'<div style="color:#38a169">'
|
| 1250 |
+
f'Dispatched: {ok}/{len(results)} '
|
| 1251 |
+
f'tasks succeeded.</div>'
|
| 1252 |
+
)
|
| 1253 |
+
except Exception as e:
|
| 1254 |
+
return f'<div style="color:#e53e3e">{e}</div>'
|
| 1255 |
+
|
| 1256 |
+
def _run_boltz(sid):
|
| 1257 |
+
try:
|
| 1258 |
+
from eval_queue import get_submission
|
| 1259 |
+
from eval_boltz import run_boltz_posteval
|
| 1260 |
+
|
| 1261 |
+
sub = get_submission(sid.strip())
|
| 1262 |
+
if sub is None:
|
| 1263 |
+
return (
|
| 1264 |
+
'<div style="color:#e53e3e">'
|
| 1265 |
+
'Not found</div>'
|
| 1266 |
+
)
|
| 1267 |
+
per_task = json.loads(
|
| 1268 |
+
sub.get("per_task_results", "{}")
|
| 1269 |
+
)
|
| 1270 |
+
if not per_task:
|
| 1271 |
+
return (
|
| 1272 |
+
'<div style="color:#e53e3e">'
|
| 1273 |
+
"No task results to process.</div>"
|
| 1274 |
+
)
|
| 1275 |
+
run_boltz_posteval(per_task)
|
| 1276 |
+
return (
|
| 1277 |
+
'<div style="color:#38a169">'
|
| 1278 |
+
"Boltz post-assessment complete.</div>"
|
| 1279 |
+
)
|
| 1280 |
+
except Exception as e:
|
| 1281 |
+
return f'<div style="color:#e53e3e">{e}</div>'
|
| 1282 |
+
|
| 1283 |
+
def _run_finalize(sid):
|
| 1284 |
+
try:
|
| 1285 |
+
from eval_queue import (
|
| 1286 |
+
finalize_submission,
|
| 1287 |
+
get_submission,
|
| 1288 |
+
)
|
| 1289 |
+
from eval_scorer import aggregate_scores
|
| 1290 |
+
|
| 1291 |
+
sub = get_submission(sid.strip())
|
| 1292 |
+
if sub is None:
|
| 1293 |
+
return (
|
| 1294 |
+
'<div style="color:#e53e3e">'
|
| 1295 |
+
'Not found</div>'
|
| 1296 |
+
)
|
| 1297 |
+
per_task = json.loads(
|
| 1298 |
+
sub.get("per_task_results", "{}")
|
| 1299 |
+
)
|
| 1300 |
+
agg = aggregate_scores(per_task)
|
| 1301 |
+
finalize_submission(
|
| 1302 |
+
sid.strip(),
|
| 1303 |
+
overall_score=agg["overall_score"],
|
| 1304 |
+
component_scores=agg["component_scores"],
|
| 1305 |
+
taxonomy_scores=agg["taxonomy_scores"],
|
| 1306 |
+
)
|
| 1307 |
+
return (
|
| 1308 |
+
f'<div style="color:#38a169">'
|
| 1309 |
+
f'Finalized! Score: '
|
| 1310 |
+
f'{agg["overall_score"]:.1f}</div>'
|
| 1311 |
+
)
|
| 1312 |
+
except Exception as e:
|
| 1313 |
+
return f'<div style="color:#e53e3e">{e}</div>'
|
| 1314 |
+
|
| 1315 |
+
dispatch_btn.click(
|
| 1316 |
+
_run_dispatch, [dispatch_id], pipeline_out,
|
| 1317 |
+
)
|
| 1318 |
+
boltz_btn.click(
|
| 1319 |
+
_run_boltz, [boltz_id], pipeline_out,
|
| 1320 |
+
)
|
| 1321 |
+
final_btn.click(
|
| 1322 |
+
_run_finalize, [final_id], pipeline_out,
|
| 1323 |
+
)
|
| 1324 |
+
|
| 1325 |
+
# ══════ Tab 7: About ══════
|
| 1326 |
with gr.Tab("\u2139\ufe0f About"):
|
| 1327 |
gr.HTML(build_about())
|
| 1328 |
|
eval_boltz.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Boltz structure prediction for post-assessment scoring.
|
| 2 |
+
|
| 3 |
+
Uses @spaces.GPU decorator for ZeroGPU on HuggingFace Spaces.
|
| 4 |
+
|
| 5 |
+
Two prediction modes:
|
| 6 |
+
- Monomer: Non-binding tasks -> pLDDT, pTM
|
| 7 |
+
- Complex: Binding tasks (binder + target) -> ipTM, i_pAE
|
| 8 |
+
|
| 9 |
+
Batch chunking respects ZeroGPU time limits (~180-240s per burst).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import time
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Chunking limits for ZeroGPU (free tier: ~300s max per burst)
|
| 21 |
+
MONOMER_CHUNK_SIZE = 5 # ~30-60s per monomer
|
| 22 |
+
COMPLEX_CHUNK_SIZE = 2 # ~60-120s per complex
|
| 23 |
+
MAX_GPU_TIME = 240 # safety margin under 300s ZeroGPU limit
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Boltz prediction (GPU-accelerated)
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _predict_monomer(sequence: str) -> dict[str, float]:
|
| 32 |
+
"""Predict structure of a single protein sequence using Boltz.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Dict with: pLDDT, pTM (or error).
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
import torch
|
| 39 |
+
from boltz import Boltz
|
| 40 |
+
|
| 41 |
+
model = Boltz.from_pretrained("boltz2")
|
| 42 |
+
result = model.predict(sequence)
|
| 43 |
+
|
| 44 |
+
plddt = float(result.confidence.plddt.mean())
|
| 45 |
+
ptm = float(result.confidence.ptm)
|
| 46 |
+
|
| 47 |
+
return {
|
| 48 |
+
"pLDDT": round(plddt, 2),
|
| 49 |
+
"pTM": round(ptm, 4),
|
| 50 |
+
"success": True,
|
| 51 |
+
}
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Boltz monomer prediction failed: {e}")
|
| 54 |
+
return {"pLDDT": 0.0, "pTM": 0.0, "success": False, "error": str(e)}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _predict_complex(
|
| 58 |
+
binder_seq: str,
|
| 59 |
+
target_seq: str,
|
| 60 |
+
) -> dict[str, float]:
|
| 61 |
+
"""Predict complex structure and binding metrics using Boltz.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Dict with: ipTM, i_pAE, pLDDT, pTM (or error).
|
| 65 |
+
"""
|
| 66 |
+
try:
|
| 67 |
+
import torch
|
| 68 |
+
from boltz import Boltz
|
| 69 |
+
|
| 70 |
+
model = Boltz.from_pretrained("boltz2")
|
| 71 |
+
result = model.predict([binder_seq, target_seq])
|
| 72 |
+
|
| 73 |
+
plddt = float(result.confidence.plddt.mean())
|
| 74 |
+
ptm = float(result.confidence.ptm)
|
| 75 |
+
iptm = float(result.confidence.iptm) if hasattr(result.confidence, "iptm") else 0.0
|
| 76 |
+
ipae = float(result.confidence.ipae) if hasattr(result.confidence, "ipae") else 0.0
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"pLDDT": round(plddt, 2),
|
| 80 |
+
"pTM": round(ptm, 4),
|
| 81 |
+
"ipTM": round(iptm, 4),
|
| 82 |
+
"i_pAE": round(ipae, 2),
|
| 83 |
+
"success": True,
|
| 84 |
+
}
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Boltz complex prediction failed: {e}")
|
| 87 |
+
return {
|
| 88 |
+
"pLDDT": 0.0, "pTM": 0.0, "ipTM": 0.0, "i_pAE": 0.0,
|
| 89 |
+
"success": False, "error": str(e),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# GPU-decorated entry points (for HF Spaces with ZeroGPU)
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
import spaces
|
| 99 |
+
|
| 100 |
+
@spaces.GPU(duration=MAX_GPU_TIME)
|
| 101 |
+
def predict_monomer_batch(sequences: list[str]) -> list[dict[str, float]]:
|
| 102 |
+
"""Predict structures for a batch of monomer sequences.
|
| 103 |
+
|
| 104 |
+
Decorated with @spaces.GPU for ZeroGPU allocation.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
sequences: List of amino acid sequences (max MONOMER_CHUNK_SIZE).
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
List of prediction result dicts with pLDDT, pTM.
|
| 111 |
+
"""
|
| 112 |
+
results = []
|
| 113 |
+
for seq in sequences[:MONOMER_CHUNK_SIZE]:
|
| 114 |
+
results.append(_predict_monomer(seq))
|
| 115 |
+
return results
|
| 116 |
+
|
| 117 |
+
@spaces.GPU(duration=MAX_GPU_TIME)
|
| 118 |
+
def predict_complex_batch(
|
| 119 |
+
pairs: list[tuple[str, str]],
|
| 120 |
+
) -> list[dict[str, float]]:
|
| 121 |
+
"""Predict structures for a batch of binder-target pairs.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
pairs: List of (binder_seq, target_seq) tuples.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
List of prediction result dicts with ipTM, i_pAE, pLDDT, pTM.
|
| 128 |
+
"""
|
| 129 |
+
results = []
|
| 130 |
+
for binder, target in pairs[:COMPLEX_CHUNK_SIZE]:
|
| 131 |
+
results.append(_predict_complex(binder, target))
|
| 132 |
+
return results
|
| 133 |
+
|
| 134 |
+
except ImportError:
|
| 135 |
+
# Not running on HF Spaces -- provide un-decorated versions
|
| 136 |
+
def predict_monomer_batch(sequences: list[str]) -> list[dict[str, float]]:
|
| 137 |
+
return [_predict_monomer(seq) for seq in sequences[:MONOMER_CHUNK_SIZE]]
|
| 138 |
+
|
| 139 |
+
def predict_complex_batch(
|
| 140 |
+
pairs: list[tuple[str, str]],
|
| 141 |
+
) -> list[dict[str, float]]:
|
| 142 |
+
return [_predict_complex(b, t) for b, t in pairs[:COMPLEX_CHUNK_SIZE]]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
# High-level assessment API
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def run_boltz_posteval(
|
| 151 |
+
per_task_results: dict[str, dict[str, Any]],
|
| 152 |
+
progress_callback=None,
|
| 153 |
+
) -> dict[str, dict[str, Any]]:
|
| 154 |
+
"""Run Boltz post-assessment on all tasks that need it.
|
| 155 |
+
|
| 156 |
+
For each task:
|
| 157 |
+
- Non-binding: pick best design -> monomer prediction
|
| 158 |
+
- Binding: pick best design + target sequence -> complex prediction
|
| 159 |
+
- Merge Boltz metrics into existing results
|
| 160 |
+
- Re-score quality component
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
per_task_results: Dict of task_id -> dispatch result (from dispatcher).
|
| 164 |
+
progress_callback: Optional callback(task_id, i, total, metrics).
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Updated per_task_results with Boltz metrics and final quality scores.
|
| 168 |
+
"""
|
| 169 |
+
from eval_scorer import _is_binding_task, score_quality
|
| 170 |
+
|
| 171 |
+
# Separate tasks into monomer and complex batches
|
| 172 |
+
monomer_tasks = []
|
| 173 |
+
complex_tasks = []
|
| 174 |
+
|
| 175 |
+
for task_id, result in per_task_results.items():
|
| 176 |
+
if not result.get("success") or not result.get("quality_pending"):
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
sequences = result.get("sequences", [])
|
| 180 |
+
if not sequences:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
best_seq = sequences[0] # Use first design for Boltz
|
| 184 |
+
|
| 185 |
+
if _is_binding_task(task_id):
|
| 186 |
+
# Need target sequence from ground truth
|
| 187 |
+
target_seq = result.get("ground_truth_thresholds", {}).get("target_sequence")
|
| 188 |
+
if target_seq:
|
| 189 |
+
complex_tasks.append((task_id, best_seq, target_seq))
|
| 190 |
+
else:
|
| 191 |
+
# Fall back to monomer if no target
|
| 192 |
+
monomer_tasks.append((task_id, best_seq))
|
| 193 |
+
else:
|
| 194 |
+
monomer_tasks.append((task_id, best_seq))
|
| 195 |
+
|
| 196 |
+
total = len(monomer_tasks) + len(complex_tasks)
|
| 197 |
+
done = 0
|
| 198 |
+
|
| 199 |
+
# Process monomer tasks in chunks
|
| 200 |
+
for chunk_start in range(0, len(monomer_tasks), MONOMER_CHUNK_SIZE):
|
| 201 |
+
chunk = monomer_tasks[chunk_start:chunk_start + MONOMER_CHUNK_SIZE]
|
| 202 |
+
seqs = [seq for _, seq in chunk]
|
| 203 |
+
|
| 204 |
+
boltz_results = predict_monomer_batch(seqs)
|
| 205 |
+
|
| 206 |
+
for (task_id, _), metrics in zip(chunk, boltz_results):
|
| 207 |
+
if metrics.get("success"):
|
| 208 |
+
_merge_boltz_metrics(per_task_results[task_id], metrics)
|
| 209 |
+
|
| 210 |
+
done += 1
|
| 211 |
+
if progress_callback:
|
| 212 |
+
progress_callback(task_id, done, total, metrics)
|
| 213 |
+
|
| 214 |
+
# Process complex tasks in chunks
|
| 215 |
+
for chunk_start in range(0, len(complex_tasks), COMPLEX_CHUNK_SIZE):
|
| 216 |
+
chunk = complex_tasks[chunk_start:chunk_start + COMPLEX_CHUNK_SIZE]
|
| 217 |
+
pairs = [(binder, target) for _, binder, target in chunk]
|
| 218 |
+
|
| 219 |
+
boltz_results = predict_complex_batch(pairs)
|
| 220 |
+
|
| 221 |
+
for (task_id, _, _), metrics in zip(chunk, boltz_results):
|
| 222 |
+
if metrics.get("success"):
|
| 223 |
+
_merge_boltz_metrics(per_task_results[task_id], metrics)
|
| 224 |
+
|
| 225 |
+
done += 1
|
| 226 |
+
if progress_callback:
|
| 227 |
+
progress_callback(task_id, done, total, metrics)
|
| 228 |
+
|
| 229 |
+
return per_task_results
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _merge_boltz_metrics(
|
| 233 |
+
task_result: dict[str, Any],
|
| 234 |
+
boltz_metrics: dict[str, float],
|
| 235 |
+
) -> None:
|
| 236 |
+
"""Merge Boltz prediction metrics into a task result and re-score quality.
|
| 237 |
+
|
| 238 |
+
Modifies task_result in-place.
|
| 239 |
+
"""
|
| 240 |
+
from eval_scorer import apply_design_gate, score_quality
|
| 241 |
+
|
| 242 |
+
# Merge Boltz metrics with any agent-reported metrics
|
| 243 |
+
merged_metrics = task_result.get("agent_metrics", {}).copy()
|
| 244 |
+
for key in ("pLDDT", "pTM", "ipTM", "i_pAE"):
|
| 245 |
+
if key in boltz_metrics and boltz_metrics[key] > 0:
|
| 246 |
+
merged_metrics[key] = boltz_metrics[key]
|
| 247 |
+
|
| 248 |
+
# Re-score quality with Boltz metrics
|
| 249 |
+
quality_result = score_quality(
|
| 250 |
+
agent_metrics=merged_metrics,
|
| 251 |
+
thresholds=task_result.get("ground_truth_thresholds", {}),
|
| 252 |
+
task_id=task_result.get("task_id", ""),
|
| 253 |
+
designs=task_result.get("sequences"),
|
| 254 |
+
oracle_sequences=task_result.get("oracle_sequences"),
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Update scores
|
| 258 |
+
task_result["boltz_metrics"] = boltz_metrics
|
| 259 |
+
task_result["quality_pending"] = False
|
| 260 |
+
|
| 261 |
+
if "cpu_scores" in task_result:
|
| 262 |
+
task_result["cpu_scores"]["quality"] = quality_result["score"]
|
| 263 |
+
|
| 264 |
+
# Compute final gated score
|
| 265 |
+
if "cpu_scores" in task_result:
|
| 266 |
+
component_scores = dict(task_result["cpu_scores"])
|
| 267 |
+
gated = apply_design_gate(component_scores, task_result.get("num_designs", 0))
|
| 268 |
+
task_result["final_scores"] = gated
|
| 269 |
+
task_result["total_score"] = sum(gated.values())
|
| 270 |
+
|
| 271 |
+
if "cpu_details" in task_result:
|
| 272 |
+
task_result["cpu_details"]["quality"] = quality_result
|
eval_dispatcher.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HTTP task dispatcher — sends benchmark tasks to submitter endpoints.
|
| 2 |
+
|
| 3 |
+
For each of 76 tasks:
|
| 4 |
+
1. Build task payload (prompt + tools + PDB data)
|
| 5 |
+
2. POST to submitter's endpoint with timeout
|
| 6 |
+
3. Validate response format
|
| 7 |
+
4. Run CPU-only scoring (approach, orchestration, feasibility, novelty, diversity)
|
| 8 |
+
5. Save results to submission queue
|
| 9 |
+
|
| 10 |
+
CPU scoring runs immediately; quality scoring waits for Boltz post-eval.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
from typing import Any, Generator
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Response validation limits
|
| 22 |
+
MAX_SEQUENCES = 50
|
| 23 |
+
MAX_SEQUENCE_LENGTH = 2000
|
| 24 |
+
MAX_LOG_ENTRIES = 200
|
| 25 |
+
DISPATCH_TIMEOUT = 300 # seconds per task
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Response validation
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def validate_response(response: dict[str, Any]) -> tuple[bool, str]:
|
| 34 |
+
"""Validate the submitter's response format.
|
| 35 |
+
|
| 36 |
+
Expected format:
|
| 37 |
+
{
|
| 38 |
+
"sequences": ["MKKL...", ...],
|
| 39 |
+
"run_log": [{"step": 1, "tool": "...", "success": true, ...}, ...],
|
| 40 |
+
"total_steps": 12,
|
| 41 |
+
"total_time_sec": 142.5,
|
| 42 |
+
"metrics": {}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
(is_valid, error_message)
|
| 47 |
+
"""
|
| 48 |
+
if not isinstance(response, dict):
|
| 49 |
+
return False, "Response must be a JSON object"
|
| 50 |
+
|
| 51 |
+
# sequences (required)
|
| 52 |
+
sequences = response.get("sequences")
|
| 53 |
+
if not isinstance(sequences, list):
|
| 54 |
+
return False, "Missing or invalid 'sequences' field (must be a list)"
|
| 55 |
+
|
| 56 |
+
if len(sequences) > MAX_SEQUENCES:
|
| 57 |
+
return False, f"Too many sequences: {len(sequences)} > {MAX_SEQUENCES}"
|
| 58 |
+
|
| 59 |
+
for i, seq in enumerate(sequences):
|
| 60 |
+
if not isinstance(seq, str):
|
| 61 |
+
return False, f"sequences[{i}] must be a string"
|
| 62 |
+
if len(seq) > MAX_SEQUENCE_LENGTH:
|
| 63 |
+
return False, f"sequences[{i}] too long: {len(seq)} > {MAX_SEQUENCE_LENGTH}"
|
| 64 |
+
if len(seq) == 0:
|
| 65 |
+
return False, f"sequences[{i}] is empty"
|
| 66 |
+
|
| 67 |
+
# run_log (required)
|
| 68 |
+
run_log = response.get("run_log")
|
| 69 |
+
if not isinstance(run_log, list):
|
| 70 |
+
return False, "Missing or invalid 'run_log' field (must be a list)"
|
| 71 |
+
|
| 72 |
+
if len(run_log) > MAX_LOG_ENTRIES:
|
| 73 |
+
return False, f"Too many log entries: {len(run_log)} > {MAX_LOG_ENTRIES}"
|
| 74 |
+
|
| 75 |
+
for i, entry in enumerate(run_log):
|
| 76 |
+
if not isinstance(entry, dict):
|
| 77 |
+
return False, f"run_log[{i}] must be a dict"
|
| 78 |
+
if "tool" not in entry:
|
| 79 |
+
return False, f"run_log[{i}] missing 'tool' field"
|
| 80 |
+
|
| 81 |
+
# Optional fields — validate types if present
|
| 82 |
+
if "total_steps" in response:
|
| 83 |
+
if not isinstance(response["total_steps"], (int, float)):
|
| 84 |
+
return False, "'total_steps' must be a number"
|
| 85 |
+
|
| 86 |
+
if "total_time_sec" in response:
|
| 87 |
+
if not isinstance(response["total_time_sec"], (int, float)):
|
| 88 |
+
return False, "'total_time_sec' must be a number"
|
| 89 |
+
|
| 90 |
+
return True, ""
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Single task dispatch
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
async def dispatch_single_task(
|
| 99 |
+
endpoint_url: str,
|
| 100 |
+
task_payload: dict[str, Any],
|
| 101 |
+
timeout: int = DISPATCH_TIMEOUT,
|
| 102 |
+
) -> dict[str, Any]:
|
| 103 |
+
"""Send a single task to the submitter's endpoint.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
endpoint_url: Submitter's POST endpoint URL.
|
| 107 |
+
task_payload: Task payload from eval_tasks.build_task_payload().
|
| 108 |
+
timeout: Request timeout in seconds.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Dict with: success, task_id, response (if success), error (if failed),
|
| 112 |
+
latency_sec.
|
| 113 |
+
"""
|
| 114 |
+
import httpx
|
| 115 |
+
|
| 116 |
+
task_id = task_payload["task_id"]
|
| 117 |
+
start = time.monotonic()
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 121 |
+
resp = await client.post(
|
| 122 |
+
endpoint_url,
|
| 123 |
+
json=task_payload,
|
| 124 |
+
headers={"Content-Type": "application/json"},
|
| 125 |
+
)
|
| 126 |
+
latency = time.monotonic() - start
|
| 127 |
+
|
| 128 |
+
if resp.status_code != 200:
|
| 129 |
+
return {
|
| 130 |
+
"success": False,
|
| 131 |
+
"task_id": task_id,
|
| 132 |
+
"error": f"HTTP {resp.status_code}: {resp.text[:200]}",
|
| 133 |
+
"latency_sec": round(latency, 1),
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
data = resp.json()
|
| 138 |
+
except Exception:
|
| 139 |
+
return {
|
| 140 |
+
"success": False,
|
| 141 |
+
"task_id": task_id,
|
| 142 |
+
"error": "Response is not valid JSON",
|
| 143 |
+
"latency_sec": round(latency, 1),
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
is_valid, error_msg = validate_response(data)
|
| 147 |
+
if not is_valid:
|
| 148 |
+
return {
|
| 149 |
+
"success": False,
|
| 150 |
+
"task_id": task_id,
|
| 151 |
+
"error": f"Invalid response: {error_msg}",
|
| 152 |
+
"latency_sec": round(latency, 1),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
"success": True,
|
| 157 |
+
"task_id": task_id,
|
| 158 |
+
"response": data,
|
| 159 |
+
"latency_sec": round(latency, 1),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
except httpx.TimeoutException:
|
| 163 |
+
latency = time.monotonic() - start
|
| 164 |
+
return {
|
| 165 |
+
"success": False,
|
| 166 |
+
"task_id": task_id,
|
| 167 |
+
"error": f"Timeout after {timeout}s",
|
| 168 |
+
"latency_sec": round(latency, 1),
|
| 169 |
+
}
|
| 170 |
+
except Exception as e:
|
| 171 |
+
latency = time.monotonic() - start
|
| 172 |
+
return {
|
| 173 |
+
"success": False,
|
| 174 |
+
"task_id": task_id,
|
| 175 |
+
"error": f"Connection error: {str(e)[:200]}",
|
| 176 |
+
"latency_sec": round(latency, 1),
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ---------------------------------------------------------------------------
|
| 181 |
+
# CPU scoring (runs immediately, no GPU needed)
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def score_cpu_components(
|
| 186 |
+
task_id: str,
|
| 187 |
+
sequences: list[str],
|
| 188 |
+
run_log: list[dict[str, Any]],
|
| 189 |
+
ground_truth: dict[str, Any],
|
| 190 |
+
oracle_sequences: list[str] | None = None,
|
| 191 |
+
) -> dict[str, Any]:
|
| 192 |
+
"""Run CPU-only scoring components.
|
| 193 |
+
|
| 194 |
+
Scores: approach, orchestration, feasibility, novelty, diversity.
|
| 195 |
+
Quality scoring is deferred until Boltz post-eval provides pLDDT/ipTM.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
task_id: Task identifier.
|
| 199 |
+
sequences: Designed sequences from submitter.
|
| 200 |
+
run_log: Tool call log from submitter.
|
| 201 |
+
ground_truth: Ground truth data for this task.
|
| 202 |
+
oracle_sequences: Oracle sequences for non-binding tasks.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Dict with partial scores and metadata for later Boltz completion.
|
| 206 |
+
"""
|
| 207 |
+
from eval_scorer import (
|
| 208 |
+
get_category,
|
| 209 |
+
score_approach,
|
| 210 |
+
score_diversity,
|
| 211 |
+
score_feasibility,
|
| 212 |
+
score_novelty,
|
| 213 |
+
score_orchestration,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Extract fields
|
| 217 |
+
thresholds = ground_truth.get("thresholds", {})
|
| 218 |
+
reference_seq = ground_truth.get("reference_sequence")
|
| 219 |
+
constraints = ground_truth.get("design_constraints", {})
|
| 220 |
+
tools_expected = ground_truth.get("tools_expected", [])
|
| 221 |
+
max_designs = ground_truth.get("max_designs", 10)
|
| 222 |
+
|
| 223 |
+
cat = get_category(task_id)
|
| 224 |
+
task_type = cat.task_type if cat else None
|
| 225 |
+
tools_used = [e.get("tool", "") for e in run_log if e.get("tool")]
|
| 226 |
+
|
| 227 |
+
approach_result = score_approach(
|
| 228 |
+
tools_used=tools_used,
|
| 229 |
+
tools_expected=tools_expected,
|
| 230 |
+
task_type=task_type,
|
| 231 |
+
)
|
| 232 |
+
orchestration_result = score_orchestration(
|
| 233 |
+
tool_call_log=run_log,
|
| 234 |
+
task_id=task_id,
|
| 235 |
+
)
|
| 236 |
+
feasibility_result = score_feasibility(
|
| 237 |
+
designs=sequences,
|
| 238 |
+
constraints=constraints,
|
| 239 |
+
)
|
| 240 |
+
novelty_result = score_novelty(
|
| 241 |
+
designs=sequences,
|
| 242 |
+
reference_seq=reference_seq,
|
| 243 |
+
thresholds=thresholds,
|
| 244 |
+
)
|
| 245 |
+
diversity_result = score_diversity(
|
| 246 |
+
designs=sequences,
|
| 247 |
+
max_designs=max_designs,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return {
|
| 251 |
+
"task_id": task_id,
|
| 252 |
+
"num_designs": len(sequences),
|
| 253 |
+
"sequences": sequences,
|
| 254 |
+
"cpu_scores": {
|
| 255 |
+
"approach": approach_result["score"],
|
| 256 |
+
"orchestration": orchestration_result["score"],
|
| 257 |
+
"feasibility": feasibility_result["score"],
|
| 258 |
+
"novelty": novelty_result["score"],
|
| 259 |
+
"diversity": diversity_result["score"],
|
| 260 |
+
},
|
| 261 |
+
"cpu_details": {
|
| 262 |
+
"approach": approach_result,
|
| 263 |
+
"orchestration": orchestration_result,
|
| 264 |
+
"feasibility": feasibility_result,
|
| 265 |
+
"novelty": novelty_result,
|
| 266 |
+
"diversity": diversity_result,
|
| 267 |
+
},
|
| 268 |
+
"quality_pending": True, # Needs Boltz post-eval
|
| 269 |
+
"oracle_sequences": oracle_sequences or [],
|
| 270 |
+
"ground_truth_thresholds": thresholds,
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# ---------------------------------------------------------------------------
|
| 275 |
+
# Full dispatch pipeline
|
| 276 |
+
# ---------------------------------------------------------------------------
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
async def dispatch_all_tasks(
|
| 280 |
+
submission_id: str,
|
| 281 |
+
endpoint_url: str,
|
| 282 |
+
progress_callback=None,
|
| 283 |
+
) -> Generator[dict[str, Any], None, None]:
|
| 284 |
+
"""Dispatch all hidden tasks to a submitter endpoint.
|
| 285 |
+
|
| 286 |
+
Yields progress updates as each task completes. Saves results
|
| 287 |
+
to the submission queue incrementally.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
submission_id: Submission ID for queue tracking.
|
| 291 |
+
endpoint_url: Submitter's POST endpoint.
|
| 292 |
+
progress_callback: Optional callback(task_id, i, total, result)
|
| 293 |
+
for streaming progress updates.
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
List of per-task results.
|
| 297 |
+
"""
|
| 298 |
+
from eval_queue import save_task_result, update_status
|
| 299 |
+
from eval_tasks import build_task_payload, get_hidden_task_ids, get_task
|
| 300 |
+
|
| 301 |
+
task_ids = get_hidden_task_ids()
|
| 302 |
+
total = len(task_ids)
|
| 303 |
+
results = []
|
| 304 |
+
|
| 305 |
+
update_status(submission_id, "dispatching", tasks_total=total)
|
| 306 |
+
|
| 307 |
+
for i, task_id in enumerate(task_ids):
|
| 308 |
+
# Build payload
|
| 309 |
+
payload = build_task_payload(task_id)
|
| 310 |
+
if payload is None:
|
| 311 |
+
result = {
|
| 312 |
+
"task_id": task_id,
|
| 313 |
+
"success": False,
|
| 314 |
+
"error": "Task not found",
|
| 315 |
+
}
|
| 316 |
+
results.append(result)
|
| 317 |
+
save_task_result(submission_id, task_id, result)
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
# Dispatch
|
| 321 |
+
dispatch_result = await dispatch_single_task(endpoint_url, payload)
|
| 322 |
+
|
| 323 |
+
if dispatch_result["success"]:
|
| 324 |
+
# Run CPU scoring
|
| 325 |
+
task_data = get_task(task_id)
|
| 326 |
+
ground_truth = task_data["ground_truth"] if task_data else {}
|
| 327 |
+
oracle_seqs = task_data.get("oracle_sequences", []) if task_data else []
|
| 328 |
+
|
| 329 |
+
response = dispatch_result["response"]
|
| 330 |
+
cpu_result = score_cpu_components(
|
| 331 |
+
task_id=task_id,
|
| 332 |
+
sequences=response["sequences"],
|
| 333 |
+
run_log=response["run_log"],
|
| 334 |
+
ground_truth=ground_truth,
|
| 335 |
+
oracle_sequences=oracle_seqs,
|
| 336 |
+
)
|
| 337 |
+
cpu_result["latency_sec"] = dispatch_result["latency_sec"]
|
| 338 |
+
cpu_result["success"] = True
|
| 339 |
+
cpu_result["agent_metrics"] = response.get("metrics", {})
|
| 340 |
+
results.append(cpu_result)
|
| 341 |
+
save_task_result(submission_id, task_id, cpu_result)
|
| 342 |
+
else:
|
| 343 |
+
result = {
|
| 344 |
+
"task_id": task_id,
|
| 345 |
+
"success": False,
|
| 346 |
+
"error": dispatch_result["error"],
|
| 347 |
+
"latency_sec": dispatch_result.get("latency_sec"),
|
| 348 |
+
}
|
| 349 |
+
results.append(result)
|
| 350 |
+
save_task_result(submission_id, task_id, result)
|
| 351 |
+
|
| 352 |
+
if progress_callback:
|
| 353 |
+
progress_callback(task_id, i + 1, total, results[-1])
|
| 354 |
+
|
| 355 |
+
logger.info(
|
| 356 |
+
f"[{i+1}/{total}] {task_id}: "
|
| 357 |
+
f"{'OK' if results[-1].get('success') else 'FAIL'} "
|
| 358 |
+
f"({results[-1].get('latency_sec', 0):.1f}s)"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
return results
|
eval_queue.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Submission queue management using HuggingFace Datasets.
|
| 2 |
+
|
| 3 |
+
Manages the lifecycle of benchmark submissions:
|
| 4 |
+
pending → approved → dispatching → boltz → scoring → complete / failed
|
| 5 |
+
|
| 6 |
+
Rate limiting: 2 submissions per calendar month per organization.
|
| 7 |
+
|
| 8 |
+
HF Dataset: RomeroLab-Duke/biodesignbench-submissions (private)
|
| 9 |
+
Schema: Each row is a submission with per-task results stored as JSON.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import uuid
|
| 18 |
+
from datetime import datetime, timezone
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Constants
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
SUBMISSIONS_DATASET = os.environ.get(
|
| 28 |
+
"BDB_SUBMISSIONS_DATASET",
|
| 29 |
+
"RomeroLab-Duke/biodesignbench-submissions",
|
| 30 |
+
)
|
| 31 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 32 |
+
MAX_SUBMISSIONS_PER_MONTH = 2
|
| 33 |
+
|
| 34 |
+
# Submission status progression
|
| 35 |
+
VALID_STATUSES = {
|
| 36 |
+
"pending",
|
| 37 |
+
"approved",
|
| 38 |
+
"dispatching",
|
| 39 |
+
"boltz",
|
| 40 |
+
"scoring",
|
| 41 |
+
"complete",
|
| 42 |
+
"failed",
|
| 43 |
+
"rejected",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Data model
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _make_submission_row(
|
| 53 |
+
agent_name: str,
|
| 54 |
+
organization: str,
|
| 55 |
+
endpoint_url: str,
|
| 56 |
+
description: str = "",
|
| 57 |
+
mcp_custom: bool = False,
|
| 58 |
+
) -> dict[str, Any]:
|
| 59 |
+
"""Create a new submission row."""
|
| 60 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 61 |
+
return {
|
| 62 |
+
"submission_id": str(uuid.uuid4())[:12],
|
| 63 |
+
"agent_name": agent_name,
|
| 64 |
+
"organization": organization,
|
| 65 |
+
"endpoint_url": endpoint_url,
|
| 66 |
+
"description": description,
|
| 67 |
+
"mcp_custom": mcp_custom,
|
| 68 |
+
"status": "pending",
|
| 69 |
+
"created_at": now,
|
| 70 |
+
"updated_at": now,
|
| 71 |
+
"tasks_dispatched": 0,
|
| 72 |
+
"tasks_total": 76,
|
| 73 |
+
"tasks_boltz_done": 0,
|
| 74 |
+
"overall_score": None,
|
| 75 |
+
"component_scores": None,
|
| 76 |
+
"taxonomy_scores": None,
|
| 77 |
+
"per_task_results": "{}", # JSON string of task_id → result
|
| 78 |
+
"error_message": None,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# Queue operations (HF Datasets API)
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _get_dataset():
|
| 88 |
+
"""Load the submissions dataset from HF Hub."""
|
| 89 |
+
try:
|
| 90 |
+
from datasets import load_dataset
|
| 91 |
+
|
| 92 |
+
ds = load_dataset(
|
| 93 |
+
SUBMISSIONS_DATASET,
|
| 94 |
+
split="train",
|
| 95 |
+
token=HF_TOKEN,
|
| 96 |
+
)
|
| 97 |
+
return ds
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.warning(f"Could not load submissions dataset: {e}")
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _save_rows(rows: list[dict[str, Any]]) -> bool:
|
| 104 |
+
"""Save rows back to HF Dataset."""
|
| 105 |
+
try:
|
| 106 |
+
from datasets import Dataset
|
| 107 |
+
from huggingface_hub import HfApi
|
| 108 |
+
|
| 109 |
+
ds = Dataset.from_list(rows)
|
| 110 |
+
ds.push_to_hub(
|
| 111 |
+
SUBMISSIONS_DATASET,
|
| 112 |
+
token=HF_TOKEN,
|
| 113 |
+
private=True,
|
| 114 |
+
)
|
| 115 |
+
return True
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Failed to save submissions: {e}")
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _load_all_rows() -> list[dict[str, Any]]:
|
| 122 |
+
"""Load all submission rows as a list of dicts."""
|
| 123 |
+
ds = _get_dataset()
|
| 124 |
+
if ds is None:
|
| 125 |
+
return []
|
| 126 |
+
return [dict(row) for row in ds]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def submit(
|
| 130 |
+
agent_name: str,
|
| 131 |
+
organization: str,
|
| 132 |
+
endpoint_url: str,
|
| 133 |
+
description: str = "",
|
| 134 |
+
mcp_custom: bool = False,
|
| 135 |
+
) -> dict[str, Any]:
|
| 136 |
+
"""Create a new submission.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Dict with submission_id and status, or error message.
|
| 140 |
+
"""
|
| 141 |
+
# Rate limit check
|
| 142 |
+
error = check_rate_limit(organization)
|
| 143 |
+
if error:
|
| 144 |
+
return {"error": error}
|
| 145 |
+
|
| 146 |
+
# Validate endpoint URL
|
| 147 |
+
if not endpoint_url.startswith(("http://", "https://")):
|
| 148 |
+
return {"error": "Endpoint URL must start with http:// or https://"}
|
| 149 |
+
|
| 150 |
+
row = _make_submission_row(
|
| 151 |
+
agent_name=agent_name,
|
| 152 |
+
organization=organization,
|
| 153 |
+
endpoint_url=endpoint_url,
|
| 154 |
+
description=description,
|
| 155 |
+
mcp_custom=mcp_custom,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
rows = _load_all_rows()
|
| 159 |
+
rows.append(row)
|
| 160 |
+
|
| 161 |
+
if _save_rows(rows):
|
| 162 |
+
return {
|
| 163 |
+
"submission_id": row["submission_id"],
|
| 164 |
+
"status": "pending",
|
| 165 |
+
"message": f"Submission created. Awaiting admin approval.",
|
| 166 |
+
}
|
| 167 |
+
return {"error": "Failed to save submission. Please try again."}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def check_rate_limit(organization: str) -> str | None:
|
| 171 |
+
"""Check if an organization has exceeded the monthly submission limit.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Error message string if rate limited, None if OK.
|
| 175 |
+
"""
|
| 176 |
+
rows = _load_all_rows()
|
| 177 |
+
now = datetime.now(timezone.utc)
|
| 178 |
+
current_month = now.strftime("%Y-%m")
|
| 179 |
+
|
| 180 |
+
monthly_count = 0
|
| 181 |
+
for row in rows:
|
| 182 |
+
if row.get("organization", "").lower() != organization.lower():
|
| 183 |
+
continue
|
| 184 |
+
if row.get("status") in ("rejected", "failed"):
|
| 185 |
+
continue
|
| 186 |
+
created = row.get("created_at", "")
|
| 187 |
+
if created.startswith(current_month):
|
| 188 |
+
monthly_count += 1
|
| 189 |
+
|
| 190 |
+
if monthly_count >= MAX_SUBMISSIONS_PER_MONTH:
|
| 191 |
+
return (
|
| 192 |
+
f"Organization '{organization}' has reached the limit of "
|
| 193 |
+
f"{MAX_SUBMISSIONS_PER_MONTH} submissions for {current_month}."
|
| 194 |
+
)
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def update_status(
|
| 199 |
+
submission_id: str,
|
| 200 |
+
status: str,
|
| 201 |
+
**extra_fields: Any,
|
| 202 |
+
) -> bool:
|
| 203 |
+
"""Update a submission's status and optional extra fields.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
submission_id: The submission to update.
|
| 207 |
+
status: New status (must be in VALID_STATUSES).
|
| 208 |
+
**extra_fields: Additional fields to update (e.g., tasks_dispatched=10).
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
True if updated successfully.
|
| 212 |
+
"""
|
| 213 |
+
if status not in VALID_STATUSES:
|
| 214 |
+
logger.error(f"Invalid status: {status}")
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
rows = _load_all_rows()
|
| 218 |
+
found = False
|
| 219 |
+
for row in rows:
|
| 220 |
+
if row.get("submission_id") == submission_id:
|
| 221 |
+
row["status"] = status
|
| 222 |
+
row["updated_at"] = datetime.now(timezone.utc).isoformat()
|
| 223 |
+
for k, v in extra_fields.items():
|
| 224 |
+
if k in row:
|
| 225 |
+
row[k] = v
|
| 226 |
+
found = True
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
if not found:
|
| 230 |
+
logger.error(f"Submission {submission_id} not found")
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
return _save_rows(rows)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def save_task_result(
|
| 237 |
+
submission_id: str,
|
| 238 |
+
task_id: str,
|
| 239 |
+
result: dict[str, Any],
|
| 240 |
+
) -> bool:
|
| 241 |
+
"""Save a per-task result to the submission.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
submission_id: The submission to update.
|
| 245 |
+
task_id: Task identifier.
|
| 246 |
+
result: Score result dict from eval_scorer.score_submission_task().
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
True if saved successfully.
|
| 250 |
+
"""
|
| 251 |
+
rows = _load_all_rows()
|
| 252 |
+
for row in rows:
|
| 253 |
+
if row.get("submission_id") == submission_id:
|
| 254 |
+
per_task = json.loads(row.get("per_task_results", "{}"))
|
| 255 |
+
per_task[task_id] = result
|
| 256 |
+
row["per_task_results"] = json.dumps(per_task)
|
| 257 |
+
row["tasks_dispatched"] = len(per_task)
|
| 258 |
+
row["updated_at"] = datetime.now(timezone.utc).isoformat()
|
| 259 |
+
return _save_rows(rows)
|
| 260 |
+
|
| 261 |
+
logger.error(f"Submission {submission_id} not found")
|
| 262 |
+
return False
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def get_submission(submission_id: str) -> dict[str, Any] | None:
|
| 266 |
+
"""Get a single submission by ID."""
|
| 267 |
+
rows = _load_all_rows()
|
| 268 |
+
for row in rows:
|
| 269 |
+
if row.get("submission_id") == submission_id:
|
| 270 |
+
return row
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_pending_submissions() -> list[dict[str, Any]]:
|
| 275 |
+
"""Get all submissions awaiting admin approval."""
|
| 276 |
+
return [r for r in _load_all_rows() if r.get("status") == "pending"]
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def get_approved_submissions() -> list[dict[str, Any]]:
|
| 280 |
+
"""Get all approved submissions ready for dispatch."""
|
| 281 |
+
return [r for r in _load_all_rows() if r.get("status") == "approved"]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def get_all_submissions() -> list[dict[str, Any]]:
|
| 285 |
+
"""Get all submissions for the admin panel."""
|
| 286 |
+
return _load_all_rows()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def finalize_submission(
|
| 290 |
+
submission_id: str,
|
| 291 |
+
overall_score: float,
|
| 292 |
+
component_scores: dict[str, float],
|
| 293 |
+
taxonomy_scores: dict[str, dict[str, float]],
|
| 294 |
+
) -> bool:
|
| 295 |
+
"""Finalize a submission with aggregated scores.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
submission_id: The submission to finalize.
|
| 299 |
+
overall_score: Overall score (0-100).
|
| 300 |
+
component_scores: Dict of component → averaged score.
|
| 301 |
+
taxonomy_scores: Nested dict of task_type → context → avg score.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
True if finalized successfully.
|
| 305 |
+
"""
|
| 306 |
+
return update_status(
|
| 307 |
+
submission_id,
|
| 308 |
+
status="complete",
|
| 309 |
+
overall_score=overall_score,
|
| 310 |
+
component_scores=json.dumps(component_scores),
|
| 311 |
+
taxonomy_scores=json.dumps(taxonomy_scores),
|
| 312 |
+
)
|
eval_scorer.py
ADDED
|
@@ -0,0 +1,1643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Standalone 100-point scoring rubric for BioDesignBench Tier 2 design tasks.
|
| 2 |
+
|
| 3 |
+
This file is a **self-contained extraction** of the scoring logic from the
|
| 4 |
+
``biodesignbench`` package. It has **zero external dependencies** (stdlib only)
|
| 5 |
+
so it can run on HuggingFace Spaces without installing the full package.
|
| 6 |
+
|
| 7 |
+
Modules consolidated:
|
| 8 |
+
- biodesignbench/taxonomy.py
|
| 9 |
+
- biodesignbench/eval/metrics/sequence.py
|
| 10 |
+
- biodesignbench/eval/metrics/approach.py
|
| 11 |
+
- biodesignbench/eval/metrics/orchestration.py
|
| 12 |
+
- biodesignbench/eval/tier2/scoring.py
|
| 13 |
+
- biodesignbench/eval/tier2/oracle.py (oracle loading stub)
|
| 14 |
+
|
| 15 |
+
Six scoring components (sum = 100):
|
| 16 |
+
approach (20 pts) — Tool/methodology selection
|
| 17 |
+
orchestration (15 pts) — Pipeline ordering + intermediate validation
|
| 18 |
+
quality (35 pts) — 3-tier continuous scoring (structure/interface/physics)
|
| 19 |
+
feasibility (15 pts) — Valid AAs, length, composition + biophysical checks
|
| 20 |
+
novelty ( 5 pts) — Sequence identity to known sequences
|
| 21 |
+
diversity (10 pts) — Number + diversity of designs
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
import math
|
| 28 |
+
import re
|
| 29 |
+
from collections import Counter
|
| 30 |
+
from dataclasses import dataclass, field
|
| 31 |
+
from enum import Enum
|
| 32 |
+
from functools import lru_cache
|
| 33 |
+
from itertools import combinations
|
| 34 |
+
from typing import Any, Optional
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 38 |
+
# SECTION 1 — Taxonomy (from biodesignbench/taxonomy.py)
|
| 39 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DesignTaskType(str, Enum):
|
| 43 |
+
"""What the agent does."""
|
| 44 |
+
|
| 45 |
+
DE_NOVO_BINDER = "de_novo_binder"
|
| 46 |
+
SEQUENCE_OPTIMIZATION = "sequence_optimization"
|
| 47 |
+
DE_NOVO_BACKBONE = "de_novo_backbone"
|
| 48 |
+
COMPLEX_ENGINEERING = "complex_engineering"
|
| 49 |
+
CONFORMATIONAL_DESIGN = "conformational_design"
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def short(self) -> str:
|
| 53 |
+
return _TASK_TYPE_SHORT[self]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class BiologicalContext(str, Enum):
|
| 57 |
+
"""Domain knowledge required."""
|
| 58 |
+
|
| 59 |
+
ANTIBODY = "antibody"
|
| 60 |
+
ENZYME = "enzyme"
|
| 61 |
+
SIGNALING = "signaling"
|
| 62 |
+
STRUCTURAL = "structural"
|
| 63 |
+
FLUORESCENT = "fluorescent"
|
| 64 |
+
THERAPEUTIC = "therapeutic"
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def short(self) -> str:
|
| 68 |
+
return _CONTEXT_SHORT[self]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
_TASK_TYPE_SHORT: dict[DesignTaskType, str] = {
|
| 72 |
+
DesignTaskType.DE_NOVO_BINDER: "dnb",
|
| 73 |
+
DesignTaskType.SEQUENCE_OPTIMIZATION: "sqo",
|
| 74 |
+
DesignTaskType.DE_NOVO_BACKBONE: "dnk",
|
| 75 |
+
DesignTaskType.COMPLEX_ENGINEERING: "cpx",
|
| 76 |
+
DesignTaskType.CONFORMATIONAL_DESIGN: "cfd",
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
_CONTEXT_SHORT: dict[BiologicalContext, str] = {
|
| 80 |
+
BiologicalContext.ANTIBODY: "ab",
|
| 81 |
+
BiologicalContext.ENZYME: "enz",
|
| 82 |
+
BiologicalContext.SIGNALING: "sig",
|
| 83 |
+
BiologicalContext.STRUCTURAL: "str",
|
| 84 |
+
BiologicalContext.FLUORESCENT: "flu",
|
| 85 |
+
BiologicalContext.THERAPEUTIC: "thr",
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
_SHORT_TO_TASK_TYPE: dict[str, DesignTaskType] = {v: k for k, v in _TASK_TYPE_SHORT.items()}
|
| 89 |
+
_SHORT_TO_CONTEXT: dict[str, BiologicalContext] = {v: k for k, v in _CONTEXT_SHORT.items()}
|
| 90 |
+
|
| 91 |
+
# Core tools expected per task type
|
| 92 |
+
_CORE_TOOLS: dict[DesignTaskType, list[str]] = {
|
| 93 |
+
DesignTaskType.DE_NOVO_BINDER: ["rfdiffusion", "proteinmpnn", "alphafold2"],
|
| 94 |
+
DesignTaskType.SEQUENCE_OPTIMIZATION: ["proteinmpnn", "esmfold", "alphafold2"],
|
| 95 |
+
DesignTaskType.DE_NOVO_BACKBONE: ["rfdiffusion", "proteinmpnn", "alphafold2"],
|
| 96 |
+
DesignTaskType.COMPLEX_ENGINEERING: ["rfdiffusion", "proteinmpnn", "alphafold2"],
|
| 97 |
+
DesignTaskType.CONFORMATIONAL_DESIGN: ["esmfold", "proteinmpnn", "alphafold2"],
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
_PRIMARY_METRIC: dict[DesignTaskType, str] = {
|
| 101 |
+
DesignTaskType.DE_NOVO_BINDER: "ipTM",
|
| 102 |
+
DesignTaskType.SEQUENCE_OPTIMIZATION: "pLDDT",
|
| 103 |
+
DesignTaskType.DE_NOVO_BACKBONE: "pLDDT",
|
| 104 |
+
DesignTaskType.COMPLEX_ENGINEERING: "ipTM",
|
| 105 |
+
DesignTaskType.CONFORMATIONAL_DESIGN: "pLDDT",
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass(frozen=True)
|
| 110 |
+
class TaskCategory:
|
| 111 |
+
"""A valid cell in the DesignTaskType × BiologicalContext matrix."""
|
| 112 |
+
|
| 113 |
+
task_type: DesignTaskType
|
| 114 |
+
context: BiologicalContext
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def category_id(self) -> str:
|
| 118 |
+
return f"{self.task_type.short}_{self.context.short}"
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def expected_core_tools(self) -> list[str]:
|
| 122 |
+
return list(_CORE_TOOLS[self.task_type])
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def primary_quality_metric(self) -> str:
|
| 126 |
+
return _PRIMARY_METRIC[self.task_type]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
VALID_CATEGORIES: list[TaskCategory] = [
|
| 130 |
+
# de_novo_binder (4)
|
| 131 |
+
TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ANTIBODY),
|
| 132 |
+
TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ENZYME),
|
| 133 |
+
TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.SIGNALING),
|
| 134 |
+
TaskCategory(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.THERAPEUTIC),
|
| 135 |
+
# sequence_optimization (5)
|
| 136 |
+
TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ANTIBODY),
|
| 137 |
+
TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ENZYME),
|
| 138 |
+
TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.SIGNALING),
|
| 139 |
+
TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.STRUCTURAL),
|
| 140 |
+
TaskCategory(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.FLUORESCENT),
|
| 141 |
+
# de_novo_backbone (1)
|
| 142 |
+
TaskCategory(DesignTaskType.DE_NOVO_BACKBONE, BiologicalContext.STRUCTURAL),
|
| 143 |
+
# complex_engineering (3)
|
| 144 |
+
TaskCategory(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.ENZYME),
|
| 145 |
+
TaskCategory(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.SIGNALING),
|
| 146 |
+
TaskCategory(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.STRUCTURAL),
|
| 147 |
+
# conformational_design (4)
|
| 148 |
+
TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.ENZYME),
|
| 149 |
+
TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.SIGNALING),
|
| 150 |
+
TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.STRUCTURAL),
|
| 151 |
+
TaskCategory(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.FLUORESCENT),
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
_CATEGORY_BY_ID: dict[str, TaskCategory] = {c.category_id: c for c in VALID_CATEGORIES}
|
| 155 |
+
|
| 156 |
+
# OLD → NEW task ID mapping (30 tasks)
|
| 157 |
+
OLD_TO_NEW_MAPPING: dict[str, str] = {
|
| 158 |
+
"binder_001": "dnb_sig_001", "binder_003": "dnb_sig_002",
|
| 159 |
+
"binder_005": "dnb_sig_003", "binder_007": "dnb_sig_004",
|
| 160 |
+
"ppi_004": "dnb_sig_005",
|
| 161 |
+
"binder_002": "dnb_thr_001", "binder_006": "dnb_thr_002",
|
| 162 |
+
"binder_008": "dnb_thr_003", "peptide_001": "dnb_thr_004",
|
| 163 |
+
"peptide_002": "dnb_thr_005", "peptide_003": "dnb_thr_006",
|
| 164 |
+
"antibody_001": "sqo_ab_001", "antibody_002": "sqo_ab_002",
|
| 165 |
+
"antibody_003": "sqo_ab_003", "antibody_004": "sqo_ab_004",
|
| 166 |
+
"antibody_005": "sqo_ab_005",
|
| 167 |
+
"stability_002": "sqo_enz_001", "enzyme_001": "sqo_enz_002",
|
| 168 |
+
"enzyme_002": "sqo_enz_003", "enzyme_003": "sqo_enz_004",
|
| 169 |
+
"stability_003": "sqo_str_001", "stability_004": "sqo_str_002",
|
| 170 |
+
"stability_001": "sqo_flu_001",
|
| 171 |
+
"scaffold_001": "dnk_str_001", "scaffold_002": "dnk_str_002",
|
| 172 |
+
"scaffold_003": "dnk_str_003",
|
| 173 |
+
"ppi_001": "cpx_str_001", "ppi_002": "cpx_str_002",
|
| 174 |
+
"ppi_003": "cfd_sig_001",
|
| 175 |
+
"fluorescence_001": "cfd_flu_001",
|
| 176 |
+
}
|
| 177 |
+
_NEW_TO_OLD_MAPPING: dict[str, str] = {v: k for k, v in OLD_TO_NEW_MAPPING.items()}
|
| 178 |
+
|
| 179 |
+
_NEW_ID_RE = re.compile(r"^([a-z]{2,3})_([a-z]{2,3})_(\d{3})$")
|
| 180 |
+
|
| 181 |
+
_OLD_TYPE_TO_CANONICAL: dict[str, str] = {
|
| 182 |
+
"binder": "de_novo_binder", "antibody": "de_novo_binder",
|
| 183 |
+
"peptide": "de_novo_binder", "stability": "sequence_optimization",
|
| 184 |
+
"enzyme": "sequence_optimization", "fluorescence": "sequence_optimization",
|
| 185 |
+
"scaffold": "de_novo_backbone", "ppi": "complex_engineering",
|
| 186 |
+
}
|
| 187 |
+
_CANONICAL_VALUES = {e.value for e in DesignTaskType}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_category(task_id: str) -> Optional[TaskCategory]:
|
| 191 |
+
"""Get the TaskCategory for a task ID (old or new format)."""
|
| 192 |
+
if task_id in OLD_TO_NEW_MAPPING:
|
| 193 |
+
new_id = OLD_TO_NEW_MAPPING[task_id]
|
| 194 |
+
cat_id = new_id.rsplit("_", 1)[0]
|
| 195 |
+
return _CATEGORY_BY_ID.get(cat_id)
|
| 196 |
+
m = _NEW_ID_RE.match(task_id)
|
| 197 |
+
if m:
|
| 198 |
+
cat_id = f"{m.group(1)}_{m.group(2)}"
|
| 199 |
+
return _CATEGORY_BY_ID.get(cat_id)
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_new_task_id(old_task_id: str) -> Optional[str]:
|
| 204 |
+
return OLD_TO_NEW_MAPPING.get(old_task_id)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_old_task_id(new_task_id: str) -> Optional[str]:
|
| 208 |
+
return _NEW_TO_OLD_MAPPING.get(new_task_id)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def is_valid_category(task_type: DesignTaskType, context: BiologicalContext) -> bool:
|
| 212 |
+
cat_id = f"{task_type.short}_{context.short}"
|
| 213 |
+
return cat_id in _CATEGORY_BY_ID
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def parse_new_task_id(
|
| 217 |
+
task_id: str,
|
| 218 |
+
) -> Optional[tuple[DesignTaskType, BiologicalContext, int]]:
|
| 219 |
+
m = _NEW_ID_RE.match(task_id)
|
| 220 |
+
if not m:
|
| 221 |
+
return None
|
| 222 |
+
task_short, ctx_short, num_str = m.group(1), m.group(2), m.group(3)
|
| 223 |
+
task_type = _SHORT_TO_TASK_TYPE.get(task_short)
|
| 224 |
+
context = _SHORT_TO_CONTEXT.get(ctx_short)
|
| 225 |
+
if task_type is None or context is None:
|
| 226 |
+
return None
|
| 227 |
+
if not is_valid_category(task_type, context):
|
| 228 |
+
return None
|
| 229 |
+
return task_type, context, int(num_str)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def normalize_task_type(task_type: str) -> str:
|
| 233 |
+
lower = task_type.lower().strip()
|
| 234 |
+
if lower in _CANONICAL_VALUES:
|
| 235 |
+
return lower
|
| 236 |
+
return _OLD_TYPE_TO_CANONICAL.get(lower, task_type)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 240 |
+
# SECTION 2 — Sequence Metrics (from biodesignbench/eval/metrics/sequence.py)
|
| 241 |
+
# ════════════════════════════════════════════════════════════════��══════════════
|
| 242 |
+
|
| 243 |
+
_KD_SCALE: dict[str, float] = {
|
| 244 |
+
"A": 1.8, "C": 2.5, "D": -3.5, "E": -3.5, "F": 2.8,
|
| 245 |
+
"G": -0.4, "H": -3.2, "I": 4.5, "K": -3.9, "L": 3.8,
|
| 246 |
+
"M": 1.9, "N": -3.5, "P": -1.6, "Q": -3.5, "R": -4.5,
|
| 247 |
+
"S": -0.8, "T": -0.7, "V": 4.2, "W": -0.9, "Y": -1.3,
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
STANDARD_AAS = set("ACDEFGHIKLMNPQRSTVWY")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def sequence_identity(seq1: str, seq2: str) -> float:
|
| 254 |
+
"""Compute fractional sequence identity between two sequences."""
|
| 255 |
+
if not seq1 or not seq2:
|
| 256 |
+
return 0.0
|
| 257 |
+
s1, s2 = seq1.upper(), seq2.upper()
|
| 258 |
+
if len(s1) == len(s2):
|
| 259 |
+
return sum(a == b for a, b in zip(s1, s2)) / len(s1)
|
| 260 |
+
short, long = (s1, s2) if len(s1) <= len(s2) else (s2, s1)
|
| 261 |
+
best = 0.0
|
| 262 |
+
for offset in range(len(long) - len(short) + 1):
|
| 263 |
+
matches = sum(a == b for a, b in zip(short, long[offset:offset + len(short)]))
|
| 264 |
+
identity = matches / len(short)
|
| 265 |
+
if identity > best:
|
| 266 |
+
best = identity
|
| 267 |
+
return best
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def max_identity_to_reference(designs: list[str], reference: str) -> float:
|
| 271 |
+
if not designs or not reference:
|
| 272 |
+
return 0.0
|
| 273 |
+
return max(sequence_identity(d, reference) for d in designs)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def mean_pairwise_diversity(sequences: list[str]) -> float:
|
| 277 |
+
if len(sequences) < 2:
|
| 278 |
+
return 0.0
|
| 279 |
+
total = 0.0
|
| 280 |
+
count = 0
|
| 281 |
+
for s1, s2 in combinations(sequences, 2):
|
| 282 |
+
total += 1.0 - sequence_identity(s1, s2)
|
| 283 |
+
count += 1
|
| 284 |
+
return total / count if count > 0 else 0.0
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def sequence_entropy(sequences: list[str], truncate: bool = False) -> float:
|
| 288 |
+
if len(sequences) < 2:
|
| 289 |
+
return 0.0
|
| 290 |
+
lengths = {len(s) for s in sequences}
|
| 291 |
+
if len(lengths) != 1:
|
| 292 |
+
if not truncate:
|
| 293 |
+
return 0.0
|
| 294 |
+
seq_len = min(lengths)
|
| 295 |
+
sequences = [s[:seq_len] for s in sequences]
|
| 296 |
+
else:
|
| 297 |
+
seq_len = lengths.pop()
|
| 298 |
+
if seq_len == 0:
|
| 299 |
+
return 0.0
|
| 300 |
+
n = len(sequences)
|
| 301 |
+
total_entropy = 0.0
|
| 302 |
+
for pos in range(seq_len):
|
| 303 |
+
counts: dict[str, int] = {}
|
| 304 |
+
for seq in sequences:
|
| 305 |
+
aa = seq[pos].upper()
|
| 306 |
+
counts[aa] = counts.get(aa, 0) + 1
|
| 307 |
+
pos_entropy = 0.0
|
| 308 |
+
for count in counts.values():
|
| 309 |
+
if count > 0:
|
| 310 |
+
p = count / n
|
| 311 |
+
pos_entropy -= p * math.log(p)
|
| 312 |
+
total_entropy += pos_entropy / math.log(20)
|
| 313 |
+
return total_entropy / seq_len
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def validate_amino_acids(sequence: str) -> dict:
|
| 317 |
+
if not sequence or not sequence.strip():
|
| 318 |
+
return {"valid": False, "invalid_chars": set(), "fraction_valid": 0.0}
|
| 319 |
+
upper = sequence.upper()
|
| 320 |
+
chars = set(upper)
|
| 321 |
+
invalid = chars - STANDARD_AAS
|
| 322 |
+
valid_count = sum(1 for c in upper if c in STANDARD_AAS)
|
| 323 |
+
return {
|
| 324 |
+
"valid": len(invalid) == 0,
|
| 325 |
+
"invalid_chars": invalid,
|
| 326 |
+
"fraction_valid": valid_count / len(upper),
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def check_length_constraints(
|
| 331 |
+
sequence: str,
|
| 332 |
+
length_range: tuple[int, int] | None,
|
| 333 |
+
) -> dict:
|
| 334 |
+
length = len(sequence)
|
| 335 |
+
if length_range is None:
|
| 336 |
+
return {"length": length, "within_range": True, "range": None}
|
| 337 |
+
min_len, max_len = length_range
|
| 338 |
+
return {
|
| 339 |
+
"length": length,
|
| 340 |
+
"within_range": min_len <= length <= max_len,
|
| 341 |
+
"range": length_range,
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def hydrophobicity_profile(sequence: str) -> dict:
|
| 346 |
+
if not sequence:
|
| 347 |
+
return {"mean": 0.0, "std": 0.0, "fraction_hydrophobic": 0.0, "min": 0.0, "max": 0.0}
|
| 348 |
+
values = [_KD_SCALE.get(aa.upper(), 0.0) for aa in sequence]
|
| 349 |
+
n = len(values)
|
| 350 |
+
mean = sum(values) / n
|
| 351 |
+
variance = sum((v - mean) ** 2 for v in values) / n
|
| 352 |
+
std = math.sqrt(variance)
|
| 353 |
+
hydrophobic_count = sum(1 for v in values if v > 0)
|
| 354 |
+
return {
|
| 355 |
+
"mean": round(mean, 3),
|
| 356 |
+
"std": round(std, 3),
|
| 357 |
+
"fraction_hydrophobic": round(hydrophobic_count / n, 3),
|
| 358 |
+
"min": round(min(values), 3),
|
| 359 |
+
"max": round(max(values), 3),
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def count_mutations(wt: str, designed: str) -> int:
|
| 364 |
+
if len(wt) != len(designed):
|
| 365 |
+
return -1
|
| 366 |
+
return sum(a != b for a, b in zip(wt.upper(), designed.upper()))
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 370 |
+
# SECTION 3 — Approach Scoring (from biodesignbench/eval/metrics/approach.py)
|
| 371 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class DesignFunction(str, Enum):
|
| 375 |
+
"""Functional capabilities that tools provide."""
|
| 376 |
+
|
| 377 |
+
BACKBONE_GENERATION = "backbone_generation"
|
| 378 |
+
SEQUENCE_DESIGN = "sequence_design"
|
| 379 |
+
STRUCTURE_PREDICTION = "structure_prediction"
|
| 380 |
+
COMPLEX_PREDICTION = "complex_prediction"
|
| 381 |
+
INTERFACE_ANALYSIS = "interface_analysis"
|
| 382 |
+
STABILITY_SCORING = "stability_scoring"
|
| 383 |
+
ENERGY_MINIMIZATION = "energy_minimization"
|
| 384 |
+
HOTSPOT_IDENTIFICATION = "hotspot_identification"
|
| 385 |
+
SEQUENCE_SCORING = "sequence_scoring"
|
| 386 |
+
PHYSICS_VALIDATION = "physics_validation"
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
TOOL_CATEGORIES: dict[str, str] = {
|
| 390 |
+
"alphafold2": "structure_prediction", "alphafold": "structure_prediction",
|
| 391 |
+
"af2": "structure_prediction", "esmfold": "structure_prediction",
|
| 392 |
+
"openfold": "structure_prediction", "boltz": "structure_prediction",
|
| 393 |
+
"colabfold": "structure_prediction", "omegafold": "structure_prediction",
|
| 394 |
+
"rosettafold": "structure_prediction",
|
| 395 |
+
"proteinmpnn": "sequence_design", "mpnn": "sequence_design",
|
| 396 |
+
"esm_if": "sequence_design", "ligandmpnn": "sequence_design",
|
| 397 |
+
"rfdiffusion": "backbone_generation", "rfdiff": "backbone_generation",
|
| 398 |
+
"chroma": "backbone_generation", "framediff": "backbone_generation",
|
| 399 |
+
"foldingdiff": "backbone_generation",
|
| 400 |
+
"rosetta": "energy_optimization", "pyrosetta": "energy_optimization",
|
| 401 |
+
"foldx": "energy_optimization", "openmm": "energy_optimization",
|
| 402 |
+
"amber": "energy_optimization", "esm2": "energy_optimization",
|
| 403 |
+
"foldseek": "structure_search", "dali": "structure_search",
|
| 404 |
+
"tmalign": "structure_search",
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
MCP_TOOL_EXPANSION: dict[str, list[str]] = {
|
| 408 |
+
"design_binder": ["rfdiffusion", "proteinmpnn", "esmfold"],
|
| 409 |
+
"validate_design": ["esmfold", "alphafold2"],
|
| 410 |
+
"optimize_sequence": ["proteinmpnn"],
|
| 411 |
+
"predict_complex": ["alphafold2"],
|
| 412 |
+
"analyze_interface": ["pyrosetta"],
|
| 413 |
+
"predict_structure": ["esmfold", "alphafold2"],
|
| 414 |
+
"score_stability": ["esm2"],
|
| 415 |
+
"energy_minimize": ["openmm"],
|
| 416 |
+
"suggest_hotspots": [],
|
| 417 |
+
"get_design_status": [],
|
| 418 |
+
"generate_backbone": ["rfdiffusion"],
|
| 419 |
+
"rosetta_score": ["pyrosetta"],
|
| 420 |
+
"rosetta_relax": ["pyrosetta"],
|
| 421 |
+
"rosetta_interface_score": ["pyrosetta"],
|
| 422 |
+
"rosetta_design": ["pyrosetta"],
|
| 423 |
+
"predict_structure_boltz": ["boltz"],
|
| 424 |
+
"predict_affinity_boltz": ["boltz"],
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
TOOL_TO_FUNCTION: dict[str, set[DesignFunction]] = {
|
| 428 |
+
# MCP wrappers
|
| 429 |
+
"design_binder": {DesignFunction.BACKBONE_GENERATION, DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
|
| 430 |
+
"validate_design": {DesignFunction.STRUCTURE_PREDICTION},
|
| 431 |
+
"optimize_sequence": {DesignFunction.SEQUENCE_DESIGN},
|
| 432 |
+
"predict_complex": {DesignFunction.COMPLEX_PREDICTION, DesignFunction.STRUCTURE_PREDICTION},
|
| 433 |
+
"analyze_interface": {DesignFunction.INTERFACE_ANALYSIS},
|
| 434 |
+
"predict_structure": {DesignFunction.STRUCTURE_PREDICTION},
|
| 435 |
+
"score_stability": {DesignFunction.STABILITY_SCORING},
|
| 436 |
+
"energy_minimize": {DesignFunction.ENERGY_MINIMIZATION},
|
| 437 |
+
"suggest_hotspots": {DesignFunction.HOTSPOT_IDENTIFICATION},
|
| 438 |
+
"get_design_status": set(),
|
| 439 |
+
"generate_backbone": {DesignFunction.BACKBONE_GENERATION},
|
| 440 |
+
"rosetta_score": {DesignFunction.PHYSICS_VALIDATION},
|
| 441 |
+
"rosetta_relax": {DesignFunction.ENERGY_MINIMIZATION},
|
| 442 |
+
"rosetta_interface_score": {DesignFunction.INTERFACE_ANALYSIS},
|
| 443 |
+
"rosetta_design": {DesignFunction.SEQUENCE_DESIGN},
|
| 444 |
+
"predict_structure_boltz": {DesignFunction.STRUCTURE_PREDICTION},
|
| 445 |
+
"predict_affinity_boltz": {DesignFunction.COMPLEX_PREDICTION, DesignFunction.INTERFACE_ANALYSIS},
|
| 446 |
+
# Bio-level tools
|
| 447 |
+
"rfdiffusion": {DesignFunction.BACKBONE_GENERATION},
|
| 448 |
+
"proteinmpnn": {DesignFunction.SEQUENCE_DESIGN},
|
| 449 |
+
"alphafold2": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
|
| 450 |
+
"alphafold": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
|
| 451 |
+
"esmfold": {DesignFunction.STRUCTURE_PREDICTION},
|
| 452 |
+
"esm2": {DesignFunction.STABILITY_SCORING, DesignFunction.SEQUENCE_SCORING},
|
| 453 |
+
"pyrosetta": {DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION, DesignFunction.INTERFACE_ANALYSIS},
|
| 454 |
+
"rosetta": {DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION, DesignFunction.INTERFACE_ANALYSIS},
|
| 455 |
+
"openmm": {DesignFunction.ENERGY_MINIMIZATION},
|
| 456 |
+
"boltz": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
|
| 457 |
+
"foldx": {DesignFunction.STABILITY_SCORING, DesignFunction.PHYSICS_VALIDATION},
|
| 458 |
+
"colabfold": {DesignFunction.STRUCTURE_PREDICTION, DesignFunction.COMPLEX_PREDICTION},
|
| 459 |
+
"foldseek": {DesignFunction.STRUCTURE_PREDICTION},
|
| 460 |
+
"chroma": {DesignFunction.BACKBONE_GENERATION},
|
| 461 |
+
"ligandmpnn": {DesignFunction.SEQUENCE_DESIGN},
|
| 462 |
+
"esm_if": {DesignFunction.SEQUENCE_DESIGN},
|
| 463 |
+
"mpnn": {DesignFunction.SEQUENCE_DESIGN},
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class _TaskTypeDict(dict):
|
| 468 |
+
"""Dict that accepts both DesignTaskType enum and string keys."""
|
| 469 |
+
|
| 470 |
+
def __init__(self, raw: dict[str, set[DesignFunction]]):
|
| 471 |
+
super().__init__()
|
| 472 |
+
self._raw = raw
|
| 473 |
+
for k, v in raw.items():
|
| 474 |
+
super().__setitem__(k, v)
|
| 475 |
+
|
| 476 |
+
def __contains__(self, key):
|
| 477 |
+
k = key.value if hasattr(key, "value") else key
|
| 478 |
+
return super().__contains__(k)
|
| 479 |
+
|
| 480 |
+
def __getitem__(self, key):
|
| 481 |
+
k = key.value if hasattr(key, "value") else key
|
| 482 |
+
return super().__getitem__(k)
|
| 483 |
+
|
| 484 |
+
def get(self, key, default=None):
|
| 485 |
+
k = key.value if hasattr(key, "value") else key
|
| 486 |
+
return super().get(k, default)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
REQUIRED_FUNCTIONS = _TaskTypeDict({
|
| 490 |
+
"de_novo_binder": {DesignFunction.BACKBONE_GENERATION, DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
|
| 491 |
+
"sequence_optimization": {DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
|
| 492 |
+
"de_novo_backbone": {DesignFunction.BACKBONE_GENERATION, DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
|
| 493 |
+
"complex_engineering": {DesignFunction.SEQUENCE_DESIGN, DesignFunction.COMPLEX_PREDICTION},
|
| 494 |
+
"conformational_design": {DesignFunction.SEQUENCE_DESIGN, DesignFunction.STRUCTURE_PREDICTION},
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
BONUS_FUNCTIONS = _TaskTypeDict({
|
| 498 |
+
"de_novo_binder": {DesignFunction.COMPLEX_PREDICTION, DesignFunction.INTERFACE_ANALYSIS, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.HOTSPOT_IDENTIFICATION},
|
| 499 |
+
"sequence_optimization": {DesignFunction.STABILITY_SCORING, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION},
|
| 500 |
+
"de_novo_backbone": {DesignFunction.ENERGY_MINIMIZATION, DesignFunction.PHYSICS_VALIDATION},
|
| 501 |
+
"complex_engineering": {DesignFunction.BACKBONE_GENERATION, DesignFunction.INTERFACE_ANALYSIS, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.STRUCTURE_PREDICTION},
|
| 502 |
+
"conformational_design": {DesignFunction.STABILITY_SCORING, DesignFunction.ENERGY_MINIMIZATION, DesignFunction.COMPLEX_PREDICTION},
|
| 503 |
+
})
|
| 504 |
+
|
| 505 |
+
_GENERATION_TOOLS: set[str] = {
|
| 506 |
+
"rfdiffusion", "proteinmpnn", "design_binder", "optimize_sequence",
|
| 507 |
+
"generate_backbone", "rosetta_design", "chroma", "ligandmpnn",
|
| 508 |
+
"esm_if", "mpnn",
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
_VALIDATION_TOOLS: set[str] = {
|
| 512 |
+
"esmfold", "alphafold2", "validate_design", "predict_structure",
|
| 513 |
+
"predict_complex", "score_stability", "rosetta_score",
|
| 514 |
+
"rosetta_interface_score", "predict_structure_boltz",
|
| 515 |
+
"predict_affinity_boltz", "analyze_interface",
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
_REFINEMENT_TOOLS: set[str] = {
|
| 519 |
+
"energy_minimize", "rosetta_relax", "openmm", "pyrosetta", "rosetta",
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def expand_mcp_tools(tools: list[str]) -> list[str]:
|
| 524 |
+
"""Expand MCP wrapper tool names to their underlying bio tools."""
|
| 525 |
+
seen: set[str] = set()
|
| 526 |
+
expanded: list[str] = []
|
| 527 |
+
for tool in tools:
|
| 528 |
+
if tool in MCP_TOOL_EXPANSION:
|
| 529 |
+
underlying = MCP_TOOL_EXPANSION[tool]
|
| 530 |
+
if not underlying:
|
| 531 |
+
if tool not in seen:
|
| 532 |
+
expanded.append(tool)
|
| 533 |
+
seen.add(tool)
|
| 534 |
+
else:
|
| 535 |
+
for ut in underlying:
|
| 536 |
+
if ut not in seen:
|
| 537 |
+
expanded.append(ut)
|
| 538 |
+
seen.add(ut)
|
| 539 |
+
else:
|
| 540 |
+
if tool not in seen:
|
| 541 |
+
expanded.append(tool)
|
| 542 |
+
seen.add(tool)
|
| 543 |
+
return expanded
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def normalize_tool_name(tool: str) -> str:
|
| 547 |
+
return tool.lower().strip().replace(" ", "").replace("-", "").replace("_", "")
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def get_tool_category(tool: str) -> str | None:
|
| 551 |
+
normalized = normalize_tool_name(tool)
|
| 552 |
+
for name, category in TOOL_CATEGORIES.items():
|
| 553 |
+
if normalize_tool_name(name) == normalized:
|
| 554 |
+
return category
|
| 555 |
+
return None
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def _extract_functions_from_tools(tools: list[str]) -> set[DesignFunction]:
|
| 559 |
+
functions: set[DesignFunction] = set()
|
| 560 |
+
for tool in tools:
|
| 561 |
+
if tool in TOOL_TO_FUNCTION:
|
| 562 |
+
functions.update(TOOL_TO_FUNCTION[tool])
|
| 563 |
+
else:
|
| 564 |
+
norm = normalize_tool_name(tool)
|
| 565 |
+
for known, funcs in TOOL_TO_FUNCTION.items():
|
| 566 |
+
if normalize_tool_name(known) == norm:
|
| 567 |
+
functions.update(funcs)
|
| 568 |
+
break
|
| 569 |
+
return functions
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def _check_validation(tools_used: list[str]) -> float:
|
| 573 |
+
if not tools_used:
|
| 574 |
+
return 0.0
|
| 575 |
+
has_generation = False
|
| 576 |
+
has_validation_after_generation = False
|
| 577 |
+
has_any_validation = False
|
| 578 |
+
for tool in tools_used:
|
| 579 |
+
if tool in _GENERATION_TOOLS:
|
| 580 |
+
has_generation = True
|
| 581 |
+
if tool in _VALIDATION_TOOLS:
|
| 582 |
+
has_any_validation = True
|
| 583 |
+
if has_generation:
|
| 584 |
+
has_validation_after_generation = True
|
| 585 |
+
if has_validation_after_generation:
|
| 586 |
+
return 4.0
|
| 587 |
+
if has_any_validation:
|
| 588 |
+
return 2.0
|
| 589 |
+
return 0.0
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def _check_refinement(tools_used: list[str]) -> float:
|
| 593 |
+
if not tools_used:
|
| 594 |
+
return 0.0
|
| 595 |
+
for tool in tools_used:
|
| 596 |
+
if tool in _REFINEMENT_TOOLS:
|
| 597 |
+
return 4.0
|
| 598 |
+
counts = Counter(tools_used)
|
| 599 |
+
for tool, count in counts.items():
|
| 600 |
+
if count >= 2 and (tool in _GENERATION_TOOLS or tool in _VALIDATION_TOOLS):
|
| 601 |
+
return 4.0
|
| 602 |
+
return 0.0
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def _score_approach_legacy(
|
| 606 |
+
tools_used: list[str],
|
| 607 |
+
tools_expected: list[str],
|
| 608 |
+
max_points: int = 20,
|
| 609 |
+
) -> dict:
|
| 610 |
+
if not tools_expected:
|
| 611 |
+
return {
|
| 612 |
+
"score": max_points, "max": max_points,
|
| 613 |
+
"breakdown": [], "tools_matched": [], "tools_missing": [],
|
| 614 |
+
"mode": "legacy",
|
| 615 |
+
}
|
| 616 |
+
expanded_used = expand_mcp_tools(tools_used)
|
| 617 |
+
per_tool = max_points / len(tools_expected)
|
| 618 |
+
used_normalized = [normalize_tool_name(t) for t in expanded_used]
|
| 619 |
+
used_categories = [get_tool_category(t) for t in expanded_used]
|
| 620 |
+
total = 0.0
|
| 621 |
+
breakdown = []
|
| 622 |
+
matched = []
|
| 623 |
+
missing = []
|
| 624 |
+
for expected in tools_expected:
|
| 625 |
+
expected_norm = normalize_tool_name(expected)
|
| 626 |
+
expected_cat = get_tool_category(expected)
|
| 627 |
+
if expected_norm in used_normalized:
|
| 628 |
+
total += per_tool
|
| 629 |
+
breakdown.append({"tool": expected, "match": "exact", "points": per_tool})
|
| 630 |
+
matched.append(expected)
|
| 631 |
+
elif expected_cat and expected_cat in used_categories:
|
| 632 |
+
points = per_tool * 0.7
|
| 633 |
+
total += points
|
| 634 |
+
breakdown.append({"tool": expected, "match": "category", "points": points})
|
| 635 |
+
matched.append(expected)
|
| 636 |
+
else:
|
| 637 |
+
breakdown.append({"tool": expected, "match": "none", "points": 0})
|
| 638 |
+
missing.append(expected)
|
| 639 |
+
return {
|
| 640 |
+
"score": int(round(total)), "max": max_points,
|
| 641 |
+
"breakdown": breakdown, "tools_matched": matched,
|
| 642 |
+
"tools_missing": missing, "mode": "legacy",
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def score_approach(
|
| 647 |
+
tools_used: list[str],
|
| 648 |
+
tools_expected: list[str],
|
| 649 |
+
max_points: int = 20,
|
| 650 |
+
task_type: DesignTaskType | str | None = None,
|
| 651 |
+
) -> dict:
|
| 652 |
+
"""Score the agent's tool/methodology selection."""
|
| 653 |
+
if task_type is None:
|
| 654 |
+
return _score_approach_legacy(tools_used, tools_expected, max_points)
|
| 655 |
+
|
| 656 |
+
tt_key = task_type.value if hasattr(task_type, "value") else str(task_type)
|
| 657 |
+
scale = max_points / 20.0
|
| 658 |
+
func_max = 12.0 * scale
|
| 659 |
+
|
| 660 |
+
agent_functions = _extract_functions_from_tools(tools_used)
|
| 661 |
+
required = REQUIRED_FUNCTIONS.get(tt_key, set())
|
| 662 |
+
bonus = BONUS_FUNCTIONS.get(tt_key, set())
|
| 663 |
+
|
| 664 |
+
if required:
|
| 665 |
+
covered_required = agent_functions & required
|
| 666 |
+
required_ratio = len(covered_required) / len(required)
|
| 667 |
+
else:
|
| 668 |
+
required_ratio = 1.0 if agent_functions else 0.0
|
| 669 |
+
covered_required = set()
|
| 670 |
+
|
| 671 |
+
covered_bonus = agent_functions & bonus
|
| 672 |
+
bonus_count = min(len(covered_bonus), 3)
|
| 673 |
+
func_score = (required_ratio * 9.0 + bonus_count * 1.0) * scale
|
| 674 |
+
func_score = min(func_score, func_max)
|
| 675 |
+
|
| 676 |
+
val_score = _check_validation(tools_used) * scale
|
| 677 |
+
ref_score = _check_refinement(tools_used) * scale
|
| 678 |
+
|
| 679 |
+
total = min(func_score + val_score + ref_score, float(max_points))
|
| 680 |
+
|
| 681 |
+
return {
|
| 682 |
+
"score": int(round(total)), "max": max_points, "mode": "function",
|
| 683 |
+
"function_coverage": round(func_score, 1),
|
| 684 |
+
"validation_inclusion": round(val_score, 1),
|
| 685 |
+
"iterative_refinement": round(ref_score, 1),
|
| 686 |
+
"required_functions": sorted(f.value for f in required),
|
| 687 |
+
"covered_required": sorted(f.value for f in covered_required),
|
| 688 |
+
"covered_bonus": sorted(f.value for f in covered_bonus),
|
| 689 |
+
"agent_functions": sorted(f.value for f in agent_functions),
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 694 |
+
# SECTION 4 — Orchestration Scoring (from biodesignbench/eval/metrics/orchestration.py)
|
| 695 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 696 |
+
|
| 697 |
+
EXPECTED_PIPELINES: dict[str, list[str]] = {
|
| 698 |
+
"de_novo_binder": ["rfdiffusion", "proteinmpnn", "esmfold"],
|
| 699 |
+
"sequence_optimization": ["proteinmpnn", "esmfold"],
|
| 700 |
+
"de_novo_backbone": ["rfdiffusion", "proteinmpnn", "esmfold"],
|
| 701 |
+
"complex_engineering": ["rfdiffusion", "proteinmpnn", "esmfold"],
|
| 702 |
+
"conformational_design": ["proteinmpnn", "esmfold"],
|
| 703 |
+
# Old category names (backward compat)
|
| 704 |
+
"binder": ["rfdiffusion", "proteinmpnn", "esmfold"],
|
| 705 |
+
"antibody": ["proteinmpnn", "esmfold"],
|
| 706 |
+
"stability": ["proteinmpnn", "esmfold"],
|
| 707 |
+
"enzyme": ["rfdiffusion", "proteinmpnn", "esmfold"],
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
ORCHESTRATION_VALIDATION_TOOLS: set[str] = {
|
| 711 |
+
"validate_design", "predict_complex", "analyze_interface",
|
| 712 |
+
"esmfold", "score_stability", "rosetta_score",
|
| 713 |
+
"rosetta_interface_score", "predict_structure_boltz",
|
| 714 |
+
"predict_affinity_boltz",
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def _expand_tool_name(tool: str) -> list[str]:
|
| 719 |
+
if tool in MCP_TOOL_EXPANSION:
|
| 720 |
+
underlying = MCP_TOOL_EXPANSION[tool]
|
| 721 |
+
return underlying if underlying else [tool]
|
| 722 |
+
return [tool]
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def _extract_ordered_bio_tools(tool_call_log: list[dict[str, Any]]) -> list[str]:
|
| 726 |
+
utility_tools = {"execute_python", "read_file", "write_file"}
|
| 727 |
+
ordered: list[str] = []
|
| 728 |
+
for entry in tool_call_log:
|
| 729 |
+
tool = entry.get("tool", "")
|
| 730 |
+
if tool in utility_tools:
|
| 731 |
+
continue
|
| 732 |
+
expanded = _expand_tool_name(tool)
|
| 733 |
+
for t in expanded:
|
| 734 |
+
ordered.append(normalize_tool_name(t))
|
| 735 |
+
return ordered
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
def _longest_ordered_subsequence_length(
|
| 739 |
+
actual: list[str], expected: list[str]
|
| 740 |
+
) -> int:
|
| 741 |
+
if not expected or not actual:
|
| 742 |
+
return 0
|
| 743 |
+
j = 0
|
| 744 |
+
matched = 0
|
| 745 |
+
for tool in actual:
|
| 746 |
+
k = j
|
| 747 |
+
while k < len(expected):
|
| 748 |
+
if tool == normalize_tool_name(expected[k]):
|
| 749 |
+
matched += 1
|
| 750 |
+
j = k + 1
|
| 751 |
+
break
|
| 752 |
+
k += 1
|
| 753 |
+
return matched
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def _count_validation_steps(tool_call_log: list[dict[str, Any]]) -> int:
|
| 757 |
+
count = 0
|
| 758 |
+
for entry in tool_call_log:
|
| 759 |
+
tool = entry.get("tool", "")
|
| 760 |
+
if tool in ORCHESTRATION_VALIDATION_TOOLS:
|
| 761 |
+
count += 1
|
| 762 |
+
expanded = _expand_tool_name(tool)
|
| 763 |
+
for t in expanded:
|
| 764 |
+
if t in ORCHESTRATION_VALIDATION_TOOLS and tool not in ORCHESTRATION_VALIDATION_TOOLS:
|
| 765 |
+
count += 1
|
| 766 |
+
return count
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
def _has_adaptive_behavior(tool_call_log: list[dict[str, Any]]) -> bool:
|
| 770 |
+
tool_args: dict[str, list[dict]] = {}
|
| 771 |
+
for entry in tool_call_log:
|
| 772 |
+
tool = entry.get("tool", "")
|
| 773 |
+
args = entry.get("args_summary", {})
|
| 774 |
+
if tool not in tool_args:
|
| 775 |
+
tool_args[tool] = []
|
| 776 |
+
tool_args[tool].append(args)
|
| 777 |
+
for tool, args_list in tool_args.items():
|
| 778 |
+
if len(args_list) >= 2:
|
| 779 |
+
for i in range(1, len(args_list)):
|
| 780 |
+
if args_list[i] != args_list[i - 1]:
|
| 781 |
+
return True
|
| 782 |
+
return False
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def _get_task_category_for_orchestration(task_id: str) -> str | None:
|
| 786 |
+
"""Extract category from task_id using taxonomy, with legacy fallback."""
|
| 787 |
+
category = get_category(task_id)
|
| 788 |
+
if category is not None:
|
| 789 |
+
return category.task_type.value
|
| 790 |
+
for cat in ("binder", "antibody", "stability", "enzyme"):
|
| 791 |
+
if task_id.startswith(cat):
|
| 792 |
+
return cat
|
| 793 |
+
return None
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def score_orchestration(
|
| 797 |
+
tool_call_log: list[dict[str, Any]],
|
| 798 |
+
task_id: str,
|
| 799 |
+
max_points: int = 15,
|
| 800 |
+
) -> dict[str, Any]:
|
| 801 |
+
"""Score the agent's multi-step pipeline orchestration."""
|
| 802 |
+
if not tool_call_log:
|
| 803 |
+
return {
|
| 804 |
+
"score": 0, "max": max_points,
|
| 805 |
+
"pipeline_order_score": 0.0, "validation_score": 0.0,
|
| 806 |
+
"adaptive_score": 0.0, "details": "No tool calls recorded",
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
category = _get_task_category_for_orchestration(task_id)
|
| 810 |
+
expected_pipeline = EXPECTED_PIPELINES.get(category, [])
|
| 811 |
+
|
| 812 |
+
ordered_tools = _extract_ordered_bio_tools(tool_call_log)
|
| 813 |
+
if expected_pipeline:
|
| 814 |
+
matched = _longest_ordered_subsequence_length(ordered_tools, expected_pipeline)
|
| 815 |
+
order_ratio = matched / len(expected_pipeline)
|
| 816 |
+
else:
|
| 817 |
+
order_ratio = 1.0 if ordered_tools else 0.0
|
| 818 |
+
|
| 819 |
+
pipeline_points = order_ratio * max_points * 0.5
|
| 820 |
+
|
| 821 |
+
validation_count = _count_validation_steps(tool_call_log)
|
| 822 |
+
if validation_count >= 2:
|
| 823 |
+
validation_ratio = 1.0
|
| 824 |
+
elif validation_count == 1:
|
| 825 |
+
validation_ratio = 0.6
|
| 826 |
+
else:
|
| 827 |
+
validation_ratio = 0.0
|
| 828 |
+
validation_points = validation_ratio * max_points * 0.3
|
| 829 |
+
|
| 830 |
+
adaptive = _has_adaptive_behavior(tool_call_log)
|
| 831 |
+
adaptive_points = max_points * 0.2 if adaptive else 0.0
|
| 832 |
+
|
| 833 |
+
total = int(round(pipeline_points + validation_points + adaptive_points))
|
| 834 |
+
|
| 835 |
+
return {
|
| 836 |
+
"score": min(total, max_points), "max": max_points,
|
| 837 |
+
"pipeline_order_score": round(pipeline_points, 1),
|
| 838 |
+
"validation_score": round(validation_points, 1),
|
| 839 |
+
"adaptive_score": round(adaptive_points, 1),
|
| 840 |
+
"expected_pipeline": expected_pipeline,
|
| 841 |
+
"actual_tool_order": ordered_tools,
|
| 842 |
+
"validation_steps": validation_count,
|
| 843 |
+
"adaptive_behavior": adaptive,
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 848 |
+
# SECTION 5 — Quality + Scoring (from biodesignbench/eval/tier2/scoring.py)
|
| 849 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 850 |
+
|
| 851 |
+
DEFAULT_DESIGN_RUBRIC = {
|
| 852 |
+
"approach": 20, "orchestration": 15, "quality": 35,
|
| 853 |
+
"feasibility": 15, "novelty": 5, "diversity": 10,
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
METRIC_RANGES: dict[str, tuple[float, float]] = {
|
| 857 |
+
"pLDDT": (0, 100), "pTM": (0, 1), "ipTM": (0, 1),
|
| 858 |
+
"i_pAE": (0, 50), "predicted_kd": (0, 1e6),
|
| 859 |
+
"predicted_ddG": (-100, 100), "active_site_rmsd": (0, 50),
|
| 860 |
+
"max_sequence_identity": (0, 1), "TM_score": (0, 1),
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
THRESHOLD_TO_METRIC: dict[str, tuple[str, str]] = {
|
| 864 |
+
"pLDDT_good": ("pLDDT", "higher_is_better"),
|
| 865 |
+
"ipTM_good": ("ipTM", "higher_is_better"),
|
| 866 |
+
"kd_nM_good": ("predicted_kd", "lower_is_better"),
|
| 867 |
+
"predicted_ddG_good": ("predicted_ddG", "lower_is_better"),
|
| 868 |
+
"active_site_rmsd_good": ("active_site_rmsd", "lower_is_better"),
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
# Tier A: Structure Confidence
|
| 872 |
+
_TIER_A_THRESHOLDS: dict[str, dict[str, float]] = {
|
| 873 |
+
"pLDDT": {"pass": 65, "good": 80, "excellent": 90},
|
| 874 |
+
"pTM": {"pass": 0.45, "good": 0.65, "excellent": 0.80},
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
# Tier B: Interface Confidence (binding only)
|
| 878 |
+
_TIER_B_THRESHOLDS: dict[str, dict[str, float]] = {
|
| 879 |
+
"ipTM": {"pass": 0.15, "good": 0.40, "excellent": 0.70},
|
| 880 |
+
"i_pAE": {"pass": 25.0, "good": 15.0, "excellent": 8.0},
|
| 881 |
+
}
|
| 882 |
+
_TIER_B_DIRECTIONS: dict[str, str] = {"i_pAE": "lower_is_better"}
|
| 883 |
+
|
| 884 |
+
# Tier C: Interface Physics
|
| 885 |
+
_TIER_C_METRICS: dict[str, tuple[str, str]] = {
|
| 886 |
+
"kd_nM_good": ("predicted_kd", "lower_is_better"),
|
| 887 |
+
"predicted_ddG_good": ("predicted_ddG", "lower_is_better"),
|
| 888 |
+
"active_site_rmsd_good": ("active_site_rmsd", "lower_is_better"),
|
| 889 |
+
}
|
| 890 |
+
_TIER_C_PHYSICS: dict[str, dict[str, float]] = {
|
| 891 |
+
"buried_surface_area": {"pass": 800, "good": 1500, "excellent": 2500},
|
| 892 |
+
"hydrogen_bonds": {"pass": 5, "good": 15, "excellent": 30},
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
_TIER_A_BASE = 15
|
| 896 |
+
_TIER_B_BASE = 10
|
| 897 |
+
_TIER_C_BASE = 10
|
| 898 |
+
_QUALITY_BASE = _TIER_A_BASE + _TIER_B_BASE + _TIER_C_BASE # 35
|
| 899 |
+
|
| 900 |
+
_BINDING_TASK_TYPES: set[DesignTaskType] = {
|
| 901 |
+
DesignTaskType.DE_NOVO_BINDER,
|
| 902 |
+
DesignTaskType.COMPLEX_ENGINEERING,
|
| 903 |
+
}
|
| 904 |
+
_BINDING_OLD_PREFIXES: set[str] = {"binder", "antibody", "ppi", "peptide"}
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def _is_binding_task(task_id: str | None) -> bool:
|
| 908 |
+
if not task_id:
|
| 909 |
+
return False
|
| 910 |
+
cat = get_category(task_id)
|
| 911 |
+
if cat is not None:
|
| 912 |
+
return cat.task_type in _BINDING_TASK_TYPES
|
| 913 |
+
prefix = task_id.split("_")[0]
|
| 914 |
+
return prefix in _BINDING_OLD_PREFIXES
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def _get_tier_weights(
|
| 918 |
+
task_id: str | None = None,
|
| 919 |
+
max_points: int = 35,
|
| 920 |
+
) -> tuple[int, int, int]:
|
| 921 |
+
if not task_id:
|
| 922 |
+
scale = max_points / _QUALITY_BASE if _QUALITY_BASE > 0 else 0
|
| 923 |
+
return (
|
| 924 |
+
int(round(_TIER_A_BASE * scale)),
|
| 925 |
+
int(round(_TIER_B_BASE * scale)),
|
| 926 |
+
int(round(_TIER_C_BASE * scale)),
|
| 927 |
+
)
|
| 928 |
+
is_binding = _is_binding_task(task_id)
|
| 929 |
+
cat = get_category(task_id)
|
| 930 |
+
if cat is None and not is_binding:
|
| 931 |
+
scale = max_points / _QUALITY_BASE if _QUALITY_BASE > 0 else 0
|
| 932 |
+
return (
|
| 933 |
+
int(round(_TIER_A_BASE * scale)),
|
| 934 |
+
int(round(_TIER_B_BASE * scale)),
|
| 935 |
+
int(round(_TIER_C_BASE * scale)),
|
| 936 |
+
)
|
| 937 |
+
if is_binding:
|
| 938 |
+
ratio_a = 12 / 35
|
| 939 |
+
ratio_b = 18 / 35
|
| 940 |
+
a = int(round(max_points * ratio_a))
|
| 941 |
+
b = int(round(max_points * ratio_b))
|
| 942 |
+
c = max_points - a - b
|
| 943 |
+
return (a, b, c)
|
| 944 |
+
else:
|
| 945 |
+
ratio_a = 25 / 35
|
| 946 |
+
ratio_b = 10 / 35
|
| 947 |
+
a = int(round(max_points * ratio_a))
|
| 948 |
+
b = int(round(max_points * ratio_b))
|
| 949 |
+
c = max_points - a - b
|
| 950 |
+
return (a, b, c)
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
def _continuous_score(
|
| 954 |
+
value: float,
|
| 955 |
+
thresholds: dict[str, float],
|
| 956 |
+
direction: str = "higher_is_better",
|
| 957 |
+
) -> float:
|
| 958 |
+
"""Return continuous fraction [0.0, 1.0] via linear interpolation."""
|
| 959 |
+
p, g, e = thresholds["pass"], thresholds["good"], thresholds["excellent"]
|
| 960 |
+
|
| 961 |
+
if direction == "lower_is_better":
|
| 962 |
+
floor = p + abs(p) * 0.3 if p != 0 else 0.3
|
| 963 |
+
if value <= e:
|
| 964 |
+
return 1.0
|
| 965 |
+
if value >= floor:
|
| 966 |
+
return 0.0
|
| 967 |
+
if value <= g:
|
| 968 |
+
span = g - e
|
| 969 |
+
if span == 0:
|
| 970 |
+
return 1.0
|
| 971 |
+
return 0.66 + (g - value) / span * 0.34
|
| 972 |
+
if value <= p:
|
| 973 |
+
span = p - g
|
| 974 |
+
if span == 0:
|
| 975 |
+
return 0.66
|
| 976 |
+
return 0.33 + (p - value) / span * 0.33
|
| 977 |
+
span = floor - p
|
| 978 |
+
if span == 0:
|
| 979 |
+
return 0.0
|
| 980 |
+
return 0.33 * (floor - value) / span
|
| 981 |
+
|
| 982 |
+
# higher_is_better
|
| 983 |
+
floor = p * 0.7
|
| 984 |
+
if value >= e:
|
| 985 |
+
return 1.0
|
| 986 |
+
if value <= floor:
|
| 987 |
+
return 0.0
|
| 988 |
+
if value >= g:
|
| 989 |
+
span = e - g
|
| 990 |
+
if span == 0:
|
| 991 |
+
return 1.0
|
| 992 |
+
return 0.66 + (value - g) / span * 0.34
|
| 993 |
+
if value >= p:
|
| 994 |
+
span = g - p
|
| 995 |
+
if span == 0:
|
| 996 |
+
return 0.66
|
| 997 |
+
return 0.33 + (value - p) / span * 0.33
|
| 998 |
+
span = p - floor
|
| 999 |
+
if span == 0:
|
| 1000 |
+
return 0.0
|
| 1001 |
+
return 0.33 * (value - floor) / span
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
# Category-specific quality metrics (17 valid taxonomy cells)
|
| 1005 |
+
QUALITY_METRICS: dict[tuple[DesignTaskType, BiologicalContext], dict[str, Any]] = {
|
| 1006 |
+
# de_novo_binder (4 cells)
|
| 1007 |
+
(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ANTIBODY): {
|
| 1008 |
+
"primary_metric": "ipTM",
|
| 1009 |
+
"thresholds": {"excellent": 0.75, "good": 0.50, "pass": 0.20},
|
| 1010 |
+
"secondary_metrics": ["pLDDT", "predicted_kd"],
|
| 1011 |
+
},
|
| 1012 |
+
(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.SIGNALING): {
|
| 1013 |
+
"primary_metric": "ipTM",
|
| 1014 |
+
"thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
|
| 1015 |
+
"secondary_metrics": ["pLDDT", "predicted_kd"],
|
| 1016 |
+
},
|
| 1017 |
+
(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.THERAPEUTIC): {
|
| 1018 |
+
"primary_metric": "ipTM",
|
| 1019 |
+
"thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
|
| 1020 |
+
"secondary_metrics": ["pLDDT", "predicted_kd"],
|
| 1021 |
+
},
|
| 1022 |
+
(DesignTaskType.DE_NOVO_BINDER, BiologicalContext.ENZYME): {
|
| 1023 |
+
"primary_metric": "ipTM",
|
| 1024 |
+
"thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
|
| 1025 |
+
"secondary_metrics": ["pLDDT", "predicted_kd", "active_site_rmsd"],
|
| 1026 |
+
},
|
| 1027 |
+
# sequence_optimization (5 cells)
|
| 1028 |
+
(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ANTIBODY): {
|
| 1029 |
+
"primary_metric": "pLDDT",
|
| 1030 |
+
"thresholds": {"excellent": 90, "good": 80, "pass": 65},
|
| 1031 |
+
"secondary_metrics": ["ipTM", "max_sequence_identity"],
|
| 1032 |
+
},
|
| 1033 |
+
(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.ENZYME): {
|
| 1034 |
+
"primary_metric": "pLDDT",
|
| 1035 |
+
"thresholds": {"excellent": 90, "good": 80, "pass": 65},
|
| 1036 |
+
"secondary_metrics": ["predicted_ddG", "active_site_rmsd"],
|
| 1037 |
+
},
|
| 1038 |
+
(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.STRUCTURAL): {
|
| 1039 |
+
"primary_metric": "pLDDT",
|
| 1040 |
+
"thresholds": {"excellent": 92, "good": 82, "pass": 68},
|
| 1041 |
+
"secondary_metrics": ["TM_score", "predicted_ddG"],
|
| 1042 |
+
},
|
| 1043 |
+
(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.FLUORESCENT): {
|
| 1044 |
+
"primary_metric": "pLDDT",
|
| 1045 |
+
"thresholds": {"excellent": 88, "good": 78, "pass": 62},
|
| 1046 |
+
"secondary_metrics": ["predicted_ddG", "max_sequence_identity"],
|
| 1047 |
+
},
|
| 1048 |
+
(DesignTaskType.SEQUENCE_OPTIMIZATION, BiologicalContext.SIGNALING): {
|
| 1049 |
+
"primary_metric": "pLDDT",
|
| 1050 |
+
"thresholds": {"excellent": 90, "good": 80, "pass": 65},
|
| 1051 |
+
"secondary_metrics": ["ipTM", "predicted_ddG"],
|
| 1052 |
+
},
|
| 1053 |
+
# de_novo_backbone (1 cell)
|
| 1054 |
+
(DesignTaskType.DE_NOVO_BACKBONE, BiologicalContext.STRUCTURAL): {
|
| 1055 |
+
"primary_metric": "pLDDT",
|
| 1056 |
+
"thresholds": {"excellent": 88, "good": 78, "pass": 60},
|
| 1057 |
+
"secondary_metrics": ["TM_score", "predicted_ddG"],
|
| 1058 |
+
},
|
| 1059 |
+
# complex_engineering (3 cells)
|
| 1060 |
+
(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.SIGNALING): {
|
| 1061 |
+
"primary_metric": "ipTM",
|
| 1062 |
+
"thresholds": {"excellent": 0.72, "good": 0.48, "pass": 0.20},
|
| 1063 |
+
"secondary_metrics": ["pLDDT", "predicted_kd"],
|
| 1064 |
+
},
|
| 1065 |
+
(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.STRUCTURAL): {
|
| 1066 |
+
"primary_metric": "ipTM",
|
| 1067 |
+
"thresholds": {"excellent": 0.72, "good": 0.48, "pass": 0.20},
|
| 1068 |
+
"secondary_metrics": ["pLDDT", "TM_score"],
|
| 1069 |
+
},
|
| 1070 |
+
(DesignTaskType.COMPLEX_ENGINEERING, BiologicalContext.ENZYME): {
|
| 1071 |
+
"primary_metric": "ipTM",
|
| 1072 |
+
"thresholds": {"excellent": 0.70, "good": 0.45, "pass": 0.18},
|
| 1073 |
+
"secondary_metrics": ["pLDDT", "predicted_kd", "active_site_rmsd"],
|
| 1074 |
+
},
|
| 1075 |
+
# conformational_design (4 cells)
|
| 1076 |
+
(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.ENZYME): {
|
| 1077 |
+
"primary_metric": "pLDDT",
|
| 1078 |
+
"thresholds": {"excellent": 88, "good": 78, "pass": 62},
|
| 1079 |
+
"secondary_metrics": ["predicted_ddG", "active_site_rmsd"],
|
| 1080 |
+
},
|
| 1081 |
+
(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.SIGNALING): {
|
| 1082 |
+
"primary_metric": "pLDDT",
|
| 1083 |
+
"thresholds": {"excellent": 85, "good": 75, "pass": 60},
|
| 1084 |
+
"secondary_metrics": ["ipTM", "predicted_kd"],
|
| 1085 |
+
},
|
| 1086 |
+
(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.FLUORESCENT): {
|
| 1087 |
+
"primary_metric": "pLDDT",
|
| 1088 |
+
"thresholds": {"excellent": 85, "good": 75, "pass": 60},
|
| 1089 |
+
"secondary_metrics": ["predicted_ddG", "max_sequence_identity"],
|
| 1090 |
+
},
|
| 1091 |
+
(DesignTaskType.CONFORMATIONAL_DESIGN, BiologicalContext.STRUCTURAL): {
|
| 1092 |
+
"primary_metric": "pLDDT",
|
| 1093 |
+
"thresholds": {"excellent": 88, "good": 78, "pass": 62},
|
| 1094 |
+
"secondary_metrics": ["TM_score", "predicted_ddG"],
|
| 1095 |
+
},
|
| 1096 |
+
}
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
def get_quality_config(task_id: str) -> dict[str, Any] | None:
|
| 1100 |
+
category = get_category(task_id)
|
| 1101 |
+
if category is None:
|
| 1102 |
+
return None
|
| 1103 |
+
key = (category.task_type, category.context)
|
| 1104 |
+
return QUALITY_METRICS.get(key)
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
@dataclass
|
| 1108 |
+
class DesignScoringRubric:
|
| 1109 |
+
components: dict[str, int] = field(default_factory=lambda: dict(DEFAULT_DESIGN_RUBRIC))
|
| 1110 |
+
|
| 1111 |
+
@property
|
| 1112 |
+
def max_score(self) -> int:
|
| 1113 |
+
return sum(self.components.values())
|
| 1114 |
+
|
| 1115 |
+
def validate(self) -> None:
|
| 1116 |
+
total = sum(self.components.values())
|
| 1117 |
+
if total != 100:
|
| 1118 |
+
raise ValueError(f"Rubric total must be 100, got {total}")
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
def _has_reasonable_composition(seq: str, min_length: int = 20) -> bool:
|
| 1122 |
+
upper = seq.upper()
|
| 1123 |
+
if len(upper) < min_length:
|
| 1124 |
+
return False
|
| 1125 |
+
unique_aas = len(set(upper))
|
| 1126 |
+
if unique_aas < 5:
|
| 1127 |
+
return False
|
| 1128 |
+
counts = Counter(upper)
|
| 1129 |
+
max_fraction = max(counts.values()) / len(upper)
|
| 1130 |
+
if max_fraction > 0.5:
|
| 1131 |
+
return False
|
| 1132 |
+
ala_fraction = counts.get("A", 0) / len(upper)
|
| 1133 |
+
if ala_fraction > 0.3:
|
| 1134 |
+
return False
|
| 1135 |
+
hp = hydrophobicity_profile(upper)
|
| 1136 |
+
if hp["mean"] > 2.0:
|
| 1137 |
+
return False
|
| 1138 |
+
return True
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
def validate_metric_range(name: str, value: float) -> bool:
|
| 1142 |
+
if name not in METRIC_RANGES:
|
| 1143 |
+
return True
|
| 1144 |
+
low, high = METRIC_RANGES[name]
|
| 1145 |
+
return low <= value <= high
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
# Functional Similarity thresholds for non-binding Tier B
|
| 1149 |
+
_FUNCTIONAL_SIM_DEFAULTS: dict[DesignTaskType, dict[str, float]] = {
|
| 1150 |
+
DesignTaskType.SEQUENCE_OPTIMIZATION: {"pass": 0.40, "good": 0.60, "excellent": 0.85},
|
| 1151 |
+
DesignTaskType.CONFORMATIONAL_DESIGN: {"pass": 0.15, "good": 0.30, "excellent": 0.50},
|
| 1152 |
+
DesignTaskType.DE_NOVO_BACKBONE: {"pass": 0.10, "good": 0.20, "excellent": 0.40},
|
| 1153 |
+
}
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
def _derive_functional_sim_thresholds(value: float) -> dict[str, float]:
|
| 1157 |
+
return {
|
| 1158 |
+
"pass": value * 0.5,
|
| 1159 |
+
"good": value,
|
| 1160 |
+
"excellent": min(value * 2, 1.0),
|
| 1161 |
+
}
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
def _get_functional_sim_thresholds(
|
| 1165 |
+
thresholds: dict[str, float],
|
| 1166 |
+
task_id: str,
|
| 1167 |
+
) -> dict[str, float] | None:
|
| 1168 |
+
if _is_binding_task(task_id):
|
| 1169 |
+
return None
|
| 1170 |
+
gt_value = thresholds.get("max_seq_identity_good")
|
| 1171 |
+
if gt_value is not None:
|
| 1172 |
+
return _derive_functional_sim_thresholds(gt_value)
|
| 1173 |
+
cat = get_category(task_id)
|
| 1174 |
+
if cat is None:
|
| 1175 |
+
return None
|
| 1176 |
+
return _FUNCTIONAL_SIM_DEFAULTS.get(cat.task_type)
|
| 1177 |
+
|
| 1178 |
+
|
| 1179 |
+
def _score_functional_similarity(
|
| 1180 |
+
designs: list[str],
|
| 1181 |
+
oracle_sequences: list[str],
|
| 1182 |
+
thresholds: dict[str, float],
|
| 1183 |
+
) -> float | None:
|
| 1184 |
+
if not designs or not oracle_sequences:
|
| 1185 |
+
return None
|
| 1186 |
+
best_identity = 0.0
|
| 1187 |
+
for design in designs:
|
| 1188 |
+
for oracle in oracle_sequences:
|
| 1189 |
+
ident = sequence_identity(design, oracle)
|
| 1190 |
+
if ident > best_identity:
|
| 1191 |
+
best_identity = ident
|
| 1192 |
+
return _continuous_score(best_identity, thresholds, "higher_is_better")
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
def score_quality(
|
| 1196 |
+
agent_metrics: dict[str, float],
|
| 1197 |
+
thresholds: dict[str, float],
|
| 1198 |
+
max_points: int = 35,
|
| 1199 |
+
task_id: str | None = None,
|
| 1200 |
+
designs: list[str] | None = None,
|
| 1201 |
+
oracle_sequences: list[str] | None = None,
|
| 1202 |
+
) -> dict[str, Any]:
|
| 1203 |
+
"""Score quality using 3-tier continuous system."""
|
| 1204 |
+
valid_metrics = {
|
| 1205 |
+
k: v for k, v in agent_metrics.items() if validate_metric_range(k, v)
|
| 1206 |
+
}
|
| 1207 |
+
for extra_key in ("buried_surface_area", "hydrogen_bonds"):
|
| 1208 |
+
if extra_key in agent_metrics and extra_key not in valid_metrics:
|
| 1209 |
+
val = agent_metrics[extra_key]
|
| 1210 |
+
if isinstance(val, (int, float)) and val >= 0:
|
| 1211 |
+
valid_metrics[extra_key] = float(val)
|
| 1212 |
+
|
| 1213 |
+
tier_a_max, tier_b_max, tier_c_max = _get_tier_weights(task_id, max_points)
|
| 1214 |
+
is_binding = _is_binding_task(task_id)
|
| 1215 |
+
|
| 1216 |
+
overrides: dict[str, dict[str, float]] = {}
|
| 1217 |
+
if task_id:
|
| 1218 |
+
config = get_quality_config(task_id)
|
| 1219 |
+
if config and "thresholds" in config:
|
| 1220 |
+
primary = config["primary_metric"]
|
| 1221 |
+
overrides[primary] = config["thresholds"]
|
| 1222 |
+
|
| 1223 |
+
# Tier A: Structure Confidence
|
| 1224 |
+
tier_a_scores: dict[str, float] = {}
|
| 1225 |
+
for metric, default_thresh in _TIER_A_THRESHOLDS.items():
|
| 1226 |
+
if metric in valid_metrics:
|
| 1227 |
+
thresh = overrides.get(metric, default_thresh)
|
| 1228 |
+
tier_a_scores[metric] = _continuous_score(
|
| 1229 |
+
valid_metrics[metric], thresh, "higher_is_better"
|
| 1230 |
+
)
|
| 1231 |
+
tier_a_pts = (sum(tier_a_scores.values()) / len(tier_a_scores)) * tier_a_max if tier_a_scores else 0.0
|
| 1232 |
+
|
| 1233 |
+
# Tier B: Interface or Functional Similarity
|
| 1234 |
+
tier_b_scores: dict[str, float] = {}
|
| 1235 |
+
tier_b_pts = 0.0
|
| 1236 |
+
_use_functional_sim = (
|
| 1237 |
+
tier_b_max > 0
|
| 1238 |
+
and task_id is not None
|
| 1239 |
+
and not is_binding
|
| 1240 |
+
and get_category(task_id) is not None
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
if tier_b_max > 0:
|
| 1244 |
+
if _use_functional_sim:
|
| 1245 |
+
if designs and oracle_sequences:
|
| 1246 |
+
func_thresh = _get_functional_sim_thresholds(thresholds, task_id)
|
| 1247 |
+
if func_thresh is not None:
|
| 1248 |
+
frac = _score_functional_similarity(designs, oracle_sequences, func_thresh)
|
| 1249 |
+
if frac is not None:
|
| 1250 |
+
tier_b_pts = frac * tier_b_max
|
| 1251 |
+
tier_b_scores["oracle_identity"] = frac
|
| 1252 |
+
else:
|
| 1253 |
+
for metric, default_thresh in _TIER_B_THRESHOLDS.items():
|
| 1254 |
+
if metric in valid_metrics:
|
| 1255 |
+
thresh = overrides.get(metric, default_thresh)
|
| 1256 |
+
direction = _TIER_B_DIRECTIONS.get(metric, "higher_is_better")
|
| 1257 |
+
tier_b_scores[metric] = _continuous_score(
|
| 1258 |
+
valid_metrics[metric], thresh, direction
|
| 1259 |
+
)
|
| 1260 |
+
if tier_b_scores:
|
| 1261 |
+
tier_b_pts = (sum(tier_b_scores.values()) / len(tier_b_scores)) * tier_b_max
|
| 1262 |
+
|
| 1263 |
+
# Tier C: Interface Physics
|
| 1264 |
+
tier_c_fractions: list[float] = []
|
| 1265 |
+
tier_c_breakdown: list[dict] = []
|
| 1266 |
+
|
| 1267 |
+
if tier_c_max > 0:
|
| 1268 |
+
if is_binding:
|
| 1269 |
+
for metric_key, phys_thresh in _TIER_C_PHYSICS.items():
|
| 1270 |
+
if metric_key in valid_metrics:
|
| 1271 |
+
frac = _continuous_score(valid_metrics[metric_key], phys_thresh, "higher_is_better")
|
| 1272 |
+
tier_c_fractions.append(frac)
|
| 1273 |
+
tier_c_breakdown.append({
|
| 1274 |
+
"threshold": metric_key, "metric": metric_key,
|
| 1275 |
+
"value": valid_metrics[metric_key],
|
| 1276 |
+
"threshold_value": phys_thresh, "fraction": round(frac, 3),
|
| 1277 |
+
})
|
| 1278 |
+
|
| 1279 |
+
for thresh_key, (metric_key, direction) in _TIER_C_METRICS.items():
|
| 1280 |
+
if thresh_key in thresholds and metric_key in valid_metrics:
|
| 1281 |
+
threshold_val = thresholds[thresh_key]
|
| 1282 |
+
agent_val = valid_metrics[metric_key]
|
| 1283 |
+
margin = abs(threshold_val) * 0.5 if threshold_val != 0 else 1.0
|
| 1284 |
+
if direction == "lower_is_better":
|
| 1285 |
+
gt_thresh = {
|
| 1286 |
+
"pass": threshold_val + margin,
|
| 1287 |
+
"good": threshold_val,
|
| 1288 |
+
"excellent": threshold_val - margin,
|
| 1289 |
+
}
|
| 1290 |
+
else:
|
| 1291 |
+
gt_thresh = {
|
| 1292 |
+
"pass": threshold_val - margin,
|
| 1293 |
+
"good": threshold_val,
|
| 1294 |
+
"excellent": threshold_val + margin,
|
| 1295 |
+
}
|
| 1296 |
+
frac = _continuous_score(agent_val, gt_thresh, direction)
|
| 1297 |
+
tier_c_fractions.append(frac)
|
| 1298 |
+
tier_c_breakdown.append({
|
| 1299 |
+
"threshold": thresh_key, "metric": metric_key,
|
| 1300 |
+
"value": agent_val, "threshold_value": threshold_val,
|
| 1301 |
+
"fraction": round(frac, 3),
|
| 1302 |
+
})
|
| 1303 |
+
|
| 1304 |
+
tier_c_pts = (sum(tier_c_fractions) / len(tier_c_fractions)) * tier_c_max if tier_c_fractions else 0.0
|
| 1305 |
+
|
| 1306 |
+
total = min(tier_a_pts + tier_b_pts + tier_c_pts, max_points)
|
| 1307 |
+
metrics_evaluated = len(tier_a_scores) + len(tier_b_scores) + len(tier_c_fractions)
|
| 1308 |
+
|
| 1309 |
+
return {
|
| 1310 |
+
"score": int(round(total)), "max": max_points,
|
| 1311 |
+
"tier_a": round(tier_a_pts, 1), "tier_b": round(tier_b_pts, 1),
|
| 1312 |
+
"tier_c": round(tier_c_pts, 1),
|
| 1313 |
+
"metrics_evaluated": metrics_evaluated,
|
| 1314 |
+
"breakdown": {
|
| 1315 |
+
"structure": tier_a_scores, "interface": tier_b_scores,
|
| 1316 |
+
"physics": tier_c_breakdown,
|
| 1317 |
+
},
|
| 1318 |
+
}
|
| 1319 |
+
|
| 1320 |
+
|
| 1321 |
+
def score_novelty(
|
| 1322 |
+
designs: list[str],
|
| 1323 |
+
reference_seq: str | None,
|
| 1324 |
+
thresholds: dict[str, float],
|
| 1325 |
+
max_points: int = 5,
|
| 1326 |
+
) -> dict[str, Any]:
|
| 1327 |
+
"""Score novelty by computing sequence identity to reference."""
|
| 1328 |
+
if not designs:
|
| 1329 |
+
return {"score": 0, "max": max_points, "max_identity": 0.0, "identity_threshold": None}
|
| 1330 |
+
|
| 1331 |
+
identity_threshold = thresholds.get("max_seq_identity_good")
|
| 1332 |
+
max_id = max_identity_to_reference(designs, reference_seq) if reference_seq else 0.0
|
| 1333 |
+
|
| 1334 |
+
if identity_threshold is None:
|
| 1335 |
+
if reference_seq:
|
| 1336 |
+
novelty_ratio = 1.0 - max_id
|
| 1337 |
+
score = int(round(max_points * min(novelty_ratio * 2, 1.0)))
|
| 1338 |
+
else:
|
| 1339 |
+
score = max_points
|
| 1340 |
+
elif identity_threshold >= 0.9:
|
| 1341 |
+
if max_id >= identity_threshold:
|
| 1342 |
+
score = max_points
|
| 1343 |
+
elif max_id >= identity_threshold * 0.9:
|
| 1344 |
+
score = int(round(max_points * 0.7))
|
| 1345 |
+
else:
|
| 1346 |
+
score = int(round(max_points * 0.3))
|
| 1347 |
+
else:
|
| 1348 |
+
if max_id <= identity_threshold:
|
| 1349 |
+
score = max_points
|
| 1350 |
+
elif max_id <= identity_threshold * 1.5:
|
| 1351 |
+
score = int(round(max_points * 0.5))
|
| 1352 |
+
else:
|
| 1353 |
+
score = int(round(max_points * 0.2))
|
| 1354 |
+
|
| 1355 |
+
return {
|
| 1356 |
+
"score": min(score, max_points), "max": max_points,
|
| 1357 |
+
"max_identity": round(max_id, 3), "identity_threshold": identity_threshold,
|
| 1358 |
+
}
|
| 1359 |
+
|
| 1360 |
+
|
| 1361 |
+
def score_diversity(
|
| 1362 |
+
designs: list[str],
|
| 1363 |
+
max_designs: int = 10,
|
| 1364 |
+
max_points: int = 5,
|
| 1365 |
+
) -> dict[str, Any]:
|
| 1366 |
+
"""Score diversity of designs."""
|
| 1367 |
+
if not designs:
|
| 1368 |
+
return {"score": 0, "max": max_points, "num_designs": 0, "pairwise_diversity": 0.0, "entropy": 0.0}
|
| 1369 |
+
|
| 1370 |
+
num = len(designs)
|
| 1371 |
+
count_fraction = min(num / max_designs, 1.0) if max_designs > 0 else 1.0
|
| 1372 |
+
diversity = mean_pairwise_diversity(designs)
|
| 1373 |
+
entropy = sequence_entropy(designs)
|
| 1374 |
+
|
| 1375 |
+
count_score = count_fraction * max_points * 0.4
|
| 1376 |
+
diversity_score = diversity * max_points * 0.4
|
| 1377 |
+
entropy_score = entropy * max_points * 0.2
|
| 1378 |
+
total = int(round(count_score + diversity_score + entropy_score))
|
| 1379 |
+
|
| 1380 |
+
return {
|
| 1381 |
+
"score": min(total, max_points), "max": max_points,
|
| 1382 |
+
"num_designs": num, "pairwise_diversity": round(diversity, 3),
|
| 1383 |
+
"entropy": round(entropy, 3),
|
| 1384 |
+
}
|
| 1385 |
+
|
| 1386 |
+
|
| 1387 |
+
def score_feasibility(
|
| 1388 |
+
designs: list[str],
|
| 1389 |
+
constraints: dict[str, Any],
|
| 1390 |
+
max_points: int = 25,
|
| 1391 |
+
) -> dict[str, Any]:
|
| 1392 |
+
"""Score feasibility of designed sequences."""
|
| 1393 |
+
if not designs:
|
| 1394 |
+
return {"score": 0, "max": max_points, "aa_validity": 0.0, "length_validity": 0.0, "composition_check": 0.0}
|
| 1395 |
+
|
| 1396 |
+
per_check = max_points / 3
|
| 1397 |
+
length_range = constraints.get("length_range")
|
| 1398 |
+
if isinstance(length_range, list):
|
| 1399 |
+
length_range = tuple(length_range)
|
| 1400 |
+
|
| 1401 |
+
comp_min_length = 20
|
| 1402 |
+
if length_range and length_range[1] < 20:
|
| 1403 |
+
comp_min_length = max(length_range[0], 5)
|
| 1404 |
+
|
| 1405 |
+
aa_valid_count = sum(1 for seq in designs if validate_amino_acids(seq)["valid"])
|
| 1406 |
+
aa_fraction = aa_valid_count / len(designs)
|
| 1407 |
+
|
| 1408 |
+
length_valid_count = sum(1 for seq in designs if check_length_constraints(seq, length_range)["within_range"])
|
| 1409 |
+
length_fraction = length_valid_count / len(designs)
|
| 1410 |
+
|
| 1411 |
+
composition_ok = sum(1 for seq in designs if _has_reasonable_composition(seq, min_length=comp_min_length))
|
| 1412 |
+
composition_fraction = composition_ok / len(designs)
|
| 1413 |
+
|
| 1414 |
+
aa_score = aa_fraction * per_check
|
| 1415 |
+
length_score = length_fraction * per_check
|
| 1416 |
+
comp_score = composition_fraction * per_check
|
| 1417 |
+
total = int(round(aa_score + length_score + comp_score))
|
| 1418 |
+
|
| 1419 |
+
return {
|
| 1420 |
+
"score": min(total, max_points), "max": max_points,
|
| 1421 |
+
"aa_validity": round(aa_fraction, 3),
|
| 1422 |
+
"length_validity": round(length_fraction, 3),
|
| 1423 |
+
"composition_check": round(composition_fraction, 3),
|
| 1424 |
+
}
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 1428 |
+
# SECTION 6 — Design Gate + Final Score
|
| 1429 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 1430 |
+
|
| 1431 |
+
_DESIGN_GATE_ZEROED = {"quality", "novelty", "diversity", "feasibility"}
|
| 1432 |
+
_DESIGN_GATE_CAP = 30
|
| 1433 |
+
|
| 1434 |
+
|
| 1435 |
+
def apply_design_gate(
|
| 1436 |
+
component_scores: dict[str, int],
|
| 1437 |
+
num_designs: int,
|
| 1438 |
+
) -> dict[str, int]:
|
| 1439 |
+
"""If no designs produced, cap total at 30."""
|
| 1440 |
+
if num_designs >= 1:
|
| 1441 |
+
return dict(component_scores)
|
| 1442 |
+
gated = dict(component_scores)
|
| 1443 |
+
for key in _DESIGN_GATE_ZEROED:
|
| 1444 |
+
gated[key] = 0
|
| 1445 |
+
remaining_sum = sum(v for k, v in gated.items() if k not in _DESIGN_GATE_ZEROED)
|
| 1446 |
+
if remaining_sum > _DESIGN_GATE_CAP:
|
| 1447 |
+
scale = _DESIGN_GATE_CAP / remaining_sum
|
| 1448 |
+
for key in gated:
|
| 1449 |
+
if key not in _DESIGN_GATE_ZEROED:
|
| 1450 |
+
gated[key] = int(round(gated[key] * scale))
|
| 1451 |
+
return gated
|
| 1452 |
+
|
| 1453 |
+
|
| 1454 |
+
def calculate_design_score(
|
| 1455 |
+
rubric: DesignScoringRubric,
|
| 1456 |
+
results: dict[str, int],
|
| 1457 |
+
) -> dict[str, Any]:
|
| 1458 |
+
"""Calculate final design task score from component results."""
|
| 1459 |
+
breakdown = {}
|
| 1460 |
+
for component, max_pts in rubric.components.items():
|
| 1461 |
+
actual = min(results.get(component, 0), max_pts)
|
| 1462 |
+
breakdown[component] = {"score": actual, "max": max_pts}
|
| 1463 |
+
total = sum(v["score"] for v in breakdown.values())
|
| 1464 |
+
max_possible = rubric.max_score
|
| 1465 |
+
return {
|
| 1466 |
+
"breakdown": breakdown,
|
| 1467 |
+
"total": total,
|
| 1468 |
+
"max_possible": max_possible,
|
| 1469 |
+
"percentage": round(total / max_possible * 100, 1) if max_possible > 0 else 0,
|
| 1470 |
+
}
|
| 1471 |
+
|
| 1472 |
+
|
| 1473 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 1474 |
+
# SECTION 7 — Full Task Scorer (high-level API for eval pipeline)
|
| 1475 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 1476 |
+
|
| 1477 |
+
|
| 1478 |
+
def score_submission_task(
|
| 1479 |
+
task_id: str,
|
| 1480 |
+
sequences: list[str],
|
| 1481 |
+
run_log: list[dict[str, Any]],
|
| 1482 |
+
ground_truth: dict[str, Any],
|
| 1483 |
+
agent_metrics: dict[str, float] | None = None,
|
| 1484 |
+
oracle_sequences: list[str] | None = None,
|
| 1485 |
+
) -> dict[str, Any]:
|
| 1486 |
+
"""Score a single task submission end-to-end.
|
| 1487 |
+
|
| 1488 |
+
This is the main entry point for the evaluation pipeline.
|
| 1489 |
+
|
| 1490 |
+
Args:
|
| 1491 |
+
task_id: Task identifier (e.g., "dnb_sig_001").
|
| 1492 |
+
sequences: Designed amino acid sequences from the agent.
|
| 1493 |
+
run_log: Tool call log from the agent.
|
| 1494 |
+
ground_truth: Ground truth dict with thresholds, reference_sequence,
|
| 1495 |
+
design_constraints, tools_expected, max_designs.
|
| 1496 |
+
agent_metrics: Optional metrics reported by the agent or from Boltz
|
| 1497 |
+
(e.g., {"pLDDT": 85.0, "ipTM": 0.35}).
|
| 1498 |
+
oracle_sequences: Optional oracle sequences for functional similarity.
|
| 1499 |
+
|
| 1500 |
+
Returns:
|
| 1501 |
+
Dict with: total_score, component_scores, details, num_designs.
|
| 1502 |
+
"""
|
| 1503 |
+
if agent_metrics is None:
|
| 1504 |
+
agent_metrics = {}
|
| 1505 |
+
|
| 1506 |
+
# Extract fields from ground truth
|
| 1507 |
+
thresholds = ground_truth.get("thresholds", {})
|
| 1508 |
+
reference_seq = ground_truth.get("reference_sequence")
|
| 1509 |
+
constraints = ground_truth.get("design_constraints", {})
|
| 1510 |
+
tools_expected = ground_truth.get("tools_expected", [])
|
| 1511 |
+
max_designs = ground_truth.get("max_designs", 10)
|
| 1512 |
+
|
| 1513 |
+
# Get task category for function-based scoring
|
| 1514 |
+
cat = get_category(task_id)
|
| 1515 |
+
task_type = cat.task_type if cat else None
|
| 1516 |
+
|
| 1517 |
+
# Extract tools used from run_log
|
| 1518 |
+
tools_used = [entry.get("tool", "") for entry in run_log if entry.get("tool")]
|
| 1519 |
+
|
| 1520 |
+
# Score all 6 components
|
| 1521 |
+
approach_result = score_approach(
|
| 1522 |
+
tools_used=tools_used,
|
| 1523 |
+
tools_expected=tools_expected,
|
| 1524 |
+
task_type=task_type,
|
| 1525 |
+
)
|
| 1526 |
+
orchestration_result = score_orchestration(
|
| 1527 |
+
tool_call_log=run_log,
|
| 1528 |
+
task_id=task_id,
|
| 1529 |
+
)
|
| 1530 |
+
quality_result = score_quality(
|
| 1531 |
+
agent_metrics=agent_metrics,
|
| 1532 |
+
thresholds=thresholds,
|
| 1533 |
+
task_id=task_id,
|
| 1534 |
+
designs=sequences,
|
| 1535 |
+
oracle_sequences=oracle_sequences,
|
| 1536 |
+
)
|
| 1537 |
+
feasibility_result = score_feasibility(
|
| 1538 |
+
designs=sequences,
|
| 1539 |
+
constraints=constraints,
|
| 1540 |
+
)
|
| 1541 |
+
novelty_result = score_novelty(
|
| 1542 |
+
designs=sequences,
|
| 1543 |
+
reference_seq=reference_seq,
|
| 1544 |
+
thresholds=thresholds,
|
| 1545 |
+
)
|
| 1546 |
+
diversity_result = score_diversity(
|
| 1547 |
+
designs=sequences,
|
| 1548 |
+
max_designs=max_designs,
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
# Build component scores dict
|
| 1552 |
+
component_scores = {
|
| 1553 |
+
"approach": approach_result["score"],
|
| 1554 |
+
"orchestration": orchestration_result["score"],
|
| 1555 |
+
"quality": quality_result["score"],
|
| 1556 |
+
"feasibility": feasibility_result["score"],
|
| 1557 |
+
"novelty": novelty_result["score"],
|
| 1558 |
+
"diversity": diversity_result["score"],
|
| 1559 |
+
}
|
| 1560 |
+
|
| 1561 |
+
# Apply design gate
|
| 1562 |
+
num_designs = len(sequences)
|
| 1563 |
+
gated = apply_design_gate(component_scores, num_designs)
|
| 1564 |
+
total = sum(gated.values())
|
| 1565 |
+
|
| 1566 |
+
return {
|
| 1567 |
+
"total_score": total,
|
| 1568 |
+
"component_scores": gated,
|
| 1569 |
+
"num_designs": num_designs,
|
| 1570 |
+
"details": {
|
| 1571 |
+
"approach": approach_result,
|
| 1572 |
+
"orchestration": orchestration_result,
|
| 1573 |
+
"quality": quality_result,
|
| 1574 |
+
"feasibility": feasibility_result,
|
| 1575 |
+
"novelty": novelty_result,
|
| 1576 |
+
"diversity": diversity_result,
|
| 1577 |
+
},
|
| 1578 |
+
}
|
| 1579 |
+
|
| 1580 |
+
|
| 1581 |
+
def aggregate_scores(
|
| 1582 |
+
per_task_scores: dict[str, dict[str, Any]],
|
| 1583 |
+
) -> dict[str, Any]:
|
| 1584 |
+
"""Aggregate per-task scores into an overall submission result.
|
| 1585 |
+
|
| 1586 |
+
Args:
|
| 1587 |
+
per_task_scores: Dict mapping task_id → score_submission_task() result.
|
| 1588 |
+
|
| 1589 |
+
Returns:
|
| 1590 |
+
Dict with: overall_score, component_scores (averaged), taxonomy_scores,
|
| 1591 |
+
tasks_completed, tasks_with_zero.
|
| 1592 |
+
"""
|
| 1593 |
+
if not per_task_scores:
|
| 1594 |
+
return {
|
| 1595 |
+
"overall_score": 0.0,
|
| 1596 |
+
"component_scores": {c: 0.0 for c in DEFAULT_DESIGN_RUBRIC},
|
| 1597 |
+
"taxonomy_scores": {},
|
| 1598 |
+
"tasks_completed": 0,
|
| 1599 |
+
"tasks_total": 0,
|
| 1600 |
+
"tasks_with_zero": 0,
|
| 1601 |
+
}
|
| 1602 |
+
|
| 1603 |
+
totals = {c: 0.0 for c in DEFAULT_DESIGN_RUBRIC}
|
| 1604 |
+
n = len(per_task_scores)
|
| 1605 |
+
tasks_with_zero = 0
|
| 1606 |
+
|
| 1607 |
+
# Taxonomy breakdown
|
| 1608 |
+
taxonomy_scores: dict[str, dict[str, list[float]]] = {}
|
| 1609 |
+
|
| 1610 |
+
for task_id, result in per_task_scores.items():
|
| 1611 |
+
total_score = result["total_score"]
|
| 1612 |
+
if total_score == 0:
|
| 1613 |
+
tasks_with_zero += 1
|
| 1614 |
+
|
| 1615 |
+
for comp, val in result["component_scores"].items():
|
| 1616 |
+
totals[comp] += val
|
| 1617 |
+
|
| 1618 |
+
# Taxonomy mapping
|
| 1619 |
+
cat = get_category(task_id)
|
| 1620 |
+
if cat:
|
| 1621 |
+
tt = cat.task_type.value
|
| 1622 |
+
ctx = cat.context.short
|
| 1623 |
+
taxonomy_scores.setdefault(tt, {}).setdefault(ctx, []).append(total_score)
|
| 1624 |
+
|
| 1625 |
+
# Average components
|
| 1626 |
+
avg_components = {c: round(v / n, 1) for c, v in totals.items()}
|
| 1627 |
+
overall = round(sum(avg_components.values()), 1)
|
| 1628 |
+
|
| 1629 |
+
# Average taxonomy scores
|
| 1630 |
+
taxonomy_avg: dict[str, dict[str, float]] = {}
|
| 1631 |
+
for tt, contexts in taxonomy_scores.items():
|
| 1632 |
+
taxonomy_avg[tt] = {}
|
| 1633 |
+
for ctx, scores in contexts.items():
|
| 1634 |
+
taxonomy_avg[tt][ctx] = round(sum(scores) / len(scores), 1)
|
| 1635 |
+
|
| 1636 |
+
return {
|
| 1637 |
+
"overall_score": overall,
|
| 1638 |
+
"component_scores": avg_components,
|
| 1639 |
+
"taxonomy_scores": taxonomy_avg,
|
| 1640 |
+
"tasks_completed": n,
|
| 1641 |
+
"tasks_total": n,
|
| 1642 |
+
"tasks_with_zero": tasks_with_zero,
|
| 1643 |
+
}
|
eval_tasks.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load hidden benchmark tasks from a private HuggingFace Dataset.
|
| 2 |
+
|
| 3 |
+
Each task row contains:
|
| 4 |
+
- task_id: e.g., "dnb_sig_001"
|
| 5 |
+
- task_json: Full task definition (JSON string)
|
| 6 |
+
- ground_truth: Ground truth thresholds + reference (JSON string)
|
| 7 |
+
- prompt_md: Task prompt in Markdown
|
| 8 |
+
- pdb_data: Base64-encoded PDB file (if needed)
|
| 9 |
+
- pdb_filename: Original PDB filename (e.g., "7n1j.pdb")
|
| 10 |
+
- oracle_sequences: JSON list of oracle sequences (for non-binding tasks)
|
| 11 |
+
|
| 12 |
+
Falls back to local files in development (when BDB_USE_LOCAL=1).
|
| 13 |
+
|
| 14 |
+
HF Dataset: RomeroLab-Duke/biodesignbench-hidden-tasks (private)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import base64
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
from functools import lru_cache
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Any
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Configuration
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
TASKS_DATASET = os.environ.get(
|
| 34 |
+
"BDB_TASKS_DATASET",
|
| 35 |
+
"RomeroLab-Duke/biodesignbench-hidden-tasks",
|
| 36 |
+
)
|
| 37 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 38 |
+
USE_LOCAL = os.environ.get("BDB_USE_LOCAL", "0") == "1"
|
| 39 |
+
|
| 40 |
+
# Local paths (for development)
|
| 41 |
+
_PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 42 |
+
_TASKS_DIR = _PROJECT_ROOT / "tasks" / "tier2"
|
| 43 |
+
_GT_DIR = _PROJECT_ROOT / "data" / "tier2" / "ground_truth"
|
| 44 |
+
_PROMPTS_DIR = _PROJECT_ROOT / "data" / "tier2" / "prompts"
|
| 45 |
+
_INPUT_DIR = _PROJECT_ROOT / "data" / "tier2" / "input"
|
| 46 |
+
_ORACLE_PATH = _PROJECT_ROOT / "data" / "oracle" / "sequences.json"
|
| 47 |
+
_TOOL_SCHEMAS_PATH = Path(__file__).parent / "mcp_tool_schemas.json"
|
| 48 |
+
|
| 49 |
+
# Public task IDs (for development/testing — not hidden)
|
| 50 |
+
PUBLIC_TASK_IDS = {"dnb_sig_001", "sqo_enz_005", "cpx_sig_001"}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# HF Dataset loading
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@lru_cache(maxsize=1)
|
| 59 |
+
def _load_from_hf() -> dict[str, dict[str, Any]]:
|
| 60 |
+
"""Load all tasks from the private HF Dataset."""
|
| 61 |
+
try:
|
| 62 |
+
from datasets import load_dataset
|
| 63 |
+
|
| 64 |
+
ds = load_dataset(
|
| 65 |
+
TASKS_DATASET,
|
| 66 |
+
split="train",
|
| 67 |
+
token=HF_TOKEN,
|
| 68 |
+
)
|
| 69 |
+
tasks = {}
|
| 70 |
+
for row in ds:
|
| 71 |
+
task_id = row["task_id"]
|
| 72 |
+
tasks[task_id] = {
|
| 73 |
+
"task_id": task_id,
|
| 74 |
+
"task_json": json.loads(row["task_json"]),
|
| 75 |
+
"ground_truth": json.loads(row["ground_truth"]),
|
| 76 |
+
"prompt_md": row["prompt_md"],
|
| 77 |
+
"pdb_data": row.get("pdb_data"),
|
| 78 |
+
"pdb_filename": row.get("pdb_filename"),
|
| 79 |
+
"oracle_sequences": json.loads(row.get("oracle_sequences", "[]")),
|
| 80 |
+
}
|
| 81 |
+
logger.info(f"Loaded {len(tasks)} tasks from HF Dataset")
|
| 82 |
+
return tasks
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Failed to load tasks from HF: {e}")
|
| 85 |
+
return {}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@lru_cache(maxsize=1)
|
| 89 |
+
def _load_from_local() -> dict[str, dict[str, Any]]:
|
| 90 |
+
"""Load tasks from local project files (development mode)."""
|
| 91 |
+
tasks = {}
|
| 92 |
+
|
| 93 |
+
# Load oracle data
|
| 94 |
+
oracle_data = {}
|
| 95 |
+
if _ORACLE_PATH.exists():
|
| 96 |
+
with open(_ORACLE_PATH) as f:
|
| 97 |
+
oracle_data = json.load(f)
|
| 98 |
+
|
| 99 |
+
# Enumerate task files
|
| 100 |
+
if not _TASKS_DIR.exists():
|
| 101 |
+
logger.warning(f"Tasks directory not found: {_TASKS_DIR}")
|
| 102 |
+
return tasks
|
| 103 |
+
|
| 104 |
+
for task_path in sorted(_TASKS_DIR.glob("*.json")):
|
| 105 |
+
task_id = task_path.stem
|
| 106 |
+
try:
|
| 107 |
+
with open(task_path) as f:
|
| 108 |
+
task_json = json.load(f)
|
| 109 |
+
|
| 110 |
+
# Ground truth
|
| 111 |
+
gt_path = _GT_DIR / f"{task_id}.json"
|
| 112 |
+
ground_truth = {}
|
| 113 |
+
if gt_path.exists():
|
| 114 |
+
with open(gt_path) as f:
|
| 115 |
+
ground_truth = json.load(f)
|
| 116 |
+
|
| 117 |
+
# Prompt
|
| 118 |
+
prompt_path = _PROMPTS_DIR / f"{task_id}.md"
|
| 119 |
+
prompt_md = ""
|
| 120 |
+
if prompt_path.exists():
|
| 121 |
+
prompt_md = prompt_path.read_text()
|
| 122 |
+
|
| 123 |
+
# PDB data
|
| 124 |
+
pdb_data = None
|
| 125 |
+
pdb_filename = None
|
| 126 |
+
input_pdb = task_json.get("input_pdb") or task_json.get("pdb_file")
|
| 127 |
+
if input_pdb:
|
| 128 |
+
pdb_path = _INPUT_DIR / input_pdb
|
| 129 |
+
if pdb_path.exists():
|
| 130 |
+
pdb_data = base64.b64encode(pdb_path.read_bytes()).decode()
|
| 131 |
+
pdb_filename = input_pdb
|
| 132 |
+
|
| 133 |
+
# Oracle sequences
|
| 134 |
+
oracle_entry = oracle_data.get(task_id, {})
|
| 135 |
+
oracle_seqs = oracle_entry.get("sequences", []) if isinstance(oracle_entry, dict) else []
|
| 136 |
+
|
| 137 |
+
tasks[task_id] = {
|
| 138 |
+
"task_id": task_id,
|
| 139 |
+
"task_json": task_json,
|
| 140 |
+
"ground_truth": ground_truth,
|
| 141 |
+
"prompt_md": prompt_md,
|
| 142 |
+
"pdb_data": pdb_data,
|
| 143 |
+
"pdb_filename": pdb_filename,
|
| 144 |
+
"oracle_sequences": oracle_seqs,
|
| 145 |
+
}
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.warning(f"Failed to load task {task_id}: {e}")
|
| 148 |
+
|
| 149 |
+
logger.info(f"Loaded {len(tasks)} tasks from local files")
|
| 150 |
+
return tasks
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
# Public API
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def load_all_tasks() -> dict[str, dict[str, Any]]:
|
| 159 |
+
"""Load all benchmark tasks.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Dict mapping task_id → task data dict.
|
| 163 |
+
"""
|
| 164 |
+
if USE_LOCAL:
|
| 165 |
+
return _load_from_local()
|
| 166 |
+
return _load_from_hf()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_task(task_id: str) -> dict[str, Any] | None:
|
| 170 |
+
"""Load a single task by ID."""
|
| 171 |
+
tasks = load_all_tasks()
|
| 172 |
+
return tasks.get(task_id)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_hidden_task_ids() -> list[str]:
|
| 176 |
+
"""Get the list of hidden (non-public) task IDs."""
|
| 177 |
+
tasks = load_all_tasks()
|
| 178 |
+
return sorted(tid for tid in tasks if tid not in PUBLIC_TASK_IDS)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_all_task_ids() -> list[str]:
|
| 182 |
+
"""Get all task IDs (public + hidden)."""
|
| 183 |
+
return sorted(load_all_tasks().keys())
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_public_task_ids() -> list[str]:
|
| 187 |
+
"""Get the 3 public task IDs for development."""
|
| 188 |
+
tasks = load_all_tasks()
|
| 189 |
+
return sorted(tid for tid in tasks if tid in PUBLIC_TASK_IDS)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@lru_cache(maxsize=1)
|
| 193 |
+
def load_tool_schemas() -> list[dict[str, Any]]:
|
| 194 |
+
"""Load the 17 MCP tool schemas for task payloads."""
|
| 195 |
+
if _TOOL_SCHEMAS_PATH.exists():
|
| 196 |
+
with open(_TOOL_SCHEMAS_PATH) as f:
|
| 197 |
+
return json.load(f)
|
| 198 |
+
return []
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def build_task_payload(task_id: str) -> dict[str, Any] | None:
|
| 202 |
+
"""Build the payload to send to a submitter's endpoint.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Dict with: task_id, task_description, available_tools,
|
| 206 |
+
input_files, design_constraints, max_steps, timeout_sec.
|
| 207 |
+
Returns None if task not found.
|
| 208 |
+
"""
|
| 209 |
+
task = get_task(task_id)
|
| 210 |
+
if task is None:
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
task_json = task["task_json"]
|
| 214 |
+
prompt = task["prompt_md"]
|
| 215 |
+
|
| 216 |
+
# Build input files (base64-encoded PDBs)
|
| 217 |
+
input_files = {}
|
| 218 |
+
if task.get("pdb_data") and task.get("pdb_filename"):
|
| 219 |
+
input_files[task["pdb_filename"]] = task["pdb_data"]
|
| 220 |
+
|
| 221 |
+
# Extract constraints from task JSON
|
| 222 |
+
constraints = task_json.get("design_constraints", {})
|
| 223 |
+
max_designs = task_json.get("max_designs", 10)
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
"task_id": task_id,
|
| 227 |
+
"task_description": prompt,
|
| 228 |
+
"available_tools": load_tool_schemas(),
|
| 229 |
+
"input_files": input_files,
|
| 230 |
+
"design_constraints": {
|
| 231 |
+
**constraints,
|
| 232 |
+
"max_designs": max_designs,
|
| 233 |
+
},
|
| 234 |
+
"max_steps": 50,
|
| 235 |
+
"timeout_sec": 300,
|
| 236 |
+
}
|
example_server.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reference FastAPI server for BioDesignBench submitters.
|
| 2 |
+
|
| 3 |
+
This example shows how to implement the API endpoint that BioDesignBench
|
| 4 |
+
will call during benchmarking. Replace the mock agent logic with your
|
| 5 |
+
actual LLM agent + MCP tool pipeline.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
pip install fastapi uvicorn
|
| 9 |
+
python example_server.py
|
| 10 |
+
|
| 11 |
+
# Or with uvicorn directly:
|
| 12 |
+
uvicorn example_server:app --host 0.0.0.0 --port 8000
|
| 13 |
+
|
| 14 |
+
Your endpoint will receive POST requests at /api/run with the task payload.
|
| 15 |
+
|
| 16 |
+
Task Payload Format:
|
| 17 |
+
{
|
| 18 |
+
"task_id": "dnb_sig_001",
|
| 19 |
+
"task_description": "Design a de novo binder for...",
|
| 20 |
+
"available_tools": [... 17 tool schemas ...],
|
| 21 |
+
"input_files": {"7n1j.pdb": "<base64>"},
|
| 22 |
+
"design_constraints": {"length_range": [80, 150], "max_designs": 10},
|
| 23 |
+
"max_steps": 50,
|
| 24 |
+
"timeout_sec": 300
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
Expected Response Format:
|
| 28 |
+
{
|
| 29 |
+
"sequences": ["MKKL...", "MFQR..."],
|
| 30 |
+
"run_log": [{"step": 1, "tool": "suggest_hotspots", "success": true}, ...],
|
| 31 |
+
"total_steps": 12,
|
| 32 |
+
"total_time_sec": 142.5,
|
| 33 |
+
"metrics": {}
|
| 34 |
+
}
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
import base64
|
| 40 |
+
import random
|
| 41 |
+
import time
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
from typing import Any
|
| 44 |
+
|
| 45 |
+
from fastapi import FastAPI
|
| 46 |
+
from pydantic import BaseModel
|
| 47 |
+
|
| 48 |
+
app = FastAPI(
|
| 49 |
+
title="BioDesignBench Example Agent",
|
| 50 |
+
description="Reference implementation for benchmark submission",
|
| 51 |
+
version="0.1.0",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
# Request/Response models
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TaskPayload(BaseModel):
|
| 61 |
+
task_id: str
|
| 62 |
+
task_description: str
|
| 63 |
+
available_tools: list[dict[str, Any]] = []
|
| 64 |
+
input_files: dict[str, str] = {} # filename -> base64 data
|
| 65 |
+
design_constraints: dict[str, Any] = {}
|
| 66 |
+
max_steps: int = 50
|
| 67 |
+
timeout_sec: int = 300
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class AgentResponse(BaseModel):
|
| 71 |
+
sequences: list[str]
|
| 72 |
+
run_log: list[dict[str, Any]]
|
| 73 |
+
total_steps: int
|
| 74 |
+
total_time_sec: float
|
| 75 |
+
metrics: dict[str, Any] = {}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Mock agent (replace with your real agent)
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
# Standard amino acids for mock sequence generation
|
| 83 |
+
_AAS = "ACDEFGHIKLMNPQRSTVWY"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _generate_mock_sequence(length: int) -> str:
|
| 87 |
+
"""Generate a random protein sequence with reasonable composition."""
|
| 88 |
+
# Weight towards common amino acids
|
| 89 |
+
weights = [
|
| 90 |
+
7, 2, 5, 6, 4, 7, 2, 5, 6, 9, # A C D E F G H I K L
|
| 91 |
+
2, 4, 5, 4, 5, 7, 6, 7, 1, 3, # M N P Q R S T V W Y
|
| 92 |
+
]
|
| 93 |
+
return "".join(random.choices(_AAS, weights=weights, k=length))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def mock_agent(payload: TaskPayload) -> AgentResponse:
|
| 97 |
+
"""Mock agent that generates random but valid designs.
|
| 98 |
+
|
| 99 |
+
Replace this with your actual LLM agent + MCP tool pipeline.
|
| 100 |
+
This mock demonstrates the expected response format.
|
| 101 |
+
"""
|
| 102 |
+
start = time.monotonic()
|
| 103 |
+
|
| 104 |
+
# Determine design parameters
|
| 105 |
+
constraints = payload.design_constraints
|
| 106 |
+
length_range = constraints.get("length_range", [80, 150])
|
| 107 |
+
max_designs = constraints.get("max_designs", 10)
|
| 108 |
+
num_designs = min(max_designs, 5) # Generate 5 for this mock
|
| 109 |
+
|
| 110 |
+
# "Decode" input PDB files (in a real agent, you'd use these)
|
| 111 |
+
for filename, b64_data in payload.input_files.items():
|
| 112 |
+
pdb_bytes = base64.b64decode(b64_data)
|
| 113 |
+
# In a real agent: save to temp file and pass to MCP tools
|
| 114 |
+
|
| 115 |
+
# Simulate a multi-step design pipeline
|
| 116 |
+
run_log = [
|
| 117 |
+
{
|
| 118 |
+
"step": 1,
|
| 119 |
+
"tool": "suggest_hotspots",
|
| 120 |
+
"success": True,
|
| 121 |
+
"args_summary": {"target": "from_pdb"},
|
| 122 |
+
"output_summary": "Found 5 hotspot residues",
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"step": 2,
|
| 126 |
+
"tool": "generate_backbone",
|
| 127 |
+
"success": True,
|
| 128 |
+
"args_summary": {"length": length_range[0]},
|
| 129 |
+
"output_summary": f"Generated {num_designs} backbones",
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"step": 3,
|
| 133 |
+
"tool": "optimize_sequence",
|
| 134 |
+
"success": True,
|
| 135 |
+
"args_summary": {"optimization_target": "both"},
|
| 136 |
+
"output_summary": f"Optimized {num_designs} sequences",
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"step": 4,
|
| 140 |
+
"tool": "predict_structure",
|
| 141 |
+
"success": True,
|
| 142 |
+
"args_summary": {"predictor": "esmfold"},
|
| 143 |
+
"output_summary": "Predicted structures for all designs",
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"step": 5,
|
| 147 |
+
"tool": "validate_design",
|
| 148 |
+
"success": True,
|
| 149 |
+
"args_summary": {},
|
| 150 |
+
"output_summary": "Validated all designs",
|
| 151 |
+
},
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
# Generate mock sequences
|
| 155 |
+
min_len, max_len = length_range
|
| 156 |
+
sequences = [
|
| 157 |
+
_generate_mock_sequence(random.randint(min_len, max_len))
|
| 158 |
+
for _ in range(num_designs)
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
elapsed = time.monotonic() - start
|
| 162 |
+
|
| 163 |
+
return AgentResponse(
|
| 164 |
+
sequences=sequences,
|
| 165 |
+
run_log=run_log,
|
| 166 |
+
total_steps=len(run_log),
|
| 167 |
+
total_time_sec=round(elapsed, 2),
|
| 168 |
+
metrics={}, # Agent-reported metrics (optional)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
# API endpoint
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@app.post("/api/run", response_model=AgentResponse)
|
| 178 |
+
async def run_task(payload: TaskPayload) -> AgentResponse:
|
| 179 |
+
"""Run a single benchmark task.
|
| 180 |
+
|
| 181 |
+
This is the endpoint that BioDesignBench will POST to during benchmarking.
|
| 182 |
+
Replace mock_agent() with your actual agent logic.
|
| 183 |
+
"""
|
| 184 |
+
return mock_agent(payload)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@app.get("/health")
|
| 188 |
+
async def health():
|
| 189 |
+
"""Health check endpoint."""
|
| 190 |
+
return {"status": "ok", "agent": "example-mock"}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
# Entry point
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
import uvicorn
|
| 199 |
+
|
| 200 |
+
print("Starting BioDesignBench example server...")
|
| 201 |
+
print("POST endpoint: http://localhost:8000/api/run")
|
| 202 |
+
print("Health check: http://localhost:8000/health")
|
| 203 |
+
print()
|
| 204 |
+
print("Replace mock_agent() with your real agent logic.")
|
| 205 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
mcp_tool_schemas.json
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"name": "design_binder",
|
| 4 |
+
"description": "Design protein binders for a target protein. Runs RFdiffusion -> ProteinMPNN -> ESMFold pipeline.",
|
| 5 |
+
"parameters": {
|
| 6 |
+
"type": "object",
|
| 7 |
+
"properties": {
|
| 8 |
+
"target_pdb": {
|
| 9 |
+
"type": "string",
|
| 10 |
+
"description": "Path to target protein PDB file"
|
| 11 |
+
},
|
| 12 |
+
"hotspot_residues": {
|
| 13 |
+
"type": "array",
|
| 14 |
+
"items": {
|
| 15 |
+
"type": "string"
|
| 16 |
+
},
|
| 17 |
+
"description": "Target residues for binder interface, e.g. ['A45', 'A46']"
|
| 18 |
+
},
|
| 19 |
+
"num_designs": {
|
| 20 |
+
"type": "integer",
|
| 21 |
+
"description": "Number of designs to generate (default: 10)",
|
| 22 |
+
"default": 10
|
| 23 |
+
},
|
| 24 |
+
"binder_length": {
|
| 25 |
+
"type": "integer",
|
| 26 |
+
"description": "Binder length in residues (default: 80)",
|
| 27 |
+
"default": 80
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"required": [
|
| 31 |
+
"target_pdb",
|
| 32 |
+
"hotspot_residues"
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"name": "analyze_interface",
|
| 38 |
+
"description": "Analyze protein-protein interface: buried surface area, H-bonds, salt bridges, hydrophobic contacts.",
|
| 39 |
+
"parameters": {
|
| 40 |
+
"type": "object",
|
| 41 |
+
"properties": {
|
| 42 |
+
"complex_pdb": {
|
| 43 |
+
"type": "string",
|
| 44 |
+
"description": "Path to complex PDB file"
|
| 45 |
+
},
|
| 46 |
+
"chain_a": {
|
| 47 |
+
"type": "string",
|
| 48 |
+
"description": "Chain ID of first protein"
|
| 49 |
+
},
|
| 50 |
+
"chain_b": {
|
| 51 |
+
"type": "string",
|
| 52 |
+
"description": "Chain ID of second protein"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"required": [
|
| 56 |
+
"complex_pdb",
|
| 57 |
+
"chain_a",
|
| 58 |
+
"chain_b"
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"name": "validate_design",
|
| 64 |
+
"description": "Validate a designed sequence by predicting its structure (ESMFold/AlphaFold2) and computing pLDDT, pTM.",
|
| 65 |
+
"parameters": {
|
| 66 |
+
"type": "object",
|
| 67 |
+
"properties": {
|
| 68 |
+
"sequence": {
|
| 69 |
+
"type": "string",
|
| 70 |
+
"description": "Amino acid sequence to validate"
|
| 71 |
+
},
|
| 72 |
+
"expected_structure": {
|
| 73 |
+
"type": "string",
|
| 74 |
+
"description": "Optional PDB path for RMSD comparison"
|
| 75 |
+
},
|
| 76 |
+
"predictor": {
|
| 77 |
+
"type": "string",
|
| 78 |
+
"enum": [
|
| 79 |
+
"esmfold",
|
| 80 |
+
"alphafold2"
|
| 81 |
+
],
|
| 82 |
+
"default": "esmfold",
|
| 83 |
+
"description": "Structure predictor to use"
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"required": [
|
| 87 |
+
"sequence"
|
| 88 |
+
]
|
| 89 |
+
}
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "optimize_sequence",
|
| 93 |
+
"description": "Optimize binder sequence for improved stability and/or binding affinity.",
|
| 94 |
+
"parameters": {
|
| 95 |
+
"type": "object",
|
| 96 |
+
"properties": {
|
| 97 |
+
"current_sequence": {
|
| 98 |
+
"type": "string",
|
| 99 |
+
"description": "Starting amino acid sequence"
|
| 100 |
+
},
|
| 101 |
+
"target_pdb": {
|
| 102 |
+
"type": "string",
|
| 103 |
+
"description": "Path to target protein PDB"
|
| 104 |
+
},
|
| 105 |
+
"optimization_target": {
|
| 106 |
+
"type": "string",
|
| 107 |
+
"enum": [
|
| 108 |
+
"stability",
|
| 109 |
+
"affinity",
|
| 110 |
+
"both"
|
| 111 |
+
],
|
| 112 |
+
"default": "both"
|
| 113 |
+
},
|
| 114 |
+
"fixed_positions": {
|
| 115 |
+
"type": "array",
|
| 116 |
+
"items": {
|
| 117 |
+
"type": "integer"
|
| 118 |
+
},
|
| 119 |
+
"description": "Positions to keep fixed (1-indexed)"
|
| 120 |
+
}
|
| 121 |
+
},
|
| 122 |
+
"required": [
|
| 123 |
+
"current_sequence",
|
| 124 |
+
"target_pdb"
|
| 125 |
+
]
|
| 126 |
+
}
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"name": "suggest_hotspots",
|
| 130 |
+
"description": "Analyze target protein and suggest binding hotspots using structure, conservation, and literature.",
|
| 131 |
+
"parameters": {
|
| 132 |
+
"type": "object",
|
| 133 |
+
"properties": {
|
| 134 |
+
"target": {
|
| 135 |
+
"type": "string",
|
| 136 |
+
"description": "Protein name, UniProt ID, PDB ID, or local PDB path"
|
| 137 |
+
},
|
| 138 |
+
"chain_id": {
|
| 139 |
+
"type": "string",
|
| 140 |
+
"description": "Chain to analyze (default: first)"
|
| 141 |
+
},
|
| 142 |
+
"criteria": {
|
| 143 |
+
"type": "string",
|
| 144 |
+
"enum": [
|
| 145 |
+
"druggable",
|
| 146 |
+
"exposed",
|
| 147 |
+
"conserved"
|
| 148 |
+
],
|
| 149 |
+
"default": "exposed"
|
| 150 |
+
},
|
| 151 |
+
"include_literature": {
|
| 152 |
+
"type": "boolean",
|
| 153 |
+
"default": false,
|
| 154 |
+
"description": "Search PubMed for known binders"
|
| 155 |
+
}
|
| 156 |
+
},
|
| 157 |
+
"required": [
|
| 158 |
+
"target"
|
| 159 |
+
]
|
| 160 |
+
}
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"name": "get_design_status",
|
| 164 |
+
"description": "Check status of running design jobs.",
|
| 165 |
+
"parameters": {
|
| 166 |
+
"type": "object",
|
| 167 |
+
"properties": {
|
| 168 |
+
"job_id": {
|
| 169 |
+
"type": "string",
|
| 170 |
+
"description": "Job ID from design_binder call"
|
| 171 |
+
}
|
| 172 |
+
},
|
| 173 |
+
"required": [
|
| 174 |
+
"job_id"
|
| 175 |
+
]
|
| 176 |
+
}
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"name": "predict_complex",
|
| 180 |
+
"description": "Predict protein complex structure using AlphaFold2-Multimer.",
|
| 181 |
+
"parameters": {
|
| 182 |
+
"type": "object",
|
| 183 |
+
"properties": {
|
| 184 |
+
"sequences": {
|
| 185 |
+
"type": "array",
|
| 186 |
+
"items": {
|
| 187 |
+
"type": "string"
|
| 188 |
+
},
|
| 189 |
+
"description": "List of sequences, one per chain"
|
| 190 |
+
},
|
| 191 |
+
"chain_names": {
|
| 192 |
+
"type": "array",
|
| 193 |
+
"items": {
|
| 194 |
+
"type": "string"
|
| 195 |
+
},
|
| 196 |
+
"description": "Optional chain identifiers"
|
| 197 |
+
}
|
| 198 |
+
},
|
| 199 |
+
"required": [
|
| 200 |
+
"sequences"
|
| 201 |
+
]
|
| 202 |
+
}
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"name": "predict_structure",
|
| 206 |
+
"description": "Predict the 3D structure of a single protein chain using ESMFold or AlphaFold2. Returns predicted PDB, pLDDT, and pTM scores.",
|
| 207 |
+
"parameters": {
|
| 208 |
+
"type": "object",
|
| 209 |
+
"properties": {
|
| 210 |
+
"sequence": {
|
| 211 |
+
"type": "string",
|
| 212 |
+
"description": "Amino acid sequence to predict structure for"
|
| 213 |
+
},
|
| 214 |
+
"predictor": {
|
| 215 |
+
"type": "string",
|
| 216 |
+
"enum": [
|
| 217 |
+
"esmfold",
|
| 218 |
+
"alphafold2"
|
| 219 |
+
],
|
| 220 |
+
"default": "esmfold",
|
| 221 |
+
"description": "Structure predictor to use"
|
| 222 |
+
}
|
| 223 |
+
},
|
| 224 |
+
"required": [
|
| 225 |
+
"sequence"
|
| 226 |
+
]
|
| 227 |
+
}
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"name": "score_stability",
|
| 231 |
+
"description": "Score protein stability using ESM2 pseudo-log-likelihood. Optionally compute per-mutation effects (delta log-likelihood).",
|
| 232 |
+
"parameters": {
|
| 233 |
+
"type": "object",
|
| 234 |
+
"properties": {
|
| 235 |
+
"sequence": {
|
| 236 |
+
"type": "string",
|
| 237 |
+
"description": "Amino acid sequence to score"
|
| 238 |
+
},
|
| 239 |
+
"mutations": {
|
| 240 |
+
"type": "array",
|
| 241 |
+
"items": {
|
| 242 |
+
"type": "string"
|
| 243 |
+
},
|
| 244 |
+
"description": "Optional mutations in 'X42Y' format for delta scoring"
|
| 245 |
+
},
|
| 246 |
+
"reference_sequence": {
|
| 247 |
+
"type": "string",
|
| 248 |
+
"description": "Optional wild-type sequence for mutation scoring"
|
| 249 |
+
}
|
| 250 |
+
},
|
| 251 |
+
"required": [
|
| 252 |
+
"sequence"
|
| 253 |
+
]
|
| 254 |
+
}
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"name": "energy_minimize",
|
| 258 |
+
"description": "Energy-minimize a protein structure using OpenMM with AMBER14 force field. Returns minimized PDB, energy change, and RMSD from initial structure.",
|
| 259 |
+
"parameters": {
|
| 260 |
+
"type": "object",
|
| 261 |
+
"properties": {
|
| 262 |
+
"pdb_path": {
|
| 263 |
+
"type": "string",
|
| 264 |
+
"description": "Path to input PDB file"
|
| 265 |
+
},
|
| 266 |
+
"force_field": {
|
| 267 |
+
"type": "string",
|
| 268 |
+
"default": "amber14-all.xml",
|
| 269 |
+
"description": "OpenMM force field XML"
|
| 270 |
+
},
|
| 271 |
+
"num_steps": {
|
| 272 |
+
"type": "integer",
|
| 273 |
+
"default": 500,
|
| 274 |
+
"description": "Maximum minimization iterations"
|
| 275 |
+
},
|
| 276 |
+
"solvent": {
|
| 277 |
+
"type": "string",
|
| 278 |
+
"enum": [
|
| 279 |
+
"implicit",
|
| 280 |
+
"none"
|
| 281 |
+
],
|
| 282 |
+
"default": "implicit",
|
| 283 |
+
"description": "Solvent model: implicit (GBn2) or none (vacuum)"
|
| 284 |
+
}
|
| 285 |
+
},
|
| 286 |
+
"required": [
|
| 287 |
+
"pdb_path"
|
| 288 |
+
]
|
| 289 |
+
}
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"name": "generate_backbone",
|
| 293 |
+
"description": "Generate de novo protein backbones using RFdiffusion unconditional generation. No target protein required.",
|
| 294 |
+
"parameters": {
|
| 295 |
+
"type": "object",
|
| 296 |
+
"properties": {
|
| 297 |
+
"length": {
|
| 298 |
+
"type": "integer",
|
| 299 |
+
"description": "Backbone length in residues"
|
| 300 |
+
},
|
| 301 |
+
"num_designs": {
|
| 302 |
+
"type": "integer",
|
| 303 |
+
"default": 10,
|
| 304 |
+
"description": "Number of designs to generate"
|
| 305 |
+
}
|
| 306 |
+
},
|
| 307 |
+
"required": [
|
| 308 |
+
"length"
|
| 309 |
+
]
|
| 310 |
+
}
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"name": "rosetta_score",
|
| 314 |
+
"description": "Score a protein structure using Rosetta energy function (ref2015). Returns total score, per-residue energies, and energy breakdown.",
|
| 315 |
+
"parameters": {
|
| 316 |
+
"type": "object",
|
| 317 |
+
"properties": {
|
| 318 |
+
"pdb_path": {
|
| 319 |
+
"type": "string",
|
| 320 |
+
"description": "Path to input PDB file"
|
| 321 |
+
},
|
| 322 |
+
"score_function": {
|
| 323 |
+
"type": "string",
|
| 324 |
+
"default": "ref2015",
|
| 325 |
+
"description": "Rosetta score function name"
|
| 326 |
+
}
|
| 327 |
+
},
|
| 328 |
+
"required": [
|
| 329 |
+
"pdb_path"
|
| 330 |
+
]
|
| 331 |
+
}
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"name": "rosetta_relax",
|
| 335 |
+
"description": "Relax a protein structure using Rosetta FastRelax. Returns relaxed PDB, energy change, and CA-RMSD.",
|
| 336 |
+
"parameters": {
|
| 337 |
+
"type": "object",
|
| 338 |
+
"properties": {
|
| 339 |
+
"pdb_path": {
|
| 340 |
+
"type": "string",
|
| 341 |
+
"description": "Path to input PDB file"
|
| 342 |
+
},
|
| 343 |
+
"nstruct": {
|
| 344 |
+
"type": "integer",
|
| 345 |
+
"default": 1,
|
| 346 |
+
"description": "Number of relaxation trajectories"
|
| 347 |
+
},
|
| 348 |
+
"score_function": {
|
| 349 |
+
"type": "string",
|
| 350 |
+
"default": "ref2015",
|
| 351 |
+
"description": "Rosetta score function name"
|
| 352 |
+
}
|
| 353 |
+
},
|
| 354 |
+
"required": [
|
| 355 |
+
"pdb_path"
|
| 356 |
+
]
|
| 357 |
+
}
|
| 358 |
+
},
|
| 359 |
+
{
|
| 360 |
+
"name": "rosetta_interface_score",
|
| 361 |
+
"description": "Compute interface energy metrics for a protein complex using Rosetta. Returns dG_separated, dSASA, interface hbonds, and packing stats.",
|
| 362 |
+
"parameters": {
|
| 363 |
+
"type": "object",
|
| 364 |
+
"properties": {
|
| 365 |
+
"pdb_path": {
|
| 366 |
+
"type": "string",
|
| 367 |
+
"description": "Path to complex PDB file"
|
| 368 |
+
},
|
| 369 |
+
"chains": {
|
| 370 |
+
"type": "string",
|
| 371 |
+
"default": "A_B",
|
| 372 |
+
"description": "Chain grouping, e.g. 'A_B' or 'AB_C'"
|
| 373 |
+
},
|
| 374 |
+
"score_function": {
|
| 375 |
+
"type": "string",
|
| 376 |
+
"default": "ref2015",
|
| 377 |
+
"description": "Rosetta score function name"
|
| 378 |
+
}
|
| 379 |
+
},
|
| 380 |
+
"required": [
|
| 381 |
+
"pdb_path"
|
| 382 |
+
]
|
| 383 |
+
}
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"name": "rosetta_design",
|
| 387 |
+
"description": "Fixed-backbone sequence design using Rosetta PackRotamers + MinMover. Composite convenience tool (hidden in benchmark mode).",
|
| 388 |
+
"parameters": {
|
| 389 |
+
"type": "object",
|
| 390 |
+
"properties": {
|
| 391 |
+
"pdb_path": {
|
| 392 |
+
"type": "string",
|
| 393 |
+
"description": "Path to input PDB file"
|
| 394 |
+
},
|
| 395 |
+
"chains": {
|
| 396 |
+
"type": "string",
|
| 397 |
+
"default": "A_B",
|
| 398 |
+
"description": "Chain grouping for interface detection"
|
| 399 |
+
},
|
| 400 |
+
"fixed_positions": {
|
| 401 |
+
"type": "array",
|
| 402 |
+
"items": {
|
| 403 |
+
"type": "integer"
|
| 404 |
+
},
|
| 405 |
+
"description": "1-indexed positions to keep fixed"
|
| 406 |
+
},
|
| 407 |
+
"score_function": {
|
| 408 |
+
"type": "string",
|
| 409 |
+
"default": "ref2015",
|
| 410 |
+
"description": "Rosetta score function name"
|
| 411 |
+
}
|
| 412 |
+
},
|
| 413 |
+
"required": [
|
| 414 |
+
"pdb_path"
|
| 415 |
+
]
|
| 416 |
+
}
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"name": "predict_structure_boltz",
|
| 420 |
+
"description": "Predict protein structure using Boltz (fast alternative to AF2/ESMFold). Returns predicted PDB, pLDDT, and pTM scores.",
|
| 421 |
+
"parameters": {
|
| 422 |
+
"type": "object",
|
| 423 |
+
"properties": {
|
| 424 |
+
"sequence": {
|
| 425 |
+
"type": "string",
|
| 426 |
+
"description": "Amino acid sequence to predict structure for"
|
| 427 |
+
},
|
| 428 |
+
"model": {
|
| 429 |
+
"type": "string",
|
| 430 |
+
"default": "boltz2",
|
| 431 |
+
"description": "Model name (default: boltz2)"
|
| 432 |
+
},
|
| 433 |
+
"num_samples": {
|
| 434 |
+
"type": "integer",
|
| 435 |
+
"default": 1,
|
| 436 |
+
"description": "Number of structure samples"
|
| 437 |
+
}
|
| 438 |
+
},
|
| 439 |
+
"required": [
|
| 440 |
+
"sequence"
|
| 441 |
+
]
|
| 442 |
+
}
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"name": "predict_affinity_boltz",
|
| 446 |
+
"description": "Predict binding affinity for a protein complex using Boltz. Returns affinity score, predicted structure, and confidence metrics.",
|
| 447 |
+
"parameters": {
|
| 448 |
+
"type": "object",
|
| 449 |
+
"properties": {
|
| 450 |
+
"sequences": {
|
| 451 |
+
"type": "array",
|
| 452 |
+
"items": {
|
| 453 |
+
"type": "string"
|
| 454 |
+
},
|
| 455 |
+
"description": "List of amino acid sequences, one per chain"
|
| 456 |
+
},
|
| 457 |
+
"model": {
|
| 458 |
+
"type": "string",
|
| 459 |
+
"default": "boltz2",
|
| 460 |
+
"description": "Model name (default: boltz2)"
|
| 461 |
+
}
|
| 462 |
+
},
|
| 463 |
+
"required": [
|
| 464 |
+
"sequences"
|
| 465 |
+
]
|
| 466 |
+
}
|
| 467 |
+
}
|
| 468 |
+
]
|
requirements.txt
CHANGED
|
@@ -1,3 +1,8 @@
|
|
| 1 |
-
gradio>=
|
| 2 |
pandas
|
| 3 |
plotly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.6
|
| 2 |
pandas
|
| 3 |
plotly
|
| 4 |
+
httpx>=0.25
|
| 5 |
+
huggingface_hub>=0.20
|
| 6 |
+
datasets>=2.16
|
| 7 |
+
boltz>=0.4
|
| 8 |
+
pyaudioop
|