import gradio as gr import torch from torchvision import transforms from PIL import Image import numpy as np # argparse is removed import os import time # huggingface_hub.notebook_login is usually not needed in a Space # from huggingface_hub import notebook_login # IMPORTANT: models.py must be in the same directory in your Space repository # Ensure Generator class is defined in models.py try: # Assuming models.py is in the same directory from models import Generator except ImportError: print("ERROR: 'models.py' not found or 'Generator' class not defined within it.") print("Please ensure models.py is in your Space's root directory.") exit(1) # --- Global Variables / Setup --- generator_model = None device = None # --- Model Configuration (Define these directly for the Space) --- # These should match the training parameters of your model checkpoint CHECKPOINT_PATH = "checkpoints/generator_epoch_15.pth" # <--- IMPORTANT: Update if your path is different MODEL_SCALE_FACTOR = 4 # <--- IMPORTANT: Update if your model scale is different MODEL_NUM_FEATURES = 64 # <--- IMPORTANT: Update if your model features is different MODEL_NUM_BLOCKS = 16 # <--- IMPORTANT: Update if your model blocks is different def load_generator_model(checkpoint_path, scale, num_features, num_blocks): """Loads the generator model from a checkpoint.""" global generator_model, device print("--- Loading Model ---") # Use GPU if available, otherwise fallback to CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") try: # Instantiate the model with the defined parameters model = Generator(scale_factor=scale, num_features=num_features, num_res_blocks=num_blocks).to(device) model.eval() # Set model to evaluation mode # Check if checkpoint exists if not os.path.exists(checkpoint_path): print(f"ERROR: Checkpoint file not found at: '{checkpoint_path}'") # Try a common alternative path for Spaces if needed, e.g., relative to app.py # alternative_path = os.path.join(os.path.dirname(__file__), checkpoint_path) # if os.path.exists(alternative_path): # checkpoint_path = alternative_path # print(f"Found checkpoint at alternative path: {checkpoint_path}") # else: # return False # Checkpoint not found return False # Just return False if not found at the specified path print(f"Loading state dictionary from: {checkpoint_path}") state_dict = torch.load(checkpoint_path, map_location=device) # Handle DataParallel saved models if any('module.' in k for k in state_dict.keys()): print("Removing 'module.' prefix from state_dict keys...") state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Load the state dictionary model.load_state_dict(state_dict) print("State dictionary loaded successfully.") except Exception as e: print(f"ERROR loading or instantiating model: {e}") return False # Return False if loading fails generator_model = model print(f"Generator model (Scale x{scale}) loaded and ready on {device}.") print("--------------------") return True # Return True on success def upscale_image(input_image: Image.Image): """Upscales an input image using the loaded generator model.""" global generator_model, device, MODEL_SCALE_FACTOR # Access scale factor if generator_model is None: # This should ideally not happen if load_generator_model succeeded, # but provides a fallback error message. raise gr.Error("Generator model not loaded. Application is not ready.") if input_image is None: return None # Return nothing if no input image is provided # --- Preprocessing --- # Gradio gives a PIL Image, convert to PyTorch tensor transform_to_tensor = transforms.ToTensor() # Add a batch dimension (unsqueeze(0)) and move to device input_tensor = transform_to_tensor(input_image).unsqueeze(0).to(device) # --- Inference --- with torch.no_grad(): # Disable gradient calculation for inference try: print(f"Upscaling image with scale factor: {MODEL_SCALE_FACTOR}") # Perform the super-resolution inference output_tensor = generator_model(input_tensor) # Remove the batch dimension and move to CPU, clamp values to [0, 1] output_tensor = output_tensor.squeeze(0).cpu().clamp(0, 1) # --- Postprocessing --- # Convert the output tensor back to a PIL Image # This automatically creates an image with the upscaled dimensions output_image = transforms.ToPILImage()(output_tensor) # The line below in your original code was incorrect for SR, # as it resized the upscaled image BACK to the original size. # output_image = output_image.resize((original_width, original_height), resample=Image.BICUBIC) print("Upscaling complete.") return output_image # Return the correctly upscaled PIL Image except RuntimeError as e: print(f"ERROR during inference: {e}") # Provide a user-friendly error message via Gradio raise gr.Error(f"Inference failed: {e}. Please check the input image.") except Exception as e: print(f"An unexpected error occurred during upscaling: {e}") raise gr.Error(f"An unexpected error occurred: {e}") # --- Gradio Interface --- if __name__ == "__main__": # --- Load the model with defined parameters --- print("Attempting to load model for Gradio interface...") # Pass the defined constants to the loading function if load_generator_model(CHECKPOINT_PATH, MODEL_SCALE_FACTOR, MODEL_NUM_FEATURES, MODEL_NUM_BLOCKS): print("Model loaded successfully. Starting Gradio interface.") # Define the Gradio Interface iface = gr.Interface( fn=upscale_image, inputs=[gr.Image(type="pil", label="Input Image (Low Resolution)")], outputs=gr.Image(type="pil", label=f"Output Image (Upscaled x{MODEL_SCALE_FACTOR})"), # Clarify output title="OxO Image Super-Resolution (x{})".format(MODEL_SCALE_FACTOR), # Include scale in title description=f"Upload a PNG to upscale it by a factor of {MODEL_SCALE_FACTOR}. Repaired images may seem dark and colorless, you can adjust the contrast to make it look better.", allow_flagging='never', # Disable flagging cache_examples=False, # Do not cache examples live=False, # Process only when submit is clicked (better for heavy tasks) ) # Launch the Gradio app # For Hugging Face Spaces, you typically don't need share=True # The server name and port are often required in containerized environments iface.launch(server_name="0.0.0.0", server_port=7860) else: # If model loading failed, print an error and potentially exit or raise print("Failed to load the model. The Gradio interface will not start.") # Optionally, you could raise an exception here if you want the Space to error out # raise RuntimeError("Model loading failed.")