abhinav0231 commited on
Commit
94047f1
Β·
verified Β·
1 Parent(s): 2a450f6

Update image_generation.py

Browse files
Files changed (1) hide show
  1. image_generation.py +187 -270
image_generation.py CHANGED
@@ -1,270 +1,187 @@
1
- import os
2
- import mimetypes
3
- import json
4
- import base64
5
- import streamlit as st
6
- from google import genai
7
- from google.genai import types
8
- from typing import List, Dict, Optional
9
- from PIL import Image
10
- import io
11
- import pathlib
12
-
13
-
14
- client = None
15
-
16
- try:
17
- api_key = st.secrets.get("GEMINI_API_KEY") or os.getenv("GEMINI_API_KEY")
18
- if api_key:
19
- client = genai.Client(api_key=api_key)
20
- print("βœ… Google AI client for Gemini initialized successfully.")
21
- else:
22
- print("⚠️ Warning: GEMINI_API_KEY not found.")
23
- exit(1)
24
- except Exception as e:
25
- print(f"❌ Error initializing Google AI client: {e}")
26
- exit(1)
27
-
28
- # --- Helper Functions ---
29
-
30
- def save_binary_file(file_name: str, data: bytes):
31
- try:
32
- with open(file_name, "wb") as f:
33
- f.write(data)
34
- print(f"βœ… Image saved to: {file_name}")
35
- except Exception as e:
36
- print(f"❌ Error saving file {file_name}: {e}")
37
-
38
- def pil_image_to_part(image: Image.Image) -> types.Part:
39
- """Convert PIL Image to types.Part for use in content."""
40
- # Convert PIL Image to bytes
41
- img_byte_arr = io.BytesIO()
42
- image.save(img_byte_arr, format='JPEG')
43
- img_bytes = img_byte_arr.getvalue()
44
-
45
- return types.Part.from_bytes(
46
- data=img_bytes,
47
- mime_type='image/jpeg'
48
- )
49
-
50
- def generate_image_with_gemini(
51
- prompt: str,
52
- output_file_base: str,
53
- context_image: Optional[Image.Image] = None
54
- ) -> Optional[str]:
55
- """
56
- Generates an image, optionally using a previous image as context.
57
- """
58
-
59
- if not client:
60
- return None
61
-
62
- print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")
63
-
64
- try:
65
- model = "gemini-2.5-flash-image-preview"
66
-
67
- # Build content parts
68
- content_parts = []
69
-
70
- # explicit system prompt that always demands image generation
71
- if context_image:
72
- system_prompt = """You are a master storyboard artist creating a visual story sequence.
73
-
74
- IMPORTANT: You MUST generate an image for every request.
75
-
76
- Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in:
77
- - Character appearance and clothing
78
- - Art style and color palette
79
- - Lighting and atmosphere
80
- - Overall visual tone
81
-
82
- Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.
83
-
84
- Generate an image that illustrates the following scene:"""
85
- print(" -> Using previous image as context for consistent styling.")
86
- else:
87
- system_prompt = """You are a master storyboard artist creating the opening scene of a visual story.
88
-
89
- IMPORTANT: You MUST generate an image for this request.
90
-
91
- Create a stunning, cinematic image in an epic fantasy digital painting style with:
92
- - Rich, detailed artwork
93
- - Dramatic lighting and atmosphere
94
- - High-quality digital painting aesthetic
95
- - Vivid colors and intricate details
96
-
97
- This is the first scene of the story. Generate an image that illustrates:"""
98
-
99
- # Add system prompt
100
- content_parts.append(types.Part.from_text(text=system_prompt))
101
-
102
- # Add context image if provided
103
- if context_image:
104
- content_parts.append(pil_image_to_part(context_image))
105
-
106
- # Add the actual prompt with explicit image generation instruction
107
- image_instruction = f"""CREATE AN IMAGE NOW:
108
-
109
- {prompt}
110
-
111
- Remember: You must generate a visual image, not text. Create the artwork described above."""
112
-
113
- content_parts.append(types.Part.from_text(text=image_instruction))
114
-
115
- # Create content structure
116
- contents = [
117
- types.Content(
118
- role="user",
119
- parts=content_parts,
120
- ),
121
- ]
122
-
123
- # Configure generation
124
- generate_content_config = types.GenerateContentConfig(
125
- response_modalities=[
126
- "IMAGE",
127
- "TEXT",
128
- ],
129
- )
130
-
131
- # Generate content using streaming
132
- saved_file_path = None
133
- text_responses = []
134
-
135
- for chunk in client.models.generate_content_stream(
136
- model=model,
137
- contents=contents,
138
- config=generate_content_config,
139
- ):
140
- if (
141
- chunk.candidates is None
142
- or chunk.candidates[0].content is None
143
- or chunk.candidates[0].content.parts is None
144
- ):
145
- continue
146
-
147
- # Check for image data
148
- if (chunk.candidates[0].content.parts[0].inline_data and
149
- chunk.candidates[0].content.parts[0].inline_data.data):
150
-
151
- inline_data = chunk.candidates[0].content.parts[0].inline_data
152
- data_buffer = inline_data.data
153
- file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg"
154
- full_file_name = f"{output_file_base}{file_extension}"
155
-
156
- save_binary_file(full_file_name, data_buffer)
157
- saved_file_path = full_file_name
158
- print(f"βœ… Successfully generated and saved image: {full_file_name}")
159
-
160
- # Collect any text responses
161
- elif hasattr(chunk, 'text') and chunk.text:
162
- text_responses.append(chunk.text)
163
-
164
- # If we got text but no image, print the text for debugging
165
- if not saved_file_path and text_responses:
166
- print(f"⚠️ No image generated. API returned text: {' '.join(text_responses)}")
167
-
168
- return saved_file_path
169
-
170
- except Exception as e:
171
- print(f"❌ An error occurred during the Gemini API call: {e}")
172
- import traceback
173
- traceback.print_exc()
174
- return None
175
-
176
- def generate_all_images_from_file(
177
- json_path: str,
178
- output_dir: str = "generated_images",
179
- output_json_path: str = "multimedia_data_with_images.json"
180
- ) -> List[Dict[str, str]]:
181
- """
182
- Reads data from a JSON, generates images sequentially, and saves a new JSON with image paths.
183
- """
184
-
185
- try:
186
- with open(json_path, 'r', encoding='utf-8') as f:
187
- multimedia_data = json.load(f)
188
- except (FileNotFoundError, json.JSONDecodeError) as e:
189
- print(f"❌ Error reading or parsing {json_path}: {e}")
190
- return []
191
-
192
- if not os.path.exists(output_dir):
193
- os.makedirs(output_dir)
194
-
195
- previous_image = None
196
- successful_generations = 0
197
-
198
- for i, item in enumerate(multimedia_data):
199
- print(f"\n{'='*60}")
200
- print(f"Processing item {i+1}/{len(multimedia_data)} - {'FIRST IMAGE' if i == 0 else 'CONTINUATION'}")
201
- print(f"{'='*60}")
202
-
203
- image_prompt = item.get("image_prompt")
204
-
205
- if not image_prompt or "Error:" in image_prompt:
206
- print(f"⚠️ Skipping item {i}: No valid image prompt found")
207
- item["image_path"] = None
208
- continue
209
-
210
- # Show the prompt being used
211
- print(f"πŸ“ Image prompt: {image_prompt[:100]}{'...' if len(image_prompt) > 100 else ''}")
212
-
213
- file_base_path = os.path.join(output_dir, f"image_{i:03d}")
214
- saved_image_path = generate_image_with_gemini(
215
- image_prompt,
216
- file_base_path,
217
- context_image=previous_image
218
- )
219
-
220
- item["image_path"] = saved_image_path
221
-
222
- # Load the generated image for the next iteration
223
- if saved_image_path and os.path.exists(saved_image_path):
224
- try:
225
- previous_image = Image.open(saved_image_path)
226
- successful_generations += 1
227
- print(f"βœ… Loaded image {saved_image_path} as context for next generation.")
228
- except Exception as e:
229
- print(f"⚠️ Could not load image {saved_image_path} for context. Error: {e}")
230
- previous_image = None
231
- else:
232
- previous_image = None
233
- print("❌ No image generated for this item!")
234
-
235
- # Add a delay between requests
236
- import time
237
- time.sleep(2)
238
-
239
- # Save the updated JSON with image paths
240
- try:
241
- with open(output_json_path, 'w', encoding='utf-8') as f:
242
- json.dump(multimedia_data, f, indent=2, ensure_ascii=False)
243
- print(f"\n--- βœ… Image generation finished. Updated data saved to {output_json_path}. ---")
244
- print(f"Successfully generated {successful_generations}/{len(multimedia_data)} images.")
245
-
246
- # Print summary of results
247
- print(f"\nπŸ“Š GENERATION SUMMARY:")
248
- for i, item in enumerate(multimedia_data):
249
- status = "βœ… SUCCESS" if item.get("image_path") else "❌ FAILED"
250
- print(f" Item {i+1}: {status}")
251
-
252
- except Exception as e:
253
- print(f"❌ Error saving updated JSON: {e}")
254
-
255
- return multimedia_data
256
-
257
- # Example usage
258
- # if __name__ == '__main__':
259
- # print("\n--- RUNNING IMAGE GENERATION EXAMPLE ---")
260
-
261
- # json_input_file = "multimedia_data.json"
262
-
263
- # if not os.path.exists(json_input_file):
264
- # print(f"❌ Error: Input file '{json_input_file}' not found.")
265
- # print("Please run chunk.py first to generate it.")
266
- # else:
267
- # final_data = generate_all_images_from_file(json_input_file)
268
-
269
- # if final_data:
270
- # print(f"\n--- βœ… Processing completed for {len(final_data)} items ---")
 
