Spaces:
Running
Running
| from google.genai import types | |
| import os | |
| from PIL import Image | |
| from io import BytesIO | |
| from datetime import datetime | |
| import logging | |
| import asyncio | |
| import gradio as gr | |
| from config import settings | |
| from services.google import GoogleClientFactory | |
| from agent.utils import with_retries | |
| logger = logging.getLogger(__name__) | |
| safety_settings = [ | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_HARASSMENT", | |
| threshold="BLOCK_NONE", # Block none | |
| ), | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_HATE_SPEECH", | |
| threshold="BLOCK_NONE", # Block none | |
| ), | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| threshold="BLOCK_NONE", # Block none | |
| ), | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_DANGEROUS_CONTENT", | |
| threshold="BLOCK_NONE", # Block none | |
| ), | |
| ] | |
| async def generate_image(prompt: str) -> tuple[str, str] | None: | |
| """ | |
| Generate an image using Google's Gemini model and save it to generated/images directory. | |
| Args: | |
| prompt (str): The text prompt to generate the image from | |
| Returns: | |
| str: Path to the generated image file, or None if generation failed | |
| """ | |
| # Ensure the generated/images directory exists | |
| output_dir = "generated/images" | |
| os.makedirs(output_dir, exist_ok=True) | |
| logger.info(f"Generating image with prompt: {prompt}") | |
| try: | |
| async with GoogleClientFactory.image() as client: | |
| response = await with_retries( | |
| lambda: client.models.generate_content( | |
| model="gemini-2.0-flash-preview-image-generation", | |
| contents=prompt, | |
| config=types.GenerateContentConfig( | |
| response_modalities=["TEXT", "IMAGE"], | |
| safety_settings=safety_settings, | |
| ), | |
| ) | |
| ) | |
| # Process the response parts | |
| image_saved = False | |
| for part in response.candidates[0].content.parts: | |
| if part.inline_data is not None: | |
| # Create a filename with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"gemini_{timestamp}.png" | |
| filepath = os.path.join(output_dir, filename) | |
| # Save the image | |
| image = Image.open(BytesIO(part.inline_data.data)) | |
| await asyncio.to_thread(image.save, filepath, "PNG") | |
| logger.info(f"Image saved to: {filepath}") | |
| image_saved = True | |
| return filepath, prompt | |
| if not image_saved: | |
| gr.Warning("Image was censored by Google!") | |
| logger.error("No image was generated in the response.") | |
| return None, None | |
| except Exception as e: | |
| logger.error(f"Error generating image: {e}") | |
| return None, None | |
| async def modify_image(image_path: str, modification_prompt: str) -> str | None: | |
| """ | |
| Modify an existing image using Google's Gemini model based on a text prompt. | |
| Args: | |
| image_path (str): Path to the existing image file | |
| modification_prompt (str): The text prompt describing how to modify the image | |
| Returns: | |
| str: Path to the modified image file, or None if modification failed | |
| """ | |
| # Ensure the generated/images directory exists | |
| output_dir = "generated/images" | |
| os.makedirs(output_dir, exist_ok=True) | |
| logger.info(f"Modifying current scene image with prompt: {modification_prompt}") | |
| # Check if the input image exists | |
| if not os.path.exists(image_path): | |
| logger.error(f"Error: Image file not found at {image_path}") | |
| return None | |
| try: | |
| async with GoogleClientFactory.image() as client: | |
| # Load the input image | |
| input_image = Image.open(image_path) | |
| # Make the API call with both text and image | |
| response = await with_retries( | |
| lambda: client.models.generate_content( | |
| model="gemini-2.0-flash-preview-image-generation", | |
| contents=[modification_prompt, input_image], | |
| config=types.GenerateContentConfig( | |
| response_modalities=["TEXT", "IMAGE"], | |
| safety_settings=safety_settings, | |
| ), | |
| ), | |
| ) | |
| # Process the response parts | |
| image_saved = False | |
| for part in response.candidates[0].content.parts: | |
| if part.inline_data is not None: | |
| # Create a filename with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"gemini_modified_{timestamp}.png" | |
| filepath = os.path.join(output_dir, filename) | |
| # Save the modified image | |
| modified_image = Image.open(BytesIO(part.inline_data.data)) | |
| await asyncio.to_thread(modified_image.save, filepath, "PNG") | |
| logger.info(f"Modified image saved to: {filepath}") | |
| image_saved = True | |
| return filepath, modification_prompt | |
| if not image_saved: | |
| gr.Warning("Updated image was censored by Google!") | |
| logger.error("No modified image was generated in the response.") | |
| return None, None | |
| except Exception as e: | |
| logger.error(f"Error modifying image: {e}") | |
| return None, None | |
| if __name__ == "__main__": | |
| # Example usage | |
| sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game" | |
| generated_image_path = generate_image(sample_prompt) | |
| # if generated_image_path: | |
| # # Example modification | |
| # modification_prompt = "Now the house is destroyed, and the jawas are running away" | |
| # modified_image_path = modify_image(generated_image_path, modification_prompt) | |
| # if modified_image_path: | |
| # print(f"Successfully modified image: {modified_image_path}") | |