abhinav0231 commited on
Commit
b61fa02
·
verified ·
1 Parent(s): bb2489e

Update image_generation.py

Browse files
Files changed (1) hide show
  1. image_generation.py +86 -27
image_generation.py CHANGED
@@ -4,6 +4,8 @@ import json
4
  import streamlit as st
5
  import google.generativeai as genai
6
  import google.api_core.exceptions
 
 
7
  from typing import List, Dict, Optional
8
  from PIL import Image
9
  import io
@@ -31,50 +33,107 @@ def save_binary_file(file_name: str, data: bytes):
31
  print(f"❌ Error saving file {file_name}: {e}")
32
 
33
  # --- IMAGE GENERATION FUNCTION ---
 
 
 
 
 
 
 
 
34
  def generate_image_with_gemini(
35
  prompt: str,
36
  output_file_base: str,
37
  context_image: Optional[Image.Image] = None
38
  ) -> Optional[str]:
39
- """
40
- Generates an image and now specifically handles ResourceExhausted errors.
41
- """
42
  print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")
 
43
  try:
44
- model = genai.GenerativeModel(model_name="gemini-2.0-flash-preview-image-generation")
45
 
46
- content_parts = []
47
  if context_image:
48
- content_parts.extend([prompt, context_image])
 
 
 
 
 
 
 
 
49
  else:
50
- content_parts.append(prompt)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- response = model.generate_content(
53
- contents=content_parts,
54
- stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  )
56
 
57
  saved_file_path = None
58
- for chunk in response:
59
- if chunk.parts and chunk.parts[0].inline_data:
60
- part = chunk.parts[0]
61
- data = part.inline_data.data
62
- mime_type = part.inline_data.mime_type
63
- file_extension = mimetypes.guess_extension(mime_type) or ".jpg"
64
- full_file_name = f"{output_file_base}{file_extension}"
65
- save_binary_file(full_file_name, data)
66
- saved_file_path = full_file_name
67
- print(f"✅ Successfully generated and saved image: {full_file_name}")
68
- return saved_file_path
69
-
70
- return None # Return None if no image was generated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- except google.api_core.exceptions.ResourceExhausted as e:
73
- # --- CATCH and signal the rate limit error ---
74
- print(f"🔴 RATE LIMIT EXCEEDED. The script will wait and retry.")
75
- return "RATE_LIMIT_EXCEEDED"
76
  except Exception as e:
77
  print(f"❌ An error occurred during the Gemini API call: {e}")
 
78
  traceback.print_exc()
79
  return None
80
 
 
4
  import streamlit as st
5
  import google.generativeai as genai
6
  import google.api_core.exceptions
7
+ from google import genai
8
+ from google.genai import types
9
  from typing import List, Dict, Optional
10
  from PIL import Image
11
  import io
 
33
  print(f"❌ Error saving file {file_name}: {e}")
34
 
35
  # --- IMAGE GENERATION FUNCTION ---
36
+ def pil_image_to_part(image: Image.Image) -> types.Part:
37
+ img_byte_arr = io.BytesIO()
38
+ image.save(img_byte_arr, format="JPEG")
39
+ return types.Part.from_bytes(
40
+ data=img_byte_arr.getvalue(),
41
+ mime_type="image/jpeg",
42
+ )
43
+
44
  def generate_image_with_gemini(
45
  prompt: str,
46
  output_file_base: str,
47
  context_image: Optional[Image.Image] = None
48
  ) -> Optional[str]:
49
+ if not client:
50
+ return None
51
+
52
  print(f"--- 🎨 Generating image for prompt: '{prompt[:70]}...' ---")
53
+
54
  try:
55
+ model = "gemini-2.0-flash-preview-image-generation"
56
 
57
+ # Build contents
58
  if context_image:
59
+ system_prompt = (
60
+ "You are a master storyboard artist creating a visual story sequence.\n"
61
+ "IMPORTANT: You MUST generate an image for every request.\n"
62
+ "Create a visually consistent image that follows the art style and character design of the provided reference image.\n"
63
+ "Maintain consistency in character appearance, art style, color palette, and lighting.\n"
64
+ "Style: Cinematic, epic fantasy digital painting with rich details and dramatic lighting.\n"
65
+ "Generate an image that illustrates the following scene:"
66
+ )
67
+ print(" -> Using previous image as context for consistent styling.")
68
  else:
69
+ system_prompt = (
70
+ "You are a master storyboard artist creating the opening scene of a visual story.\n"
71
+ "IMPORTANT: You MUST generate an image for this request.\n"
72
+ "Create a stunning, cinematic image in an epic fantasy digital painting style with rich, detailed artwork and dramatic lighting.\n"
73
+ "This is the first scene of the story. Generate an image that illustrates:"
74
+ )
75
+
76
+ content_parts = [
77
+ types.Part.from_text(text=system_prompt)
78
+ ]
79
+ if context_image:
80
+ content_parts.append(pil_image_to_part(context_image))
81
 
82
+ image_instruction = f"""CREATE AN IMAGE NOW:
83
+
84
+ {prompt}
85
+
86
+ Remember: You must generate a visual image, not text. Create the artwork described above."""
87
+ content_parts.append(types.Part.from_text(text=image_instruction))
88
+
89
+ contents = [
90
+ types.Content(
91
+ role="user",
92
+ parts=content_parts,
93
+ )
94
+ ]
95
+
96
+ # CRITICAL: request both IMAGE and TEXT
97
+ generate_content_config = types.GenerateContentConfig(
98
+ response_modalities=["IMAGE", "TEXT"]
99
  )
100
 
101
  saved_file_path = None
102
+ text_responses = []
103
+
104
+ # Stream the response
105
+ for chunk in client.models.generate_content_stream(
106
+ model=model,
107
+ contents=contents,
108
+ config=generate_content_config,
109
+ ):
110
+ cand = getattr(chunk, "candidates", None)
111
+ if not cand or not cand.content or not cand.content.parts:
112
+ continue
113
+
114
+ for part in cand.content.parts:
115
+ # Image bytes
116
+ if getattr(part, "inline_data", None) and getattr(part.inline_data, "data", None):
117
+ inline_data = part.inline_data
118
+ data_buffer = inline_data.data
119
+ file_extension = mimetypes.guess_extension(inline_data.mime_type) or ".jpg"
120
+ full_file_name = f"{output_file_base}{file_extension}"
121
+ save_binary_file(full_file_name, data_buffer)
122
+ saved_file_path = full_file_name
123
+ print(f"✅ Successfully generated and saved image: {full_file_name}")
124
+
125
+ # Text side-channel
126
+ if getattr(part, "text", None):
127
+ text_responses.append(part.text)
128
+
129
+ if not saved_file_path and text_responses:
130
+ print(f"⚠️ No image generated. API returned text: {' '.join(text_responses)}")
131
+
132
+ return saved_file_path
133
 
 
 
 
 
134
  except Exception as e:
135
  print(f"❌ An error occurred during the Gemini API call: {e}")
136
+ import traceback
137
  traceback.print_exc()
138
  return None
139