from PIL import Image import os import torch from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, LCMScheduler import time model_id = "simianluo/lcm_dreamshaper_v7" class GenerationSession: def __init__(self, model_id): self.model_id = model_id self.txt2img_pipeline = None self.img2img_pipeline = None self.current_image = None self.current_prompt = None self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.float16 if self.device == "cuda" else torch.float32 self._initialize_pipelines() def _initialize_pipelines(self): print(f"Initializing pipelines on device: {self.device}...") self.txt2img_pipeline = DiffusionPipeline.from_pretrained( self.model_id, torch_dtype=self.dtype, safety_checker=None ) self.txt2img_pipeline.scheduler = LCMScheduler.from_config(self.txt2img_pipeline.scheduler.config) self.txt2img_pipeline.to(self.device) self.txt2img_pipeline.enable_attention_slicing() self.txt2img_pipeline.enable_vae_slicing() print("Text 2 image pipeline loaded.") self.img2img_pipeline = AutoPipelineForImage2Image.from_pipe(self.txt2img_pipeline) print("Image 2 image pipeline loaded (shared weights).") def GeneratingBaseImage(self, prompt: str, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str: start = time.time() image = self.txt2img_pipeline( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=4, guidance_scale=1.0, height=512, width=512 ).images print(f"Text to image generated in [{time.time() - start:.2f}s]") return image def GeneratingVariationImage(self, prompt: str, reference_image: Image.Image, strength: float = 0.5, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str: start = time.time() image = self.img2img_pipeline( prompt=prompt, image=reference_image, strength=strength, num_inference_steps=4, guidance_scale=1.0, negative_prompt=negative_prompt ).images print(f"Image to image generated in [{time.time() - start:.2f}s]") return image def Generate(self, new_prompt: str, strength: float = 0.5): if self.current_image is None: self.current_image = self.GeneratingBaseImage(new_prompt) else: self.current_image = self.GeneratingVariationImage(new_prompt, self.current_image, strength) self.current_prompt = new_prompt return self.current_image def reset(self): self.current_image = None self.current_prompt = None print("Session reset. Ready for new generation.")