subnet32-llm-detector / scripts /detect_fluoroscopy.py
ThaoTran7's picture
incomplete commit
485127c
import torch
import torch.nn as nn
from tqdm import tqdm
import pickle
import numpy as np
import argparse
from metrics import get_roc_metrics, get_precision_recall_metrics
from sklearn.metrics import roc_auc_score
from TextFluoroscopy.utils_fluoroscopy import compute_embedding, compute_kl_feat, load_model, load_model2
import json
def get_embedding(train_file, valid_file, test_file, kl_path, which_embedding, which_layer, device):
train_num=None
valid_num=None
test_num=None
train_name = train_file.split("/")[-1]
valid_name = valid_file.split("/")[-1]
train_embeddings = torch.load(f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/{train_name}.pt')[:train_num]
valid_embeddings = torch.load(f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/{valid_name}.pt')[:valid_num]
train_labels = torch.arange(train_embeddings.shape[0]) % 2
valid_labels = torch.arange(valid_embeddings.shape[0]) % 2
train_embeddings = train_embeddings.to(device)
valid_embeddings = valid_embeddings.to(device)
train_labels = train_labels.to(device)
valid_labels = valid_labels.to(device)
with open(f'scripts/TextFluoroscopy/save/{kl_path}/{train_name}.pkl', 'rb') as f:
train_kl = pickle.load(f)
train_kl = np.array(train_kl)
idx = train_kl.argmax(axis=1)
if which_layer == 'max_kl':
train_embeddings = torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row ,i in zip(train_embeddings, idx)]).to(device)
if which_layer == 'max_kl_and_last_layer':
train_embeddings = torch.cat([torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row ,i in zip(train_embeddings, idx)]),
train_embeddings[:,-embedding_dim:]], dim=1).to(device)
with open(f'scripts/TextFluoroscopy/save/{kl_path}/{valid_name}.pkl', 'rb') as f:
valid_kl = pickle.load(f)
valid_kl = np.array(valid_kl)
idx = valid_kl.argmax(axis=1)
if which_layer == 'max_kl':
valid_embeddings = torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row,i in zip(valid_embeddings, idx)]).to(device)
if which_layer == 'max_kl_and_last_layer':
valid_embeddings = torch.cat([torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row, i in zip(valid_embeddings, idx)]),
valid_embeddings[:,-embedding_dim:]], dim=1).to(device)
if which_layer == 'first_layer':
train_embeddings = train_embeddings[:,:embedding_dim].to(device)
valid_embeddings = valid_embeddings[:,:embedding_dim].to(device)
test_embeddings = test_embeddings[:,:embedding_dim].to(device)
elif which_layer == 'last_layer':
train_embeddings = train_embeddings[:,-embedding_dim:].to(device)
valid_embeddings = valid_embeddings[:,-embedding_dim:].to(device)
test_embeddings = test_embeddings[:,-embedding_dim:].to(device)
elif which_layer == 'first_and_last_layers':
train_embeddings = torch.cat([train_embeddings[:,:embedding_dim],train_embeddings[:,-embedding_dim:]], dim=1).to(device)
valid_embeddings = torch.cat([valid_embeddings[:,:embedding_dim],valid_embeddings[:,-embedding_dim:]], dim=1).to(device)
test_embeddings = torch.cat([test_embeddings[:,:embedding_dim],test_embeddings[:,-embedding_dim:]], dim=1).to(device)
elif which_layer.startswith('layer_'):
if 'last_layer' not in which_layer and 'later_layer' not in which_layer and 'to' not in which_layer:
layer_num = int(which_layer.split('_')[-1])
train_embeddings = train_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim].to(device)
valid_embeddings = valid_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim].to(device)
elif 'last_layer' in which_layer:
layer_num = int(which_layer.split('_')[1])
train_embeddings = torch.cat([train_embeddings[:,-embedding_dim:],train_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim]], dim=1).to(device)
valid_embeddings = torch.cat([valid_embeddings[:,-embedding_dim:],valid_embeddings[:,(layer_num)*embedding_dim:(layer_num+1)*embedding_dim]], dim=1).to(device)
elif 'later_layer' in which_layer:
layer_num = int(which_layer.split('_')[1])
train_embeddings = train_embeddings[:,(layer_num)*embedding_dim:].to(device)
valid_embeddings = valid_embeddings[:,(layer_num)*embedding_dim:].to(device)
elif 'to' in which_layer:
layer_num = int(which_layer.split('_')[1])
layer_num2 = int(which_layer.split('_')[3])
train_embeddings = train_embeddings[:,(layer_num)*embedding_dim:(layer_num2+1)*embedding_dim].to(device)
valid_embeddings = valid_embeddings[:,(layer_num)*embedding_dim:(layer_num2+1)*embedding_dim].to(device)
test_name = test_file.split("/")[-1]
testset_embeddings = torch.load(f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/{test_name}.pt')[:test_num]
testset_embeddings = testset_embeddings.to(device)
testset_labels = torch.arange(testset_embeddings.shape[0]) % 2
with open(f'scripts/TextFluoroscopy/save/{kl_path}/{test_name}.pkl', 'rb') as f:
kl = pickle.load(f)
kl = np.array(kl)
idx = kl.argmax(axis=1)
if which_layer == 'max_kl':
testset_embeddings = torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row, i in zip(testset_embeddings, idx)]).to(device)
elif which_layer == 'max_kl_and_last_layer':
testset_embeddings = torch.cat([torch.tensor([row[(i+1)*embedding_dim:(i+2)*embedding_dim].tolist() for row, i in zip(testset_embeddings, idx)]),
testset_embeddings[:,-embedding_dim:]],dim=1).to(device)
elif which_layer == 'first_layer':
testset_embeddings = testset_embeddings[:, :embedding_dim].to(device)
elif which_layer == 'last_layer':
testset_embeddings = testset_embeddings[:, -embedding_dim:].to(device)
elif which_layer == 'first_and_last_layers':
testset_embeddings = torch.cat([testset_embeddings[:, :embedding_dim], testset_embeddings[:, -embedding_dim:]], dim=1).to(device)
elif which_layer.startswith('layer_'):
if 'last_layer' not in which_layer and 'later_layer' not in which_layer and 'to' not in which_layer:
layer_num = int(which_layer.split('_'))
testset_embeddings = testset_embeddings[:, (layer_num)*embedding_dim:(layer_num+1)*embedding_dim].to(device)
elif 'last_layer' in which_layer:
layer_num = int(which_layer.split('_')[1])
testset_embeddings = torch.cat([testset_embeddings[:, -embedding_dim:], testset_embeddings[:, (layer_num)*embedding_dim:(layer_num+1)*embedding_dim]], dim=1).to(device)
elif 'later_layer' in which_layer:
layer_num = int(which_layer.split('_')[1])
testset_embeddings = testset_embeddings[:, (layer_num)*embedding_dim:].to(device)
elif 'to' in which_layer:
layer_num = int(which_layer.split('_')[1])
layer_num2 = int(which_layer.split('_')[3])
testset_embeddings = testset_embeddings[:, (layer_num)*embedding_dim:(layer_num2+1)*embedding_dim].to(device)
return train_embeddings, train_labels, valid_embeddings, valid_labels, testset_embeddings, testset_labels
def test(model, test_set):
with torch.no_grad():
outputs = model(test_set)
probabilities = torch.softmax(outputs, dim=1)[:, 1]
prediction = probabilities.cpu().numpy()
real_pred = prediction[0::2].tolist()
sampled_pred = prediction[1::2].tolist()
fpr, tpr, roc_auc = get_roc_metrics(real_pred, sampled_pred)
p, r, pr_auc = get_precision_recall_metrics(real_pred, sampled_pred)
results = {
'name': f'fluoroscopy_threshold',
'predictions': {'real': real_pred, 'samples': sampled_pred},
'metrics': {'roc_auc': roc_auc, 'fpr': fpr, 'tpr': tpr},
'pr_metrics': {'pr_auc': pr_auc, 'precision': p, 'recall': r},
'loss': 1 - pr_auc
}
return results
class BinaryClassifier(nn.Module):
def __init__(self, input_size, hidden_sizes=[1024, 512], num_labels=2, dropout_prob=0.2):
super(BinaryClassifier, self).__init__()
self.num_labels = num_labels
layers = []
prev_size = input_size
for hidden_size in hidden_sizes:
layers.extend([
nn.Dropout(dropout_prob),
nn.Linear(prev_size, hidden_size),
# nn.Tanh(),
nn.ReLU(),
])
prev_size = hidden_size
self.dense = nn.Sequential(*layers)
self.classifier = nn.Linear(prev_size, num_labels)
def forward(self, x):
x = self.dense(x)
x = self.classifier(x)
return x
def train(train_embeddings, train_labels, hidden_sizes, learning_rate, droprate, device,
valid_embeddings=None, valid_labels=None, testset_embeddings=None, testset_labels=None):
input_size = train_embeddings.shape[1]
model = BinaryClassifier(input_size, hidden_sizes=hidden_sizes, dropout_prob=droprate).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 100
batch_size = 16
best_valid_acc = 0.0
best_test_res = {}
for epoch in tqdm(range(num_epochs), desc="Training classifer"):
for i in range(0, len(train_embeddings), batch_size):
model.train()
batch_embeddings = train_embeddings[i:i+batch_size]
batch_labels = train_labels[i:i+batch_size]
outputs = model(batch_embeddings)
loss = criterion(outputs, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if None not in [valid_embeddings, valid_labels, testset_embeddings, testset_labels]:
model.eval()
with torch.no_grad():
outputs = model(valid_embeddings)
# _, predicted = torch.max(outputs.data, 1)
# accuracy = (predicted == valid_labels).sum().item() / len(valid_labels)
predicted = torch.softmax(outputs.data, 1)[:, 0]
accuracy = roc_auc_score(valid_labels.cpu().numpy(), predicted.cpu().numpy())
results = test(model, testset_embeddings)
if epoch % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Valid AUC: {accuracy:.4f}, Test AUC: {results['metrics']['roc_auc']:.4f}")
if accuracy > best_valid_acc:
best_valid_acc = accuracy
best_test_res = results
return best_test_res
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_dataset', type=str, default='./exp_main/data/squad_qwen-7b')
parser.add_argument('--valid_dataset', type=str, default="./exp_main/data/writing_qwen-7b")
parser.add_argument('--test_dataset', type=str, default="./exp_main/data/xsum_qwen-7b")
parser.add_argument('--output_file', type=str, default="./exp_main/results/xsum_qwen-7b")
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--droprate', type=float, default=0.4)
parser.add_argument('--model_name', type=str, default="Alibaba-NLP/gte-Qwen1.5-7B-instruct")
parser.add_argument('--embedding_dim', type=int, default=4096)
parser.add_argument('--max_length', type=int, default=512)
parser.add_argument('--which_layer', type=str, default="max_kl")
parser.add_argument('--which_embedding', type=str, default='gte-qwen_all')
parser.add_argument('--kl_path', type=str, default='gte-qwen_KL_with_first_and_last_layer')
parser.add_argument('--cache_dir', type=str, default='../cache')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
torch.manual_seed(args.seed)
name = 'fluoroscopy'
learning_rate=args.lr
droprate=args.droprate
embedding_dim=args.embedding_dim
which_layer=args.which_layer
which_embedding=args.which_embedding
kl_path=args.kl_path
max_length=args.max_length
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = args.model_name
cache_dir = args.cache_dir
# cache_dir = '/root/autodl-tmp/cache'
save_dir = f'scripts/TextFluoroscopy/save/{kl_path}/'
tokenizer, model = load_model(model_name, cache_dir)
compute_kl_feat(model, tokenizer, args.train_dataset, save_dir, max_length, device)
compute_kl_feat(model, tokenizer, args.valid_dataset, save_dir, max_length, device)
compute_kl_feat(model, tokenizer, args.test_dataset, save_dir, max_length, device)
save_dir = f'scripts/TextFluoroscopy/save/{which_embedding}_embedding/save_embedding/'
tokenizer, model = load_model2(model_name, cache_dir)
compute_embedding(model, tokenizer, args.train_dataset, save_dir, max_length, device)
compute_embedding(model, tokenizer, args.valid_dataset, save_dir, max_length, device)
compute_embedding(model, tokenizer, args.test_dataset, save_dir, max_length, device)
train_X, train_Y, valid_X, valid_Y, test_X, test_Y = get_embedding(args.train_dataset, args.valid_dataset, args.test_dataset,
kl_path, which_embedding, which_layer, device)
clf_hidden_dim = [1024, 512]
results = train(train_X, train_Y, clf_hidden_dim, learning_rate, droprate, device, valid_X, valid_Y, test_X, test_Y)
print(f"Real mean/std: {np.mean(results['predictions']['real']):.2f}/{np.std(results['predictions']['real']):.2f}, Samples mean/std: {np.mean(results['predictions']['real']):.2f}/{np.std(results['predictions']['samples']):.2f}")
print(f"Criterion {name}_threshold ROC AUC: {results['metrics']['roc_auc']:.4f}, PR AUC: {results['pr_metrics']['pr_auc']:.4f}")
results_file = f'{args.output_file}.{name}.json'
with open(results_file, 'w') as fout:
json.dump(results, fout)
print(f'Results written into {results_file}')