abhinav0231 commited on
Commit
0432337
Β·
verified Β·
1 Parent(s): 2b359cb

Update image_generation.py

Browse files
Files changed (1) hide show
  1. image_generation.py +27 -54
image_generation.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import mimetypes
3
  import json
@@ -7,11 +9,14 @@ import time
7
  import traceback
8
  from PIL import Image
9
  from typing import List, Dict, Optional
 
 
10
  from google import genai
11
  from google.generativeai import types
12
  from google.api_core import exceptions
13
 
14
  # --- Client Initialization ---
 
15
  client = None
16
  try:
17
  api_key = st.secrets.get("GEMINI_API_KEY")
@@ -24,10 +29,11 @@ try:
24
  st.stop()
25
  except Exception as e:
26
  print(f"❌ Error initializing Google AI client: {e}")
27
- st.error(f"Error initializing Google AI client: {e}")
28
  st.stop()
29
 
30
  # --- Helper Functions ---
 
31
  def save_binary_file(file_name: str, data: bytes):
32
  """Saves binary data to a file."""
33
  try:
@@ -42,20 +48,14 @@ def pil_image_to_part(image: Image.Image) -> types.Part:
42
  img_byte_arr = io.BytesIO()
43
  image.save(img_byte_arr, format='JPEG')
44
  img_bytes = img_byte_arr.getvalue()
45
- # This call was correct and remains the same.
46
- return types.Part.from_bytes(
47
- data=img_bytes,
48
- mime_type='image/jpeg'
49
- )
50
 
51
  def generate_image_with_gemini(
52
  prompt: str,
53
  output_file_base: str,
54
  context_image: Optional[Image.Image] = None
55
  ) -> Optional[str]:
56
- """
57
- Generates an image using the Gemini API, now with all corrections.
58
- """
59
  if not client:
60
  print("❌ Gemini client not initialized.")
61
  return None
