File size: 868 Bytes
c4c644d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from .tokenizer import tokenizer

def insert_markers(sentence, ent1, ent2):
    if ent1 not in sentence or ent2 not in sentence:
        return None

    marked = sentence
    marked = marked.replace(ent1, f"[Sub] {ent1} [/Sub]", 1)
    marked = marked.replace(ent2, f"[Obj] {ent2} [/Obj]", 1)

    return marked


def encode(sentence):
    enc = tokenizer(
        sentence,
        max_length=128,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]

    sub_id = tokenizer.convert_tokens_to_ids("[Sub]")
    obj_id = tokenizer.convert_tokens_to_ids("[Obj]")

    sub_pos = (input_ids == sub_id).nonzero(as_tuple=True)[1]
    obj_pos = (input_ids == obj_id).nonzero(as_tuple=True)[1]

    return input_ids, attention_mask, sub_pos, obj_pos