rairo commited on
Commit
a6314a0
·
verified ·
1 Parent(s): b47e518

Update image_gen.py

Browse files
Files changed (1) hide show
  1. image_gen.py +209 -0
image_gen.py CHANGED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------
2
+ # Image Generation
3
+ # -----------------------
4
+
5
+
6
+ import os
7
+ import re
8
+ import time
9
+ import tempfile
10
+ import requests
11
+ import json
12
+ from google import genai
13
+ from google.genai import types
14
+ import io
15
+ import base64
16
+ import numpy as np
17
+ import cv2
18
+ import logging
19
+ import uuid
20
+ import subprocess
21
+ from pathlib import Path
22
+ import urllib.parse
23
+ import pandas as pd
24
+ import plotly.graph_objects as go
25
+ import matplotlib.pyplot as plt
26
+ import base64
27
+ import os
28
+ import uuid
29
+ import matplotlib
30
+ import matplotlib.pyplot as plt
31
+ from io import BytesIO
32
+ import dataframe_image as dfi
33
+ import uuid
34
+ from PIL import ImageFont, ImageDraw, Image
35
+ import seaborn as sns
36
+
37
+
38
+ logging.basicConfig(level=logging.INFO)
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+
43
+ def is_valid_png(file_path):
44
+ """Check if the PNG file at `file_path` is valid."""
45
+ try:
46
+ with open(file_path, "rb") as f:
47
+ # Read the first 8 bytes to check the PNG signature
48
+ header = f.read(8)
49
+ if header != b'\x89PNG\r\n\x1a\n':
50
+ return False
51
+
52
+ # Attempt to open and verify the entire image
53
+ with Image.open(file_path) as img:
54
+ img.verify() # Verify the file integrity
55
+ return True
56
+ except Exception as e:
57
+ print(f"Invalid PNG file at {file_path}: {e}")
58
+ return False
59
+
60
+
61
+ def standardize_and_validate_image(file_path):
62
+ """Validate, standardize, and overwrite the image at `file_path`."""
63
+ try:
64
+ # Verify basic integrity
65
+ with Image.open(file_path) as img:
66
+ img.verify()
67
+
68
+ # Reopen and convert to RGB
69
+ with Image.open(file_path) as img:
70
+ img = img.convert("RGB") # Remove alpha channel if present
71
+
72
+ # Save to a temporary BytesIO buffer first
73
+ buffer = io.BytesIO()
74
+ img.save(buffer, format="PNG")
75
+ buffer.seek(0)
76
+
77
+ # Write the buffer to the file
78
+ with open(file_path, "wb") as f:
79
+ f.write(buffer.getvalue())
80
+
81
+ return True
82
+ except Exception as e:
83
+ print(f"Failed to standardize/validate {file_path}: {e}")
84
+ return False
85
+
86
+ def generate_image(prompt_text, style, model="hf"):
87
+ """
88
+ Generate an image from a text prompt using either Hugging Face's, Pollinations Turbo's,
89
+ or Google's Gemini API.
90
+ Args:
91
+ prompt_text (str): The text prompt for image generation.
92
+ style (str or None): The style of the image (used for HF and Gemini models).
93
+ model (str): Which model to use ("hf" for Hugging Face, "pollinations_turbo" for Pollinations Turbo,
94
+ or "gemini" for Google's Gemini).
95
+ Returns:
96
+ tuple: A tuple containing the generated PIL.Image and a Base64 string of the image.
97
+ """
98
+ try:
99
+ if model == "pollinations_turbo":
100
+ # URL-encode the prompt and add the query parameter to specify the model as "turbo"
101
+ prompt_encoded = urllib.parse.quote(prompt_text)
102
+ api_url = f"https://image.pollinations.ai/prompt/{prompt_encoded}?model=turbo"
103
+ response = requests.get(api_url)
104
+ if response.status_code != 200:
105
+ logger.error(f"Pollinations API error: {response.status_code}, {response.text}")
106
+ st.error(f"Error from image generation API: {response.status_code}")
107
+ return None, None
108
+ image_bytes = response.content
109
+
110
+ elif model == "gemini":
111
+ # For Google's Gemini model
112
+ try:
113
+
114
+ # Get API key from environment variable
115
+ g_api_key = os.getenv("GEMINI")
116
+ if not g_api_key:
117
+ logger.error("GEMINI_API_KEY not found in environment variables")
118
+ st.error("Google Gemini API key is missing. Please set the GEMINI_API_KEY environment variable.")
119
+ return None, None
120
+
121
+ # Initialize Gemini client
122
+ client = genai.Client(api_key=g_api_key)
123
+
124
+ # Enhance prompt with style
125
+ enhanced_prompt = f"image of {prompt_text} in {style} style, high quality, detailed illustration"
126
+
127
+ # Generate content
128
+ response = client.models.generate_content(
129
+ model="models/gemini-2.0-flash-exp",
130
+ contents=enhanced_prompt,
131
+ config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
132
+ )
133
+
134
+ # Extract image from response
135
+ for part in response.candidates[0].content.parts:
136
+ if part.inline_data is not None:
137
+ image = Image.open(BytesIO(part.inline_data.data))
138
+
139
+ # Convert to base64 string
140
+ buffered = io.BytesIO()
141
+ image.save(buffered, format="JPEG")
142
+ img_str = base64.b64encode(buffered.getvalue()).decode()
143
+
144
+ return image, img_str
145
+
146
+ # If no image was found in the response
147
+ logger.error("No image was found in the Gemini API response")
148
+ st.error("Gemini API didn't return an image")
149
+ return None, None
150
+
151
+ except ImportError:
152
+ logger.error("Google Gemini libraries not installed")
153
+ st.error("Google Gemini libraries not installed. Install with 'pip install google-genai'")
154
+ return None, None
155
+
156
+ except Exception as e:
157
+ logger.error(f"Gemini API error: {str(e)}")
158
+ st.error(f"Error from Gemini image generation: {str(e)}")
159
+ return None, None
160
+
161
+ else: # Default to Hugging Face model
162
+ # For Hugging Face model, include style details in the prompt
163
+ enhanced_prompt = f"{prompt_text} in {style} style, high quality, detailed illustration"
164
+ model_id = "black-forest-labs/FLUX.1-dev"
165
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
166
+ payload = {"inputs": enhanced_prompt}
167
+ response = requests.post(api_url, headers=headers, json=payload)
168
+ if response.status_code != 200:
169
+ logger.error(f"Hugging Face API error: {response.status_code}, {response.text}")
170
+ st.error(f"Error from image generation API: {response.status_code}")
171
+ return None, None
172
+ image_bytes = response.content
173
+
174
+ # For HF and Pollinations models that return image bytes
175
+ if model != "gemini":
176
+ image = Image.open(io.BytesIO(image_bytes))
177
+ buffered = io.BytesIO()
178
+ image.save(buffered, format="JPEG")
179
+ img_str = base64.b64encode(buffered.getvalue()).decode()
180
+ return image, img_str
181
+
182
+ except Exception as e:
183
+ st.error(f"Error generating image: {e}")
184
+ logger.error(f"Image generation error: {str(e)}")
185
+
186
+ # Return a placeholder image in case of failure
187
+ return Image.new('RGB', (1024, 1024), color=(200,200,200)), None
188
+
189
+ def generate_image_with_retry(prompt_text, style, model="hf", max_retries=3):
190
+ """
191
+ Attempt to generate an image using generate_image, retrying up to max_retries if needed.
192
+ Args:
193
+ prompt_text (str): The text prompt for image generation.
194
+ style (str or None): The style of the image (ignored for Pollinations Turbo).
195
+ model (str): Which model to use ("hf" or "pollinations_turbo").
196
+ max_retries (int): Maximum number of retries.
197
+ Returns:
198
+ tuple: The generated image and its Base64 string.
199
+ """
200
+ for attempt in range(max_retries):
201
+ try:
202
+ if attempt > 0:
203
+ time.sleep(2 ** attempt)
204
+ return generate_image(prompt_text, style, model=model)
205
+ except Exception as e:
206
+ logger.error(f"Attempt {attempt+1} failed: {e}")
207
+ if attempt == max_retries - 1:
208
+ raise
209
+ return None, None