AI_generation / model.py
Prashanthsrn's picture
Update model.py
4e9fdd8 verified
from diffusers import StableDiffusionPipeline
import torch
def load_model(model_name):
try:
if model_name == "Stable Diffusion":
# Check if GPU is available, else use CPU
if torch.cuda.is_available():
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda") # Use GPU
else:
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cpu") # Use CPU if no GPU is found
return pipe
except Exception as e:
print(f"Error loading model: {e}")
return None
def generate_image(model, prompt, model_name):
try:
if model_name == "Stable Diffusion":
result = model(prompt, num_inference_steps=10, height=256, width=256) # Reduced resolution for faster generation
image = result.images[0]
return image
except Exception as e:
print(f"Error generating image: {e}")
return None