TESTT1 / app.py
tejani's picture
Update app.py
009235c verified
import gradio as gr
import os
import sys
import cv2
import numpy as np
import torch
from generate import process, load_model, NORMAL_MAP_MODEL, OTHER_MAP_MODEL
# Ensure the utils directory is in the path
sys.path.append(os.path.join(os.path.dirname(__file__), 'utils'))
# Force CPU usage
device = torch.device('cpu')
normal_map_model = load_model(NORMAL_MAP_MODEL, device)
other_map_model = load_model(OTHER_MAP_MODEL, device)
models = [normal_map_model, other_map_model]
# Temporary input and output directories
input_dir = "temp_input"
output_dir = "temp_output"
if not os.path.exists(input_dir):
os.makedirs(input_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def generate_maps(input_image, tile_size=512, seamless=False, mirror=False, replicate=False, ishiiruka=False, ishiiruka_texture_encoder=False):
# Save the uploaded image to the input directory
input_path = os.path.join(input_dir, "input.png")
cv2.imwrite(input_path, cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR))
# Prepare arguments as in generate.py
args = type('Args', (), {
'input': input_dir,
'output': output_dir,
'tile_size': tile_size,
'seamless': seamless,
'mirror': mirror,
'replicate': replicate,
'ishiiruka': ishiiruka,
'ishiiruka_texture_encoder': ishiiruka_texture_encoder,
'cpu': True # Force CPU usage
})()
# Create a mock parser and set args (simulating command-line args)
import argparse
parser = argparse.ArgumentParser()
parser.parse_args = lambda: args
# Process the image
img = cv2.imread(input_path, cv2.IMREAD_COLOR)
if seamless:
img = cv2.copyMakeBorder(img, 16, 16, 16, 16, cv2.BORDER_WRAP)
elif mirror:
img = cv2.copyMakeBorder(img, 16, 16, 16, 16, cv2.BORDER_REFLECT_101)
elif replicate:
img = cv2.copyMakeBorder(img, 16, 16, 16, 16, cv2.BORDER_REPLICATE)
img_height, img_width = img.shape[:2]
do_split = img_height > args.tile_size or img_width > args.tile_size
from utils.imgops import esrgan_launcher_split_merge, crop_seamless
if do_split:
rlts = esrgan_launcher_split_merge(img, lambda x, m: process(x, m, device), models, scale_factor=1, tile_size=args.tile_size)
else:
rlts = [process(img, model, device) for model in models]
if seamless or mirror or replicate:
rlts = [crop_seamless(rlt) for rlt in rlts]
normal_map = rlts[0]
roughness = rlts[1][:, :, 1]
displacement = rlts[1][:, :, 0]
# Save and load outputs
base = "output"
if args.ishiiruka_texture_encoder:
r = 255 - roughness
g = normal_map[:, :, 1]
b = displacement
a = normal_map[:, :, 2]
output = cv2.merge((b, g, r, a))
cv2.imwrite(os.path.join(output_dir, f'{base}.mat.png'), output)
normal_map_output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGB)
else:
normal_name = f'{base}_Normal.png'
cv2.imwrite(os.path.join(output_dir, normal_name), normal_map)
normal_map_output = cv2.cvtColor(normal_map, cv2.COLOR_BGR2RGB)
rough_name = f'{base}_Roughness.png'
rough_img = 255 - roughness if args.ishiiruka else roughness
cv2.imwrite(os.path.join(output_dir, rough_name), rough_img)
roughness_output = cv2.cvtColor(rough_img, cv2.COLOR_GRAY2RGB)
displ_name = f'{base}_Displacement.png'
cv2.imwrite(os.path.join(output_dir, displ_name), displacement)
displacement_output = cv2.cvtColor(displacement, cv2.COLOR_GRAY2RGB)
# Return outputs
outputs = [normal_map_output]
if not args.ishiiruka_texture_encoder:
outputs.extend([roughness_output, displacement_output])
return outputs
# Gradio interface
interface = gr.Interface(
fn=generate_maps,
inputs=[
gr.Image(type="numpy", label="Upload Diffuse Texture"),
gr.Slider(128, 1024, value=512, step=128, label="Tile Size"),
gr.Checkbox(label="Seamless"),
gr.Checkbox(label="Mirror"),
gr.Checkbox(label="Replicate"),
gr.Checkbox(label="Ishiiruka Format"),
gr.Checkbox(label="Ishiiruka Texture Encoder")
],
outputs=[
gr.Image(label="Normal Map"),
gr.Image(label="Roughness Map"),
gr.Image(label="Displacement Map")
],
title="Material Map Generator",
description="Upload a diffuse texture to generate Normal, Roughness, and Displacement maps using AI. Adjust settings as needed. Note: Running on CPU due to lack of GPU support in this environment, which may result in slower processing."
)
# Launch the interface
interface.launch()