PaddleOCR-VL-MLX / vision_alignment_check.py
gamhtoi's picture
Upload PaddleOCR-VL-MLX - MLX optimized for Apple Silicon
d48a40f verified
#!/opt/homebrew/bin/python3
"""
真理之眼:对比 PyTorch 与 MLX 视觉编码器最后一层 Hidden States 的对齐情况
修复了 PyTorch 5D 输入问题。
"""
import torch
import mlx.core as mx
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoModel, AutoProcessor
import sys
from pathlib import Path
# 添加最终修复版的路径
sys.path.append('/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion')
from ultimate_fix import FinalOptimizedPaddleOCRMLX
def create_test_image(text="Hi"):
img = Image.new('RGB', (100, 100), color='white')
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 40)
draw.text((20, 20), text, fill='black', font=font)
except:
draw.text((20, 20), text, fill='black')
return img
def main():
print("🚀 启动视觉编码器对齐检查 (V2)...")
model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl"
mlx_path = "/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion"
# 1. 准备输入
image = create_test_image()
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# 使用 processor 进行强制处理
inputs = processor(images=image, text="OCR", return_tensors="pt")
pixel_values_pt = inputs['pixel_values'] # [num_patches, 3, 14, 14]
grid_thw = inputs['image_grid_thw'].tolist()
print(f"📸 输入图像 Patch 数量: {pixel_values_pt.shape[0]}")
print(f"📊 Grid THW: {grid_thw}")
# 2. PyTorch 推理
print("\n🔍 正在获取 PyTorch Vision Hidden States...")
torch_model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
torch_model.eval()
with torch.no_grad():
# 满足 5D 约束
pixel_values_5d = pixel_values_pt.unsqueeze(0)
num_patches = pixel_values_pt.shape[0]
# 构造 pos_ids (0..N-1)
pos_ids_pt = torch.arange(num_patches).unsqueeze(0).to(torch_model.device)
# 调用视觉部分
pt_vision_out = torch_model.visual(
pixel_values_5d,
position_ids=pos_ids_pt,
image_grid_thw=grid_thw
)
if hasattr(pt_vision_out, "last_hidden_state"):
pt_hidden_states = pt_vision_out.last_hidden_state.detach().cpu().numpy()
else:
pt_hidden_states = pt_vision_out.detach().cpu().numpy()
print(f"✅ PyTorch Hidden States 获取成功: {pt_hidden_states.shape}")
# 3. MLX 推理
print("\n🔍 正在获取 MLX Vision Hidden States...")
mlx_ocr = FinalOptimizedPaddleOCRMLX(mlx_path)
# 转换为 MLX 格式 [num_patches, H, W, 3]
pixel_values_mx = mx.array(pixel_values_pt.numpy().transpose(0, 2, 3, 1))
def get_mlx_vision_hidden(model, pixel_values):
num_p = pixel_values.shape[0]
# 1. Patch Embedding
x = model.patch_embedding(pixel_values) # [N, 1, 1, 1152]
x = x.reshape(num_p, model.vision_hidden_size)
x = mx.expand_dims(x, 0) # [batch=1, N, 1152]
# 2. Position Embedding
# 注意:这里我们使用的是与 PyTorch 对齐的顺序嵌入
if num_p <= 729:
x = x + model.position_embedding[:num_p, :]
else:
# 这里的重复逻辑必须与 PyTorch 完全一致
pos_emb_repeated = mx.tile(model.position_embedding, (num_p // 729 + 1, 1))
x = x + pos_emb_repeated[:num_p, :]
# 3. Vision Layers
for i, layer in enumerate(model.vision_layers):
x = layer(x, None)
# 4. Final Norm
x = model.vision_norm(x)
# 5. Vision Head
x = model.vision_head(x)
# 6. Post LayerNorm
x = model.post_layernorm(x)
return x
mx_hidden_states = np.array(get_mlx_vision_hidden(mlx_ocr.model, pixel_values_mx))
print(f"✅ MLX Hidden States 获取成功: {mx_hidden_states.shape}")
# 4. 对齐分析
print("\n" + "="*60)
print("📊 统计分析结果")
print("="*60)
# 对齐维度
# PyTorch 输出通常是 [batch, seq, dim] -> [1, 784, 1152]
# MLX 输出通常是 [1, 784, 1152]
pt_flat = pt_hidden_states.flatten()
mx_flat = mx_hidden_states.flatten()
if pt_flat.shape != mx_flat.shape:
print(f"❌ 警告: 维度不匹配! PT: {pt_flat.shape}, MX: {mx_flat.shape}")
# 限制到共同部分进行对比
min_size = min(pt_flat.size, mx_flat.size)
pt_flat = pt_flat[:min_size]
mx_flat = mx_flat[:min_size]
cos_sim = np.dot(pt_flat, mx_flat) / (np.linalg.norm(pt_flat) * np.linalg.norm(mx_flat))
mae = np.mean(np.abs(pt_flat - mx_flat))
max_diff = np.max(np.abs(pt_flat - mx_flat))
print(f"✨ 余弦相似度: {cos_sim:.8f} (目标 > 0.999)")
print(f"📏 平均绝对误差 (MAE): {mae:.8f}")
print(f"🚫 最大绝对差异: {max_diff:.8f}")
if cos_sim > 0.9995:
print("\n🏆 结论: 视觉编码器完全对齐!问题 100% 锁定在 LLM/RoPE 层面。")
elif cos_sim > 0.99:
print("\n✅ 结论: 视觉编码器基本对齐。")
else:
print("\n❌ 结论: 视觉编码器不齐!这才是 OCR 失效的根源。")
print("="*60)
if __name__ == "__main__":
main()