Spaces:
Sleeping
Sleeping
| # image_generation.py | |
| import os | |
| import mimetypes | |
| import json | |
| import streamlit as st | |
| import io | |
| import time | |
| import traceback | |
| from PIL import Image | |
| from typing import List, Dict, Optional | |
| # CORRECT IMPORTS FOR THE 'google-genai' SDK | |
| from google import genai | |
| from google.generativeai import types | |
| from google.api_core import exceptions | |
| # --- Client Initialization --- | |
| # This section initializes the client using the secrets from your HF Space. | |
| client = None | |
| try: | |
| api_key = st.secrets.get("GEMINI_API_KEY") | |
| if api_key: | |
| client = genai.Client(api_key=api_key) | |
| print("β Google AI client for Gemini initialized successfully.") | |
| else: | |
| print("β FATAL: GEMINI_API_KEY not found in Streamlit secrets.") | |
| st.error("GEMINI_API_KEY not configured. Please set it in your Hugging Face Space secrets.") | |
| st.stop() | |
| except Exception as e: | |
| print(f"β Error initializing Google AI client: {e}") | |
| st.error(f"An unexpected error occurred during client initialization: {e}") | |
| st.stop() | |
| # --- Helper Functions --- | |
| def save_binary_file(file_name: str, data: bytes): | |
| """Saves binary data to a file.""" | |
| try: | |
| with open(file_name, "wb") as f: | |
| f.write(data) | |
| print(f"β Image saved to: {file_name}") | |
| except Exception as e: | |
| print(f"β Error saving file {file_name}: {e}") | |
| def pil_image_to_part(image: Image.Image) -> types.Part: | |
| """Converts a PIL Image to a genai.types.Part object.""" | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='JPEG') | |
| img_bytes = img_byte_arr.getvalue() | |
| return types.Part(inline_data=types.Blob(mime_type="image/jpeg", data=img_bytes)) | |
| def generate_image_with_gemini( | |
| prompt: str, | |
| output_file_base: str, | |
| context_image: Optional[Image.Image] = None | |
| ) -> Optional[str]: | |
| """Generates an image using the Gemini API with the corrected SDK calls.""" | |
| if not client: | |
| print("β Gemini client not initialized.") | |
| return None | |
| print(f"--- π¨ Generating image for prompt: '{prompt[:70]}...' ---") | |
| try: | |
| model_name = "gemini-2.0-flash-preview-image-generation" | |
| content_parts = [] | |
| if context_image: | |
| system_prompt = """You are a master storyboard artist creating a visual story sequence. | |
| IMPORTANT: You MUST generate an image for every request. Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in: | |
| - Character appearance and clothing | |
| - Art style and color palette | |
| - Lighting and atmosphere | |
| Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting. | |
| Generate an image that illustrates the following scene:""" | |
| print(" -> Using previous image as context.") | |
| else: | |
| system_prompt = """You are a master storyboard artist creating the opening scene of a visual story. | |
| IMPORTANT: You MUST generate an image for this request. Create a stunning, cinematic image in an epic fantasy digital painting style with: | |
| - Rich, detailed artwork | |
| - Dramatic lighting and atmosphere | |
| - High-quality digital painting aesthetic | |
| This is the first scene of the story. Generate an image that illustrates:""" | |
| content_parts.append(types.Part(text=system_prompt)) | |
| if context_image: | |
| content_parts.append(pil_image_to_part(context_image)) | |
| image_instruction = f"CREATE AN IMAGE NOW:\n{prompt}\nRemember: You must generate a visual image." | |
| content_parts.append(types.Part(text=image_instruction)) | |
| contents = [types.Content(role="user", parts=content_parts)] | |
| generate_content_config = types.GenerateContentConfig( | |
| response_modalities=["IMAGE", "TEXT"], | |
| ) | |
| stream = client.models.generate_content_stream( | |
| model=model_name, | |
| contents=contents, | |
| config=generate_content_config, | |
| ) | |
| saved_file_path = None | |
| for chunk in stream: | |
| if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
| continue | |
| for part in chunk.candidates[0].content.parts: | |
| if part.inline_data and part.inline_data.data: | |
| inline_data = part.inline_data | |
| file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg" | |
| full_file_name = f"{output_file_base}{file_extension}" | |
| save_binary_file(full_file_name, inline_data.data) | |
| saved_file_path = full_file_name | |
| if saved_file_path: | |
| print(f"β Successfully generated and saved image: {saved_file_path}") | |
| else: | |
| print("β οΈ No image was returned from the API.") | |
| return saved_file_path | |
| except exceptions.InvalidArgument as e: | |
| print(f"β API Invalid Argument Error: {e}") | |
| traceback.print_exc() | |
| return None | |
| except Exception as e: | |
| print(f"β An unexpected error occurred during the Gemini API call: {e}") | |
| traceback.print_exc() | |
| return None | |
| def generate_all_images_from_file(json_path: str, output_dir: str, output_json_path: str): | |
| """Main loop to process a JSON file and generate images.""" | |
| try: | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| multimedia_data = json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError) as e: | |
| print(f"β Error reading or parsing {json_path}: {e}") | |
| return | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| previous_image = None | |
| successful_generations = 0 | |
| for i, item in enumerate(multimedia_data): | |
| print(f"\n{'='*60}\nProcessing item {i+1}/{len(multimedia_data)}\n{'='*60}") | |
| image_prompt = item.get("image_prompt") | |
| if not image_prompt: | |
| item["image_path"] = None | |
| continue | |
| file_base_path = os.path.join(output_dir, f"image_{i:03d}") | |
| saved_image_path = generate_image_with_gemini( | |
| image_prompt, file_base_path, context_image=previous_image | |
| ) | |
| item["image_path"] = saved_image_path | |
| if saved_image_path: | |
| try: | |
| previous_image = Image.open(saved_image_path) | |
| successful_generations += 1 | |
| except Exception as e: | |
| previous_image = None | |
| else: | |
| previous_image = None | |
| time.sleep(2) | |
| with open(output_json_path, 'w', encoding='utf-8') as f: | |
| json.dump(multimedia_data, f, indent=2, ensure_ascii=False) | |
| print(f"\n--- β Finished. Generated {successful_generations}/{len(multimedia_data)} images. ---") |