MogensR commited on
Commit
7a982a5
·
verified ·
1 Parent(s): b899a46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -7
app.py CHANGED
@@ -372,13 +372,108 @@ def generate_ai_background(
372
  seed: Optional[int] = None,
373
  ) -> str:
374
  """Generate AI background using Stable Diffusion."""
375
- # TEMPORARILY DISABLED due to PyTorch/Diffusers compatibility issue
376
- # To fix: pip install --upgrade torch diffusers transformers
377
- raise RuntimeError(
378
- "AI Background temporarily disabled due to PyTorch/Diffusers version compatibility.\n"
379
- "To fix: pip install --upgrade torch diffusers transformers accelerate\n"
380
- "For now, please use Upload Image or Gradients instead."
381
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  # ==============================================================================
384
  # MAIN PROCESSING PIPELINE
 
372
  seed: Optional[int] = None,
373
  ) -> str:
374
  """Generate AI background using Stable Diffusion."""
375
+ if not TORCH_AVAILABLE:
376
+ raise RuntimeError("PyTorch required for AI background generation")
377
+
378
+ try:
379
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
380
+ except ImportError as e:
381
+ raise RuntimeError(f"Please install diffusers: pip install diffusers transformers accelerate\nError: {e}")
382
+
383
+ device = "cuda" if CUDA_AVAILABLE else "cpu"
384
+ torch_dtype = torch.float16 if CUDA_AVAILABLE else torch.float32
385
+
386
+ # Setup generator
387
+ generator = torch.Generator(device=device)
388
+ if seed is None:
389
+ seed = random.randint(0, 2**31 - 1)
390
+ generator.manual_seed(seed)
391
+
392
+ logger.info(f"Generating {width}x{height} background: '{prompt}' (seed: {seed})")
393
+
394
+ try:
395
+ # Choose pipeline based on whether we have an init image
396
+ if init_image_path and os.path.exists(init_image_path):
397
+ # Image-to-image pipeline
398
+ logger.info("Using img2img pipeline")
399
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
400
+ "runwayml/stable-diffusion-v1-5",
401
+ torch_dtype=torch_dtype,
402
+ safety_checker=None,
403
+ requires_safety_checker=False
404
+ ).to(device)
405
+
406
+ # Enable memory efficient attention if available
407
+ try:
408
+ pipe.enable_attention_slicing()
409
+ if hasattr(pipe, 'enable_model_cpu_offload'):
410
+ pipe.enable_model_cpu_offload()
411
+ except Exception:
412
+ pass
413
+
414
+ # Load and resize init image
415
+ init_image = Image.open(init_image_path).convert("RGB")
416
+ init_image = init_image.resize((width, height), Image.LANCZOS)
417
+
418
+ # Generate
419
+ result = pipe(
420
+ prompt=prompt,
421
+ image=init_image,
422
+ strength=0.6,
423
+ num_inference_steps=num_steps,
424
+ guidance_scale=guidance_scale,
425
+ generator=generator,
426
+ height=height,
427
+ width=width
428
+ ).images[0]
429
+
430
+ else:
431
+ # Text-to-image pipeline
432
+ logger.info("Using txt2img pipeline")
433
+ pipe = StableDiffusionPipeline.from_pretrained(
434
+ "runwayml/stable-diffusion-v1-5",
435
+ torch_dtype=torch_dtype,
436
+ safety_checker=None,
437
+ requires_safety_checker=False
438
+ ).to(device)
439
+
440
+ # Enable memory efficient attention if available
441
+ try:
442
+ pipe.enable_attention_slicing()
443
+ if hasattr(pipe, 'enable_model_cpu_offload'):
444
+ pipe.enable_model_cpu_offload()
445
+ except Exception:
446
+ pass
447
+
448
+ # Generate
449
+ result = pipe(
450
+ prompt=prompt,
451
+ height=height,
452
+ width=width,
453
+ num_inference_steps=num_steps,
454
+ guidance_scale=guidance_scale,
455
+ generator=generator
456
+ ).images[0]
457
+
458
+ # Save result
459
+ output_path = TEMP_DIR / f"ai_bg_{int(time.time())}_{seed:08x}.jpg"
460
+ result.save(output_path, quality=95, optimize=True)
461
+
462
+ # Cleanup GPU memory
463
+ try:
464
+ del pipe
465
+ if TORCH_AVAILABLE and CUDA_AVAILABLE:
466
+ torch.cuda.empty_cache()
467
+ except Exception:
468
+ pass
469
+
470
+ logger.info(f"AI background generated: {output_path}")
471
+ return str(output_path)
472
+
473
+ except Exception as e:
474
+ logger.error(f"AI background generation failed: {e}")
475
+ raise RuntimeError(f"Background generation failed: {e}")
476
+
477
 
478
  # ==============================================================================
479
  # MAIN PROCESSING PIPELINE