PaddleOCR-VL-MLX / mlx_components.py
gamhtoi's picture
Upload PaddleOCR-VL-MLX - MLX optimized for Apple Silicon
d48a40f verified
#!/opt/homebrew/bin/python3
"""
PaddleOCR-VL MLX 实现 - 第一阶段:基础组件
这个文件实现了 PaddleOCR-VL 模型的基础 MLX 组件
作者: AI Assistant
日期: 2024-12-24
"""
import mlx.core as mx
import mlx.nn as nn
import math
from typing import Optional, Tuple
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (MLX 版本)"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states):
variance = mx.mean(mx.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class RotaryEmbedding(nn.Module):
"""旋转位置编码 (RoPE)"""
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# 预计算频率
inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
self.inv_freq = inv_freq
def __call__(self, x, seq_len: int):
# 生成位置索引
t = mx.arange(seq_len, dtype=mx.float32)
# 计算频率
freqs = mx.outer(t, self.inv_freq)
# 生成 sin 和 cos
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb)
sin = mx.sin(emb)
return cos, sin
def rotate_half(x):
"""旋转输入张量的一半维度"""
x1, x2 = mx.split(x, 2, axis=-1)
return mx.concatenate([-x2, x1], axis=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""应用旋转位置编码"""
# q, k shape: [B, H, L, D]
# cos, sin shape: [L, D]
# 需要reshape为 [1, 1, L, D] 以便广播
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Attention(nn.Module):
"""多头注意力机制 (MLX 版本)"""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
# 使用明确指定的 head_dim,或者从 hidden_size 计算
self.head_dim = head_dim if head_dim is not None else (hidden_size // num_heads)
# Q, K, V 投影
self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False)
self.rotary_emb = RotaryEmbedding(self.head_dim)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
batch_size, seq_length, _ = hidden_states.shape
# 投影到 Q, K, V
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# 重塑为多头
query_states = query_states.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
key_states = key_states.reshape(batch_size, seq_length, self.num_kv_heads, self.head_dim)
value_states = value_states.reshape(batch_size, seq_length, self.num_kv_heads, self.head_dim)
# 应用旋转位置编码(在 transpose 之前)
cos, sin = self.rotary_emb(query_states, seq_length)
# cos, sin shape: [seq_length, head_dim]
# 需要 reshape 为 [1, seq_length, 1, head_dim]
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
query_states = (query_states * cos) + (rotate_half(query_states) * sin)
key_states = (key_states * cos) + (rotate_half(key_states) * sin)
# 转置以进行注意力计算
query_states = mx.transpose(query_states, (0, 2, 1, 3)) # [B, H, L, D]
key_states = mx.transpose(key_states, (0, 2, 1, 3)) # [B, KV_H, L, D]
value_states = mx.transpose(value_states, (0, 2, 1, 3)) # [B, KV_H, L, D]
# Grouped Query Attention: 重复 key 和 value states
if self.num_kv_heads != self.num_heads:
# 计算每个 KV head 需要重复的次数
num_repeats = self.num_heads // self.num_kv_heads
# 重复 key 和 value states
# [B, KV_H, L, D] -> [B, KV_H, num_repeats, L, D] -> [B, H, L, D]
key_states = mx.repeat(key_states[:, :, None, :, :], num_repeats, axis=2)
key_states = key_states.reshape(batch_size, self.num_heads, seq_length, self.head_dim)
value_states = mx.repeat(value_states[:, :, None, :, :], num_repeats, axis=2)
value_states = value_states.reshape(batch_size, self.num_heads, seq_length, self.head_dim)
# 计算注意力分数
attn_weights = mx.matmul(query_states, mx.transpose(key_states, (0, 1, 3, 2)))
attn_weights = attn_weights / math.sqrt(self.head_dim)
# 应用注意力掩码
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# Softmax
attn_weights = mx.softmax(attn_weights, axis=-1)
# 应用注意力到值
attn_output = mx.matmul(attn_weights, value_states)
# 重塑回原始形状
attn_output = mx.transpose(attn_output, (0, 2, 1, 3))
attn_output = attn_output.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
# 输出投影
attn_output = self.o_proj(attn_output)
return attn_output
class MLP(nn.Module):
"""前馈神经网络 (MLX 版本)"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
activation: str = "gelu",
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
# 激活函数
if activation == "gelu":
self.act_fn = nn.gelu
elif activation == "silu":
self.act_fn = nn.silu
else:
self.act_fn = nn.gelu
def __call__(self, x):
# SwiGLU 风格的 MLP
gate = self.act_fn(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class DecoderLayer(nn.Module):
"""Transformer 解码器层"""
def __init__(
self,
hidden_size: int,
num_heads: int,
intermediate_size: int,
num_kv_heads: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__()
self.self_attn = Attention(hidden_size, num_heads, num_kv_heads, head_dim)
self.mlp = MLP(hidden_size, intermediate_size)
self.input_layernorm = RMSNorm(hidden_size)
self.post_attention_layernorm = RMSNorm(hidden_size)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
# 自注意力
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class VisionEncoder(nn.Module):
"""视觉编码器 - 简化版本"""
def __init__(
self,
image_size: int = 224,
patch_size: int = 14,
hidden_size: int = 1024,
num_layers: int = 24,
num_heads: int = 16,
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
# Patch 嵌入
self.patch_embed = nn.Conv2d(
in_channels=3,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size,
)
# 位置嵌入
self.position_embedding = mx.zeros((1, self.num_patches, hidden_size))
# Transformer 层
self.layers = [
DecoderLayer(hidden_size, num_heads, hidden_size * 4)
for _ in range(num_layers)
]
self.norm = RMSNorm(hidden_size)
def __call__(self, pixel_values: mx.array) -> mx.array:
# pixel_values: [B, C, H, W]
batch_size = pixel_values.shape[0]
# Patch 嵌入
x = self.patch_embed(pixel_values) # [B, hidden_size, H/P, W/P]
# 重塑为序列
x = x.reshape(batch_size, x.shape[1], -1) # [B, hidden_size, num_patches]
x = mx.transpose(x, (0, 2, 1)) # [B, num_patches, hidden_size]
# 添加位置嵌入
x = x + self.position_embedding
# Transformer 层
for layer in self.layers:
x = layer(x)
# 最终归一化
x = self.norm(x)
return x
# 测试代码
if __name__ == "__main__":
print("🧪 测试 MLX 基础组件...")
# 测试 RMSNorm
print("\n1. 测试 RMSNorm...")
norm = RMSNorm(hidden_size=768)
x = mx.random.normal((2, 10, 768))
out = norm(x)
print(f" 输入形状: {x.shape}, 输出形状: {out.shape}")
# 测试 Attention
print("\n2. 测试 Attention...")
attn = Attention(hidden_size=768, num_heads=12)
x = mx.random.normal((2, 10, 768))
out = attn(x)
print(f" 输入形状: {x.shape}, 输出形状: {out.shape}")
# 测试 MLP
print("\n3. 测试 MLP...")
mlp = MLP(hidden_size=768, intermediate_size=3072)
x = mx.random.normal((2, 10, 768))
out = mlp(x)
print(f" 输入形状: {x.shape}, 输出形状: {out.shape}")
# 测试 DecoderLayer
print("\n4. 测试 DecoderLayer...")
layer = DecoderLayer(hidden_size=768, num_heads=12, intermediate_size=3072)
x = mx.random.normal((2, 10, 768))
out = layer(x)
print(f" 输入形状: {x.shape}, 输出形状: {out.shape}")
print("\n✅ 所有基础组件测试通过!")