TESTT1 / generate.py
tejani's picture
Update generate.py
9b812d3 verified
import os
import argparse
import cv2
import numpy as np
import torch
from utils.imgops import esrgan_launcher_split_merge, crop_seamless
# Define model paths
NORMAL_MAP_MODEL = 'utils/models/1x_NormalMapGenerator-CX-Lite_200000_G.pth'
OTHER_MAP_MODEL = 'utils/models/1x_FrankenMapGenerator-CX-Lite_215000_G.pth'
def process(img, model, device=None):
"""
Process an image through the model to generate material maps.
Args:
img: Input image (numpy array)
model: PyTorch model to process the image
device: Torch device to use (defaults to CPU if None)
Returns:
Processed output image (numpy array)
"""
if device is None:
device = torch.device('cpu') # Default to CPU if no device provided
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
with torch.no_grad():
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round()
return output
def load_model(model_path, device=None):
"""
Load a pre-trained model from a file.
Args:
model_path: Path to the model file (.pth)
device: Torch device to use (defaults to CPU if None)
Returns:
Loaded PyTorch model
"""
if device is None:
device = torch.device('cpu') # Default to CPU if no device provided
from utils.architecture.architecture import RRDB # Corrected import
model = RRDB(nb=15, gc=32)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
return model.to(device)
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Generate material maps from diffuse textures.")
parser.add_argument('--input', type=str, default='input', help='Input directory or image path')
parser.add_argument('--output', type=str, default='output', help='Output directory')
parser.add_argument('--tile_size', type=int, default=512, help='Tile size for processing large images')
parser.add_argument('--seamless', action='store_true', help='Enable seamless tiling with wrap')
parser.add_argument('--mirror', action='store_true', help='Enable seamless tiling with mirror')
parser.add_argument('--replicate', action='store_true', help='Enable seamless tiling with replicate')
parser.add_argument('--ishiiruka', action='store_true', help='Output in Ishiiruka format')
parser.add_argument('--ishiiruka_texture_encoder', action='store_true', help='Output in Ishiiruka texture encoder format')
parser.add_argument('--cpu', action='store_true', help='Force CPU usage')
args = parser.parse_args()
# Set device based on args.cpu
device = torch.device('cpu' if args.cpu else 'cuda')
# Load models
models = [
load_model(NORMAL_MAP_MODEL, device),
load_model(OTHER_MAP_MODEL, device),
]
# Ensure output directory exists
if not os.path.exists(args.output):
os.makedirs(args.output)
# Process each image in the input directory
if os.path.isdir(args.input):
for filename in os.listdir(args.input):
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tga')):
img_path = os.path.join(args.input, filename)
base = os.path.splitext(filename)[0]
process_image(img_path, base, args, models, device)
else:
base = os.path.splitext(os.path.basename(args.input))[0]
process_image(args.input, base, args, models, device)
def process_image(img_path, base, args, models, device):
"""Helper function to process a single image."""
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if args.seamless:
img = cv2.copyMakeBorder(img, 16, 16, 16, 16, cv2.BORDER_WRAP)
elif args.mirror:
img = cv2.copyMakeBorder(img, 16, 16, 16, 16, cv2.BORDER_REFLECT_101)
elif args.replicate:
img = cv2.copyMakeBorder(img, 16, 16, 16, 16, cv2.BORDER_REPLICATE)
img_height, img_width = img.shape[:2]
do_split = img_height > args.tile_size or img_width > args.tile_size
if do_split:
rlts = esrgan_launcher_split_merge(img, lambda x, m: process(x, m, device), models, scale_factor=1, tile_size=args.tile_size)
else:
rlts = [process(img, model, device) for model in models]
if args.seamless or args.mirror or args.replicate:
rlts = [crop_seamless(rlt) for rlt in rlts]
normal_map = rlts[0]
roughness = rlts[1][:, :, 1]
displacement = rlts[1][:, :, 0]
# Save outputs
if args.ishiiruka_texture_encoder:
r = 255 - roughness
g = normal_map[:, :, 1]
b = displacement
a = normal_map[:, :, 2]
output = cv2.merge((b, g, r, a))
cv2.imwrite(os.path.join(args.output, f'{base}.mat.png'), output)
else:
normal_name = f'{base}_Normal.png'
cv2.imwrite(os.path.join(args.output, normal_name), normal_map)
rough_name = f'{base}_Roughness.png'
rough_img = 255 - roughness if args.ishiiruka else roughness
cv2.imwrite(os.path.join(args.output, rough_name), rough_img)
displ_name = f'{base}_Displacement.png'
cv2.imwrite(os.path.join(args.output, displ_name), displacement)
if __name__ == "__main__":
main()