| from asyncio import Lock |
|
|
| from scrapling.spiders.request import Request |
| from scrapling.engines.static import _ASyncSessionLogic |
| from scrapling.engines.toolbelt.convertor import Response |
| from scrapling.core._types import Set, cast, SUPPORTED_HTTP_METHODS |
| from scrapling.fetchers import AsyncDynamicSession, AsyncStealthySession, FetcherSession |
|
|
| Session = FetcherSession | AsyncDynamicSession | AsyncStealthySession |
|
|
|
|
| class SessionManager: |
| """Manages pre-configured session instances.""" |
|
|
| def __init__(self) -> None: |
| self._sessions: dict[str, Session] = {} |
| self._default_session_id: str | None = None |
| self._started: bool = False |
| self._lazy_sessions: Set[str] = set() |
| self._lazy_lock = Lock() |
|
|
| def add(self, session_id: str, session: Session, *, default: bool = False, lazy: bool = False) -> "SessionManager": |
| """Register a session instance. |
| |
| :param session_id: Name to reference this session in requests |
| :param session: Your pre-configured session instance |
| :param default: If True, this becomes the default session |
| :param lazy: If True, the session will be started only when a request uses its ID. |
| """ |
| if session_id in self._sessions: |
| raise ValueError(f"Session '{session_id}' already registered") |
|
|
| self._sessions[session_id] = session |
|
|
| if default or self._default_session_id is None: |
| self._default_session_id = session_id |
|
|
| if lazy: |
| self._lazy_sessions.add(session_id) |
|
|
| return self |
|
|
| def remove(self, session_id: str) -> None: |
| """Removes a session. |
| |
| :param session_id: ID of session to remove |
| """ |
| _ = self.pop(session_id) |
|
|
| def pop(self, session_id: str) -> Session: |
| """Remove and returns a session. |
| |
| :param session_id: ID of session to remove |
| """ |
| if session_id not in self._sessions: |
| raise KeyError(f"Session '{session_id}' not found") |
|
|
| session = self._sessions.pop(session_id) |
| if session_id in self._lazy_sessions: |
| self._lazy_sessions.remove(session_id) |
|
|
| if session and self._default_session_id == session_id: |
| self._default_session_id = next(iter(self._sessions), None) |
|
|
| return session |
|
|
| @property |
| def default_session_id(self) -> str: |
| if self._default_session_id is None: |
| raise RuntimeError("No sessions registered") |
| return self._default_session_id |
|
|
| @property |
| def session_ids(self) -> list[str]: |
| return list(self._sessions.keys()) |
|
|
| def get(self, session_id: str) -> Session: |
| if session_id not in self._sessions: |
| available = ", ".join(self._sessions.keys()) |
| raise KeyError(f"Session '{session_id}' not found. Available: {available}") |
| return self._sessions[session_id] |
|
|
| async def start(self) -> None: |
| """Start all sessions that aren't already alive.""" |
| if self._started: |
| return |
|
|
| for sid, session in self._sessions.items(): |
| if sid not in self._lazy_sessions and not session._is_alive: |
| await session.__aenter__() |
|
|
| self._started = True |
|
|
| async def close(self) -> None: |
| """Close all registered sessions.""" |
| for session in self._sessions.values(): |
| _ = await session.__aexit__(None, None, None) |
|
|
| self._started = False |
|
|
| async def fetch(self, request: Request) -> Response: |
| sid = request.sid if request.sid else self.default_session_id |
| session = self.get(sid) |
|
|
| if session: |
| if sid in self._lazy_sessions and not session._is_alive: |
| async with self._lazy_lock: |
| if not session._is_alive: |
| await session.__aenter__() |
|
|
| if isinstance(session, FetcherSession): |
| client = session._client |
|
|
| if isinstance(client, _ASyncSessionLogic): |
| response = await client._make_request( |
| method=cast(SUPPORTED_HTTP_METHODS, request._session_kwargs.pop("method", "GET")), |
| url=request.url, |
| **request._session_kwargs, |
| ) |
| else: |
| |
| raise TypeError(f"Session type {type(client)} not supported for async fetch") |
| else: |
| response = await session.fetch(url=request.url, **request._session_kwargs) |
|
|
| response.request = request |
| |
| response.meta = {**request.meta, **response.meta} |
| return response |
| raise RuntimeError("No session found with the request session id") |
|
|
| async def __aenter__(self) -> "SessionManager": |
| await self.start() |
| return self |
|
|
| async def __aexit__(self, *exc) -> None: |
| await self.close() |
|
|
| def __contains__(self, session_id: str) -> bool: |
| """Check if a session ID is registered.""" |
| return session_id in self._sessions |
|
|
| def __len__(self) -> int: |
| """Number of registered sessions.""" |
| return len(self._sessions) |
|
|