| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| from transformers import pipeline | |
| from collections import defaultdict | |
| import torch | |
| device = torch.device("cuda") | |
| tokenizer = AutoTokenizer.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner") | |
| model = AutoModelForTokenClassification.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner") | |
| model.to(device) | |
| # ์ก์ฅ์ด๋ผ ์ถ์ ๋๋๋ถ๋ถ์ craft์ ํต๊ณผ์ํค๊ณ text ๊ฐ ์๋๋ถ๋ถ์ ํฌ๋กญํด์ trocr๋ก text๋ฅผ ๊ทธ ์์ญ์ ๋ฝ์๋ธ์ดํ ํ๋ก์ธ์ค์ ๋๋ค. | |
| # ๋ฝํ text์ ๋ํ class๋ฅผ ํ๋ณํฉ๋๋ค. | |
| # text์ ๋ํ class๊ฐ "์ฌ๋์ด๋ฆ PS", "๋๋ก/๊ฑด๋ฌผ ์ด๋ฆ AF", "์ฃผ์ LC" ์ ์ํ๋ฉด 1์ ๋ฐํํ์ฌ ์ดํ ๋ชจ์์ดํฌ ํ๋๋กํฉ๋๋ค. | |
| # ner ๋ชจ๋ธ์ text๋ฅผ ์ด์ ๋ง๋ค ์ชผ๊ฐ์ ๊ฐ ๋จ์ด์ ๋ํ class๋ฅผ ๋ฐํํฉ๋๋ค. | |
| # ์ด ๋ ๋ชจ๋ ๋จ์ด์ ๋ํ class๋ฅผ ๊ณ ๋ คํ๋ค๋ณด๋ฉด infer speed ๊ฐ ๋งค์ฐ๋๋ ค์ ์ต์ํ ํ๋๋ผ๋ ps,af,lc ํด๋์ค ํด๋น ๋จ์ด๊ฐ ์์ผ๋ฉด 1 ๋ฐํํ๋๋กํฉ๋๋ค. | |
| def check_entity(entities): | |
| for entity_info in entities: | |
| entity_value = entity_info.get('entity', '').upper() | |
| if 'LC' in entity_value or 'PS' in entity_value or 'AF' in entity_value: | |
| return 1 | |
| return 0 | |
| def ner(example): | |
| ner = pipeline("ner", model=model, tokenizer=tokenizer,device=device) | |
| ner_results = ner(example) | |
| ner_results=check_entity(ner_results) | |
| return ner_results | |
| # ํ๋ | |
| # def find_longest_value_key(input_dict): | |
| # max_length = 0 | |
| # max_length_keys = [] | |
| # for key, value in input_dict.items(): | |
| # current_length = len(value) | |
| # if current_length > max_length: | |
| # max_length = current_length | |
| # max_length_keys = [key] | |
| # elif current_length == max_length: | |
| # max_length_keys.append(key) | |
| # if len(max_length_keys) == 1: | |
| # return 0 | |
| # else: | |
| # return 1 | |
| # def find_longest_value_key2(input_dict): | |
| # if not input_dict: | |
| # return None | |
| # max_key = max(input_dict, key=lambda k: len(input_dict[k])) | |
| # return max_key | |
| # def find_most_frequent_entity(entities): | |
| # entity_counts = defaultdict(list) | |
| # for item in entities: | |
| # split_entity = item['entity'].split('-') | |
| # entity_type = split_entity[1] | |
| # entity_counts[entity_type].append(item['score']) | |
| # number=find_longest_value_key(entity_counts) | |
| # if number==1: | |
| # max_entities = [] | |
| # max_score_average = -1 | |
| # for entity, scores in entity_counts.items(): | |
| # score_average = sum(scores) / len(scores) | |
| # if score_average > max_score_average: | |
| # max_entities = [entity] | |
| # max_score_average = score_average | |
| # elif score_average == max_score_average: | |
| # max_entities.append(entity) | |
| # if len(max_entities)>0: | |
| # return max_entities if len(max_entities) > 1 else max_entities[0] | |
| # else: | |
| # return "Do not mosaik" | |
| # else: | |
| # A=find_longest_value_key2(entity_counts) | |
| # return A | |
| # ํ๋๋ผ๋ ps ๋ lc ๊ฐ ์์ผ๋ฉด ๋ฐ๋ก ps , lc ๊บผ๋ด๊ธฐ | |
| # label=filtering(ner_results) | |
| # if label.find("PS")>-1 or label.find("LC")>-1: | |
| # return 1 | |
| # else: | |
| # return 0 | |
| #print(ner("ํ๊ธธ๋")) | |
| #label=check_label(example) | |