fiveflow's picture
Update README.md
218ff4f
|
raw
history blame
1.18 kB
metadata
language:
  - ko
library_name: transformers
pipeline_tag: token-classification
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoConfig

tokenizer = AutoTokenizer.from_pretrained("fiveflow/roberta-base-spacing")
roberta = AutoModelForTokenClassification.from_pretrained("fiveflow/roberta-base-spacing")

org_text = "ํƒ„์†Œ์ค‘๋ฆฝ๊ณผESG๊ฒฝ์˜์—๋Œ€ํ•œ์‚ฌํšŒ์ ์š”๊ตฌํ™•๋Œ€".replace(" ", "") # ๊ณต๋ฐฑ์ œ๊ฑฐ
label = ["UNK", "PAD", "O", "B", "I", "E", "S"]
# char ๋‹จ์œ„๋กœ ํ† ํฐํ™”
token_list = [tokenizer.cls_token_id]
for char in org_text:
    token_list.append(tokenizer.encode(char)[1]) 
token_list.append(tokenizer.eos_token_id)
tkd = torch.tensor(token_list).unsqueeze(0)

output = roberta(tkd).logits

_, pred_idx = torch.max(output, dim=2)
tags = [label[idx] for idx in pred_idx.squeeze()][1:-1]
pred_sent = ""
for char_idx, spc_idx in enumerate(pred_idx.squeeze()[1:-1]):
    # "E" tag ๋‹จ์œ„๋กœ ๋„์–ด์“ฐ๊ธฐ
    if label[spc_idx] == "E": pred_sent += org_text[char_idx] + " "
    else: pred_sent += org_text[char_idx]

print(pred_sent.strip())
# 'ํƒ„์†Œ์ค‘๋ฆฝ๊ณผ ESG ๊ฒฝ์˜์— ๋Œ€ํ•œ ์‚ฌํšŒ์  ์š”๊ตฌ ํ™•๋Œ€'