| | import os |
| | from itertools import combinations |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from datasets import Audio, load_dataset |
| | from safetensors.torch import save_file |
| | from tqdm import tqdm |
| | from transformers import AutoFeatureExtractor, WhisperModel |
| |
|
| | from .config import * |
| |
|
| | model_ids = ENABLED_MODELS |
| |
|
| | |
| | dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train") |
| | if MAX_SAMPLES is not None: |
| | dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset)))) |
| | print(f"Limited dataset to {len(dataset)} samples for testing") |
| |
|
| | dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) |
| |
|
| | device = torch.device( |
| | "cuda" |
| | if torch.cuda.is_available() |
| | else "mps" if torch.backends.mps.is_available() else "cpu" |
| | ) |
| | print(f"Using device: {device}") |
| |
|
| |
|
| | def extract_layer_reps_generator(model_id, batch_size=4): |
| | """ |
| | Use a generator to process samples in batches, avoiding loading all hidden states into memory at once. |
| | Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample. |
| | """ |
| | model = WhisperModel.from_pretrained(model_id).to(device) |
| | feat_ext = AutoFeatureExtractor.from_pretrained(model_id) |
| | model.eval() |
| |
|
| | for i in tqdm( |
| | range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches" |
| | ): |
| | batch_end = min(i + batch_size, len(dataset)) |
| | batch_samples = dataset.select(range(i, batch_end)) |
| |
|
| | |
| | for j, sample in enumerate(batch_samples): |
| | audio = sample["audio"] |
| | samples = audio["array"] |
| | sr = audio["sampling_rate"] |
| |
|
| | inputs = feat_ext( |
| | samples, sampling_rate=sr, return_tensors="pt" |
| | ).input_features.to(device) |
| | with torch.no_grad(): |
| | outputs = model.encoder( |
| | inputs, return_dict=True, output_hidden_states=True |
| | ) |
| |
|
| | |
| | layer_reps_for_sample = [] |
| | for hs in outputs.hidden_states: |
| | |
| | layer_rep = hs.squeeze(0) |
| | if USE_HALF_PRECISION: |
| | layer_rep = layer_rep.to(HALF_PRECISION_DTYPE) |
| | layer_reps_for_sample.append(layer_rep) |
| |
|
| | yield i + j, layer_reps_for_sample |
| |
|
| | |
| | del outputs, inputs |
| | if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | del model, feat_ext |
| | if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | def compute_linear_mse_matrix_temporal_memory_efficient( |
| | model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4 |
| | ): |
| | """ |
| | Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE. |
| | Uses a generator to process in batches, avoiding loading all representations into memory at once. |
| | Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes. |
| | """ |
| | print(f"Computing alignment between {model_a_id} and {model_b_id}...") |
| |
|
| | |
| | sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1) |
| | _, sample_reps_a = next(sample_gen_a) |
| | layers_a = len(sample_reps_a) |
| |
|
| | sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1) |
| | _, sample_reps_b = next(sample_gen_b) |
| | layers_b = len(sample_reps_b) |
| |
|
| | mse_mat = np.zeros((layers_a, layers_b)) |
| | trained_probes = {} |
| |
|
| | pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs") |
| |
|
| | |
| | gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size) |
| | gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size) |
| |
|
| | |
| | reps_a_dict_all = {} |
| | for sample_idx, layer_reps in gen_a: |
| | reps_a_dict_all[sample_idx] = layer_reps |
| |
|
| | reps_b_dict_all = {} |
| | for sample_idx, layer_reps in gen_b: |
| | reps_b_dict_all[sample_idx] = layer_reps |
| |
|
| | for i in range(layers_a): |
| | for j in range(layers_b): |
| | |
| | reps_a_dict = {} |
| | for sample_idx, layer_reps in reps_a_dict_all.items(): |
| | if i < len(layer_reps): |
| | reps_a_dict[sample_idx] = layer_reps[i] |
| |
|
| | reps_b_dict = {} |
| | for sample_idx, layer_reps in reps_b_dict_all.items(): |
| | if j < len(layer_reps): |
| | reps_b_dict[sample_idx] = layer_reps[j] |
| |
|
| | |
| | X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())] |
| | Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())] |
| |
|
| | |
| | X_cat = torch.cat(X_list, dim=0).to(device) |
| | Y_cat = torch.cat(Y_list, dim=0).to(device) |
| |
|
| | dim_a = X_cat.shape[1] |
| | dim_b = Y_cat.shape[1] |
| |
|
| | |
| | X = X_cat.T.unsqueeze(0) |
| | Y = Y_cat.T.unsqueeze(0) |
| |
|
| | |
| | probe = nn.Conv1d( |
| | in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False |
| | ).to(device=device, dtype=HALF_PRECISION_DTYPE) |
| | probe.train() |
| |
|
| | optimizer = torch.optim.Adam(probe.parameters(), lr=lr) |
| | loss_fn = nn.MSELoss() |
| |
|
| | for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"): |
| | optimizer.zero_grad() |
| | Y_pred = probe(X) |
| | loss = loss_fn(Y_pred, Y) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | |
| | final_mse = loss.item() |
| | mse_mat[i, j] = final_mse |
| | trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"] |
| |
|
| | |
| | del ( |
| | X_cat, |
| | Y_cat, |
| | X, |
| | Y, |
| | probe, |
| | optimizer, |
| | reps_a_dict, |
| | reps_b_dict, |
| | X_list, |
| | Y_list, |
| | ) |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | pbar.update(1) |
| | pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"}) |
| |
|
| | pbar.close() |
| | return mse_mat, trained_probes |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print(f"Memory optimization settings:") |
| | print(f" Batch size: {BATCH_SIZE}") |
| | print(f" Training steps: {TRAINING_STEPS}") |
| | if USE_HALF_PRECISION: |
| | dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16" |
| | print(f" Half precision: {USE_HALF_PRECISION} ({dtype_name})") |
| | else: |
| | print(f" Half precision: {USE_HALF_PRECISION}") |
| | print(f" Aggressive cleanup: {AGGRESSIVE_CLEANUP}") |
| | print(f" Models: {list(model_ids.keys())}") |
| | print(f" Dataset size: {len(dataset)} samples") |
| |
|
| | |
| | os.makedirs(OUTPUT_DIR, exist_ok=True) |
| |
|
| | |
| | model_names = list(model_ids.keys()) |
| | all_pairs = list(combinations(model_names, 2)) |
| |
|
| | print( |
| | f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..." |
| | ) |
| |
|
| | for pair_idx, (model_a, model_b) in enumerate(all_pairs): |
| | print( |
| | f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..." |
| | ) |
| |
|
| | |
| | mse_mat_temporal, trained_probes = ( |
| | compute_linear_mse_matrix_temporal_memory_efficient( |
| | model_ids[model_a], |
| | model_ids[model_b], |
| | n_steps=TRAINING_STEPS, |
| | lr=LEARNING_RATE, |
| | batch_size=BATCH_SIZE, |
| | ) |
| | ) |
| |
|
| | |
| | model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors" |
| | save_file( |
| | trained_probes, |
| | model_save_path, |
| | { |
| | "from_model": model_a, |
| | "to_model": model_b, |
| | "from_layers": str(len(mse_mat_temporal)), |
| | "to_layers": str(len(mse_mat_temporal[0])), |
| | }, |
| | ) |
| | print(f"Saved trained probes to: {model_save_path}") |
| |
|
| | if SAVE_PLOTS: |
| | |
| | |
| | eps = 1e-10 |
| | log_mse_mat = -np.log10(mse_mat_temporal + eps) |
| |
|
| | plt.figure(figsize=(8, 6)) |
| | plt.imshow( |
| | log_mse_mat, aspect="auto", origin="lower" |
| | ) |
| | plt.colorbar(label="-log10(MSE)") |
| | plt.title( |
| | f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}" |
| | ) |
| | plt.xlabel(f"whisper-{model_b} layers") |
| | plt.ylabel(f"whisper-{model_a} layers") |
| | plt.tight_layout() |
| |
|
| | |
| | plot_save_path = ( |
| | f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png" |
| | ) |
| | plt.savefig(plot_save_path, dpi=PLOT_DPI) |
| | plt.close() |
| | print(f"Saved plot to: {plot_save_path}") |
| |
|
| | print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory") |
| | print( |
| | f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models" |
| | ) |
| |
|