Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__))) | |
| import blended_tiling | |
| import numpy | |
| import onnxruntime | |
| import streamlit.file_util | |
| import torch | |
| import torch.cuda | |
| from PIL import Image | |
| from streamlit.runtime.uploaded_file_manager import UploadedFile | |
| from streamlit_image_comparison import image_comparison | |
| from torchvision.transforms import functional as TVTF | |
| from tools import image_tools | |
| # * Cached/loaded model | |
| onnx_session = None # type: onnxruntime.InferenceSession | |
| # * Streamlit UI / Config | |
| streamlit.set_page_config(page_title="🐲 PXDN Line Extractor v1", layout="wide") | |
| streamlit.title("🐲 PXDN Line Extractor v1") | |
| # * Streamlit Containers / Base Layout | |
| # Row 1 | |
| ui_section_status = streamlit.container() | |
| # Row 2 | |
| ui_col1, ui_col2 = streamlit.columns(2, gap="medium") | |
| streamlit.html("<hr>") | |
| # Row 3 | |
| ui_section_compare = streamlit.container() | |
| # * Streamlit Session | |
| # Nothing yet | |
| with ui_section_status: | |
| # Forward declared UI elements | |
| ui_status_text = streamlit.empty() | |
| ui_progress_bar = streamlit.empty() | |
| with ui_col1: | |
| # Input Area | |
| streamlit.markdown("### Input Image") | |
| ui_image_input = streamlit.file_uploader("Upload an image", key="fileupload_image", type=[".png", ".jpg", ".jpeg", ".webp"]) # type: UploadedFile | |
| with ui_col2: | |
| # Output Area | |
| streamlit.markdown("### Output Image") | |
| # Preallocate image spot and download button | |
| ui_image_output = streamlit.empty() | |
| ui_image_download = streamlit.empty() | |
| def fetch_model_to_cache(huggingface_repo: str, file_path: str, access_token: str) -> str: | |
| import huggingface_hub | |
| return huggingface_hub.hf_hub_download(huggingface_repo, file_path, token=access_token) | |
| def bootstrap_model(): | |
| global onnx_session | |
| if onnx_session is None: | |
| # Environment-level configuration | |
| huggingface_repo = os.getenv("HF_REPO_NAME", "") | |
| file_path = os.getenv("HF_FILE_PATH", "") | |
| access_token = os.getenv("HF_TOKEN", "") | |
| allow_cuda = os.getenv("ALLOW_CUDA", "false").lower() in {'true', 'yes', '1', 'y'} | |
| model_file_path = fetch_model_to_cache(huggingface_repo, file_path, access_token) | |
| # * Enable CUDA if available and allowed | |
| model_providers = ['CPUExecutionProvider'] | |
| if torch.cuda.is_available() and allow_cuda: | |
| model_providers.insert(0, 'CUDAExecutionProvider') | |
| onnx_session = onnxruntime.InferenceSession(model_file_path, sess_options=None, providers=model_providers) | |
| def evaluate_tiled(image_pt: torch.Tensor, tile_size: int = 128, batch_size: int = 1) -> Image.Image: | |
| image_pt_orig = image_pt | |
| orig_h, orig_w = image_pt_orig.shape[1], image_pt_orig.shape[2] | |
| # ? Padding | |
| image_pt_padded, place_x, place_y = image_tools.pad_to_divisible(image_pt_orig, tile_size) | |
| _, im_h_padded, im_w_padded = image_pt_padded.shape | |
| # ? Tiling | |
| image_tiler = blended_tiling.TilingModule(tile_size=tile_size, tile_overlap=[0.18, 0.18], base_size=(im_w_padded, im_h_padded)).eval() | |
| # * Add batch dim for the tiler which expects (1, C, H, W) | |
| image_tiles = image_tiler.split_into_tiles(image_pt_padded.unsqueeze(0)) | |
| # ? Pull the input and output names from the model so we're not hardcoding them. | |
| onnx_session.get_modelmeta() | |
| input_name = onnx_session.get_inputs()[0].name | |
| output_name = onnx_session.get_outputs()[0].name | |
| # ? Inference ================================================================================================== | |
| complete_tiles = [] | |
| max_evals = image_tiles.size(0) // batch_size | |
| image_tiles = image_tiles.numpy() | |
| ui_status_text.markdown("### Processing...") | |
| active_progress = ui_progress_bar.progress(0, "Progress") | |
| for i in range(max_evals): | |
| tile_batch = image_tiles[i * batch_size:(i + 1) * batch_size] | |
| if len(tile_batch) == 0: | |
| break | |
| pct_complete = round((i + 1) / max_evals, 2) | |
| active_progress.progress(pct_complete) | |
| eval_output = onnx_session.run([output_name], {input_name: tile_batch}) | |
| output_batch = eval_output[0] | |
| complete_tiles.extend(output_batch) | |
| # ? /Inference | |
| ui_status_text.empty() | |
| ui_progress_bar.empty() | |
| # ? Rehydrate the tiles into a full image. | |
| complete_tiles_tensor = torch.from_numpy(numpy.stack(complete_tiles)) | |
| complete_image = image_tiler.rebuild_with_masks(complete_tiles_tensor) | |
| # ? Unpad the image, a simple crop. | |
| if place_x > 0 or place_y > 0: | |
| complete_image = complete_image[:, :, place_y:place_y + orig_h, place_x:place_x + orig_w] | |
| # ? Clamp and convert to PIL. | |
| complete_image = complete_image.squeeze(0) | |
| complete_image = complete_image.clamp(0, 1.0) | |
| final_image_pil = TVTF.to_pil_image(complete_image) | |
| return final_image_pil | |
| def streamlit_to_pil_image(streamlit_file: UploadedFile): | |
| image = Image.open(io.BytesIO(streamlit_file.read())) | |
| return image | |
| def pil_to_buffered_png(image: Image.Image) -> io.BytesIO: | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG", compression=3) | |
| buffer.seek(0) | |
| return buffer | |
| # ! Image Inference | |
| if ui_image_input is not None and ui_image_input.name is not None: | |
| bootstrap_model() | |
| ui_status_text.empty() | |
| ui_progress_bar.empty() | |
| onnx_session.get_modelmeta() | |
| onnx_input_metadata = onnx_session.get_inputs()[0] | |
| b, c, h, w = onnx_input_metadata.shape | |
| target_batch_size = b | |
| # This is always square, if H and W are different for ONNX input you screwed up, so I don't want to hear it. | |
| target_tile_size = h | |
| input_image = streamlit_to_pil_image(ui_image_input) | |
| loaded_image_pt = image_tools.prepare_image_for_inference(input_image) | |
| finished_image = evaluate_tiled(loaded_image_pt, tile_size=target_tile_size, batch_size=target_batch_size) | |
| with ui_col2: | |
| ui_image_output.image(finished_image, use_container_width=True, caption="Output Image") | |
| complete_file_name = f"{ui_image_input.name.rsplit('.', 1)[0]}_output.png" | |
| def download_button(): | |
| # ui_image_download.download_button("Download Image", image_to_bytesio(finished_image), complete_file_name, type="primary", on_click=lambda: setattr(streamlit.session_state, 'download_click', True)) | |
| streamlit.download_button("Download Image", pil_to_buffered_png(finished_image), complete_file_name, type="primary") | |
| download_button() | |
| with ui_section_compare: | |
| image_comparison(img1=input_image, img2=finished_image, make_responsive=True, label1="Input Image", label2="Output Image", width=1024) | |