Phauglin commited on
Commit
b89861e
·
verified ·
1 Parent(s): 9dfef73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -30
app.py CHANGED
@@ -91,7 +91,8 @@ example_images = [
91
  ]
92
 
93
 
94
- def process_classification(img):
 
95
  print("Starting prediction...")
96
  predicted_class, _, probs = learn.predict(img)
97
  print(f"Prediction complete: {predicted_class}")
@@ -108,26 +109,31 @@ def process_classification(img):
108
  endangerment_status = get_status(predicted_class)
109
  print(f"Status found: {endangerment_status}")
110
 
111
- return classification_results, wiki_url, endangerment_status, predicted_class
112
-
113
- def generate_artistic_image(predicted_class):
114
  print("Sending request to DALL-E...")
115
  try:
116
  client = OpenAI()
117
- result = client.images.generate(
118
- model="gpt-image-1",
119
- prompt=random.choice(prompt_templates).format(flower=predicted_class),
120
- size="1024x1024"
121
- )
122
 
123
- image_base64 = result.data[0].b64_json
124
- image_bytes = base64.b64decode(image_base64)
125
- generated_image = Image.open(io.BytesIO(image_bytes))
126
- return generated_image
 
 
 
 
 
 
 
 
127
 
128
  except Exception as e:
129
  print(f"Error generating image: {e}")
130
- return None
 
 
 
 
131
 
132
  # Function to clear all outputs
133
  def clear_outputs():
@@ -165,24 +171,10 @@ with gr.Blocks() as demo:
165
  outputs=None
166
  )
167
 
168
- # Store the predicted class for image generation
169
- predicted_class = gr.State()
170
-
171
- def process_and_generate(img):
172
- # First get classification results
173
- classification_results, wiki_url, endangerment_status, pred_class = process_classification(img)
174
-
175
- # Return classification results immediately
176
- yield classification_results, None, wiki_url, endangerment_status, pred_class
177
-
178
- # Then generate and return the image
179
- generated_img = generate_artistic_image(pred_class)
180
- yield classification_results, generated_img, wiki_url, endangerment_status, pred_class
181
-
182
  input_image.change(
183
- fn=process_and_generate,
184
  inputs=input_image,
185
- outputs=[label_output, generated_image, wiki_output, endangerment_output, predicted_class]
186
  )
187
 
188
  input_image.clear(
 
91
  ]
92
 
93
 
94
+ # Main function to process the uploaded image
95
+ def process_image(img, generate_image=True):
96
  print("Starting prediction...")
97
  predicted_class, _, probs = learn.predict(img)
98
  print(f"Prediction complete: {predicted_class}")
 
109
  endangerment_status = get_status(predicted_class)
110
  print(f"Status found: {endangerment_status}")
111
 
112
+ # Generate artistic interpretation using DALL-E
 
 
113
  print("Sending request to DALL-E...")
114
  try:
115
  client = OpenAI()
 
 
 
 
 
116
 
117
+ if generate_image:
118
+ result = client.images.generate(
119
+ model="gpt-image-1",
120
+ prompt=random.choice(prompt_templates).format(flower=predicted_class),
121
+ size="1024x1024"
122
+ )
123
+
124
+ image_base64 = result.data[0].b64_json
125
+ image_bytes = base64.b64decode(image_base64)
126
+ generated_image = Image.open(io.BytesIO(image_bytes))
127
+ else:
128
+ generated_image = None
129
 
130
  except Exception as e:
131
  print(f"Error generating image: {e}")
132
+ generated_image = None
133
+
134
+ print("Image retrieved and ready to return")
135
+ return classification_results, generated_image, wiki_url, endangerment_status
136
+
137
 
138
  # Function to clear all outputs
139
  def clear_outputs():
 
171
  outputs=None
172
  )
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  input_image.change(
175
+ fn=lambda img: process_image(img, generate_image=True),
176
  inputs=input_image,
177
+ outputs=[label_output, generated_image, wiki_output, endangerment_output]
178
  )
179
 
180
  input_image.clear(