Scrapling / scrapling /spiders /session.py
Karim shoair
style: Fix all mypy errors and add type hints to untyped function bodies
31c2447
from asyncio import Lock
from scrapling.spiders.request import Request
from scrapling.engines.static import _ASyncSessionLogic
from scrapling.engines.toolbelt.convertor import Response
from scrapling.core._types import Set, cast, SUPPORTED_HTTP_METHODS
from scrapling.fetchers import AsyncDynamicSession, AsyncStealthySession, FetcherSession
Session = FetcherSession | AsyncDynamicSession | AsyncStealthySession
class SessionManager:
"""Manages pre-configured session instances."""
def __init__(self) -> None:
self._sessions: dict[str, Session] = {}
self._default_session_id: str | None = None
self._started: bool = False
self._lazy_sessions: Set[str] = set()
self._lazy_lock = Lock()
def add(self, session_id: str, session: Session, *, default: bool = False, lazy: bool = False) -> "SessionManager":
"""Register a session instance.
:param session_id: Name to reference this session in requests
:param session: Your pre-configured session instance
:param default: If True, this becomes the default session
:param lazy: If True, the session will be started only when a request uses its ID.
"""
if session_id in self._sessions:
raise ValueError(f"Session '{session_id}' already registered")
self._sessions[session_id] = session
if default or self._default_session_id is None:
self._default_session_id = session_id
if lazy:
self._lazy_sessions.add(session_id)
return self
def remove(self, session_id: str) -> None:
"""Removes a session.
:param session_id: ID of session to remove
"""
_ = self.pop(session_id)
def pop(self, session_id: str) -> Session:
"""Remove and returns a session.
:param session_id: ID of session to remove
"""
if session_id not in self._sessions:
raise KeyError(f"Session '{session_id}' not found")
session = self._sessions.pop(session_id)
if session_id in self._lazy_sessions:
self._lazy_sessions.remove(session_id)
if session and self._default_session_id == session_id:
self._default_session_id = next(iter(self._sessions), None)
return session
@property
def default_session_id(self) -> str:
if self._default_session_id is None:
raise RuntimeError("No sessions registered")
return self._default_session_id
@property
def session_ids(self) -> list[str]:
return list(self._sessions.keys())
def get(self, session_id: str) -> Session:
if session_id not in self._sessions:
available = ", ".join(self._sessions.keys())
raise KeyError(f"Session '{session_id}' not found. Available: {available}")
return self._sessions[session_id]
async def start(self) -> None:
"""Start all sessions that aren't already alive."""
if self._started:
return
for sid, session in self._sessions.items():
if sid not in self._lazy_sessions and not session._is_alive:
await session.__aenter__()
self._started = True
async def close(self) -> None:
"""Close all registered sessions."""
for session in self._sessions.values():
_ = await session.__aexit__(None, None, None)
self._started = False
async def fetch(self, request: Request) -> Response:
sid = request.sid if request.sid else self.default_session_id
session = self.get(sid)
if session:
if sid in self._lazy_sessions and not session._is_alive:
async with self._lazy_lock:
if not session._is_alive:
await session.__aenter__()
if isinstance(session, FetcherSession):
client = session._client
if isinstance(client, _ASyncSessionLogic):
response = await client._make_request(
method=cast(SUPPORTED_HTTP_METHODS, request._session_kwargs.pop("method", "GET")),
url=request.url,
**request._session_kwargs,
)
else:
# Sync session or other types - shouldn't happen in async context
raise TypeError(f"Session type {type(client)} not supported for async fetch")
else:
response = await session.fetch(url=request.url, **request._session_kwargs)
response.request = request
# Merge request meta into response meta (response meta takes priority)
response.meta = {**request.meta, **response.meta}
return response
raise RuntimeError("No session found with the request session id")
async def __aenter__(self) -> "SessionManager":
await self.start()
return self
async def __aexit__(self, *exc) -> None:
await self.close()
def __contains__(self, session_id: str) -> bool:
"""Check if a session ID is registered."""
return session_id in self._sessions
def __len__(self) -> int:
"""Number of registered sessions."""
return len(self._sessions)