Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -66,7 +66,7 @@ class ProteinDataset(Dataset):
|
|
| 66 |
|
| 67 |
# Create labels (tokens for ground truth AAs)
|
| 68 |
label_seq = protein_seq + peptide_seq
|
| 69 |
-
labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].squeeze()
|
| 70 |
|
| 71 |
# Set non-masked positions in the labels tensor to -100
|
| 72 |
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
|
|
|
|
| 66 |
|
| 67 |
# Create labels (tokens for ground truth AAs)
|
| 68 |
label_seq = protein_seq + peptide_seq
|
| 69 |
+
labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].to(device).squeeze()
|
| 70 |
|
| 71 |
# Set non-masked positions in the labels tensor to -100
|
| 72 |
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
|