nroggendorff commited on
Commit
ae0c2b0
·
verified ·
1 Parent(s): 55c4eff

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -0
train.py CHANGED
@@ -36,6 +36,9 @@ def caption_batch(batch, processor, model):
36
 
37
  pil_images = []
38
  for image in images:
 
 
 
39
  if not isinstance(image, Image.Image):
40
  image = Image.fromarray(image)
41
  if image.mode != "RGB":
 
36
 
37
  pil_images = []
38
  for image in images:
39
+ if isinstance(image, torch.Tensor):
40
+ image = image.cpu().numpy()
41
+
42
  if not isinstance(image, Image.Image):
43
  image = Image.fromarray(image)
44
  if image.mode != "RGB":