Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| import numpy as np | |
| from pathlib import Path | |
| import argparse | |
| # You'll need to have the DiTo codebase available | |
| import models | |
| from omegaconf import OmegaConf | |
| class DiToInference: | |
| def __init__(self, checkpoint_path, device='cuda'): | |
| """Initialize DiTo model from checkpoint""" | |
| self.device = device | |
| # Load checkpoint | |
| print(f"Loading checkpoint from {checkpoint_path}") | |
| ckpt = torch.load(checkpoint_path, map_location='cpu') | |
| # Extract config | |
| self.config = OmegaConf.create(ckpt['config']) | |
| # Create model | |
| self.model = models.make(self.config['model']) | |
| # Load state dict | |
| self.model.load_state_dict(ckpt['model']['sd']) | |
| # Move to device and set to eval | |
| self.model = self.model.to(device) | |
| self.model.eval() | |
| # Setup image transforms based on config | |
| self.transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(256), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| print("Model loaded successfully!") | |
| def reconstruct_image(self, image_path, debug=True): | |
| """Reconstruct a single image""" | |
| # Load and preprocess image | |
| image = Image.open(image_path).convert('RGB') | |
| if debug: | |
| debug_transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(256), | |
| ]) | |
| debug_image = debug_transform(image) | |
| debug_image.save('debug_1_resized_cropped.png') | |
| print("Saved debug_1_resized_cropped.png") | |
| image_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| # Step 1: Encode to latent | |
| z = self.model.encode(image_tensor) | |
| # Step 2: Decode to features (in DiTo this is identity) | |
| z_dec = self.model.decode(z) | |
| print('z_dec.shape:', z_dec.shape) | |
| # Step 3: Prepare coordinate grids | |
| # Based on the training code, coord and scale are dummy values | |
| b, c, h, w = image_tensor.shape | |
| coord = torch.zeros(b, 2, h, w, device=self.device) | |
| scale = torch.zeros(b, 2, h, w, device=self.device) | |
| # Step 4: Render using diffusion | |
| reconstructed = self.model.render(z_dec, coord, scale) | |
| # Denormalize from [-1, 1] to [0, 1] | |
| reconstructed = (reconstructed * 0.5 + 0.5).clamp(0, 1) | |
| return reconstructed | |
| def save_reconstruction(self, image_path, output_path): | |
| """Reconstruct and save image""" | |
| reconstructed = self.reconstruct_image(image_path) | |
| # Convert to PIL | |
| to_pil = transforms.ToPILImage() | |
| reconstructed_pil = to_pil(reconstructed.squeeze(0).cpu()) | |
| # Save | |
| reconstructed_pil.save(output_path) | |
| print(f"Saved reconstruction to {output_path}") | |
| def compare_reconstruction(self, image_path, output_path): | |
| """Save original and reconstruction side by side""" | |
| # Get reconstruction | |
| reconstructed = self.reconstruct_image(image_path) | |
| # Load original at same resolution | |
| original = Image.open(image_path).convert('RGB') | |
| original = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(256), | |
| transforms.ToTensor() | |
| ])(original).unsqueeze(0) | |
| # Concatenate side by side | |
| comparison = torch.cat([original, reconstructed.cpu()], dim=3) | |
| # Save | |
| to_pil = transforms.ToPILImage() | |
| comparison_pil = to_pil(comparison.squeeze(0)) | |
| comparison_pil.save(output_path) | |
| print(f"Saved comparison to {output_path}") | |
| def batch_reconstruct(self, image_folder, output_folder, max_images=None): | |
| """Reconstruct all images in a folder""" | |
| image_folder = Path(image_folder) | |
| output_folder = Path(output_folder) | |
| output_folder.mkdir(exist_ok=True, parents=True) | |
| # Get all images | |
| image_paths = list(image_folder.glob('*.png')) + \ | |
| list(image_folder.glob('*.jpg')) + \ | |
| list(image_folder.glob('*.jpeg')) | |
| if max_images: | |
| image_paths = image_paths[:max_images] | |
| print(f"Processing {len(image_paths)} images...") | |
| for img_path in image_paths: | |
| output_path = output_folder / f"recon_{img_path.name}" | |
| self.save_reconstruction(str(img_path), str(output_path)) | |
| print("Batch reconstruction complete!") | |
| def main(): | |
| parser = argparse.ArgumentParser(description='DiTo Image Reconstruction') | |
| parser.add_argument('--checkpoint', type=str, required=True, | |
| help='Path to DiTo checkpoint') | |
| parser.add_argument('--input', type=str, required=True, | |
| help='Input image path or folder') | |
| parser.add_argument('--output', type=str, required=True, | |
| help='Output path') | |
| parser.add_argument('--compare', action='store_true', | |
| help='Save comparison with original') | |
| parser.add_argument('--batch', action='store_true', | |
| help='Process entire folder') | |
| parser.add_argument('--device', type=str, default='cuda', | |
| help='Device to use (cuda/cpu)') | |
| parser.add_argument('--max_images', type=int, default=None, | |
| help='Maximum images to process in batch mode') | |
| args = parser.parse_args() | |
| # Initialize model | |
| dito = DiToInference(args.checkpoint, device=args.device) | |
| # Process based on mode | |
| if args.batch: | |
| dito.batch_reconstruct(args.input, args.output, args.max_images) | |
| elif args.compare: | |
| dito.compare_reconstruction(args.input, args.output) | |
| else: | |
| dito.save_reconstruction(args.input, args.output) | |
| # Example usage function for direct Python use | |
| def reconstruct_single_image(checkpoint_path, image_path, output_path): | |
| """Simple function to reconstruct a single image""" | |
| dito = DiToInference(checkpoint_path) | |
| dito.save_reconstruction(image_path, output_path) | |
| if __name__ == "__main__": | |
| main() | |
| # Usage examples: | |
| # 1. Single image reconstruction: | |
| # python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output recon.jpg | |
| # | |
| # 2. Single image with comparison: | |
| # python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output compare.jpg --compare | |
| # | |
| # 3. Batch processing: | |
| # python dito_inference.py --checkpoint ckpt-best.pth --input input_folder/ --output output_folder/ --batch | |
| # | |
| # 4. Direct Python usage: | |
| # reconstruct_single_image('ckpt-best.pth', 'input.jpg', 'output.jpg') |