diesdasjunge commited on
Commit
5e76241
·
verified ·
1 Parent(s): 3af9471

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ import gradio as gr
6
+
7
+ # Model configuration
8
+ model_path = "path/to/your/checkpoint.safetensors" # Update this with your checkpoint path
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load the model
12
+ pipe = StableDiffusionPipeline.from_pretrained(
13
+ "runwayml/stable-diffusion-v1-5",
14
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
15
+ safety_checker=None
16
+ )
17
+ pipe.to(device)
18
+
19
+ # If you have a custom checkpoint, load it
20
+ if os.path.exists(model_path):
21
+ pipe.unet.load_state_dict(torch.load(model_path))
22
+
23
+ def generate_image(prompt, negative_prompt, num_steps, guidance_scale, width, height, seed):
24
+ """
25
+ Generate an image using Stable Diffusion
26
+ """
27
+ if seed == -1:
28
+ seed = int.from_bytes(os.urandom(2), "big")
29
+ generator = torch.Generator(device=device).manual_seed(seed)
30
+
31
+ with autocast(device):
32
+ image = pipe(
33
+ prompt=prompt,
34
+ negative_prompt=negative_prompt,
35
+ num_inference_steps=num_steps,
36
+ guidance_scale=guidance_scale,
37
+ width=width,
38
+ height=height,
39
+ generator=generator
40
+ ).images[0]
41
+
42
+ return image, seed
43
+
44
+ # Create Gradio interface
45
+ with gr.Blocks() as demo:
46
+ gr.Markdown("# Stable Diffusion 1.5 Custom Model")
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
51
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here...")
52
+
53
+ with gr.Row():
54
+ num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Number of Steps")
55
+ guidance_scale = gr.Slider(minimum=1, maximum=20, value=7.5, step=0.5, label="Guidance Scale")
56
+
57
+ with gr.Row():
58
+ width = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Width")
59
+ height = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Height")
60
+
61
+ seed = gr.Number(label="Seed (-1 for random)", value=-1)
62
+ generate_btn = gr.Button("Generate Image")
63
+
64
+ with gr.Column():
65
+ output_image = gr.Image(label="Generated Image")
66
+ used_seed = gr.Number(label="Used Seed")
67
+
68
+ generate_btn.click(
69
+ fn=generate_image,
70
+ inputs=[prompt, negative_prompt, num_steps, guidance_scale, width, height, seed],
71
+ outputs=[output_image, used_seed]
72
+ )
73
+
74
+ # Launch app locally
75
+ if __name__ == "__main__":
76
+ demo.launch()