Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import abc | |
| from typing import Any, Generator, Generic, TypeVar | |
| import pydantic | |
| T = TypeVar("T") | |
| C = TypeVar("C") | |
| class StatefulIterator(Generic[T, C], abc.ABC): | |
| def get_state(self) -> C: | |
| pass | |
| def create_iter(self) -> Generator[T, Any, None]: | |
| pass | |
| class IteratorState(Generic[C]): | |
| def build(self) -> StatefulIterator[T, C]: | |
| pass | |
| class PydanticIteratorState(pydantic.BaseModel, IteratorState): | |
| model_config = pydantic.ConfigDict(extra="forbid") | |
| def get_state_and_refresh(iterator: StatefulIterator): | |
| # Re-init dataloader and iterator is necessary since get_state() | |
| # on mp iterator shuts down MP to correctly persist state and it needs | |
| # to be restarted. | |
| state = iterator.get_state() | |
| data_loader = state.build() | |
| py_iterator = data_loader.create_iter() | |
| return state, data_loader, py_iterator | |