# image_generation.py import os import mimetypes import json import streamlit as st import io import time import traceback from PIL import Image from typing import List, Dict, Optional # CORRECT IMPORTS FOR THE 'google-genai' SDK from google import genai from google.generativeai import types from google.api_core import exceptions # --- Client Initialization --- # This section initializes the client using the secrets from your HF Space. client = None try: api_key = st.secrets.get("GEMINI_API_KEY") if api_key: client = genai.Client(api_key=api_key) print("✅ Google AI client for Gemini initialized successfully.") else: print("❌ FATAL: GEMINI_API_KEY not found in Streamlit secrets.") st.error("GEMINI_API_KEY not configured. Please set it in your Hugging Face Space secrets.") st.stop() except Exception as e: print(f"❌ Error initializing Google AI client: {e}") st.error(f"An unexpected error occurred during client initialization: {e}") st.stop() # --- Helper Functions --- def save_binary_file(file_name: str, data: bytes): """Saves binary data to a file.""" try: with open(file_name, "wb") as f: f.write(data) print(f"✅ Image saved to: {file_name}") except Exception as e: print(f"❌ Error saving file {file_name}: {e}") def pil_image_to_part(image: Image.Image) -> types.Part: """Converts a PIL Image to a genai.types.Part object.""" img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='JPEG') img_bytes = img_byte_arr.getvalue() return types.Part(inline_data=types.Blob(mime_type="image/jpeg", data=img_bytes)) def generate_image_with_gemini( prompt: str, output_file_base: str, context_image: Optional[Image.Image] = None ) -> Optional[str]: """Generates an image using the Gemini API with the corrected SDK calls.""" if not client: print("❌ Gemini client not initialized.") return None print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---") try: model_name = "gemini-2.0-flash-preview-image-generation" content_parts = [] if context_image: system_prompt = """You are a master storyboard artist creating a visual story sequence. IMPORTANT: You MUST generate an image for every request. Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in: - Character appearance and clothing - Art style and color palette - Lighting and atmosphere Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting. Generate an image that illustrates the following scene:""" print(" -> Using previous image as context.") else: system_prompt = """You are a master storyboard artist creating the opening scene of a visual story. IMPORTANT: You MUST generate an image for this request. Create a stunning, cinematic image in an epic fantasy digital painting style with: - Rich, detailed artwork - Dramatic lighting and atmosphere - High-quality digital painting aesthetic This is the first scene of the story. Generate an image that illustrates:""" content_parts.append(types.Part(text=system_prompt)) if context_image: content_parts.append(pil_image_to_part(context_image)) image_instruction = f"CREATE AN IMAGE NOW:\n{prompt}\nRemember: You must generate a visual image." content_parts.append(types.Part(text=image_instruction)) contents = [types.Content(role="user", parts=content_parts)] generate_content_config = types.GenerateContentConfig( response_modalities=["IMAGE", "TEXT"], ) stream = client.models.generate_content_stream( model=model_name, contents=contents, config=generate_content_config, ) saved_file_path = None for chunk in stream: if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: continue for part in chunk.candidates[0].content.parts: if part.inline_data and part.inline_data.data: inline_data = part.inline_data file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg" full_file_name = f"{output_file_base}{file_extension}" save_binary_file(full_file_name, inline_data.data) saved_file_path = full_file_name if saved_file_path: print(f"✅ Successfully generated and saved image: {saved_file_path}") else: print("⚠️ No image was returned from the API.") return saved_file_path except exceptions.InvalidArgument as e: print(f"❌ API Invalid Argument Error: {e}") traceback.print_exc() return None except Exception as e: print(f"❌ An unexpected error occurred during the Gemini API call: {e}") traceback.print_exc() return None def generate_all_images_from_file(json_path: str, output_dir: str, output_json_path: str): """Main loop to process a JSON file and generate images.""" try: with open(json_path, 'r', encoding='utf-8') as f: multimedia_data = json.load(f) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"❌ Error reading or parsing {json_path}: {e}") return if not os.path.exists(output_dir): os.makedirs(output_dir) previous_image = None successful_generations = 0 for i, item in enumerate(multimedia_data): print(f"\n{'='*60}\nProcessing item {i+1}/{len(multimedia_data)}\n{'='*60}") image_prompt = item.get("image_prompt") if not image_prompt: item["image_path"] = None continue file_base_path = os.path.join(output_dir, f"image_{i:03d}") saved_image_path = generate_image_with_gemini( image_prompt, file_base_path, context_image=previous_image ) item["image_path"] = saved_image_path if saved_image_path: try: previous_image = Image.open(saved_image_path) successful_generations += 1 except Exception as e: previous_image = None else: previous_image = None time.sleep(2) with open(output_json_path, 'w', encoding='utf-8') as f: json.dump(multimedia_data, f, indent=2, ensure_ascii=False) print(f"\n--- ✅ Finished. Generated {successful_generations}/{len(multimedia_data)} images. ---")