1
+ import os
2
+ import mimetypes
3
+ import json
4
+ import streamlit as st
5
+ import google.generativeai as genai
6
+ from typing import List, Dict, Optional
7
+ from PIL import Image
8
+ import io
9
+ import time
10
+ import traceback
11
+
12
+ # Configure the client once with the API key
13
+ try:
14
+ api_key = st.secrets.get("GEMINI_API_KEY") or os.getenv("GEMINI_API_KEY")
15
+ if api_key:
16
+ genai.configure(api_key=api_key)
17
+ print("βœ… Google AI client for Gemini configured successfully.")
18
+ else:
19
+ print("⚠️ Warning: GEMINI_API_KEY not found.")
20
+ # Exit if running in a context where it should be mandatory
21
+ # exit(1)
22
+ except Exception as e:
23
+ print(f"❌ Error configuring Google AI client: {e}")
24
+ # exit(1)
25
+
26
+ # --- Helper Function ---
27
+ def save_binary_file(file_name: str, data: bytes):
28
+ """Saves binary data to a file."""
29
+ try:
30
+ with open(file_name, "wb") as f:
31
+ f.write(data)
32
+ print(f"βœ… Image saved to: {file_name}")
33
+ except Exception as e:
34
+ print(f"❌ Error saving file {file_name}: {e}")
35
+
36
+ # --- IMAGE GENERATION FUNCTION ---
37
+ def generate_image_with_gemini(
38
+ prompt: str,
39
+ output_file_base: str,
40
+ context_image: Optional[Image.Image] = None
41
+ ) -> Optional[str]:
42
+ """
43
+ Generates an image using the updated Gemini API syntax, optionally using a
44
+ previous image as context.
45
+ """
46
+ print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")
47
+
48
+ try:
49
+ # Define the specific image generation model
50
+ model = genai.GenerativeModel(model_name="gemini-1.5-flash") # Use a current, valid model
51
+
52
+ # --- UPDATED CONTENT STRUCTURE ---
53
+ # Build a simple list of content parts. The SDK handles the type.
54
+ content_parts = []
55
+
56
+ if context_image:
57
+ system_prompt = """You are a master storyboard artist creating a visual story sequence.
58
+ IMPORTANT: You MUST generate an image for every request.
59
+ Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in:
60
+ - Character appearance and clothing
61
+ - Art style and color palette
62
+ - Lighting and atmosphere
63
+ - Overall visual tone
64
+ Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.
65
+ Generate an image that illustrates the following scene:"""
66
+ print(" -> Using previous image as context for consistent styling.")
67
+ # Add the system prompt, context image, and user prompt
68
+ content_parts = [system_prompt, context_image, f"CREATE AN IMAGE NOW: {prompt}"]
69
+ else:
70
+ system_prompt = """You are a master storyboard artist creating the opening scene of a visual story.
71
+ IMPORTANT: You MUST generate an image for this request.
72
+ Create a stunning, cinematic image in an epic fantasy digital painting style with:
73
+ - Rich, detailed artwork
74
+ - Dramatic lighting and atmosphere
75
+ - High-quality digital painting aesthetic
76
+ - Vivid colors and intricate details
77
+ This is the first scene of the story. Generate an image that illustrates:"""
78
+ # Add the system prompt and user prompt
79
+ content_parts = [system_prompt, f"CREATE AN IMAGE NOW: {prompt}"]
80
+
81
+ # --- UPDATED API CALL ---
82
+ # The generate_content method now takes a simple list
83
+ response = model.generate_content(content_parts, stream=True)
84
+
85
+ saved_file_path = None
86
+ text_responses = []
87
+
88
+ for chunk in response:
89
+ # The modern SDK provides image data directly in parts
90
+ if chunk.parts and chunk.parts[0].file_data:
91
+ file_data = chunk.parts[0].file_data
92
+ data_buffer = file_data.data
93
+ file_extension = mimetypes.guess_extension(file_data.mime_type) or ".jpg"
94
+ full_file_name = f"{output_file_base}{file_extension}"
95
+ save_binary_file(full_file_name, data_buffer)
96
+ saved_file_path = full_file_name
97
+ print(f"βœ… Successfully generated and saved image: {full_file_name}")
98
+
99
+ # Collect any text responses for debugging
100
+ if chunk.text:
101
+ text_responses.append(chunk.text)
102
+
103
+ if not saved_file_path and text_responses:
104
+ print(f"⚠️ No image generated. API returned text: {' '.join(text_responses)}")
105
+
106
+ return saved_file_path
107
+
108
+ except Exception as e:
109
+ print(f"❌ An error occurred during the Gemini API call: {e}")
110
+ traceback.print_exc()
111
+ return None
112
+
113
+ # The rest of your script (generate_all_images_from_file) does not need changes
114
+ # as it correctly calls the function above.
115
+ def generate_all_images_from_file(
116
+ json_path: str,
117
+ output_dir: str = "generated_images",
118
+ output_json_path: str = "multimedia_data_with_images.json"
119
+ ) -> List[Dict[str, str]]:
120
+ """
121
+ Reads data from a JSON, generates images sequentially, and saves a new JSON with image paths.
122
+ """
123
+ try:
124
+ with open(json_path, 'r', encoding='utf-8') as f:
125
+ multimedia_data = json.load(f)
126
+ except (FileNotFoundError, json.JSONDecodeError) as e:
127
+ print(f"❌ Error reading or parsing {json_path}: {e}")
128
+ return []
129
+
130
+ if not os.path.exists(output_dir):
131
+ os.makedirs(output_dir)
132
+
133
+ previous_image = None
134
+ successful_generations = 0
135
+
136
+ for i, item in enumerate(multimedia_data):
137
+ print(f"\n{'='*60}")
138
+ print(f"Processing item {i+1}/{len(multimedia_data)} - {'FIRST IMAGE' if i == 0 else 'CONTINUATION'}")
139
+ print(f"{'='*60}")
140
+
141
+ image_prompt = item.get("image_prompt")
142
+ if not image_prompt or "Error:" in image_prompt:
143
+ print(f"⚠️ Skipping item {i}: No valid image prompt found")
144
+ item["image_path"] = None
145
+ continue
146
+
147
+ print(f"πŸ“ Image prompt: {image_prompt[:100]}{'...' if len(image_prompt) > 100 else ''}")
148
+ file_base_path = os.path.join(output_dir, f"image_{i:03d}")
149
+
150
+ saved_image_path = generate_image_with_gemini(
151
+ image_prompt,
152
+ file_base_path,
153
+ context_image=previous_image
154
+ )
155
+
156
+ item["image_path"] = saved_image_path
157
+
158
+ if saved_image_path and os.path.exists(saved_image_path):
159
+ try:
160
+ previous_image = Image.open(saved_image_path)
161
+ successful_generations += 1
162
+ print(f"βœ… Loaded image {saved_image_path} as context for next generation.")
163
+ except Exception as e:
164
+ print(f"⚠️ Could not load image {saved_image_path} for context. Error: {e}")
165
+ previous_image = None
166
+ else:
167
+ previous_image = None
168
+ print("❌ No image generated for this item!")
169
+
170
+ time.sleep(2)
171
+
172
+ try:
173
+ with open(output_json_path, 'w', encoding='utf-8') as f:
174
+ json.dump(multimedia_data, f, indent=2, ensure_ascii=False)
175
+ print(f"\n--- βœ… Image generation finished. Updated data saved to {output_json_path}. ---")
176
+ print(f"Successfully generated {successful_generations}/{len(multimedia_data)} images.")
177
+
178
+ print(f"\nπŸ“Š GENERATION SUMMARY:")
179
+ for i, item in enumerate(multimedia_data):
180
+ status = "βœ… SUCCESS" if item.get("image_path") else "❌ FAILED"
181
+ print(f" Item {i+1}: {status}")
182
+
183
+ except Exception as e:
184
+ print(f"❌ Error saving updated JSON: {e}")
185
+
186
+ return multimedia_data
187
+