| """场景分类:规则驱动,无需模型""" | |
| import math | |
| from enum import Enum | |
| from dataclasses import dataclass | |
| import numpy as np | |
| from PIL import Image | |
| class Scene(str, Enum): | |
| SCREENSHOT = "screenshot" # PNG + 宽屏 w/h > 1.4 | |
| DOCUMENT = "document" # 长宽比接近 A4, h/w > 1.3 | |
| POSTER = "poster" # 低熵值 < 5 + JPEG | |
| UNKNOWN = "unknown" # 最严处理 | |
| # OCR 置信度阈值 | |
| OCR_THRESHOLD = { | |
| Scene.SCREENSHOT: 0.85, | |
| Scene.DOCUMENT: 0.75, | |
| Scene.POSTER: 0.65, | |
| Scene.UNKNOWN: 0.85, | |
| } | |
| class SceneResult: | |
| scene: Scene | |
| ocr_threshold: float | |
| def _image_entropy(img: Image.Image) -> float: | |
| gray = img.convert("L") | |
| hist = gray.histogram() | |
| total = sum(hist) | |
| entropy = 0.0 | |
| for count in hist: | |
| if count > 0: | |
| p = count / total | |
| entropy -= p * math.log2(p) | |
| return entropy | |
| def classify_scene(img: Image.Image, fmt: str) -> SceneResult: | |
| """ | |
| img: PIL Image | |
| fmt: 文件格式,如 'PNG', 'JPEG' | |
| """ | |
| w, h = img.size | |
| fmt = fmt.upper() | |
| if fmt == "PNG" and w / h > 1.4: | |
| scene = Scene.SCREENSHOT | |
| elif h / w > 1.3: | |
| scene = Scene.DOCUMENT | |
| elif fmt in ("JPEG", "JPG") and _image_entropy(img) < 5: | |
| scene = Scene.POSTER | |
| else: | |
| scene = Scene.UNKNOWN | |
| return SceneResult(scene=scene, ocr_threshold=OCR_THRESHOLD[scene]) | |