| |
| import tempfile |
| import random |
| import numpy as np |
| import torch |
| import optax |
| import jax |
| import sys |
| from flax.training.common_utils import onehot |
| from flax.traverse_util import flatten_dict |
|
|
|
|
| def ids_tensor(shape, vocab_size, rng=None): |
| """Creates a random int32 tensor of the shape within the vocab size.""" |
| if rng is None: |
| rng = random.Random() |
|
|
| total_dims = 1 |
| for dim in shape: |
| total_dims *= dim |
|
|
| values = [] |
| for _ in range(total_dims): |
| values.append(rng.randint(0, vocab_size - 1)) |
|
|
| output = np.array(values).reshape(shape) |
|
|
| return output |
|
|
|
|
| def random_attention_mask(shape, rng=None): |
| attn_mask = ids_tensor(shape, vocab_size=2, rng=rng) |
| |
| attn_mask[:, -1] = 1 |
| return attn_mask |
|
|
|
|
| def floats_tensor(shape, scale=1.0, rng=None): |
| """Creates a random float32 tensor""" |
| if rng is None: |
| rng = random.Random() |
|
|
| total_dims = 1 |
| for dim in shape: |
| total_dims *= dim |
|
|
| values = [] |
| for _ in range(total_dims): |
| values.append(rng.random() * scale) |
|
|
| return np.array(values, dtype=np.float32).reshape(shape) |
|
|
|
|
| def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: |
| """ |
| Shift input ids one token to the right. |
| """ |
| shifted_input_ids = np.zeros_like(input_ids) |
| shifted_input_ids[:, 1:] = input_ids[:, :-1] |
| shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
| shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
| return shifted_input_ids |
|
|
|
|
| def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2): |
| diff = np.abs((a - b)).max() |
| if diff < tol: |
| print(f"β
Difference between Flax and PyTorch is {diff} (< {tol})") |
| else: |
| print(f"β Difference between Flax and PyTorch is {diff} (>= {tol})") |
|
|
|
|
| def assert_dict_equal(a: dict, b: dict, tol: float = 1e-2): |
| if a.keys() != b.keys(): |
| print("β Dictionary keys for PyTorch and Flax do not match") |
| results_fail = [] |
| results_correct = [] |
|
|
| results_fail_rel = [] |
| results_correct_rel = [] |
| for k in a: |
| ak_norm = np.linalg.norm(a[k]) |
| bk_norm = np.linalg.norm(b[k]) |
| diff = np.abs(ak_norm - bk_norm) |
| diff_rel = np.abs(ak_norm - bk_norm) / np.abs(ak_norm) |
| if diff < tol: |
| results_correct.append(f"β
Layer {k} diff is {diff} < {tol}).") |
| else: |
| results_fail.append(f"β Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.") |
| if diff_rel < tol: |
| results_correct_rel.append(f"β
Layer {k} rel diff is {diff} < {tol}).") |
| else: |
| results_fail_rel.append(f"β Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.") |
| return results_fail_rel, results_correct_rel, results_fail, results_correct |
|
|
|
|
| def compare_grads(model_id, pt_architecture): |
| transformers_module = __import__("transformers", fromlist=[pt_architecture]) |
|
|
| model_cls = getattr(transformers_module, pt_architecture) |
| flax_model_cls = getattr(transformers_module, "Flax" + pt_architecture) |
|
|
| pt_model, model_info = model_cls.from_pretrained(model_id, output_loading_info=True) |
|
|
| if len(model_info["missing_keys"]) > 0: |
| raise ValueError(f"{model_id} with {pt_architecture} has missing keys: {model_info['missing_keys']}") |
|
|
| fx_model = flax_model_cls.from_pretrained(model_id, from_pt=True) |
|
|
| batch_size = 2 |
| seq_len = 64 |
|
|
| input_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size) |
| label_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size) |
|
|
| attention_mask = random_attention_mask([batch_size, seq_len]) |
| label_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size) |
|
|
| fx_inputs = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| } |
|
|
| if pt_model.config.is_encoder_decoder: |
| decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=fx_model.config.pad_token_id, decoder_start_token_id=fx_model.config.decoder_start_token_id) |
| fx_inputs["decoder_input_ids"] = decoder_input_ids |
|
|
| pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()} |
| pt_inputs["labels"] = torch.tensor(label_ids.tolist()) |
|
|
| fx_outputs = fx_model(**fx_inputs) |
| fx_logits = fx_outputs.logits |
|
|
| pt_outputs = pt_model(**pt_inputs) |
| pt_logits = pt_outputs.logits |
| pt_loss = pt_outputs.loss |
|
|
| print("--------------------------Checking logits match--------------------------") |
| print(f"Flax logits shape: {fx_logits.shape}, PyTorch logits shape: {pt_logits.shape}") |
| assert_almost_equals(fx_logits, pt_logits.detach().numpy()) |
|
|
| def fx_train_step(fx_model, batch): |
| def compute_loss(params): |
| label_ids = batch.pop('label_ids') |
| logits = fx_model(**batch, params=params).logits |
| vocab_size = logits.shape[-1] |
| targets = onehot(label_ids, vocab_size) |
| loss = optax.softmax_cross_entropy(logits, targets) |
| return loss.mean() |
|
|
| grad_fn = jax.value_and_grad(compute_loss) |
| loss, grad = grad_fn(fx_model.params) |
| return loss, grad |
|
|
| fx_inputs["label_ids"] = label_ids |
|
|
| fx_loss, fx_grad = fx_train_step(fx_model, fx_inputs) |
|
|
| print("--------------------------Checking losses match--------------------------") |
| print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}") |
| assert_almost_equals(fx_loss, pt_loss.detach().numpy()) |
|
|
| pt_loss.backward() |
|
|
| pt_grad_dict = {k: v.grad for k, v in pt_model.named_parameters()} |
| missing_grads = [k for k in pt_model.state_dict().keys() if k not in pt_grad_dict] |
|
|
| missing_keys, unexpected_keys = pt_model.load_state_dict(pt_grad_dict, strict=False) |
|
|
| assert missing_grads == missing_keys, f"Error with either grads {missing_keys} or keys {unexpected_keys}" |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| pt_model.save_pretrained(tmpdirname) |
| pt_grad_model_to_fx = flax_model_cls.from_pretrained(tmpdirname, from_pt=True) |
|
|
| pt_grad_to_fx = pt_grad_model_to_fx.params |
| fx_grad = flatten_dict(fx_grad) |
| pt_grad_to_fx = flatten_dict(pt_grad_to_fx) |
| print("--------------------------Checking gradients match--------------------------") |
| results_fail_rel, results_correct_rel, results_fail, results_correct = assert_dict_equal(fx_grad, pt_grad_to_fx) |
|
|
| if len(results_fail) == 0: |
| print("β
All grads pass") |
| else: |
| print("\n".join(results_fail)) |
|
|
| print("--------------------------Checking rel gradients match--------------------------") |
|
|
| if len(results_fail_rel) == 0: |
| print("β
All rel grads pass") |
| else: |
| print("\n".join(results_fail_rel)) |
|
|
|
|
| def main(): |
| model_id = sys.argv[1] |
| pt_architecture_name = sys.argv[2] |
| compare_grads(model_id, pt_architecture_name) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|