|
|
--- |
|
|
library_name: transformers |
|
|
tags: [] |
|
|
license: other |
|
|
--- |
|
|
|
|
|
## How to use |
|
|
|
|
|
### 1. modelとtokenizerの呼び出し |
|
|
``` |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True) |
|
|
``` |
|
|
|
|
|
### 2. modelのoutput |
|
|
``` |
|
|
text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。' |
|
|
model.eval() |
|
|
a = tokenizer(text,truncation =True,max_length=512, entities=["船井電機・ホールディングス株式会社"], entity_spans=[(0,4)],return_tensors='pt') |
|
|
outputs = model(**a) |
|
|
mlm_logits = outputs.logits |
|
|
tep_logits = outputs.topic_entity_logits |
|
|
mep_logits = outputs.entity_logits |
|
|
print(mlm_logits.shape) >> torch.Size([1, 49, 32770]) |
|
|
print(tep_logits.shape) >> torch.Size([1, 20972]) |
|
|
print(mep_logits.shape) >> torch.Size([1, 1, 20972]) |
|
|
``` |
|
|
- modelのencode結果は,logits, topic_entity_logits, entity_logitsを属性として持ちます. |
|
|
- logitsは通常のBERTなどの言語モデルと同様の扱い方です. |
|
|
- topic_entity_logits(文章における各enitityの関連度) と entity_logits(entityの埋め込み表現)に関しては,このモデル固有のものであり,以下に扱い方を解説します. |
|
|
|
|
|
### 3. topic_entity_logits(文章における各enitityの関連度を取得) |
|
|
``` |
|
|
tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True) |
|
|
model = AutoModelForPreTraining.from_pretrained("uzabase/UBKE-LUKE", output_hidden_states=True, trust_remote_code=True) |
|
|
|
|
|
text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。' |
|
|
|
|
|
model.eval() |
|
|
a = tokenizer(text,truncation =True,max_length=512, return_tensors='pt') |
|
|
outputs = model(**a) |
|
|
tep_logits = outputs.topic_entity_logits |
|
|
print("tep_logits shape: ", tep_logits.shape) # >> torch.Size([1, 20972]) each dimentions correspond to entities |
|
|
|
|
|
ent = { tokenizer.entity_vocab[i]:i for i in tokenizer.entity_vocab} |
|
|
|
|
|
print("Entity Recognition Results:") |
|
|
topk_logits, topk_entity_ids = tep_logits.topk(10, dim=1) |
|
|
for logit, entity_id in zip(topk_logits[0].tolist(), topk_entity_ids[0].tolist()): |
|
|
print("\t", ent[entity_id], logit) |
|
|
>>> |
|
|
Entity Recognition Results: |
|
|
船井電機・ホールディングス株式会社 1.8898193836212158 |
|
|
セイノーホールディングス 1.668973684310913 |
|
|
東洋電機 1.658090353012085 |
|
|
横河電機 1.6363312005996704 |
|
|
船井総研ホールディングス 1.618525743484497 |
|
|
西菱電機 1.587844967842102 |
|
|
フォスター電機 1.5436134338378906 |
|
|
東洋電機製造 1.493951678276062 |
|
|
ヒロセ電機 1.458113193511963 |
|
|
サクサ 1.4461733102798462 |
|
|
``` |
|
|
- modelのencode結果は,topic_entity_logits属性を持ちます. |
|
|
- topic_entity_logitsは, torch.Size([batch_size, entity_size])のtoroch.tensorです. |
|
|
- 各次元のlogit値は,入力文章における各entityの関連度を表現しています. |
|
|
|
|
|
### 4. entity_logits(entityの埋め込み表現) |
|
|
- entityの一覧は,tokenizerがentity_vocabに辞書形式で持ちます. |
|
|
``` |
|
|
tokenizer.entity_vocab # => {"": 0, ... ,"AGC": 48, .... |
|
|
tokenizer.entity_vocab["味の素"] # => 8469(味の素のentity_id) |
|
|
``` |
|
|
- entity_spans及びentitties引数をtokenizerに渡し,tokenをencodeすることで,entityの埋め込み表現を得ます. |
|
|
``` |
|
|
model.eval() |
|
|
tokens = tokenizer("味の素", entities=["味の素"], entity_spans=[(0, 3)], truncation=True, max_length=512, return_tensors="pt") |
|
|
print(tokens["entity_ids"]) # => tensor([[8469]]) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**tokens) |
|
|
outputs.entity_logits.shape # 味の素のentity_vector |
|
|
``` |
|
|
- entityの埋め込み表現の内積(やコサイン類似度)を計算することで,entity同士の類似度を計算可能です. |
|
|
``` |
|
|
def encode(entity_text): |
|
|
model.eval() |
|
|
tokens = tokenizer(entity_text, entities=[entity_text], entity_spans=[(0, len(entity_text))], |
|
|
truncation=True, max_length=512, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**tokens) |
|
|
return outputs.entity_logits[0][0] |
|
|
azinomoto = encode("味の素") |
|
|
nisshin = encode("日清食品ホールディングス") |
|
|
kameda = encode("亀田製菓") |
|
|
sony = encode("ソニーホールディングス") |
|
|
print(azinomoto @ nisshin) # => tensor(24834.6836) |
|
|
print(azinomoto @ kameda) # => tensor(17547.6895) |
|
|
print(azinomoto @ sony) # => tensor(8699.2871) |
|
|
``` |
|
|
|
|
|
## Licenses |
|
|
|
|
|
The model parameters `model.safetensors` is licensed under CC BY-NC. |
|
|
モデルの重みファイル `model.safetensors` はCC BY-NCライセンスで利用可能です。 |
|
|
|
|
|
Other files are subject to the same license as [LUKE](https://github.com/studio-ousia/luke) itself. |
|
|
その他のファイルは[LUKE](https://github.com/studio-ousia/luke)自体と同じライセンスが適用されます。 |
|
|
|
|
|
|
|
|
## Reference |
|
|
|
|
|
* 開発の背景などについては[ブログ](https://tech.uzabase.com/entry/2024/12/24/173942)を参照してください |
|
|
* もしUBKE-LUKEの活用に興味をお持ちの方は ub-research@uzabase.com までご連絡ください |
|
|
|