ahmetyaylalioglu's picture
Update app.py
0fef7b3 verified
# app.py
from diffusers import DDPMPipeline
import torch
import numpy as np
import gradio as gr
from torchvision.utils import make_grid
import torchvision.transforms as transforms
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO)
# Optional: warn if 'accelerate' is missing
try:
from accelerate import Accelerator
logging.info("Accelerate library found.")
except ImportError:
logging.warning("Accelerate library not found; it's recommended for large models.")
# Load DDPM pipeline
MODEL_ID = "ahmetyaylalioglu/textile_diffusion_ddpm"
logging.info(f"Loading model from {MODEL_ID}...")
pipeline = DDPMPipeline.from_pretrained(MODEL_ID)
# ZeroGPU → CPU fallback
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
pipeline.to(device)
pipeline.unet.to(torch.float32) # ensure CPU dtype
pipeline.unet.eval()
def generate_images(seed, num_images):
try:
seed = int(seed)
num_images = min(int(num_images), 16)
logging.info(f"Generating {num_images} images with seed {seed}")
torch.manual_seed(seed)
np.random.seed(seed)
imgs = pipeline(batch_size=num_images).images
grid = make_grid([transforms.ToTensor()(img) for img in imgs], nrow=min(4, num_images))
return transforms.ToPILImage()(grid)
except Exception as e:
logging.error(f"generate_images error: {e}")
return None
interface = gr.Interface(
fn=generate_images,
inputs=[
gr.Textbox(label="Random Seed", value=str(np.random.randint(0,1000))),
gr.Textbox(label="Number of Images", value="4")
],
outputs="image",
title="Textile Diffusion (DDPM)",
description="Generate textile patterns via a DDPM model."
)
if __name__ == "__main__":
interface.launch(share=True)