Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -221,10 +221,16 @@ def infer(style_description, ref_style_file, caption):
|
|
| 221 |
],
|
| 222 |
dim=0)
|
| 223 |
|
| 224 |
-
#
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
clear_gpu_cache() # Clear cache after inference
|
| 229 |
|
| 230 |
return output_file # Return the path to the saved image
|
|
|
|
| 221 |
],
|
| 222 |
dim=0)
|
| 223 |
|
| 224 |
+
# Remove batch dimension if it exists
|
| 225 |
+
if sampled.dim() == 4 and sampled.size(0) == 1:
|
| 226 |
+
sampled = sampled.squeeze(0)
|
| 227 |
+
|
| 228 |
+
# Ensure the tensor is in [C, H, W] format
|
| 229 |
+
if sampled.dim() == 3:
|
| 230 |
+
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
| 231 |
+
sampled_image.save(output_file) # Save the image as a PNG
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(f"Expected tensor of shape [C, H, W] but got {sampled.shape}")
|
| 234 |
clear_gpu_cache() # Clear cache after inference
|
| 235 |
|
| 236 |
return output_file # Return the path to the saved image
|