File size: 2,288 Bytes
af83196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Async utilities for SkyDiscover
"""

import asyncio
import logging
from typing import Any, Callable, List, Optional, Sequence, Tuple

logger = logging.getLogger(__name__)


class TaskPool:
    """
    A simple task pool for managing and limiting concurrent tasks
    """

    def __init__(self, max_concurrency: int = 10):
        self.max_concurrency = max_concurrency
        self._semaphore: Optional[asyncio.Semaphore] = None
        self.tasks: List[asyncio.Task] = []

    @property
    def semaphore(self) -> asyncio.Semaphore:
        """Lazy-initialize the semaphore when first needed."""
        if self._semaphore is None:
            self._semaphore = asyncio.Semaphore(self.max_concurrency)
        return self._semaphore

    async def run(self, coro: Callable, *args: Any, **kwargs: Any) -> Any:
        """Run a single coroutine function under the concurrency semaphore."""
        async with self.semaphore:
            return await coro(*args, **kwargs)

    def create_task(self, coro: Callable, *args: Any, **kwargs: Any) -> asyncio.Task:
        """Create, track, and return an ``asyncio.Task`` bounded by the pool."""
        task = asyncio.create_task(self.run(coro, *args, **kwargs))
        self.tasks.append(task)
        task.add_done_callback(lambda t: self.tasks.remove(t))
        return task

    async def gather(
        self,
        coros: Sequence[Callable],
        args_list: Sequence[Tuple[Any, ...]] = (),
        kwargs_list: Sequence[dict] = (),
        return_exceptions: bool = False,
    ) -> List[Any]:
        """Run *coros* concurrently (bounded by the semaphore), return results in order."""
        n = len(coros)
        _args = args_list if args_list else [() for _ in range(n)]
        _kwargs = kwargs_list if kwargs_list else [{} for _ in range(n)]

        if len(_args) != n:
            raise ValueError(f"args_list length ({len(_args)}) must match coros length ({n})")
        if len(_kwargs) != n:
            raise ValueError(f"kwargs_list length ({len(_kwargs)}) must match coros length ({n})")

        tasks = [
            self.create_task(coro, *args, **kwargs)
            for coro, args, kwargs in zip(coros, _args, _kwargs)
        ]
        return await asyncio.gather(*tasks, return_exceptions=return_exceptions)