Commit
·
290a993
1
Parent(s):
ec2aa9d
testin the trick
Browse files- modeling_stacked.py +8 -1
modeling_stacked.py
CHANGED
|
@@ -65,7 +65,14 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
|
|
| 65 |
print(
|
| 66 |
f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
|
| 67 |
)
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# if input_ids is not None:
|
| 70 |
# tokenizer = kwargs.get("tokenizer")
|
| 71 |
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
|
|
|
| 65 |
print(
|
| 66 |
f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
|
| 67 |
)
|
| 68 |
+
if isinstance(input_ids, str):
|
| 69 |
+
# If the input is a single string, make it a list for floret
|
| 70 |
+
texts = [input_ids]
|
| 71 |
+
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
|
| 72 |
+
texts = input_ids
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Unexpected input type: {type(input_ids)}")
|
| 75 |
+
# print(self.model_floret(input_ids))
|
| 76 |
# if input_ids is not None:
|
| 77 |
# tokenizer = kwargs.get("tokenizer")
|
| 78 |
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|