pp-nsfw_Inspector / src /preprocess /scene_classifier.py
philcuriosity1024's picture
Upload folder using huggingface_hub
670cf0c verified
Raw
History Blame Contribute Delete
1.42 kB
"""场景分类:规则驱动,无需模型"""
import math
from enum import Enum
from dataclasses import dataclass
import numpy as np
from PIL import Image
class Scene(str, Enum):
SCREENSHOT = "screenshot" # PNG + 宽屏 w/h > 1.4
DOCUMENT = "document" # 长宽比接近 A4, h/w > 1.3
POSTER = "poster" # 低熵值 < 5 + JPEG
UNKNOWN = "unknown" # 最严处理
# OCR 置信度阈值
OCR_THRESHOLD = {
Scene.SCREENSHOT: 0.85,
Scene.DOCUMENT: 0.75,
Scene.POSTER: 0.65,
Scene.UNKNOWN: 0.85,
}
@dataclass
class SceneResult:
scene: Scene
ocr_threshold: float
def _image_entropy(img: Image.Image) -> float:
gray = img.convert("L")
hist = gray.histogram()
total = sum(hist)
entropy = 0.0
for count in hist:
if count > 0:
p = count / total
entropy -= p * math.log2(p)
return entropy
def classify_scene(img: Image.Image, fmt: str) -> SceneResult:
"""
img: PIL Image
fmt: 文件格式,如 'PNG', 'JPEG'
"""
w, h = img.size
fmt = fmt.upper()
if fmt == "PNG" and w / h > 1.4:
scene = Scene.SCREENSHOT
elif h / w > 1.3:
scene = Scene.DOCUMENT
elif fmt in ("JPEG", "JPG") and _image_entropy(img) < 5:
scene = Scene.POSTER
else:
scene = Scene.UNKNOWN
return SceneResult(scene=scene, ocr_threshold=OCR_THRESHOLD[scene])