yiningmao's picture
Update app.py
d96e7ad
import os
import sys
import pickle
import random
import copy
import numpy as np
import gradio as gr
import re
import string
import torch
import torch.nn as nn
from tqdm import tqdm, trange
from collections import OrderedDict
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from utils import Config, Logger, make_log_dir
from modeling import (
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForSequenceClassification_SPV,
AutoModelForSequenceClassification_MIP,
AutoModelForSequenceClassification_SPV_MIP,
)
from run_classifier_dataset_utils import processors, output_modes, compute_metrics
from data_loader import load_train_data, load_train_data_kf, load_test_data, load_sentence_data
from frame_semantic_transformer import FrameSemanticTransformer
frame_transformer = FrameSemanticTransformer()
frame_transformer.setup()
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
ARGS_NAME = "training_args.bin"
def main():
# read configs
config = Config(main_conf_path="./")
# apply system arguments if exist
argv = sys.argv[1:]
if len(argv) > 0:
cmd_arg = OrderedDict()
argvs = " ".join(sys.argv[1:]).split(" ")
for i in range(0, len(argvs), 2):
arg_name, arg_value = argvs[i], argvs[i + 1]
arg_name = arg_name.strip("-")
cmd_arg[arg_name] = arg_value
config.update_params(cmd_arg)
args = config
print(args.__dict__)
# logger
log_dir = args.bert_model
logger = Logger(log_dir)
config = Config(main_conf_path=log_dir)
old_args = copy.deepcopy(args)
args.__dict__.update(config.__dict__)
args.bert_model = old_args.bert_model
args.do_train = old_args.do_train
args.data_dir = old_args.data_dir
args.task_name = old_args.task_name
# apply system arguments if exist
argv = sys.argv[1:]
if len(argv) > 0:
cmd_arg = OrderedDict()
argvs = " ".join(sys.argv[1:]).split(" ")
for i in range(0, len(argvs), 2):
arg_name, arg_value = argvs[i], argvs[i + 1]
arg_name = arg_name.strip("-")
cmd_arg[arg_name] = arg_value
config.update_params(cmd_arg)
args.log_dir = log_dir
# set CUDA devices
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
args.device = device
logger.info("device: {} n_gpu: {}".format(device, args.n_gpu))
# set seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
# get dataset and processor
args.num_labels = 2
# build tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = load_pretrained_model(args)
# Load trained model
model = load_trained_model(args, model, tokenizer)
def run_one_sentence(sentence, test_frame):
sentence = re.sub(r'([^\w\s])', r' \1 ', sentence)
sentence = ' '.join(sentence.split())
frame_sentence = re.sub(r'\s+([^\w\s])', r'\1', sentence)
model.eval()
s_batch = load_sentence_data(args, sentence, ['0','1'], tokenizer, 'classification')
if args.model_type in ["MELBERT_MIP", "MELBERT"]:
input_ids, input_mask, segment_ids, label_ids, idx, input_ids_2, input_mask_2, segment_ids_2 = s_batch
else:
input_ids, input_mask, segment_ids, label_ids, idx = s_batch
with torch.no_grad():
# compute loss values
if args.model_type in ["BERT_BASE", "BERT_SEQ", "MELBERT_SPV"]:
logits = model(
input_ids,
target_mask=(segment_ids == 1),
token_type_ids=segment_ids,
attention_mask=input_mask,
)
elif args.model_type in ["MELBERT_MIP", "MELBERT"]:
logits = model(
input_ids,
input_ids_2,
target_mask=(segment_ids == 1),
target_mask_2=segment_ids_2,
attention_mask_2=input_mask_2,
token_type_ids=segment_ids,
attention_mask=input_mask,
)
pred = logits.detach().cpu().numpy()
pred = np.argmax(pred, axis=1)
sentence_list = sentence.split()
pred_list = [None for _ in range(len(sentence_list))]
for i,n in enumerate(idx):
pred_list[n] = 'M' if pred[i] == 1 else None
frame_list = []
if test_frame:
sentence_frame = frame_transformer.detect_frames(frame_sentence)
for i,n in enumerate(idx):
if pred[i] == 1:
word_loc = frame_sentence.find(sentence_list[n])
word_frame = frame_transformer.detect_frames(sentence_list[n])
if word_loc in sentence_frame.trigger_locations and 0 in word_frame.trigger_locations:
frame_list = frame_list + [
('['+sentence_list[n]+']', None),
(sentence_frame.frames[sentence_frame.trigger_locations.index(word_loc)].name, 'Contextual'),
(word_frame.frames[0].name, 'Literal'),
(' \n', None)
]
else:
frame_list = frame_list + [
('['+sentence_list[n]+']', None),
]
label_list = [(w, p) for w,p in zip(sentence.split(), pred_list)]
#print(label_list)
return label_list, frame_list
demo = gr.Interface(
run_one_sentence,
[
gr.Textbox(placeholder="Enter sentence here..."),
gr.Checkbox(label="Test frame", value=False),
],
[gr.HighlightedText(label='Metaphor Detection'), gr.HighlightedText(label='Frame Extraction')],
examples=[
['while new departments are born and others extended .', True],
['Dimples played in his cheeks .', True],
['For a whole week they had worked closely together , sharing flasks of coffee and packets of cigarettes and Paula had grown to like the pixieish little man who by his very nature offered her no challenge — and no threat .', True],
['Paula looked every inch a model these days , Arlene thought with a touch of proprietorial pride .', True],
['The sounds are the same as those of daylight , yet somehow the night magnifies and sharpens the creak of a yielding block , the sigh of air over a shroud , the stretching of a sail , the hiss of water sliding sleek against the hull , the curl of a quarter-wave falling away , and the thump as a wave strikes the cutwater to be sheared into two bright slices of whiteness .', True],
['and finally, the debate has sharpened.', True],
['But the transformation of a leggy young filly into a sleekly beautiful racehorse had been her doing .', True],
["It would change the trajectory of your legal career.", True],
["Washington and the media just explodes on you, you just don’t know where you are at the moment", True],
["Those statements are deeply concerning.", True],
['The dog can smell your fear.', True],
['I don\'t know what the hell is going on!', True],
['Well, hang on a minute', True],
]
)
demo.launch(debug=True)
def load_pretrained_model(args):
# Pretrained Model
bert = AutoModel.from_pretrained(args.bert_model)
#for name, param in bert.named_parameters():
# print(name, param.requires_grad)
config = bert.config
config.type_vocab_size = 4
if "albert" in args.bert_model:
bert.embeddings.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.embedding_size
)
else:
bert.embeddings.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size
)
bert._init_weights(bert.embeddings.token_type_embeddings)
# Additional Layers
if args.model_type in ["BERT_BASE"]:
model = AutoModelForSequenceClassification(
args=args, Model=bert, config=config, num_labels=args.num_labels
)
if args.model_type == "BERT_SEQ":
model = AutoModelForTokenClassification(
args=args, Model=bert, config=config, num_labels=args.num_labels
)
if args.model_type == "MELBERT_SPV":
model = AutoModelForSequenceClassification_SPV(
args=args, Model=bert, config=config, num_labels=args.num_labels
)
if args.model_type == "MELBERT_MIP":
model = AutoModelForSequenceClassification_MIP(
args=args, Model=bert, config=config, num_labels=args.num_labels
)
if args.model_type == "MELBERT":
model = AutoModelForSequenceClassification_SPV_MIP(
args=args, Model=bert, config=config, num_labels=args.num_labels
)
model.to(args.device)
if args.n_gpu > 1 and not args.no_cuda:
model = torch.nn.DataParallel(model)
return model
def load_trained_model(args, model, tokenizer):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.log_dir, WEIGHTS_NAME)
if hasattr(model, "module"):
model.module.load_state_dict(torch.load(output_model_file, map_location=args.device))
else:
model.load_state_dict(torch.load(output_model_file, map_location=args.device))
return model
if __name__ == "__main__":
main()