| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the CC-by-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from abc import ABC | |
| from torch import nn, Tensor | |
| class ModelWrapper(ABC, nn.Module): | |
| """ | |
| This class is used to wrap around another model, adding custom forward pass logic. | |
| """ | |
| def __init__(self, model: nn.Module): | |
| super().__init__() | |
| self.model = model | |
| def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: | |
| r""" | |
| This method defines how inputs should be passed through the wrapped model. | |
| Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input, | |
| along with any additional keyword arguments. | |
| Optional things to do here: | |
| - check that t is in the dimensions that the model is expecting. | |
| - add a custom forward pass logic. | |
| - call the wrapped model. | |
| | given x, t | |
| | returns the model output for input x at time t, with extra information `extra`. | |
| Args: | |
| x (Tensor): input data to the model (batch_size, ...). | |
| t (Tensor): time (batch_size). | |
| **extras: additional information forwarded to the model, e.g., text condition. | |
| Returns: | |
| Tensor: model output. | |
| """ | |
| return self.model(x=x, t=t, **extras) | |
Xet Storage Details
- Size:
- 1.45 kB
- Xet hash:
- be0b08306ca7167a3a979029eab1aa8a38ffaef54246af98f7bce17a6cfda1dc
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.