File size: 5,392 Bytes
3c71742
 
 
a8be56a
3c71742
 
fb9a1de
bf4aa9b
 
fb9a1de
 
3c71742
fb9a1de
 
 
bf4aa9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d397f88
3d3d7f4
 
 
 
d397f88
3d3d7f4
 
 
 
 
 
 
 
d397f88
3d3d7f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8be56a
 
 
 
 
 
 
 
 
d2840b5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
---
library_name: transformers
tags: []
license: other
---

## How to use

### 1. modelとtokenizerの呼び出し
```
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
model = AutoModel.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
```

### 2. modelのoutput
```
text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。'
model.eval()
a = tokenizer(text,truncation =True,max_length=512, entities=["船井電機・ホールディングス株式会社"], entity_spans=[(0,4)],return_tensors='pt')
outputs = model(**a)
mlm_logits = outputs.logits
tep_logits = outputs.topic_entity_logits
mep_logits = outputs.entity_logits
print(mlm_logits.shape) >> torch.Size([1, 49, 32770])
print(tep_logits.shape) >> torch.Size([1, 20972])
print(mep_logits.shape) >> torch.Size([1, 1, 20972]) 
```
- modelのencode結果は,logits, topic_entity_logits, entity_logitsを属性として持ちます.
- logitsは通常のBERTなどの言語モデルと同様の扱い方です.
- topic_entity_logits(文章における各enitityの関連度) と entity_logits(entityの埋め込み表現)に関しては,このモデル固有のものであり,以下に扱い方を解説します.

### 3. topic_entity_logits(文章における各enitityの関連度を取得)
```
tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
model = AutoModelForPreTraining.from_pretrained("uzabase/UBKE-LUKE", output_hidden_states=True, trust_remote_code=True)

text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。'

model.eval()
a = tokenizer(text,truncation =True,max_length=512, return_tensors='pt')
outputs = model(**a)
tep_logits = outputs.topic_entity_logits
print("tep_logits shape: ", tep_logits.shape) # >> torch.Size([1, 20972]) each dimentions correspond to entities

ent = { tokenizer.entity_vocab[i]:i for i in tokenizer.entity_vocab}

print("Entity Recognition Results:")
topk_logits, topk_entity_ids = tep_logits.topk(10, dim=1)
for logit, entity_id in zip(topk_logits[0].tolist(), topk_entity_ids[0].tolist()):
    print("\t", ent[entity_id], logit)
>>>
Entity Recognition Results:
	 船井電機・ホールディングス株式会社 1.8898193836212158
	 セイノーホールディングス 1.668973684310913
	 東洋電機 1.658090353012085
	 横河電機 1.6363312005996704
	 船井総研ホールディングス 1.618525743484497
	 西菱電機 1.587844967842102
	 フォスター電機 1.5436134338378906
	 東洋電機製造 1.493951678276062
	 ヒロセ電機 1.458113193511963
	 サクサ 1.4461733102798462
```
- modelのencode結果は,topic_entity_logits属性を持ちます.
- topic_entity_logitsは, torch.Size([batch_size, entity_size])のtoroch.tensorです.
- 各次元のlogit値は,入力文章における各entityの関連度を表現しています.

### 4. entity_logits(entityの埋め込み表現)
- entityの一覧は,tokenizerがentity_vocabに辞書形式で持ちます.
```
tokenizer.entity_vocab # => {"": 0, ... ,"AGC": 48, ....
tokenizer.entity_vocab["味の素"] # => 8469(味の素のentity_id) 
```
- entity_spans及びentitties引数をtokenizerに渡し,tokenをencodeすることで,entityの埋め込み表現を得ます.
```
model.eval()
tokens = tokenizer("味の素", entities=["味の素"], entity_spans=[(0, 3)], truncation=True, max_length=512, return_tensors="pt")
print(tokens["entity_ids"]) # => tensor([[8469]])
with torch.no_grad():
    outputs = model(**tokens)
outputs.entity_logits.shape # 味の素のentity_vector
```
- entityの埋め込み表現の内積(やコサイン類似度)を計算することで,entity同士の類似度を計算可能です.
```
def encode(entity_text):
    model.eval()
    tokens = tokenizer(entity_text, entities=[entity_text], entity_spans=[(0, len(entity_text))],
                       truncation=True, max_length=512, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**tokens)
    return outputs.entity_logits[0][0]
azinomoto = encode("味の素")
nisshin = encode("日清食品ホールディングス")
kameda = encode("亀田製菓")
sony = encode("ソニーホールディングス")
print(azinomoto @ nisshin) # => tensor(24834.6836)
print(azinomoto @ kameda) # => tensor(17547.6895)
print(azinomoto @ sony) # => tensor(8699.2871)
```

## Licenses

The model parameters `model.safetensors` is licensed under CC BY-NC.
モデルの重みファイル `model.safetensors` はCC BY-NCライセンスで利用可能です。

Other files are subject to the same license as [LUKE](https://github.com/studio-ousia/luke) itself.
その他のファイルは[LUKE](https://github.com/studio-ousia/luke)自体と同じライセンスが適用されます。


## Reference

* 開発の背景などについては[ブログ](https://tech.uzabase.com/entry/2024/12/24/173942)を参照してください
* もしUBKE-LUKEの活用に興味をお持ちの方は ub-research@uzabase.com までご連絡ください