qbhf2's picture
Update app.py
e53b43c verified
import os
os.system("pip uninstall -y huggingface-hub")
os.system("pip install huggingface-hub==0.25.2")
import gradio as gr
from PIL import Image
import numpy as np
import cv2
from zipfile import ZipFile
# Функция обработки изображения
import gradio as gr
from PIL import Image
import numpy as np
import os
import shutil
import yaml
from pathlib import Path
from fabric_diffusion import FabricDiffusionPipeline
from delete_bg import remove_background
import capture
from pytorch_lightning import Trainer
import random
import torch
ZIP_FOLDER = "./ZIPS"
os.makedirs(ZIP_FOLDER, exist_ok=True)
def zip_folder(folder_path, output_zip):
with ZipFile(output_zip, 'w') as zipf:
for root, dirs, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, os.path.relpath(file_path, folder_path))
def set_deterministic(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
def run_flatten_texture(pipeline, input_image_path, output_path, n_samples=3):
os.makedirs(output_path, exist_ok=True)
texture_name = os.path.splitext(os.path.basename(input_image_path))[0]
texture_patch = pipeline.load_patch_data(input_image_path)
gen_imgs = pipeline.flatten_texture(texture_patch, n_samples=n_samples)
for i, gen_img in enumerate(gen_imgs):
gen_img.save(os.path.join(output_path, f'{texture_name}_gen_{i}.png'))
def organize_images_into_structure(source_folder, new_folder):
os.makedirs(new_folder, exist_ok=True)
for file_name in os.listdir(source_folder):
source_file = os.path.join(source_folder, file_name)
if os.path.isfile(source_file) and file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
folder_name = os.path.splitext(file_name)[0]
subfolder_path = os.path.join(new_folder, folder_name, "outputs")
os.makedirs(subfolder_path, exist_ok=True)
destination_file = os.path.join(subfolder_path, file_name)
shutil.copy(source_file, destination_file)
# Create a directory for saving if it doesn't exist
if not os.path.exists("saved_images"):
os.makedirs("saved_images")
def decode_rgba(image: Image.Image) -> Image.Image:
image_array = np.array(image)
alpha_channel = image_array[:, :, 3]
coords = np.argwhere(alpha_channel == 255)
y_min, x_min = coords.min(axis=0)
y_max, x_max = coords.max(axis=0) + 1 # Include the max boundary
cropped_image_array = image_array[y_min:y_max, x_min:x_max]
cropped_rgb_array = cropped_image_array[:, :, :3]
cropped_rgb_image = Image.fromarray(cropped_rgb_array, "RGB")
return cropped_rgb_image
def process_image(orig_image, image_data):
"""Processes and saves the original and cropped images."""
# Get the original image. Convert to PIL Image if needed
# Get the cropped image (composite). Convert to PIL Image if needed
cropped_image = image_data['composite']
# print(type(cropped_image), cropped_image.size, np.array(crop_image).shape)
# Generate unique filenames using timestamps to avoid overwriting
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
original_filename = os.path.join("saved_images", f"original_{timestamp}.png")
cropped_filename = os.path.join("saved_images", f"cropped_{timestamp}.png")
orig_image.save(original_filename)
decoded_cropped_image = decode_rgba(cropped_image)
decoded_cropped_image.save(cropped_filename)
return orig_image, original_filename, decoded_cropped_image, cropped_filename, timestamp
def web_main(orig_image, image_data):
orig_image, original_filename, cropped_image, cropped_filename, timestamp = process_image(orig_image, image_data)
set_deterministic(seed=42)
config = load_config("config_demo.yaml")
device = config["hyperparameters"]["device"]
texture_checkpoint = config["hyperparameters"]["fb_checkpoint"]
print_checkpoint = config["hyperparameters"].get("print_checkpoint", None)
input_image = config["hyperparameters"]["input_image"]
save_fd_dir = config["hyperparameters"]["save_fd_dir"]
save_mp_dir = config["hyperparameters"]["save_mp_dir"]
# save_mp_dir = timestamp
n_samples = config["hyperparameters"]["n_samples"]
pipeline = FabricDiffusionPipeline(device, texture_checkpoint, print_checkpoint=print_checkpoint)
os.makedirs(save_fd_dir, exist_ok=True)
run_flatten_texture(pipeline, cropped_filename, output_path=save_fd_dir, n_samples=n_samples)
organize_images_into_structure(save_fd_dir, save_mp_dir)
data = capture.get_data(predict_dir=Path(save_mp_dir), predict_ds='sd')
module = capture.get_inference_module(pt=config["hyperparameters"]["checkpoint_name"])
decomp = Trainer(default_root_dir=Path(save_mp_dir), accelerator='gpu', devices=1, precision=16)
decomp.predict(module, data)
folder = f"cropped_{timestamp}_gen_0"
# for folder in os.listdir(save_mp_dir):
folder_path = os.path.join(save_mp_dir, folder)
if os.path.isdir(folder_path):
target_path = os.path.join(folder_path, "weights", "mask", "an_object_with_azertyuiop_texture",
"checkpoint-800", "outputs")
if os.path.exists(target_path):
for file_name in os.listdir(target_path):
file_path = os.path.join(target_path, file_name)
if os.path.isfile(file_path):
shutil.move(file_path, folder_path)
print(f"FOLDER: {folder_path}")
print(os.path.exists(original_filename))
shutil.copyfile(
original_filename,
os.path.join(folder_path, "outputs", original_filename.split("/")[-1])
)
shutil.copyfile(
cropped_filename,
os.path.join(folder_path, "outputs", cropped_filename.split("/")[-1])
)
zip_folder(
folder_path,
os.path.join(ZIP_FOLDER, f"{folder_path.split('/')[-1]}.zip")
)
return orig_image, cropped_image, os.path.join(ZIP_FOLDER, f"{folder_path.split('/')[-1]}.zip")
def remove_background_main(image):
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
if not os.path.exists("saved_images"):
os.makedirs("saved_images")
original_filename = os.path.join("saved_images", f"original_{timestamp}.png")
image.save(original_filename)
result_img = remove_background(original_filename)
output_path = os.path.join("saved_images", f"nobg_{timestamp}.png")
cv2.imwrite(output_path, cv2.cvtColor(result_img, cv2.COLOR_RGBA2BGRA))
folder_name = f"nobg_{timestamp}"
folder_path = os.path.join("output_images", folder_name)
os.makedirs(os.path.join(folder_path, "outputs"), exist_ok=True)
shutil.copyfile(
original_filename,
os.path.join(folder_path, "outputs", os.path.basename(original_filename))
)
shutil.copyfile(
output_path,
os.path.join(folder_path, "outputs", os.path.basename(output_path))
)
result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_RGBA2BGRA))
return result_pil
with gr.Blocks() as demo:
with gr.Tabs() as tabs:
# PBR generator
# with gr.TabItem("PBR generator") as pbr_tab:
# with gr.Row():
# orig_image = gr.Image(label="Orig image", type="pil")
# image_editor = gr.ImageEditor(type="pil", crop_size="1:1", label="Edit Image: Crop the desired element")
# with gr.Row():
# output_image = gr.Image(label="Result")
# crop_image = gr.Image(label="Crop_image")
# zip_file = gr.File(label="Download Zip File")
# process_button = gr.Button("Crop Element")
# Print extraction
with gr.TabItem("Print extraction") as print_tab:
with gr.Row():
orig_image_print = gr.Image(label="Orig image (cropped print)", type="pil")
with gr.Row():
output_image_print = gr.Image(label="Result")
process_print_button = gr.Button("Extract print")
# process_button.click(
# web_main,
# inputs=[orig_image, image_editor],
# outputs=[output_image, crop_image, zip_file]
# )
process_print_button.click(
remove_background_main,
inputs=[orig_image_print],
outputs=[output_image_print]
)
if __name__ == "__main__":
demo.launch()