metadata
language:
- ko
library_name: transformers
pipeline_tag: token-classification
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoConfig
tokenizer = AutoTokenizer.from_pretrained("fiveflow/roberta-base-spacing")
roberta = AutoModelForTokenClassification.from_pretrained("fiveflow/roberta-base-spacing")
org_text = "ํ์์ค๋ฆฝ๊ณผESG๊ฒฝ์์๋ํ์ฌํ์ ์๊ตฌํ๋".replace(" ", "") # ๊ณต๋ฐฑ์ ๊ฑฐ
label = ["UNK", "PAD", "O", "B", "I", "E", "S"]
# char ๋จ์๋ก ํ ํฐํ
token_list = [tokenizer.cls_token_id]
for char in org_text:
token_list.append(tokenizer.encode(char)[1])
token_list.append(tokenizer.eos_token_id)
tkd = torch.tensor(token_list).unsqueeze(0)
output = roberta(tkd).logits
_, pred_idx = torch.max(output, dim=2)
tags = [label[idx] for idx in pred_idx.squeeze()][1:-1]
pred_sent = ""
for char_idx, spc_idx in enumerate(pred_idx.squeeze()[1:-1]):
# "E" tag ๋จ์๋ก ๋์ด์ฐ๊ธฐ
if label[spc_idx] == "E": pred_sent += org_text[char_idx] + " "
else: pred_sent += org_text[char_idx]
print(pred_sent.strip())
# 'ํ์์ค๋ฆฝ๊ณผ ESG ๊ฒฝ์์ ๋ํ ์ฌํ์ ์๊ตฌ ํ๋'