DeepSeek-OCR-4bit-Quantized / test_benchmark.py
WHY2001's picture
Upload folder using huggingface_hub
0695d0b verified
import os
# ================= 🔧 强制单卡模式 (最优先执行) =================
# 这行代码必须在 import torch 之前!
# 它可以解决 "RuntimeError: ... cuda:1 and cuda:0" 的双卡冲突问题
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# ==============================================================
import torch
import time
import gc
import shutil
import psutil
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
# ===================== ⚙️ 用户配置区域 =====================
# 1. 模型路径
MODEL_PATH = "/home/nashen/deepseek-ocr/DeepSeek-OCR-master/DeepSeek-OCR-vllm/model/"
# 2. 测试图片 (请确保同目录下有这张图,随便找张图改名即可)
TEST_IMAGE = "./chart_with_line.jpg"
# 3. 结果保存位置
RESULT_ROOT = "./benchmark_output_chart_4bit"
# ==========================================================
class VRAMMonitor:
"""显存监控工具"""
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def get_current_mem(self):
if self.device == "cpu": return 0
return torch.cuda.memory_allocated() / (1024 ** 3)
def reset_peak(self):
if self.device != "cpu":
torch.cuda.reset_peak_memory_stats()
def get_peak_mem(self):
if self.device == "cpu": return 0
return torch.cuda.max_memory_allocated() / (1024 ** 3)
def run_evaluation(mode="original"):
# 准备输出目录
output_dir = os.path.join(RESULT_ROOT, mode)
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
monitor = VRAMMonitor()
print(f"\n{'='*40}")
print(f"🚀 正在启动评测: [{mode.upper()}] 模式")
print(f"📂 结果将保存至: {output_dir}")
print(f"{'='*40}")
# 1. 环境清理
gc.collect()
torch.cuda.empty_cache()
monitor.reset_peak()
# 2. 加载模型
print("⏳ [1/3] 正在加载模型...")
start_load_time = time.time()
try:
# 加载 Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
# 补丁:消除 pad_token 警告
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if mode == "quantized":
# === 量化版配置 (4-bit) ===
q_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
# 关键:跳过视觉层和输出层,防止精度崩坏
llm_int8_skip_modules=[
# 1. 修复报错的关键 (SAM 模型处理 4D 图像)
"sam_model",
"model.sam_model",
# 2. 视觉主干
"vision_model",
"model.vision_model",
# 3. 修复表格竖线的关键 (投影层)
"projector",
"model.projector",
# 4. 基础 LLM 保护
"lm_head",
"embed_tokens"
]
)
model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
quantization_config=q_config,
device_map="auto" # 这里会自动映射到我们指定的单卡 CUDA:0
)
else:
# === 原版配置 (BF16) ===
model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 确保 Use Cache 开启
model.config.use_cache = True
model.eval()
except Exception as e:
import traceback
traceback.print_exc()
print(f"❌ 加载失败: {e}")
return None
load_time = time.time() - start_load_time
static_vram = monitor.get_current_mem()
print(f" ✅ 加载耗时: {load_time:.2f}s | 静态显存: {static_vram:.2f} GB")
# 3. 推理测试
print("⏳ [2/3] 正在运行 OCR 推理...")
if not os.path.exists(TEST_IMAGE):
print(f"❌ 错误: 没找到测试图片 {TEST_IMAGE}!")
return None
monitor.reset_peak()
prompt = "<image>\nConvert the image to markdown format. Preserve all table structures and separators carefully."
try:
start_infer_time = time.time()
with torch.no_grad():
res = model.infer(
tokenizer,
prompt=prompt,
image_file=TEST_IMAGE,
base_size=1024,
image_size=640,
crop_mode=True,
# 关键参数:防止 output_path 报错
output_path=output_dir,
save_results=True
)
infer_time = time.time() - start_infer_time
peak_vram = monitor.get_peak_mem()
print(f" ✅ 推理耗时: {infer_time:.2f}s | 峰值显存: {peak_vram:.2f} GB")
# 4. 保存文本结果
text_save_path = os.path.join(output_dir, "full_result.md")
with open(text_save_path, "w", encoding="utf-8") as f:
f.write(res)
print(f" 💾 结果已保存: {text_save_path}")
except Exception as e:
import traceback
traceback.print_exc()
print(f"❌ 推理过程报错: {e}")
return None
# 5. 清理内存 (为下一轮腾地方)
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()
return {
"mode": mode,
"static_vram": static_vram,
"infer_time": infer_time,
"peak_vram": peak_vram,
"result_head": res[:100].replace('\n', ' ') + "..."
}
# ===================== 🏆 主程序 =====================
if __name__ == "__main__":
if not os.path.exists(TEST_IMAGE):
print(f"⚠️ 请准备一张图片,重命名为 {TEST_IMAGE} 放在脚本同目录下!")
exit()
print("🏁 开始 A/B 对比测试...")
# 1. 跑量化版
res_quant = run_evaluation("quantized")
# 2. 跑原版
res_origin = run_evaluation("original")
if res_quant and res_origin:
mem_saved = (1 - res_quant['peak_vram'] / res_origin['peak_vram']) * 100
print("\n\n")
print(f"{'='*30} 📊 最终数据对比报告 {'='*30}")
print(f"{'指标':<20} | {'原版 (Original)':<18} | {'量化版 (Quantized)':<18} | {'变化'}")
print("-" * 90)
print(f"{'静态显存 (Static)':<20} | {res_origin['static_vram']:.2f} GB {'':<9} | {res_quant['static_vram']:.2f} GB {'':<9} | 📉 节省 {100*(1-res_quant['static_vram']/res_origin['static_vram']):.1f}%")
print(f"{'峰值显存 (Peak)':<20} | {res_origin['peak_vram']:.2f} GB {'':<9} | {res_quant['peak_vram']:.2f} GB {'':<9} | 📉 节省 {mem_saved:.1f}%")
print(f"{'推理时间 (Time)':<20} | {res_origin['infer_time']:.2f} s {'':<10} | {res_quant['infer_time']:.2f} s {'':<10} | {'🟢' if res_quant['infer_time']<res_origin['infer_time'] else '🟡'}{abs(res_origin['infer_time']-res_quant['infer_time']):.2f}s")
print("-" * 90)