| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from collections.abc import AsyncGenerator |
|
|
| from ..config import ModelArguments, SampleArguments, SampleBackend |
| from ..utils.types import HFModel, Message, Sample, TorchDataset |
| from .utils.inference_engine import HuggingFaceEngine |
| from .utils.rendering import Renderer |
|
|
|
|
| class BaseSampler: |
| """Base sampler. |
| |
| Args: |
| args: Sample arguments. |
| model_args: Model arguments. |
| model: Model. |
| renderer: Renderer. |
| """ |
|
|
| def __init__( |
| self, |
| args: SampleArguments, |
| model_args: ModelArguments, |
| model: HFModel, |
| renderer: Renderer, |
| ) -> None: |
| if args.sample_backend == SampleBackend.HF: |
| self.engine = HuggingFaceEngine(args, model_args, model, renderer) |
| else: |
| raise ValueError(f"Unknown sample backend: {args.sample_backend}") |
|
|
| async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: |
| """Generate tokens asynchronously. |
| |
| Args: |
| messages: List of messages. |
| tools: Tools string. |
| |
| Yields: |
| Generated tokens. |
| """ |
| async for token in self.engine.generate(messages, tools): |
| yield token |
|
|
| async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: |
| """Batch infer samples. |
| |
| Args: |
| dataset: Torch dataset. |
| |
| Returns: |
| List of samples. |
| """ |
| return await self.engine.batch_infer(dataset) |
|
|