Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import fasttext | |
| from transformers import AutoModelForSequenceClassification | |
| from transformers import AutoTokenizer | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| id2label = {0: "NEGATIVE", 1: "POSITIVE"} | |
| label2id = {"NEGATIVE": 0, "POSITIVE": 1} | |
| title = "์ํ ๋ฆฌ๋ทฐ ์ ์ ํ๋ณ๊ธฐ" | |
| description = "์ํํ์ ์ ๋ ฅํ์ฌ ๊ธ์ ์ ์ธ์ง ๋ถ์ ์ ์ธ์ง๋ฅผ ๋ถ๋ฅํ๋ ํ๋ก๊ทธ๋จ์ ๋๋ค. \ | |
| ํ๊ตญ์ด ๋ฒ์ ๊ณผ ์์ด ๋ฒ์ ์ค์์ ์ ํํ ์ ์์ต๋๋ค. \ | |
| ํ๊ตญ์ด์ธ์ง ์์ด์ธ์ง ํ๋จํ๊ณ ์์ธกํด์ฃผ๋ ""Default""๋ผ๋ ๋ฒ์ ๋ ์ ๊ณตํฉ๋๋ค." | |
| class LanguageIdentification: | |
| def __init__(self): | |
| pretrained_lang_model = "./lid.176.ftz" | |
| self.model = fasttext.load_model(pretrained_lang_model) | |
| def predict_lang(self, text): | |
| predictions = self.model.predict(text, k=200) # returns top 200 matching languages | |
| return predictions | |
| LANGUAGE = LanguageIdentification() | |
| def tokenized_data(tokenizer, inputs): | |
| return tokenizer.batch_encode_plus( | |
| [inputs], | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=64, | |
| truncation=True) | |
| examples = [] | |
| df = pd.read_csv('examples.csv', sep='\t', index_col='Unnamed: 0') | |
| np.random.seed(100) | |
| idx = np.random.choice(50, size=5, replace=False) | |
| eng_examples = [ ['Eng', df.iloc[i, 0]] for i in idx ] | |
| kor_examples = [ ['Kor', df.iloc[i, 1]] for i in idx ] | |
| examples = eng_examples + kor_examples | |
| eng_model_name = "roberta-base" | |
| eng_step = 1900 | |
| eng_tokenizer = AutoTokenizer.from_pretrained(eng_model_name) | |
| eng_file_name = "{}-{}.pt".format(eng_model_name, eng_step) | |
| eng_model = AutoModelForSequenceClassification.from_pretrained( | |
| eng_model_name, num_labels=2, id2label=id2label, label2id=label2id | |
| ) | |
| eng_state_dict = torch.load(eng_file_name) | |
| # Remove position_ids from state_dict as it's not needed in newer transformers versions | |
| eng_state_dict = {k: v for k, v in eng_state_dict.items() if 'position_ids' not in k} | |
| eng_model.load_state_dict(eng_state_dict, strict=False) | |
| kor_model_name = "klue/roberta-small" | |
| kor_step = 2400 | |
| kor_tokenizer = AutoTokenizer.from_pretrained(kor_model_name) | |
| kor_file_name = "{}-{}.pt".format(kor_model_name.replace('/', '_'), kor_step) | |
| kor_model = AutoModelForSequenceClassification.from_pretrained( | |
| kor_model_name, num_labels=2, id2label=id2label, label2id=label2id | |
| ) | |
| kor_state_dict = torch.load(kor_file_name) | |
| # Remove position_ids from state_dict as it's not needed in newer transformers versions | |
| kor_state_dict = {k: v for k, v in kor_state_dict.items() if 'position_ids' not in k} | |
| kor_model.load_state_dict(kor_state_dict, strict=False) | |
| def builder(Lang, Text): | |
| percent_kor, percent_eng = 0, 0 | |
| text_list = Text.split(' ') | |
| # [ output_1 ] | |
| if Lang == '์ธ์ด๊ฐ์ง ๊ธฐ๋ฅ ์ฌ์ฉ': | |
| pred = LANGUAGE.predict_lang(Text) | |
| if '__label__en' in pred[0]: | |
| Lang = 'Eng' | |
| idx = pred[0].index('__label__en') | |
| p_eng = pred[1][idx] | |
| if '__label__ko' in pred[0]: | |
| Lang = 'Kor' | |
| idx = pred[0].index('__label__ko') | |
| p_kor = pred[1][idx] | |
| # Normalize Percentage | |
| percent_kor = p_kor / (p_kor+p_eng) | |
| percent_eng = p_eng / (p_kor+p_eng) | |
| if Lang == 'Eng': | |
| model = eng_model | |
| tokenizer = eng_tokenizer | |
| if percent_eng==0: percent_eng=1 | |
| if Lang == 'Kor': | |
| model = kor_model | |
| tokenizer = kor_tokenizer | |
| if percent_kor==0: percent_kor=1 | |
| # [ output_2 ] | |
| inputs = tokenized_data(tokenizer, Text) | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask']) | |
| # Handle both tuple and object returns | |
| if isinstance(outputs, tuple): | |
| logits = outputs[0] | |
| else: | |
| logits = outputs.logits | |
| m = torch.nn.Softmax(dim=1) | |
| output = m(logits) | |
| # print(logits, output) | |
| # [ output_3 ] | |
| output_analysis = [] | |
| for word in text_list: | |
| tokenized_word = tokenized_data(tokenizer, word) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=tokenized_word['input_ids'], | |
| attention_mask=tokenized_word['attention_mask']) | |
| # Handle both tuple and object returns | |
| if isinstance(outputs, tuple): | |
| logit = outputs[0] | |
| else: | |
| logit = outputs.logits | |
| word_output = m(logit) | |
| if word_output[0][1] > 0.99: | |
| output_analysis.append( (word, '+++') ) | |
| elif word_output[0][1] > 0.9: | |
| output_analysis.append( (word, '++') ) | |
| elif word_output[0][1] > 0.8: | |
| output_analysis.append( (word, '+') ) | |
| elif word_output[0][1] < 0.01: | |
| output_analysis.append( (word, '---') ) | |
| elif word_output[0][1] < 0.1: | |
| output_analysis.append( (word, '--') ) | |
| elif word_output[0][1] < 0.2: | |
| output_analysis.append( (word, '-') ) | |
| else: | |
| output_analysis.append( (word, None) ) | |
| return [ {'Kor': percent_kor, 'Eng': percent_eng}, | |
| {id2label[1]: output[0][1].item(), id2label[0]: output[0][0].item()}, | |
| output_analysis ] | |
| # prediction = torch.argmax(logits, axis=1) | |
| return id2label[prediction.item()] | |
| # demo3 = gr.Interface.load("models/mdj1412/movie_review_score_discriminator_eng", inputs="text", outputs="text", | |
| # title=title, theme="peach", | |
| # allow_flagging="auto", | |
| # description=description, examples=examples) | |
| # demo = gr.Interface(builder, inputs=[gr.inputs.Dropdown(['Default', 'Eng', 'Kor']), gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ ๋ ฅํ์์ค.")], | |
| # outputs=[ gr.Label(num_top_classes=3, label='Lang'), | |
| # gr.Label(num_top_classes=2, label='Result'), | |
| # gr.HighlightedText(label="Analysis", combine_adjacent=False) | |
| # .style(color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"}) ], | |
| # # outputs='label', | |
| # title=title, description=description, examples=examples) | |
| with gr.Blocks() as demo1: | |
| gr.Markdown( | |
| """ | |
| <h1 align="center"> | |
| ์ํ ๋ฆฌ๋ทฐ ์ ์ ํ๋ณ๊ธฐ | |
| </h1> | |
| """) | |
| gr.Markdown( | |
| """ | |
| ์ํ ๋ฆฌ๋ทฐ๋ฅผ ์ ๋ ฅํ๋ฉด, ๋ฆฌ๋ทฐ๊ฐ ๊ธ์ ์ธ์ง ๋ถ์ ์ธ์ง ํ๋ณํด์ฃผ๋ ๋ชจ๋ธ์ด๋ค. \ | |
| ์์ด์ ํ๊ธ์ ์ง์ํ๋ฉฐ, ์ธ์ด๋ฅผ ์ง์ ์ ํํ ์๋, ํน์ ๋ชจ๋ธ์ด ์ธ์ด๊ฐ์ง๋ฅผ ์ง์ ํ๋๋ก ํ ์ ์๋ค. | |
| ๋ฆฌ๋ทฐ๋ฅผ ์ ๋ ฅํ๋ฉด, (1) ๊ฐ์ง๋ ์ธ์ด, (2) ๊ธ์ ๋ฆฌ๋ทฐ์ผ ํ๋ฅ ๊ณผ ๋ถ์ ๋ฆฌ๋ทฐ์ผ ํ๋ฅ , (3) ์ ๋ ฅ๋ ๋ฆฌ๋ทฐ์ ์ด๋ ๋จ์ด๊ฐ ๊ธ์ /๋ถ์ ๊ฒฐ์ ์ ์ํฅ์ ์ฃผ์๋์ง \ | |
| (๊ธ์ ์ผ ๊ฒฝ์ฐ ๋นจ๊ฐ์, ๋ถ์ ์ผ ๊ฒฝ์ฐ ํ๋์)๋ฅผ ํ์ธํ ์ ์๋ค. | |
| """) | |
| with gr.Accordion(label="๋ชจ๋ธ์ ๋ํ ์ค๋ช ( ์ฌ๊ธฐ๋ฅผ ํด๋ฆญ ํ์์ค. )", open=False): | |
| gr.Markdown( | |
| """ | |
| ์์ด ๋ชจ๋ธ์ bert-base-uncased ๊ธฐ๋ฐ์ผ๋ก, ์์ด ์ํ ๋ฆฌ๋ทฐ ๋ถ์ ๋ฐ์ดํฐ์ ์ธ SST-2๋ก ํ์ต ๋ฐ ํ๊ฐ๋์๋ค. | |
| ํ๊ธ ๋ชจ๋ธ์ klue/roberta-base ๊ธฐ๋ฐ์ด๋ค. ๊ธฐ์กด ํ๊ธ ์ํ ๋ฆฌ๋ทฐ ๋ถ์ ๋ฐ์ดํฐ์ ์ด ์กด์ฌํ์ง ์์, ๋ค์ด๋ฒ ์ํ์ ๋ฆฌ๋ทฐ๋ฅผ ํฌ๋กค๋งํด์ ์ํ ๋ฆฌ๋ทฐ ๋ถ์ ๋ฐ์ดํฐ์ ์ ์ ์ํ๊ณ , ์ด๋ฅผ ์ด์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ต ๋ฐ ํ๊ฐํ์๋ค. | |
| ์์ด ๋ชจ๋ธ์ SST-2์์ 92.8%, ํ๊ธ ๋ชจ๋ธ์ ๋ค์ด๋ฒ ์ํ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ์ ์์ 94%์ ์ ํ๋๋ฅผ ๊ฐ์ง๋ค (test set ๊ธฐ์ค). | |
| ์ธ์ด๊ฐ์ง๋ fasttext์ language detector๋ฅผ ์ฌ์ฉํ์๋ค. ๋ฆฌ๋ทฐ์ ๋จ์ด๋ณ ์ํฅ๋ ฅ์, ๋จ์ด ๊ฐ๊ฐ์ ๋ชจ๋ธ์ ๋ฃ์์ ๋ ๊ฒฐ๊ณผ๊ฐ ๊ธ์ ์ผ๋ก ๋์ค๋์ง ๋ถ์ ์ผ๋ก ๋์ค๋์ง๋ฅผ ๋ฐํ์ผ๋ก ์ธก์ ํ์๋ค. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inputs_1 = gr.Dropdown(choices=['์ธ์ด๊ฐ์ง ๊ธฐ๋ฅ ์ฌ์ฉ', 'Eng', 'Kor'], value='์ธ์ด๊ฐ์ง ๊ธฐ๋ฅ ์ฌ์ฉ', label='Lang') | |
| inputs_2 = gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ ๋ ฅํ์์ค.", label='Text') | |
| with gr.Row(): | |
| # btn2 = gr.Button("ํด๋ฆฌ์ด") | |
| btn = gr.Button("์ ์ถํ๊ธฐ") | |
| with gr.Column(): | |
| output_1 = gr.Label(num_top_classes=3, label='Lang') | |
| output_2 = gr.Label(num_top_classes=2, label='Result') | |
| output_3 = gr.HighlightedText( | |
| label="Analysis", | |
| combine_adjacent=False, | |
| color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"} | |
| ) | |
| # btn2.click(fn=fn2, inputs=[None, None], output=[output_1, output_2, output_3]) | |
| btn.click(fn=builder, inputs=[inputs_1, inputs_2], outputs=[output_1, output_2, output_3]) | |
| gr.Examples(examples, inputs=[inputs_1, inputs_2]) | |
| if __name__ == "__main__": | |
| # print(examples) | |
| # demo.launch() | |
| demo1.launch() |