#!/opt/homebrew/bin/python3 """ PaddleOCR-VL MLX 实现 - 第一阶段:基础组件 这个文件实现了 PaddleOCR-VL 模型的基础 MLX 组件 作者: AI Assistant 日期: 2024-12-24 """ import mlx.core as mx import mlx.nn as nn import math from typing import Optional, Tuple class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (MLX 版本)""" def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = mx.ones((hidden_size,)) self.variance_epsilon = eps def __call__(self, hidden_states): variance = mx.mean(mx.square(hidden_states), axis=-1, keepdims=True) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states class RotaryEmbedding(nn.Module): """旋转位置编码 (RoPE)""" def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base # 预计算频率 inv_freq = 1.0 / (self.base ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)) self.inv_freq = inv_freq def __call__(self, x, seq_len: int): # 生成位置索引 t = mx.arange(seq_len, dtype=mx.float32) # 计算频率 freqs = mx.outer(t, self.inv_freq) # 生成 sin 和 cos emb = mx.concatenate([freqs, freqs], axis=-1) cos = mx.cos(emb) sin = mx.sin(emb) return cos, sin def rotate_half(x): """旋转输入张量的一半维度""" x1, x2 = mx.split(x, 2, axis=-1) return mx.concatenate([-x2, x1], axis=-1) def apply_rotary_pos_emb(q, k, cos, sin): """应用旋转位置编码""" # q, k shape: [B, H, L, D] # cos, sin shape: [L, D] # 需要reshape为 [1, 1, L, D] 以便广播 cos = cos[None, None, :, :] sin = sin[None, None, :, :] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class Attention(nn.Module): """多头注意力机制 (MLX 版本)""" def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: Optional[int] = None, head_dim: Optional[int] = None, ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads # 使用明确指定的 head_dim,或者从 hidden_size 计算 self.head_dim = head_dim if head_dim is not None else (hidden_size // num_heads) # Q, K, V 投影 self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) self.rotary_emb = RotaryEmbedding(self.head_dim) def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, ) -> mx.array: batch_size, seq_length, _ = hidden_states.shape # 投影到 Q, K, V query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # 重塑为多头 query_states = query_states.reshape(batch_size, seq_length, self.num_heads, self.head_dim) key_states = key_states.reshape(batch_size, seq_length, self.num_kv_heads, self.head_dim) value_states = value_states.reshape(batch_size, seq_length, self.num_kv_heads, self.head_dim) # 应用旋转位置编码(在 transpose 之前) cos, sin = self.rotary_emb(query_states, seq_length) # cos, sin shape: [seq_length, head_dim] # 需要 reshape 为 [1, seq_length, 1, head_dim] cos = cos[None, :, None, :] sin = sin[None, :, None, :] query_states = (query_states * cos) + (rotate_half(query_states) * sin) key_states = (key_states * cos) + (rotate_half(key_states) * sin) # 转置以进行注意力计算 query_states = mx.transpose(query_states, (0, 2, 1, 3)) # [B, H, L, D] key_states = mx.transpose(key_states, (0, 2, 1, 3)) # [B, KV_H, L, D] value_states = mx.transpose(value_states, (0, 2, 1, 3)) # [B, KV_H, L, D] # Grouped Query Attention: 重复 key 和 value states if self.num_kv_heads != self.num_heads: # 计算每个 KV head 需要重复的次数 num_repeats = self.num_heads // self.num_kv_heads # 重复 key 和 value states # [B, KV_H, L, D] -> [B, KV_H, num_repeats, L, D] -> [B, H, L, D] key_states = mx.repeat(key_states[:, :, None, :, :], num_repeats, axis=2) key_states = key_states.reshape(batch_size, self.num_heads, seq_length, self.head_dim) value_states = mx.repeat(value_states[:, :, None, :, :], num_repeats, axis=2) value_states = value_states.reshape(batch_size, self.num_heads, seq_length, self.head_dim) # 计算注意力分数 attn_weights = mx.matmul(query_states, mx.transpose(key_states, (0, 1, 3, 2))) attn_weights = attn_weights / math.sqrt(self.head_dim) # 应用注意力掩码 if attention_mask is not None: attn_weights = attn_weights + attention_mask # Softmax attn_weights = mx.softmax(attn_weights, axis=-1) # 应用注意力到值 attn_output = mx.matmul(attn_weights, value_states) # 重塑回原始形状 attn_output = mx.transpose(attn_output, (0, 2, 1, 3)) attn_output = attn_output.reshape(batch_size, seq_length, self.num_heads * self.head_dim) # 输出投影 attn_output = self.o_proj(attn_output) return attn_output class MLP(nn.Module): """前馈神经网络 (MLX 版本)""" def __init__( self, hidden_size: int, intermediate_size: int, activation: str = "gelu", ): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) # 激活函数 if activation == "gelu": self.act_fn = nn.gelu elif activation == "silu": self.act_fn = nn.silu else: self.act_fn = nn.gelu def __call__(self, x): # SwiGLU 风格的 MLP gate = self.act_fn(self.gate_proj(x)) up = self.up_proj(x) return self.down_proj(gate * up) class DecoderLayer(nn.Module): """Transformer 解码器层""" def __init__( self, hidden_size: int, num_heads: int, intermediate_size: int, num_kv_heads: Optional[int] = None, head_dim: Optional[int] = None, ): super().__init__() self.self_attn = Attention(hidden_size, num_heads, num_kv_heads, head_dim) self.mlp = MLP(hidden_size, intermediate_size) self.input_layernorm = RMSNorm(hidden_size) self.post_attention_layernorm = RMSNorm(hidden_size) def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, ) -> mx.array: # 自注意力 residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn(hidden_states, attention_mask) hidden_states = residual + hidden_states # MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class VisionEncoder(nn.Module): """视觉编码器 - 简化版本""" def __init__( self, image_size: int = 224, patch_size: int = 14, hidden_size: int = 1024, num_layers: int = 24, num_heads: int = 16, ): super().__init__() self.image_size = image_size self.patch_size = patch_size self.num_patches = (image_size // patch_size) ** 2 # Patch 嵌入 self.patch_embed = nn.Conv2d( in_channels=3, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size, ) # 位置嵌入 self.position_embedding = mx.zeros((1, self.num_patches, hidden_size)) # Transformer 层 self.layers = [ DecoderLayer(hidden_size, num_heads, hidden_size * 4) for _ in range(num_layers) ] self.norm = RMSNorm(hidden_size) def __call__(self, pixel_values: mx.array) -> mx.array: # pixel_values: [B, C, H, W] batch_size = pixel_values.shape[0] # Patch 嵌入 x = self.patch_embed(pixel_values) # [B, hidden_size, H/P, W/P] # 重塑为序列 x = x.reshape(batch_size, x.shape[1], -1) # [B, hidden_size, num_patches] x = mx.transpose(x, (0, 2, 1)) # [B, num_patches, hidden_size] # 添加位置嵌入 x = x + self.position_embedding # Transformer 层 for layer in self.layers: x = layer(x) # 最终归一化 x = self.norm(x) return x # 测试代码 if __name__ == "__main__": print("🧪 测试 MLX 基础组件...") # 测试 RMSNorm print("\n1. 测试 RMSNorm...") norm = RMSNorm(hidden_size=768) x = mx.random.normal((2, 10, 768)) out = norm(x) print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") # 测试 Attention print("\n2. 测试 Attention...") attn = Attention(hidden_size=768, num_heads=12) x = mx.random.normal((2, 10, 768)) out = attn(x) print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") # 测试 MLP print("\n3. 测试 MLP...") mlp = MLP(hidden_size=768, intermediate_size=3072) x = mx.random.normal((2, 10, 768)) out = mlp(x) print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") # 测试 DecoderLayer print("\n4. 测试 DecoderLayer...") layer = DecoderLayer(hidden_size=768, num_heads=12, intermediate_size=3072) x = mx.random.normal((2, 10, 768)) out = layer(x) print(f" 输入形状: {x.shape}, 输出形状: {out.shape}") print("\n✅ 所有基础组件测试通过!")