| | import argparse |
| | import time |
| | from typing import List |
| |
|
| | import model |
| | import numpy as np |
| | import mlx.core as mx |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| |
|
| | def run_torch(bert_model: str, batch: List[str]): |
| | print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}") |
| | start_time = time.time() |
| | tokenizer = AutoTokenizer.from_pretrained(bert_model) |
| | torch_model = AutoModel.from_pretrained(bert_model) |
| | load_time = time.time() - start_time |
| | print(f"[PyTorch] Model loaded in {load_time:.2f} seconds") |
| | |
| | print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences") |
| | torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) |
| | |
| | print(f"[PyTorch] Running model inference") |
| | inference_start = time.time() |
| | torch_forward = torch_model(**torch_tokens) |
| | inference_time = time.time() - inference_start |
| | print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds") |
| | |
| | torch_output = torch_forward.last_hidden_state.detach().numpy() |
| | torch_pooled = torch_forward.pooler_output.detach().numpy() |
| | |
| | print(f"[PyTorch] Output shape: {torch_output.shape}") |
| | print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}") |
| | |
| | |
| | print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}") |
| | print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}") |
| | |
| | return torch_output, torch_pooled |
| |
|
| |
|
| | def run_mlx(bert_model: str, mlx_model: str, batch: List[str]): |
| | print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}") |
| | start_time = time.time() |
| | mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch) |
| | load_and_run_time = time.time() - start_time |
| | print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds") |
| | |
| | |
| | |
| | mlx_output_np = np.array(mlx_output) |
| | mlx_pooled_np = np.array(mlx_pooled) |
| | |
| | print(f"[MLX] Output shape: {mlx_output_np.shape}") |
| | print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}") |
| | |
| | |
| | print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}") |
| | print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}") |
| | |
| | return mlx_output_np, mlx_pooled_np |
| |
|
| |
|
| | def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled): |
| | print("\n[Comparison] Comparing PyTorch and MLX outputs") |
| | |
| | |
| | print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}") |
| | print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}") |
| | |
| | |
| | output_max_diff = np.max(np.abs(torch_output - mlx_output)) |
| | output_mean_diff = np.mean(np.abs(torch_output - mlx_output)) |
| | pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled)) |
| | pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled)) |
| | |
| | print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}") |
| | print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}") |
| | print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}") |
| | print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}") |
| | |
| | |
| | print("\n[Comparison] Detailed comparison of first 5 values from first output token:") |
| | for i in range(5): |
| | torch_val = torch_output[0, 0, i] |
| | mlx_val = mlx_output[0, 0, i] |
| | diff = abs(torch_val - mlx_val) |
| | print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}") |
| | |
| | |
| | outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4) |
| | pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4) |
| | |
| | print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}") |
| | print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}") |
| | |
| | return outputs_close and pooled_close |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs." |
| | ) |
| | parser.add_argument( |
| | "--bert-model", |
| | type=str, |
| | default="bert-base-uncased", |
| | help="The model identifier for a BERT-like model from Hugging Face Transformers.", |
| | ) |
| | parser.add_argument( |
| | "--mlx-model", |
| | type=str, |
| | default="weights/bert-base-uncased.npz", |
| | help="The path of the stored MLX BERT weights (npz file).", |
| | ) |
| | parser.add_argument( |
| | "--text", |
| | nargs="+", |
| | default=["This is an example of BERT working in MLX."], |
| | help="A batch of texts to process. Multiple texts should be separated by spaces.", |
| | ) |
| | parser.add_argument( |
| | "--verbose", |
| | action="store_true", |
| | help="Print detailed information about the model execution.", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | |
| | print(f"Testing BERT model: {args.bert_model}") |
| | print(f"MLX weights: {args.mlx_model}") |
| | print(f"Input text: {args.text}") |
| | |
| | |
| | torch_output, torch_pooled = run_torch(args.bert_model, args.text) |
| | mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text) |
| | |
| | |
| | all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled) |
| | |
| | if all_match: |
| | print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!") |
| | else: |
| | print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.") |
| |
|