SparrowTale / image_generation.py
abhinav0231's picture
Update image_generation.py
0432337 verified
# image_generation.py
import os
import mimetypes
import json
import streamlit as st
import io
import time
import traceback
from PIL import Image
from typing import List, Dict, Optional
# CORRECT IMPORTS FOR THE 'google-genai' SDK
from google import genai
from google.generativeai import types
from google.api_core import exceptions
# --- Client Initialization ---
# This section initializes the client using the secrets from your HF Space.
client = None
try:
api_key = st.secrets.get("GEMINI_API_KEY")
if api_key:
client = genai.Client(api_key=api_key)
print("βœ… Google AI client for Gemini initialized successfully.")
else:
print("❌ FATAL: GEMINI_API_KEY not found in Streamlit secrets.")
st.error("GEMINI_API_KEY not configured. Please set it in your Hugging Face Space secrets.")
st.stop()
except Exception as e:
print(f"❌ Error initializing Google AI client: {e}")
st.error(f"An unexpected error occurred during client initialization: {e}")
st.stop()
# --- Helper Functions ---
def save_binary_file(file_name: str, data: bytes):
"""Saves binary data to a file."""
try:
with open(file_name, "wb") as f:
f.write(data)
print(f"βœ… Image saved to: {file_name}")
except Exception as e:
print(f"❌ Error saving file {file_name}: {e}")
def pil_image_to_part(image: Image.Image) -> types.Part:
"""Converts a PIL Image to a genai.types.Part object."""
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='JPEG')
img_bytes = img_byte_arr.getvalue()
return types.Part(inline_data=types.Blob(mime_type="image/jpeg", data=img_bytes))
def generate_image_with_gemini(
prompt: str,
output_file_base: str,
context_image: Optional[Image.Image] = None
) -> Optional[str]:
"""Generates an image using the Gemini API with the corrected SDK calls."""
if not client:
print("❌ Gemini client not initialized.")
return None
print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")
try:
model_name = "gemini-2.0-flash-preview-image-generation"
content_parts = []
if context_image:
system_prompt = """You are a master storyboard artist creating a visual story sequence.
IMPORTANT: You MUST generate an image for every request. Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in:
- Character appearance and clothing
- Art style and color palette
- Lighting and atmosphere
Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.
Generate an image that illustrates the following scene:"""
print(" -> Using previous image as context.")
else:
system_prompt = """You are a master storyboard artist creating the opening scene of a visual story.
IMPORTANT: You MUST generate an image for this request. Create a stunning, cinematic image in an epic fantasy digital painting style with:
- Rich, detailed artwork
- Dramatic lighting and atmosphere
- High-quality digital painting aesthetic
This is the first scene of the story. Generate an image that illustrates:"""
content_parts.append(types.Part(text=system_prompt))
if context_image:
content_parts.append(pil_image_to_part(context_image))
image_instruction = f"CREATE AN IMAGE NOW:\n{prompt}\nRemember: You must generate a visual image."
content_parts.append(types.Part(text=image_instruction))
contents = [types.Content(role="user", parts=content_parts)]
generate_content_config = types.GenerateContentConfig(
response_modalities=["IMAGE", "TEXT"],
)
stream = client.models.generate_content_stream(
model=model_name,
contents=contents,
config=generate_content_config,
)
saved_file_path = None
for chunk in stream:
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
for part in chunk.candidates[0].content.parts:
if part.inline_data and part.inline_data.data:
inline_data = part.inline_data
file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg"
full_file_name = f"{output_file_base}{file_extension}"
save_binary_file(full_file_name, inline_data.data)
saved_file_path = full_file_name
if saved_file_path:
print(f"βœ… Successfully generated and saved image: {saved_file_path}")
else:
print("⚠️ No image was returned from the API.")
return saved_file_path
except exceptions.InvalidArgument as e:
print(f"❌ API Invalid Argument Error: {e}")
traceback.print_exc()
return None
except Exception as e:
print(f"❌ An unexpected error occurred during the Gemini API call: {e}")
traceback.print_exc()
return None
def generate_all_images_from_file(json_path: str, output_dir: str, output_json_path: str):
"""Main loop to process a JSON file and generate images."""
try:
with open(json_path, 'r', encoding='utf-8') as f:
multimedia_data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"❌ Error reading or parsing {json_path}: {e}")
return
if not os.path.exists(output_dir):
os.makedirs(output_dir)
previous_image = None
successful_generations = 0
for i, item in enumerate(multimedia_data):
print(f"\n{'='*60}\nProcessing item {i+1}/{len(multimedia_data)}\n{'='*60}")
image_prompt = item.get("image_prompt")
if not image_prompt:
item["image_path"] = None
continue
file_base_path = os.path.join(output_dir, f"image_{i:03d}")
saved_image_path = generate_image_with_gemini(
image_prompt, file_base_path, context_image=previous_image
)
item["image_path"] = saved_image_path
if saved_image_path:
try:
previous_image = Image.open(saved_image_path)
successful_generations += 1
except Exception as e:
previous_image = None
else:
previous_image = None
time.sleep(2)
with open(output_json_path, 'w', encoding='utf-8') as f:
json.dump(multimedia_data, f, indent=2, ensure_ascii=False)
print(f"\n--- βœ… Finished. Generated {successful_generations}/{len(multimedia_data)} images. ---")