emanuelaboros commited on
Commit
290a993
·
1 Parent(s): ec2aa9d

testin the trick

Browse files
Files changed (1) hide show
  1. modeling_stacked.py +8 -1
modeling_stacked.py CHANGED
@@ -65,7 +65,14 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
65
  print(
66
  f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
67
  )
68
- print(self.model_floret(input_ids))
 
 
 
 
 
 
 
69
  # if input_ids is not None:
70
  # tokenizer = kwargs.get("tokenizer")
71
  # texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
 
65
  print(
66
  f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
67
  )
68
+ if isinstance(input_ids, str):
69
+ # If the input is a single string, make it a list for floret
70
+ texts = [input_ids]
71
+ elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
72
+ texts = input_ids
73
+ else:
74
+ raise ValueError(f"Unexpected input type: {type(input_ids)}")
75
+ # print(self.model_floret(input_ids))
76
  # if input_ids is not None:
77
  # tokenizer = kwargs.get("tokenizer")
78
  # texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)