Spaces:
Running
Running
removed useless comments
Browse files
app.py
CHANGED
|
@@ -52,8 +52,8 @@ def get_image_tags(file_name, text_prompt, model="gemini-2.0-flash-exp"):
|
|
| 52 |
temperature=0.5, # Lower temperature might give more focused tags
|
| 53 |
top_p=0.95,
|
| 54 |
top_k=40,
|
| 55 |
-
max_output_tokens=1024,
|
| 56 |
-
response_modalities=["text"],
|
| 57 |
response_mime_type="text/plain", # Expect plain text
|
| 58 |
)
|
| 59 |
|
|
@@ -77,10 +77,8 @@ def get_image_tags(file_name, text_prompt, model="gemini-2.0-flash-exp"):
|
|
| 77 |
|
| 78 |
except Exception as e:
|
| 79 |
print(f"Error during tagging API call: {e}")
|
| 80 |
-
# Return an error message if tagging fails
|
| 81 |
return f"Error generating tags: {e}"
|
| 82 |
finally:
|
| 83 |
-
# Clean up uploaded files from the tagging call
|
| 84 |
for file in uploaded_files:
|
| 85 |
try:
|
| 86 |
client.files.delete(name=file.name)
|
|
@@ -91,20 +89,15 @@ def get_image_tags(file_name, text_prompt, model="gemini-2.0-flash-exp"):
|
|
| 91 |
|
| 92 |
# Function for the main image processing call
|
| 93 |
def generate(text, file_name, model="gemini-2.0-flash-exp"):
|
| 94 |
-
"""
|
| 95 |
-
Sends the image and prompt to Gemini and processes the streamed response.
|
| 96 |
-
This function is used for the main user request (editing, analysis, etc.).
|
| 97 |
-
"""
|
| 98 |
api_key = os.environ.get("geminigoogle")
|
| 99 |
if not api_key:
|
| 100 |
raise ValueError("GEMINI_API_KEY environment variable (geminigoogle) not set.")
|
| 101 |
|
| 102 |
client = genai.Client(api_key=api_key)
|
| 103 |
-
uploaded_files = []
|
| 104 |
-
temp_output_image_path = None
|
| 105 |
|
| 106 |
try:
|
| 107 |
-
# Upload the file for the main generation call
|
| 108 |
uploaded_files = [client.files.upload(file=file_name)]
|
| 109 |
print(f"Uploaded file for generation: {uploaded_files[0].uri}")
|
| 110 |
|
|
@@ -134,8 +127,6 @@ def generate(text, file_name, model="gemini-2.0-flash-exp"):
|
|
| 134 |
text_response = ""
|
| 135 |
image_path = None
|
| 136 |
|
| 137 |
-
# Use NamedTemporaryFile with delete=False because we need to return the path
|
| 138 |
-
# We will handle deletion explicitly later.
|
| 139 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
| 140 |
temp_output_image_path = tmp.name
|
| 141 |
|
|
@@ -149,44 +140,32 @@ def generate(text, file_name, model="gemini-2.0-flash-exp"):
|
|
| 149 |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
|
| 150 |
continue
|
| 151 |
|
| 152 |
-
# Process each part in the chunk
|
| 153 |
for part in chunk.candidates[0].content.parts:
|
| 154 |
# Check for text parts
|
| 155 |
text_part = getattr(part, "text", "")
|
| 156 |
if text_part:
|
| 157 |
text_response += text_part
|
| 158 |
|
| 159 |
-
# Check for inline image data
|
| 160 |
if part.inline_data:
|
| 161 |
print(f"Received image data with mime type {part.inline_data.mime_type}. Saving to {temp_output_image_path}")
|
| 162 |
save_binary_file(temp_output_image_path, part.inline_data.data)
|
| 163 |
image_path = temp_output_image_path # Set the output image path
|
| 164 |
-
# Note: If the model sends multiple images, this will only save the last one received in a part.
|
| 165 |
-
# For typical use cases where one image is expected, this is fine.
|
| 166 |
-
# If multiple images could be in different parts of the *same* chunk,
|
| 167 |
-
# you'd need more complex handling (e.g., saving each to a separate file).
|
| 168 |
-
# If the model sends an image and *then* more text, the loop continues.
|
| 169 |
-
# We set image_path here and let the loop finish collecting text.
|
| 170 |
|
| 171 |
print("Generation stream finished.")
|
| 172 |
-
# The loop finishes after processing all parts of all chunks.
|
| 173 |
|
| 174 |
-
# Check if an image was actually saved, otherwise set image_path to None
|
| 175 |
if not image_path or not os.path.exists(image_path) or os.path.getsize(image_path) == 0:
|
| 176 |
print("No valid image data was received or saved.")
|
| 177 |
-
image_path = None
|
| 178 |
|
| 179 |
-
return image_path, text_response.strip()
|
| 180 |
|
| 181 |
except Exception as e:
|
| 182 |
print(f"Error during main generation API call: {e}")
|
| 183 |
-
# Ensure temporary files created before the error are cleaned up
|
| 184 |
if temp_output_image_path and os.path.exists(temp_output_image_path):
|
| 185 |
os.remove(temp_output_image_path)
|
| 186 |
raise e # Re-raise the exception after cleanup
|
| 187 |
|
| 188 |
finally:
|
| 189 |
-
# Clean up uploaded files from the generation call
|
| 190 |
for file in uploaded_files:
|
| 191 |
try:
|
| 192 |
client.files.delete(name=file.name)
|
|
@@ -202,12 +181,7 @@ def process_image_and_prompt(composite_pil, prompt):
|
|
| 202 |
# 1. Save the input PIL image to a temporary file
|
| 203 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
| 204 |
composite_path = tmp.name
|
| 205 |
-
# Ensure image is saved in a format compatible with Gemini, convert if necessary
|
| 206 |
if composite_pil.mode == "RGBA":
|
| 207 |
-
# Convert RGBA to RGB if necessary, as some models prefer RGB
|
| 208 |
-
# Or handle alpha channel depending on model capabilities.
|
| 209 |
-
# For simplicity here, saving as PNG should preserve alpha,
|
| 210 |
-
# but Gemini might interpret it differently. Let's save as PNG.
|
| 211 |
composite_pil.save(composite_path, format="PNG")
|
| 212 |
else:
|
| 213 |
composite_pil.save(composite_path, format="PNG") # Save as PNG by default
|
|
@@ -221,12 +195,11 @@ def process_image_and_prompt(composite_pil, prompt):
|
|
| 221 |
tag_json_string = get_image_tags(file_name, tagging_prompt, model=model)
|
| 222 |
|
| 223 |
# 3. Call generate for the main image processing based on the user prompt
|
| 224 |
-
# This function returns the path to a generated image (if any) and text response
|
| 225 |
output_image_path, main_text_response = generate(text=prompt, file_name=file_name, model=model)
|
| 226 |
|
| 227 |
# 4. Combine the tag JSON string and the main text response
|
| 228 |
# Format the output clearly
|
| 229 |
-
final_text_output = f"
|
| 230 |
|
| 231 |
# 5. Prepare the image output for the Gradio gallery
|
| 232 |
result_img = None
|
|
@@ -265,9 +238,6 @@ def process_image_and_prompt(composite_pil, prompt):
|
|
| 265 |
except Exception as cleanup_e:
|
| 266 |
print(f"Error deleting input temporary file {composite_path}: {cleanup_e}")
|
| 267 |
|
| 268 |
-
# Clean up the temporary output image file created by generate()
|
| 269 |
-
# Note: generate() might have already deleted the *uploaded* file via API,
|
| 270 |
-
# but this handles the local file saved from inline_data.
|
| 271 |
if output_image_path and os.path.exists(output_image_path):
|
| 272 |
try:
|
| 273 |
os.remove(output_image_path)
|
|
|
|
| 52 |
temperature=0.5, # Lower temperature might give more focused tags
|
| 53 |
top_p=0.95,
|
| 54 |
top_k=40,
|
| 55 |
+
max_output_tokens=1024,
|
| 56 |
+
response_modalities=["text"],
|
| 57 |
response_mime_type="text/plain", # Expect plain text
|
| 58 |
)
|
| 59 |
|
|
|
|
| 77 |
|
| 78 |
except Exception as e:
|
| 79 |
print(f"Error during tagging API call: {e}")
|
|
|
|
| 80 |
return f"Error generating tags: {e}"
|
| 81 |
finally:
|
|
|
|
| 82 |
for file in uploaded_files:
|
| 83 |
try:
|
| 84 |
client.files.delete(name=file.name)
|
|
|
|
| 89 |
|
| 90 |
# Function for the main image processing call
|
| 91 |
def generate(text, file_name, model="gemini-2.0-flash-exp"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
api_key = os.environ.get("geminigoogle")
|
| 93 |
if not api_key:
|
| 94 |
raise ValueError("GEMINI_API_KEY environment variable (geminigoogle) not set.")
|
| 95 |
|
| 96 |
client = genai.Client(api_key=api_key)
|
| 97 |
+
uploaded_files = []
|
| 98 |
+
temp_output_image_path = None
|
| 99 |
|
| 100 |
try:
|
|
|
|
| 101 |
uploaded_files = [client.files.upload(file=file_name)]
|
| 102 |
print(f"Uploaded file for generation: {uploaded_files[0].uri}")
|
| 103 |
|
|
|
|
| 127 |
text_response = ""
|
| 128 |
image_path = None
|
| 129 |
|
|
|
|
|
|
|
| 130 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
| 131 |
temp_output_image_path = tmp.name
|
| 132 |
|
|
|
|
| 140 |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
|
| 141 |
continue
|
| 142 |
|
|
|
|
| 143 |
for part in chunk.candidates[0].content.parts:
|
| 144 |
# Check for text parts
|
| 145 |
text_part = getattr(part, "text", "")
|
| 146 |
if text_part:
|
| 147 |
text_response += text_part
|
| 148 |
|
|
|
|
| 149 |
if part.inline_data:
|
| 150 |
print(f"Received image data with mime type {part.inline_data.mime_type}. Saving to {temp_output_image_path}")
|
| 151 |
save_binary_file(temp_output_image_path, part.inline_data.data)
|
| 152 |
image_path = temp_output_image_path # Set the output image path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
print("Generation stream finished.")
|
|
|
|
| 155 |
|
|
|
|
| 156 |
if not image_path or not os.path.exists(image_path) or os.path.getsize(image_path) == 0:
|
| 157 |
print("No valid image data was received or saved.")
|
| 158 |
+
image_path = None
|
| 159 |
|
| 160 |
+
return image_path, text_response.strip()
|
| 161 |
|
| 162 |
except Exception as e:
|
| 163 |
print(f"Error during main generation API call: {e}")
|
|
|
|
| 164 |
if temp_output_image_path and os.path.exists(temp_output_image_path):
|
| 165 |
os.remove(temp_output_image_path)
|
| 166 |
raise e # Re-raise the exception after cleanup
|
| 167 |
|
| 168 |
finally:
|
|
|
|
| 169 |
for file in uploaded_files:
|
| 170 |
try:
|
| 171 |
client.files.delete(name=file.name)
|
|
|
|
| 181 |
# 1. Save the input PIL image to a temporary file
|
| 182 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
| 183 |
composite_path = tmp.name
|
|
|
|
| 184 |
if composite_pil.mode == "RGBA":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
composite_pil.save(composite_path, format="PNG")
|
| 186 |
else:
|
| 187 |
composite_pil.save(composite_path, format="PNG") # Save as PNG by default
|
|
|
|
| 195 |
tag_json_string = get_image_tags(file_name, tagging_prompt, model=model)
|
| 196 |
|
| 197 |
# 3. Call generate for the main image processing based on the user prompt
|
|
|
|
| 198 |
output_image_path, main_text_response = generate(text=prompt, file_name=file_name, model=model)
|
| 199 |
|
| 200 |
# 4. Combine the tag JSON string and the main text response
|
| 201 |
# Format the output clearly
|
| 202 |
+
final_text_output = f"{tag_json_string},{main_text_response}"
|
| 203 |
|
| 204 |
# 5. Prepare the image output for the Gradio gallery
|
| 205 |
result_img = None
|
|
|
|
| 238 |
except Exception as cleanup_e:
|
| 239 |
print(f"Error deleting input temporary file {composite_path}: {cleanup_e}")
|
| 240 |
|
|
|
|
|
|
|
|
|
|
| 241 |
if output_image_path and os.path.exists(output_image_path):
|
| 242 |
try:
|
| 243 |
os.remove(output_image_path)
|