import torch from .tokenizer import tokenizer def insert_markers(sentence, ent1, ent2): if ent1 not in sentence or ent2 not in sentence: return None marked = sentence marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1) marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1) return marked def encode(sentence): enc = tokenizer( sentence, max_length=128, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] sub_id = tokenizer.convert_tokens_to_ids("[Sub]") obj_id = tokenizer.convert_tokens_to_ids("[Obj]") sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1] obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1] return input_ids, attention_mask, sub_pos, obj_pos