wangjin2000 commited on
Commit
cd0b38f
·
verified ·
1 Parent(s): 752caab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -66,7 +66,7 @@ class ProteinDataset(Dataset):
66
 
67
  # Create labels (tokens for ground truth AAs)
68
  label_seq = protein_seq + peptide_seq
69
- labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].squeeze()
70
 
71
  # Set non-masked positions in the labels tensor to -100
72
  labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
 
66
 
67
  # Create labels (tokens for ground truth AAs)
68
  label_seq = protein_seq + peptide_seq
69
+ labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].to(device).squeeze()
70
 
71
  # Set non-masked positions in the labels tensor to -100
72
  labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)