@@ -66,7 +66,6 @@ def generate_image_with_gemini(
66
  model_name = "gemini-2.0-flash-preview-image-generation"
67
  content_parts = []
68
 
69
- # Define system prompts
70
  if context_image:
71
  system_prompt = """You are a master storyboard artist creating a visual story sequence.
72
  IMPORTANT: You MUST generate an image for every request. Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in:
@@ -83,14 +82,14 @@ def generate_image_with_gemini(
83
  - Dramatic lighting and atmosphere
84
  - High-quality digital painting aesthetic
85
  This is the first scene of the story. Generate an image that illustrates:"""
86
-
87
- content_parts.append(types.Part.from_text(system_prompt))
88
 
89
  if context_image:
90
  content_parts.append(pil_image_to_part(context_image))
91
 
92
- image_instruction = f"CREATE AN IMAGE NOW:\n{prompt}\nRemember: You must generate a visual image, not text."
93
- content_parts.append(types.Part.from_text(image_instruction))
94
 
95
  contents = [types.Content(role="user", parts=content_parts)]
96
 
@@ -105,7 +104,6 @@ def generate_image_with_gemini(
105
  )
106
 
107
  saved_file_path = None
108
- text_responses = []
109
  for chunk in stream:
110
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
111
  continue
@@ -117,15 +115,11 @@ def generate_image_with_gemini(
117
  full_file_name = f"{output_file_base}{file_extension}"
118
  save_binary_file(full_file_name, inline_data.data)
119
  saved_file_path = full_file_name
120
- elif part.text:
121
- text_responses.append(part.text)
122
-
123
  if saved_file_path:
124
  print(f"βœ… Successfully generated and saved image: {saved_file_path}")
125
- elif text_responses:
126
- print(f"⚠️ No image generated. API returned text only: {' '.join(text_responses)}")
127
  else:
128
- print("⚠️ No image or text was returned from the API.")
129
 
130
  return saved_file_path
131
 
@@ -133,27 +127,19 @@ def generate_image_with_gemini(
133
  print(f"❌ API Invalid Argument Error: {e}")
134
  traceback.print_exc()
135
  return None
136
- except TypeError as e:
137
- print(f"❌ TypeError during API call: {e}")
138
- print(" -> This often indicates an incorrect way of creating API objects like 'Part'. Please double-check SDK documentation.")
139
- traceback.print_exc()
140
- return None
141
  except Exception as e:
142
- print(f"❌ An unexpected error occurred: {e}")
143
  traceback.print_exc()
144
  return None
145
 
146
- def generate_all_images_from_file(
147
- json_path: str,
148
- output_dir: str = "generated_images",
149
- output_json_path: str = "multimedia_data_with_images.json"
150
- ) -> List[Dict]:
151
  try:
152
  with open(json_path, 'r', encoding='utf-8') as f:
153
  multimedia_data = json.load(f)
154
  except (FileNotFoundError, json.JSONDecodeError) as e:
155
  print(f"❌ Error reading or parsing {json_path}: {e}")
156
- return []
157
 
158
  if not os.path.exists(output_dir):
159
  os.makedirs(output_dir)
@@ -164,42 +150,29 @@ def generate_all_images_from_file(
164
  print(f"\n{'='*60}\nProcessing item {i+1}/{len(multimedia_data)}\n{'='*60}")
165
  image_prompt = item.get("image_prompt")
166
 
167
- if not image_prompt or "Error:" in image_prompt:
168
- print(f"⚠️ Skipping item {i+1}: No valid image prompt found.")
169
  item["image_path"] = None
170
  continue
171
 
172
- print(f"πŸ“ Image prompt: {image_prompt[:100]}{'...' if len(image_prompt) > 100 else ''}")
173
  file_base_path = os.path.join(output_dir, f"image_{i:03d}")
174
-
175
  saved_image_path = generate_image_with_gemini(
176
- image_prompt,
177
- file_base_path,
178
- context_image=previous_image
179
  )
180
 
181
  item["image_path"] = saved_image_path
182
 
183
- if saved_image_path and os.path.exists(saved_image_path):
184
  try:
185
  previous_image = Image.open(saved_image_path)
186
  successful_generations += 1
187
- print(f"βœ… Loaded image {saved_image_path} as context for the next generation.")
188
  except Exception as e:
189
- print(f"⚠️ Could not load image {saved_image_path} for context. Error: {e}")
190
  previous_image = None
191
  else:
192
- print("❌ No image was generated for this item. Context will be reset.")
193
  previous_image = None
194
-
195
  time.sleep(2)
196
 
197
- try:
198
- with open(output_json_path, 'w', encoding='utf-8') as f:
199
- json.dump(multimedia_data, f, indent=2, ensure_ascii=False)
200
- print(f"\n--- βœ… Image generation finished. Updated data saved to {output_json_path}. ---")
201
- print(f"Successfully generated {successful_generations}/{len(multimedia_data)} images.")
202
- except Exception as e:
203
- print(f"❌ Error saving updated JSON: {e}")
204
-
205
- return multimedia_data
 
1
+ # image_generation.py
2
+
3
  import os
4
  import mimetypes
5
  import json
 
9
  import traceback
10
  from PIL import Image
11
  from typing import List, Dict, Optional
12
+
13
+ # CORRECT IMPORTS FOR THE 'google-genai' SDK
14
  from google import genai
15
  from google.generativeai import types
16
  from google.api_core import exceptions
17
 
18
  # --- Client Initialization ---
19
+ # This section initializes the client using the secrets from your HF Space.
20
  client = None
21
  try:
22
  api_key = st.secrets.get("GEMINI_API_KEY")
 
29
  st.stop()
30
  except Exception as e:
31
  print(f"❌ Error initializing Google AI client: {e}")
32
+ st.error(f"An unexpected error occurred during client initialization: {e}")
33
  st.stop()
34
 
35
  # --- Helper Functions ---
36
+
37
  def save_binary_file(file_name: str, data: bytes):
38
  """Saves binary data to a file."""
39
  try:
 
48
  img_byte_arr = io.BytesIO()
49
  image.save(img_byte_arr, format='JPEG')
50
  img_bytes = img_byte_arr.getvalue()
51
+ return types.Part(inline_data=types.Blob(mime_type="image/jpeg", data=img_bytes))
 
 
 
 
52
 
53
  def generate_image_with_gemini(
54
  prompt: str,
55
  output_file_base: str,
56
  context_image: Optional[Image.Image] = None
57
  ) -> Optional[str]:
58
+ """Generates an image using the Gemini API with the corrected SDK calls."""
 
 
59
  if not client:
60
  print("❌ Gemini client not initialized.")
61
  return None
 
66
  model_name = "gemini-2.0-flash-preview-image-generation"
67
  content_parts = []
68
 
 
69
  if context_image:
70
  system_prompt = """You are a master storyboard artist creating a visual story sequence.
71
  IMPORTANT: You MUST generate an image for every request. Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in:
 
82
  - Dramatic lighting and atmosphere
83
  - High-quality digital painting aesthetic
84
  This is the first scene of the story. Generate an image that illustrates:"""
85
+
86
+ content_parts.append(types.Part(text=system_prompt))
87
 
88
  if context_image:
89
  content_parts.append(pil_image_to_part(context_image))
90
 
91
+ image_instruction = f"CREATE AN IMAGE NOW:\n{prompt}\nRemember: You must generate a visual image."
92
+ content_parts.append(types.Part(text=image_instruction))
93
 
94
  contents = [types.Content(role="user", parts=content_parts)]
95
 
 
104
  )
105
 
106
  saved_file_path = None
 
107
  for chunk in stream:
108
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
109
  continue
 
115
  full_file_name = f"{output_file_base}{file_extension}"
116
  save_binary_file(full_file_name, inline_data.data)
117
  saved_file_path = full_file_name
118
+
 
 
119
  if saved_file_path:
120
  print(f"βœ… Successfully generated and saved image: {saved_file_path}")
 
 
121
  else:
122
+ print("⚠️ No image was returned from the API.")
123
 
124
  return saved_file_path
125
 
 
127
  print(f"❌ API Invalid Argument Error: {e}")
128
  traceback.print_exc()
129
  return None
 
 
 
 
 
130
  except Exception as e:
131
+ print(f"❌ An unexpected error occurred during the Gemini API call: {e}")
132
  traceback.print_exc()
133
  return None
134
 
135
+ def generate_all_images_from_file(json_path: str, output_dir: str, output_json_path: str):
136
+ """Main loop to process a JSON file and generate images."""
 
 
 
137
  try:
138
  with open(json_path, 'r', encoding='utf-8') as f:
139
  multimedia_data = json.load(f)
140
  except (FileNotFoundError, json.JSONDecodeError) as e:
141
  print(f"❌ Error reading or parsing {json_path}: {e}")
142
+ return
143
 
144
  if not os.path.exists(output_dir):
145
  os.makedirs(output_dir)
 
150
  print(f"\n{'='*60}\nProcessing item {i+1}/{len(multimedia_data)}\n{'='*60}")
151
  image_prompt = item.get("image_prompt")
152
 
153
+ if not image_prompt:
 
154
  item["image_path"] = None
155
  continue
156
 
 
157
  file_base_path = os.path.join(output_dir, f"image_{i:03d}")
 
158
  saved_image_path = generate_image_with_gemini(
159
+ image_prompt, file_base_path, context_image=previous_image
 
 
160
  )
161
 
162
  item["image_path"] = saved_image_path
163
 
164
+ if saved_image_path:
165
  try:
166
  previous_image = Image.open(saved_image_path)
167
  successful_generations += 1
 
168
  except Exception as e:
 
169
  previous_image = None
170
  else:
 
171
  previous_image = None
172
+
173
  time.sleep(2)
174
 
175
+ with open(output_json_path, 'w', encoding='utf-8') as f:
176
+ json.dump(multimedia_data, f, indent=2, ensure_ascii=False)
177
+
178
+ print(f"\n--- βœ… Finished. Generated {successful_generations}/{len(multimedia_data)} images. ---")