Spaces:
Paused
Paused
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| StableDiffusionXLAdapterPipeline, | |
| AutoencoderKL, | |
| UniPCMultistepScheduler, | |
| T2IAdapter, | |
| ) | |
| import torch, os | |
| from PIL import Image | |
| from io import BytesIO | |
| import models | |
| from database import SessionLocal | |
| from text_processor import ( | |
| get_resolved_sentences, | |
| detect_and_translate_to_english, | |
| get_script_captions, | |
| ) | |
| from s3 import upload_image_to_s3 | |
| from diffusers.utils import load_image | |
| import random | |
| from controlnet_aux import OpenposeDetector | |
| import numpy as np | |
| import gc | |
| # Global device configuration | |
| dtype = torch.float16 | |
| # Initialize global generator | |
| generator = torch.Generator() | |
| # Initialize the models globally to ensure they're only loaded once | |
| print("Loading VAE...") | |
| vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype, use_safetensors=True | |
| ).to("cuda") | |
| print("Loading base pipeline...") | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| vae=vae, | |
| torch_dtype=dtype, | |
| variant="fp16", | |
| use_safetensors=True, | |
| ).to("cuda") | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe.load_lora_weights("safetensors/Storyboard_sketch.safetensors", adapter_name="sketch") | |
| pipe.load_lora_weights("safetensors/anglesv2.safetensors", adapter_name="angles") | |
| pipe.set_adapters(["sketch", "angles"], adapter_weights=[0.5, 0.5]) | |
| pipe.enable_xformers_memory_efficient_attention() | |
| print("Loading OpenPose detector...") | |
| openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") | |
| print("Loading T2I adapter...") | |
| adapter = T2IAdapter.from_pretrained( | |
| "TencentARC/t2i-adapter-openpose-sdxl-1.0", torch_dtype=dtype | |
| ).to("cuda") | |
| print("Loading adapter pipeline...") | |
| posepipe = StableDiffusionXLAdapterPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| adapter=adapter, | |
| vae=vae, | |
| torch_dtype=dtype, | |
| variant="fp16", | |
| use_safetensors=True, | |
| ).to("cuda") | |
| posepipe.scheduler = UniPCMultistepScheduler.from_config(posepipe.scheduler.config) | |
| posepipe.load_lora_weights( | |
| "safetensors/Storyboard_sketch.safetensors", adapter_name="sketch" | |
| ) | |
| posepipe.load_lora_weights("safetensors/anglesv2.safetensors", adapter_name="angles") | |
| posepipe.set_adapters(["sketch", "angles"], adapter_weights=[0.5, 0.5]) | |
| posepipe.enable_xformers_memory_efficient_attention() | |
| print("All models loaded successfully") | |
| def clear_cuda_cache(): | |
| """Clear CUDA cache to free up memory""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def get_dimensions(resolution: str) -> tuple[int, int]: | |
| resolution_map = { | |
| "16:9": (1024, 576), | |
| "1:1": (1024, 1024), | |
| "9:16": (576, 1024), | |
| } | |
| return resolution_map.get(resolution, (1024, 1024)) | |
| def generate_batch_images( | |
| story: str, storyboard_id: int, resolution: str = "1:1", isStory: bool = True | |
| ): | |
| # Clear cache before batch generation | |
| clear_cuda_cache() | |
| db = SessionLocal() | |
| try: | |
| if isStory: | |
| prompts = get_resolved_sentences(story) | |
| elif not isStory: | |
| prompts = get_script_captions(story) | |
| width, height = get_dimensions(resolution) | |
| for num, prompt in enumerate(prompts): | |
| # Generate a random seed for each image in the batch | |
| seed = random.randint(0, 2**32 - 1) | |
| generator.manual_seed(seed) | |
| print(f"Generating image {num+1} with seed {seed}") | |
| result = pipe( | |
| prompt=f"Storyboard sketch of {prompt}, black and white, cinematic, high quality", | |
| negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy, abstract, bad physics", | |
| guidance_scale=8.5, | |
| height=height, | |
| width=width, | |
| num_inference_steps=30, | |
| generator=generator, | |
| ) | |
| image = result.images[0] | |
| buf = BytesIO() | |
| image.save(buf, format="JPEG") | |
| buf.seek(0) | |
| s3_url = upload_image_to_s3( | |
| buf.read(), | |
| f"image_{num + 1}.jpg", | |
| folder=f"storyboards/{storyboard_id}", | |
| ) | |
| db_image = models.Image( | |
| storyboard_id=storyboard_id, | |
| image_path=s3_url, | |
| caption=prompt, | |
| ) | |
| db.add(db_image) | |
| db.commit() | |
| db.refresh(db_image) | |
| print(f"Image {num+1} generated successfully") | |
| # Clear cache after each image | |
| clear_cuda_cache() | |
| except Exception as e: | |
| print(f"Error during image generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| db.rollback() | |
| finally: | |
| db.close() | |
| def generate_single_image( | |
| image_id: int, | |
| caption: str, | |
| seed: int = None, | |
| resolution: str = "1:1", | |
| isOpenPose: bool = False, | |
| pose_img: Image.Image = None, | |
| ): | |
| # Clear cache before single image generation | |
| clear_cuda_cache() | |
| db = SessionLocal() | |
| try: | |
| # Get existing image record | |
| db_image = db.query(models.Image).filter(models.Image.id == image_id).first() | |
| processed_caption = detect_and_translate_to_english(caption) | |
| width, height = get_dimensions(resolution) | |
| # Use provided seed or generate a random one | |
| current_seed = seed if seed is not None else random.randint(0, 2**32 - 1) | |
| generator.manual_seed(current_seed) | |
| print(f"Generating single image with seed {current_seed}") | |
| if not db_image: | |
| raise ValueError(f"Image with id {image_id} not found.") | |
| if isOpenPose: | |
| print("Using OpenPose pipeline") | |
| image = openpose(pose_img, detect_resolution=512, image_resolution=1024) | |
| image = np.array(image)[:, :, ::-1] | |
| image = Image.fromarray(np.uint8(image)) | |
| result = posepipe( | |
| prompt=f"Storyboard sketch of {processed_caption}, black and white, cinematic, high quality", | |
| negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy, abstract, bad physics", | |
| image=image, | |
| adapter_conditioning_scale=1, | |
| guidance_scale=8.5, | |
| num_inference_steps=30, | |
| generator=generator, | |
| ) | |
| else: | |
| print("Using standard pipeline") | |
| result = pipe( | |
| prompt=f"Storyboard sketch of {processed_caption}, black and white, cinematic, high quality", | |
| negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy, abstract, bad physics", | |
| guidance_scale=8.5, | |
| num_inference_steps=30, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| ) | |
| # Save and upload | |
| image = result.images[0] | |
| buf = BytesIO() | |
| image.save(buf, format="JPEG") | |
| buf.seek(0) | |
| s3_url = upload_image_to_s3( | |
| buf.read(), | |
| f"image_{image_id}.jpg", | |
| folder=f"storyboards/{db_image.storyboard_id}", | |
| ) | |
| # Update image record | |
| db_image.image_path = s3_url | |
| db_image.caption = caption | |
| db_image.seed = current_seed | |
| db.commit() | |
| db.refresh(db_image) | |
| print(f"Single image generated successfully") | |
| # Clear cache after generation | |
| clear_cuda_cache() | |
| return db_image | |
| except Exception as e: | |
| print(f"Error during image regeneration: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| db.rollback() | |
| return None | |
| finally: | |
| db.close() | |