Varhal commited on
Commit
e230053
·
verified ·
1 Parent(s): 2c1f8ae

removed useless comments

Browse files
Files changed (1) hide show
  1. app.py +7 -37
app.py CHANGED
@@ -52,8 +52,8 @@ def get_image_tags(file_name, text_prompt, model="gemini-2.0-flash-exp"):
52
  temperature=0.5, # Lower temperature might give more focused tags
53
  top_p=0.95,
54
  top_k=40,
55
- max_output_tokens=1024, # Tags shouldn't need many tokens
56
- response_modalities=["text"], # Explicitly ask for text
57
  response_mime_type="text/plain", # Expect plain text
58
  )
59
 
@@ -77,10 +77,8 @@ def get_image_tags(file_name, text_prompt, model="gemini-2.0-flash-exp"):
77
 
78
  except Exception as e:
79
  print(f"Error during tagging API call: {e}")
80
- # Return an error message if tagging fails
81
  return f"Error generating tags: {e}"
82
  finally:
83
- # Clean up uploaded files from the tagging call
84
  for file in uploaded_files:
85
  try:
86
  client.files.delete(name=file.name)
@@ -91,20 +89,15 @@ def get_image_tags(file_name, text_prompt, model="gemini-2.0-flash-exp"):
91
 
92
  # Function for the main image processing call
93
  def generate(text, file_name, model="gemini-2.0-flash-exp"):
94
- """
95
- Sends the image and prompt to Gemini and processes the streamed response.
96
- This function is used for the main user request (editing, analysis, etc.).
97
- """
98
  api_key = os.environ.get("geminigoogle")
99
  if not api_key:
100
  raise ValueError("GEMINI_API_KEY environment variable (geminigoogle) not set.")
101
 
102
  client = genai.Client(api_key=api_key)
103
- uploaded_files = [] # Keep track of uploaded files for cleanup
104
- temp_output_image_path = None # Keep track of generated temp image for cleanup
105
 
106
  try:
107
- # Upload the file for the main generation call
108
  uploaded_files = [client.files.upload(file=file_name)]
109
  print(f"Uploaded file for generation: {uploaded_files[0].uri}")
110
 
@@ -134,8 +127,6 @@ def generate(text, file_name, model="gemini-2.0-flash-exp"):
134
  text_response = ""
135
  image_path = None
136
 
137
- # Use NamedTemporaryFile with delete=False because we need to return the path
138
- # We will handle deletion explicitly later.
139
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
140
  temp_output_image_path = tmp.name
141
 
@@ -149,44 +140,32 @@ def generate(text, file_name, model="gemini-2.0-flash-exp"):
149
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
150
  continue
151
 
152
- # Process each part in the chunk
153
  for part in chunk.candidates[0].content.parts:
154
  # Check for text parts
155
  text_part = getattr(part, "text", "")
156
  if text_part:
157
  text_response += text_part
158
 
159
- # Check for inline image data
160
  if part.inline_data:
161
  print(f"Received image data with mime type {part.inline_data.mime_type}. Saving to {temp_output_image_path}")
162
  save_binary_file(temp_output_image_path, part.inline_data.data)
163
  image_path = temp_output_image_path # Set the output image path
164
- # Note: If the model sends multiple images, this will only save the last one received in a part.
165
- # For typical use cases where one image is expected, this is fine.
166
- # If multiple images could be in different parts of the *same* chunk,
167
- # you'd need more complex handling (e.g., saving each to a separate file).
168
- # If the model sends an image and *then* more text, the loop continues.
169
- # We set image_path here and let the loop finish collecting text.
170
 
171
  print("Generation stream finished.")
172
- # The loop finishes after processing all parts of all chunks.
173
 
174
- # Check if an image was actually saved, otherwise set image_path to None
175
  if not image_path or not os.path.exists(image_path) or os.path.getsize(image_path) == 0:
176
  print("No valid image data was received or saved.")
177
- image_path = None # Ensure image_path is None if no image data was received/saved
178
 
179
- return image_path, text_response.strip() # Return the path to the saved image (or None) and the collected text
180
 
181
  except Exception as e:
182
  print(f"Error during main generation API call: {e}")
183
- # Ensure temporary files created before the error are cleaned up
184
  if temp_output_image_path and os.path.exists(temp_output_image_path):
185
  os.remove(temp_output_image_path)
186
  raise e # Re-raise the exception after cleanup
187
 
188
  finally:
189
- # Clean up uploaded files from the generation call
190
  for file in uploaded_files:
191
  try:
192
  client.files.delete(name=file.name)
@@ -202,12 +181,7 @@ def process_image_and_prompt(composite_pil, prompt):
202
  # 1. Save the input PIL image to a temporary file
203
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
204
  composite_path = tmp.name
205
- # Ensure image is saved in a format compatible with Gemini, convert if necessary
206
  if composite_pil.mode == "RGBA":
207
- # Convert RGBA to RGB if necessary, as some models prefer RGB
208
- # Or handle alpha channel depending on model capabilities.
209
- # For simplicity here, saving as PNG should preserve alpha,
210
- # but Gemini might interpret it differently. Let's save as PNG.
211
  composite_pil.save(composite_path, format="PNG")
