| |
| |
|
|
| |
| import sys |
| sys.path.append("..") |
|
|
| from utils import CHECKPOINT_READ_PATH, PERTURBATIONS, PAREN_MODELS, get_gpt2_tokenizer_with_markers |
| from gpt2_no_positional_encoding_model import GPT2NoPositionalEncodingLMHeadModel |
| from transformers import GPT2LMHeadModel |
| from sklearn.preprocessing import StandardScaler |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.metrics import accuracy_score |
| from sklearn.model_selection import train_test_split |
| from itertools import zip_longest |
| import torch |
| import tqdm |
| import argparse |
| import pandas as pd |
| import os |
|
|
|
|
| MAX_TRAINING_STEPS = 3000 |
| CHECKPOINTS = list(range(200, MAX_TRAINING_STEPS+1, 200)) |
| LAYERS = [1, 3, 6, 9, 12, "Avg Last 4"] |
|
|
|
|
| def get_layer_embedding(model, token_sequences, indices, layer=None): |
|
|
| |
| input_ids = zip(*zip_longest(*token_sequences, |
| fillvalue=gpt2_tokenizer.eos_token_id)) |
| input_ids = torch.tensor(list(input_ids)).to(device) |
|
|
| |
| with torch.no_grad(): |
| output = model(input_ids) |
|
|
| |
| |
| if layer is not None: |
| hidden_states = output.hidden_states[layer] |
| else: |
| hidden_states = output.hidden_states[-4:] |
| hidden_states = sum(hidden_states) / 4 |
|
|
| |
| batch_size, seq_length = input_ids.shape |
| mask = torch.full((batch_size, seq_length), 0).to(device) |
| for i, (start_idx, end_idx) in enumerate(indices): |
| mask[i, start_idx:end_idx] = 1 |
|
|
| |
| mask_expanded = mask.unsqueeze(-1).expand(hidden_states.size()) |
| hidden_states = hidden_states * mask_expanded |
|
|
| return hidden_states |
|
|
|
|
| def max_pooling(tensor, index_tuples): |
| pooled_results = [] |
| for i, (start, end) in enumerate(index_tuples): |
| |
| embeddings = tensor[i, start:end, :] |
|
|
| |
| max_pooled = torch.max(embeddings, dim=0)[0] |
|
|
| pooled_results.append(max_pooled) |
| return torch.stack(pooled_results) |
|
|
|
|
| def mean_pooling(tensor, index_tuples): |
| batch_size, seq_len, embedding_size = tensor.shape |
| output = torch.empty(batch_size, embedding_size, |
| device=tensor.device, dtype=tensor.dtype) |
|
|
| for i, (start, end) in enumerate(index_tuples): |
| embeddings = tensor[i, start:end, :] |
| output[i, :] = torch.mean(embeddings, dim=0) |
|
|
| return output |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| prog='Edge probing', |
| description='Edge probing experiments') |
| parser.add_argument('perturbation_type', |
| default='all', |
| const='all', |
| nargs='?', |
| choices=PERTURBATIONS.keys(), |
| help='Perturbation function used to transform BabyLM dataset') |
| parser.add_argument('train_set', |
| default='all', |
| const='all', |
| nargs='?', |
| choices=["100M", "10M"], |
| help='BabyLM train set') |
| parser.add_argument('random_seed', type=int, help="Random seed") |
| parser.add_argument('paren_model', |
| default='all', |
| const='all', |
| nargs='?', |
| choices=list(PAREN_MODELS.keys()) + ["randinit"], |
| help='Parenthesis model') |
| parser.add_argument('pooling_operation', |
| default='all', |
| const='all', |
| nargs='?', |
| choices=["mean", "max"], |
| help='Pooling operation to compute on embeddings') |
| parser.add_argument('-np', '--no_pos_encodings', action='store_true', |
| help="Train GPT-2 with no positional encodings") |
|
|
| |
| args = parser.parse_args() |
|
|
| if args.pooling_operation == "mean": |
| pooling_function = mean_pooling |
| elif args.pooling_operation == "max": |
| pooling_function = max_pooling |
| else: |
| raise Exception("Pooling operation undefined") |
|
|
| |
| gpt2_tokenizer = get_gpt2_tokenizer_with_markers([]) |
| gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token |
|
|
| |
| no_pos_encodings_underscore = "_no_positional_encodings" if args.no_pos_encodings else "" |
| model = f"babylm_{args.perturbation_type}_{args.train_set}_{args.paren_model}{no_pos_encodings_underscore}_seed{args.random_seed}" |
| model_path = f"{CHECKPOINT_READ_PATH}/babylm_{args.perturbation_type}_{args.train_set}_{args.paren_model}{no_pos_encodings_underscore}/{model}/runs/{model}/checkpoint-" |
|
|
| |
| if "hop" in args.perturbation_type: |
| phrase_df = pd.read_csv("phrase_data/hop_phrase_data.csv") |
| elif "reverse" in args.perturbation_type: |
| phrase_df = pd.read_csv("phrase_data/reverse_phrase_data.csv") |
| else: |
| raise Exception("Phrase data not found") |
|
|
| token_sequences = list(phrase_df["Sentence Tokens"]) |
| if args.perturbation_type == "reverse_full": |
| indices = list( |
| zip(phrase_df["Rev Start Index"], phrase_df["Rev End Index"])) |
| else: |
| indices = list(zip(phrase_df["Start Index"], phrase_df["End Index"])) |
| labels = list(phrase_df["Category"]) |
|
|
| BATCH_SIZE = 32 |
| device = "cuda" |
|
|
| edge_probing_df = pd.DataFrame(LAYERS, columns=["GPT-2 Layer"]) |
| for ckpt in CHECKPOINTS: |
|
|
| |
| if args.no_pos_encodings: |
| model = GPT2LMHeadModel.from_pretrained( |
| model_path + str(ckpt), output_hidden_states=True).to(device) |
| else: |
| model = GPT2NoPositionalEncodingLMHeadModel.from_pretrained( |
| model_path + str(ckpt), output_hidden_states=True).to(device) |
|
|
| layer_accuracies = [] |
| for layer in LAYERS: |
| print(f"Checkpoint: {ckpt}, Layer: {layer}") |
| print("Computing span embeddings...") |
|
|
| |
| spans = [] |
| for i in tqdm.tqdm(list(range(0, len(token_sequences), BATCH_SIZE))): |
|
|
| tokens_batch = [[int(tok) for tok in seq.split()] |
| for seq in token_sequences[i:i+BATCH_SIZE]] |
| if args.perturbation_type == "reverse_full": |
| tokens_batch = [toks[::-1] for toks in tokens_batch] |
|
|
| index_batch = indices[i:i+BATCH_SIZE] |
|
|
| |
| if layer == "Avg Last 4": |
| embeddings = get_layer_embedding( |
| model, tokens_batch, index_batch, None) |
| else: |
| embeddings = get_layer_embedding( |
| model, tokens_batch, index_batch, layer) |
| pooled_results = pooling_function(embeddings, index_batch) |
| spans.extend(list(pooled_results)) |
|
|
| |
| X = torch.vstack(spans).detach().cpu().numpy() |
| y = labels |
|
|
| |
| |
| X_train, X_test, y_train, y_test = train_test_split( |
| X, y, test_size=0.2, random_state=args.random_seed) |
|
|
| |
| clf = LogisticRegression(max_iter=10, |
| random_state=args.random_seed).fit(X_train, y_train) |
|
|
| |
| y_test_pred = clf.predict(X_test) |
| acc = accuracy_score(y_test, y_test_pred) |
| layer_accuracies.append(acc) |
| print(f"Accuracy: {acc}") |
|
|
| edge_probing_df[f"Accuracy (ckpt {ckpt})"] = layer_accuracies |
|
|
| |
| nps = '_no_pos_encodings' if args.no_pos_encodings else '' |
| directory = f"edge_probing_results/{args.perturbation_type}_{args.train_set}{nps}" |
| if not os.path.exists(directory): |
| os.makedirs(directory) |
|
|
| file = directory + \ |
| f"/{args.paren_model}_{args.pooling_operation}_pooling_seed{args.random_seed}.csv" |
| print(f"Writing results to CSV: {file}") |
| edge_probing_df.to_csv(file) |
|
|