Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -149,6 +149,10 @@ models_rbm = core.Models(
|
|
| 149 |
models_rbm.generator.eval().requires_grad_(False)
|
| 150 |
|
| 151 |
def infer(style_description, ref_style_file, caption):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
clear_gpu_cache() # Clear cache before inference
|
| 153 |
|
| 154 |
height=1024
|
|
@@ -181,10 +185,10 @@ def infer(style_description, ref_style_file, caption):
|
|
| 181 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
| 182 |
|
| 183 |
if low_vram:
|
| 184 |
-
#
|
| 185 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 186 |
|
| 187 |
-
# Stage C reverse process
|
| 188 |
with torch.cuda.amp.autocast(): # Use mixed precision
|
| 189 |
sampling_c = extras.gdf.sample(
|
| 190 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
|
@@ -202,7 +206,10 @@ def infer(style_description, ref_style_file, caption):
|
|
| 202 |
|
| 203 |
clear_gpu_cache() # Clear cache between stages
|
| 204 |
|
| 205 |
-
#
|
|
|
|
|
|
|
|
|
|
| 206 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 207 |
conditions_b['effnet'] = sampled_c
|
| 208 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
|
@@ -215,6 +222,7 @@ def infer(style_description, ref_style_file, caption):
|
|
| 215 |
sampled_b = sampled_b
|
| 216 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 217 |
|
|
|
|
| 218 |
sampled = torch.cat([
|
| 219 |
torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
|
| 220 |
sampled.cpu(),
|
|
@@ -234,6 +242,7 @@ def infer(style_description, ref_style_file, caption):
|
|
| 234 |
|
| 235 |
return output_file # Return the path to the saved image
|
| 236 |
|
|
|
|
| 237 |
import gradio as gr
|
| 238 |
|
| 239 |
gr.Interface(
|
|
|
|
| 149 |
models_rbm.generator.eval().requires_grad_(False)
|
| 150 |
|
| 151 |
def infer(style_description, ref_style_file, caption):
|
| 152 |
+
# Ensure all models are moved back to the correct device
|
| 153 |
+
core_b.generator.to(device)
|
| 154 |
+
models_rbm.generator.to(device)
|
| 155 |
+
|
| 156 |
clear_gpu_cache() # Clear cache before inference
|
| 157 |
|
| 158 |
height=1024
|
|
|
|
| 185 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
| 186 |
|
| 187 |
if low_vram:
|
| 188 |
+
# Offload non-essential models to CPU for memory savings
|
| 189 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
| 190 |
|
| 191 |
+
# Stage C reverse process
|
| 192 |
with torch.cuda.amp.autocast(): # Use mixed precision
|
| 193 |
sampling_c = extras.gdf.sample(
|
| 194 |
models_rbm.generator, conditions, stage_c_latent_shape,
|
|
|
|
| 206 |
|
| 207 |
clear_gpu_cache() # Clear cache between stages
|
| 208 |
|
| 209 |
+
# Ensure all models are on the right device again
|
| 210 |
+
models_b.generator.to(device)
|
| 211 |
+
|
| 212 |
+
# Stage B reverse process
|
| 213 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 214 |
conditions_b['effnet'] = sampled_c
|
| 215 |
unconditions_b['effnet'] = torch.zeros_like(sampled_c)
|
|
|
|
| 222 |
sampled_b = sampled_b
|
| 223 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
| 224 |
|
| 225 |
+
# Post-process and save the image
|
| 226 |
sampled = torch.cat([
|
| 227 |
torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
|
| 228 |
sampled.cpu(),
|
|
|
|
| 242 |
|
| 243 |
return output_file # Return the path to the saved image
|
| 244 |
|
| 245 |
+
|
| 246 |
import gradio as gr
|
| 247 |
|
| 248 |
gr.Interface(
|