HuiZhang0812 commited on
Commit
41050f2
·
verified ·
1 Parent(s): 0b66af7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -24
app.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from utils.bbox_visualization import bbox_visualization,scale_boxes
2
  from PIL import Image
3
  import os
@@ -5,6 +10,23 @@ import pandas as pd
5
  from huggingface_hub import login
6
 
7
  hf_token = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  print("pipeline is loaded.")
9
 
10
  @spaces.GPU
@@ -12,6 +34,7 @@ def process_image_and_text(global_caption, box_detail_phrases_list:pd.DataFrame,
12
 
13
  if randomize_seed:
14
  seed = torch.randint(0, 100, (1,)).item()
 
15
  height = 1024
16
  width = 1024
17
 
@@ -19,7 +42,6 @@ def process_image_and_text(global_caption, box_detail_phrases_list:pd.DataFrame,
19
  box_detail_phrases_list_tmp = [c[0] for c in box_detail_phrases_list_tmp]
20
  boxes = boxes.astype(float).values.tolist()
21
 
22
-
23
  white_image = Image.new('RGB', (width, height), color='rgb(256,256,256)')
24
  show_input = {"boxes":scale_boxes(boxes,width,height),"labels":box_detail_phrases_list_tmp}
25
  bbox_visualization_img = bbox_visualization(white_image,show_input)
@@ -33,6 +55,7 @@ def process_image_and_text(global_caption, box_detail_phrases_list:pd.DataFrame,
33
  bbox_raw=boxes,
34
  height=height,
35
  width=width
 
36
 
37
  return bbox_visualization_img, result_img
38
 
@@ -100,34 +123,12 @@ def get_samples():
100
  with gr.Blocks() as demo:
101
  gr.Markdown("# CreatiLayout: Layout-to-Image generation")
102
  gr.Markdown("""CreatiLayout is a layout-to-image framework for Diffusion Transformer models, offering high-quality and fine-grained controllable generation based on the global description and entity annotations. Users need to provide a global description and the position and description of each entity, as shown in the examples. Please feel free to modify the position and attributes of the entities in the examples (such as size, color, shape, text, portrait, etc.). Here are some inspirations: Iron Man -> Spider Man/Harry Potter/Buzz Lightyear; CreatiLayout -> Hello Friends/Let's Control; drawing board -> round drawing board; Modify the position of the drawing board to (0.4, 0.15, 0.55, 0.35)""")
103
-
104
-
105
-
106
-
107
-
108
-
109
-
110
  with gr.Row():
111
 
112
  with gr.Column():
113
  global_caption = gr.Textbox(lines=2, label="Global Caption")
114
  box_detail_phrases_list = gr.Dataframe(headers=["Region Captions"], label="Region Captions")
115
  boxes = gr.Dataframe(headers=["x1", "y1", "x2", "y2"], label="Region Bounding Boxes (x_min,y_min,x_max,y_max)")
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-
124
-
125
-
126
-
127
-
128
-
129
-
130
-
131
  with gr.Accordion("Advanced Settings", open=False):
132
  seed = gr.Slider(0, 100, step=1, label="Seed", value=42)
133
  randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
@@ -137,7 +138,6 @@ with gr.Blocks() as demo:
137
  bbox_visualization_img = gr.Image(type="pil", label="Bounding Box Visualization")
138
 
139
  with gr.Column():
140
-
141
  output_image = gr.Image(type="pil", label="Generated Image")
142
 
143
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from src.models.transformer_sd3_SiamLayout import SiamLayoutSD3Transformer2DModel
5
+ from src.pipeline.pipeline_CreatiLayout import CreatiLayoutSD3Pipeline
6
  from utils.bbox_visualization import bbox_visualization,scale_boxes
7
  from PIL import Image
8
  import os
 
10
  from huggingface_hub import login
11
 
12
  hf_token = os.getenv("HF_TOKEN")
13
+
14
+ if hf_token is None:
15
+ raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret.")
16
+
17
+ login(token=hf_token)
18
+
19
+ model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
20
+ ckpt_path = "HuiZhang0812/CreatiLayout"
21
+
22
+ transformer_additional_kwargs = dict(attention_type="layout",strict=True)
23
+
24
+ transformer = SiamLayoutSD3Transformer2DModel.from_pretrained(
25
+ ckpt_path, subfolder="SiamLayout_SD3", torch_dtype=torch.float16,**transformer_additional_kwargs)
26
+
27
+ pipe = CreatiLayoutSD3Pipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
28
+ pipe = pipe.to("cuda")
29
+
30
  print("pipeline is loaded.")
31
 
32
  @spaces.GPU
 
34
 
35
  if randomize_seed:
36
  seed = torch.randint(0, 100, (1,)).item()
37
+
38
  height = 1024
39
  width = 1024
40
 
 
42
  box_detail_phrases_list_tmp = [c[0] for c in box_detail_phrases_list_tmp]
43
  boxes = boxes.astype(float).values.tolist()
44
 
 
45
  white_image = Image.new('RGB', (width, height), color='rgb(256,256,256)')
46
  show_input = {"boxes":scale_boxes(boxes,width,height),"labels":box_detail_phrases_list_tmp}
47
  bbox_visualization_img = bbox_visualization(white_image,show_input)
 
55
  bbox_raw=boxes,
56
  height=height,
57
  width=width
58
+ ).images[0]
59
 
60
  return bbox_visualization_img, result_img
61
 
 
123
  with gr.Blocks() as demo:
124
  gr.Markdown("# CreatiLayout: Layout-to-Image generation")
125
  gr.Markdown("""CreatiLayout is a layout-to-image framework for Diffusion Transformer models, offering high-quality and fine-grained controllable generation based on the global description and entity annotations. Users need to provide a global description and the position and description of each entity, as shown in the examples. Please feel free to modify the position and attributes of the entities in the examples (such as size, color, shape, text, portrait, etc.). Here are some inspirations: Iron Man -> Spider Man/Harry Potter/Buzz Lightyear; CreatiLayout -> Hello Friends/Let's Control; drawing board -> round drawing board; Modify the position of the drawing board to (0.4, 0.15, 0.55, 0.35)""")
 
 
 
 
 
 
 
126
  with gr.Row():
127
 
128
  with gr.Column():
129
  global_caption = gr.Textbox(lines=2, label="Global Caption")
130
  box_detail_phrases_list = gr.Dataframe(headers=["Region Captions"], label="Region Captions")
131
  boxes = gr.Dataframe(headers=["x1", "y1", "x2", "y2"], label="Region Bounding Boxes (x_min,y_min,x_max,y_max)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  with gr.Accordion("Advanced Settings", open=False):
133
  seed = gr.Slider(0, 100, step=1, label="Seed", value=42)
134
  randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
 
138
  bbox_visualization_img = gr.Image(type="pil", label="Bounding Box Visualization")
139
 
140
  with gr.Column():
 
141
  output_image = gr.Image(type="pil", label="Generated Image")
142
 
143