| |
| """ |
| 直接集成权重加载到模型初始化中 |
| |
| 作者: AI Assistant |
| 日期: 2024-12-25 |
| 版本: v4.0 - 集成权重加载 |
| """ |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| from PIL import Image |
| import numpy as np |
| import json |
| from pathlib import Path |
| from typing import Optional, List, Dict, Tuple |
| import time |
|
|
| |
| from mlx_components import ( |
| RMSNorm, Attention, MLP, DecoderLayer, VisionEncoder |
| ) |
|
|
|
|
| class RealPaddleOCRMLXWithWeights: |
| """真实的 PaddleOCR MLX 推理引擎 - 集成权重加载""" |
| |
| def __init__(self, model_dir: str, load_weights: bool = True): |
| self.model_dir = Path(model_dir) |
| print("🚀 初始化 PaddleOCR MLX (集成权重加载)...") |
| print(f"📂 模型目录: {model_dir}") |
| |
| |
| self.config = self._load_config() |
| |
| |
| self.tokenizer = self._load_tokenizer() |
| |
| |
| self.model = self._create_model() |
| |
| |
| if load_weights: |
| self._load_real_weights() |
| |
| print("✅ 初始化完成!") |
| |
| def _load_config(self) -> dict: |
| """加载模型配置""" |
| config_path = self.model_dir / "config.json" |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
| print(f"✅ 配置加载完成") |
| return config |
| |
| def _load_tokenizer(self): |
| """加载 tokenizer""" |
| try: |
| from transformers import AutoTokenizer |
| original_model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" |
| tokenizer = AutoTokenizer.from_pretrained( |
| original_model_path, |
| trust_remote_code=True |
| ) |
| print(f"✅ Tokenizer 加载完成 (词汇表: {len(tokenizer)})") |
| return tokenizer |
| except Exception as e: |
| print(f"⚠️ Tokenizer 加载失败: {e}") |
| return None |
| |
| def _create_model(self): |
| """创建模型""" |
| print("🔄 创建模型...") |
| |
| class SimplifiedModel(nn.Module): |
| """简化的模型""" |
| |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.get('hidden_size', 1024) |
| self.vocab_size = config.get('vocab_size', 103424) |
| self.intermediate_size = config.get('intermediate_size', 3072) |
| self.num_attention_heads = config.get('num_attention_heads', 16) |
| self.num_kv_heads = config.get('num_key_value_heads', 2) |
| self.num_hidden_layers = config.get('num_hidden_layers', 18) |
| self.head_dim = config.get('head_dim', 128) |
| |
| |
| self.image_size = 224 |
| self.patch_size = 14 |
| self.num_patches = (self.image_size // self.patch_size) ** 2 |
| patch_dim = self.patch_size * self.patch_size * 3 |
| self.patch_projection = nn.Linear(patch_dim, self.hidden_size, bias=False) |
| |
| |
| self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) |
| |
| |
| self.layers = [ |
| DecoderLayer( |
| hidden_size=self.hidden_size, |
| num_heads=self.num_attention_heads, |
| intermediate_size=self.intermediate_size, |
| num_kv_heads=self.num_kv_heads, |
| head_dim=self.head_dim, |
| ) |
| for _ in range(self.num_hidden_layers) |
| ] |
| |
| |
| self.norm = RMSNorm(self.hidden_size, eps=config.get('rms_norm_eps', 1e-6)) |
| |
| |
| self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) |
| |
| def encode_image(self, pixel_values: mx.array) -> mx.array: |
| """编码图像""" |
| B, H, W, C = pixel_values.shape |
| patches = pixel_values.reshape( |
| B, |
| H // 14, 14, |
| W // 14, 14, |
| C |
| ) |
| patches = mx.transpose(patches, (0, 1, 3, 2, 4, 5)) |
| patches = patches.reshape(B, self.num_patches, -1) |
| vision_embeds = self.patch_projection(patches) |
| return vision_embeds |
| |
| def forward( |
| self, |
| input_ids: mx.array, |
| vision_embeds: Optional[mx.array] = None, |
| ) -> mx.array: |
| """前向传播""" |
| text_embeds = self.embed_tokens(input_ids) |
| |
| if vision_embeds is not None: |
| hidden_states = mx.concatenate([vision_embeds, text_embeds], axis=1) |
| else: |
| hidden_states = text_embeds |
| |
| for layer in self.layers: |
| hidden_states = layer(hidden_states, None) |
| |
| hidden_states = self.norm(hidden_states) |
| logits = self.lm_head(hidden_states) |
| |
| return logits |
| |
| model = SimplifiedModel(self.config) |
| print("✅ 模型创建完成") |
| return model |
| |
| def _load_real_weights(self): |
| """加载真实权重""" |
| print("\n" + "="*60) |
| print("🔄 加载真实权重...") |
| print("="*60) |
| |
| weights_path = self.model_dir / "paddleocr_vl_mlx.npz" |
| |
| |
| print(f"\n📂 加载权重文件...") |
| weights = mx.load(str(weights_path)) |
| print(f"✅ 加载了 {len(weights)} 个权重张量") |
| |
| |
| print(f"\n🗺️ 创建权重映射...") |
| mapping = self._create_weight_mapping() |
| |
| |
| print(f"\n🔧 准备权重字典...") |
| new_weights = {} |
| loaded_count = 0 |
| |
| for pt_name, mlx_name in mapping.items(): |
| if pt_name in weights: |
| new_weights[mlx_name] = weights[pt_name] |
| loaded_count += 1 |
| |
| print(f"✅ 准备了 {loaded_count} 个权重") |
| |
| |
| print(f"\n🔄 更新模型参数...") |
| self.model.update(new_weights) |
| print(f"✅ 权重加载完成") |
| |
| print("\n" + "="*60) |
| print(f"📊 成功加载 {loaded_count}/{len(mapping)} 个权重") |
| print("="*60) |
| |
| def _create_weight_mapping(self) -> Dict[str, str]: |
| """创建权重映射""" |
| mapping = {} |
| |
| |
| mapping['model.embed_tokens.weight'] = 'embed_tokens.weight' |
| |
| |
| for i in range(18): |
| mapping[f'model.layers.{i}.self_attn.q_proj.weight'] = f'layers.{i}.self_attn.q_proj.weight' |
| mapping[f'model.layers.{i}.self_attn.k_proj.weight'] = f'layers.{i}.self_attn.k_proj.weight' |
| mapping[f'model.layers.{i}.self_attn.v_proj.weight'] = f'layers.{i}.self_attn.v_proj.weight' |
| mapping[f'model.layers.{i}.self_attn.o_proj.weight'] = f'layers.{i}.self_attn.o_proj.weight' |
| mapping[f'model.layers.{i}.mlp.gate_proj.weight'] = f'layers.{i}.mlp.gate_proj.weight' |
| mapping[f'model.layers.{i}.mlp.up_proj.weight'] = f'layers.{i}.mlp.up_proj.weight' |
| mapping[f'model.layers.{i}.mlp.down_proj.weight'] = f'layers.{i}.mlp.down_proj.weight' |
| mapping[f'model.layers.{i}.input_layernorm.weight'] = f'layers.{i}.input_layernorm.weight' |
| mapping[f'model.layers.{i}.post_attention_layernorm.weight'] = f'layers.{i}.post_attention_layernorm.weight' |
| |
| |
| mapping['model.norm.weight'] = 'norm.weight' |
| |
| |
| mapping['lm_head.weight'] = 'lm_head.weight' |
| |
| return mapping |
| |
| |
| def preprocess_image(self, image_path: str) -> Tuple[mx.array, tuple]: |
| """预处理图像""" |
| image = Image.open(image_path).convert('RGB') |
| original_size = image.size |
| target_size = (224, 224) |
| image = image.resize(target_size, Image.Resampling.BILINEAR) |
| image_array = np.array(image).astype(np.float32) / 255.0 |
| mean = np.array([0.485, 0.456, 0.406]) |
| std = np.array([0.229, 0.224, 0.225]) |
| image_array = (image_array - mean) / std |
| image_array = np.expand_dims(image_array, 0) |
| return mx.array(image_array), original_size |
| |
| def encode_prompt(self, prompt: str) -> mx.array: |
| """编码提示文本""" |
| if self.tokenizer: |
| tokens = self.tokenizer.encode(prompt, add_special_tokens=True) |
| return mx.array([tokens]) |
| else: |
| return mx.array([[1, 2, 3, 4, 5]]) |
| |
| def decode_tokens(self, token_ids: List[int]) -> str: |
| """解码 token IDs 为文本""" |
| if self.tokenizer: |
| text = self.tokenizer.decode(token_ids, skip_special_tokens=True) |
| return text |
| else: |
| return f"[Token IDs: {token_ids[:10]}...]" |
| |
| def generate( |
| self, |
| pixel_values: mx.array, |
| prompt: str = "Extract all text from this image.", |
| max_tokens: int = 512, |
| temperature: float = 0.7, |
| ) -> str: |
| """生成文本""" |
| |
| print(f"\n🔮 开始生成...") |
| |
| |
| vision_embeds = self.model.encode_image(pixel_values) |
| print(f"✅ 图像编码: {vision_embeds.shape}") |
| |
| |
| prompt_ids = self.encode_prompt(prompt) |
| print(f"✅ 提示编码: {prompt_ids.shape}") |
| |
| |
| output_ids = [] |
| current_ids = prompt_ids |
| eos_token_id = self.tokenizer.eos_token_id if self.tokenizer else 2 |
| |
| for i in range(max_tokens): |
| logits = self.model.forward(current_ids, vision_embeds) |
| next_token_logits = logits[:, -1, :] |
| |
| if temperature == 0: |
| next_token = mx.argmax(next_token_logits, axis=-1) |
| else: |
| next_token_logits = next_token_logits / temperature |
| probs = mx.softmax(next_token_logits, axis=-1) |
| next_token = mx.random.categorical(probs) |
| |
| next_token_id = int(next_token[0]) |
| |
| if next_token_id == eos_token_id: |
| break |
| |
| output_ids.append(next_token_id) |
| current_ids = mx.concatenate([current_ids, mx.array([[next_token_id]])], axis=1) |
| |
| if (i + 1) % 10 == 0: |
| print(f" 生成了 {i + 1} tokens...") |
| |
| print(f"✅ 生成完成: {len(output_ids)} tokens") |
| |
| |
| result_text = self.decode_tokens(output_ids) |
| return result_text |
| |
| def ocr( |
| self, |
| image_path: str, |
| prompt: str = "Extract all text from this image.", |
| max_tokens: int = 512, |
| ) -> dict: |
| """端到端 OCR""" |
| |
| print("\n" + "="*60) |
| print("🚀 执行 OCR") |
| print("="*60) |
| |
| start_time = time.time() |
| |
| |
| pixel_values, original_size = self.preprocess_image(image_path) |
| |
| |
| result_text = self.generate(pixel_values, prompt, max_tokens) |
| |
| elapsed = time.time() - start_time |
| |
| print(f"\n✅ OCR 完成 (耗时: {elapsed:.2f}s)") |
| |
| return { |
| 'text': result_text, |
| 'image_size': original_size, |
| 'elapsed_time': elapsed, |
| 'status': 'success' |
| } |
|
|
|
|
| |
| if __name__ == "__main__": |
| print("\n🎉 测试集成权重加载版本") |
| |
| model_dir = "/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion" |
| |
| try: |
| ocr = RealPaddleOCRMLXWithWeights(model_dir, load_weights=True) |
| |
| |
| from PIL import ImageDraw |
| img = Image.new('RGB', (400, 200), color='white') |
| draw = ImageDraw.Draw(img) |
| draw.text((50, 80), "Test OCR", fill='black') |
| test_path = "/tmp/test_integrated.png" |
| img.save(test_path) |
| |
| |
| result = ocr.ocr(test_path, max_tokens=20) |
| |
| print(f"\n📝 结果:") |
| print(f"{'='*60}") |
| print(result['text'][:200]) |
| print(f"{'='*60}") |
| |
| print(f"\n🎉 测试成功!") |
| |
| except Exception as e: |
| print(f"\n❌ 错误: {e}") |
| import traceback |
| traceback.print_exc() |
|
|