Spaces:
Runtime error
Runtime error
| import base64 | |
| import os | |
| import uuid | |
| import time | |
| import logging | |
| import google.genai as genai | |
| from google.genai import types | |
| import gradio as gr | |
| from PIL import Image | |
| import io | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file if it exists | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('gemini_debug.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| if not os.environ.get("GEMINI_API_KEY"): | |
| raise ValueError("GEMINI_API_KEY environment variable is not set") | |
| def save_binary_file(data, mime_type): | |
| # Create unique filename with timestamp and UUID | |
| file_extension = mime_type.split('/')[-1] | |
| file_name = f"output_{time.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}.{file_extension}" | |
| file_path = os.path.join(os.path.dirname(__file__), file_name) | |
| with open(file_path, "wb") as f: | |
| f.write(data) | |
| return file_path | |
| def optimize_image(image, max_size=1024, quality=85): | |
| """ | |
| Optimize the image by: | |
| 1. Resizing if larger than max_size | |
| 2. Converting to RGB mode | |
| 3. Applying compression | |
| Returns: Optimized PIL Image object | |
| """ | |
| logger.debug(f"Optimizing image. Original size: {image.size}, mode: {image.mode}") | |
| # Convert to RGB if needed (removing alpha channel) | |
| if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| if image.mode == 'P': | |
| image = image.convert('RGBA') | |
| background.paste(image, mask=image.split()[3] if image.mode == 'RGBA' else None) | |
| image = background | |
| elif image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize if the image is too large | |
| width, height = image.size | |
| if width > max_size or height > max_size: | |
| if width > height: | |
| new_width = max_size | |
| new_height = int(max_size * height / width) | |
| else: | |
| new_height = max_size | |
| new_width = int(max_size * width / height) | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| logger.debug(f"Resized image to: {new_width}x{new_height}") | |
| # Compress the image to a BytesIO object | |
| output_buffer = io.BytesIO() | |
| image.save(output_buffer, format='JPEG', quality=quality, optimize=True) | |
| # Get the size of the compressed image | |
| compressed_size = output_buffer.tell() | |
| logger.debug(f"Optimized image size: {compressed_size / 1024:.1f} KB") | |
| # Return to the beginning of the buffer and load as an image | |
| output_buffer.seek(0) | |
| optimized_image = Image.open(output_buffer) | |
| optimized_image.load() | |
| return optimized_image | |
| def save_temp_image(image): | |
| """Save PIL image temporarily to disk for upload""" | |
| # Optimize the image first | |
| optimized_image = optimize_image(image) | |
| temp_path = os.path.join(os.path.dirname(__file__), f"temp_input_{uuid.uuid4().hex[:8]}.jpg") | |
| optimized_image.save(temp_path, format="JPEG", quality=90, optimize=True) | |
| logger.debug(f"Saved optimized image to {temp_path}") | |
| file_size = os.path.getsize(temp_path) / 1024 # Size in KB | |
| logger.debug(f"File size: {file_size:.1f} KB") | |
| return temp_path | |
| def debug_save_failed_data(data, prefix="failed"): | |
| """Save problematic data for debugging""" | |
| debug_path = os.path.join(os.path.dirname(__file__), f"{prefix}_{uuid.uuid4().hex[:8]}.bin") | |
| with open(debug_path, "wb") as f: | |
| f.write(data) | |
| logger.debug(f"Saved problematic data to {debug_path}") | |
| return debug_path | |
| def is_base64_encoded(data): | |
| """Check if data is likely base64 encoded by examining its characteristic patterns""" | |
| if isinstance(data, bytes): | |
| # Convert a sample of the data to string for checking | |
| sample = data[:20].decode('utf-8', errors='ignore') | |
| else: | |
| sample = data[:20] | |
| # Common base64 image prefixes | |
| base64_prefixes = ['iVBOR', 'R0lGOD', '/9j/', 'PD94', 'PHN2'] | |
| return any(sample.startswith(prefix) for prefix in base64_prefixes) | |
| def generate(input_image, prompt_text): | |
| logger.info(f"Starting generate function with prompt: {prompt_text}") | |
| # Optimize the input image before processing | |
| input_image = optimize_image(input_image) | |
| client = genai.Client( | |
| api_key=os.environ.get("GEMINI_API_KEY"), | |
| ) | |
| model = "gemini-2.0-flash-exp" | |
| temp_image_path = save_temp_image(input_image) | |
| try: | |
| uploaded_file = client.files.upload(file=temp_image_path) | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=[ | |
| types.Part.from_uri( | |
| file_uri=uploaded_file.uri, | |
| mime_type=uploaded_file.mime_type, | |
| ), | |
| types.Part.from_text(text=f"Edit this image according to these instructions: {prompt_text}"), | |
| ], | |
| ), | |
| ] | |
| generate_content_config = types.GenerateContentConfig( | |
| temperature=1, | |
| top_p=0.95, | |
| top_k=40, | |
| max_output_tokens=8192, | |
| response_modalities=[ | |
| "image", | |
| "text", | |
| ], | |
| safety_settings=[ | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_CIVIC_INTEGRITY", | |
| threshold="OFF", | |
| ), | |
| ], | |
| response_mime_type="text/plain", | |
| ) | |
| output_image = None | |
| for chunk in client.models.generate_content_stream( | |
| model=model, | |
| contents=contents, | |
| config=generate_content_config, | |
| ): | |
| if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
| logger.debug("Received empty chunk or missing components") | |
| continue | |
| if chunk.candidates[0].content.parts[0].inline_data: | |
| try: | |
| data = chunk.candidates[0].content.parts[0].inline_data.data | |
| mime_type = chunk.candidates[0].content.parts[0].inline_data.mime_type | |
| logger.debug(f"Received data type: {type(data)}") | |
| logger.debug(f"Data length: {len(data) if isinstance(data, (str, bytes)) else 'N/A'}") | |
| logger.debug(f"MIME type: {mime_type}") | |
| # Debug first few bytes to identify format | |
| if isinstance(data, bytes): | |
| try: | |
| sample = data[:100].decode('utf-8', errors='ignore') | |
| logger.debug(f"First 100 chars as string: {sample}") | |
| except Exception: | |
| pass | |
| logger.debug(f"First 16 bytes: {data[:16].hex()}") | |
| # Check if the data is base64 encoded but being returned as bytes | |
| if isinstance(data, bytes) and is_base64_encoded(data): | |
| logger.debug("Data appears to be base64 encoded but returned as bytes") | |
| # Convert bytes to string first | |
| data_str = data.decode('utf-8', errors='ignore') | |
| # Then decode the base64 | |
| try: | |
| data = base64.b64decode(data_str) | |
| logger.debug("Successfully decoded base64 from bytes->string->binary") | |
| except Exception as e: | |
| logger.error(f"Base64 decoding failed after bytes->string: {str(e)}") | |
| # If data is a base64-encoded string | |
| elif isinstance(data, str): | |
| logger.debug("Decoding base64 string data") | |
| try: | |
| data = base64.b64decode(data) | |
| logger.debug(f"Decoded data length: {len(data)}") | |
| except Exception as e: | |
| logger.error(f"Base64 decoding failed: {str(e)}") | |
| debug_save_failed_data(data.encode(), "base64_failed") | |
| return None | |
| # Save data for debugging | |
| debug_path = debug_save_failed_data(data, "debug_data") | |
| # Create a BytesIO object from the data | |
| img_buffer = io.BytesIO(data) | |
| img_buffer.seek(0) | |
| # Try multiple approaches to load the image | |
| try: | |
| output_image = Image.open(img_buffer) | |
| except Exception as e1: | |
| logger.warning(f"First attempt failed: {str(e1)}") | |
| # Try all possible formats | |
| img_buffer.seek(0) | |
| try: | |
| output_image = Image.open(img_buffer, formats=['PNG', 'JPEG', 'WEBP', 'GIF']) | |
| except Exception as e2: | |
| logger.error(f"Second attempt failed: {str(e2)}") | |
| # Final attempt: Try to save to file and reload | |
| try: | |
| temp_img_path = os.path.join(os.path.dirname(__file__), f"temp_output_{uuid.uuid4().hex[:8]}.png") | |
| with open(temp_img_path, 'wb') as f: | |
| f.write(data) | |
| output_image = Image.open(temp_img_path) | |
| # Clean up temp file | |
| os.remove(temp_img_path) | |
| except Exception as e3: | |
| logger.error(f"All image loading attempts failed: {str(e3)}") | |
| return None | |
| logger.info(f"Successfully opened image: format={output_image.format}, size={output_image.size}, mode={output_image.mode}") | |
| if output_image.mode in ('RGBA', 'LA'): | |
| output_image = output_image.convert('RGB') | |
| logger.debug("Converted image to RGB mode") | |
| output_image.load() | |
| return output_image | |
| except Exception as e: | |
| logger.error(f"Error processing image data: {str(e)}", exc_info=True) | |
| return None | |
| else: | |
| logger.debug("Chunk contained no inline data") | |
| logger.warning("No valid image data found in response") | |
| return None | |
| finally: | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| logger.debug(f"Cleaned up temporary file: {temp_image_path}") | |
| def create_interface(): | |
| with gr.Blocks(title="Gemini Image Editor") as app: | |
| gr.Markdown("# Gemini Image Editor") | |
| gr.Markdown("Upload an image and provide instructions for how you want Gemini to edit it.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload Image", type="pil") | |
| prompt_text = gr.Textbox(label="Editing Instructions", | |
| placeholder="Describe how you want the image to be edited...") | |
| submit_btn = gr.Button("Generate Edited Image") | |
| with gr.Column(): | |
| # Remove the output text box and only keep the image output | |
| output_image = gr.Image(label="Edited Image") | |
| # Update the function outputs to only include the image | |
| submit_btn.click( | |
| fn=generate, | |
| inputs=[input_image, prompt_text], | |
| outputs=[output_image] | |
| ) | |
| return app | |
| def main(): | |
| app = create_interface() | |
| app.launch() | |
| if __name__ == "__main__": | |
| main() | |