Content_Creation / tools /image_generation_tools.py
SwatGarg's picture
Create image_generation_tools.py
d65d363 verified
raw
history blame contribute delete
487 Bytes
import torch
from diffusers import StableDiffusionPipeline
class ImageGenerationTools:
def __init__(self, model_id="CompVis/stable-diffusion-v1-4"):
self.pipe = StableDiffusionPipeline.from_pretrained(model_id)
self.pipe = self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
def generate_image(self, prompt, save_path):
# Generate the image
image = self.pipe(prompt)["sample"][0]
# Save the image
image.save(save_path)