Update model_loading.py
Browse files- model_loading.py +33 -36
model_loading.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
from PIL import Image
|
| 2 |
import os
|
| 3 |
-
import ollama
|
| 4 |
import torch
|
| 5 |
-
from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, LCMScheduler
|
| 6 |
import time
|
| 7 |
|
| 8 |
model_id = "simianluo/lcm_dreamshaper_v7"
|
|
@@ -14,42 +13,41 @@ class GenerationSession:
|
|
| 14 |
self.img2img_pipeline = None
|
| 15 |
self.current_image = None
|
| 16 |
self.current_prompt = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
self._initialize_pipelines()
|
| 18 |
|
| 19 |
def _initialize_pipelines(self):
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
self.txt2img_pipeline = DiffusionPipeline.from_pretrained(
|
| 23 |
-
model_id,
|
| 24 |
-
torch_dtype = torch.float16,
|
| 25 |
-
safety_checker = None
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
self.txt2img_pipeline.scheduler = LCMScheduler.from_config(self.txt2img_pipeline.scheduler.config)
|
| 29 |
-
self.txt2img_pipeline.to("cuda")
|
| 30 |
-
self.txt2img_pipeline.enable_attention_slicing()
|
| 31 |
-
self.txt2img_pipeline.enable_vae_slicing()
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
def GeneratingBaseImage(self, prompt: str, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
|
| 45 |
start = time.time()
|
| 46 |
image = self.txt2img_pipeline(
|
| 47 |
-
prompt
|
| 48 |
-
negative_prompt=
|
| 49 |
-
num_inference_steps
|
| 50 |
-
guidance_scale
|
| 51 |
-
height
|
| 52 |
-
width
|
| 53 |
).images
|
| 54 |
print(f"Text to image generated in [{time.time() - start:.2f}s]")
|
| 55 |
return image
|
|
@@ -57,12 +55,12 @@ class GenerationSession:
|
|
| 57 |
def GeneratingVariationImage(self, prompt: str, reference_image: Image.Image, strength: float = 0.5, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
|
| 58 |
start = time.time()
|
| 59 |
image = self.img2img_pipeline(
|
| 60 |
-
prompt
|
| 61 |
-
image
|
| 62 |
-
strength
|
| 63 |
-
num_inference_steps
|
| 64 |
-
guidance_scale
|
| 65 |
-
negative_prompt
|
| 66 |
).images
|
| 67 |
print(f"Image to image generated in [{time.time() - start:.2f}s]")
|
| 68 |
return image
|
|
@@ -79,5 +77,4 @@ class GenerationSession:
|
|
| 79 |
def reset(self):
|
| 80 |
self.current_image = None
|
| 81 |
self.current_prompt = None
|
| 82 |
-
print("Session reset. Ready for new generation.")
|
| 83 |
-
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
import os
|
|
|
|
| 3 |
import torch
|
| 4 |
+
from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, LCMScheduler
|
| 5 |
import time
|
| 6 |
|
| 7 |
model_id = "simianluo/lcm_dreamshaper_v7"
|
|
|
|
| 13 |
self.img2img_pipeline = None
|
| 14 |
self.current_image = None
|
| 15 |
self.current_prompt = None
|
| 16 |
+
|
| 17 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 19 |
+
|
| 20 |
self._initialize_pipelines()
|
| 21 |
|
| 22 |
def _initialize_pipelines(self):
|
| 23 |
+
print(f"Initializing pipelines on device: {self.device}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
self.txt2img_pipeline = DiffusionPipeline.from_pretrained(
|
| 26 |
+
self.model_id,
|
| 27 |
+
torch_dtype=self.dtype,
|
| 28 |
+
safety_checker=None
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.txt2img_pipeline.scheduler = LCMScheduler.from_config(self.txt2img_pipeline.scheduler.config)
|
| 32 |
+
self.txt2img_pipeline.to(self.device)
|
| 33 |
+
|
| 34 |
+
self.txt2img_pipeline.enable_attention_slicing()
|
| 35 |
+
self.txt2img_pipeline.enable_vae_slicing()
|
| 36 |
|
| 37 |
+
print("Text 2 image pipeline loaded.")
|
| 38 |
|
| 39 |
+
self.img2img_pipeline = AutoPipelineForImage2Image.from_pipe(self.txt2img_pipeline)
|
| 40 |
+
print("Image 2 image pipeline loaded (shared weights).")
|
| 41 |
|
| 42 |
def GeneratingBaseImage(self, prompt: str, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
|
| 43 |
start = time.time()
|
| 44 |
image = self.txt2img_pipeline(
|
| 45 |
+
prompt=prompt,
|
| 46 |
+
negative_prompt=negative_prompt,
|
| 47 |
+
num_inference_steps=4,
|
| 48 |
+
guidance_scale=1.0,
|
| 49 |
+
height=512,
|
| 50 |
+
width=512
|
| 51 |
).images
|
| 52 |
print(f"Text to image generated in [{time.time() - start:.2f}s]")
|
| 53 |
return image
|
|
|
|
| 55 |
def GeneratingVariationImage(self, prompt: str, reference_image: Image.Image, strength: float = 0.5, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
|
| 56 |
start = time.time()
|
| 57 |
image = self.img2img_pipeline(
|
| 58 |
+
prompt=prompt,
|
| 59 |
+
image=reference_image,
|
| 60 |
+
strength=strength,
|
| 61 |
+
num_inference_steps=4,
|
| 62 |
+
guidance_scale=1.0,
|
| 63 |
+
negative_prompt=negative_prompt
|
| 64 |
).images
|
| 65 |
print(f"Image to image generated in [{time.time() - start:.2f}s]")
|
| 66 |
return image
|
|
|
|
| 77 |
def reset(self):
|
| 78 |
self.current_image = None
|
| 79 |
self.current_prompt = None
|
| 80 |
+
print("Session reset. Ready for new generation.")
|
|
|