dftest1 / src /features /noiseprint_wrapper.py
akcanca's picture
Create noiseprint_wrapper.py
5496367 verified
"""
Wrapper for Noiseprint that allows QF override to handle Gradio PIL Images properly.
"""
import numpy as np
import torch
from torchvision import transforms
from src.features.noiseprint.Noiseprint import load_noiseprint_model
from src.features.noiseprint.utilityRead import imread2f_pil
def getNoiseprint_with_qf(image_path, qf_override=None):
"""
Wrapper around getNoiseprint that allows QF override.
Args:
image_path: Path to image file
qf_override: Optional QF value to use instead of detecting from file
Returns:
(img, noiseprint) tuple
"""
img, mode = imread2f_pil(image_path, channel=1)
slide = 1024
largeLimit = 1050000
overlap = 34
transform = transforms.ToTensor()
# Use override QF if provided, otherwise detect from file
if qf_override is not None:
QF = qf_override
else:
from src.features.noiseprint.utilityRead import jpeg_qtableinv
try:
QF = jpeg_qtableinv(image_path)
except:
QF = 101
net = load_noiseprint_model(QF)
with torch.no_grad():
if img.shape[0]*img.shape[1] > largeLimit:
print(' %dx%d large %3d' % (img.shape[0], img.shape[1], QF))
res = np.zeros((img.shape[0],img.shape[1]), np.float32)
for index0 in range(0,img.shape[0],slide):
index0start = index0-overlap
index0end = index0+slide+overlap
for index1 in range(0,img.shape[1],slide):
index1start = index1-overlap
index1end = index1+slide+overlap
clip = img[max(index0start, 0): min(index0end, img.shape[0]), \
max(index1start, 0): min(index1end, img.shape[1])]
tensor_image = transform(clip)
tensor_image = tensor_image.reshape(1,1,tensor_image.shape[1],tensor_image.shape[2])
tensor_image = tensor_image.to(next(net.parameters()).device)
resB = net(tensor_image)
resB = resB[0][0].cpu().numpy()
if index0>0:
resB = resB[overlap:, :]
if index1>0:
resB = resB[:, overlap:]
resB = resB[:min(slide,resB.shape[0]), :min(slide,resB.shape[1])]
res[index0: min(index0+slide, res.shape[0]), \
index1: min(index1+slide, res.shape[1])] = resB
noiseprint = res
else:
tensor_image = transform(img)
tensor_image = tensor_image.reshape(1,1,tensor_image.shape[1],tensor_image.shape[2])
tensor_image = tensor_image.to(next(net.parameters()).device)
res = net(tensor_image)
noiseprint = (res[0][0]).cpu().numpy()
return img, noiseprint