Spaces:
Paused
Paused
| # Copyright 2022 Google LLC | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A library for instantiating the model for training frame interpolation. | |
| All models are expected to use three inputs: input image batches 'x0' and 'x1' | |
| and 'time', the fractional time where the output should be generated. | |
| The models are expected to output the prediction as a dictionary that contains | |
| at least the predicted image batch as 'image' plus optional data for debug, | |
| analysis or custom losses. | |
| """ | |
| import gin.tf | |
| from ..models.film_net import interpolator as film_net_interpolator | |
| from ..models.film_net import options as film_net_options | |
| import tensorflow as tf | |
| def create_model(name: str) -> tf.keras.Model: | |
| """Creates the frame interpolation model based on given model name.""" | |
| if name == 'film_net': | |
| return _create_film_net_model() # pylint: disable=no-value-for-parameter | |
| else: | |
| raise ValueError(f'Model {name} not implemented.') | |
| def _create_film_net_model() -> tf.keras.Model: | |
| """Creates the film_net interpolator.""" | |
| # Options are gin-configured in the Options class directly. | |
| options = film_net_options.Options() | |
| x0 = tf.keras.Input( | |
| shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x0') | |
| x1 = tf.keras.Input( | |
| shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x1') | |
| time = tf.keras.Input( | |
| shape=(1,), batch_size=None, dtype=tf.float32, name='time') | |
| return film_net_interpolator.create_model(x0, x1, time, options) | |