abhinav0231 commited on
Commit
177b6e5
Β·
verified Β·
1 Parent(s): 6340efd

Update image_generation.py

Browse files
Files changed (1) hide show
  1. image_generation.py +14 -28
image_generation.py CHANGED
@@ -1,25 +1,19 @@
1
- # image_generation.py
2
-
3
  import os
4
  import mimetypes
5
  import json
6
- import base64
7
  import streamlit as st
8
  import io
9
  import time
10
  import traceback
11
  from PIL import Image
12
  from typing import List, Dict, Optional
13
-
14
- # CORRECT: Import from the modern 'google-genai' SDK
15
  from google import genai
16
- from google.genai import types
17
  from google.api_core import exceptions
18
 
19
  # --- Client Initialization ---
20
  client = None
21
  try:
22
- # Ensure the API key is set in Hugging Face Space secrets
23
  api_key = st.secrets.get("GEMINI_API_KEY")
24
  if api_key:
25
  client = genai.Client(api_key=api_key)
@@ -34,7 +28,6 @@ except Exception as e:
34
  st.stop()
35
 
36
  # --- Helper Functions ---
37
-
38
  def save_binary_file(file_name: str, data: bytes):
39
  """Saves binary data to a file."""
40
  try:
@@ -47,10 +40,9 @@ def save_binary_file(file_name: str, data: bytes):
47
  def pil_image_to_part(image: Image.Image) -> types.Part:
48
  """Converts a PIL Image to a genai.types.Part object."""
49
  img_byte_arr = io.BytesIO()
50
- # Save image to an in-memory byte stream
51
  image.save(img_byte_arr, format='JPEG')
52
  img_bytes = img_byte_arr.getvalue()
