#!/opt/homebrew/bin/python3 """ 真理之眼:对比 PyTorch 与 MLX 视觉编码器最后一层 Hidden States 的对齐情况 修复了 PyTorch 5D 输入问题。 """ import torch import mlx.core as mx import numpy as np from PIL import Image, ImageDraw, ImageFont from transformers import AutoModel, AutoProcessor import sys from pathlib import Path # 添加最终修复版的路径 sys.path.append('/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion') from ultimate_fix import FinalOptimizedPaddleOCRMLX def create_test_image(text="Hi"): img = Image.new('RGB', (100, 100), color='white') draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 40) draw.text((20, 20), text, fill='black', font=font) except: draw.text((20, 20), text, fill='black') return img def main(): print("🚀 启动视觉编码器对齐检查 (V2)...") model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" mlx_path = "/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion" # 1. 准备输入 image = create_test_image() processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # 使用 processor 进行强制处理 inputs = processor(images=image, text="OCR", return_tensors="pt") pixel_values_pt = inputs['pixel_values'] # [num_patches, 3, 14, 14] grid_thw = inputs['image_grid_thw'].tolist() print(f"📸 输入图像 Patch 数量: {pixel_values_pt.shape[0]}") print(f"📊 Grid THW: {grid_thw}") # 2. PyTorch 推理 print("\n🔍 正在获取 PyTorch Vision Hidden States...") torch_model = AutoModel.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.float32, low_cpu_mem_usage=True ) torch_model.eval() with torch.no_grad(): # 满足 5D 约束 pixel_values_5d = pixel_values_pt.unsqueeze(0) num_patches = pixel_values_pt.shape[0] # 构造 pos_ids (0..N-1) pos_ids_pt = torch.arange(num_patches).unsqueeze(0).to(torch_model.device) # 调用视觉部分 pt_vision_out = torch_model.visual( pixel_values_5d, position_ids=pos_ids_pt, image_grid_thw=grid_thw ) if hasattr(pt_vision_out, "last_hidden_state"): pt_hidden_states = pt_vision_out.last_hidden_state.detach().cpu().numpy() else: pt_hidden_states = pt_vision_out.detach().cpu().numpy() print(f"✅ PyTorch Hidden States 获取成功: {pt_hidden_states.shape}") # 3. MLX 推理 print("\n🔍 正在获取 MLX Vision Hidden States...") mlx_ocr = FinalOptimizedPaddleOCRMLX(mlx_path) # 转换为 MLX 格式 [num_patches, H, W, 3] pixel_values_mx = mx.array(pixel_values_pt.numpy().transpose(0, 2, 3, 1)) def get_mlx_vision_hidden(model, pixel_values): num_p = pixel_values.shape[0] # 1. Patch Embedding x = model.patch_embedding(pixel_values) # [N, 1, 1, 1152] x = x.reshape(num_p, model.vision_hidden_size) x = mx.expand_dims(x, 0) # [batch=1, N, 1152] # 2. Position Embedding # 注意:这里我们使用的是与 PyTorch 对齐的顺序嵌入 if num_p <= 729: x = x + model.position_embedding[:num_p, :] else: # 这里的重复逻辑必须与 PyTorch 完全一致 pos_emb_repeated = mx.tile(model.position_embedding, (num_p // 729 + 1, 1)) x = x + pos_emb_repeated[:num_p, :] # 3. Vision Layers for i, layer in enumerate(model.vision_layers): x = layer(x, None) # 4. Final Norm x = model.vision_norm(x) # 5. Vision Head x = model.vision_head(x) # 6. Post LayerNorm x = model.post_layernorm(x) return x mx_hidden_states = np.array(get_mlx_vision_hidden(mlx_ocr.model, pixel_values_mx)) print(f"✅ MLX Hidden States 获取成功: {mx_hidden_states.shape}") # 4. 对齐分析 print("\n" + "="*60) print("📊 统计分析结果") print("="*60) # 对齐维度 # PyTorch 输出通常是 [batch, seq, dim] -> [1, 784, 1152] # MLX 输出通常是 [1, 784, 1152] pt_flat = pt_hidden_states.flatten() mx_flat = mx_hidden_states.flatten() if pt_flat.shape != mx_flat.shape: print(f"❌ 警告: 维度不匹配! PT: {pt_flat.shape}, MX: {mx_flat.shape}") # 限制到共同部分进行对比 min_size = min(pt_flat.size, mx_flat.size) pt_flat = pt_flat[:min_size] mx_flat = mx_flat[:min_size] cos_sim = np.dot(pt_flat, mx_flat) / (np.linalg.norm(pt_flat) * np.linalg.norm(mx_flat)) mae = np.mean(np.abs(pt_flat - mx_flat)) max_diff = np.max(np.abs(pt_flat - mx_flat)) print(f"✨ 余弦相似度: {cos_sim:.8f} (目标 > 0.999)") print(f"📏 平均绝对误差 (MAE): {mae:.8f}") print(f"🚫 最大绝对差异: {max_diff:.8f}") if cos_sim > 0.9995: print("\n🏆 结论: 视觉编码器完全对齐!问题 100% 锁定在 LLM/RoPE 层面。") elif cos_sim > 0.99: print("\n✅ 结论: 视觉编码器基本对齐。") else: print("\n❌ 结论: 视觉编码器不齐!这才是 OCR 失效的根源。") print("="*60) if __name__ == "__main__": main()