File size: 7,518 Bytes
838846e
 
a71d6ca
838846e
 
a71d6ca
838846e
 
a71d6ca
 
838846e
a71d6ca
 
838846e
a71d6ca
838846e
 
a71d6ca
 
838846e
 
 
 
 
 
a71d6ca
 
 
 
 
 
 
838846e
 
a71d6ca
 
 
838846e
 
a71d6ca
838846e
 
 
 
a71d6ca
838846e
a71d6ca
838846e
a71d6ca
838846e
 
a71d6ca
 
 
 
 
 
 
 
838846e
 
 
 
a71d6ca
838846e
a71d6ca
838846e
 
a71d6ca
838846e
 
 
 
 
a71d6ca
838846e
 
a71d6ca
838846e
a71d6ca
838846e
 
a71d6ca
 
 
838846e
 
a71d6ca
 
 
838846e
 
a71d6ca
838846e
a71d6ca
 
838846e
a71d6ca
838846e
 
a71d6ca
 
838846e
a71d6ca
 
838846e
a71d6ca
 
838846e
 
a71d6ca
 
 
838846e
 
a71d6ca
 
 
 
 
 
838846e
 
 
a71d6ca
 
 
 
 
838846e
 
a71d6ca
838846e
a71d6ca
 
 
 
 
838846e
a71d6ca
838846e
 
a71d6ca
 
 
deca831
a71d6ca
 
 
838846e
a71d6ca
 
 
 
 
838846e
a71d6ca
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
# argparse is removed
import os
import time
# huggingface_hub.notebook_login is usually not needed in a Space
# from huggingface_hub import notebook_login 

# IMPORTANT: models.py must be in the same directory in your Space repository
# Ensure Generator class is defined in models.py
try:
    # Assuming models.py is in the same directory
    from models import Generator
except ImportError:
    print("ERROR: 'models.py' not found or 'Generator' class not defined within it.")
    print("Please ensure models.py is in your Space's root directory.")
    exit(1)


# --- Global Variables / Setup ---
generator_model = None
device = None

# --- Model Configuration (Define these directly for the Space) ---
# These should match the training parameters of your model checkpoint
CHECKPOINT_PATH = "checkpoints/generator_epoch_15.pth" # <--- IMPORTANT: Update if your path is different
MODEL_SCALE_FACTOR = 4                               # <--- IMPORTANT: Update if your model scale is different
MODEL_NUM_FEATURES = 64                              # <--- IMPORTANT: Update if your model features is different
MODEL_NUM_BLOCKS = 16                                # <--- IMPORTANT: Update if your model blocks is different


def load_generator_model(checkpoint_path, scale, num_features, num_blocks):
    """Loads the generator model from a checkpoint."""
    global generator_model, device

    print("--- Loading Model ---")
    # Use GPU if available, otherwise fallback to CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    try:
        # Instantiate the model with the defined parameters
        model = Generator(scale_factor=scale, num_features=num_features, num_res_blocks=num_blocks).to(device)
        model.eval() # Set model to evaluation mode

        # Check if checkpoint exists
        if not os.path.exists(checkpoint_path):
            print(f"ERROR: Checkpoint file not found at: '{checkpoint_path}'")
            # Try a common alternative path for Spaces if needed, e.g., relative to app.py
            # alternative_path = os.path.join(os.path.dirname(__file__), checkpoint_path)
            # if os.path.exists(alternative_path):
            #     checkpoint_path = alternative_path
            #     print(f"Found checkpoint at alternative path: {checkpoint_path}")
            # else:
            #     return False # Checkpoint not found
            return False # Just return False if not found at the specified path

        print(f"Loading state dictionary from: {checkpoint_path}")
        state_dict = torch.load(checkpoint_path, map_location=device)

        # Handle DataParallel saved models
        if any('module.' in k for k in state_dict.keys()):
            print("Removing 'module.' prefix from state_dict keys...")
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

        # Load the state dictionary
        model.load_state_dict(state_dict)
        print("State dictionary loaded successfully.")

    except Exception as e:
        print(f"ERROR loading or instantiating model: {e}")
        return False # Return False if loading fails

    generator_model = model
    print(f"Generator model (Scale x{scale}) loaded and ready on {device}.")
    print("--------------------")
    return True # Return True on success


