Spaces:
Running
Running
File size: 5,377 Bytes
891e05c 473ac76 891e05c 473ac76 891e05c 473ac76 891e05c 473ac76 891e05c 473ac76 891e05c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | import hydra
import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
from src.models.loupe.configuration_loupe import LoupeConfig
from src.models.loupe.modeling_loupe import LoupeModel
from src.models.loupe.image_precessing_loupe import LoupeImageProcessor
from src.lit_model import LitModel
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(
repo_id="xxwyyds/Loupe",
filename="loupe_model/pretrained_weights/pe/PE-Core-L14-336.pt"
)
seg_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/seg/model.safetensors")
# Initialize hydra
hydra.initialize(config_path="../configs", version_base=None)
# Load model configuration
cfg = hydra.compose(config_name="infer")
cfg.model.backbone_path = ckpt_path
# seg:/home/xxw/Loupe/model_weigths/seg/model.safetensors
cfg.ckpt.checkpoint_paths = [seg_ckpt]
loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
loupe = LoupeModel(loupe_config)
model = LitModel(cfg, loupe)
processor = LoupeImageProcessor(loupe_config)
# cls:/home/xxw/Loupe/model_weigths/cls/model.safetensors
cls_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/cls/model.safetensors")
cfc = hydra.compose(config_name="infer")
cfc.ckpt.checkpoint_paths = [cls_ckpt]
cfc.model.backbone_path = ckpt_path
cls_loupe_config = LoupeConfig(stage=cfc.stage.name, **cfc.model)
cls_loupe = LoupeModel(cls_loupe_config)
cls_model = LitModel(cfc, cls_loupe)
# ffhq-7, 24, 48, 51
def predict(image):
"""Predict segmentation and classification probabilities for a single image"""
seg, cls_probs = None, None
inputs = processor([image], return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
outputs_cls = cls_model(**inputs)
if "seg" in cfg.stage.name or cfg.stage.name == "test":
segmentation = processor.post_process_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]
print(seg)
# np.savetxt("segmentation_output.csv", seg, fmt="%d", delimiter=",")
seg = Image.fromarray(
torch.where(segmentation == 0, 255, 0).numpy().astype(np.uint8)
).convert("L")
if "cls" in cfg.stage.name or cfg.stage.name == "test":
cls_probs = torch.sigmoid(outputs_cls.cls_logits).tolist()[0]
return seg, cls_probs
def visualize_result(image, seg, mask=None, alpha=0.5):
"""
Visualize detection result with different color schemes:
- With mask: TP(green), FP(red), FN(blue)
- Without mask: Predicted forgery(green)
"""
# Convert to numpy arrays
seg_np = np.array(seg, dtype=np.float32) / 255.0
if mask is not None:
mask_np = np.array(mask, dtype=np.float32) / 255.0
# Create RGB overlay
overlay = np.zeros((*seg_np.shape, 3))
if mask is not None:
# Case with mask: show TP/FP/FN
tp = (seg_np > 0.5) & (mask_np > 0.5) # True positive
fp = (seg_np > 0.5) & (mask_np <= 0.5) # False positive
fn = (mask_np > 0.5) & (seg_np <= 0.5) # False negative
overlay[tp] = [0, 1, 0] # Green for TP
overlay[fp] = [1, 0, 0] # Red for FP
overlay[fn] = [0, 0, 1] # Blue for FN
else:
# Case without mask: only show predicted regions
overlay[seg_np > 0.5] = [0, 1, 0] # Green for predicted
# Create transparent overlay
overlay_img = Image.fromarray((overlay * 255).astype(np.uint8))
alpha_layer = Image.fromarray(((seg_np > 0.5) * alpha * 255).astype(np.uint8), "L")
# Composite with original image
base_img = image.convert("RGBA")
overlay_img = overlay_img.convert("RGBA")
overlay_img.putalpha(alpha_layer)
result = Image.alpha_composite(base_img, overlay_img)
return result
def process_single_image(image_path, mask_path=None):
"""Process a single image with optional mask"""
# Load images
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("L") if mask_path else None
# Get predictions
seg, cls_probs = predict(image)
print(f"Classification probability: {cls_probs[0]:.4f}" if cls_probs else "No cls output")
# Visualize
result = visualize_result(image, seg, mask)
# Display
plt.figure(figsize=(10, 5))
if mask is not None:
plt.subplot(1, 2, 1)
plt.imshow(mask, cmap='gray')
plt.title("Ground Truth Mask")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(result)
title = "Detection Result (With Mask)" if mask else "Detection Result"
plt.title(title)
plt.axis('off')
# Save result
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.splitext(os.path.basename(image_path))[0]
output_path = os.path.join(output_dir, f"{filename}_result_{timestamp}.png")
result.save(output_path)
print(f"Result saved to {output_path}")
plt.show()
return result,cls_probs
if __name__ == "__main__":
# Example usage:
# Case 1: With mask
# process_single_image("tampered_image.png", "tampered_mask.png")
# Case 2: Without mask
process_single_image("ffhq/ffhq-0001.png") |