gbyuvd commited on
Commit
097a367
·
verified ·
1 Parent(s): d7be213

Tensor handling fix

Browse files
Files changed (2) hide show
  1. CHANGELOG +2 -1
  2. FastChemTokenizer.py +6 -6
CHANGELOG CHANGED
@@ -1,4 +1,5 @@
1
  [20 Sept 2025]
2
  - Add basic SELFIES tokenizer function
3
  - Upload both core and tailed SELFIES vocab
4
- - Update README to include SELFIES evals
 
 
1
  [20 Sept 2025]
2
  - Add basic SELFIES tokenizer function
3
  - Upload both core and tailed SELFIES vocab
4
+ - Update README to include SELFIES evals
5
+ - Handle both tensor and non-tensor items properly
FastChemTokenizer.py CHANGED
@@ -271,9 +271,9 @@ class FastChemTokenizer:
271
 
272
  if kwargs.get("return_tensors") == "pt":
273
  def to_tensor_list(lst):
274
- # Use torch.tensor for safety avoids "copy construct from tensor" warning
275
- return [torch.tensor(item, dtype=torch.long) for item in lst]
276
-
277
  batched = {
278
  k: torch.nn.utils.rnn.pad_sequence(
279
  to_tensor_list(v),
@@ -570,9 +570,9 @@ class FastChemTokenizerSelfies:
570
 
571
  if kwargs.get("return_tensors") == "pt":
572
  def to_tensor_list(lst):
573
- # Use torch.tensor for safety avoids "copy construct from tensor" warning
574
- return [torch.tensor(item, dtype=torch.long) for item in lst]
575
-
576
  batched = {
577
  k: torch.nn.utils.rnn.pad_sequence(
578
  to_tensor_list(v),
 
271
 
272
  if kwargs.get("return_tensors") == "pt":
273
  def to_tensor_list(lst):
274
+ # Fixed: Handle both tensor and non-tensor items properly
275
+ return [item.clone().detach() if isinstance(item, torch.Tensor)
276
+ else torch.tensor(item, dtype=torch.long) for item in lst]
277
  batched = {
278
  k: torch.nn.utils.rnn.pad_sequence(
279
  to_tensor_list(v),
 
570
 
571
  if kwargs.get("return_tensors") == "pt":
572
  def to_tensor_list(lst):
573
+ # Fixed: Handle both tensor and non-tensor items properly
574
+ return [item.clone().detach() if isinstance(item, torch.Tensor)
575
+ else torch.tensor(item, dtype=torch.long) for item in lst]
576
  batched = {
577
  k: torch.nn.utils.rnn.pad_sequence(
578
  to_tensor_list(v),