Rahatara commited on
Commit
5355110
·
verified ·
1 Parent(s): 6719c4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -30
app.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import requests
9
  import io
10
  import base64
 
11
 
12
  # Set API tokens
13
  os.environ["REPLICATE_API_TOKEN"] = "r8_Brv0MtpmAiqrXrMrziyUXoSHuFV5hqs1Lw4Mo"
@@ -57,7 +58,6 @@ def generate_variations(base_prompt, number_of_variations):
57
  weather = random.choice(weather_conditions)
58
 
59
  # Enhance the base prompt with the GPT model
60
- #enhanced_prompt = ask_rail_defect_question(base_prompt)
61
  enhanced_prompt = base_prompt
62
  full_prompt = f"{enhanced_prompt}, with a {size} defect {location}, observed {weather}."
63
  variations.append(full_prompt)
@@ -70,33 +70,37 @@ def image_to_data_url(image):
70
  return f"data:image/png;base64,{img_str}"
71
 
72
  # Function to inpaint images
73
- def inpaint_defect(image, prompt):
74
  if isinstance(image, np.ndarray):
75
  image = Image.fromarray(image)
76
 
77
  image_data_url = image_to_data_url(image)
78
-
79
- input = {
80
- "image": image_data_url,
81
- "prompt": prompt,
82
- "scheduler": "K_EULER_ANCESTRAL",
83
- "num_outputs": 1,
84
- "guidance_scale": 7.5,
85
- "num_inference_steps": 100,
86
- "image_guidance_scale": 1.5
87
- }
88
-
89
- prediction = rep_client.predictions.create(
90
- version="10e63b0e6361eb23a0374f4d9ee145824d9d09f7a31dcd70803193ebc7121430",
91
- input = input
92
- )
93
- prediction.wait()
94
- if prediction.status == "succeeded":
95
- image_url = prediction.output[0]
96
- response = requests.get(image_url)
97
- image = Image.open(io.BytesIO(response.content))
98
- return image
99
- return None
 
 
 
 
100
 
101
  # Function to generate images from prompts
102
  def generate_images(prompts):
@@ -121,6 +125,17 @@ def process_railway_defects(prompt, number_of_images):
121
  images = generate_images(variations)
122
  return images
123
 
 
 
 
 
 
 
 
 
 
 
 
124
  # UI creation
125
  with gr.Blocks() as app:
126
  with gr.Tabs("Prompt Input"):
@@ -162,18 +177,30 @@ with gr.Blocks() as app:
162
  with gr.Row():
163
  image_input = gr.Image(label="Upload Image")
164
  inpaint_prompt_input = gr.Textbox(label="Defect Description")
 
165
  submit_button_inpaint = gr.Button("Inpaint Defect")
166
- inpainted_image_output = gr.Image()
 
 
 
 
 
167
 
168
- def on_submit_click_inpaint(image, inpaint_prompt):
169
- inpainted_image = inpaint_defect(image, inpaint_prompt)
170
- return inpainted_image
171
 
172
  submit_button_inpaint.click(
173
  fn=on_submit_click_inpaint,
174
- inputs=[image_input, inpaint_prompt_input],
175
  outputs=inpainted_image_output
176
  )
177
 
 
 
 
 
 
 
178
  if __name__ == "__main__":
179
- app.launch()
 
8
  import requests
9
  import io
10
  import base64
11
+ import zipfile
12
 
13
  # Set API tokens
14
  os.environ["REPLICATE_API_TOKEN"] = "r8_Brv0MtpmAiqrXrMrziyUXoSHuFV5hqs1Lw4Mo"
 
58
  weather = random.choice(weather_conditions)
59
 
60
  # Enhance the base prompt with the GPT model
 
61
  enhanced_prompt = base_prompt
62
  full_prompt = f"{enhanced_prompt}, with a {size} defect {location}, observed {weather}."
63
  variations.append(full_prompt)
 
70
  return f"data:image/png;base64,{img_str}"
71
 
72
  # Function to inpaint images
73
+ def inpaint_defect(image, prompt, num_images=1):
74
  if isinstance(image, np.ndarray):
75
  image = Image.fromarray(image)
76
 
77
  image_data_url = image_to_data_url(image)
78
+ images = []
79
+
80
+ for _ in range(num_images):
81
+ input = {
82
+ "input_image": image_data_url,
83
+ "instruction_text": prompt,
84
+ "scheduler": "K_EULER_ANCESTRAL",
85
+ "num_outputs": 1,
86
+ "guidance_scale": 7.5,
87
+ "num_inference_steps": 100,
88
+ "image_guidance_scale": 1.5
89
+ }
90
+
91
+ prediction = rep_client.predictions.create(
92
+ version="10e63b0e6361eb23a0374f4d9ee145824d9d09f7a31dcd70803193ebc7121430",
93
+ input=input
94
+ )
95
+ prediction.wait()
96
+ if prediction.status == "succeeded":
97
+ image_url = prediction.output[0]
98
+ response = requests.get(image_url)
99
+ img = Image.open(io.BytesIO(response.content))
100
+ images.append(img)
101
+ else:
102
+ images.append(None)
103
+ return images
104
 
105
  # Function to generate images from prompts
106
  def generate_images(prompts):
 
125
  images = generate_images(variations)
126
  return images
127
 
128
+ def download_images_as_zip(images):
129
+ zip_buffer = io.BytesIO()
130
+ with zipfile.ZipFile(zip_buffer, 'w') as zf:
131
+ for i, img in enumerate(images):
132
+ img_buffer = io.BytesIO()
133
+ img.save(img_buffer, format='PNG')
134
+ img_buffer.seek(0)
135
+ zf.writestr(f'image_{i + 1}.png', img_buffer.read())
136
+ zip_buffer.seek(0)
137
+ return zip_buffer
138
+
139
  # UI creation
140
  with gr.Blocks() as app:
141
  with gr.Tabs("Prompt Input"):
 
177
  with gr.Row():
178
  image_input = gr.Image(label="Upload Image")
179
  inpaint_prompt_input = gr.Textbox(label="Defect Description")
180
+ number_input_inpaint = gr.Number(label="Number of images to generate", value=1, minimum=1, maximum=10)
181
  submit_button_inpaint = gr.Button("Inpaint Defect")
182
+ inpainted_image_output = gr.Gallery()
183
+ download_button = gr.Button("Download Images as Zip")
184
+
185
+ def on_submit_click_inpaint(image, inpaint_prompt, number_of_images):
186
+ inpainted_images = inpaint_defect(image, inpaint_prompt, num_images=number_of_images)
187
+ return inpainted_images
188
 
189
+ def on_download_click(images):
190
+ zip_buffer = download_images_as_zip(images)
191
+ return zip_buffer
192
 
193
  submit_button_inpaint.click(
194
  fn=on_submit_click_inpaint,
195
+ inputs=[image_input, inpaint_prompt_input, number_input_inpaint],
196
  outputs=inpainted_image_output
197
  )
198
 
199
+ download_button.click(
200
+ fn=on_download_click,
201
+ inputs=inpainted_image_output,
202
+ outputs=gr.File(label="Download Zip")
203
+ )
204
+
205
  if __name__ == "__main__":
206
+ app.launch()