Spaces:
Running
on
T4
Running
on
T4
| import os | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| import gradio as gr | |
| from PIL import Image | |
| # Path to the model in Hugging Face Space | |
| MODEL_PATH = "pretrained/4xGRL.onnx" # Adjust this if the model is stored in a different location | |
| # Preprocessing function for images (similar to original script) | |
| def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720): | |
| ''' Preprocess the image to match model input expectations ''' | |
| img = np.array(img) | |
| # Convert to RGB (OpenCV uses BGR by default) | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| # Resize if necessary (downsample based on the downsample threshold) | |
| h, w, _ = img_rgb.shape | |
| short_side = min(h, w) | |
| # Downsample if the short side exceeds the threshold | |
| if short_side > downsample_threshold: | |
| resize_ratio = short_side / downsample_threshold | |
| img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR) | |
| # Crop to match 4x scaling if needed | |
| if crop_for_4x: | |
| h, w, _ = img_rgb.shape | |
| if h % 4 != 0: | |
| img_rgb = img_rgb[:4 * (h // 4), :, :] | |
| if w % 4 != 0: | |
| img_rgb = img_rgb[:, :4 * (w // 4), :] | |
| # Resize the image to match the model's expected input size (e.g., 180x320) | |
| img_resized = cv2.resize(img_rgb, (target_width, target_height)) # Resize to 180x320 | |
| return img_resized | |
| # Inference function to process image using ONNX model | |
| def inference(img, model_name="4xGRL"): | |
| try: | |
| # Ensure correct dtype for ONNX | |
| weight_dtype = np.float32 # ONNX uses numpy arrays, so use np.float32 | |
| if model_name == "4xGRL": | |
| # Load the ONNX model | |
| ort_session = ort.InferenceSession(MODEL_PATH) | |
| # Preprocess the image (resize, crop, etc.) | |
| img_resized = preprocess_image(img) | |
| # Prepare the input in the format expected by the model (e.g., (N, C, H, W)) | |
| input_image = np.transpose(img_resized, (2, 0, 1)) # Convert to (C, H, W) | |
| input_image = np.expand_dims(input_image, axis=0) # Add batch dimension | |
| input_image = input_image.astype(weight_dtype) # Convert to float32 | |
| # Run the model | |
| ort_inputs = {ort_session.get_inputs()[0].name: input_image} | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| # Post-process the output | |
| output_image = ort_outs[0] # Assuming the model output is in the first position | |
| output_image = np.transpose(output_image.squeeze(), (1, 2, 0)) # Convert to (H, W, C) | |
| output_image = np.clip(output_image, 0, 255).astype(np.uint8) # Ensure valid image range | |
| # Convert output to PIL Image for Gradio | |
| output_pil = Image.fromarray(output_image) | |
| return output_pil | |
| else: | |
| raise Exception("Model not supported") | |
| except Exception as error: | |
| return f"An error occurred: {error}" | |
| # Gradio interface | |
| def create_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Anime Super-Resolution using ONNX") | |
| gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.") | |
| # File input for image | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", label="Upload Image", interactive=True) | |
| # Process button | |
| with gr.Row(): | |
| process_button = gr.Button("Process Image") | |
| # Output for result image | |
| with gr.Row(): | |
| result_image = gr.Image(type="pil", label="Processed Image") | |
| # Functionality for processing image | |
| process_button.click(inference, inputs=input_image, outputs=result_image) | |
| return demo | |
| # Launch the app | |
| demo = create_interface() | |
| demo.launch(share=True) | |