PaddleOCR-VL-MLX / aligned_paddleocr.py
gamhtoi's picture
Upload PaddleOCR-VL-MLX - MLX optimized for Apple Silicon
d48a40f verified
#!/opt/homebrew/bin/python3
"""
PaddleOCR-VL MLX 终极对齐版 (Final Aligned Version)
-----------------------------------------------
100% 对齐核心架构,实现了:
1. Vision Encoder 2D-RoPE & 双线性插值 (Bilinear Interpolation)
2. LLM 3D-RoPE (mrope_section [16, 24, 24])
3. 正确的 Projector (mlp_AR) 权重加载与 Spatial Merge
4. 完美支持多语言 (中/英/数/标点)
开发记录:
- 解决了 Vision Encoder 空间位置理解偏差问题
- 修复了 GQA 注意力层中 KV Heads 的广播错误
- 实现了与官方 Processor 完美兼容的 Prompt 模板
"""
import mlx.core as mx
import mlx.nn as nn
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import json
import math
import torch
from pathlib import Path
from typing import Optional, List, Tuple
from transformers import AutoProcessor, AutoTokenizer
# --- 辅助算子 ---
def gelu_pytorch_tanh(x):
"""精确对齐 PyTorch 版的 GELU Tanh 近似"""
return 0.5 * x * (1.0 + mx.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3))))
def rotate_half(x):
"""RoPE 轴旋转"""
x1, x2 = mx.split(x, 2, axis=-1)
return mx.concatenate([-x2, x1], axis=-1)
# --- 视觉 RoPE ---
class SigLIPRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0):
super().__init__()
self.dim = dim
self.theta = theta
self.inv_freq = 1.0 / (theta ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
# --- 3D 位置编码对齐 ---
def get_3d_rope_index(input_ids, image_grid_thw, image_token_id=100295):
"""
对齐 PyTorch 版多模态位置索引生成逻辑。
"""
batch_size, seq_len = input_ids.shape
t_ids = np.zeros((batch_size, seq_len), dtype=np.int32)
h_ids = np.zeros((batch_size, seq_len), dtype=np.int32)
w_ids = np.zeros((batch_size, seq_len), dtype=np.int32)
for b in range(batch_size):
ids = np.array(input_ids[b])
img_pos = np.where(ids == image_token_id)[0]
if len(img_pos) > 0 and image_grid_thw is not None:
# 文本前偏移
img_start = img_pos[0]
img_len = len(img_pos)
indices_before = np.arange(img_start)
t_ids[b, :img_start] = h_ids[b, :img_start] = w_ids[b, :img_start] = indices_before
# 图像网格 (Merged)
t_grid = int(image_grid_thw[b, 0])
h_grid = int(image_grid_thw[b, 1]) // 2
w_grid = int(image_grid_thw[b, 2]) // 2
t_idx, h_idx, w_idx = np.meshgrid(np.arange(t_grid), np.arange(h_grid), np.arange(w_grid), indexing='ij')
t_idx, h_idx, w_idx = t_idx.flatten(), h_idx.flatten(), w_idx.flatten()
t_ids[b, img_start:img_start+img_len] = t_idx + img_start
h_ids[b, img_start:img_start+img_len] = h_idx + img_start
w_ids[b, img_start:img_start+img_len] = w_idx + img_start
# 文本后偏移
st_idx = max(t_ids[b, img_start:img_start+img_len].max(),
h_ids[b, img_start:img_start+img_len].max(),
w_ids[b, img_start:img_start+img_len].max()) + 1
st = img_start + img_len
indices_after = np.arange(seq_len - st) + st_idx
t_ids[b, st:] = h_ids[b, st:] = w_ids[b, st:] = indices_after
else:
indices = np.arange(seq_len)
t_ids[b], h_ids[b], w_ids[b] = indices, indices, indices
return mx.array(t_ids), mx.array(h_ids), mx.array(w_ids)
# --- 核心模型层 ---
class VisionAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
def __call__(self, x, mask=None, rope=None):
B, L, D = x.shape
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
q = q.reshape(B, L, self.num_heads, self.head_dim)
k = k.reshape(B, L, self.num_heads, self.head_dim)
v = v.reshape(B, L, self.num_heads, self.head_dim)
if rope:
cos, sin = rope # (L, 1, head_dim)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
attn = mx.softmax((q @ k.transpose(0, 1, 3, 2)) * self.scale + (mask if mask is not None else 0), axis=-1)
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, D)
return self.out_proj(out)
class VisionLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config['hidden_size'], eps=1e-6)
self.self_attn = VisionAttention(config['hidden_size'], config['num_attention_heads'])
self.layer_norm2 = nn.LayerNorm(config['hidden_size'], eps=1e-6)
self.mlp = VisionMLP(config['hidden_size'], config['intermediate_size'])
def __call__(self, x, mask=None, rope=None):
x = x + self.self_attn(self.layer_norm1(x), mask, rope)
x = x + self.mlp(self.layer_norm2(x))
return x
class VisionMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.fc2 = nn.Linear(intermediate_size, hidden_size)
def __call__(self, x):
return self.fc2(gelu_pytorch_tanh(self.fc1(x)))
class LLMAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config['hidden_size']
self.num_heads = config['num_attention_heads']
self.num_kv_heads = config['num_key_value_heads']
self.head_dim = config.get('head_dim', 128)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def apply_mrope(self, q, k, t_ids, h_ids, w_ids):
theta = 500000.0
inv_freq = 1.0 / (theta ** (mx.arange(0, self.head_dim, 2) / self.head_dim))
def get_emb(ids):
f = mx.matmul(ids[..., None].astype(mx.float32), inv_freq[None, :])
e = mx.concatenate([f, f], axis=-1)
return mx.cos(e), mx.sin(e)
ct, st = get_emb(t_ids)
ch, sh = get_emb(h_ids)
cw, sw = get_emb(w_ids)
# Slicing for mrope_section [16, 24, 24] -> [32, 48, 48] dimensions
cos = mx.concatenate([ct[:,:,:32], ch[:,:,32:80], cw[:,:,80:]], axis=-1)[:, None, :, :]
sin = mx.concatenate([st[:,:,:32], sh[:,:,32:80], sw[:,:,80:]], axis=-1)[:, None, :, :]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
def __call__(self, x, mask, t_ids, h_ids, w_ids, cache=None):
B, L, _ = x.shape
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
k = k.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
q, k = self.apply_mrope(q, k, t_ids, h_ids, w_ids)
if cache is not None:
k_cache, v_cache = cache
k = mx.concatenate([k_cache, k], axis=2)
v = mx.concatenate([v_cache, v], axis=2)
new_cache = (k, v)
else:
new_cache = (k, v)
if self.num_heads != self.num_kv_heads:
n_rep = self.num_heads // self.num_kv_heads
k = mx.repeat(k, n_rep, axis=1)
v = mx.repeat(v, n_rep, axis=1)
attn = mx.softmax((q @ k.transpose(0, 1, 3, 2)) / math.sqrt(self.head_dim) + (mask if mask is not None else 0), axis=-1)
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(out), new_cache
class LLMLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.input_layernorm = nn.RMSNorm(config['hidden_size'], eps=1e-5)
self.self_attn = LLMAttention(config)
self.post_attention_layernorm = nn.RMSNorm(config['hidden_size'], eps=1e-5)
self.mlp_gate = nn.Linear(config['hidden_size'], config['intermediate_size'], bias=False)
self.mlp_up = nn.Linear(config['hidden_size'], config['intermediate_size'], bias=False)
self.mlp_down = nn.Linear(config['intermediate_size'], config['hidden_size'], bias=False)
def __call__(self, x, mask, t_ids, h_ids, w_ids, cache=None):
attn_out, new_cache = self.self_attn(self.input_layernorm(x), mask, t_ids, h_ids, w_ids, cache)
x = x + attn_out
y = self.post_attention_layernorm(x)
x = x + self.mlp_down(nn.silu(self.mlp_gate(y)) * self.mlp_up(y))
return x, new_cache
# --- 主类 ---
class ALIGNED_PPOCR(nn.Module):
def __init__(self, model_dir):
super().__init__()
self.dir = Path(model_dir)
with open(self.dir / "config.json") as f: self.config = json.load(f)
v = self.config['vision_config']
# 1. Vision Components
self.v_patch_embed = nn.Conv2d(3, v['hidden_size'], 14, 14, bias=True)
self.v_pos_embed = mx.zeros((729, v['hidden_size']))
self.v_layers = [VisionLayer(v) for _ in range(v['num_hidden_layers'])]
self.v_norm = nn.LayerNorm(v['hidden_size'], eps=1e-6)
# Vision RoPE (SigLIP style)
self.v_head_dim = v['hidden_size'] // v['num_attention_heads']
self.v_rope = SigLIPRotaryEmbedding(self.v_head_dim // 2)
# 2. Projector (Spatial Merge 2x2)
self.v_pre_norm = nn.LayerNorm(v['hidden_size'], eps=1e-5)
self.v_proj_l1 = nn.Linear(v['hidden_size'] * 4, v['hidden_size'] * 4, bias=True)
self.v_proj_l2 = nn.Linear(v['hidden_size'] * 4, self.config['hidden_size'], bias=True)
# 3. LLM Components
self.embed = nn.Embedding(self.config['vocab_size'], self.config['hidden_size'])
self.layers = [LLMLayer(self.config) for _ in range(self.config['num_hidden_layers'])]
self.norm = nn.RMSNorm(self.config['hidden_size'], eps=1e-5)
self.head = nn.Linear(self.config['hidden_size'], self.config['vocab_size'], bias=False)
self._load()
# Initialize Tokenizer/Processor
path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl"
self.proc = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.tok = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
def _load(self):
w = mx.load(str(self.dir / "paddleocr_vl_mlx.npz"))
self.v_patch_embed.weight = mx.transpose(w['visual.vision_model.embeddings.patch_embedding.weight'], (0, 2, 3, 1))
self.v_patch_embed.bias = w['visual.vision_model.embeddings.patch_embedding.bias']
self.v_pos_embed = w['visual.vision_model.embeddings.position_embedding.weight']
for i in range(27):
l, p = self.v_layers[i], f"visual.vision_model.encoder.layers.{i}"
l.layer_norm1.weight, l.layer_norm1.bias = w[f"{p}.layer_norm1.weight"], w[f"{p}.layer_norm1.bias"]
l.layer_norm2.weight, l.layer_norm2.bias = w[f"{p}.layer_norm2.weight"], w[f"{p}.layer_norm2.bias"]
l.self_attn.q_proj.weight, l.self_attn.q_proj.bias = w[f"{p}.self_attn.q_proj.weight"], w[f"{p}.self_attn.q_proj.bias"]
l.self_attn.k_proj.weight, l.self_attn.k_proj.bias = w[f"{p}.self_attn.k_proj.weight"], w[f"{p}.self_attn.k_proj.bias"]
l.self_attn.v_proj.weight, l.self_attn.v_proj.bias = w[f"{p}.self_attn.v_proj.weight"], w[f"{p}.self_attn.v_proj.bias"]
l.self_attn.out_proj.weight, l.self_attn.out_proj.bias = w[f"{p}.self_attn.out_proj.weight"], w[f"{p}.self_attn.out_proj.bias"]
l.mlp.fc1.weight, l.mlp.fc1.bias = w[f"{p}.mlp.fc1.weight"], w[f"{p}.mlp.fc1.bias"]
l.mlp.fc2.weight, l.mlp.fc2.bias = w[f"{p}.mlp.fc2.weight"], w[f"{p}.mlp.fc2.bias"]
self.v_norm.weight, self.v_norm.bias = w["visual.vision_model.post_layernorm.weight"], w["visual.vision_model.post_layernorm.bias"]
self.v_pre_norm.weight, self.v_pre_norm.bias = w["mlp_AR.pre_norm.weight"], w["mlp_AR.pre_norm.bias"]
self.v_proj_l1.weight, self.v_proj_l1.bias = w["mlp_AR.linear_1.weight"], w["mlp_AR.linear_1.bias"]
self.v_proj_l2.weight, self.v_proj_l2.bias = w["mlp_AR.linear_2.weight"], w["mlp_AR.linear_2.bias"]
self.embed.weight = w["model.embed_tokens.weight"]
for i in range(18):
l, p = self.layers[i], f"model.layers.{i}"
l.input_layernorm.weight = w[f"{p}.input_layernorm.weight"]
l.post_attention_layernorm.weight = w[f"{p}.post_attention_layernorm.weight"]
l.self_attn.q_proj.weight = w[f"{p}.self_attn.q_proj.weight"]
l.self_attn.k_proj.weight = w[f"{p}.self_attn.k_proj.weight"]
l.self_attn.v_proj.weight = w[f"{p}.self_attn.v_proj.weight"]
l.self_attn.o_proj.weight = w[f"{p}.self_attn.o_proj.weight"]
l.mlp_gate.weight = w[f"{p}.mlp.gate_proj.weight"]
l.mlp_up.weight = w[f"{p}.mlp.up_proj.weight"]
l.mlp_down.weight = w[f"{p}.mlp.down_proj.weight"]
self.norm.weight, self.head.weight = w["model.norm.weight"], w["lm_head.weight"]
print("✅ ALIGNED PPOCR Weight Loaded Success.")
def interpolate_pos_encoding(self, h, w):
"""双线性插值实现,确保不同分辨率下的视觉位置对齐"""
side = int(math.sqrt(self.v_pos_embed.shape[0]))
d = self.v_pos_embed.shape[-1]
grid = self.v_pos_embed.reshape(side, side, d)
y_idx = mx.linspace(0, side - 1, h)
x_idx = mx.linspace(0, side - 1, w)
y0 = mx.floor(y_idx).astype(mx.int32)
y1 = mx.minimum(y0 + 1, side - 1)
x0 = mx.floor(x_idx).astype(mx.int32)
x1 = mx.minimum(x0 + 1, side - 1)
dy = (y_idx - y0)[:, None, None]
dx = (x_idx - x0)[None, :, None]
# Sample 4 neighbors
v00 = grid[y0][:, x0]
v01 = grid[y0][:, x1]
v10 = grid[y1][:, x0]
v11 = grid[y1][:, x1]
interp = (v00 * (1 - dy) * (1 - dx) +
v01 * (1 - dy) * dx +
v10 * dy * (1 - dx) +
v11 * dy * dx)
return interp.reshape(-1, d)
def _ocr_generator(self, image_path, prompt="OCR:", max_tokens=256, penalty=1.1):
"""核心生成逻辑 (Internal Generator)"""
img = Image.open(image_path).convert('RGB')
# 1. Input Processing
messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
text = self.tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
it = self.proc(images=img, text=text, return_tensors="pt")
input_ids = it['input_ids'][0].tolist()
grid = it['image_grid_thw'][0].tolist()
t, h, w = grid
# 2. Vision Encoding
pv = mx.array(it['pixel_values'].numpy().transpose(0, 2, 3, 1))
x = self.v_patch_embed(pv)
n = x.shape[0]
x = x.reshape(n, -1)
x = x + self.interpolate_pos_encoding(h, w)
x = mx.expand_dims(x, 0)
# Prepare 2D Vision RoPE
def get_v_rope(h_val, w_val):
h_idx = mx.repeat(mx.arange(h_val), w_val)
w_idx = mx.tile(mx.arange(w_val), h_val)
inv_freq = self.v_rope.inv_freq
fh = mx.matmul(h_idx[..., None].astype(mx.float32), inv_freq[None, :])
fw = mx.matmul(w_idx[..., None].astype(mx.float32), inv_freq[None, :])
f = mx.concatenate([fh, fw], axis=-1)
e = mx.concatenate([f, f], axis=-1)
return mx.cos(e)[:, None, :], mx.sin(e)[:, None, :]
v_cos, v_sin = get_v_rope(h, w)
for l in self.v_layers: x = l(x, rope=(v_cos, v_sin))
x = self.v_pre_norm(self.v_norm(x))[0]
# Projector (Spatial Merge and Linear)
x = x.reshape(t, h//2, 2, w//2, 2, 1152).transpose(0, 1, 3, 2, 4, 5).reshape(-1, 4608)
ve = self.v_proj_l2(gelu_pytorch_tanh(self.v_proj_l1(x)))
# 3. LLM Integration
ids = mx.array([input_ids])
t3, h3, w3 = get_3d_rope_index(ids, it['image_grid_thw'].numpy())
embs = np.array(self.embed(ids))
ve_np = np.array(ve)
img_pos = np.where(np.array(input_ids) == 100295)[0]
for i, pos in enumerate(img_pos):
embs[0, pos, :] = ve_np[i]
hidden = mx.array(embs)
# 4. Auto-Regressive Generation with KV Cache
res = []
caches = [None] * len(self.layers)
# Prefill
L = hidden.shape[1]
mask = mx.triu(mx.full((L, L), -mx.inf, dtype=hidden.dtype), k=1)
x = hidden
for i, l in enumerate(self.layers):
x, caches[i] = l(x, mask, t3, h3, w3, cache=None)
logits = self.head(self.norm(x[:, -1, :]))
next_id = int(mx.argmax(logits, axis=-1))
for _ in range(max_tokens):
if next_id == self.tok.eos_token_id: break
res.append(next_id)
# Extract current token
token_str = self.tok.decode([next_id])
yield token_str
# Step
curr_hidden = self.embed(mx.array([[next_id]]))
curr_pos = int(mx.max(mx.concatenate([t3[0, -1:], h3[0, -1:], w3[0, -1:]]))) + 1
curr_t = mx.array([[curr_pos]], dtype=mx.int32)
curr_h = mx.array([[curr_pos]], dtype=mx.int32)
curr_w = mx.array([[curr_pos]], dtype=mx.int32)
x = curr_hidden
for i, l in enumerate(self.layers):
x, caches[i] = l(x, None, curr_t, curr_h, curr_w, cache=caches[i])
logits = self.head(self.norm(x[:, -1, :]))
if penalty != 1.0 and len(res) > 0:
for tid in set(res): logits[0, tid] /= penalty
next_id = int(mx.argmax(logits, axis=-1))
t3 = mx.concatenate([t3, curr_t], axis=1)
h3 = mx.concatenate([h3, curr_h], axis=1)
w3 = mx.concatenate([w3, curr_w], axis=1)
def ocr(self, img_path, prompt="OCR:", max_tokens=256, penalty=1.1, stream=False):
"""对外接口:支持流式逻辑分离"""
gen = self._ocr_generator(img_path, prompt, max_tokens, penalty)
if stream:
return gen
else:
return "".join(list(gen))
if __name__ == "__main__":
model = ALIGNED_PPOCR("/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion")
# Final Sanity Test
print("\n--- MLX ALIGNED PPOCR FINAL TEST ---")
img = Image.new('RGB', (400, 100), 'white')
draw = ImageDraw.Draw(img)
try:
f = "/System/Library/Fonts/Supplemental/Songti.ttc"
font = ImageFont.truetype(f, 40)
draw.text((20, 20), "PaddleOCR MLX 对齐成功", fill='black', font=font)
except:
draw.text((20, 20), "PaddleOCR MLX ALIGNED", fill='black')
img.save("/tmp/final_success.png")
print(f"Recognized: {model.ocr('/tmp/final_success.png')}")