Loupe / src /predict.py
xxwyyds's picture
Update src/predict.py
473ac76 verified
import hydra
import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
from src.models.loupe.configuration_loupe import LoupeConfig
from src.models.loupe.modeling_loupe import LoupeModel
from src.models.loupe.image_precessing_loupe import LoupeImageProcessor
from src.lit_model import LitModel
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(
repo_id="xxwyyds/Loupe",
filename="loupe_model/pretrained_weights/pe/PE-Core-L14-336.pt"
)
seg_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/seg/model.safetensors")
# Initialize hydra
hydra.initialize(config_path="../configs", version_base=None)
# Load model configuration
cfg = hydra.compose(config_name="infer")
cfg.model.backbone_path = ckpt_path
# seg:/home/xxw/Loupe/model_weigths/seg/model.safetensors
cfg.ckpt.checkpoint_paths = [seg_ckpt]
loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
loupe = LoupeModel(loupe_config)
model = LitModel(cfg, loupe)
processor = LoupeImageProcessor(loupe_config)
# cls:/home/xxw/Loupe/model_weigths/cls/model.safetensors
cls_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/cls/model.safetensors")
cfc = hydra.compose(config_name="infer")
cfc.ckpt.checkpoint_paths = [cls_ckpt]
cfc.model.backbone_path = ckpt_path
cls_loupe_config = LoupeConfig(stage=cfc.stage.name, **cfc.model)
cls_loupe = LoupeModel(cls_loupe_config)
cls_model = LitModel(cfc, cls_loupe)
# ffhq-7, 24, 48, 51
def predict(image):
"""Predict segmentation and classification probabilities for a single image"""
seg, cls_probs = None, None
inputs = processor([image], return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
outputs_cls = cls_model(**inputs)
if "seg" in cfg.stage.name or cfg.stage.name == "test":
segmentation = processor.post_process_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]
print(seg)
# np.savetxt("segmentation_output.csv", seg, fmt="%d", delimiter=",")
seg = Image.fromarray(
torch.where(segmentation == 0, 255, 0).numpy().astype(np.uint8)
).convert("L")
if "cls" in cfg.stage.name or cfg.stage.name == "test":
cls_probs = torch.sigmoid(outputs_cls.cls_logits).tolist()[0]
return seg, cls_probs
def visualize_result(image, seg, mask=None, alpha=0.5):
"""
Visualize detection result with different color schemes:
- With mask: TP(green), FP(red), FN(blue)
- Without mask: Predicted forgery(green)
"""
# Convert to numpy arrays
seg_np = np.array(seg, dtype=np.float32) / 255.0
if mask is not None:
mask_np = np.array(mask, dtype=np.float32) / 255.0
# Create RGB overlay
overlay = np.zeros((*seg_np.shape, 3))
if mask is not None:
# Case with mask: show TP/FP/FN
tp = (seg_np > 0.5) & (mask_np > 0.5) # True positive
fp = (seg_np > 0.5) & (mask_np <= 0.5) # False positive
fn = (mask_np > 0.5) & (seg_np <= 0.5) # False negative
overlay[tp] = [0, 1, 0] # Green for TP
overlay[fp] = [1, 0, 0] # Red for FP
overlay[fn] = [0, 0, 1] # Blue for FN
else:
# Case without mask: only show predicted regions
overlay[seg_np > 0.5] = [0, 1, 0] # Green for predicted
# Create transparent overlay
overlay_img = Image.fromarray((overlay * 255).astype(np.uint8))
alpha_layer = Image.fromarray(((seg_np > 0.5) * alpha * 255).astype(np.uint8), "L")
# Composite with original image
base_img = image.convert("RGBA")
overlay_img = overlay_img.convert("RGBA")
overlay_img.putalpha(alpha_layer)
result = Image.alpha_composite(base_img, overlay_img)
return result
def process_single_image(image_path, mask_path=None):
"""Process a single image with optional mask"""
# Load images
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("L") if mask_path else None
# Get predictions
seg, cls_probs = predict(image)
print(f"Classification probability: {cls_probs[0]:.4f}" if cls_probs else "No cls output")
# Visualize
result = visualize_result(image, seg, mask)
# Display
plt.figure(figsize=(10, 5))
if mask is not None:
plt.subplot(1, 2, 1)
plt.imshow(mask, cmap='gray')
plt.title("Ground Truth Mask")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(result)
title = "Detection Result (With Mask)" if mask else "Detection Result"
plt.title(title)
plt.axis('off')
# Save result
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.splitext(os.path.basename(image_path))[0]
output_path = os.path.join(output_dir, f"{filename}_result_{timestamp}.png")
result.save(output_path)
print(f"Result saved to {output_path}")
plt.show()
return result,cls_probs
if __name__ == "__main__":
# Example usage:
# Case 1: With mask
# process_single_image("tampered_image.png", "tampered_mask.png")
# Case 2: Without mask
process_single_image("ffhq/ffhq-0001.png")