Spaces:
Running
on
Zero
Running
on
Zero
Men1scus
commited on
Commit
·
aad20a1
1
Parent(s):
1a908ba
fix: Refactor generator initialization and enhance mixed precision handling in process_sr function
Browse files
app.py
CHANGED
|
@@ -329,7 +329,7 @@ def process_sr(
|
|
| 329 |
])
|
| 330 |
|
| 331 |
seed_everything(seed)
|
| 332 |
-
generator = torch.Generator(device=
|
| 333 |
generator.manual_seed(seed)
|
| 334 |
|
| 335 |
validation_prompt = f"{user_prompt} {positive_prompt}"
|
|
@@ -355,12 +355,19 @@ def process_sr(
|
|
| 355 |
else:
|
| 356 |
pipeline = pipeline_dit4sr_f
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
try:
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
|
|
|
| 364 |
|
| 365 |
if True: # alpha<1.0:
|
| 366 |
image = adain_color_fix(image, input_image)
|
|
@@ -368,7 +375,7 @@ def process_sr(
|
|
| 368 |
if resize_flag:
|
| 369 |
image = image.resize((ori_width * rscale, ori_height * rscale))
|
| 370 |
except Exception as e:
|
| 371 |
-
print(e)
|
| 372 |
image = Image.new(mode="RGB", size=(512, 512))
|
| 373 |
images.append(np.array(image))
|
| 374 |
return images
|
|
|
|
| 329 |
])
|
| 330 |
|
| 331 |
seed_everything(seed)
|
| 332 |
+
generator = torch.Generator(device=dit4sr_device)
|
| 333 |
generator.manual_seed(seed)
|
| 334 |
|
| 335 |
validation_prompt = f"{user_prompt} {positive_prompt}"
|
|
|
|
| 355 |
else:
|
| 356 |
pipeline = pipeline_dit4sr_f
|
| 357 |
|
| 358 |
+
weight_dtype = torch.float32
|
| 359 |
+
if args.mixed_precision == "fp16":
|
| 360 |
+
weight_dtype = torch.float16
|
| 361 |
+
elif args.mixed_precision == "bf16":
|
| 362 |
+
weight_dtype = torch.bfloat16
|
| 363 |
+
|
| 364 |
try:
|
| 365 |
+
with torch.autocast(device_type='cuda', dtype=weight_dtype, enabled=(args.mixed_precision != "no")):
|
| 366 |
+
image = pipeline(
|
| 367 |
+
prompt=validation_prompt, control_image=input_image, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width,
|
| 368 |
+
guidance_scale=cfg_scale, negative_prompt=negative_prompt, start_point=args.start_point, latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap,
|
| 369 |
+
args=args,
|
| 370 |
+
).images[0]
|
| 371 |
|
| 372 |
if True: # alpha<1.0:
|
| 373 |
image = adain_color_fix(image, input_image)
|
|
|
|
| 375 |
if resize_flag:
|
| 376 |
image = image.resize((ori_width * rscale, ori_height * rscale))
|
| 377 |
except Exception as e:
|
| 378 |
+
print(f"Error during inference: {e}")
|
| 379 |
image = Image.new(mode="RGB", size=(512, 512))
|
| 380 |
images.append(np.array(image))
|
| 381 |
return images
|