abhinav0231 commited on
Commit
a296a03
·
verified ·
1 Parent(s): 94047f1

Update image_generation.py

Browse files
Files changed (1) hide show
  1. image_generation.py +57 -64
image_generation.py CHANGED
@@ -3,13 +3,13 @@ import mimetypes
3
  import json
4
  import streamlit as st
5
  import google.generativeai as genai
 
6
  from typing import List, Dict, Optional
7
  from PIL import Image
8
  import io
9
- import time
10
  import traceback
 
11
 
12
- # Configure the client once with the API key
13
  try:
14
  api_key = st.secrets.get("GEMINI_API_KEY") or os.getenv("GEMINI_API_KEY")
15
  if api_key:
@@ -17,15 +17,11 @@ try:
17
  print("✅ Google AI client for Gemini configured successfully.")
18
  else:
19
  print("⚠️ Warning: GEMINI_API_KEY not found.")
20
- # Exit if running in a context where it should be mandatory
21
- # exit(1)
22
  except Exception as e:
23
  print(f"❌ Error configuring Google AI client: {e}")
24
- # exit(1)
25
 
26
- # --- Helper Function ---
27
  def save_binary_file(file_name: str, data: bytes):
28
- """Saves binary data to a file."""
29
  try:
30
  with open(file_name, "wb") as f:
31
  f.write(data)
@@ -33,72 +29,73 @@ def save_binary_file(file_name: str, data: bytes):
33
  except Exception as e:
34
  print(f"❌ Error saving file {file_name}: {e}")
35
 
36
- # --- IMAGE GENERATION FUNCTION ---
 
 
 
 
 
 
 
 
37
  def generate_image_with_gemini(
38
  prompt: str,
39
  output_file_base: str,
40
  context_image: Optional[Image.Image] = None
41
  ) -> Optional[str]:
42
  """
43
- Generates an image using the updated Gemini API syntax, optionally using a
44
- previous image as context.
45
  """
46
  print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")
47
-
48
  try:
49
- # Define the specific image generation model
50
- model = genai.GenerativeModel(model_name="gemini-1.5-flash") # Use a current, valid model
51
-
52
- # --- UPDATED CONTENT STRUCTURE ---
53
- # Build a simple list of content parts. The SDK handles the type.
54
  content_parts = []
55
-
56
  if context_image:
57
- system_prompt = """You are a master storyboard artist creating a visual story sequence.
58
- IMPORTANT: You MUST generate an image for every request.
59
- Create a visually consistent image that follows the art style and character design of the provided reference image. Maintain consistency in:
60
- - Character appearance and clothing
61
- - Art style and color palette
62
- - Lighting and atmosphere
63
- - Overall visual tone
64
- Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.
65
- Generate an image that illustrates the following scene:"""
66
  print(" -> Using previous image as context for consistent styling.")
67
- # Add the system prompt, context image, and user prompt
68
- content_parts = [system_prompt, context_image, f"CREATE AN IMAGE NOW: {prompt}"]
69
  else:
70
- system_prompt = """You are a master storyboard artist creating the opening scene of a visual story.
71
- IMPORTANT: You MUST generate an image for this request.
72
- Create a stunning, cinematic image in an epic fantasy digital painting style with:
73
- - Rich, detailed artwork
74
- - Dramatic lighting and atmosphere
75
- - High-quality digital painting aesthetic
76
- - Vivid colors and intricate details
77
- This is the first scene of the story. Generate an image that illustrates:"""
78
- # Add the system prompt and user prompt
79
- content_parts = [system_prompt, f"CREATE AN IMAGE NOW: {prompt}"]
80
-
81
- # --- UPDATED API CALL ---
82
- # The generate_content method now takes a simple list
83
- response = model.generate_content(content_parts, stream=True)
 
 
 
 
 
84
 
85
  saved_file_path = None
86
  text_responses = []
87
-
88
- for chunk in response:
89
- # The modern SDK provides image data directly in parts
90
- if chunk.parts and chunk.parts[0].file_data:
91
- file_data = chunk.parts[0].file_data
92
- data_buffer = file_data.data
93
- file_extension = mimetypes.guess_extension(file_data.mime_type) or ".jpg"
94
- full_file_name = f"{output_file_base}{file_extension}"
95
- save_binary_file(full_file_name, data_buffer)
96
- saved_file_path = full_file_name
97
- print(f"✅ Successfully generated and saved image: {full_file_name}")
98
-
99
- # Collect any text responses for debugging
100
- if chunk.text:
101
- text_responses.append(chunk.text)
102
 
103
  if not saved_file_path and text_responses:
104
  print(f"⚠️ No image generated. API returned text: {' '.join(text_responses)}")
