boomcheng commited on
Commit
6308e17
·
verified ·
1 Parent(s): e65ee5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -88
app.py CHANGED
@@ -34,52 +34,35 @@ def submit_prompt(prompt):
34
  combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
35
  return combined_list, gr.update(interactive=False) # Make the prompt input non-editable
36
 
37
- # Function to add a new object
38
  def add_object(object_class, bbox):
39
- object_classes_list.append(object_class)
40
- object_bboxes_list.append(bbox)
41
- combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
42
- return combined_list
43
-
44
- # Function to generate images based on added objects
45
- def generate_image(prompt, guidance_scale, num_inference_steps, randomize_seed, seed):
46
- img_width, img_height = 512, 512
47
- r_image = np.zeros((img_height, img_width, 3), dtype=np.uint8)
48
- list_cond_image = []
49
-
50
- for bbox in object_bboxes_list:
51
  x1, y1, x2, y2 = map(int, bbox.split(","))
52
- cond_image = np.zeros_like(r_image, dtype=np.uint8)
53
- cond_image[y1:y2, x1:x2] = 255
54
- list_cond_image.append(Image.fromarray(cond_image).convert('RGB'))
55
-
56
- if randomize_seed or seed is None:
57
- seed = np.random.randint(0, MAX_SEED)
58
-
59
- generator = torch.manual_seed(seed)
60
-
61
- image = pipe(
62
- prompt=prompt,
63
- layo_prompt=object_classes_list,
64
- guess_mode=False,
65
- guidance_scale=guidance_scale,
66
- num_inference_steps=num_inference_steps,
67
- image=list_cond_image,
68
- fuse_type="avg",
69
- width=512,
70
- height=512
71
- ).images[0]
72
-
73
- return image, seed
74
 
75
  # Gradio UI
76
  with gr.Blocks() as demo:
77
  gr.Markdown("# Text-to-Image Generator with Object Addition")
78
 
79
- # Put prompt and submit button in the same row
80
  with gr.Row():
81
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here")
82
- submit_button = gr.Button("Submit Prompt")
83
 
84
  # Always visible DataFrame
85
  objects_display = gr.Dataframe(
@@ -87,56 +70,7 @@ with gr.Blocks() as demo:
87
  value=[]
88
  )
89
 
90
- with gr.Row():
91
- object_class_input = gr.Textbox(label="Object Class", placeholder="Enter object class (e.g., Object_1)")
92
- bbox_input = gr.Textbox(label="Bounding Box (x1,y1,x2,y2)", placeholder="Enter bounding box coordinates")
93
-
94
- add_button = gr.Button("Add Object")
95
-
96
- # Advanced settings in a collapsible accordion
97
- with gr.Accordion("Advanced Settings", open=False):
98
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
99
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
100
-
101
- with gr.Row():
102
- guidance_scale = gr.Slider(
103
- label="Guidance scale",
104
- minimum=0.0,
105
- maximum=10.0,
106
- step=0.1,
107
- value=7.5
108
- )
109
- num_inference_steps = gr.Slider(
110
- label="Number of inference steps",
111
- minimum=1,
112
- maximum=50,
113
- step=1,
114
- value=50
115
- )
116
-
117
- generate_button = gr.Button("Generate Image")
118
- result = gr.Image(label="Generated Image")
119
-
120
- # Submit the prompt and update the display
121
- submit_button.click(
122
- fn=submit_prompt,
123
- inputs=prompt,
124
- outputs=[objects_display, prompt] # Update both the display and prompt input
125
- )
126
-
127
- # Add object and update display
128
- add_button.click(
129
- fn=add_object,
130
- inputs=[object_class_input, bbox_input],
131
- outputs=objects_display
132
- )
133
-
134
- # Generate image based on added objects
135
- generate_button.click(
136
- fn=generate_image,
137
- inputs=[prompt, guidance_scale, num_inference_steps, randomize_seed, seed],
138
- outputs=[result, seed]
139
- )
140
 
141
  if __name__ == "__main__":
142
  demo.launch()
 
34
  combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
35
  return combined_list, gr.update(interactive=False) # Make the prompt input non-editable
36
 
37
+ # Function to add a new object with validation
38
  def add_object(object_class, bbox):
39
+ try:
40
+ # Split and convert bbox string into integers
 
 
 
 
 
 
 
 
 
 
41
  x1, y1, x2, y2 = map(int, bbox.split(","))
42
+
43
+ # Validate the coordinates
44
+ if x2 < x1 or y2 < y1:
45
+ return "Error: x2 cannot be less than x1 and y2 cannot be less than y1.", []
46
+ if x1 < 0 or y1 < 0 or x2 > 512 or y2 > 512:
47
+ return "Error: Coordinates must be between 0 and 512.", []
48
+
49
+ # If validation passes, add to the lists
50
+ object_classes_list.append(object_class)
51
+ object_bboxes_list.append(bbox)
52
+ combined_list = [[cls, bbox] for cls, bbox in zip(object_classes_list, object_bboxes_list)]
53
+ return combined_list
54
+
55
+ except ValueError:
56
+ return "Error: Invalid input format. Use x1,y1,x2,y2.", []
 
 
 
 
 
 
 
57
 
58
  # Gradio UI
59
  with gr.Blocks() as demo:
60
  gr.Markdown("# Text-to-Image Generator with Object Addition")
61
 
62
+ # Put prompt and submit button together in the same input field
63
  with gr.Row():
64
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", interactive=True)
65
+ submit_button = gr.Button("Submit", size="sm") # Use a smaller button size
66
 
67
  # Always visible DataFrame
68
  objects_display = gr.Dataframe(
 
70
  value=[]
71
  )
72
 
73
+ # Add remaining UI components...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
  demo.launch()