ai-image-edit / app.py
EDM25's picture
Update app.py
5319dd9 verified
import base64
import os
import uuid
import time
import logging
import google.genai as genai
from google.genai import types
import gradio as gr
from PIL import Image
import io
from dotenv import load_dotenv
# Load environment variables from .env file if it exists
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('gemini_debug.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
if not os.environ.get("GEMINI_API_KEY"):
raise ValueError("GEMINI_API_KEY environment variable is not set")
def save_binary_file(data, mime_type):
# Create unique filename with timestamp and UUID
file_extension = mime_type.split('/')[-1]
file_name = f"output_{time.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}.{file_extension}"
file_path = os.path.join(os.path.dirname(__file__), file_name)
with open(file_path, "wb") as f:
f.write(data)
return file_path
def optimize_image(image, max_size=1024, quality=85):
"""
Optimize the image by:
1. Resizing if larger than max_size
2. Converting to RGB mode
3. Applying compression
Returns: Optimized PIL Image object
"""
logger.debug(f"Optimizing image. Original size: {image.size}, mode: {image.mode}")
# Convert to RGB if needed (removing alpha channel)
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
background = Image.new('RGB', image.size, (255, 255, 255))
if image.mode == 'P':
image = image.convert('RGBA')
background.paste(image, mask=image.split()[3] if image.mode == 'RGBA' else None)
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
# Resize if the image is too large
width, height = image.size
if width > max_size or height > max_size:
if width > height:
new_width = max_size
new_height = int(max_size * height / width)
else:
new_height = max_size
new_width = int(max_size * width / height)
image = image.resize((new_width, new_height), Image.LANCZOS)
logger.debug(f"Resized image to: {new_width}x{new_height}")
# Compress the image to a BytesIO object
output_buffer = io.BytesIO()
image.save(output_buffer, format='JPEG', quality=quality, optimize=True)
# Get the size of the compressed image
compressed_size = output_buffer.tell()
logger.debug(f"Optimized image size: {compressed_size / 1024:.1f} KB")
# Return to the beginning of the buffer and load as an image
output_buffer.seek(0)
optimized_image = Image.open(output_buffer)
optimized_image.load()
return optimized_image
def save_temp_image(image):
"""Save PIL image temporarily to disk for upload"""
# Optimize the image first
optimized_image = optimize_image(image)
temp_path = os.path.join(os.path.dirname(__file__), f"temp_input_{uuid.uuid4().hex[:8]}.jpg")
optimized_image.save(temp_path, format="JPEG", quality=90, optimize=True)
logger.debug(f"Saved optimized image to {temp_path}")
file_size = os.path.getsize(temp_path) / 1024 # Size in KB
logger.debug(f"File size: {file_size:.1f} KB")
return temp_path
def debug_save_failed_data(data, prefix="failed"):
"""Save problematic data for debugging"""
debug_path = os.path.join(os.path.dirname(__file__), f"{prefix}_{uuid.uuid4().hex[:8]}.bin")
with open(debug_path, "wb") as f:
f.write(data)
logger.debug(f"Saved problematic data to {debug_path}")
return debug_path
def is_base64_encoded(data):
"""Check if data is likely base64 encoded by examining its characteristic patterns"""
if isinstance(data, bytes):
# Convert a sample of the data to string for checking
sample = data[:20].decode('utf-8', errors='ignore')
else:
sample = data[:20]
# Common base64 image prefixes
base64_prefixes = ['iVBOR', 'R0lGOD', '/9j/', 'PD94', 'PHN2']
return any(sample.startswith(prefix) for prefix in base64_prefixes)
def generate(input_image, prompt_text):
logger.info(f"Starting generate function with prompt: {prompt_text}")
# Optimize the input image before processing
input_image = optimize_image(input_image)
client = genai.Client(
api_key=os.environ.get("GEMINI_API_KEY"),
)
model = "gemini-2.0-flash-exp"
temp_image_path = save_temp_image(input_image)
try:
uploaded_file = client.files.upload(file=temp_image_path)
contents = [
types.Content(
role="user",
parts=[
types.Part.from_uri(
file_uri=uploaded_file.uri,
mime_type=uploaded_file.mime_type,
),
types.Part.from_text(text=f"Edit this image according to these instructions: {prompt_text}"),
],
),
]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=[
"image",
"text",
],
safety_settings=[
types.SafetySetting(
category="HARM_CATEGORY_CIVIC_INTEGRITY",
threshold="OFF",
),
],
response_mime_type="text/plain",
)
output_image = None
for chunk in client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
):
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
logger.debug("Received empty chunk or missing components")
continue
if chunk.candidates[0].content.parts[0].inline_data:
try:
data = chunk.candidates[0].content.parts[0].inline_data.data
mime_type = chunk.candidates[0].content.parts[0].inline_data.mime_type
logger.debug(f"Received data type: {type(data)}")
logger.debug(f"Data length: {len(data) if isinstance(data, (str, bytes)) else 'N/A'}")
logger.debug(f"MIME type: {mime_type}")
# Debug first few bytes to identify format
if isinstance(data, bytes):
try:
sample = data[:100].decode('utf-8', errors='ignore')
logger.debug(f"First 100 chars as string: {sample}")
except Exception:
pass
logger.debug(f"First 16 bytes: {data[:16].hex()}")
# Check if the data is base64 encoded but being returned as bytes
if isinstance(data, bytes) and is_base64_encoded(data):
logger.debug("Data appears to be base64 encoded but returned as bytes")
# Convert bytes to string first
data_str = data.decode('utf-8', errors='ignore')
# Then decode the base64
try:
data = base64.b64decode(data_str)
logger.debug("Successfully decoded base64 from bytes->string->binary")
except Exception as e:
logger.error(f"Base64 decoding failed after bytes->string: {str(e)}")
# If data is a base64-encoded string
elif isinstance(data, str):
logger.debug("Decoding base64 string data")
try:
data = base64.b64decode(data)
logger.debug(f"Decoded data length: {len(data)}")
except Exception as e:
logger.error(f"Base64 decoding failed: {str(e)}")
debug_save_failed_data(data.encode(), "base64_failed")
return None
# Save data for debugging
debug_path = debug_save_failed_data(data, "debug_data")
# Create a BytesIO object from the data
img_buffer = io.BytesIO(data)
img_buffer.seek(0)
# Try multiple approaches to load the image
try:
output_image = Image.open(img_buffer)
except Exception as e1:
logger.warning(f"First attempt failed: {str(e1)}")
# Try all possible formats
img_buffer.seek(0)
try:
output_image = Image.open(img_buffer, formats=['PNG', 'JPEG', 'WEBP', 'GIF'])
except Exception as e2:
logger.error(f"Second attempt failed: {str(e2)}")
# Final attempt: Try to save to file and reload
try:
temp_img_path = os.path.join(os.path.dirname(__file__), f"temp_output_{uuid.uuid4().hex[:8]}.png")
with open(temp_img_path, 'wb') as f:
f.write(data)
output_image = Image.open(temp_img_path)
# Clean up temp file
os.remove(temp_img_path)
except Exception as e3:
logger.error(f"All image loading attempts failed: {str(e3)}")
return None
logger.info(f"Successfully opened image: format={output_image.format}, size={output_image.size}, mode={output_image.mode}")
if output_image.mode in ('RGBA', 'LA'):
output_image = output_image.convert('RGB')
logger.debug("Converted image to RGB mode")
output_image.load()
return output_image
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return None
else:
logger.debug("Chunk contained no inline data")
logger.warning("No valid image data found in response")
return None
finally:
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
logger.debug(f"Cleaned up temporary file: {temp_image_path}")
def create_interface():
with gr.Blocks(title="Gemini Image Editor") as app:
gr.Markdown("# Gemini Image Editor")
gr.Markdown("Upload an image and provide instructions for how you want Gemini to edit it.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Image", type="pil")
prompt_text = gr.Textbox(label="Editing Instructions",
placeholder="Describe how you want the image to be edited...")
submit_btn = gr.Button("Generate Edited Image")
with gr.Column():
# Remove the output text box and only keep the image output
output_image = gr.Image(label="Edited Image")
# Update the function outputs to only include the image
submit_btn.click(
fn=generate,
inputs=[input_image, prompt_text],
outputs=[output_image]
)
return app
def main():
app = create_interface()
app.launch()
if __name__ == "__main__":
main()