| import torch |
| import torch.nn as nn |
| from tqdm import tqdm |
| import pickle |
| import numpy as np |
| import argparse |
| from metrics import get_roc_metrics, get_precision_recall_metrics |
| from sklearn.metrics import roc_auc_score |
| from TextFluoroscopy.utils_fluoroscopy import compute_embedding, compute_kl_feat, load_model, load_model2 |
| import json |
|
|
| def get_embedding(train_file, valid_file, test_file, kl_path, which_embedding, which_layer, device): |
| train_num=None |
| valid_num=None |
| test_num=None |
| train_name = train_file.split("/")[-1] |
| valid_name = valid_file.split("/")[-1] |
| train_embeddings = torch.load(f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/{train_name}.pt')[:train_num] |
| valid_embeddings = torch.load(f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/{valid_name}.pt')[:valid_num] |
| train_labels = torch.arange(train_embeddings.shape[0]) % 2 |
| valid_labels = torch.arange(valid_embeddings.shape[0]) % 2 |
| |
| train_embeddings = train_embeddings.to(device) |
| valid_embeddings = valid_embeddings.to(device) |
| train_labels = train_labels.to(device) |
| valid_labels = valid_labels.to(device) |
|
|
| with open(f'scripts/TextFluoroscopy/save/{kl_path}/{train_name}.pkl', 'rb') as f: |
| train_kl = pickle.load(f) |
| train_kl = np.array(train_kl) |
| idx = train_kl.argmax(axis=1) |
| if which_layer == 'max_kl': |
| train_embeddings = torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row ,i in zip(train_embeddings, idx)]).to(device) |
| if which_layer == 'max_kl_and_last_layer': |
| train_embeddings = torch.cat([torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row ,i in zip(train_embeddings, idx)]), |
| train_embeddings[:,-embedding_dim:]], dim=1).to(device) |
|
|
| with open(f'scripts/TextFluoroscopy/save/{kl_path}/{valid_name}.pkl', 'rb') as f: |
| valid_kl = pickle.load(f) |
| valid_kl = np.array(valid_kl) |
| idx = valid_kl.argmax(axis=1) |
| if which_layer == 'max_kl': |
| valid_embeddings = torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row,i in zip(valid_embeddings, idx)]).to(device) |
| if which_layer == 'max_kl_and_last_layer': |
| valid_embeddings = torch.cat([torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row, i in zip(valid_embeddings, idx)]), |
| valid_embeddings[:,-embedding_dim:]], dim=1).to(device) |
|
|
| if which_layer == 'first_layer': |
| train_embeddings = train_embeddings[:,:embedding_dim].to(device) |
| valid_embeddings = valid_embeddings[:,:embedding_dim].to(device) |
| test_embeddings = test_embeddings[:,:embedding_dim].to(device) |
| elif which_layer == 'last_layer': |
| train_embeddings = train_embeddings[:,-embedding_dim:].to(device) |
| valid_embeddings = valid_embeddings[:,-embedding_dim:].to(device) |
| test_embeddings = test_embeddings[:,-embedding_dim:].to(device) |
| elif which_layer == 'first_and_last_layers': |
| train_embeddings = torch.cat([train_embeddings[:,:embedding_dim],train_embeddings[:,-embedding_dim:]], dim=1).to(device) |
| valid_embeddings = torch.cat([valid_embeddings[:,:embedding_dim],valid_embeddings[:,-embedding_dim:]], dim=1).to(device) |
| test_embeddings = torch.cat([test_embeddings[:,:embedding_dim],test_embeddings[:,-embedding_dim:]], dim=1).to(device) |
| elif which_layer.startswith('layer_'): |
| if 'last_layer' not in which_layer and 'later_layer' not in which_layer and 'to' not in which_layer: |
| layer_num = int(which_layer.split('_')[-1]) |
| train_embeddings = train_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim].to(device) |
| valid_embeddings = valid_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim].to(device) |
| elif 'last_layer' in which_layer: |
| layer_num = int(which_layer.split('_')[1]) |
| train_embeddings = torch.cat([train_embeddings[:,-embedding_dim:],train_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim]], dim=1).to(device) |
| valid_embeddings = torch.cat([valid_embeddings[:,-embedding_dim:],valid_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim]], dim=1).to(device) |
| elif 'later_layer' in which_layer: |
| layer_num = int(which_layer.split('_')[1]) |
| train_embeddings = train_embeddings[:,(layer_num)*embedding_dim:].to(device) |
| valid_embeddings = valid_embeddings[:,(layer_num)*embedding_dim:].to(device) |
| elif 'to' in which_layer: |
| layer_num = int(which_layer.split('_')[1]) |
| layer_num2 = int(which_layer.split('_')[3]) |
| train_embeddings = train_embeddings[:,(layer_num)*embedding_dim:(layer_num2+1)*embedding_dim].to(device) |
| valid_embeddings = valid_embeddings[:,(layer_num)*embedding_dim:(layer_num2+1)*embedding_dim].to(device) |
|
|
| test_name = test_file.split("/")[-1] |
| testset_embeddings = torch.load(f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/{test_name}.pt')[:test_num] |
| testset_embeddings = testset_embeddings.to(device) |
| testset_labels = torch.arange(testset_embeddings.shape[0]) % 2 |
| with open(f'scripts/TextFluoroscopy/save/{kl_path}/{test_name}.pkl', 'rb') as f: |
| kl = pickle.load(f) |
| kl = np.array(kl) |
| idx = kl.argmax(axis=1) |
| if which_layer == 'max_kl': |
| testset_embeddings = torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row, i in zip(testset_embeddings, idx)]).to(device) |
| elif which_layer == 'max_kl_and_last_layer': |
| testset_embeddings = torch.cat([torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row, i in zip(testset_embeddings, idx)]), |
| testset_embeddings[:,-embedding_dim:]],dim=1).to(device) |
| elif which_layer == 'first_layer': |
| testset_embeddings = testset_embeddings[:, :embedding_dim].to(device) |
| elif which_layer == 'last_layer': |
| testset_embeddings = testset_embeddings[:, -embedding_dim:].to(device) |
| elif which_layer == 'first_and_last_layers': |
| testset_embeddings = torch.cat([testset_embeddings[:, :embedding_dim], testset_embeddings[:, -embedding_dim:]], dim=1).to(device) |
| elif which_layer.startswith('layer_'): |
| if 'last_layer' not in which_layer and 'later_layer' not in which_layer and 'to' not in which_layer: |
| layer_num = int(which_layer.split('_')) |
| testset_embeddings = testset_embeddings[:, (layer_num)*embedding_dim:(layer_num+1)*embedding_dim].to(device) |
| elif 'last_layer' in which_layer: |
| layer_num = int(which_layer.split('_')[1]) |
| testset_embeddings = torch.cat([testset_embeddings[:, -embedding_dim:], testset_embeddings[:, (layer_num)*embedding_dim:(layer_num+1)*embedding_dim]], dim=1).to(device) |
| elif 'later_layer' in which_layer: |
| layer_num = int(which_layer.split('_')[1]) |
| testset_embeddings = testset_embeddings[:, (layer_num)*embedding_dim:].to(device) |
| elif 'to' in which_layer: |
| layer_num = int(which_layer.split('_')[1]) |
| layer_num2 = int(which_layer.split('_')[3]) |
| testset_embeddings = testset_embeddings[:, (layer_num)*embedding_dim:(layer_num2+1)*embedding_dim].to(device) |
| |
| return train_embeddings, train_labels, valid_embeddings, valid_labels, testset_embeddings, testset_labels |
|
|
| def test(model, test_set): |
| with torch.no_grad(): |
| outputs = model(test_set) |
| probabilities = torch.softmax(outputs, dim=1)[:, 1] |
| prediction = probabilities.cpu().numpy() |
| real_pred = prediction[0::2].tolist() |
| sampled_pred = prediction[1::2].tolist() |
| fpr, tpr, roc_auc = get_roc_metrics(real_pred, sampled_pred) |
| p, r, pr_auc = get_precision_recall_metrics(real_pred, sampled_pred) |
| results = { |
| 'name': f'fluoroscopy_threshold', |
| 'predictions': {'real': real_pred, 'samples': sampled_pred}, |
| 'metrics': {'roc_auc': roc_auc, 'fpr': fpr, 'tpr': tpr}, |
| 'pr_metrics': {'pr_auc': pr_auc, 'precision': p, 'recall': r}, |
| 'loss': 1 - pr_auc |
| } |
| return results |
|
|
| class BinaryClassifier(nn.Module): |
| def __init__(self, input_size, hidden_sizes=[1024, 512], num_labels=2, dropout_prob=0.2): |
| super(BinaryClassifier, self).__init__() |
| self.num_labels = num_labels |
| layers = [] |
| prev_size = input_size |
| for hidden_size in hidden_sizes: |
| layers.extend([ |
| nn.Dropout(dropout_prob), |
| nn.Linear(prev_size, hidden_size), |
| |
| nn.ReLU(), |
| ]) |
| prev_size = hidden_size |
| self.dense = nn.Sequential(*layers) |
| self.classifier = nn.Linear(prev_size, num_labels) |
| |
| def forward(self, x): |
| x = self.dense(x) |
| x = self.classifier(x) |
| return x |
| |
| def train(train_embeddings, train_labels, hidden_sizes, learning_rate, droprate, device, |
| valid_embeddings=None, valid_labels=None, testset_embeddings=None, testset_labels=None): |
| input_size = train_embeddings.shape[1] |
| model = BinaryClassifier(input_size, hidden_sizes=hidden_sizes, dropout_prob=droprate).to(device) |
| criterion = nn.CrossEntropyLoss() |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
| num_epochs = 100 |
| batch_size = 16 |
| best_valid_acc = 0.0 |
| best_test_res = {} |
| for epoch in tqdm(range(num_epochs), desc="Training classifer"): |
| for i in range(0, len(train_embeddings), batch_size): |
| model.train() |
| batch_embeddings = train_embeddings[i:i+batch_size] |
| batch_labels = train_labels[i:i+batch_size] |
| outputs = model(batch_embeddings) |
| loss = criterion(outputs, batch_labels) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| if None not in [valid_embeddings, valid_labels, testset_embeddings, testset_labels]: |
| model.eval() |
| with torch.no_grad(): |
| outputs = model(valid_embeddings) |
| |
| |
| predicted = torch.softmax(outputs.data, 1)[:, 0] |
| accuracy = roc_auc_score(valid_labels.cpu().numpy(), predicted.cpu().numpy()) |
| results = test(model, testset_embeddings) |
| if epoch % 10 == 0: |
| print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Valid AUC: {accuracy:.4f}, Test AUC: {results['metrics']['roc_auc']:.4f}") |
| if accuracy > best_valid_acc: |
| best_valid_acc = accuracy |
| best_test_res = results |
|
|
| return best_test_res |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--train_dataset', type=str, default='./exp_main/data/squad_qwen-7b') |
| parser.add_argument('--valid_dataset', type=str, default="./exp_main/data/writing_qwen-7b") |
| parser.add_argument('--test_dataset', type=str, default="./exp_main/data/xsum_qwen-7b") |
| parser.add_argument('--output_file', type=str, default="./exp_main/results/xsum_qwen-7b") |
| parser.add_argument('--lr', type=float, default=0.003) |
| parser.add_argument('--droprate', type=float, default=0.4) |
| parser.add_argument('--model_name', type=str, default="Alibaba-NLP/gte-Qwen1.5-7B-instruct") |
| parser.add_argument('--embedding_dim', type=int, default=4096) |
| parser.add_argument('--max_length', type=int, default=512) |
| parser.add_argument('--which_layer', type=str, default="max_kl") |
| parser.add_argument('--which_embedding', type=str, default='gte-qwen_all') |
| parser.add_argument('--kl_path', type=str, default='gte-qwen_KL_with_first_and_last_layer') |
| parser.add_argument('--cache_dir', type=str, default='../cache') |
| parser.add_argument('--seed', type=int, default=42) |
| args = parser.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| |
| name = 'fluoroscopy' |
|
|
| learning_rate=args.lr |
| droprate=args.droprate |
| embedding_dim=args.embedding_dim |
| which_layer=args.which_layer |
| which_embedding=args.which_embedding |
| kl_path=args.kl_path |
| max_length=args.max_length |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| model_name = args.model_name |
| cache_dir = args.cache_dir |
| |
| |
| save_dir = f'scripts/TextFluoroscopy/save/{kl_path}/' |
| tokenizer, model = load_model(model_name, cache_dir) |
| compute_kl_feat(model, tokenizer, args.train_dataset, save_dir, max_length, device) |
| compute_kl_feat(model, tokenizer, args.valid_dataset, save_dir, max_length, device) |
| compute_kl_feat(model, tokenizer, args.test_dataset, save_dir, max_length, device) |
| |
| save_dir = f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/' |
| tokenizer, model = load_model2(model_name, cache_dir) |
| compute_embedding(model, tokenizer, args.train_dataset, save_dir, max_length, device) |
| compute_embedding(model, tokenizer, args.valid_dataset, save_dir, max_length, device) |
| compute_embedding(model, tokenizer, args.test_dataset, save_dir, max_length, device) |
| |
| train_X, train_Y, valid_X, valid_Y, test_X, test_Y = get_embedding(args.train_dataset, args.valid_dataset, args.test_dataset, |
| kl_path, which_embedding, which_layer, device) |
|
|
| clf_hidden_dim = [1024, 512] |
| results = train(train_X, train_Y, clf_hidden_dim, learning_rate, droprate, device, valid_X, valid_Y, test_X, test_Y) |
|
|
| print(f"Real mean/std: {np.mean(results['predictions']['real']):.2f}/{np.std(results['predictions']['real']):.2f}, Samples mean/std: {np.mean(results['predictions']['real']):.2f}/{np.std(results['predictions']['samples']):.2f}") |
| print(f"Criterion {name}_threshold ROC AUC: {results['metrics']['roc_auc']:.4f}, PR AUC: {results['pr_metrics']['pr_auc']:.4f}") |
|
|
| results_file = f'{args.output_file}.{name}.json' |
| with open(results_file, 'w') as fout: |
| json.dump(results, fout) |
| print(f'Results written into {results_file}') |