| | |
| | """ |
| | PaddleOCR-VL MLX 实现 - 第一阶段:基础组件 |
| | |
| | 这个文件实现了 PaddleOCR-VL 模型的基础 MLX 组件 |
| | 作者: AI Assistant |
| | 日期: 2024-12-24 |
| | """ |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | import math |
| | from typing import Optional, Tuple |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | """Root Mean Square Layer Normalization (MLX 版本)""" |
| | |
| | def __init__(self, hidden_size: int, eps: float = 1e-6): |
| | super().__init__() |
| | self.weight = mx.ones((hidden_size,)) |
| | self.variance_epsilon = eps |
| | |
| | def __call__(self, hidden_states): |
| | variance = mx.mean(mx.square(hidden_states), axis=-1, keepdims=True) |
| | hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) |
| | return self.weight * hidden_states |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """旋转位置编码 (RoPE)""" |
| | |
| | def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): |
| | super().__init__() |
| | self.dim = dim |
| | self.max_position_embeddings = max_position_embeddings |
| | self.base = base |
| | |
| | |
| | inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)) |
| | self.inv_freq = inv_freq |
| | |
| | def __call__(self, x, seq_len: int): |
| | |
| | t = mx.arange(seq_len, dtype=mx.float32) |
| | |
| | |
| | freqs = mx.outer(t, self.inv_freq) |
| | |
| | |
| | emb = mx.concatenate([freqs, freqs], axis=-1) |
| | cos = mx.cos(emb) |
| | sin = mx.sin(emb) |
| | |
| | return cos, sin |
| |
|
| |
|
| | def rotate_half(x): |
| | """旋转输入张量的一半维度""" |
| | x1, x2 = mx.split(x, 2, axis=-1) |
| | return mx.concatenate([-x2, x1], axis=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb(q, k, cos, sin): |
| | """应用旋转位置编码""" |
| | |
| | |
| | |
| | cos = cos[None, None, :, :] |
| | sin = sin[None, None, :, :] |
| | |
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|
| |
|
| | class Attention(nn.Module): |
| | """多头注意力机制 (MLX 版本)""" |
| | |
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | num_heads: int, |
| | num_kv_heads: Optional[int] = None, |
| | head_dim: Optional[int] = None, |
| | ): |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.num_heads = num_heads |
| | self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads |
| | |
| | self.head_dim = head_dim if head_dim is not None else (hidden_size // num_heads) |
| | |
| | |
| | self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) |
| | self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| | self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| | self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) |
| | |
| | self.rotary_emb = RotaryEmbedding(self.head_dim) |
| | |
| | def __call__( |
| | self, |
| | hidden_states: mx.array, |
| | attention_mask: Optional[mx.array] = None, |
| | ) -> mx.array: |
| | batch_size, seq_length, _ = hidden_states.shape |
| | |
| | |
| | query_states = self.q_proj(hidden_states) |
| | key_states = self.k_proj(hidden_states) |
| | value_states = self.v_proj(hidden_states) |
| | |
| | |
| | query_states = query_states.reshape(batch_size, seq_length, self.num_heads, self.head_dim) |
| | key_states = key_states.reshape(batch_size, seq_length, self.num_kv_heads, self.head_dim) |
| | value_states = value_states.reshape(batch_size, seq_length, self.num_kv_heads, self.head_dim) |
| | |
| | |
| | cos, sin = self.rotary_emb(query_states, seq_length) |
| | |
| | |
| | cos = cos[None, :, None, :] |
| | sin = sin[None, :, None, :] |
| | |
| | query_states = (query_states * cos) + (rotate_half(query_states) * sin) |
| | key_states = (key_states * cos) + (rotate_half(key_states) * sin) |
| | |
| | |
| | query_states = mx.transpose(query_states, (0, 2, 1, 3)) |
| | key_states = mx.transpose(key_states, (0, 2, 1, 3)) |
| | value_states = mx.transpose(value_states, (0, 2, 1, 3)) |
| | |
| | |
| | if self.num_kv_heads != self.num_heads: |
| | |
| | num_repeats = self.num_heads // self.num_kv_heads |
| | |
| | |
| | key_states = mx.repeat(key_states[:, :, None, :, :], num_repeats, axis=2) |
| | key_states = key_states.reshape(batch_size, self.num_heads, seq_length, self.head_dim) |
| | value_states = mx.repeat(value_states[:, :, None, :, :], num_repeats, axis=2) |
| | value_states = value_states.reshape(batch_size, self.num_heads, seq_length, self.head_dim) |
| | |
| | |
| | attn_weights = mx.matmul(query_states, mx.transpose(key_states, (0, 1, 3, 2))) |
| | attn_weights = attn_weights / math.sqrt(self.head_dim) |
| | |
| | |
| | if attention_mask is not None: |
| | attn_weights = attn_weights + attention_mask |
| | |
| | |
| | attn_weights = mx.softmax(attn_weights, axis=-1) |
| | |
| | |
| | attn_output = mx.matmul(attn_weights, value_states) |
| | |
| | |
| | attn_output = mx.transpose(attn_output, (0, 2, 1, 3)) |
| | attn_output = attn_output.reshape(batch_size, seq_length, self.num_heads * self.head_dim) |
| | |
| | |
| | attn_output = self.o_proj(attn_output) |
| | |
| | return attn_output |
| |
|
| |
|
| | class MLP(nn.Module): |
| | """前馈神经网络 (MLX 版本)""" |
| | |
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | intermediate_size: int, |
| | activation: str = "gelu", |
| | ): |
| | super().__init__() |
| | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
| | |
| | |
| | if activation == "gelu": |
| | self.act_fn = nn.gelu |
| | elif activation == "silu": |
| | self.act_fn = nn.silu |
| | else: |
| | self.act_fn = nn.gelu |
| | |
| | def __call__(self, x): |
| | |
| | gate = self.act_fn(self.gate_proj(x)) |
| | up = self.up_proj(x) |
| | return self.down_proj(gate * up) |
| |
|
| |
|
| | class DecoderLayer(nn.Module): |
| | """Transformer 解码器层""" |
| | |
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | num_heads: int, |
| | intermediate_size: int, |
| | num_kv_heads: Optional[int] = None, |
| | head_dim: Optional[int] = None, |
| | ): |
| | super().__init__() |
| | self.self_attn = Attention(hidden_size, num_heads, num_kv_heads, head_dim) |
| | self.mlp = MLP(hidden_size, intermediate_size) |
| | self.input_layernorm = RMSNorm(hidden_size) |
| | self.post_attention_layernorm = RMSNorm(hidden_size) |
| | |
| | def __call__( |
| | self, |
| | hidden_states: mx.array, |
| | attention_mask: Optional[mx.array] = None, |
| | ) -> mx.array: |
| | |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| | hidden_states = self.self_attn(hidden_states, attention_mask) |
| | hidden_states = residual + hidden_states |
| | |
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| | |
| | return hidden_states |
| |
|
| |
|
| | class VisionEncoder(nn.Module): |
| | """视觉编码器 - 简化版本""" |
| | |
| | def __init__( |
| | self, |
| | image_size: int = 224, |
| | patch_size: int = 14, |
| | hidden_size: int = 1024, |
| | num_layers: int = 24, |
| | num_heads: int = 16, |
| | ): |
| | super().__init__() |
| | self.image_size = image_size |
| | self.patch_size = patch_size |
| | self.num_patches = (image_size // patch_size) ** 2 |
| | |
| | |
| | self.patch_embed = nn.Conv2d( |
| | in_channels=3, |
| | out_channels=hidden_size, |
| | kernel_size=patch_size, |
| | stride=patch_size, |
| | ) |
| | |
| | |
| | self.position_embedding = mx.zeros((1, self.num_patches, hidden_size)) |
| | |
| | |
| | self.layers = [ |
| | DecoderLayer(hidden_size, num_heads, hidden_size * 4) |
| | for _ in range(num_layers) |
| | ] |
| | |
| | self.norm = RMSNorm(hidden_size) |
| | |
| | def __call__(self, pixel_values: mx.array) -> mx.array: |
| | |
| | batch_size = pixel_values.shape[0] |
| | |
| | |
| | x = self.patch_embed(pixel_values) |
| | |
| | |
| | x = x.reshape(batch_size, x.shape[1], -1) |
| | x = mx.transpose(x, (0, 2, 1)) |
| | |
| | |
| | x = x + self.position_embedding |
| | |
| | |
| | for layer in self.layers: |
| | x = layer(x) |
| | |
| | |
| | x = self.norm(x) |
| | |
| | return x |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | print("🧪 测试 MLX 基础组件...") |
| | |
| | |
| | print("\n1. 测试 RMSNorm...") |
| | norm = RMSNorm(hidden_size=768) |
| | x = mx.random.normal((2, 10, 768)) |
| | out = norm(x) |
| | print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") |
| | |
| | |
| | print("\n2. 测试 Attention...") |
| | attn = Attention(hidden_size=768, num_heads=12) |
| | x = mx.random.normal((2, 10, 768)) |
| | out = attn(x) |
| | print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") |
| | |
| | |
| | print("\n3. 测试 MLP...") |
| | mlp = MLP(hidden_size=768, intermediate_size=3072) |
| | x = mx.random.normal((2, 10, 768)) |
| | out = mlp(x) |
| | print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") |
| | |
| | |
| | print("\n4. 测试 DecoderLayer...") |
| | layer = DecoderLayer(hidden_size=768, num_heads=12, intermediate_size=3072) |
| | x = mx.random.normal((2, 10, 768)) |
| | out = layer(x) |
| | print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") |
| | |
| | print("\n✅ 所有基础组件测试通过!") |
| |
|