@@ -110,16 +107,13 @@ This is the first scene of the story. Generate an image that illustrates:"""
110
  traceback.print_exc()
111
  return None
112
 
113
- # The rest of your script (generate_all_images_from_file) does not need changes
114
- # as it correctly calls the function above.
115
  def generate_all_images_from_file(
116
  json_path: str,
117
  output_dir: str = "generated_images",
118
  output_json_path: str = "multimedia_data_with_images.json"
119
  ) -> List[Dict[str, str]]:
120
- """
121
- Reads data from a JSON, generates images sequentially, and saves a new JSON with image paths.
122
- """
123
  try:
124
  with open(json_path, 'r', encoding='utf-8') as f:
125
  multimedia_data = json.load(f)
@@ -183,5 +177,4 @@ def generate_all_images_from_file(
183
  except Exception as e:
184
  print(f"❌ Error saving updated JSON: {e}")
185
 
186
- return multimedia_data
187
-
 
3
  import json
4
  import streamlit as st
5
  import google.generativeai as genai
6
+ from google.generativeai import types
7
  from typing import List, Dict, Optional
8
  from PIL import Image
9
  import io
 
10
  import traceback
11
+ import time
12
 
 
13
  try:
14
  api_key = st.secrets.get("GEMINI_API_KEY") or os.getenv("GEMINI_API_KEY")
15
  if api_key:
 
17
  print("✅ Google AI client for Gemini configured successfully.")
18
  else:
19
  print("⚠️ Warning: GEMINI_API_KEY not found.")
 
 
20
  except Exception as e:
21
  print(f"❌ Error configuring Google AI client: {e}")
 
22
 
23
+ # --- Helper Functions (Unchanged) ---
24
  def save_binary_file(file_name: str, data: bytes):
 
25
  try:
26
  with open(file_name, "wb") as f:
27
  f.write(data)
 
29
  except Exception as e:
30
  print(f"❌ Error saving file {file_name}: {e}")
31
 
32
+ def pil_image_to_part(image: Image.Image) -> types.Part:
33
+ img_byte_arr = io.BytesIO()
34
+ image.save(img_byte_arr, format='JPEG')
35
+ img_bytes = img_byte_arr.getvalue()
36
+ return types.Part.from_bytes(
37
+ data=img_bytes,
38
+ mime_type='image/jpeg'
39
+ )
40
+
41
  def generate_image_with_gemini(
42
  prompt: str,
43
  output_file_base: str,
44
  context_image: Optional[Image.Image] = None
45
  ) -> Optional[str]:
46
  """
47
+ Generates an image, using your original SDK logic and model.
 
48
  """
49
  print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")
 
50
  try:
51
+ # Your model and content structure are preserved
52
+ model_name = "gemini-2.5-flash-image-preview"
53
+ model = genai.GenerativeModel(model_name)
54
+
 
55
  content_parts = []
 
56
  if context_image:
57
+ system_prompt = """You are a master storyboard artist creating a visual story sequence... (rest of your prompt)"""
 
 
 
 
 
 
 
 
58
  print(" -> Using previous image as context for consistent styling.")
59
+ content_parts.append(types.Part.from_text(text=system_prompt))
60
+ content_parts.append(pil_image_to_part(context_image))
61
  else:
62
+ system_prompt = """You are a master storyboard artist creating the opening scene... (rest of your prompt)"""
63
+ content_parts.append(types.Part.from_text(text=system_prompt))
64
+
65
+ image_instruction = f"""CREATE AN IMAGE NOW:\n{prompt}\nRemember: You must generate a visual image..."""
66
+ content_parts.append(types.Part.from_text(text=image_instruction))
67
+
68
+ contents = [types.Content(role="user", parts=content_parts)]
69
+
70
+ generate_content_config = types.GenerateContentConfig(
71
+ response_modalities=["IMAGE", "TEXT"],
72
+ )
73
+
74
+ # --- CORRECTED API CALL ---
75
+ # The streaming call is made directly on the model object, not a separate client.
76
+ stream = model.generate_content(
77
+ contents=contents,
78
+ generation_config=generate_content_config, # Corrected parameter name
79
+ stream=True
80
+ )
81
 
82
  saved_file_path = None
83
  text_responses = []
84
+
85
+ # Your original response handling logic
86
+ for chunk in stream:
87
+ if (chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts):
88
+ part = chunk.candidates[0].content.parts[0]
89
+ if part.inline_data and part.inline_data.data:
90
+ inline_data = part.inline_data
91
+ data_buffer = inline_data.data
92
+ file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg"
93
+ full_file_name = f"{output_file_base}{file_extension}"
94
+ save_binary_file(full_file_name, data_buffer)
95
+ saved_file_path = full_file_name
96
+ print(f"✅ Successfully generated and saved image: {full_file_name}")
97
+ elif hasattr(chunk, 'text') and chunk.text:
98
+ text_responses.append(chunk.text)
99
 
100
  if not saved_file_path and text_responses:
101
  print(f"⚠️ No image generated. API returned text: {' '.join(text_responses)}")
 
107
  traceback.print_exc()
108
  return None
109
 
110
+
 
111
  def generate_all_images_from_file(
112
  json_path: str,
113
  output_dir: str = "generated_images",
114
  output_json_path: str = "multimedia_data_with_images.json"
115
  ) -> List[Dict[str, str]]:
116
+
 
 
117
  try:
118
  with open(json_path, 'r', encoding='utf-8') as f:
119
  multimedia_data = json.load(f)
 
177
  except Exception as e:
178
  print(f"❌ Error saving updated JSON: {e}")
179
 
180
+ return multimedia_data