| | |
| | from diffusers import DDPMPipeline |
| | import torch |
| | import numpy as np |
| | import gradio as gr |
| | from torchvision.utils import make_grid |
| | import torchvision.transforms as transforms |
| | from PIL import Image |
| | import logging |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | |
| | try: |
| | from accelerate import Accelerator |
| | logging.info("Accelerate library found.") |
| | except ImportError: |
| | logging.warning("Accelerate library not found; it's recommended for large models.") |
| |
|
| | |
| | MODEL_ID = "ahmetyaylalioglu/textile_diffusion_ddpm" |
| | logging.info(f"Loading model from {MODEL_ID}...") |
| | pipeline = DDPMPipeline.from_pretrained(MODEL_ID) |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | logging.info(f"Using device: {device}") |
| | pipeline.to(device) |
| | pipeline.unet.to(torch.float32) |
| | pipeline.unet.eval() |
| |
|
| | def generate_images(seed, num_images): |
| | try: |
| | seed = int(seed) |
| | num_images = min(int(num_images), 16) |
| | logging.info(f"Generating {num_images} images with seed {seed}") |
| |
|
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| |
|
| | imgs = pipeline(batch_size=num_images).images |
| | grid = make_grid([transforms.ToTensor()(img) for img in imgs], nrow=min(4, num_images)) |
| | return transforms.ToPILImage()(grid) |
| | except Exception as e: |
| | logging.error(f"generate_images error: {e}") |
| | return None |
| |
|
| | interface = gr.Interface( |
| | fn=generate_images, |
| | inputs=[ |
| | gr.Textbox(label="Random Seed", value=str(np.random.randint(0,1000))), |
| | gr.Textbox(label="Number of Images", value="4") |
| | ], |
| | outputs="image", |
| | title="Textile Diffusion (DDPM)", |
| | description="Generate textile patterns via a DDPM model." |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | interface.launch(share=True) |
| |
|