SwatGarg commited on
Commit
d65d363
·
verified ·
1 Parent(s): 55d30d0

Create image_generation_tools.py

Browse files
Files changed (1) hide show
  1. tools/image_generation_tools.py +14 -0
tools/image_generation_tools.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+
4
+ class ImageGenerationTools:
5
+
6
+ def __init__(self, model_id="CompVis/stable-diffusion-v1-4"):
7
+ self.pipe = StableDiffusionPipeline.from_pretrained(model_id)
8
+ self.pipe = self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ def generate_image(self, prompt, save_path):
11
+ # Generate the image
12
+ image = self.pipe(prompt)["sample"][0]
13
+ # Save the image
14
+ image.save(save_path)