Spaces:
Runtime error
Runtime error
File size: 868 Bytes
c4c644d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch
from .tokenizer import tokenizer
def insert_markers(sentence, ent1, ent2):
if ent1 not in sentence or ent2 not in sentence:
return None
marked = sentence
marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1)
marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)
return marked
def encode(sentence):
enc = tokenizer(
sentence,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
sub_id = tokenizer.convert_tokens_to_ids("[Sub]")
obj_id = tokenizer.convert_tokens_to_ids("[Obj]")
sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1]
obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1]
return input_ids, attention_mask, sub_pos, obj_pos |