Multimodal_ai_agent / image_gen.py
hari7261's picture
Update image_gen.py
f31f8ac verified
raw
history blame contribute delete
757 Bytes
import torch
from diffusers import StableDiffusionPipeline
MODEL_ID = "stabilityai/sd-turbo"
_pipe = None
def get_pipe():
global _pipe
if _pipe is None:
_pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
_pipe = _pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# sd-turbo does not use safety checker
if hasattr(_pipe, "safety_checker"):
_pipe.safety_checker = None
return _pipe
def generate_image(prompt: str):
pipe = get_pipe()
image = pipe(
prompt=prompt,
num_inference_steps=4,
guidance_scale=0.0
).images[0]
return image