Spaces:
Build error
Build error
File size: 7,518 Bytes
838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca 838846e a71d6ca deca831 a71d6ca 838846e a71d6ca 838846e a71d6ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
# argparse is removed
import os
import time
# huggingface_hub.notebook_login is usually not needed in a Space
# from huggingface_hub import notebook_login
# IMPORTANT: models.py must be in the same directory in your Space repository
# Ensure Generator class is defined in models.py
try:
# Assuming models.py is in the same directory
from models import Generator
except ImportError:
print("ERROR: 'models.py' not found or 'Generator' class not defined within it.")
print("Please ensure models.py is in your Space's root directory.")
exit(1)
# --- Global Variables / Setup ---
generator_model = None
device = None
# --- Model Configuration (Define these directly for the Space) ---
# These should match the training parameters of your model checkpoint
CHECKPOINT_PATH = "checkpoints/generator_epoch_15.pth" # <--- IMPORTANT: Update if your path is different
MODEL_SCALE_FACTOR = 4 # <--- IMPORTANT: Update if your model scale is different
MODEL_NUM_FEATURES = 64 # <--- IMPORTANT: Update if your model features is different
MODEL_NUM_BLOCKS = 16 # <--- IMPORTANT: Update if your model blocks is different
def load_generator_model(checkpoint_path, scale, num_features, num_blocks):
"""Loads the generator model from a checkpoint."""
global generator_model, device
print("--- Loading Model ---")
# Use GPU if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
try:
# Instantiate the model with the defined parameters
model = Generator(scale_factor=scale, num_features=num_features, num_res_blocks=num_blocks).to(device)
model.eval() # Set model to evaluation mode
# Check if checkpoint exists
if not os.path.exists(checkpoint_path):
print(f"ERROR: Checkpoint file not found at: '{checkpoint_path}'")
# Try a common alternative path for Spaces if needed, e.g., relative to app.py
# alternative_path = os.path.join(os.path.dirname(__file__), checkpoint_path)
# if os.path.exists(alternative_path):
# checkpoint_path = alternative_path
# print(f"Found checkpoint at alternative path: {checkpoint_path}")
# else:
# return False # Checkpoint not found
return False # Just return False if not found at the specified path
print(f"Loading state dictionary from: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device)
# Handle DataParallel saved models
if any('module.' in k for k in state_dict.keys()):
print("Removing 'module.' prefix from state_dict keys...")
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
# Load the state dictionary
model.load_state_dict(state_dict)
print("State dictionary loaded successfully.")
except Exception as e:
print(f"ERROR loading or instantiating model: {e}")
return False # Return False if loading fails
generator_model = model
print(f"Generator model (Scale x{scale}) loaded and ready on {device}.")
print("--------------------")
return True # Return True on success
def upscale_image(input_image: Image.Image):
"""Upscales an input image using the loaded generator model."""
global generator_model, device, MODEL_SCALE_FACTOR # Access scale factor
if generator_model is None:
# This should ideally not happen if load_generator_model succeeded,
# but provides a fallback error message.
raise gr.Error("Generator model not loaded. Application is not ready.")
if input_image is None:
return None # Return nothing if no input image is provided
# --- Preprocessing ---
# Gradio gives a PIL Image, convert to PyTorch tensor
transform_to_tensor = transforms.ToTensor()
# Add a batch dimension (unsqueeze(0)) and move to device
input_tensor = transform_to_tensor(input_image).unsqueeze(0).to(device)
# --- Inference ---
with torch.no_grad(): # Disable gradient calculation for inference
try:
print(f"Upscaling image with scale factor: {MODEL_SCALE_FACTOR}")
# Perform the super-resolution inference
output_tensor = generator_model(input_tensor)
# Remove the batch dimension and move to CPU, clamp values to [0, 1]
output_tensor = output_tensor.squeeze(0).cpu().clamp(0, 1)
# --- Postprocessing ---
# Convert the output tensor back to a PIL Image
# This automatically creates an image with the upscaled dimensions
output_image = transforms.ToPILImage()(output_tensor)
# The line below in your original code was incorrect for SR,
# as it resized the upscaled image BACK to the original size.
# output_image = output_image.resize((original_width, original_height), resample=Image.BICUBIC)
print("Upscaling complete.")
return output_image # Return the correctly upscaled PIL Image
except RuntimeError as e:
print(f"ERROR during inference: {e}")
# Provide a user-friendly error message via Gradio
raise gr.Error(f"Inference failed: {e}. Please check the input image.")
except Exception as e:
print(f"An unexpected error occurred during upscaling: {e}")
raise gr.Error(f"An unexpected error occurred: {e}")
# --- Gradio Interface ---
if __name__ == "__main__":
# --- Load the model with defined parameters ---
print("Attempting to load model for Gradio interface...")
# Pass the defined constants to the loading function
if load_generator_model(CHECKPOINT_PATH, MODEL_SCALE_FACTOR, MODEL_NUM_FEATURES, MODEL_NUM_BLOCKS):
print("Model loaded successfully. Starting Gradio interface.")
# Define the Gradio Interface
iface = gr.Interface(
fn=upscale_image,
inputs=[gr.Image(type="pil", label="Input Image (Low Resolution)")],
outputs=gr.Image(type="pil", label=f"Output Image (Upscaled x{MODEL_SCALE_FACTOR})"), # Clarify output
title="OxO Image Super-Resolution (x{})".format(MODEL_SCALE_FACTOR), # Include scale in title
description=f"Upload a PNG to upscale it by a factor of {MODEL_SCALE_FACTOR}. Repaired images may seem dark and colorless, you can adjust the contrast to make it look better.",
allow_flagging='never', # Disable flagging
cache_examples=False, # Do not cache examples
live=False, # Process only when submit is clicked (better for heavy tasks)
)
# Launch the Gradio app
# For Hugging Face Spaces, you typically don't need share=True
# The server name and port are often required in containerized environments
iface.launch(server_name="0.0.0.0", server_port=7860)
else:
# If model loading failed, print an error and potentially exit or raise
print("Failed to load the model. The Gradio interface will not start.")
# Optionally, you could raise an exception here if you want the Space to error out
# raise RuntimeError("Model loading failed.") |