Ubuntu commited on
Commit
4540774
·
1 Parent(s): 0c82c11
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +1 -1
  2. app.py +1 -1
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -22,7 +22,7 @@ class ImageClassificationCollator:
22
 
23
  def __call__(self, batch):
24
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
25
- encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
26
  return encodings
27
 
28
  class Classifier(pl.LightningModule):
 
22
 
23
  def __call__(self, batch):
24
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
25
+ encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.float)
26
  return encodings
27
 
28
  class Classifier(pl.LightningModule):
app.py CHANGED
@@ -22,7 +22,7 @@ class ImageClassificationCollator:
22
 
23
  def __call__(self, batch):
24
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
25
- encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
26
  return encodings
27
 
28
  class Classifier(pl.LightningModule):
 
22
 
23
  def __call__(self, batch):
24
  encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
25
+ encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.float)
26
  return encodings
27
 
28
  class Classifier(pl.LightningModule):