nas / PFMBench /src /data /esm /utils /forge_context_manager.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import threading
from collections import deque
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from contextvars import copy_context
from typing import Any, Callable, Dict, List
from tqdm import tqdm
from src.data.esm.sdk.api import ESMProteinError
from src.data.esm.sdk.forge import (
retry_if_specific_error,
skip_retries_var,
)
TQDM_BAR_FORMAT = (
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
"[Elapsed: {elapsed} | Remaining: {remaining}] {postfix}"
)
class AIMDRateLimiter:
"""Rate limiter with AIMD (Additive Increase/Multiplicative Decrease) control."""
def __init__(
self,
initial_concurrency: int = 32,
min_concurrency: int = 1,
max_concurrency: int = 512,
step_up: int = 1,
):
self.concurrency = initial_concurrency
self.min_concurrency = min_concurrency
self.max_concurrency = max_concurrency
self.step_up = step_up
self._lock = threading.Lock()
def adjust_concurrency(self, error_seen: bool) -> int:
"""Update concurrency based on if an error is seen."""
with self._lock:
if error_seen:
self.concurrency = max(self.min_concurrency, self.concurrency // 2)
else:
self.concurrency = min(
self.max_concurrency, self.concurrency + self.step_up
)
return self.concurrency
class ForgeBatchExecutor:
"""Context manager for managing concurrent calls with rate limiting."""
def __init__(self, max_attempts: int = 10):
self.rate_limiter = AIMDRateLimiter()
self.max_attempts = max_attempts
self._executor = ThreadPoolExecutor(
max_workers=self.rate_limiter.max_concurrency
)
self._skip_retries_token = None
def __enter__(self):
self._skip_retries_token = skip_retries_var.set(True)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._skip_retries_token is not None:
skip_retries_var.reset(self._skip_retries_token)
if self._executor:
self._executor.shutdown(wait=True)
def _validate_inputs(self, inputs: Dict[str, Any]) -> int:
"""Validate input lengths and return the number of tasks."""
input_lengths = [len(v) for v in inputs.values() if isinstance(v, list)]
num_inputs = max(input_lengths) if input_lengths else 1
if input_lengths and len(set(input_lengths)) > 1:
raise ValueError("All list-valued arguments must have the same length")
return num_inputs
def execute_batch(self, user_func: Callable, **kwargs: Any) -> List[Any]:
"""Call the endpoint with batched inputs, managing concurrency and retries."""
num_tasks = self._validate_inputs(kwargs)
# Initialize task queue with (task_index, attempt) tuples.
task_queue = deque([(i, 1) for i in range(num_tasks)])
results = [None] * num_tasks
running_futures = {}
success_count = 0
fail_count = 0
retry_count = 0
with tqdm(
total=num_tasks, desc="Processing", bar_format=TQDM_BAR_FORMAT, unit="task"
) as pbar:
while task_queue or running_futures:
current_limit = self.rate_limiter.concurrency
while task_queue and len(running_futures) < current_limit:
idx, attempt = task_queue.popleft()
call_kwargs = {
k: v[idx] if isinstance(v, list) else v
for k, v in kwargs.items()
}
ctx = copy_context()
future = self._executor.submit(ctx.run, user_func, **call_kwargs)
running_futures[future] = (idx, attempt)
done, _ = wait(
running_futures.keys(), return_when=FIRST_COMPLETED, timeout=1
)
error_seen = False
for future in done:
idx, attempt = running_futures.pop(future)
try:
result = future.result()
if isinstance(result, ESMProteinError):
raise result
results[idx] = result
success_count += 1
pbar.update(1)
except Exception as e:
if retry_if_specific_error(e) and attempt < self.max_attempts:
task_queue.append((idx, attempt + 1))
# Only scale concurrency if hit rate limit errors.
if isinstance(e, ESMProteinError) and e.error_code == 429:
error_seen = True
retry_count += 1
pbar.update(0)
else:
results[idx] = e # type: ignore
fail_count += 1
pbar.update(0)
self.rate_limiter.adjust_concurrency(error_seen)
pbar.set_postfix_str(
f"Success={success_count} Fail={fail_count} Retry={retry_count}"
)
return results