Sek2810 commited on
Commit
7b0d784
·
verified ·
1 Parent(s): 6d8ca5a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import spaces
6
+ from PIL import Image
7
+ import os
8
+
9
+ from models.transformer_sd3 import SD3Transformer2DModel
10
+ from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
11
+
12
+ from transformers import AutoProcessor, SiglipVisionModel
13
+ from huggingface_hub import hf_hub_download
14
+
15
+
16
+ # Constants
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+ MAX_IMAGE_SIZE = 1024
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ model_path = 'stabilityai/stable-diffusion-3.5-large'
22
+ image_encoder_path = "google/siglip-so400m-patch14-384"
23
+ ipadapter_path = hf_hub_download(repo_id="InstantX/SD3.5-Large-IP-Adapter", filename="ip-adapter.bin")
24
+
25
+ transformer = SD3Transformer2DModel.from_pretrained(
26
+ model_path,
27
+ subfolder="transformer",
28
+ torch_dtype=torch.bfloat16
29
+ )
30
+
31
+ pipe = StableDiffusion3Pipeline.from_pretrained(
32
+ model_path,
33
+ transformer=transformer,
34
+ torch_dtype=torch.bfloat16
35
+ ).to("cuda")
36
+
37
+ pipe.init_ipadapter(
38
+ ip_adapter_path=ipadapter_path,
39
+ image_encoder_path=image_encoder_path,
40
+ nb_token=64,
41
+ )
42
+
43
+ def resize_img(image, max_size=1024):
44
+ width, height = image.size
45
+ scaling_factor = min(max_size / width, max_size / height)
46
+ new_width = int(width * scaling_factor)
47
+ new_height = int(height * scaling_factor)
48
+ return image.resize((new_width, new_height), Image.LANCZOS)
49
+
50
+ @spaces.GPU
51
+ def process_image(
52
+ image,
53
+ prompt,
54
+ scale,
55
+ seed,
56
+ randomize_seed,
57
+ width,
58
+ height,
59
+ progress=gr.Progress(track_tqdm=True),
60
+ ):
61
+ #pipe.to("cuda")
62
+ if randomize_seed:
63
+ seed = random.randint(0, MAX_SEED)
64
+
65
+ if image is None:
66
+ return None, seed
67
+
68
+ # Convert to PIL Image if needed
69
+ if not isinstance(image, Image.Image):
70
+ image = Image.fromarray(image)
71
+
72
+ # Resize image
73
+ image = resize_img(image)
74
+
75
+ # Generate the image
76
+ result = pipe(
77
+ clip_image=image,
78
+ prompt=prompt,
79
+ ipadapter_scale=scale,
80
+ width=width,
81
+ height=height,
82
+ generator=torch.Generator().manual_seed(seed)
83
+ ).images[0]
84
+
85
+ return result, seed
86
+
87
+ # UI CSS
88
+ css = """
89
+ #col-container {
90
+ margin: 0 auto;
91
+ max-width: 960px;
92
+ }
93
+ """
94
+
95
+ # Create the Gradio interface
96
+ with gr.Blocks(css=css) as demo:
97
+ with gr.Column(elem_id="col-container"):
98
+ gr.Markdown("# InstantX's SD3.5 IP Adapter")
99
+
100
+ with gr.Row():
101
+ with gr.Column():
102
+ input_image = gr.Image(
103
+ label="Input Image",
104
+ type="pil"
105
+ )
106
+ scale = gr.Slider(
107
+ label="Image Scale",
108
+ minimum=0.0,
109
+ maximum=1.0,
110
+ step=0.1,
111
+ value=0.7,
112
+ )
113
+ prompt = gr.Text(
114
+ label="Prompt",
115
+ max_lines=1,
116
+ placeholder="Enter your prompt",
117
+ )
118
+ run_button = gr.Button("Generate", variant="primary")
119
+
120
+ with gr.Column():
121
+ result = gr.Image(label="Result")
122
+
123
+ with gr.Accordion("Advanced Settings", open=False):
124
+ seed = gr.Slider(
125
+ label="Seed",
126
+ minimum=0,
127
+ maximum=MAX_SEED,
128
+ step=1,
129
+ value=42,
130
+ )
131
+
132
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
133
+
134
+ with gr.Row():
135
+ width = gr.Slider(
136
+ label="Width",
137
+ minimum=256,
138
+ maximum=MAX_IMAGE_SIZE,
139
+ step=32,
140
+ value=1024,
141
+ )
142
+
143
+ height = gr.Slider(
144
+ label="Height",
145
+ minimum=256,
146
+ maximum=MAX_IMAGE_SIZE,
147
+ step=32,
148
+ value=1024,
149
+ )
150
+
151
+ run_button.click(
152
+ fn=process_image,
153
+ inputs=[
154
+ input_image,
155
+ prompt,
156
+ scale,
157
+ seed,
158
+ randomize_seed,
159
+ width,
160
+ height,
161
+ ],
162
+ outputs=[result, seed],
163
+ )
164
+
165
+ if __name__ == "__main__":
166
+ demo.launch()