acmyu commited on
Commit
32da7a4
·
1 Parent(s): 7e8803c
Files changed (1) hide show
  1. libs/film/predict.py +3 -22
libs/film/predict.py CHANGED
@@ -7,12 +7,12 @@ import mediapy
7
  from PIL import Image
8
  import cog
9
 
10
- from eval import interpolator, util
11
 
12
  _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
13
 
14
 
15
- class Predictor(cog.Predictor):
16
  def setup(self):
17
  import tensorflow as tf
18
  print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
@@ -21,26 +21,7 @@ class Predictor(cog.Predictor):
21
  # Batched time.
22
  self.batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
23
 
24
- @cog.input(
25
- "frame1",
26
- type=Path,
27
- help="The first input frame",
28
- )
29
- @cog.input(
30
- "frame2",
31
- type=Path,
32
- help="The second input frame",
33
- )
34
- @cog.input(
35
- "times_to_interpolate",
36
- type=int,
37
- default=1,
38
- min=1,
39
- max=8,
40
- help="Controls the number of times the frame interpolator is invoked If set to 1, the output will be the "
41
- "sub-frame at t=0.5; when set to > 1, the output will be the interpolation video with "
42
- "(2^times_to_interpolate + 1) frames, fps of 30.",
43
- )
44
  def predict(self, frame1, frame2, times_to_interpolate):
45
  INPUT_EXT = ['.png', '.jpg', '.jpeg']
46
  assert os.path.splitext(str(frame1))[-1] in INPUT_EXT and os.path.splitext(str(frame2))[-1] in INPUT_EXT, \
 
7
  from PIL import Image
8
  import cog
9
 
10
+ from .eval import interpolator, util
11
 
12
  _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
13
 
14
 
15
+ class Predictor(cog.BasePredictor):
16
  def setup(self):
17
  import tensorflow as tf
18
  print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
 
21
  # Batched time.
22
  self.batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
23
 
24
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def predict(self, frame1, frame2, times_to_interpolate):
26
  INPUT_EXT = ['.png', '.jpg', '.jpeg']
27
  assert os.path.splitext(str(frame1))[-1] in INPUT_EXT and os.path.splitext(str(frame2))[-1] in INPUT_EXT, \