53
- # CORRECT: This works with 'google-genai'
54
  return types.Part.from_bytes(
55
  data=img_bytes,
56
  mime_type='image/jpeg'
@@ -62,8 +54,7 @@ def generate_image_with_gemini(
62
  context_image: Optional[Image.Image] = None
63
  ) -> Optional[str]:
64
  """
65
- Generates an image using the Gemini API, optionally with a context image.
66
- This function is now fully compatible with the 'google-genai' SDK.
67
  """
68
  if not client:
69
  print("❌ Gemini client not initialized.")
@@ -84,7 +75,7 @@ def generate_image_with_gemini(
84
  - Lighting and atmosphere
85
  Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.
86
  Generate an image that illustrates the following scene:"""
87
- print(" -> Using previous image as context for consistent styling.")
88
  else:
89
  system_prompt = """You are a master storyboard artist creating the opening scene of a visual story.
90
  IMPORTANT: You MUST generate an image for this request. Create a stunning, cinematic image in an epic fantasy digital painting style with:
@@ -93,8 +84,8 @@ def generate_image_with_gemini(
93
  - High-quality digital painting aesthetic
94
  This is the first scene of the story. Generate an image that illustrates:"""
95
 
96
- # Build the request content
97
  content_parts.append(types.Part.from_text(system_prompt))
 
98
  if context_image:
99
  content_parts.append(pil_image_to_part(context_image))
100
 
@@ -103,12 +94,10 @@ def generate_image_with_gemini(
103
 
104
  contents = [types.Content(role="user", parts=content_parts)]
105
 
106
- # The model requires both IMAGE and TEXT modalities.
107
  generate_content_config = types.GenerateContentConfig(
108
  response_modalities=["IMAGE", "TEXT"],
109
  )
110
 
111
- # Generate content using the streaming API
112
  stream = client.models.generate_content_stream(
113
  model=model_name,
114
  contents=contents,
@@ -121,23 +110,18 @@ def generate_image_with_gemini(
121
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
122
  continue
123
 
124
- # Iterate through all parts in the chunk
125
  for part in chunk.candidates[0].content.parts:
126
  if part.inline_data and part.inline_data.data:
127
- # This part is an image
128
  inline_data = part.inline_data
129
  file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg"
130
  full_file_name = f"{output_file_base}{file_extension}"
131
  save_binary_file(full_file_name, inline_data.data)
132
  saved_file_path = full_file_name
133
  elif part.text:
134
- # This part is text
135
  text_responses.append(part.text)
136
 
137
  if saved_file_path:
138
  print(f"βœ… Successfully generated and saved image: {saved_file_path}")
139
- if text_responses:
140
- print(f" -> Accompanied text: {''.join(text_responses)}")
141
  elif text_responses:
142
  print(f"⚠️ No image generated. API returned text only: {' '.join(text_responses)}")
143
  else:
@@ -147,11 +131,15 @@ def generate_image_with_gemini(
147
 
148
  except exceptions.InvalidArgument as e:
149
  print(f"❌ API Invalid Argument Error: {e}")
150
- print(" -> This often means the model or request parameters are wrong. Check model name and modalities.")
 
 
 
 
151
  traceback.print_exc()
152
  return None
153
  except Exception as e:
154
- print(f"❌ An unexpected error occurred during the Gemini API call: {e}")
155
  traceback.print_exc()
156
  return None
157
 
@@ -160,9 +148,6 @@ def generate_all_images_from_file(
160
  output_dir: str = "generated_images",
161
  output_json_path: str = "multimedia_data_with_images.json"
162
  ) -> List[Dict]:
163
- """
164
- Reads data from a JSON, generates images sequentially, and saves a new JSON with image paths.
165
- """
166
  try:
167
  with open(json_path, 'r', encoding='utf-8') as f:
168
  multimedia_data = json.load(f)
@@ -202,12 +187,13 @@ def generate_all_images_from_file(
202
  print(f"βœ… Loaded image {saved_image_path} as context for the next generation.")
203
  except Exception as e:
204
  print(f"⚠️ Could not load image {saved_image_path} for context. Error: {e}")
205
- previous_image = None # Don't use a corrupted image as context
206
  else:
207
  print("❌ No image was generated for this item. Context will be reset.")
208
  previous_image = None
209
 
210
- # Save the final JSON with all image paths
 
211
  try:
212
  with open(output_json_path, 'w', encoding='utf-8') as f:
213
  json.dump(multimedia_data, f, indent=2, ensure_ascii=False)
 
 
 
1
  import os
2
  import mimetypes
3
  import json
 
4
  import streamlit as st
5
  import io
6
  import time
7
  import traceback
8
  from PIL import Image
9
  from typing import List, Dict, Optional
 
 
10
  from google import genai
11
+ from google.generativeai import types
12
  from google.api_core import exceptions
13
 
14
  # --- Client Initialization ---
15
  client = None
16
  try:
 
17
  api_key = st.secrets.get("GEMINI_API_KEY")
18
  if api_key:
19
  client = genai.Client(api_key=api_key)
 
28
  st.stop()
29
 
30
  # --- Helper Functions ---
 
31
  def save_binary_file(file_name: str, data: bytes):
32
  """Saves binary data to a file."""
33
  try:
 
40
  def pil_image_to_part(image: Image.Image) -> types.Part:
41
  """Converts a PIL Image to a genai.types.Part object."""
42
  img_byte_arr = io.BytesIO()
 
43
  image.save(img_byte_arr, format='JPEG')
44
  img_bytes = img_byte_arr.getvalue()
45
+ # This call was correct and remains the same.
46
  return types.Part.from_bytes(
47
  data=img_bytes,
48
  mime_type='image/jpeg'
 
54
  context_image: Optional[Image.Image] = None
55
  ) -> Optional[str]:
56
  """
57
+ Generates an image using the Gemini API, now with all corrections.
 
58
  """
59
  if not client:
60
  print("❌ Gemini client not initialized.")
 
75
  - Lighting and atmosphere
76
  Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.
77
  Generate an image that illustrates the following scene:"""
78
+ print(" -> Using previous image as context.")
79
  else:
80
  system_prompt = """You are a master storyboard artist creating the opening scene of a visual story.
81
  IMPORTANT: You MUST generate an image for this request. Create a stunning, cinematic image in an epic fantasy digital painting style with:
 
84
  - High-quality digital painting aesthetic
85
  This is the first scene of the story. Generate an image that illustrates:"""
86
 
 
87
  content_parts.append(types.Part.from_text(system_prompt))
88
+
89
  if context_image:
90
  content_parts.append(pil_image_to_part(context_image))
91
 
 
94
 
95
  contents = [types.Content(role="user", parts=content_parts)]
96
 
 
97
  generate_content_config = types.GenerateContentConfig(
98
  response_modalities=["IMAGE", "TEXT"],
99
  )
100
 
 
101
  stream = client.models.generate_content_stream(
102
  model=model_name,
103
  contents=contents,
 
110
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
111
  continue
112
 
 
113
  for part in chunk.candidates[0].content.parts:
114
  if part.inline_data and part.inline_data.data:
 
115
  inline_data = part.inline_data
116
  file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg"
117
  full_file_name = f"{output_file_base}{file_extension}"
118
  save_binary_file(full_file_name, inline_data.data)
119
  saved_file_path = full_file_name
120
  elif part.text:
 
121
  text_responses.append(part.text)
122
 
123
  if saved_file_path:
124
  print(f"βœ… Successfully generated and saved image: {saved_file_path}")
 
 
125
  elif text_responses:
126
  print(f"⚠️ No image generated. API returned text only: {' '.join(text_responses)}")
127
  else:
 
131
 
132
  except exceptions.InvalidArgument as e:
133
  print(f"❌ API Invalid Argument Error: {e}")
134
+ traceback.print_exc()
135
+ return None
136
+ except TypeError as e:
137
+ print(f"❌ TypeError during API call: {e}")
138
+ print(" -> This often indicates an incorrect way of creating API objects like 'Part'. Please double-check SDK documentation.")
139
  traceback.print_exc()
140
  return None
141
  except Exception as e:
142
+ print(f"❌ An unexpected error occurred: {e}")
143
  traceback.print_exc()
144
  return None
145
 
 
148
  output_dir: str = "generated_images",
149
  output_json_path: str = "multimedia_data_with_images.json"
150
  ) -> List[Dict]:
 
 
 
151
  try:
152
  with open(json_path, 'r', encoding='utf-8') as f:
153
  multimedia_data = json.load(f)
 
187
  print(f"βœ… Loaded image {saved_image_path} as context for the next generation.")
188
  except Exception as e:
189
  print(f"⚠️ Could not load image {saved_image_path} for context. Error: {e}")
190
+ previous_image = None
191
  else:
192
  print("❌ No image was generated for this item. Context will be reset.")
193
  previous_image = None
194
 
195
+ time.sleep(2)
196
+
197
  try:
198
  with open(output_json_path, 'w', encoding='utf-8') as f:
199
  json.dump(multimedia_data, f, indent=2, ensure_ascii=False)