Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| # argparse is removed | |
| import os | |
| import time | |
| # huggingface_hub.notebook_login is usually not needed in a Space | |
| # from huggingface_hub import notebook_login | |
| # IMPORTANT: models.py must be in the same directory in your Space repository | |
| # Ensure Generator class is defined in models.py | |
| try: | |
| # Assuming models.py is in the same directory | |
| from models import Generator | |
| except ImportError: | |
| print("ERROR: 'models.py' not found or 'Generator' class not defined within it.") | |
| print("Please ensure models.py is in your Space's root directory.") | |
| exit(1) | |
| # --- Global Variables / Setup --- | |
| generator_model = None | |
| device = None | |
| # --- Model Configuration (Define these directly for the Space) --- | |
| # These should match the training parameters of your model checkpoint | |
| CHECKPOINT_PATH = "checkpoints/generator_epoch_15.pth" # <--- IMPORTANT: Update if your path is different | |
| MODEL_SCALE_FACTOR = 4 # <--- IMPORTANT: Update if your model scale is different | |
| MODEL_NUM_FEATURES = 64 # <--- IMPORTANT: Update if your model features is different | |
| MODEL_NUM_BLOCKS = 16 # <--- IMPORTANT: Update if your model blocks is different | |
| def load_generator_model(checkpoint_path, scale, num_features, num_blocks): | |
| """Loads the generator model from a checkpoint.""" | |
| global generator_model, device | |
| print("--- Loading Model ---") | |
| # Use GPU if available, otherwise fallback to CPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| try: | |
| # Instantiate the model with the defined parameters | |
| model = Generator(scale_factor=scale, num_features=num_features, num_res_blocks=num_blocks).to(device) | |
| model.eval() # Set model to evaluation mode | |
| # Check if checkpoint exists | |
| if not os.path.exists(checkpoint_path): | |
| print(f"ERROR: Checkpoint file not found at: '{checkpoint_path}'") | |
| # Try a common alternative path for Spaces if needed, e.g., relative to app.py | |
| # alternative_path = os.path.join(os.path.dirname(__file__), checkpoint_path) | |
| # if os.path.exists(alternative_path): | |
| # checkpoint_path = alternative_path | |
| # print(f"Found checkpoint at alternative path: {checkpoint_path}") | |
| # else: | |
| # return False # Checkpoint not found | |
| return False # Just return False if not found at the specified path | |
| print(f"Loading state dictionary from: {checkpoint_path}") | |
| state_dict = torch.load(checkpoint_path, map_location=device) | |
| # Handle DataParallel saved models | |
| if any('module.' in k for k in state_dict.keys()): | |
| print("Removing 'module.' prefix from state_dict keys...") | |
| state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
| # Load the state dictionary | |
| model.load_state_dict(state_dict) | |
| print("State dictionary loaded successfully.") | |
| except Exception as e: | |
| print(f"ERROR loading or instantiating model: {e}") | |
| return False # Return False if loading fails | |
| generator_model = model | |
| print(f"Generator model (Scale x{scale}) loaded and ready on {device}.") | |
| print("--------------------") | |
| return True # Return True on success | |
| def upscale_image(input_image: Image.Image): | |
| """Upscales an input image using the loaded generator model.""" | |
| global generator_model, device, MODEL_SCALE_FACTOR # Access scale factor | |
| if generator_model is None: | |
| # This should ideally not happen if load_generator_model succeeded, | |
| # but provides a fallback error message. | |
| raise gr.Error("Generator model not loaded. Application is not ready.") | |
| if input_image is None: | |
| return None # Return nothing if no input image is provided | |
| # --- Preprocessing --- | |
| # Gradio gives a PIL Image, convert to PyTorch tensor | |
| transform_to_tensor = transforms.ToTensor() | |
| # Add a batch dimension (unsqueeze(0)) and move to device | |
| input_tensor = transform_to_tensor(input_image).unsqueeze(0).to(device) | |
| # --- Inference --- | |
| with torch.no_grad(): # Disable gradient calculation for inference | |
| try: | |
| print(f"Upscaling image with scale factor: {MODEL_SCALE_FACTOR}") | |
| # Perform the super-resolution inference | |
| output_tensor = generator_model(input_tensor) | |
| # Remove the batch dimension and move to CPU, clamp values to [0, 1] | |
| output_tensor = output_tensor.squeeze(0).cpu().clamp(0, 1) | |
| # --- Postprocessing --- | |
| # Convert the output tensor back to a PIL Image | |
| # This automatically creates an image with the upscaled dimensions | |
| output_image = transforms.ToPILImage()(output_tensor) | |
| # The line below in your original code was incorrect for SR, | |
| # as it resized the upscaled image BACK to the original size. | |
| # output_image = output_image.resize((original_width, original_height), resample=Image.BICUBIC) | |
| print("Upscaling complete.") | |
| return output_image # Return the correctly upscaled PIL Image | |
| except RuntimeError as e: | |
| print(f"ERROR during inference: {e}") | |
| # Provide a user-friendly error message via Gradio | |
| raise gr.Error(f"Inference failed: {e}. Please check the input image.") | |
| except Exception as e: | |
| print(f"An unexpected error occurred during upscaling: {e}") | |
| raise gr.Error(f"An unexpected error occurred: {e}") | |
| # --- Gradio Interface --- | |
| if __name__ == "__main__": | |
| # --- Load the model with defined parameters --- | |
| print("Attempting to load model for Gradio interface...") | |
| # Pass the defined constants to the loading function | |
| if load_generator_model(CHECKPOINT_PATH, MODEL_SCALE_FACTOR, MODEL_NUM_FEATURES, MODEL_NUM_BLOCKS): | |
| print("Model loaded successfully. Starting Gradio interface.") | |
| # Define the Gradio Interface | |
| iface = gr.Interface( | |
| fn=upscale_image, | |
| inputs=[gr.Image(type="pil", label="Input Image (Low Resolution)")], | |
| outputs=gr.Image(type="pil", label=f"Output Image (Upscaled x{MODEL_SCALE_FACTOR})"), # Clarify output | |
| title="OxO Image Super-Resolution (x{})".format(MODEL_SCALE_FACTOR), # Include scale in title | |
| description=f"Upload a PNG to upscale it by a factor of {MODEL_SCALE_FACTOR}. Repaired images may seem dark and colorless, you can adjust the contrast to make it look better.", | |
| allow_flagging='never', # Disable flagging | |
| cache_examples=False, # Do not cache examples | |
| live=False, # Process only when submit is clicked (better for heavy tasks) | |
| ) | |
| # Launch the Gradio app | |
| # For Hugging Face Spaces, you typically don't need share=True | |
| # The server name and port are often required in containerized environments | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |
| else: | |
| # If model loading failed, print an error and potentially exit or raise | |
| print("Failed to load the model. The Gradio interface will not start.") | |
| # Optionally, you could raise an exception here if you want the Space to error out | |
| # raise RuntimeError("Model loading failed.") |