File size: 3,850 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json

from src.data.esm.sdk.forge import (
    ESM3ForgeInferenceClient,
    SequenceStructureForgeInferenceClient,
)


class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
    def __init__(self, endpoint_name: str, model: str | None = None):
        """SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.

        Args:
            endpoint_name: Name of the SageMaker endpoint.
        """
        # Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
        super().__init__(url="", model=model, token="dummy")

        self._endpoint_name = endpoint_name

        self._client = boto3.client(service_name="sagemaker-runtime")

    def _post(self, endpoint, request, potential_sequence_of_concern):
        request["potential_sequence_of_concern"] = potential_sequence_of_concern
        request["model"] = request.get("model", None)
        invocations_request = {
            # Duplicate these fields at the top level to make Forge requests consistent.
            "model": request["model"],
            "request_id": "",  # Forge specific field.
            "user_id": "",  # Forge specific field.
            # Invocation data bits.
            "api_ver": "v1",  # Must be v1 right now.
            "endpoint": endpoint,
            # Wrapped request.
            endpoint: request,
        }

        try:
            response = self._client.invoke_endpoint(
                EndpointName=self._endpoint_name,
                ContentType="application/json",
                Body=json.dumps(invocations_request),
            )
        except Exception as e:
            raise RuntimeError(f"Failure in {endpoint}: {e}") from e

        data = json.loads(response["Body"].read().decode())

        # Response must match request.
        assert (
            data["endpoint"] == endpoint
        ), f"Response endpoint is {data['endpoint']} but request is {endpoint}"

        # Get the actual responses under the endpoint key.
        data = data[endpoint]

        return data


class ESM3SageMakerClient(ESM3ForgeInferenceClient):
    def __init__(self, endpoint_name: str, model: str):
        """ESM3 client that talks to a SageMaker endpoint.

        Args:
            endpoint_name: Name of the SageMaker endpoint.
            model: Name of the ESM3 model.
        """
        # Dummy URL and token to make ESM3ForgeInferenceClient happy.
        super().__init__(model=model, url="", token="dummy")

        self._endpoint_name = endpoint_name
        self._model = model

        self._client = boto3.client(service_name="sagemaker-runtime")

    def _post(self, endpoint, request, potential_sequence_of_concern):
        request["potential_sequence_of_concern"] = potential_sequence_of_concern

        invocations_request = {
            # Duplicate these fields at the top level to make Forge requests consistent.
            "model": request["model"],
            "request_id": "",  # Forge specific field.
            "user_id": "",  # Forge specific field.
            # Invocation data bits.
            "api_ver": "v1",  # Must be v1 right now.
            "endpoint": endpoint,
            # Wrapped request.
            endpoint: request,
        }

        try:
            response = self._client.invoke_endpoint(
                EndpointName=self._endpoint_name,
                ContentType="application/json",
                Body=json.dumps(invocations_request),
            )
        except Exception as e:
            raise RuntimeError(f"Failure in {endpoint}: {e}")

        data = json.loads(response["Body"].read().decode())

        # Response must match request.
        assert data["endpoint"] == endpoint

        # Get the actual responses under the endpoint key.
        data = data[endpoint]

        return data