|
|
|
|
|
""" |
|
|
PaddleOCR-VL MLX 最终优化版 - 使用正确的图像预处理 |
|
|
目标:达到原版准确度 80-90% |
|
|
|
|
|
作者: AI Assistant |
|
|
日期: 2024-12-25 |
|
|
版本: v8.0 - 最终优化 |
|
|
""" |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
from PIL import Image, ImageDraw |
|
|
import numpy as np |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Optional, List, Tuple |
|
|
import time |
|
|
import torch |
|
|
|
|
|
|
|
|
from mlx_components import ( |
|
|
RMSNorm, MLP, DecoderLayer |
|
|
) |
|
|
|
|
|
|
|
|
class VisionHeadAttention(nn.Module): |
|
|
"""Vision Head 的注意力层""" |
|
|
|
|
|
def __init__(self, hidden_size: int = 1152): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.num_heads = 16 |
|
|
self.head_dim = hidden_size // self.num_heads |
|
|
|
|
|
self.in_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True) |
|
|
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True) |
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
|
B, L, D = x.shape |
|
|
|
|
|
qkv = self.in_proj(x) |
|
|
qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim) |
|
|
qkv = mx.transpose(qkv, (2, 0, 3, 1, 4)) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) / (self.head_dim ** 0.5) |
|
|
attn = mx.softmax(attn, axis=-1) |
|
|
out = mx.matmul(attn, v) |
|
|
|
|
|
out = mx.transpose(out, (0, 2, 1, 3)) |
|
|
out = out.reshape(B, L, D) |
|
|
out = self.out_proj(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class VisionHead(nn.Module): |
|
|
"""Vision Head 层""" |
|
|
|
|
|
def __init__(self, hidden_size: int = 1152): |
|
|
super().__init__() |
|
|
self.attention = VisionHeadAttention(hidden_size) |
|
|
self.layernorm = nn.LayerNorm(hidden_size) |
|
|
self.mlp = MLP(hidden_size, 4304) |
|
|
self.probe = mx.zeros((1, 1, hidden_size)) |
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
|
residual = x |
|
|
x = self.attention(x) |
|
|
x = residual + x |
|
|
|
|
|
x = self.layernorm(x) |
|
|
|
|
|
residual = x |
|
|
x = self.mlp(x) |
|
|
x = residual + x |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class FinalOptimizedPaddleOCRMLX: |
|
|
"""最终优化版 PaddleOCR MLX - 使用正确的图像预处理""" |
|
|
|
|
|
def __init__(self, model_dir: str): |
|
|
self.model_dir = Path(model_dir) |
|
|
print("🚀 初始化最终优化版 PaddleOCR MLX...") |
|
|
print(f"📂 模型目录: {model_dir}") |
|
|
|
|
|
|
|
|
self.config = self._load_config() |
|
|
|
|
|
|
|
|
self.tokenizer = self._load_tokenizer() |
|
|
self.processor = self._load_processor() |
|
|
|
|
|
|
|
|
self.model = self._create_model() |
|
|
|
|
|
|
|
|
self._load_all_weights() |
|
|
|
|
|
print("✅ 初始化完成!") |
|
|
|
|
|
def _load_config(self) -> dict: |
|
|
"""加载模型配置""" |
|
|
config_path = self.model_dir / "config.json" |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
print(f"✅ 配置加载完成") |
|
|
return config |
|
|
|
|
|
def _load_tokenizer(self): |
|
|
"""加载 tokenizer""" |
|
|
try: |
|
|
from transformers import AutoTokenizer |
|
|
original_model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
original_model_path, |
|
|
trust_remote_code=True |
|
|
) |
|
|
print(f"✅ Tokenizer 加载完成 (词汇表: {len(tokenizer)})") |
|
|
return tokenizer |
|
|
except Exception as e: |
|
|
print(f"⚠️ Tokenizer 加载失败: {e}") |
|
|
return None |
|
|
|
|
|
def _load_processor(self): |
|
|
"""加载 processor - 关键!""" |
|
|
try: |
|
|
from transformers import AutoProcessor |
|
|
original_model_path = "/Users/gt/.lmstudio/hub/models/paddleocr-vl" |
|
|
processor = AutoProcessor.from_pretrained( |
|
|
original_model_path, |
|
|
trust_remote_code=True |
|
|
) |
|
|
print(f"✅ Processor 加载完成 ⭐ 关键改进") |
|
|
return processor |
|
|
except Exception as e: |
|
|
print(f"⚠️ Processor 加载失败: {e}") |
|
|
return None |
|
|
|
|
|
def _create_model(self): |
|
|
"""创建完整模型""" |
|
|
print("🔄 创建完整模型...") |
|
|
|
|
|
class OptimizedModel(nn.Module): |
|
|
"""优化的模型""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.hidden_size = config.get('hidden_size', 1024) |
|
|
self.vocab_size = config.get('vocab_size', 103424) |
|
|
self.intermediate_size = config.get('intermediate_size', 3072) |
|
|
self.num_attention_heads = config.get('num_attention_heads', 16) |
|
|
self.num_kv_heads = config.get('num_key_value_heads', 2) |
|
|
self.num_hidden_layers = config.get('num_hidden_layers', 18) |
|
|
self.head_dim = config.get('head_dim', 128) |
|
|
|
|
|
|
|
|
vision_config = config.get('vision_config', {}) |
|
|
self.vision_hidden_size = vision_config.get('hidden_size', 1152) |
|
|
self.vision_num_layers = 27 |
|
|
|
|
|
|
|
|
self.patch_embedding = nn.Conv2d( |
|
|
in_channels=3, |
|
|
out_channels=self.vision_hidden_size, |
|
|
kernel_size=14, |
|
|
stride=14, |
|
|
bias=True |
|
|
) |
|
|
|
|
|
|
|
|
self.position_embedding = mx.zeros((729, self.vision_hidden_size)) |
|
|
|
|
|
|
|
|
self.vision_layers = [ |
|
|
DecoderLayer( |
|
|
hidden_size=self.vision_hidden_size, |
|
|
num_heads=16, |
|
|
intermediate_size=4304, |
|
|
num_kv_heads=16, |
|
|
head_dim=72, |
|
|
) |
|
|
for _ in range(self.vision_num_layers) |
|
|
] |
|
|
|
|
|
|
|
|
self.vision_norm = RMSNorm(self.vision_hidden_size) |
|
|
|
|
|
|
|
|
self.vision_head = VisionHead(self.vision_hidden_size) |
|
|
|
|
|
|
|
|
self.post_layernorm = nn.LayerNorm(self.vision_hidden_size) |
|
|
|
|
|
|
|
|
self.vision_pre_norm = nn.LayerNorm(self.vision_hidden_size) |
|
|
self.vision_linear_1 = nn.Linear(4608, 4608, bias=True) |
|
|
self.vision_linear_2 = nn.Linear(4608, self.hidden_size, bias=True) |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) |
|
|
|
|
|
|
|
|
self.layers = [ |
|
|
DecoderLayer( |
|
|
hidden_size=self.hidden_size, |
|
|
num_heads=self.num_attention_heads, |
|
|
intermediate_size=self.intermediate_size, |
|
|
num_kv_heads=self.num_kv_heads, |
|
|
head_dim=self.head_dim, |
|
|
) |
|
|
for _ in range(self.num_hidden_layers) |
|
|
] |
|
|
|
|
|
|
|
|
self.norm = RMSNorm(self.hidden_size, eps=config.get('rms_norm_eps', 1e-6)) |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) |
|
|
|
|
|
def encode_image(self, pixel_values: mx.array) -> mx.array: |
|
|
"""编码图像 - 处理多个 patches""" |
|
|
|
|
|
num_patches, H, W, C = pixel_values.shape |
|
|
|
|
|
|
|
|
x = self.patch_embedding(pixel_values) |
|
|
x = x.reshape(num_patches, self.vision_hidden_size) |
|
|
x = mx.expand_dims(x, 0) |
|
|
|
|
|
|
|
|
if num_patches <= 729: |
|
|
x = x + self.position_embedding[:num_patches, :] |
|
|
|
|
|
|
|
|
for layer in self.vision_layers: |
|
|
x = layer(x, None) |
|
|
|
|
|
|
|
|
x = self.vision_norm(x) |
|
|
|
|
|
|
|
|
x = self.vision_head(x) |
|
|
|
|
|
|
|
|
x = self.post_layernorm(x) |
|
|
|
|
|
|
|
|
x = self.vision_pre_norm(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_patches > 256: |
|
|
|
|
|
indices = mx.linspace(0, num_patches-1, 256).astype(mx.int32) |
|
|
x = x[:, indices, :] |
|
|
num_patches = 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_patches % 4 == 0: |
|
|
|
|
|
x = x.reshape(1, num_patches // 4, 4 * self.vision_hidden_size) |
|
|
else: |
|
|
|
|
|
pad_size = (4 - num_patches % 4) % 4 |
|
|
if pad_size > 0: |
|
|
padding = mx.zeros((1, pad_size, self.vision_hidden_size)) |
|
|
x = mx.concatenate([x, padding], axis=1) |
|
|
num_patches += pad_size |
|
|
x = x.reshape(1, num_patches // 4, 4 * self.vision_hidden_size) |
|
|
|
|
|
|
|
|
x = self.vision_linear_1(x) |
|
|
x = nn.gelu(x) |
|
|
x = self.vision_linear_2(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def forward(self, input_ids: mx.array, vision_embeds: Optional[mx.array] = None) -> mx.array: |
|
|
"""前向传播""" |
|
|
text_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
if vision_embeds is not None: |
|
|
|
|
|
vision_start_id = mx.array([[101305]]) |
|
|
vision_start_embed = self.embed_tokens(vision_start_id) |
|
|
|
|
|
vision_end_id = mx.array([[101306]]) |
|
|
vision_end_embed = self.embed_tokens(vision_end_id) |
|
|
|
|
|
hidden_states = mx.concatenate([ |
|
|
vision_start_embed, |
|
|
vision_embeds, |
|
|
vision_end_embed, |
|
|
text_embeds |
|
|
], axis=1) |
|
|
else: |
|
|
hidden_states = text_embeds |
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states, None) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
return logits |
|
|
|
|
|
model = OptimizedModel(self.config) |
|
|
print("✅ 优化模型创建完成") |
|
|
return model |
|
|
|
|
|
def _load_all_weights(self): |
|
|
"""加载所有权重""" |
|
|
print("\n" + "="*60) |
|
|
print("🔄 加载所有权重...") |
|
|
print("="*60) |
|
|
|
|
|
weights_path = self.model_dir / "paddleocr_vl_mlx.npz" |
|
|
weights = mx.load(str(weights_path)) |
|
|
print(f"\n📂 加载了 {len(weights)} 个权重张量") |
|
|
|
|
|
loaded_count = 0 |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"\n📸 加载视觉编码器权重...") |
|
|
|
|
|
if 'visual.vision_model.embeddings.patch_embedding.weight' in weights: |
|
|
w = weights['visual.vision_model.embeddings.patch_embedding.weight'] |
|
|
w_transposed = mx.transpose(w, (0, 2, 3, 1)) |
|
|
self.model.patch_embedding.weight = w_transposed |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.embeddings.patch_embedding.bias' in weights: |
|
|
self.model.patch_embedding.bias = weights['visual.vision_model.embeddings.patch_embedding.bias'] |
|
|
loaded_count += 1 |
|
|
|
|
|
if 'visual.vision_model.embeddings.position_embedding.weight' in weights: |
|
|
self.model.position_embedding = weights['visual.vision_model.embeddings.position_embedding.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
for i in range(27): |
|
|
layer = self.model.vision_layers[i] |
|
|
prefix = f'visual.vision_model.encoder.layers.{i}' |
|
|
|
|
|
for proj_name in ['q_proj', 'k_proj', 'v_proj']: |
|
|
w_key = f'{prefix}.self_attn.{proj_name}.weight' |
|
|
b_key = f'{prefix}.self_attn.{proj_name}.bias' |
|
|
if w_key in weights: |
|
|
proj = getattr(layer.self_attn, proj_name) |
|
|
proj.weight = weights[w_key] |
|
|
if b_key in weights: |
|
|
proj.bias = weights[b_key] |
|
|
loaded_count += 1 |
|
|
|
|
|
w_key = f'{prefix}.self_attn.out_proj.weight' |
|
|
b_key = f'{prefix}.self_attn.out_proj.bias' |
|
|
if w_key in weights: |
|
|
layer.self_attn.o_proj.weight = weights[w_key] |
|
|
if b_key in weights: |
|
|
layer.self_attn.o_proj.bias = weights[b_key] |
|
|
loaded_count += 1 |
|
|
|
|
|
if f'{prefix}.mlp.fc1.weight' in weights: |
|
|
layer.mlp.gate_proj.weight = weights[f'{prefix}.mlp.fc1.weight'] |
|
|
loaded_count += 1 |
|
|
if f'{prefix}.mlp.fc2.weight' in weights: |
|
|
layer.mlp.down_proj.weight = weights[f'{prefix}.mlp.fc2.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
for norm_name, model_norm in [('layer_norm1', 'input_layernorm'), ('layer_norm2', 'post_attention_layernorm')]: |
|
|
if f'{prefix}.{norm_name}.weight' in weights: |
|
|
getattr(layer, model_norm).weight = weights[f'{prefix}.{norm_name}.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
print(f"✅ 视觉编码器权重加载完成 (27 层)") |
|
|
|
|
|
|
|
|
print(f"\n🎯 加载 Vision Head 权重...") |
|
|
|
|
|
if 'visual.vision_model.head.attention.in_proj_weight' in weights: |
|
|
self.model.vision_head.attention.in_proj.weight = weights['visual.vision_model.head.attention.in_proj_weight'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.head.attention.in_proj_bias' in weights: |
|
|
self.model.vision_head.attention.in_proj.bias = weights['visual.vision_model.head.attention.in_proj_bias'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.head.attention.out_proj.weight' in weights: |
|
|
self.model.vision_head.attention.out_proj.weight = weights['visual.vision_model.head.attention.out_proj.weight'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.head.attention.out_proj.bias' in weights: |
|
|
self.model.vision_head.attention.out_proj.bias = weights['visual.vision_model.head.attention.out_proj.bias'] |
|
|
loaded_count += 1 |
|
|
|
|
|
if 'visual.vision_model.head.layernorm.weight' in weights: |
|
|
self.model.vision_head.layernorm.weight = weights['visual.vision_model.head.layernorm.weight'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.head.layernorm.bias' in weights: |
|
|
self.model.vision_head.layernorm.bias = weights['visual.vision_model.head.layernorm.bias'] |
|
|
loaded_count += 1 |
|
|
|
|
|
if 'visual.vision_model.head.mlp.fc1.weight' in weights: |
|
|
self.model.vision_head.mlp.gate_proj.weight = weights['visual.vision_model.head.mlp.fc1.weight'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.head.mlp.fc1.bias' in weights: |
|
|
self.model.vision_head.mlp.gate_proj.bias = weights['visual.vision_model.head.mlp.fc1.bias'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.head.mlp.fc2.weight' in weights: |
|
|
self.model.vision_head.mlp.down_proj.weight = weights['visual.vision_model.head.mlp.fc2.weight'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.head.mlp.fc2.bias' in weights: |
|
|
self.model.vision_head.mlp.down_proj.bias = weights['visual.vision_model.head.mlp.fc2.bias'] |
|
|
loaded_count += 1 |
|
|
|
|
|
if 'visual.vision_model.head.probe' in weights: |
|
|
self.model.vision_head.probe = weights['visual.vision_model.head.probe'] |
|
|
loaded_count += 1 |
|
|
|
|
|
print(f"✅ Vision Head 权重加载完成 (11 个)") |
|
|
|
|
|
|
|
|
print(f"\n🎯 加载 Post LayerNorm 权重...") |
|
|
if 'visual.vision_model.post_layernorm.weight' in weights: |
|
|
self.model.post_layernorm.weight = weights['visual.vision_model.post_layernorm.weight'] |
|
|
loaded_count += 1 |
|
|
if 'visual.vision_model.post_layernorm.bias' in weights: |
|
|
self.model.post_layernorm.bias = weights['visual.vision_model.post_layernorm.bias'] |
|
|
loaded_count += 1 |
|
|
print(f"✅ Post LayerNorm 权重加载完成 (2 个)") |
|
|
|
|
|
|
|
|
print(f"\n🔗 加载视觉投影层 (mlp_AR)...") |
|
|
mlp_ar_loaded = 0 |
|
|
|
|
|
if 'mlp_AR.pre_norm.weight' in weights: |
|
|
self.model.vision_pre_norm.weight = weights['mlp_AR.pre_norm.weight'] |
|
|
mlp_ar_loaded += 1 |
|
|
if 'mlp_AR.pre_norm.bias' in weights: |
|
|
self.model.vision_pre_norm.bias = weights['mlp_AR.pre_norm.bias'] |
|
|
mlp_ar_loaded += 1 |
|
|
if 'mlp_AR.linear_1.weight' in weights: |
|
|
self.model.vision_linear_1.weight = weights['mlp_AR.linear_1.weight'] |
|
|
mlp_ar_loaded += 1 |
|
|
if 'mlp_AR.linear_1.bias' in weights: |
|
|
self.model.vision_linear_1.bias = weights['mlp_AR.linear_1.bias'] |
|
|
mlp_ar_loaded += 1 |
|
|
if 'mlp_AR.linear_2.weight' in weights: |
|
|
self.model.vision_linear_2.weight = weights['mlp_AR.linear_2.weight'] |
|
|
mlp_ar_loaded += 1 |
|
|
if 'mlp_AR.linear_2.bias' in weights: |
|
|
self.model.vision_linear_2.bias = weights['mlp_AR.linear_2.bias'] |
|
|
mlp_ar_loaded += 1 |
|
|
|
|
|
print(f"✅ 视觉投影层加载完成 ({mlp_ar_loaded}/6 个)") |
|
|
loaded_count += mlp_ar_loaded |
|
|
|
|
|
|
|
|
print(f"\n📝 加载语言模型权重...") |
|
|
|
|
|
if 'model.embed_tokens.weight' in weights: |
|
|
self.model.embed_tokens.weight = weights['model.embed_tokens.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
for i in range(18): |
|
|
layer = self.model.layers[i] |
|
|
prefix = f'model.layers.{i}' |
|
|
|
|
|
for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']: |
|
|
if f'{prefix}.self_attn.{proj}.weight' in weights: |
|
|
getattr(layer.self_attn, proj).weight = weights[f'{prefix}.self_attn.{proj}.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
for proj in ['gate_proj', 'up_proj', 'down_proj']: |
|
|
if f'{prefix}.mlp.{proj}.weight' in weights: |
|
|
getattr(layer.mlp, proj).weight = weights[f'{prefix}.mlp.{proj}.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
for norm in ['input_layernorm', 'post_attention_layernorm']: |
|
|
if f'{prefix}.{norm}.weight' in weights: |
|
|
getattr(layer, norm).weight = weights[f'{prefix}.{norm}.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
if 'model.norm.weight' in weights: |
|
|
self.model.norm.weight = weights['model.norm.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
if 'lm_head.weight' in weights: |
|
|
self.model.lm_head.weight = weights['lm_head.weight'] |
|
|
loaded_count += 1 |
|
|
|
|
|
print(f"✅ 语言模型权重加载完成") |
|
|
print(f"\n✅ 总共成功加载 {loaded_count} 个权重") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n❌ 权重加载失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print(f"📊 权重加载完成: {loaded_count}/620") |
|
|
print("="*60) |
|
|
|
|
|
def preprocess_image(self, image_path: str, prompt: str = "Extract all text from this image.") -> Tuple[mx.array, mx.array]: |
|
|
"""使用原始 processor 预处理图像 - 关键改进!""" |
|
|
if self.processor is None: |
|
|
raise ValueError("Processor not loaded!") |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
images=image, |
|
|
text=prompt, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
pixel_values_torch = inputs['pixel_values'] |
|
|
pixel_values_np = pixel_values_torch.numpy() |
|
|
|
|
|
|
|
|
pixel_values_np = np.transpose(pixel_values_np, (0, 2, 3, 1)) |
|
|
pixel_values = mx.array(pixel_values_np) |
|
|
|
|
|
|
|
|
input_ids = mx.array(inputs['input_ids'].numpy()) |
|
|
|
|
|
print(f"✅ 图像预处理完成:") |
|
|
print(f" pixel_values: {pixel_values.shape}") |
|
|
print(f" input_ids: {input_ids.shape}") |
|
|
|
|
|
return pixel_values, input_ids |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
pixel_values: mx.array, |
|
|
input_ids: mx.array, |
|
|
max_tokens: int = 100, |
|
|
temperature: float = 0.0, |
|
|
repetition_penalty: float = 2.0, |
|
|
) -> str: |
|
|
"""生成文本""" |
|
|
|
|
|
print(f"\n🔮 开始生成...") |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
vision_embeds = self.model.encode_image(pixel_values) |
|
|
print(f"✅ 图像编码: {vision_embeds.shape} ({time.time()-start:.2f}s)") |
|
|
|
|
|
|
|
|
print(f"\n🔄 自回归生成 (max_tokens={max_tokens}, repetition_penalty={repetition_penalty})...") |
|
|
start = time.time() |
|
|
|
|
|
output_ids = [] |
|
|
current_ids = input_ids |
|
|
eos_token_id = self.tokenizer.eos_token_id if self.tokenizer else 2 |
|
|
|
|
|
for i in range(max_tokens): |
|
|
logits = self.model.forward(current_ids, vision_embeds) |
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0 and len(output_ids) > 0: |
|
|
next_token_logits = mx.array(next_token_logits) |
|
|
for token_id in set(output_ids): |
|
|
next_token_logits[0, token_id] = next_token_logits[0, token_id] / repetition_penalty |
|
|
|
|
|
if temperature == 0: |
|
|
next_token = mx.argmax(next_token_logits, axis=-1) |
|
|
else: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
probs = mx.softmax(next_token_logits, axis=-1) |
|
|
next_token = mx.random.categorical(probs) |
|
|
|
|
|
next_token_id = int(next_token[0]) |
|
|
|
|
|
if next_token_id == eos_token_id: |
|
|
print(f" 遇到 EOS token,停止生成") |
|
|
break |
|
|
|
|
|
output_ids.append(next_token_id) |
|
|
current_ids = mx.concatenate([current_ids, mx.array([[next_token_id]])], axis=1) |
|
|
|
|
|
if (i + 1) % 20 == 0: |
|
|
print(f" 生成了 {i + 1} tokens...") |
|
|
|
|
|
elapsed = time.time() - start |
|
|
print(f"✅ 生成完成: {len(output_ids)} tokens ({elapsed:.2f}s, {len(output_ids)/elapsed:.1f} tokens/s)") |
|
|
|
|
|
|
|
|
if self.tokenizer: |
|
|
result_text = self.tokenizer.decode(output_ids, skip_special_tokens=True) |
|
|
else: |
|
|
result_text = f"[Token IDs: {output_ids[:10]}...]" |
|
|
|
|
|
return result_text |
|
|
|
|
|
def ocr( |
|
|
self, |
|
|
image_path: str, |
|
|
prompt: str = "Extract all text from this image.", |
|
|
max_tokens: int = 100, |
|
|
repetition_penalty: float = 2.0, |
|
|
) -> dict: |
|
|
"""端到端 OCR""" |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("🚀 执行最终优化版 OCR") |
|
|
print("="*60) |
|
|
|
|
|
total_start = time.time() |
|
|
|
|
|
|
|
|
pixel_values, input_ids = self.preprocess_image(image_path, prompt) |
|
|
|
|
|
|
|
|
result_text = self.generate(pixel_values, input_ids, max_tokens, repetition_penalty=repetition_penalty) |
|
|
|
|
|
total_time = time.time() - total_start |
|
|
|
|
|
print(f"\n✅ OCR 完成 (总耗时: {total_time:.2f}s)") |
|
|
print("="*60) |
|
|
|
|
|
return { |
|
|
'text': result_text, |
|
|
'elapsed_time': total_time, |
|
|
'status': 'success' |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("🎯 PaddleOCR MLX 最终优化版测试") |
|
|
print("="*60) |
|
|
print(f"目标: 达到原版准确度 80-90%") |
|
|
print(f"关键改进: 使用正确的图像预处理 ⭐") |
|
|
print("="*60) |
|
|
|
|
|
model_dir = "/Users/gt/.gemini/antigravity/scratch/paddleocr-mlx-conversion" |
|
|
|
|
|
try: |
|
|
|
|
|
ocr = FinalOptimizedPaddleOCRMLX(model_dir) |
|
|
|
|
|
|
|
|
print("\n📋 创建测试图像...") |
|
|
img = Image.new('RGB', (400, 200), color='white') |
|
|
draw = ImageDraw.Draw(img) |
|
|
draw.text((50, 80), "Hello World", fill='black') |
|
|
test_path = "/tmp/test_final_mlx.png" |
|
|
img.save(test_path) |
|
|
print(f"✅ 测试图像: {test_path}") |
|
|
|
|
|
|
|
|
result = ocr.ocr(test_path, max_tokens=50, repetition_penalty=2.0) |
|
|
|
|
|
|
|
|
print(f"\n📝 OCR 结果:") |
|
|
print(f"{'='*60}") |
|
|
print(result['text']) |
|
|
print(f"{'='*60}") |
|
|
print(f"耗时: {result['elapsed_time']:.2f}s") |
|
|
|
|
|
print(f"\n🎉 最终优化版测试完成!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n❌ 错误: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|