Men1scus commited on
Commit
aad20a1
·
1 Parent(s): 1a908ba

fix: Refactor generator initialization and enhance mixed precision handling in process_sr function

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -329,7 +329,7 @@ def process_sr(
329
  ])
330
 
331
  seed_everything(seed)
332
- generator = torch.Generator(device='cuda:1')
333
  generator.manual_seed(seed)
334
 
335
  validation_prompt = f"{user_prompt} {positive_prompt}"
@@ -355,12 +355,19 @@ def process_sr(
355
  else:
356
  pipeline = pipeline_dit4sr_f
357
 
 
 
 
 
 
 
358
  try:
359
- image = pipeline(
360
- prompt=validation_prompt, control_image=input_image, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width,
361
- guidance_scale=cfg_scale, negative_prompt=negative_prompt, start_point=args.start_point, latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap,
362
- args=args,
363
- ).images[0]
 
364
 
365
  if True: # alpha<1.0:
366
  image = adain_color_fix(image, input_image)
@@ -368,7 +375,7 @@ def process_sr(
368
  if resize_flag:
369
  image = image.resize((ori_width * rscale, ori_height * rscale))
370
  except Exception as e:
371
- print(e)
372
  image = Image.new(mode="RGB", size=(512, 512))
373
  images.append(np.array(image))
374
  return images
 
329
  ])
330
 
331
  seed_everything(seed)
332
+ generator = torch.Generator(device=dit4sr_device)
333
  generator.manual_seed(seed)
334
 
335
  validation_prompt = f"{user_prompt} {positive_prompt}"
 
355
  else:
356
  pipeline = pipeline_dit4sr_f
357
 
358
+ weight_dtype = torch.float32
359
+ if args.mixed_precision == "fp16":
360
+ weight_dtype = torch.float16
361
+ elif args.mixed_precision == "bf16":
362
+ weight_dtype = torch.bfloat16
363
+
364
  try:
365
+ with torch.autocast(device_type='cuda', dtype=weight_dtype, enabled=(args.mixed_precision != "no")):
366
+ image = pipeline(
367
+ prompt=validation_prompt, control_image=input_image, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width,
368
+ guidance_scale=cfg_scale, negative_prompt=negative_prompt, start_point=args.start_point, latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap,
369
+ args=args,
370
+ ).images[0]
371
 
372
  if True: # alpha<1.0:
373
  image = adain_color_fix(image, input_image)
 
375
  if resize_flag:
376
  image = image.resize((ori_width * rscale, ori_height * rscale))
377
  except Exception as e:
378
+ print(f"Error during inference: {e}")
379
  image = Image.new(mode="RGB", size=(512, 512))
380
  images.append(np.array(image))
381
  return images