Gordon-H's picture
Update app.py
deca831 verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
# argparse is removed
import os
import time
# huggingface_hub.notebook_login is usually not needed in a Space
# from huggingface_hub import notebook_login
# IMPORTANT: models.py must be in the same directory in your Space repository
# Ensure Generator class is defined in models.py
try:
# Assuming models.py is in the same directory
from models import Generator
except ImportError:
print("ERROR: 'models.py' not found or 'Generator' class not defined within it.")
print("Please ensure models.py is in your Space's root directory.")
exit(1)
# --- Global Variables / Setup ---
generator_model = None
device = None
# --- Model Configuration (Define these directly for the Space) ---
# These should match the training parameters of your model checkpoint
CHECKPOINT_PATH = "checkpoints/generator_epoch_15.pth" # <--- IMPORTANT: Update if your path is different
MODEL_SCALE_FACTOR = 4 # <--- IMPORTANT: Update if your model scale is different
MODEL_NUM_FEATURES = 64 # <--- IMPORTANT: Update if your model features is different
MODEL_NUM_BLOCKS = 16 # <--- IMPORTANT: Update if your model blocks is different
def load_generator_model(checkpoint_path, scale, num_features, num_blocks):
"""Loads the generator model from a checkpoint."""
global generator_model, device
print("--- Loading Model ---")
# Use GPU if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
try:
# Instantiate the model with the defined parameters
model = Generator(scale_factor=scale, num_features=num_features, num_res_blocks=num_blocks).to(device)
model.eval() # Set model to evaluation mode
# Check if checkpoint exists
if not os.path.exists(checkpoint_path):
print(f"ERROR: Checkpoint file not found at: '{checkpoint_path}'")
# Try a common alternative path for Spaces if needed, e.g., relative to app.py
# alternative_path = os.path.join(os.path.dirname(__file__), checkpoint_path)
# if os.path.exists(alternative_path):
# checkpoint_path = alternative_path
# print(f"Found checkpoint at alternative path: {checkpoint_path}")
# else:
# return False # Checkpoint not found
return False # Just return False if not found at the specified path
print(f"Loading state dictionary from: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device)
# Handle DataParallel saved models
if any('module.' in k for k in state_dict.keys()):
print("Removing 'module.' prefix from state_dict keys...")
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
# Load the state dictionary
model.load_state_dict(state_dict)
print("State dictionary loaded successfully.")
except Exception as e:
print(f"ERROR loading or instantiating model: {e}")
return False # Return False if loading fails
generator_model = model
print(f"Generator model (Scale x{scale}) loaded and ready on {device}.")
print("--------------------")
return True # Return True on success
def upscale_image(input_image: Image.Image):
"""Upscales an input image using the loaded generator model."""
global generator_model, device, MODEL_SCALE_FACTOR # Access scale factor
if generator_model is None:
# This should ideally not happen if load_generator_model succeeded,
# but provides a fallback error message.
raise gr.Error("Generator model not loaded. Application is not ready.")
if input_image is None:
return None # Return nothing if no input image is provided
# --- Preprocessing ---
# Gradio gives a PIL Image, convert to PyTorch tensor
transform_to_tensor = transforms.ToTensor()
# Add a batch dimension (unsqueeze(0)) and move to device
input_tensor = transform_to_tensor(input_image).unsqueeze(0).to(device)
# --- Inference ---
with torch.no_grad(): # Disable gradient calculation for inference
try:
print(f"Upscaling image with scale factor: {MODEL_SCALE_FACTOR}")
# Perform the super-resolution inference
output_tensor = generator_model(input_tensor)
# Remove the batch dimension and move to CPU, clamp values to [0, 1]
output_tensor = output_tensor.squeeze(0).cpu().clamp(0, 1)
# --- Postprocessing ---
# Convert the output tensor back to a PIL Image
# This automatically creates an image with the upscaled dimensions
output_image = transforms.ToPILImage()(output_tensor)
# The line below in your original code was incorrect for SR,
# as it resized the upscaled image BACK to the original size.
# output_image = output_image.resize((original_width, original_height), resample=Image.BICUBIC)
print("Upscaling complete.")
return output_image # Return the correctly upscaled PIL Image
except RuntimeError as e:
print(f"ERROR during inference: {e}")
# Provide a user-friendly error message via Gradio
raise gr.Error(f"Inference failed: {e}. Please check the input image.")
except Exception as e:
print(f"An unexpected error occurred during upscaling: {e}")
raise gr.Error(f"An unexpected error occurred: {e}")
# --- Gradio Interface ---
if __name__ == "__main__":
# --- Load the model with defined parameters ---
print("Attempting to load model for Gradio interface...")
# Pass the defined constants to the loading function
if load_generator_model(CHECKPOINT_PATH, MODEL_SCALE_FACTOR, MODEL_NUM_FEATURES, MODEL_NUM_BLOCKS):
print("Model loaded successfully. Starting Gradio interface.")
# Define the Gradio Interface
iface = gr.Interface(
fn=upscale_image,
inputs=[gr.Image(type="pil", label="Input Image (Low Resolution)")],
outputs=gr.Image(type="pil", label=f"Output Image (Upscaled x{MODEL_SCALE_FACTOR})"), # Clarify output
title="OxO Image Super-Resolution (x{})".format(MODEL_SCALE_FACTOR), # Include scale in title
description=f"Upload a PNG to upscale it by a factor of {MODEL_SCALE_FACTOR}. Repaired images may seem dark and colorless, you can adjust the contrast to make it look better.",
allow_flagging='never', # Disable flagging
cache_examples=False, # Do not cache examples
live=False, # Process only when submit is clicked (better for heavy tasks)
)
# Launch the Gradio app
# For Hugging Face Spaces, you typically don't need share=True
# The server name and port are often required in containerized environments
iface.launch(server_name="0.0.0.0", server_port=7860)
else:
# If model loading failed, print an error and potentially exit or raise
print("Failed to load the model. The Gradio interface will not start.")
# Optionally, you could raise an exception here if you want the Space to error out
# raise RuntimeError("Model loading failed.")