THU-IAR's picture
Upload 198 files
2d06dcc verified
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import WEIGHTS_NAME, CONFIG_NAME
def mask_tokens(inputs, tokenizer, special_tokens_mask=None, mlm_probability=0.15):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
probability_matrix = torch.full(labels.shape, mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
probability_matrix[torch.where(inputs==0)] = 0.0
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
return inputs, labels
def save_npy(npy_file, path, file_name):
npy_path = os.path.join(path, file_name)
np.save(npy_path, npy_file)
def load_npy(path, file_name):
npy_path = os.path.join(path, file_name)
npy_file = np.load(npy_path)
return npy_file
def save_model(model, model_dir):
save_model = model.module if hasattr(model, 'module') else model
model_file = os.path.join(model_dir, WEIGHTS_NAME)
model_config_file = os.path.join(model_dir, CONFIG_NAME)
torch.save(save_model.state_dict(), model_file)
if hasattr(save_model, 'config'):
with open(model_config_file, "w") as f:
f.write(save_model.config.to_json_string())
def restore_model(model, model_dir):
output_model_file = os.path.join(model_dir, 'pytorch_model.bin')
model.load_state_dict(torch.load(output_model_file), strict=False)
return model
def save_results(args, test_results):
pred_labels_path = os.path.join(args.method_output_dir, 'y_pred.npy')
np.save(pred_labels_path, test_results['y_pred'])
true_labels_path = os.path.join(args.method_output_dir, 'y_true.npy')
np.save(true_labels_path, test_results['y_true'])
del test_results['y_pred']
del test_results['y_true']
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
import datetime
created_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
var = [args.dataset, args.method, args.backbone, args.known_cls_ratio, args.labeled_ratio, args.loss_fct, args.seed, args.num_train_epochs, created_time]
names = ['dataset', 'method', 'backbone', 'known_cls_ratio', 'labeled_ratio', 'loss', 'seed', 'train_epochs', 'created_time']
vars_dict = {k:v for k,v in zip(names, var) }
results = dict(test_results,**vars_dict)
keys = list(results.keys())
values = list(results.values())
results_path = os.path.join(args.result_dir, args.results_file_name)
if not os.path.exists(results_path) or os.path.getsize(results_path) == 0:
ori = []
ori.append(values)
df1 = pd.DataFrame(ori,columns = keys)
df1.to_csv(results_path,index=False)
else:
df1 = pd.read_csv(results_path)
new = pd.DataFrame(results,index=[1])
df1 = pd.concat([df1, new], ignore_index=True)
df1.to_csv(results_path,index=False)
data_diagram = pd.read_csv(results_path)
print('test_results', data_diagram)
def class_count(labels):
class_data_num = []
for l in np.unique(labels):
num = len(labels[labels == l])
class_data_num.append(num)
return class_data_num
def centroids_cal(model, args, data, train_dataloader, device):
model.eval()
centroids = torch.zeros(data.num_labels, args.feat_dim).to(device)
total_labels = torch.empty(0, dtype=torch.long).to(device)
with torch.set_grad_enabled(False):
for batch in tqdm(train_dataloader, desc="Calculate centroids"):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
features = model(input_ids, segment_ids, input_mask, feature_ext=True)
total_labels = torch.cat((total_labels, label_ids))
for i in range(len(label_ids)):
label = label_ids[i]
centroids[label] += features[i]
total_labels = total_labels.cpu().numpy()
centroids /= torch.tensor(class_count(total_labels)).float().unsqueeze(1).to(device)
return centroids
def euclidean_metric(a, b):
n = a.shape[0]
m = b.shape[0]
a = a.unsqueeze(1).expand(n, m, -1)
b = b.unsqueeze(0).expand(n, m, -1)
logits = -((a - b)**2).sum(dim=2)
return logits