Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,8 +3,9 @@ import torch
|
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
|
|
|
| 6 |
from huggingface_hub import snapshot_download
|
| 7 |
-
from test_ccsr_tile import
|
| 8 |
import argparse
|
| 9 |
from accelerate import Accelerator
|
| 10 |
|
|
@@ -48,6 +49,12 @@ def initialize_models():
|
|
| 48 |
# Load pipeline
|
| 49 |
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# Initialize generator
|
| 52 |
generator = torch.Generator(device=accelerator.device)
|
| 53 |
|
|
@@ -57,7 +64,7 @@ def initialize_models():
|
|
| 57 |
print(f"Error initializing models: {str(e)}")
|
| 58 |
return False
|
| 59 |
|
| 60 |
-
@spaces.GPU
|
| 61 |
def process_image(
|
| 62 |
input_image,
|
| 63 |
prompt="clean, high-resolution, 8k",
|
|
@@ -117,13 +124,6 @@ def process_image(
|
|
| 117 |
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
|
| 118 |
width, height = validation_image.size
|
| 119 |
|
| 120 |
-
# Move pipeline to GPU and set to eval mode
|
| 121 |
-
pipeline.to(accelerator.device)
|
| 122 |
-
pipeline.unet.eval()
|
| 123 |
-
pipeline.controlnet.eval()
|
| 124 |
-
pipeline.vae.eval()
|
| 125 |
-
pipeline.text_encoder.eval()
|
| 126 |
-
|
| 127 |
# Generate image
|
| 128 |
with torch.no_grad():
|
| 129 |
inference_time, output = pipeline(
|
|
@@ -157,62 +157,30 @@ def process_image(
|
|
| 157 |
if resize_flag:
|
| 158 |
image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
|
| 159 |
|
| 160 |
-
# Move pipeline back to CPU to free up GPU memory
|
| 161 |
-
pipeline.to("cpu")
|
| 162 |
-
torch.cuda.empty_cache()
|
| 163 |
-
|
| 164 |
return image
|
| 165 |
|
| 166 |
except Exception as e:
|
| 167 |
print(f"Error processing image: {str(e)}")
|
| 168 |
return None
|
| 169 |
|
| 170 |
-
#
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
vae_decoder_tile_size=224
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
# Initialize accelerator
|
| 195 |
-
accelerator = Accelerator(
|
| 196 |
-
mixed_precision=args.mixed_precision,
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
# Load pipeline
|
| 200 |
-
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
|
| 201 |
-
|
| 202 |
-
# Set pipeline to eval mode
|
| 203 |
-
pipeline.unet.eval()
|
| 204 |
-
pipeline.controlnet.eval()
|
| 205 |
-
pipeline.vae.eval()
|
| 206 |
-
pipeline.text_encoder.eval()
|
| 207 |
-
|
| 208 |
-
# Move to CPU initially to save memory
|
| 209 |
-
pipeline.to("cpu")
|
| 210 |
-
|
| 211 |
-
# Initialize generator
|
| 212 |
-
generator = torch.Generator(device=accelerator.device)
|
| 213 |
-
|
| 214 |
-
return True
|
| 215 |
-
|
| 216 |
-
except Exception as e:
|
| 217 |
-
print(f"Error initializing models: {str(e)}")
|
| 218 |
-
return False
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
| 6 |
+
from diffusers import DiffusionPipeline
|
| 7 |
from huggingface_hub import snapshot_download
|
| 8 |
+
from test_ccsr_tile import load_pipeline
|
| 9 |
import argparse
|
| 10 |
from accelerate import Accelerator
|
| 11 |
|
|
|
|
| 49 |
# Load pipeline
|
| 50 |
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
|
| 51 |
|
| 52 |
+
# Set pipeline to eval mode
|
| 53 |
+
pipeline.unet.eval()
|
| 54 |
+
pipeline.controlnet.eval()
|
| 55 |
+
pipeline.vae.eval()
|
| 56 |
+
pipeline.text_encoder.eval()
|
| 57 |
+
|
| 58 |
# Initialize generator
|
| 59 |
generator = torch.Generator(device=accelerator.device)
|
| 60 |
|
|
|
|
| 64 |
print(f"Error initializing models: {str(e)}")
|
| 65 |
return False
|
| 66 |
|
| 67 |
+
@spaces.GPU(processing_timeout=180) # Increased timeout for longer processing
|
| 68 |
def process_image(
|
| 69 |
input_image,
|
| 70 |
prompt="clean, high-resolution, 8k",
|
|
|
|
| 124 |
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
|
| 125 |
width, height = validation_image.size
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
# Generate image
|
| 128 |
with torch.no_grad():
|
| 129 |
inference_time, output = pipeline(
|
|
|
|
| 157 |
if resize_flag:
|
| 158 |
image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
return image
|
| 161 |
|
| 162 |
except Exception as e:
|
| 163 |
print(f"Error processing image: {str(e)}")
|
| 164 |
return None
|
| 165 |
|
| 166 |
+
# Create Gradio interface
|
| 167 |
+
demo = gr.Interface(
|
| 168 |
+
fn=process_image,
|
| 169 |
+
inputs=[
|
| 170 |
+
gr.Image(label="Input Image"),
|
| 171 |
+
gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
|
| 172 |
+
gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
|
| 173 |
+
gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
|
| 174 |
+
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
|
| 175 |
+
gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
|
| 176 |
+
gr.Number(label="Seed", value=42),
|
| 177 |
+
gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
|
| 178 |
+
gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
|
| 179 |
+
],
|
| 180 |
+
outputs=gr.Image(label="Generated Image"),
|
| 181 |
+
title="Controllable Conditional Super-Resolution",
|
| 182 |
+
description="Upload an image to enhance its resolution using CCSR."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|