| | |
| | """ |
| | 对比 PyTorch 原始模型和 MLX 实现 |
| | 逐层对比输出,找出差异 |
| | |
| | 目标:找到准确度低的根本原因 |
| | """ |
| |
|
| | import torch |
| | import mlx.core as mx |
| | import numpy as np |
| | from PIL import Image, ImageDraw |
| | from transformers import AutoModel, AutoProcessor |
| | import sys |
| |
|
| | sys.path.append('/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion') |
| | from final_fixed import FinalOptimizedPaddleOCRMLX |
| |
|
| |
|
| | def create_test_image(): |
| | """创建简单的测试图像""" |
| | img = Image.new('RGB', (400, 200), color='white') |
| | draw = ImageDraw.Draw(img) |
| | draw.text((50, 80), "Hello World", fill='black') |
| | test_path = "/tmp/test_compare.png" |
| | img.save(test_path) |
| | return test_path, img |
| |
|
| |
|
| | def test_pytorch_model(image_path, img): |
| | """测试 PyTorch 模型""" |
| | print("\n" + "="*60) |
| | print("🔍 测试 PyTorch 原始模型") |
| | print("="*60) |
| | |
| | model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" |
| | |
| | |
| | print("\n📂 加载模型...") |
| | model = AutoModel.from_pretrained( |
| | model_path, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float32 |
| | ) |
| | model.eval() |
| | |
| | processor = AutoProcessor.from_pretrained( |
| | model_path, |
| | trust_remote_code=True |
| | ) |
| | |
| | print("✅ 模型加载完成") |
| | |
| | |
| | print("\n🔄 处理图像...") |
| | prompt = "Extract all text from this image." |
| | inputs = processor( |
| | images=img, |
| | text=prompt, |
| | return_tensors="pt" |
| | ) |
| | |
| | print(f"pixel_values shape: {inputs['pixel_values'].shape}") |
| | print(f"input_ids shape: {inputs['input_ids'].shape}") |
| | print(f"image_grid_thw: {inputs['image_grid_thw']}") |
| | |
| | |
| | print("\n🔍 获取中间层输出...") |
| | |
| | with torch.no_grad(): |
| | |
| | vision_outputs = model.vision_model(inputs['pixel_values']) |
| | vision_features = vision_outputs.last_hidden_state |
| | print(f"\n1. Vision encoder output: {vision_features.shape}") |
| | print(f" 均值: {vision_features.mean():.6f}") |
| | print(f" 标准差: {vision_features.std():.6f}") |
| | |
| | |
| | vision_embeds = model.mlp_AR(vision_features, inputs['image_grid_thw'].tolist()) |
| | if isinstance(vision_embeds, list): |
| | vision_embeds = vision_embeds[0] |
| | print(f"\n2. mlp_AR output: {vision_embeds.shape}") |
| | print(f" 均值: {vision_embeds.mean():.6f}") |
| | print(f" 标准差: {vision_embeds.std():.6f}") |
| | |
| | |
| | outputs = model( |
| | input_ids=inputs['input_ids'], |
| | pixel_values=inputs['pixel_values'], |
| | image_grid_thw=inputs['image_grid_thw'], |
| | max_new_tokens=10, |
| | do_sample=False |
| | ) |
| | |
| | print(f"\n3. 生成的 token IDs: {outputs[0][:20].tolist()}") |
| | |
| | return { |
| | 'vision_features': vision_features, |
| | 'vision_embeds': vision_embeds, |
| | 'inputs': inputs |
| | } |
| |
|
| |
|
| | def test_mlx_model(image_path): |
| | """测试 MLX 模型""" |
| | print("\n" + "="*60) |
| | print("🔍 测试 MLX 实现") |
| | print("="*60) |
| | |
| | model_dir = "/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion" |
| | ocr = FinalOptimizedPaddleOCRMLX(model_dir) |
| | |
| | |
| | prompt = "Extract all text from this image." |
| | pixel_values, input_ids, image_grid_thw = ocr.preprocess_image(image_path, prompt) |
| | |
| | print(f"\npixel_values shape: {pixel_values.shape}") |
| | print(f"input_ids shape: {input_ids.shape}") |
| | print(f"image_grid_thw: {image_grid_thw.tolist()}") |
| | |
| | |
| | print("\n🔍 获取中间层输出...") |
| | |
| | |
| | num_patches = pixel_values.shape[0] |
| | x = ocr.model.patch_embedding(pixel_values) |
| | x = x.reshape(num_patches, ocr.model.vision_hidden_size) |
| | x = mx.expand_dims(x, 0) |
| | print(f"\n1. Patch embedding: {x.shape}") |
| | print(f" 均值: {float(mx.mean(x)):.6f}") |
| | print(f" 标准差: {float(mx.std(x)):.6f}") |
| | |
| | |
| | x = x + ocr.model.position_embedding[:num_patches, :] |
| | print(f"\n2. After position embedding: {x.shape}") |
| | print(f" 均值: {float(mx.mean(x)):.6f}") |
| | print(f" 标准差: {float(mx.std(x)):.6f}") |
| | |
| | |
| | for i in range(3): |
| | x = ocr.model.vision_layers[i](x, None) |
| | print(f"\n3.{i+1}. After vision layer {i}: {x.shape}") |
| | print(f" 均值: {float(mx.mean(x)):.6f}") |
| | print(f" 标准差: {float(mx.std(x)):.6f}") |
| | |
| | |
| | vision_embeds = ocr.model.encode_image(pixel_values, image_grid_thw) |
| | print(f"\n4. Final vision embeds: {vision_embeds.shape}") |
| | print(f" 均值: {float(mx.mean(vision_embeds)):.6f}") |
| | print(f" 标准差: {float(mx.std(vision_embeds)):.6f}") |
| | |
| | |
| | result = ocr.generate(pixel_values, input_ids, image_grid_thw, max_tokens=10, repetition_penalty=2.0) |
| | print(f"\n5. 生成结果: {result}") |
| | |
| | return { |
| | 'vision_embeds': vision_embeds |
| | } |
| |
|
| |
|
| | def compare_outputs(pytorch_results, mlx_results): |
| | """对比输出""" |
| | print("\n" + "="*60) |
| | print("📊 对比 PyTorch vs MLX") |
| | print("="*60) |
| | |
| | |
| | pt_embeds = pytorch_results['vision_embeds'].numpy() |
| | mlx_embeds = np.array(mlx_results['vision_embeds']) |
| | |
| | print(f"\nVision Embeds 对比:") |
| | print(f"PyTorch shape: {pt_embeds.shape}") |
| | print(f"MLX shape: {mlx_embeds.shape}") |
| | |
| | if pt_embeds.shape == mlx_embeds.shape: |
| | diff = np.abs(pt_embeds - mlx_embeds) |
| | print(f"\n差异统计:") |
| | print(f" 最大差异: {diff.max():.6f}") |
| | print(f" 平均差异: {diff.mean():.6f}") |
| | print(f" 相对误差: {(diff.mean() / np.abs(pt_embeds).mean() * 100):.2f}%") |
| | |
| | |
| | if diff.max() < 0.1: |
| | print("✅ 输出非常接近!") |
| | elif diff.max() < 1.0: |
| | print("⚠️ 输出有一些差异") |
| | else: |
| | print("❌ 输出差异很大!") |
| | else: |
| | print("❌ 形状不匹配!") |
| |
|
| |
|
| | def main(): |
| | """主函数""" |
| | |
| | print("\n" + "="*60) |
| | print("🔬 PyTorch vs MLX 详细对比") |
| | print("="*60) |
| | print("目标: 找出准确度低的根本原因") |
| | print("="*60) |
| | |
| | try: |
| | |
| | print("\n📋 创建测试图像...") |
| | image_path, img = create_test_image() |
| | print(f"✅ 测试图像: {image_path}") |
| | |
| | |
| | pytorch_results = test_pytorch_model(image_path, img) |
| | |
| | |
| | mlx_results = test_mlx_model(image_path) |
| | |
| | |
| | compare_outputs(pytorch_results, mlx_results) |
| | |
| | print("\n" + "="*60) |
| | print("✅ 对比完成") |
| | print("="*60) |
| | |
| | except Exception as e: |
| | print(f"\n❌ 错误: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|