Phauglin commited on
Commit
9dfef73
·
verified ·
1 Parent(s): eb2b58f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -24
app.py CHANGED
@@ -91,8 +91,7 @@ example_images = [
91
  ]
92
 
93
 
94
- # Main function to process the uploaded image
95
- def process_image(img, generate_image=True):
96
  print("Starting prediction...")
97
  predicted_class, _, probs = learn.predict(img)
98
  print(f"Prediction complete: {predicted_class}")
@@ -109,33 +108,26 @@ def process_image(img, generate_image=True):
109
  endangerment_status = get_status(predicted_class)
110
  print(f"Status found: {endangerment_status}")
111
 
112
- # Generate artistic interpretation using DALL-E
 
 
113
  print("Sending request to DALL-E...")
114
  try:
115
  client = OpenAI()
 
 
 
 
 
116
 
117
- if generate_image:
118
- result = client.images.generate(
119
- model="gpt-image-1",
120
- prompt=random.choice(prompt_templates).format(flower=predicted_class),
121
- size="1024x1024",
122
- background="transparent",
123
- quality="low"
124
- )
125
-
126
- image_base64 = result.data[0].b64_json
127
- image_bytes = base64.b64decode(image_base64)
128
- generated_image = Image.open(io.BytesIO(image_bytes))
129
- else:
130
- generated_image = None
131
 
132
  except Exception as e:
133
  print(f"Error generating image: {e}")
134
- generated_image = None
135
-
136
- print("Image retrieved and ready to return")
137
- return classification_results, generated_image, wiki_url, endangerment_status
138
-
139
 
140
  # Function to clear all outputs
141
  def clear_outputs():
@@ -173,10 +165,24 @@ with gr.Blocks() as demo:
173
  outputs=None
174
  )
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  input_image.change(
177
- fn=lambda img: process_image(img, generate_image=True),
178
  inputs=input_image,
179
- outputs=[label_output, generated_image, wiki_output, endangerment_output]
180
  )
181
 
182
  input_image.clear(
 
91
  ]
92
 
93
 
94
+ def process_classification(img):
 
95
  print("Starting prediction...")
96
  predicted_class, _, probs = learn.predict(img)
97
  print(f"Prediction complete: {predicted_class}")
 
108
  endangerment_status = get_status(predicted_class)
109
  print(f"Status found: {endangerment_status}")
110
 
111
+ return classification_results, wiki_url, endangerment_status, predicted_class
112
+
113
+ def generate_artistic_image(predicted_class):
114
  print("Sending request to DALL-E...")
115
  try:
116
  client = OpenAI()
117
+ result = client.images.generate(
118
+ model="gpt-image-1",
119
+ prompt=random.choice(prompt_templates).format(flower=predicted_class),
120
+ size="1024x1024"
121
+ )
122
 
123
+ image_base64 = result.data[0].b64_json
124
+ image_bytes = base64.b64decode(image_base64)
125
+ generated_image = Image.open(io.BytesIO(image_bytes))
126
+ return generated_image
 
 
 
 
 
 
 
 
 
 
127
 
128
  except Exception as e:
129
  print(f"Error generating image: {e}")
130
+ return None
 
 
 
 
131
 
132
  # Function to clear all outputs
133
  def clear_outputs():
 
165
  outputs=None
166
  )
167
 
168
+ # Store the predicted class for image generation
169
+ predicted_class = gr.State()
170
+
171
+ def process_and_generate(img):
172
+ # First get classification results
173
+ classification_results, wiki_url, endangerment_status, pred_class = process_classification(img)
174
+
175
+ # Return classification results immediately
176
+ yield classification_results, None, wiki_url, endangerment_status, pred_class
177
+
178
+ # Then generate and return the image
179
+ generated_img = generate_artistic_image(pred_class)
180
+ yield classification_results, generated_img, wiki_url, endangerment_status, pred_class
181
+
182
  input_image.change(
183
+ fn=process_and_generate,
184
  inputs=input_image,
185
+ outputs=[label_output, generated_image, wiki_output, endangerment_output, predicted_class]
186
  )
187
 
188
  input_image.clear(