File size: 3,075 Bytes
2480131
 
 
bfc02f3
2480131
 
 
 
 
 
 
 
 
 
 
bfc02f3
 
 
 
2480131
 
 
bfc02f3
2480131
bfc02f3
 
 
 
 
 
 
 
 
 
 
2480131
bfc02f3
2480131
bfc02f3
 
2480131
 
 
 
bfc02f3
 
 
 
 
 
2480131
 
 
 
 
 
 
bfc02f3
 
 
 
 
 
2480131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfc02f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from PIL import Image
import os
import torch
from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, LCMScheduler
import time

model_id = "simianluo/lcm_dreamshaper_v7"

class GenerationSession:
    def __init__(self, model_id):
        self.model_id = model_id   
        self.txt2img_pipeline = None  
        self.img2img_pipeline = None
        self.current_image = None
        self.current_prompt = None
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.float16 if self.device == "cuda" else torch.float32
        
        self._initialize_pipelines()

    def _initialize_pipelines(self):
        print(f"Initializing pipelines on device: {self.device}...")

        self.txt2img_pipeline = DiffusionPipeline.from_pretrained(
            self.model_id,
            torch_dtype=self.dtype,
            safety_checker=None
        )
        
        self.txt2img_pipeline.scheduler = LCMScheduler.from_config(self.txt2img_pipeline.scheduler.config)
        self.txt2img_pipeline.to(self.device)
        
        self.txt2img_pipeline.enable_attention_slicing()
        self.txt2img_pipeline.enable_vae_slicing()

        print("Text 2 image pipeline loaded.")

        self.img2img_pipeline = AutoPipelineForImage2Image.from_pipe(self.txt2img_pipeline)  
        print("Image 2 image pipeline loaded (shared weights).")

    def GeneratingBaseImage(self, prompt: str, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
        start = time.time()
        image = self.txt2img_pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=4,
            guidance_scale=1.0,
            height=512,
            width=512
        ).images
        print(f"Text to image generated in [{time.time() - start:.2f}s]")
        return image
    
    def GeneratingVariationImage(self, prompt: str, reference_image: Image.Image, strength: float = 0.5, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
        start = time.time()
        image = self.img2img_pipeline(
            prompt=prompt,
            image=reference_image,
            strength=strength,
            num_inference_steps=4,
            guidance_scale=1.0,
            negative_prompt=negative_prompt
        ).images
        print(f"Image to image generated in [{time.time() - start:.2f}s]")
        return image
    
    def Generate(self, new_prompt: str, strength: float = 0.5):
        if self.current_image is None:
            self.current_image = self.GeneratingBaseImage(new_prompt)
        else:
            self.current_image = self.GeneratingVariationImage(new_prompt, self.current_image, strength)
        
        self.current_prompt = new_prompt
        return self.current_image
    
    def reset(self):
        self.current_image = None
        self.current_prompt = None
        print("Session reset. Ready for new generation.")