TraceDetect-AI / image_module.py
nchdlhbctm's picture
Upload 13 files
646535a verified
Raw
History Blame Contribute Delete
6.19 kB
import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from skimage import feature
from scipy.fftpack import dct
from PIL import Image, ImageOps
import streamlit as st
# 引入热力图相关库
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
# ==========================================
# 辅助函数:智能读取图像(修正 EXIF 方向)
# ==========================================
def get_pil_image(image_input):
if isinstance(image_input, str):
img = Image.open(image_input)
else:
image_input.seek(0)
img = Image.open(image_input)
# 强制应用隐藏的 EXIF 旋转信息,还原真实方向
img = ImageOps.exif_transpose(img)
return img.convert('RGB')
def get_cv_image(image_input):
pil_img = get_pil_image(image_input)
# 保证 OpenCV 提取物理特征时,方向与 PIL 绝对一致
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
# ==========================================
# 第一部分:传统物理特征检测 (权重 15%)
# ==========================================
def extract_traditional_features(image_input):
img = get_cv_image(image_input)
if img is None:
return 0.0
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 1. LBP 局部二值模式
lbp = feature.local_binary_pattern(gray, P=8, R=1, method="uniform")
hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 11), density=True)
lbp_entropy = -np.sum(hist * np.log2(hist + 1e-7))
# 2. DCT 离散余弦变换
dct_data = dct(dct(gray.T, norm='ortho').T, norm='ortho')
high_freq = np.sum(np.abs(dct_data[10:, 10:]))
total_energy = np.sum(np.abs(dct_data))
dct_ratio = high_freq / (total_energy + 1e-7)
# 3. FFT 快速傅里叶变换
f = np.fft.fft2(gray)
fshift = np.fft.fftshift(f)
magnitude = 20 * np.log(np.abs(fshift) + 1)
h, w = magnitude.shape
center_h, center_w = h // 2, w // 2
top_left = magnitude[0:center_h, 0:center_w]
bottom_right = np.flip(magnitude[center_h + (h % 2):, center_w + (w % 2):])
if top_left.shape != bottom_right.shape:
bottom_right = cv2.resize(bottom_right, (top_left.shape[1], top_left.shape[0]))
fft_sym_error = np.mean(np.abs(top_left - bottom_right))
score = 0.0
if lbp_entropy > 3.6: score += 0.3
if dct_ratio < 0.985: score += 0.1
if fft_sym_error > 13.8: score += 0.1
return min(score, 1.0)
# ==========================================
# 第二部分:深度模型缓存与提取
# ==========================================
@st.cache_resource
def load_deep_image_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.mobilenet_v2(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 2)
# 加载自己炼丹生成的权重文件
model.load_state_dict(torch.load('mobilenet_finetuned.pth', map_location=device))
model = model.to(device)
model.eval()
return model, device
def extract_deep_features(image_input, model, device):
pil_img = get_pil_image(image_input)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(pil_img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
probs = torch.softmax(outputs, dim=1)
fake_prob = probs[0][0].item()
return fake_prob
# ==========================================
# 第三部分:多模态加权融合引擎
# ==========================================
def analyze_image(image_input):
trad_score = extract_traditional_features(image_input)
model, device = load_deep_image_model()
deep_score = extract_deep_features(image_input, model, device)
final_prob = (trad_score * 0.15) + (deep_score * 0.85)
return {
'traditional_score': trad_score,
'deep_score': deep_score,
'final_probability': final_prob
}
# ==========================================
# 第四部分:XAI 热力图渲染引擎
# ==========================================
def generate_image_heatmap(image_input, model, device):
"""使用 Grad-CAM 生成热力图,并无损拉伸回原图尺寸,带动态透明度"""
# ⚠️ 获取原图,绝对没有任何 resize 操作干扰!
original_pil_img = get_pil_image(image_input)
orig_width, orig_height = original_pil_img.size
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(original_pil_img).unsqueeze(0).to(device)
target_layers = [model.features[-1]]
try:
with GradCAM(model=model, target_layers=target_layers) as cam:
targets = [ClassifierOutputTarget(0)]
# 跑出 224x224 的 0~1 原始热力图矩阵
grayscale_cam_224 = cam(input_tensor=input_tensor, targets=targets)[0, :]
# 将 224x224 的热力图拉伸回真实的高清原图宽高!
grayscale_cam_resized = cv2.resize(grayscale_cam_224, (orig_width, orig_height))
heatmap_color = cv2.applyColorMap(np.uint8(255 * grayscale_cam_resized), cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
heatmap_color = np.float32(heatmap_color) / 255.0
orig_img_array = np.array(original_pil_img, dtype=np.float32) / 255.0
alpha = grayscale_cam_resized[..., np.newaxis]
alpha = np.power(alpha, 1.5) * 0.65
visualization = orig_img_array * (1 - alpha) + heatmap_color * alpha
return np.uint8(255 * visualization)
except Exception as e:
return np.array(original_pil_img)