learnable-speech / flowae /image_dito_inference.py
primepake
add training flowvae
4f877a2
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import numpy as np
from pathlib import Path
import argparse
# You'll need to have the DiTo codebase available
import models
from omegaconf import OmegaConf
class DiToInference:
def __init__(self, checkpoint_path, device='cuda'):
"""Initialize DiTo model from checkpoint"""
self.device = device
# Load checkpoint
print(f"Loading checkpoint from {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location='cpu')
# Extract config
self.config = OmegaConf.create(ckpt['config'])
# Create model
self.model = models.make(self.config['model'])
# Load state dict
self.model.load_state_dict(ckpt['model']['sd'])
# Move to device and set to eval
self.model = self.model.to(device)
self.model.eval()
# Setup image transforms based on config
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
print("Model loaded successfully!")
def reconstruct_image(self, image_path, debug=True):
"""Reconstruct a single image"""
# Load and preprocess image
image = Image.open(image_path).convert('RGB')
if debug:
debug_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
])
debug_image = debug_transform(image)
debug_image.save('debug_1_resized_cropped.png')
print("Saved debug_1_resized_cropped.png")
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Step 1: Encode to latent
z = self.model.encode(image_tensor)
# Step 2: Decode to features (in DiTo this is identity)
z_dec = self.model.decode(z)
print('z_dec.shape:', z_dec.shape)
# Step 3: Prepare coordinate grids
# Based on the training code, coord and scale are dummy values
b, c, h, w = image_tensor.shape
coord = torch.zeros(b, 2, h, w, device=self.device)
scale = torch.zeros(b, 2, h, w, device=self.device)
# Step 4: Render using diffusion
reconstructed = self.model.render(z_dec, coord, scale)
# Denormalize from [-1, 1] to [0, 1]
reconstructed = (reconstructed * 0.5 + 0.5).clamp(0, 1)
return reconstructed
def save_reconstruction(self, image_path, output_path):
"""Reconstruct and save image"""
reconstructed = self.reconstruct_image(image_path)
# Convert to PIL
to_pil = transforms.ToPILImage()
reconstructed_pil = to_pil(reconstructed.squeeze(0).cpu())
# Save
reconstructed_pil.save(output_path)
print(f"Saved reconstruction to {output_path}")
def compare_reconstruction(self, image_path, output_path):
"""Save original and reconstruction side by side"""
# Get reconstruction
reconstructed = self.reconstruct_image(image_path)
# Load original at same resolution
original = Image.open(image_path).convert('RGB')
original = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor()
])(original).unsqueeze(0)
# Concatenate side by side
comparison = torch.cat([original, reconstructed.cpu()], dim=3)
# Save
to_pil = transforms.ToPILImage()
comparison_pil = to_pil(comparison.squeeze(0))
comparison_pil.save(output_path)
print(f"Saved comparison to {output_path}")
def batch_reconstruct(self, image_folder, output_folder, max_images=None):
"""Reconstruct all images in a folder"""
image_folder = Path(image_folder)
output_folder = Path(output_folder)
output_folder.mkdir(exist_ok=True, parents=True)
# Get all images
image_paths = list(image_folder.glob('*.png')) + \
list(image_folder.glob('*.jpg')) + \
list(image_folder.glob('*.jpeg'))
if max_images:
image_paths = image_paths[:max_images]
print(f"Processing {len(image_paths)} images...")
for img_path in image_paths:
output_path = output_folder / f"recon_{img_path.name}"
self.save_reconstruction(str(img_path), str(output_path))
print("Batch reconstruction complete!")
def main():
parser = argparse.ArgumentParser(description='DiTo Image Reconstruction')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to DiTo checkpoint')
parser.add_argument('--input', type=str, required=True,
help='Input image path or folder')
parser.add_argument('--output', type=str, required=True,
help='Output path')
parser.add_argument('--compare', action='store_true',
help='Save comparison with original')
parser.add_argument('--batch', action='store_true',
help='Process entire folder')
parser.add_argument('--device', type=str, default='cuda',
help='Device to use (cuda/cpu)')
parser.add_argument('--max_images', type=int, default=None,
help='Maximum images to process in batch mode')
args = parser.parse_args()
# Initialize model
dito = DiToInference(args.checkpoint, device=args.device)
# Process based on mode
if args.batch:
dito.batch_reconstruct(args.input, args.output, args.max_images)
elif args.compare:
dito.compare_reconstruction(args.input, args.output)
else:
dito.save_reconstruction(args.input, args.output)
# Example usage function for direct Python use
def reconstruct_single_image(checkpoint_path, image_path, output_path):
"""Simple function to reconstruct a single image"""
dito = DiToInference(checkpoint_path)
dito.save_reconstruction(image_path, output_path)
if __name__ == "__main__":
main()
# Usage examples:
# 1. Single image reconstruction:
# python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output recon.jpg
#
# 2. Single image with comparison:
# python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output compare.jpg --compare
#
# 3. Batch processing:
# python dito_inference.py --checkpoint ckpt-best.pth --input input_folder/ --output output_folder/ --batch
#
# 4. Direct Python usage:
# reconstruct_single_image('ckpt-best.pth', 'input.jpg', 'output.jpg')