CyberRohith commited on
Commit
bfc02f3
·
verified ·
1 Parent(s): bc82176

Update model_loading.py

Browse files
Files changed (1) hide show
  1. model_loading.py +33 -36
model_loading.py CHANGED
@@ -1,8 +1,7 @@
1
  from PIL import Image
2
  import os
3
- import ollama
4
  import torch
5
- from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, LCMScheduler, AutoPipelineForText2Image
6
  import time
7
 
8
  model_id = "simianluo/lcm_dreamshaper_v7"
@@ -14,42 +13,41 @@ class GenerationSession:
14
  self.img2img_pipeline = None
15
  self.current_image = None
16
  self.current_prompt = None
 
 
 
 
17
  self._initialize_pipelines()
18
 
19
  def _initialize_pipelines(self):
20
- print("initializing pipelines...")
21
-
22
- self.txt2img_pipeline = DiffusionPipeline.from_pretrained(
23
- model_id,
24
- torch_dtype = torch.float16,
25
- safety_checker = None
26
- )
27
-
28
- self.txt2img_pipeline.scheduler = LCMScheduler.from_config(self.txt2img_pipeline.scheduler.config)
29
- self.txt2img_pipeline.to("cuda")
30
- self.txt2img_pipeline.enable_attention_slicing()
31
- self.txt2img_pipeline.enable_vae_slicing()
32
 
33
- # self.txt2img_pipeline.unet = torch.compile(
34
- # self.txt2img_pipeline.unet,
35
- # mode = "reduce-overhead",
36
- # fullgraph = True
37
- #)
38
- print("Text 2 image pipeline loaded and compiled.")
 
 
 
 
 
39
 
 
40
 
41
- self.img2img_pipeline = AutoPipelineForImage2Image.from_pipe(self.txt2img_pipeline)
42
- print("Image 2 image pipeline loaded (shared weights).")
43
 
44
  def GeneratingBaseImage(self, prompt: str, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
45
  start = time.time()
46
  image = self.txt2img_pipeline(
47
- prompt = prompt,
48
- negative_prompt= negative_prompt,
49
- num_inference_steps = 4,
50
- guidance_scale = 1.0,
51
- height = 512,
52
- width = 512
53
  ).images
54
  print(f"Text to image generated in [{time.time() - start:.2f}s]")
55
  return image
@@ -57,12 +55,12 @@ class GenerationSession:
57
  def GeneratingVariationImage(self, prompt: str, reference_image: Image.Image, strength: float = 0.5, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
58
  start = time.time()
59
  image = self.img2img_pipeline(
60
- prompt = prompt,
61
- image = reference_image,
62
- strength = strength,
63
- num_inference_steps = 4,
64
- guidance_scale = 1.0,
65
- negative_prompt = negative_prompt
66
  ).images
67
  print(f"Image to image generated in [{time.time() - start:.2f}s]")
68
  return image
@@ -79,5 +77,4 @@ class GenerationSession:
79
  def reset(self):
80
  self.current_image = None
81
  self.current_prompt = None
82
- print("Session reset. Ready for new generation.")
83
-
 
1
  from PIL import Image
2
  import os
 
3
  import torch
4
+ from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, LCMScheduler
5
  import time
6
 
7
  model_id = "simianluo/lcm_dreamshaper_v7"
 
13
  self.img2img_pipeline = None
14
  self.current_image = None
15
  self.current_prompt = None
16
+
17
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ self.dtype = torch.float16 if self.device == "cuda" else torch.float32
19
+
20
  self._initialize_pipelines()
21
 
22
  def _initialize_pipelines(self):
23
+ print(f"Initializing pipelines on device: {self.device}...")
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ self.txt2img_pipeline = DiffusionPipeline.from_pretrained(
26
+ self.model_id,
27
+ torch_dtype=self.dtype,
28
+ safety_checker=None
29
+ )
30
+
31
+ self.txt2img_pipeline.scheduler = LCMScheduler.from_config(self.txt2img_pipeline.scheduler.config)
32
+ self.txt2img_pipeline.to(self.device)
33
+
34
+ self.txt2img_pipeline.enable_attention_slicing()
35
+ self.txt2img_pipeline.enable_vae_slicing()
36
 
37
+ print("Text 2 image pipeline loaded.")
38
 
39
+ self.img2img_pipeline = AutoPipelineForImage2Image.from_pipe(self.txt2img_pipeline)
40
+ print("Image 2 image pipeline loaded (shared weights).")
41
 
42
  def GeneratingBaseImage(self, prompt: str, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
43
  start = time.time()
44
  image = self.txt2img_pipeline(
45
+ prompt=prompt,
46
+ negative_prompt=negative_prompt,
47
+ num_inference_steps=4,
48
+ guidance_scale=1.0,
49
+ height=512,
50
+ width=512
51
  ).images
52
  print(f"Text to image generated in [{time.time() - start:.2f}s]")
53
  return image
 
55
  def GeneratingVariationImage(self, prompt: str, reference_image: Image.Image, strength: float = 0.5, negative_prompt: str = "Blurry, low quality, static and distorted image") -> str:
56
  start = time.time()
57
  image = self.img2img_pipeline(
58
+ prompt=prompt,
59
+ image=reference_image,
60
+ strength=strength,
61
+ num_inference_steps=4,
62
+ guidance_scale=1.0,
63
+ negative_prompt=negative_prompt
64
  ).images
65
  print(f"Image to image generated in [{time.time() - start:.2f}s]")
66
  return image
 
77
  def reset(self):
78
  self.current_image = None
79
  self.current_prompt = None
80
+ print("Session reset. Ready for new generation.")