File size: 3,097 Bytes
d9b3b55
0cd2c97
 
 
 
 
3a9c126
0cd2c97
 
3a9c126
0cd2c97
 
 
 
 
 
499c31b
0cd2c97
 
 
8966d80
b59964c
0cd2c97
8966d80
b59964c
0cd2c97
 
 
8966d80
 
 
 
0cd2c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b59964c
 
0cd2c97
 
 
 
 
 
b59964c
0cd2c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95ba507
 
0cd2c97
 
 
 
 
d9b3b55
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from models import *
from huggingface_hub import hf_hub_download
import os
from config import *

ENTITY_REPO_ID = 'vaivTA/absa_v2_entity'
ENTITY_FILENAME = "entity_model.pt"

SENTIMENT_REPO_ID = 'vaivTA/absa_v2_sentiment'
SENTIMENT_FILENAME = "sentiment_model.pt"

print("downloading model...")
sen_model_file = hf_hub_download(repo_id=SENTIMENT_REPO_ID, filename=SENTIMENT_FILENAME)
entity_model_file = hf_hub_download(repo_id=ENTITY_REPO_ID, filename=ENTITY_FILENAME)

base_model = base_model

tokenizer = AutoTokenizer.from_pretrained(base_model)

sen_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
sen_model.load_state_dict(torch.load(sen_model_file, map_location=torch.device('cpu')))

entity_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
entity_model.load_state_dict(torch.load(entity_model_file, map_location=torch.device('cpu')))


def infer(test_sentence):
    # entity_model.to(device)
    # entity_model.eval()
    # sen_model.to(device)
    # sen_model.eval()
    
    form = test_sentence
    annotation = []
    
    if len(form) > 500:
        return "Too long sentence!"
    
    
    for pair in entity_property_pair:  
        
        form_ = form + "[SEP]"   
        pair_ = entity2str[pair] + "[SEP]"
        
        tokenized_data = tokenizer(form_, pair_, padding='max_length', max_length=512, truncation=True)
        
        input_ids = torch.tensor([tokenized_data['input_ids']])
        attention_mask = torch.tensor([tokenized_data['attention_mask']])
        
        first_sep = tokenized_data['input_ids'].index(2)
        last_sep = tokenized_data['input_ids'][first_sep+2:].index(2) + (first_sep + 2)        
        mask = [0] * len(tokenized_data['input_ids'])        
        for i in range(first_sep + 2, last_sep):
            mask[i] = 1     
        mask = torch.tensor([mask])
                
        with torch.no_grad():
            outputs = entity_model(input_ids, attention_mask, mask)
        ce_logits = outputs
        ce_predictions = torch.argmax(ce_logits, dim = -1)

        ce_result = tf_id_to_name[ce_predictions[0]]

        if ce_result == 'True':
            with torch.no_grad():
                outputs = sen_model(input_ids, attention_mask, mask)
            pc_logits = outputs
            pc_predictions = torch.argmax(pc_logits, dim=-1)
            pc_result = polarity_id_to_name[pc_predictions[0]]

            annotation.append(f"{pair} - {pc_result}")
            
    result = '\n'.join(annotation)
    return result

    
demo = gr.Interface(fn=infer,
             inputs=gr.Textbox(type="text", label="Input Sentence"),
             outputs=gr.Textbox(type="text", label="Result Sentence"),
             article="**๋ฆฌ๋ทฐ ์˜ˆ์‹œ** : ์•„ํŒŒํŠธ๋Š” ์˜ค๋ž˜๋˜์—ˆ์ง€๋งŒ ๋™๋„ค๊ฐ€ ์กฐ์šฉํ•˜๊ณ  ์พŒ์ ํ•˜์—ฌ ์‚ด๊ธฐ์—๋Š” ์•„์ฃผ ์ข‹์Šต๋‹ˆ๋‹ค. ํฐ ๋งˆํŠธ๊ฐ€ ์ฃผ๋ณ€์— ์—†๋Š” ๋‹จ์ ์ด ์ž‡์ง€๋งŒ ์ด์ดŒ์—ญ์ด ๋งค์šฐ ๊ฐ€๊น๊ณ  ์ƒํ™œ๊ถŒ ๋‚ด์— ๋ง›์ž‡๋Š” ์‹๋‹น๊ณผ ์ปคํ”ผ์ˆ–์ด ์ฆ๋น„ํ•ฉ๋‹ˆ๋‹ค ใ…Žใ…Ž"
             )

demo.launch(share=True)