File size: 6,896 Bytes
0432337
 
94047f1
 
 
 
 
a296a03
6340efd
 
 
0432337
 
6340efd
177b6e5
6340efd
94047f1
6340efd
0432337
6340efd
94047f1
6340efd
94047f1
6340efd
 
94047f1
6340efd
 
 
94047f1
6340efd
0432337
6340efd
 
 
0432337
94047f1
6340efd
94047f1
 
 
 
 
 
 
b61fa02
6340efd
b61fa02
6340efd
 
0432337
b61fa02
94047f1
 
 
 
 
0432337
b61fa02
6340efd
b61fa02
 
94047f1
b61fa02
94047f1
6340efd
 
755e872
94047f1
6340efd
 
 
 
 
 
 
177b6e5
94047f1
6340efd
 
 
 
 
 
0432337
 
177b6e5
b61fa02
 
 
0432337
 
b61fa02
6340efd
a296a03
b61fa02
6340efd
a296a03
94047f1
6340efd
 
b61fa02
 
6340efd
 
 
 
 
b61fa02
 
6340efd
 
b61fa02
 
 
6340efd
b61fa02
0432337
6340efd
 
 
0432337
b61fa02
 
94047f1
6340efd
 
177b6e5
 
94047f1
0432337
94047f1
 
 
0432337
 
94047f1
 
 
 
 
0432337
94047f1
6340efd
 
94047f1
3aa60f4
6340efd
94047f1
6340efd
94047f1
6340efd
0432337
94047f1
 
 
 
6340efd
0432337
6340efd
 
94047f1
 
0432337
94047f1
 
6340efd
94047f1
177b6e5
94047f1
 
0432337
177b6e5
 
0432337
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# image_generation.py

import os
import mimetypes
import json
import streamlit as st
import io
import time
import traceback
from PIL import Image
from typing import List, Dict, Optional

# CORRECT IMPORTS FOR THE 'google-genai' SDK
from google import genai
from google.generativeai import types
from google.api_core import exceptions

# --- Client Initialization ---
# This section initializes the client using the secrets from your HF Space.
client = None
try:
    api_key = st.secrets.get("GEMINI_API_KEY")
    if api_key:
        client = genai.Client(api_key=api_key)
        print("βœ… Google AI client for Gemini initialized successfully.")
    else:
        print("❌ FATAL: GEMINI_API_KEY not found in Streamlit secrets.")
        st.error("GEMINI_API_KEY not configured. Please set it in your Hugging Face Space secrets.")
        st.stop()
except Exception as e:
    print(f"❌ Error initializing Google AI client: {e}")
    st.error(f"An unexpected error occurred during client initialization: {e}")
    st.stop()

# --- Helper Functions ---

def save_binary_file(file_name: str, data: bytes):
    """Saves binary data to a file."""
    try:
        with open(file_name, "wb") as f:
            f.write(data)
        print(f"βœ… Image saved to: {file_name}")
    except Exception as e:
        print(f"❌ Error saving file {file_name}: {e}")

def pil_image_to_part(image: Image.Image) -> types.Part:
    """Converts a PIL Image to a genai.types.Part object."""
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='JPEG')
    img_bytes = img_byte_arr.getvalue()
    return types.Part(inline_data=types.Blob(mime_type="image/jpeg", data=img_bytes))

def generate_image_with_gemini(
    prompt: str,
    output_file_base: str,
    context_image: Optional[Image.Image] = None
) -> Optional[str]:
    """Generates an image using the Gemini API with the corrected SDK calls."""
    if not client:
        print("❌ Gemini client not initialized.")
        return None

    print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")

    try:
        model_name = "gemini-2.0-flash-preview-image-generation"
        content_parts = []

        if context_image:
            system_prompt = """You are a master storyboard artist creating a visual story sequence.
            IMPORTANT: You MUST generate an image for every request. Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in:
            - Character appearance and clothing
            - Art style and color palette
            - Lighting and atmosphere
            Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.
            Generate an image that illustrates the following scene:"""
            print(" -> Using previous image as context.")
        else:
            system_prompt = """You are a master storyboard artist creating the opening scene of a visual story.
            IMPORTANT: You MUST generate an image for this request. Create a stunning, cinematic image in an epic fantasy digital painting style with:
            - Rich, detailed artwork
            - Dramatic lighting and atmosphere
            - High-quality digital painting aesthetic
            This is the first scene of the story. Generate an image that illustrates:"""
        
        content_parts.append(types.Part(text=system_prompt))
        
        if context_image:
            content_parts.append(pil_image_to_part(context_image))
        
        image_instruction = f"CREATE AN IMAGE NOW:\n{prompt}\nRemember: You must generate a visual image."
        content_parts.append(types.Part(text=image_instruction))

        contents = [types.Content(role="user", parts=content_parts)]

        generate_content_config = types.GenerateContentConfig(
            response_modalities=["IMAGE", "TEXT"],
        )

        stream = client.models.generate_content_stream(
            model=model_name,
            contents=contents,
            config=generate_content_config,
        )

        saved_file_path = None
        for chunk in stream:
            if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
                continue

            for part in chunk.candidates[0].content.parts:
                if part.inline_data and part.inline_data.data:
                    inline_data = part.inline_data
                    file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg"
                    full_file_name = f"{output_file_base}{file_extension}"
                    save_binary_file(full_file_name, inline_data.data)
                    saved_file_path = full_file_name
        
        if saved_file_path:
            print(f"βœ… Successfully generated and saved image: {saved_file_path}")
        else:
            print("⚠️ No image was returned from the API.")

        return saved_file_path

    except exceptions.InvalidArgument as e:
        print(f"❌ API Invalid Argument Error: {e}")
        traceback.print_exc()
        return None
    except Exception as e:
        print(f"❌ An unexpected error occurred during the Gemini API call: {e}")
        traceback.print_exc()
        return None

def generate_all_images_from_file(json_path: str, output_dir: str, output_json_path: str):
    """Main loop to process a JSON file and generate images."""
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            multimedia_data = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print(f"❌ Error reading or parsing {json_path}: {e}")
        return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    previous_image = None
    successful_generations = 0
    for i, item in enumerate(multimedia_data):
        print(f"\n{'='*60}\nProcessing item {i+1}/{len(multimedia_data)}\n{'='*60}")
        image_prompt = item.get("image_prompt")

        if not image_prompt:
            item["image_path"] = None
            continue

        file_base_path = os.path.join(output_dir, f"image_{i:03d}")
        saved_image_path = generate_image_with_gemini(
            image_prompt, file_base_path, context_image=previous_image
        )

        item["image_path"] = saved_image_path

        if saved_image_path:
            try:
                previous_image = Image.open(saved_image_path)
                successful_generations += 1
            except Exception as e:
                previous_image = None
        else:
            previous_image = None
        
        time.sleep(2)

    with open(output_json_path, 'w', encoding='utf-8') as f:
        json.dump(multimedia_data, f, indent=2, ensure_ascii=False)
    
    print(f"\n--- βœ… Finished. Generated {successful_generations}/{len(multimedia_data)} images. ---")