def upscale_image(input_image: Image.Image):
    """Upscales an input image using the loaded generator model."""
    global generator_model, device, MODEL_SCALE_FACTOR # Access scale factor

    if generator_model is None:
        # This should ideally not happen if load_generator_model succeeded,
        # but provides a fallback error message.
        raise gr.Error("Generator model not loaded. Application is not ready.")

    if input_image is None:
        return None # Return nothing if no input image is provided

    # --- Preprocessing ---
    # Gradio gives a PIL Image, convert to PyTorch tensor
    transform_to_tensor = transforms.ToTensor()
    # Add a batch dimension (unsqueeze(0)) and move to device
    input_tensor = transform_to_tensor(input_image).unsqueeze(0).to(device)

    # --- Inference ---
    with torch.no_grad(): # Disable gradient calculation for inference
        try:
            print(f"Upscaling image with scale factor: {MODEL_SCALE_FACTOR}")
            # Perform the super-resolution inference
            output_tensor = generator_model(input_tensor)

            # Remove the batch dimension and move to CPU, clamp values to [0, 1]
            output_tensor = output_tensor.squeeze(0).cpu().clamp(0, 1)

            # --- Postprocessing ---
            # Convert the output tensor back to a PIL Image
            # This automatically creates an image with the upscaled dimensions
            output_image = transforms.ToPILImage()(output_tensor)

            # The line below in your original code was incorrect for SR,
            # as it resized the upscaled image BACK to the original size.
            # output_image = output_image.resize((original_width, original_height), resample=Image.BICUBIC)

            print("Upscaling complete.")
            return output_image # Return the correctly upscaled PIL Image

        except RuntimeError as e:
            print(f"ERROR during inference: {e}")
            # Provide a user-friendly error message via Gradio
            raise gr.Error(f"Inference failed: {e}. Please check the input image.")
        except Exception as e:
            print(f"An unexpected error occurred during upscaling: {e}")
            raise gr.Error(f"An unexpected error occurred: {e}")


# --- Gradio Interface ---
if __name__ == "__main__":
    # --- Load the model with defined parameters ---
    print("Attempting to load model for Gradio interface...")
    # Pass the defined constants to the loading function
    if load_generator_model(CHECKPOINT_PATH, MODEL_SCALE_FACTOR, MODEL_NUM_FEATURES, MODEL_NUM_BLOCKS):
        print("Model loaded successfully. Starting Gradio interface.")

        # Define the Gradio Interface
        iface = gr.Interface(
            fn=upscale_image,
            inputs=[gr.Image(type="pil", label="Input Image (Low Resolution)")],
            outputs=gr.Image(type="pil", label=f"Output Image (Upscaled x{MODEL_SCALE_FACTOR})"), # Clarify output
            title="OxO Image Super-Resolution (x{})".format(MODEL_SCALE_FACTOR), # Include scale in title
            description=f"Upload a PNG to upscale it by a factor of {MODEL_SCALE_FACTOR}. Repaired images may seem dark and colorless, you can adjust the contrast to make it look better.",
            allow_flagging='never', # Disable flagging
            cache_examples=False,   # Do not cache examples
            live=False,             # Process only when submit is clicked (better for heavy tasks)
        )

        # Launch the Gradio app
        # For Hugging Face Spaces, you typically don't need share=True
        # The server name and port are often required in containerized environments
        iface.launch(server_name="0.0.0.0", server_port=7860)
    else:
        # If model loading failed, print an error and potentially exit or raise
        print("Failed to load the model. The Gradio interface will not start.")
        # Optionally, you could raise an exception here if you want the Space to error out
        # raise RuntimeError("Model loading failed.")