Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -55,6 +55,7 @@ def models_to(model, device="cpu", excepts=None):
|
|
| 55 |
attr_value.to(device)
|
| 56 |
|
| 57 |
torch.cuda.empty_cache()
|
|
|
|
| 58 |
|
| 59 |
# Stage C model configuration
|
| 60 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
@@ -214,19 +215,24 @@ def infer(ref_style_file, style_description, caption, progress):
|
|
| 214 |
# Remove the batch dimension and keep only the generated image
|
| 215 |
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
| 216 |
|
|
|
|
|
|
|
|
|
|
| 217 |
# Ensure the tensor is in [C, H, W] format
|
| 218 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 219 |
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
| 220 |
-
|
| 221 |
else:
|
| 222 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 223 |
|
| 224 |
progress(1.0, "Inference complete")
|
| 225 |
-
return output_file # Return the path to the saved image
|
|
|
|
| 226 |
|
| 227 |
finally:
|
| 228 |
# Clear CUDA cache
|
| 229 |
torch.cuda.empty_cache()
|
|
|
|
| 230 |
|
| 231 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
| 232 |
global models_rbm, models_b, device, sam_model
|
|
@@ -348,6 +354,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
|
|
| 348 |
finally:
|
| 349 |
# Clear CUDA cache
|
| 350 |
torch.cuda.empty_cache()
|
|
|
|
| 351 |
|
| 352 |
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
|
| 353 |
result = None
|
|
|
|
| 55 |
attr_value.to(device)
|
| 56 |
|
| 57 |
torch.cuda.empty_cache()
|
| 58 |
+
gc.collect()
|
| 59 |
|
| 60 |
# Stage C model configuration
|
| 61 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
|
|
| 215 |
# Remove the batch dimension and keep only the generated image
|
| 216 |
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
| 217 |
|
| 218 |
+
# Ensure the tensor values are in the correct range
|
| 219 |
+
sampled = torch.clamp(sampled, 0, 1)
|
| 220 |
+
|
| 221 |
# Ensure the tensor is in [C, H, W] format
|
| 222 |
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
| 223 |
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
| 224 |
+
# sampled_image.save(output_file) # Save the image as a PNG
|
| 225 |
else:
|
| 226 |
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
| 227 |
|
| 228 |
progress(1.0, "Inference complete")
|
| 229 |
+
#return output_file # Return the path to the saved image
|
| 230 |
+
return sampled_image
|
| 231 |
|
| 232 |
finally:
|
| 233 |
# Clear CUDA cache
|
| 234 |
torch.cuda.empty_cache()
|
| 235 |
+
gc.collect()
|
| 236 |
|
| 237 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
| 238 |
global models_rbm, models_b, device, sam_model
|
|
|
|
| 354 |
finally:
|
| 355 |
# Clear CUDA cache
|
| 356 |
torch.cuda.empty_cache()
|
| 357 |
+
gc.collect()
|
| 358 |
|
| 359 |
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
|
| 360 |
result = None
|