davidaf3 commited on
Commit
382cc6e
·
1 Parent(s): eb443a0

Fixed error in pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +4 -3
pipeline.py CHANGED
@@ -8,11 +8,12 @@ import tensorflow as tf
8
 
9
  class PreTrainedPipeline():
10
  def __init__(self, path=""):
11
- crop_size = (224, 224)
12
  self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
13
- self.model = FCNutr(self.nutr_names, crop_size, 4096, 3, False)
 
14
  self.model.compile()
15
- self.model(tf.zeros((1, crop_size[0], crop_size[1], 3)))
16
  self.model.load_weights(os.path.join(path, "fcnutr.h5"))
17
 
18
  def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
 
8
 
9
  class PreTrainedPipeline():
10
  def __init__(self, path=""):
11
+ self.crop_size = (224, 224)
12
  self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
13
+ self.img_size = 256
14
+ self.model = FCNutr(self.nutr_names, self.crop_size, 4096, 3, False)
15
  self.model.compile()
16
+ self.model(tf.zeros((1, self.crop_size[0], self.crop_size[1], 3)))
17
  self.model.load_weights(os.path.join(path, "fcnutr.h5"))
18
 
19
  def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]: