File size: 1,008 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import os
from src.data.esm.sdk.forge import ESM3ForgeInferenceClient
from src.data.esm.utils.forge_context_manager import ForgeBatchExecutor
# Note: please do not import ESM3SageMakerClient here since that requires AWS SDK.
def client(
model="esm3-sm-open-v1",
url="https://forge.evolutionaryscale.ai",
token=os.environ.get("ESM_API_KEY", ""),
request_timeout=None,
):
"""
Args:
model: Name of the model to use.
url: URL of a forge server.
token: User's API token.
request_timeout: Amount of time to wait for a request to finish.
Default is wait indefinitely.
"""
return ESM3ForgeInferenceClient(model, url, token, request_timeout)
def batch_executor(max_attempts: int = 10):
"""
Args:
max_attempts: Maximum number of attempts to make before giving up.
Usage:
with batch_executor() as executor:
executor.execute_batch(fn, **kwargs)
"""
return ForgeBatchExecutor(max_attempts)
|