| | |
| | """ |
| | 真理之眼:对比 PyTorch 与 MLX 视觉编码器最后一层 Hidden States 的对齐情况 |
| | 修复了 PyTorch 5D 输入问题。 |
| | """ |
| |
|
| | import torch |
| | import mlx.core as mx |
| | import numpy as np |
| | from PIL import Image, ImageDraw, ImageFont |
| | from transformers import AutoModel, AutoProcessor |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.append('/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion') |
| | from ultimate_fix import FinalOptimizedPaddleOCRMLX |
| |
|
| | def create_test_image(text="Hi"): |
| | img = Image.new('RGB', (100, 100), color='white') |
| | draw = ImageDraw.Draw(img) |
| | try: |
| | font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 40) |
| | draw.text((20, 20), text, fill='black', font=font) |
| | except: |
| | draw.text((20, 20), text, fill='black') |
| | return img |
| |
|
| | def main(): |
| | print("🚀 启动视觉编码器对齐检查 (V2)...") |
| | model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" |
| | mlx_path = "/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion" |
| | |
| | |
| | image = create_test_image() |
| | processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
| | |
| | inputs = processor(images=image, text="OCR", return_tensors="pt") |
| | pixel_values_pt = inputs['pixel_values'] |
| | grid_thw = inputs['image_grid_thw'].tolist() |
| | |
| | print(f"📸 输入图像 Patch 数量: {pixel_values_pt.shape[0]}") |
| | print(f"📊 Grid THW: {grid_thw}") |
| |
|
| | |
| | print("\n🔍 正在获取 PyTorch Vision Hidden States...") |
| | torch_model = AutoModel.from_pretrained( |
| | model_path, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float32, |
| | low_cpu_mem_usage=True |
| | ) |
| | torch_model.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | pixel_values_5d = pixel_values_pt.unsqueeze(0) |
| | num_patches = pixel_values_pt.shape[0] |
| | |
| | pos_ids_pt = torch.arange(num_patches).unsqueeze(0).to(torch_model.device) |
| | |
| | |
| | pt_vision_out = torch_model.visual( |
| | pixel_values_5d, |
| | position_ids=pos_ids_pt, |
| | image_grid_thw=grid_thw |
| | ) |
| | if hasattr(pt_vision_out, "last_hidden_state"): |
| | pt_hidden_states = pt_vision_out.last_hidden_state.detach().cpu().numpy() |
| | else: |
| | pt_hidden_states = pt_vision_out.detach().cpu().numpy() |
| | |
| | print(f"✅ PyTorch Hidden States 获取成功: {pt_hidden_states.shape}") |
| |
|
| | |
| | print("\n🔍 正在获取 MLX Vision Hidden States...") |
| | mlx_ocr = FinalOptimizedPaddleOCRMLX(mlx_path) |
| | |
| | |
| | pixel_values_mx = mx.array(pixel_values_pt.numpy().transpose(0, 2, 3, 1)) |
| | |
| | def get_mlx_vision_hidden(model, pixel_values): |
| | num_p = pixel_values.shape[0] |
| | |
| | x = model.patch_embedding(pixel_values) |
| | x = x.reshape(num_p, model.vision_hidden_size) |
| | x = mx.expand_dims(x, 0) |
| | |
| | |
| | |
| | if num_p <= 729: |
| | x = x + model.position_embedding[:num_p, :] |
| | else: |
| | |
| | pos_emb_repeated = mx.tile(model.position_embedding, (num_p // 729 + 1, 1)) |
| | x = x + pos_emb_repeated[:num_p, :] |
| | |
| | |
| | for i, layer in enumerate(model.vision_layers): |
| | x = layer(x, None) |
| | |
| | |
| | x = model.vision_norm(x) |
| | |
| | |
| | x = model.vision_head(x) |
| | |
| | |
| | x = model.post_layernorm(x) |
| | |
| | return x |
| |
|
| | mx_hidden_states = np.array(get_mlx_vision_hidden(mlx_ocr.model, pixel_values_mx)) |
| | print(f"✅ MLX Hidden States 获取成功: {mx_hidden_states.shape}") |
| |
|
| | |
| | print("\n" + "="*60) |
| | print("📊 统计分析结果") |
| | print("="*60) |
| | |
| | |
| | |
| | |
| | |
| | pt_flat = pt_hidden_states.flatten() |
| | mx_flat = mx_hidden_states.flatten() |
| | |
| | if pt_flat.shape != mx_flat.shape: |
| | print(f"❌ 警告: 维度不匹配! PT: {pt_flat.shape}, MX: {mx_flat.shape}") |
| | |
| | min_size = min(pt_flat.size, mx_flat.size) |
| | pt_flat = pt_flat[:min_size] |
| | mx_flat = mx_flat[:min_size] |
| |
|
| | cos_sim = np.dot(pt_flat, mx_flat) / (np.linalg.norm(pt_flat) * np.linalg.norm(mx_flat)) |
| | mae = np.mean(np.abs(pt_flat - mx_flat)) |
| | max_diff = np.max(np.abs(pt_flat - mx_flat)) |
| | |
| | print(f"✨ 余弦相似度: {cos_sim:.8f} (目标 > 0.999)") |
| | print(f"📏 平均绝对误差 (MAE): {mae:.8f}") |
| | print(f"🚫 最大绝对差异: {max_diff:.8f}") |
| | |
| | if cos_sim > 0.9995: |
| | print("\n🏆 结论: 视觉编码器完全对齐!问题 100% 锁定在 LLM/RoPE 层面。") |
| | elif cos_sim > 0.99: |
| | print("\n✅ 结论: 视觉编码器基本对齐。") |
| | else: |
| | print("\n❌ 结论: 视觉编码器不齐!这才是 OCR 失效的根源。") |
| | print("="*60) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|