import hydra import os import numpy as np import torch from PIL import Image import matplotlib.pyplot as plt from datetime import datetime from src.models.loupe.configuration_loupe import LoupeConfig from src.models.loupe.modeling_loupe import LoupeModel from src.models.loupe.image_precessing_loupe import LoupeImageProcessor from src.lit_model import LitModel from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download( repo_id="xxwyyds/Loupe", filename="loupe_model/pretrained_weights/pe/PE-Core-L14-336.pt" ) seg_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/seg/model.safetensors") # Initialize hydra hydra.initialize(config_path="../configs", version_base=None) # Load model configuration cfg = hydra.compose(config_name="infer") cfg.model.backbone_path = ckpt_path # seg:/home/xxw/Loupe/model_weigths/seg/model.safetensors cfg.ckpt.checkpoint_paths = [seg_ckpt] loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model) loupe = LoupeModel(loupe_config) model = LitModel(cfg, loupe) processor = LoupeImageProcessor(loupe_config) # cls:/home/xxw/Loupe/model_weigths/cls/model.safetensors cls_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/cls/model.safetensors") cfc = hydra.compose(config_name="infer") cfc.ckpt.checkpoint_paths = [cls_ckpt] cfc.model.backbone_path = ckpt_path cls_loupe_config = LoupeConfig(stage=cfc.stage.name, **cfc.model) cls_loupe = LoupeModel(cls_loupe_config) cls_model = LitModel(cfc, cls_loupe) # ffhq-7, 24, 48, 51 def predict(image): """Predict segmentation and classification probabilities for a single image""" seg, cls_probs = None, None inputs = processor([image], return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) outputs_cls = cls_model(**inputs) if "seg" in cfg.stage.name or cfg.stage.name == "test": segmentation = processor.post_process_segmentation( outputs, target_sizes=[image.size[::-1]] )[0] print(seg) # np.savetxt("segmentation_output.csv", seg, fmt="%d", delimiter=",") seg = Image.fromarray( torch.where(segmentation == 0, 255, 0).numpy().astype(np.uint8) ).convert("L") if "cls" in cfg.stage.name or cfg.stage.name == "test": cls_probs = torch.sigmoid(outputs_cls.cls_logits).tolist()[0] return seg, cls_probs def visualize_result(image, seg, mask=None, alpha=0.5): """ Visualize detection result with different color schemes: - With mask: TP(green), FP(red), FN(blue) - Without mask: Predicted forgery(green) """ # Convert to numpy arrays seg_np = np.array(seg, dtype=np.float32) / 255.0 if mask is not None: mask_np = np.array(mask, dtype=np.float32) / 255.0 # Create RGB overlay overlay = np.zeros((*seg_np.shape, 3)) if mask is not None: # Case with mask: show TP/FP/FN tp = (seg_np > 0.5) & (mask_np > 0.5) # True positive fp = (seg_np > 0.5) & (mask_np <= 0.5) # False positive fn = (mask_np > 0.5) & (seg_np <= 0.5) # False negative overlay[tp] = [0, 1, 0] # Green for TP overlay[fp] = [1, 0, 0] # Red for FP overlay[fn] = [0, 0, 1] # Blue for FN else: # Case without mask: only show predicted regions overlay[seg_np > 0.5] = [0, 1, 0] # Green for predicted # Create transparent overlay overlay_img = Image.fromarray((overlay * 255).astype(np.uint8)) alpha_layer = Image.fromarray(((seg_np > 0.5) * alpha * 255).astype(np.uint8), "L") # Composite with original image base_img = image.convert("RGBA") overlay_img = overlay_img.convert("RGBA") overlay_img.putalpha(alpha_layer) result = Image.alpha_composite(base_img, overlay_img) return result def process_single_image(image_path, mask_path=None): """Process a single image with optional mask""" # Load images image = Image.open(image_path).convert("RGB") mask = Image.open(mask_path).convert("L") if mask_path else None # Get predictions seg, cls_probs = predict(image) print(f"Classification probability: {cls_probs[0]:.4f}" if cls_probs else "No cls output") # Visualize result = visualize_result(image, seg, mask) # Display plt.figure(figsize=(10, 5)) if mask is not None: plt.subplot(1, 2, 1) plt.imshow(mask, cmap='gray') plt.title("Ground Truth Mask") plt.axis('off') plt.subplot(1, 2, 2) plt.imshow(result) title = "Detection Result (With Mask)" if mask else "Detection Result" plt.title(title) plt.axis('off') # Save result output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = os.path.splitext(os.path.basename(image_path))[0] output_path = os.path.join(output_dir, f"{filename}_result_{timestamp}.png") result.save(output_path) print(f"Result saved to {output_path}") plt.show() return result,cls_probs if __name__ == "__main__": # Example usage: # Case 1: With mask # process_single_image("tampered_image.png", "tampered_mask.png") # Case 2: Without mask process_single_image("ffhq/ffhq-0001.png")