Spaces:
Runtime error
Runtime error
| import torch | |
| from .tokenizer import tokenizer | |
| def insert_markers(sentence, ent1, ent2): | |
| if ent1 not in sentence or ent2 not in sentence: | |
| return None | |
| marked = sentence | |
| marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1) | |
| marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1) | |
| return marked | |
| def encode(sentence): | |
| enc = tokenizer( | |
| sentence, | |
| max_length=128, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| input_ids = enc["input_ids"] | |
| attention_mask = enc["attention_mask"] | |
| sub_id = tokenizer.convert_tokens_to_ids("[Sub]") | |
| obj_id = tokenizer.convert_tokens_to_ids("[Obj]") | |
| sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1] | |
| obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1] | |
| return input_ids, attention_mask, sub_pos, obj_pos |