212
  else:
213
  composite_pil.save(composite_path, format="PNG") # Save as PNG by default
@@ -221,12 +195,11 @@ def process_image_and_prompt(composite_pil, prompt):
221
  tag_json_string = get_image_tags(file_name, tagging_prompt, model=model)
222
 
223
  # 3. Call generate for the main image processing based on the user prompt
224
- # This function returns the path to a generated image (if any) and text response
225
  output_image_path, main_text_response = generate(text=prompt, file_name=file_name, model=model)
226
 
227
  # 4. Combine the tag JSON string and the main text response
228
  # Format the output clearly
229
- final_text_output = f"Original Image Tags (JSON): {tag_json_string}\n\n---\n\nGemini Response:\n{main_text_response}"
230
 
231
  # 5. Prepare the image output for the Gradio gallery
232
  result_img = None
@@ -265,9 +238,6 @@ def process_image_and_prompt(composite_pil, prompt):
265
  except Exception as cleanup_e:
266
  print(f"Error deleting input temporary file {composite_path}: {cleanup_e}")
267
 
268
- # Clean up the temporary output image file created by generate()
269
- # Note: generate() might have already deleted the *uploaded* file via API,
270
- # but this handles the local file saved from inline_data.
271
  if output_image_path and os.path.exists(output_image_path):
272
  try:
273
  os.remove(output_image_path)
 
52
  temperature=0.5, # Lower temperature might give more focused tags
53
  top_p=0.95,
54
  top_k=40,
55
+ max_output_tokens=1024,
56
+ response_modalities=["text"],
57
  response_mime_type="text/plain", # Expect plain text
58
  )
59
 
 
77
 
78
  except Exception as e:
79
  print(f"Error during tagging API call: {e}")
 
80
  return f"Error generating tags: {e}"
81
  finally:
 
82
  for file in uploaded_files:
83
  try:
84
  client.files.delete(name=file.name)
 
89
 
90
  # Function for the main image processing call
91
  def generate(text, file_name, model="gemini-2.0-flash-exp"):
 
 
 
 
92
  api_key = os.environ.get("geminigoogle")
93
  if not api_key:
94
  raise ValueError("GEMINI_API_KEY environment variable (geminigoogle) not set.")
95
 
96
  client = genai.Client(api_key=api_key)
97
+ uploaded_files = []
98
+ temp_output_image_path = None
99
 
100
  try:
 
101
  uploaded_files = [client.files.upload(file=file_name)]
102
  print(f"Uploaded file for generation: {uploaded_files[0].uri}")
103
 
 
127
  text_response = ""
128
  image_path = None
129
 
 
 
130
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
131
  temp_output_image_path = tmp.name
132
 
 
140
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
141
  continue
142
 
 
143
  for part in chunk.candidates[0].content.parts:
144
  # Check for text parts
145
  text_part = getattr(part, "text", "")
146
  if text_part:
147
  text_response += text_part
148
 
 
149
  if part.inline_data:
150
  print(f"Received image data with mime type {part.inline_data.mime_type}. Saving to {temp_output_image_path}")
151
  save_binary_file(temp_output_image_path, part.inline_data.data)
152
  image_path = temp_output_image_path # Set the output image path
 
 
 
 
 
 
153
 
154
  print("Generation stream finished.")
 
155
 
 
156
  if not image_path or not os.path.exists(image_path) or os.path.getsize(image_path) == 0:
157
  print("No valid image data was received or saved.")
158
+ image_path = None
159
 
160
+ return image_path, text_response.strip()
161
 
162
  except Exception as e:
163
  print(f"Error during main generation API call: {e}")
 
164
  if temp_output_image_path and os.path.exists(temp_output_image_path):
165
  os.remove(temp_output_image_path)
166
  raise e # Re-raise the exception after cleanup
167
 
168
  finally:
 
169
  for file in uploaded_files:
170
  try:
171
  client.files.delete(name=file.name)
 
181
  # 1. Save the input PIL image to a temporary file
182
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
183
  composite_path = tmp.name
 
184
  if composite_pil.mode == "RGBA":
 
 
 
 
185
  composite_pil.save(composite_path, format="PNG")
186
  else:
187
  composite_pil.save(composite_path, format="PNG") # Save as PNG by default
 
195
  tag_json_string = get_image_tags(file_name, tagging_prompt, model=model)
196
 
197
  # 3. Call generate for the main image processing based on the user prompt
 
198
  output_image_path, main_text_response = generate(text=prompt, file_name=file_name, model=model)
199
 
200
  # 4. Combine the tag JSON string and the main text response
201
  # Format the output clearly
202
+ final_text_output = f"{tag_json_string},{main_text_response}"
203
 
204
  # 5. Prepare the image output for the Gradio gallery
205
  result_img = None
 
238
  except Exception as cleanup_e:
239
  print(f"Error deleting input temporary file {composite_path}: {cleanup_e}")
240
 
 
 
 
241
  if output_image_path and os.path.exists(output_image_path):
242
  try:
243
  os.remove(output_image_path)