Spaces:
Running
Running
| import hydra | |
| import os | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from datetime import datetime | |
| from src.models.loupe.configuration_loupe import LoupeConfig | |
| from src.models.loupe.modeling_loupe import LoupeModel | |
| from src.models.loupe.image_precessing_loupe import LoupeImageProcessor | |
| from src.lit_model import LitModel | |
| from huggingface_hub import hf_hub_download | |
| ckpt_path = hf_hub_download( | |
| repo_id="xxwyyds/Loupe", | |
| filename="loupe_model/pretrained_weights/pe/PE-Core-L14-336.pt" | |
| ) | |
| seg_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/seg/model.safetensors") | |
| # Initialize hydra | |
| hydra.initialize(config_path="../configs", version_base=None) | |
| # Load model configuration | |
| cfg = hydra.compose(config_name="infer") | |
| cfg.model.backbone_path = ckpt_path | |
| # seg:/home/xxw/Loupe/model_weigths/seg/model.safetensors | |
| cfg.ckpt.checkpoint_paths = [seg_ckpt] | |
| loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model) | |
| loupe = LoupeModel(loupe_config) | |
| model = LitModel(cfg, loupe) | |
| processor = LoupeImageProcessor(loupe_config) | |
| # cls:/home/xxw/Loupe/model_weigths/cls/model.safetensors | |
| cls_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/cls/model.safetensors") | |
| cfc = hydra.compose(config_name="infer") | |
| cfc.ckpt.checkpoint_paths = [cls_ckpt] | |
| cfc.model.backbone_path = ckpt_path | |
| cls_loupe_config = LoupeConfig(stage=cfc.stage.name, **cfc.model) | |
| cls_loupe = LoupeModel(cls_loupe_config) | |
| cls_model = LitModel(cfc, cls_loupe) | |
| # ffhq-7, 24, 48, 51 | |
| def predict(image): | |
| """Predict segmentation and classification probabilities for a single image""" | |
| seg, cls_probs = None, None | |
| inputs = processor([image], return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| outputs_cls = cls_model(**inputs) | |
| if "seg" in cfg.stage.name or cfg.stage.name == "test": | |
| segmentation = processor.post_process_segmentation( | |
| outputs, target_sizes=[image.size[::-1]] | |
| )[0] | |
| print(seg) | |
| # np.savetxt("segmentation_output.csv", seg, fmt="%d", delimiter=",") | |
| seg = Image.fromarray( | |
| torch.where(segmentation == 0, 255, 0).numpy().astype(np.uint8) | |
| ).convert("L") | |
| if "cls" in cfg.stage.name or cfg.stage.name == "test": | |
| cls_probs = torch.sigmoid(outputs_cls.cls_logits).tolist()[0] | |
| return seg, cls_probs | |
| def visualize_result(image, seg, mask=None, alpha=0.5): | |
| """ | |
| Visualize detection result with different color schemes: | |
| - With mask: TP(green), FP(red), FN(blue) | |
| - Without mask: Predicted forgery(green) | |
| """ | |
| # Convert to numpy arrays | |
| seg_np = np.array(seg, dtype=np.float32) / 255.0 | |
| if mask is not None: | |
| mask_np = np.array(mask, dtype=np.float32) / 255.0 | |
| # Create RGB overlay | |
| overlay = np.zeros((*seg_np.shape, 3)) | |
| if mask is not None: | |
| # Case with mask: show TP/FP/FN | |
| tp = (seg_np > 0.5) & (mask_np > 0.5) # True positive | |
| fp = (seg_np > 0.5) & (mask_np <= 0.5) # False positive | |
| fn = (mask_np > 0.5) & (seg_np <= 0.5) # False negative | |
| overlay[tp] = [0, 1, 0] # Green for TP | |
| overlay[fp] = [1, 0, 0] # Red for FP | |
| overlay[fn] = [0, 0, 1] # Blue for FN | |
| else: | |
| # Case without mask: only show predicted regions | |
| overlay[seg_np > 0.5] = [0, 1, 0] # Green for predicted | |
| # Create transparent overlay | |
| overlay_img = Image.fromarray((overlay * 255).astype(np.uint8)) | |
| alpha_layer = Image.fromarray(((seg_np > 0.5) * alpha * 255).astype(np.uint8), "L") | |
| # Composite with original image | |
| base_img = image.convert("RGBA") | |
| overlay_img = overlay_img.convert("RGBA") | |
| overlay_img.putalpha(alpha_layer) | |
| result = Image.alpha_composite(base_img, overlay_img) | |
| return result | |
| def process_single_image(image_path, mask_path=None): | |
| """Process a single image with optional mask""" | |
| # Load images | |
| image = Image.open(image_path).convert("RGB") | |
| mask = Image.open(mask_path).convert("L") if mask_path else None | |
| # Get predictions | |
| seg, cls_probs = predict(image) | |
| print(f"Classification probability: {cls_probs[0]:.4f}" if cls_probs else "No cls output") | |
| # Visualize | |
| result = visualize_result(image, seg, mask) | |
| # Display | |
| plt.figure(figsize=(10, 5)) | |
| if mask is not None: | |
| plt.subplot(1, 2, 1) | |
| plt.imshow(mask, cmap='gray') | |
| plt.title("Ground Truth Mask") | |
| plt.axis('off') | |
| plt.subplot(1, 2, 2) | |
| plt.imshow(result) | |
| title = "Detection Result (With Mask)" if mask else "Detection Result" | |
| plt.title(title) | |
| plt.axis('off') | |
| # Save result | |
| output_dir = "outputs" | |
| os.makedirs(output_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = os.path.splitext(os.path.basename(image_path))[0] | |
| output_path = os.path.join(output_dir, f"{filename}_result_{timestamp}.png") | |
| result.save(output_path) | |
| print(f"Result saved to {output_path}") | |
| plt.show() | |
| return result,cls_probs | |
| if __name__ == "__main__": | |
| # Example usage: | |
| # Case 1: With mask | |
| # process_single_image("tampered_image.png", "tampered_mask.png") | |
| # Case 2: Without mask | |
| process_single_image("ffhq/ffhq-0001.png") |