File size: 3,850 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import json
from src.data.esm.sdk.forge import (
ESM3ForgeInferenceClient,
SequenceStructureForgeInferenceClient,
)
class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
def __init__(self, endpoint_name: str, model: str | None = None):
"""SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
Args:
endpoint_name: Name of the SageMaker endpoint.
"""
# Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
super().__init__(url="", model=model, token="dummy")
self._endpoint_name = endpoint_name
self._client = boto3.client(service_name="sagemaker-runtime")
def _post(self, endpoint, request, potential_sequence_of_concern):
request["potential_sequence_of_concern"] = potential_sequence_of_concern
request["model"] = request.get("model", None)
invocations_request = {
# Duplicate these fields at the top level to make Forge requests consistent.
"model": request["model"],
"request_id": "", # Forge specific field.
"user_id": "", # Forge specific field.
# Invocation data bits.
"api_ver": "v1", # Must be v1 right now.
"endpoint": endpoint,
# Wrapped request.
endpoint: request,
}
try:
response = self._client.invoke_endpoint(
EndpointName=self._endpoint_name,
ContentType="application/json",
Body=json.dumps(invocations_request),
)
except Exception as e:
raise RuntimeError(f"Failure in {endpoint}: {e}") from e
data = json.loads(response["Body"].read().decode())
# Response must match request.
assert (
data["endpoint"] == endpoint
), f"Response endpoint is {data['endpoint']} but request is {endpoint}"
# Get the actual responses under the endpoint key.
data = data[endpoint]
return data
class ESM3SageMakerClient(ESM3ForgeInferenceClient):
def __init__(self, endpoint_name: str, model: str):
"""ESM3 client that talks to a SageMaker endpoint.
Args:
endpoint_name: Name of the SageMaker endpoint.
model: Name of the ESM3 model.
"""
# Dummy URL and token to make ESM3ForgeInferenceClient happy.
super().__init__(model=model, url="", token="dummy")
self._endpoint_name = endpoint_name
self._model = model
self._client = boto3.client(service_name="sagemaker-runtime")
def _post(self, endpoint, request, potential_sequence_of_concern):
request["potential_sequence_of_concern"] = potential_sequence_of_concern
invocations_request = {
# Duplicate these fields at the top level to make Forge requests consistent.
"model": request["model"],
"request_id": "", # Forge specific field.
"user_id": "", # Forge specific field.
# Invocation data bits.
"api_ver": "v1", # Must be v1 right now.
"endpoint": endpoint,
# Wrapped request.
endpoint: request,
}
try:
response = self._client.invoke_endpoint(
EndpointName=self._endpoint_name,
ContentType="application/json",
Body=json.dumps(invocations_request),
)
except Exception as e:
raise RuntimeError(f"Failure in {endpoint}: {e}")
data = json.loads(response["Body"].read().decode())
# Response must match request.
assert data["endpoint"] == endpoint
# Get the actual responses under the endpoint key.
data = data[endpoint]
return data
|