| | |
| | """ |
| | PaddleOCR-VL MLX 最终优化版 - 使用正确的图像预处理 |
| | 目标:达到原版准确度 80-90% |
| | |
| | 作者: AI Assistant |
| | 日期: 2024-12-25 |
| | 版本: v8.0 - 最终优化 |
| | """ |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | from PIL import Image, ImageDraw |
| | import numpy as np |
| | import json |
| | from pathlib import Path |
| | from typing import Optional, List, Tuple |
| | import time |
| | import torch |
| |
|
| | |
| | from mlx_components import ( |
| | RMSNorm, MLP, DecoderLayer |
| | ) |
| |
|
| |
|
| | def spatial_merge_mlx(x: mx.array, t: int, h: int, w: int, d: int, m1: int = 2, m2: int = 2) -> mx.array: |
| | """ |
| | MLX 版本的 spatial merge - 关键修复! |
| | |
| | 参数: |
| | x: shape (t*h*w, d) 的输入 |
| | t, h, w: image_grid_thw 的值 |
| | d: 特征维度 (1152) |
| | m1, m2: merge kernel size (2, 2) |
| | |
| | 返回: |
| | shape (t*h/2*w/2, 4*d) 的输出 |
| | """ |
| | |
| | x = x.reshape(t, h, w, d) |
| | |
| | |
| | x = x.reshape(t, h//m1, m1, w//m2, m2, d) |
| | |
| | |
| | x = mx.transpose(x, (0, 1, 3, 2, 4, 5)) |
| | |
| | |
| | x = x.reshape(t * (h//m1) * (w//m2), m1 * m2 * d) |
| | |
| | return x |
| |
|
| |
|
| | class VisionHeadAttention(nn.Module): |
| | """Vision Head 的注意力层""" |
| | |
| | def __init__(self, hidden_size: int = 1152): |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.num_heads = 16 |
| | self.head_dim = hidden_size // self.num_heads |
| | |
| | self.in_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True) |
| | self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True) |
| | |
| | def __call__(self, x: mx.array) -> mx.array: |
| | B, L, D = x.shape |
| | |
| | qkv = self.in_proj(x) |
| | qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim) |
| | qkv = mx.transpose(qkv, (2, 0, 3, 1, 4)) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | |
| | attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) / (self.head_dim ** 0.5) |
| | attn = mx.softmax(attn, axis=-1) |
| | out = mx.matmul(attn, v) |
| | |
| | out = mx.transpose(out, (0, 2, 1, 3)) |
| | out = out.reshape(B, L, D) |
| | out = self.out_proj(out) |
| | |
| | return out |
| |
|
| |
|
| | import math |
| |
|
| | def rotate_half(x): |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return mx.concatenate([-x2, x1], axis=-1) |
| |
|
| | def apply_mrope(q, k, cos, sin): |
| | |
| | |
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|
| | def get_3d_rope_index(input_ids, image_grid_thw, image_token_id=100295): |
| | batch_size, seq_len = input_ids.shape |
| | |
| | t_ids = np.zeros((batch_size, seq_len), dtype=np.int32) |
| | h_ids = np.zeros((batch_size, seq_len), dtype=np.int32) |
| | w_ids = np.zeros((batch_size, seq_len), dtype=np.int32) |
| | |
| | for b in range(batch_size): |
| | ids = np.array(input_ids[b]) |
| | img_pos = np.where(ids == image_token_id)[0] |
| | |
| | if len(img_pos) > 0 and image_grid_thw is not None: |
| | |
| | t = int(image_grid_thw[0, 0]) |
| | h = int(image_grid_thw[0, 1]) // 2 |
| | w = int(image_grid_thw[0, 2]) // 2 |
| | |
| | |
| | hh, ww = np.meshgrid(np.arange(h), np.arange(w), indexing='ij') |
| | hh = hh.flatten() |
| | ww = ww.flatten() |
| | |
| | limit = min(len(img_pos), len(hh)) |
| | h_ids[b, img_pos[:limit]] = hh[:limit] |
| | w_ids[b, img_pos[:limit]] = ww[:limit] |
| | |
| | |
| | max_pos = max(h, w) |
| | curr = max_pos + 1 |
| | for i in range(seq_len): |
| | if ids[i] != image_token_id and i > 0: |
| | t_ids[b, i] = curr |
| | h_ids[b, i] = curr |
| | w_ids[b, i] = curr |
| | curr += 1 |
| | else: |
| | |
| | seq = np.arange(seq_len) |
| | t_ids[b] = seq |
| | h_ids[b] = seq |
| | w_ids[b] = seq |
| | |
| | return mx.array(t_ids), mx.array(h_ids), mx.array(w_ids) |
| | class VisionHead(nn.Module): |
| | """Vision Head 层""" |
| | |
| | def __init__(self, hidden_size: int = 1152): |
| | super().__init__() |
| | self.attention = VisionHeadAttention(hidden_size) |
| | self.layernorm = nn.LayerNorm(hidden_size) |
| | self.mlp = MLP(hidden_size, 4304) |
| | self.probe = mx.zeros((1, 1, hidden_size)) |
| | |
| | def __call__(self, x: mx.array) -> mx.array: |
| | residual = x |
| | x = self.attention(x) |
| | x = residual + x |
| | |
| | x = self.layernorm(x) |
| | |
| | residual = x |
| | x = self.mlp(x) |
| | x = residual + x |
| | |
| | return x |
| |
|
| |
|
| | class FinalOptimizedPaddleOCRMLX: |
| | """最终优化版 PaddleOCR MLX - 使用正确的图像预处理""" |
| | |
| | def __init__(self, model_dir: str): |
| | self.model_dir = Path(model_dir) |
| | print("🚀 初始化最终优化版 PaddleOCR MLX...") |
| | print(f"📂 模型目录: {model_dir}") |
| | |
| | |
| | self.config = self._load_config() |
| | |
| | |
| | self.tokenizer = self._load_tokenizer() |
| | self.processor = self._load_processor() |
| | |
| | |
| | self.model = self._create_model() |
| | |
| | |
| | self._load_all_weights() |
| | |
| | print("✅ 初始化完成!") |
| | |
| | def _load_config(self) -> dict: |
| | """加载模型配置""" |
| | config_path = self.model_dir / "config.json" |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| | print(f"✅ 配置加载完成") |
| | return config |
| | |
| | def _load_tokenizer(self): |
| | """加载 tokenizer""" |
| | try: |
| | from transformers import AutoTokenizer |
| | original_model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | original_model_path, |
| | trust_remote_code=True |
| | ) |
| | print(f"✅ Tokenizer 加载完成 (词汇表: {len(tokenizer)})") |
| | return tokenizer |
| | except Exception as e: |
| | print(f"⚠️ Tokenizer 加载失败: {e}") |
| | return None |
| | |
| | def _load_processor(self): |
| | """加载 processor - 关键!""" |
| | try: |
| | from transformers import AutoProcessor |
| | original_model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" |
| | processor = AutoProcessor.from_pretrained( |
| | original_model_path, |
| | trust_remote_code=True |
| | ) |
| | print(f"✅ Processor 加载完成 ⭐ 关键改进") |
| | return processor |
| | except Exception as e: |
| | print(f"⚠️ Processor 加载失败: {e}") |
| | return None |
| | |
| | def _create_model(self): |
| | """创建完整模型""" |
| | print("🔄 创建完整模型...") |
| | |
| | class OptimizedModel(nn.Module): |
| | """优化的模型""" |
| | |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | |
| | |
| | self.hidden_size = config.get('hidden_size', 1024) |
| | self.vocab_size = config.get('vocab_size', 103424) |
| | self.intermediate_size = config.get('intermediate_size', 3072) |
| | self.num_attention_heads = config.get('num_attention_heads', 16) |
| | self.num_kv_heads = config.get('num_key_value_heads', 2) |
| | self.num_hidden_layers = config.get('num_hidden_layers', 18) |
| | self.head_dim = config.get('head_dim', 128) |
| | |
| | |
| | vision_config = config.get('vision_config', {}) |
| | self.vision_hidden_size = vision_config.get('hidden_size', 1152) |
| | self.vision_num_layers = 27 |
| | |
| | |
| | self.patch_embedding = nn.Conv2d( |
| | in_channels=3, |
| | out_channels=self.vision_hidden_size, |
| | kernel_size=14, |
| | stride=14, |
| | bias=True |
| | ) |
| | |
| | |
| | self.position_embedding = mx.zeros((729, self.vision_hidden_size)) |
| | |
| | |
| | self.packing_position_embedding = mx.zeros((32768, self.vision_hidden_size)) |
| | |
| | |
| | self.vision_layers = [ |
| | DecoderLayer( |
| | hidden_size=self.vision_hidden_size, |
| | num_heads=16, |
| | intermediate_size=4304, |
| | num_kv_heads=16, |
| | head_dim=72, |
| | ) |
| | for _ in range(self.vision_num_layers) |
| | ] |
| | |
| | |
| | self.vision_norm = RMSNorm(self.vision_hidden_size) |
| | |
| | |
| | self.vision_head = VisionHead(self.vision_hidden_size) |
| | |
| | |
| | self.post_layernorm = nn.LayerNorm(self.vision_hidden_size) |
| | |
| | |
| | self.vision_pre_norm = nn.LayerNorm(self.vision_hidden_size) |
| | self.vision_linear_1 = nn.Linear(4608, 4608, bias=True) |
| | self.vision_linear_2 = nn.Linear(4608, self.hidden_size, bias=True) |
| | |
| | |
| | self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) |
| | |
| | |
| | self.layers = [ |
| | DecoderLayer( |
| | hidden_size=self.hidden_size, |
| | num_heads=self.num_attention_heads, |
| | intermediate_size=self.intermediate_size, |
| | num_kv_heads=self.num_kv_heads, |
| | head_dim=self.head_dim, |
| | ) |
| | for _ in range(self.num_hidden_layers) |
| | ] |
| | |
| | |
| | self.norm = RMSNorm(self.hidden_size, eps=config.get('rms_norm_eps', 1e-6)) |
| | |
| | |
| | self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) |
| | |
| | def encode_image(self, pixel_values: mx.array, image_grid_thw: mx.array) -> mx.array: |
| | """编码图像 - 使用正确的 spatial merge ⭐ 关键修复""" |
| | |
| | |
| | num_patches, H, W, C = pixel_values.shape |
| | |
| | |
| | t = int(image_grid_thw[0, 0]) |
| | h = int(image_grid_thw[0, 1]) |
| | w = int(image_grid_thw[0, 2]) |
| | print(f" Grid info: t={t}, h={h}, w={w}, num_patches={num_patches}") |
| | |
| | |
| | x = self.patch_embedding(pixel_values) |
| | x = x.reshape(num_patches, self.vision_hidden_size) |
| | x = mx.expand_dims(x, 0) |
| | |
| | |
| | if num_patches <= 729: |
| | x = x + self.position_embedding[:num_patches, :] |
| | else: |
| | |
| | pos_emb_repeated = mx.tile(self.position_embedding, (num_patches // 729 + 1, 1)) |
| | x = x + pos_emb_repeated[:num_patches, :] |
| | |
| | |
| | for layer in self.vision_layers: |
| | x = layer(x, None) |
| | |
| | |
| | x = self.vision_norm(x) |
| | |
| | |
| | x = self.vision_head(x) |
| | |
| | |
| | x = self.post_layernorm(x) |
| | |
| | |
| | x = self.vision_pre_norm(x) |
| | x = x[0] |
| | |
| | |
| | x = spatial_merge_mlx(x, t, h, w, self.vision_hidden_size, m1=2, m2=2) |
| | print(f" After spatial merge: {x.shape}") |
| | |
| | |
| | x = self.vision_linear_1(x) |
| | x = nn.gelu(x) |
| | x = self.vision_linear_2(x) |
| | |
| | |
| | x = mx.expand_dims(x, 0) |
| | |
| | return x |
| | |
| | def forward(self, input_ids: mx.array, vision_embeds: Optional[mx.array] = None) -> mx.array: |
| | """前向传播 - 最终修复!⭐""" |
| | |
| | text_embeds = self.embed_tokens(input_ids) |
| | |
| | if vision_embeds is not None: |
| | |
| | image_token_id = 100295 |
| | batch_size = input_ids.shape[0] |
| | |
| | |
| | mask = (input_ids == image_token_id) |
| | |
| | |
| | |
| | |
| | |
| | vision_embeds_flat = vision_embeds[0] |
| | |
| | |
| | batch_size, seq_len, hidden_size = text_embeds.shape |
| | mask_np = np.array(mask) |
| | text_embeds_np = np.array(text_embeds) |
| | |
| | for i in range(batch_size): |
| | img_positions = np.where(mask_np[i])[0] |
| | for j, pos in enumerate(img_positions): |
| | if j < vision_embeds_flat.shape[0]: |
| | text_embeds_np[i, pos, :] = np.array(vision_embeds_flat[j]) |
| | |
| | text_embeds = mx.array(text_embeds_np) |
| | |
| | |
| | hidden_states = text_embeds |
| | else: |
| | hidden_states = text_embeds |
| | |
| | |
| | for layer in self.layers: |
| | hidden_states = layer(hidden_states, None) |
| | |
| | hidden_states = self.norm(hidden_states) |
| | logits = self.lm_head(hidden_states) |
| | |
| | return logits |
| | |
| | model = OptimizedModel(self.config) |
| | print("✅ 优化模型创建完成") |
| | return model |
| | |
| | def _load_all_weights(self): |
| | """加载所有权重""" |
| | print("\n" + "="*60) |
| | print("🔄 加载所有权重...") |
| | print("="*60) |
| | |
| | weights_path = self.model_dir / "paddleocr_vl_mlx.npz" |
| | weights = mx.load(str(weights_path)) |
| | print(f"\n📂 加载了 {len(weights)} 个权重张量") |
| | |
| | loaded_count = 0 |
| | |
| | try: |
| | |
| | print(f"\n📸 加载视觉编码器权重...") |
| | |
| | if 'visual.vision_model.embeddings.patch_embedding.weight' in weights: |
| | w = weights['visual.vision_model.embeddings.patch_embedding.weight'] |
| | w_transposed = mx.transpose(w, (0, 2, 3, 1)) |
| | self.model.patch_embedding.weight = w_transposed |
| | loaded_count += 1 |
| | if 'visual.vision_model.embeddings.patch_embedding.bias' in weights: |
| | self.model.patch_embedding.bias = weights['visual.vision_model.embeddings.patch_embedding.bias'] |
| | loaded_count += 1 |
| | |
| | if 'visual.vision_model.embeddings.position_embedding.weight' in weights: |
| | self.model.position_embedding = weights['visual.vision_model.embeddings.position_embedding.weight'] |
| | loaded_count += 1 |
| | |
| | |
| | if 'visual.vision_model.embeddings.packing_position_embedding.weight' in weights: |
| | self.model.packing_position_embedding = weights['visual.vision_model.embeddings.packing_position_embedding.weight'] |
| | loaded_count += 1 |
| | print(f"✅ 加载 packing_position_embedding") |
| | |
| | for i in range(27): |
| | layer = self.model.vision_layers[i] |
| | prefix = f'visual.vision_model.encoder.layers.{i}' |
| | |
| | for proj_name in ['q_proj', 'k_proj', 'v_proj']: |
| | w_key = f'{prefix}.self_attn.{proj_name}.weight' |
| | b_key = f'{prefix}.self_attn.{proj_name}.bias' |
| | if w_key in weights: |
| | proj = getattr(layer.self_attn, proj_name) |
| | proj.weight = weights[w_key] |
| | if b_key in weights: |
| | proj.bias = weights[b_key] |
| | loaded_count += 1 |
| | |
| | w_key = f'{prefix}.self_attn.out_proj.weight' |
| | b_key = f'{prefix}.self_attn.out_proj.bias' |
| | if w_key in weights: |
| | layer.self_attn.o_proj.weight = weights[w_key] |
| | if b_key in weights: |
| | layer.self_attn.o_proj.bias = weights[b_key] |
| | loaded_count += 1 |
| | |
| | if f'{prefix}.mlp.fc1.weight' in weights: |
| | layer.mlp.gate_proj.weight = weights[f'{prefix}.mlp.fc1.weight'] |
| | loaded_count += 1 |
| | if f'{prefix}.mlp.fc2.weight' in weights: |
| | layer.mlp.down_proj.weight = weights[f'{prefix}.mlp.fc2.weight'] |
| | loaded_count += 1 |
| | |
| | for norm_name, model_norm in [('layer_norm1', 'input_layernorm'), ('layer_norm2', 'post_attention_layernorm')]: |
| | if f'{prefix}.{norm_name}.weight' in weights: |
| | getattr(layer, model_norm).weight = weights[f'{prefix}.{norm_name}.weight'] |
| | loaded_count += 1 |
| | |
| | print(f"✅ 视觉编码器权重加载完成 (27 层)") |
| | |
| | |
| | print(f"\n🎯 加载 Vision Head 权重...") |
| | |
| | if 'visual.vision_model.head.attention.in_proj_weight' in weights: |
| | self.model.vision_head.attention.in_proj.weight = weights['visual.vision_model.head.attention.in_proj_weight'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.head.attention.in_proj_bias' in weights: |
| | self.model.vision_head.attention.in_proj.bias = weights['visual.vision_model.head.attention.in_proj_bias'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.head.attention.out_proj.weight' in weights: |
| | self.model.vision_head.attention.out_proj.weight = weights['visual.vision_model.head.attention.out_proj.weight'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.head.attention.out_proj.bias' in weights: |
| | self.model.vision_head.attention.out_proj.bias = weights['visual.vision_model.head.attention.out_proj.bias'] |
| | loaded_count += 1 |
| | |
| | if 'visual.vision_model.head.layernorm.weight' in weights: |
| | self.model.vision_head.layernorm.weight = weights['visual.vision_model.head.layernorm.weight'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.head.layernorm.bias' in weights: |
| | self.model.vision_head.layernorm.bias = weights['visual.vision_model.head.layernorm.bias'] |
| | loaded_count += 1 |
| | |
| | if 'visual.vision_model.head.mlp.fc1.weight' in weights: |
| | self.model.vision_head.mlp.gate_proj.weight = weights['visual.vision_model.head.mlp.fc1.weight'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.head.mlp.fc1.bias' in weights: |
| | self.model.vision_head.mlp.gate_proj.bias = weights['visual.vision_model.head.mlp.fc1.bias'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.head.mlp.fc2.weight' in weights: |
| | self.model.vision_head.mlp.down_proj.weight = weights['visual.vision_model.head.mlp.fc2.weight'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.head.mlp.fc2.bias' in weights: |
| | self.model.vision_head.mlp.down_proj.bias = weights['visual.vision_model.head.mlp.fc2.bias'] |
| | loaded_count += 1 |
| | |
| | if 'visual.vision_model.head.probe' in weights: |
| | self.model.vision_head.probe = weights['visual.vision_model.head.probe'] |
| | loaded_count += 1 |
| | |
| | print(f"✅ Vision Head 权重加载完成 (11 个)") |
| | |
| | |
| | print(f"\n🎯 加载 Post LayerNorm 权重...") |
| | if 'visual.vision_model.post_layernorm.weight' in weights: |
| | self.model.post_layernorm.weight = weights['visual.vision_model.post_layernorm.weight'] |
| | loaded_count += 1 |
| | if 'visual.vision_model.post_layernorm.bias' in weights: |
| | self.model.post_layernorm.bias = weights['visual.vision_model.post_layernorm.bias'] |
| | loaded_count += 1 |
| | print(f"✅ Post LayerNorm 权重加载完成 (2 个)") |
| | |
| | |
| | print(f"\n🔗 加载视觉投影层 (mlp_AR)...") |
| | mlp_ar_loaded = 0 |
| | |
| | if 'mlp_AR.pre_norm.weight' in weights: |
| | self.model.vision_pre_norm.weight = weights['mlp_AR.pre_norm.weight'] |
| | mlp_ar_loaded += 1 |
| | if 'mlp_AR.pre_norm.bias' in weights: |
| | self.model.vision_pre_norm.bias = weights['mlp_AR.pre_norm.bias'] |
| | mlp_ar_loaded += 1 |
| | if 'mlp_AR.linear_1.weight' in weights: |
| | self.model.vision_linear_1.weight = weights['mlp_AR.linear_1.weight'] |
| | mlp_ar_loaded += 1 |
| | if 'mlp_AR.linear_1.bias' in weights: |
| | self.model.vision_linear_1.bias = weights['mlp_AR.linear_1.bias'] |
| | mlp_ar_loaded += 1 |
| | if 'mlp_AR.linear_2.weight' in weights: |
| | self.model.vision_linear_2.weight = weights['mlp_AR.linear_2.weight'] |
| | mlp_ar_loaded += 1 |
| | if 'mlp_AR.linear_2.bias' in weights: |
| | self.model.vision_linear_2.bias = weights['mlp_AR.linear_2.bias'] |
| | mlp_ar_loaded += 1 |
| | |
| | print(f"✅ 视觉投影层加载完成 ({mlp_ar_loaded}/6 个)") |
| | loaded_count += mlp_ar_loaded |
| | |
| | |
| | print(f"\n📝 加载语言模型权重...") |
| | |
| | if 'model.embed_tokens.weight' in weights: |
| | self.model.embed_tokens.weight = weights['model.embed_tokens.weight'] |
| | loaded_count += 1 |
| | |
| | for i in range(18): |
| | layer = self.model.layers[i] |
| | prefix = f'model.layers.{i}' |
| | |
| | for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']: |
| | if f'{prefix}.self_attn.{proj}.weight' in weights: |
| | getattr(layer.self_attn, proj).weight = weights[f'{prefix}.self_attn.{proj}.weight'] |
| | loaded_count += 1 |
| | |
| | for proj in ['gate_proj', 'up_proj', 'down_proj']: |
| | if f'{prefix}.mlp.{proj}.weight' in weights: |
| | getattr(layer.mlp, proj).weight = weights[f'{prefix}.mlp.{proj}.weight'] |
| | loaded_count += 1 |
| | |
| | for norm in ['input_layernorm', 'post_attention_layernorm']: |
| | if f'{prefix}.{norm}.weight' in weights: |
| | getattr(layer, norm).weight = weights[f'{prefix}.{norm}.weight'] |
| | loaded_count += 1 |
| | |
| | if 'model.norm.weight' in weights: |
| | self.model.norm.weight = weights['model.norm.weight'] |
| | loaded_count += 1 |
| | |
| | if 'lm_head.weight' in weights: |
| | self.model.lm_head.weight = weights['lm_head.weight'] |
| | loaded_count += 1 |
| | |
| | print(f"✅ 语言模型权重加载完成") |
| | print(f"\n✅ 总共成功加载 {loaded_count} 个权重") |
| | |
| | except Exception as e: |
| | print(f"\n❌ 权重加载失败: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | |
| | print("\n" + "="*60) |
| | print(f"📊 权重加载完成: {loaded_count}/620") |
| | print("="*60) |
| | |
| | def preprocess_image(self, image_path: str, prompt: str = "Extract all text from this image.") -> Tuple[mx.array, mx.array, mx.array, int]: |
| | """使用原始 processor 预处理图像 - 最终修复!⭐""" |
| | if self.processor is None: |
| | raise ValueError("Processor not loaded!") |
| | |
| | |
| | image = Image.open(image_path).convert('RGB') |
| | |
| | |
| | inputs = self.processor( |
| | images=image, |
| | text=prompt, |
| | return_tensors="pt" |
| | ) |
| | |
| | |
| | pixel_values_torch = inputs['pixel_values'] |
| | pixel_values_np = pixel_values_torch.numpy() |
| | pixel_values_np = np.transpose(pixel_values_np, (0, 2, 3, 1)) |
| | pixel_values = mx.array(pixel_values_np) |
| | |
| | |
| | image_grid_thw = mx.array(inputs['image_grid_thw'].numpy()) |
| | |
| | |
| | t, h, w = int(image_grid_thw[0, 0]), int(image_grid_thw[0, 1]), int(image_grid_thw[0, 2]) |
| | num_image_features = t * (h // 2) * (w // 2) |
| | |
| | |
| | image_token_id = 100295 |
| | |
| | |
| | if self.tokenizer: |
| | text_tokens = self.tokenizer.encode(prompt, add_special_tokens=False) |
| | else: |
| | text_tokens = [1, 2, 3, 4, 5] |
| | |
| | |
| | bos_token_id = 1 |
| | input_ids_list = [bos_token_id] + [image_token_id] * num_image_features + text_tokens |
| | input_ids = mx.array([input_ids_list]) |
| | |
| | print(f"✅ 图像预处理完成 (最终修复):") |
| | print(f" pixel_values: {pixel_values.shape}") |
| | print(f" num_image_features: {num_image_features}") |
| | print(f" input_ids: {input_ids.shape}") |
| | print(f" image_grid_thw: {image_grid_thw.tolist()}") |
| | |
| | return pixel_values, input_ids, image_grid_thw, num_image_features |
| | |
| | def generate( |
| | self, |
| | pixel_values: mx.array, |
| | input_ids: mx.array, |
| | image_grid_thw: mx.array, |
| | max_tokens: int = 100, |
| | temperature: float = 0.0, |
| | repetition_penalty: float = 2.0, |
| | ) -> str: |
| | """生成文本""" |
| | |
| | print(f"\n🔮 开始生成...") |
| | |
| | |
| | start = time.time() |
| | vision_embeds = self.model.encode_image(pixel_values, image_grid_thw) |
| | print(f"✅ 图像编码: {vision_embeds.shape} ({time.time()-start:.2f}s)") |
| | |
| | |
| | print(f"\n🔄 自回归生成 (max_tokens={max_tokens}, repetition_penalty={repetition_penalty})...") |
| | start = time.time() |
| | |
| | output_ids = [] |
| | current_ids = input_ids |
| | eos_token_id = self.tokenizer.eos_token_id if self.tokenizer else 2 |
| | |
| | for i in range(max_tokens): |
| | logits = self.model.forward(current_ids, vision_embeds) |
| | next_token_logits = logits[:, -1, :] |
| | |
| | |
| | if repetition_penalty != 1.0 and len(output_ids) > 0: |
| | next_token_logits = mx.array(next_token_logits) |
| | for token_id in set(output_ids): |
| | next_token_logits[0, token_id] = next_token_logits[0, token_id] / repetition_penalty |
| | |
| | if temperature == 0: |
| | next_token = mx.argmax(next_token_logits, axis=-1) |
| | else: |
| | next_token_logits = next_token_logits / temperature |
| | probs = mx.softmax(next_token_logits, axis=-1) |
| | next_token = mx.random.categorical(probs) |
| | |
| | next_token_id = int(next_token[0]) |
| | |
| | if next_token_id == eos_token_id: |
| | print(f" 遇到 EOS token,停止生成") |
| | break |
| | |
| | output_ids.append(next_token_id) |
| | current_ids = mx.concatenate([current_ids, mx.array([[next_token_id]])], axis=1) |
| | |
| | if (i + 1) % 20 == 0: |
| | print(f" 生成了 {i + 1} tokens...") |
| | |
| | elapsed = time.time() - start |
| | print(f"✅ 生成完成: {len(output_ids)} tokens ({elapsed:.2f}s, {len(output_ids)/elapsed:.1f} tokens/s)") |
| | |
| | |
| | if self.tokenizer: |
| | result_text = self.tokenizer.decode(output_ids, skip_special_tokens=True) |
| | else: |
| | result_text = f"[Token IDs: {output_ids[:10]}...]" |
| | |
| | return result_text |
| | |
| | def ocr( |
| | self, |
| | image_path: str, |
| | prompt: str = "Extract all text from this image.", |
| | max_tokens: int = 100, |
| | repetition_penalty: float = 2.0, |
| | ) -> dict: |
| | """端到端 OCR""" |
| | |
| | print("\n" + "="*60) |
| | print("🚀 执行最终优化版 OCR") |
| | print("="*60) |
| | |
| | total_start = time.time() |
| | |
| | |
| | pixel_values, input_ids, image_grid_thw, num_image_features = self.preprocess_image(image_path, prompt) |
| | |
| | |
| | result_text = self.generate(pixel_values, input_ids, image_grid_thw, max_tokens, repetition_penalty=repetition_penalty) |
| | |
| | total_time = time.time() - total_start |
| | |
| | print(f"\n✅ OCR 完成 (总耗时: {total_time:.2f}s)") |
| | print("="*60) |
| | |
| | return { |
| | 'text': result_text, |
| | 'elapsed_time': total_time, |
| | 'status': 'success' |
| | } |
| |
|
| |
|
| | def main(): |
| | """主函数""" |
| | |
| | print("\n" + "="*60) |
| | print("🎯 PaddleOCR MLX 最终优化版测试") |
| | print("="*60) |
| | print(f"目标: 达到原版准确度 80-90%") |
| | print(f"关键改进: 使用正确的图像预处理 ⭐") |
| | print("="*60) |
| | |
| | model_dir = "/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion" |
| | |
| | try: |
| | |
| | ocr = FinalOptimizedPaddleOCRMLX(model_dir) |
| | |
| | |
| | print("\n📋 创建测试图像...") |
| | img = Image.new('RGB', (400, 200), color='white') |
| | draw = ImageDraw.Draw(img) |
| | draw.text((50, 80), "Hello World", fill='black') |
| | test_path = "/tmp/test_final_mlx.png" |
| | img.save(test_path) |
| | print(f"✅ 测试图像: {test_path}") |
| | |
| | |
| | result = ocr.ocr(test_path, max_tokens=50, repetition_penalty=2.0) |
| | |
| | |
| | print(f"\n📝 OCR 结果:") |
| | print(f"{'='*60}") |
| | print(result['text']) |
| | print(f"{'='*60}") |
| | print(f"耗时: {result['elapsed_time']:.2f}s") |
| | |
| | print(f"\n🎉 最终优化版测试完成!") |
| | |
| | except Exception as e: |
| | print(f"\n❌ 错误: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| | def apply_multimodal_rotary_pos_emb(q, k, t_ids, h_ids, w_ids, head_dim): |
| | |
| | m_t, m_h, m_w = 16, 24, 24 |
| | theta = 500000.0 |
| | |
| | |
| | def get_cos_sin(ids, dim): |
| | |
| | |
| | inv_freq = 1.0 / (theta ** (mx.arange(0, dim, 2) / dim)) |
| | |
| | freqs = mx.matmul(ids[..., None].astype(mx.float32), inv_freq[None, :]) |
| | emb = mx.concatenate([freqs, freqs], axis=-1) |
| | return mx.cos(emb), mx.sin(emb) |
| | |
| | |
| | |
| | |
| | c_t, s_t = get_cos_sin(t_ids, m_t) |
| | c_h, s_h = get_cos_sin(h_ids, m_h) |
| | c_w, s_w = get_cos_sin(w_ids, m_w) |
| | |
| | |
| | |
| | cos = mx.concatenate([c_t, c_h, c_w], axis=-1) |
| | sin = mx.concatenate([s_t, s_h, s_w], axis=-1) |
| | |
| | |
| | |
| | rem = head_dim - 64 |
| | if rem > 0: |
| | c_rem = mx.ones((*cos.shape[:-1], rem)) |
| | s_rem = mx.zeros((*sin.shape[:-1], rem)) |
| | cos = mx.concatenate([cos, c_rem], axis=-1) |
| | sin = mx.concatenate([sin, s_rem], axis=-1) |
| | |
| | |
| | cos = cos[:, None, :, :] |
| | sin = sin[:, None, :, :] |
| | |
| | q_rot = (q * cos) + (rotate_half(q) * sin) |
| | k_rot = (k * cos) + (rotate_half(k) * sin) |
| | |
| | return q_rot, k_rot |
| |
|