sozo-api / image_gen.py
rairo's picture
Update image_gen.py
a180781 verified
raw
history blame
12.5 kB
# -----------------------
# Image Generation
# -----------------------
import os
import re
import time
import tempfile
import requests
import json
from google import genai
from google.genai import types
import io
import base64
import numpy as np
import cv2
import logging
import uuid
import subprocess
from pathlib import Path
import urllib.parse
import pandas as pd
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import base64
import os
import uuid
import matplotlib
import matplotlib.pyplot as plt
from io import BytesIO
import dataframe_image as dfi
import uuid
from PIL import ImageFont, ImageDraw, Image
import seaborn as sns
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
token = os.getenv('HF_API')
headers = {"Authorization": f"Bearer {token}"}
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
def is_valid_png(file_path):
"""Check if the PNG file at `file_path` is valid."""
try:
with open(file_path, "rb") as f:
# Read the first 8 bytes to check the PNG signature
header = f.read(8)
if header != b'\x89PNG\r\n\x1a\n':
return False
# Attempt to open and verify the entire image
with Image.open(file_path) as img:
img.verify() # Verify the file integrity
return True
except Exception as e:
print(f"Invalid PNG file at {file_path}: {e}")
return False
def standardize_and_validate_image(file_path):
"""Validate, standardize, and overwrite the image at `file_path`."""
try:
# Verify basic integrity
with Image.open(file_path) as img:
img.verify()
# Reopen and convert to RGB
with Image.open(file_path) as img:
img = img.convert("RGB") # Remove alpha channel if present
# Save to a temporary BytesIO buffer first
buffer = io.BytesIO()
img.save(buffer, format="PNG")
buffer.seek(0)
# Write the buffer to the file
with open(file_path, "wb") as f:
f.write(buffer.getvalue())
return True
except Exception as e:
print(f"Failed to standardize/validate {file_path}: {e}")
return False
def generate_image(prompt_text, style, model="hf"):
"""
Generate an image from a text prompt using one of the following:
- Hugging Face's FLUX.1-dev
- Pollinations Turbo
- Google's Gemini
- Pexels API (for a real photo instead of AI)
Args:
prompt_text (str): The text prompt for image generation or search.
style (str or None): The style of the image (used for HF and Gemini models, ignored for Pexels).
model (str): Which model to use
("hf", "pollinations_turbo", "gemini", or "pexels").
Returns:
tuple: (PIL.Image, base64_string) or (None, None) on error.
"""
try:
if model == "pollinations_turbo":
# URL-encode the prompt and add the query parameter to specify the model as "turbo"
prompt_encoded = urllib.parse.quote(prompt_text)
api_url = f"https://image.pollinations.ai/prompt/{prompt_encoded}?model=turbo"
response = requests.get(api_url)
if response.status_code != 200:
logger.error(f"Pollinations API error: {response.status_code}, {response.text}")
print(f"Error from image generation API: {response.status_code}")
return None, None
image_bytes = response.content
elif model == "gemini":
# For Google's Gemini model
try:
g_api_key = os.getenv("GEMINI")
if not g_api_key:
logger.error("GEMINI_API_KEY not found in environment variables")
print("Google Gemini API key is missing. Please set the GEMINI_API_KEY environment variable.")
return None, None
# Initialize Gemini client
client = genai.Client(api_key=g_api_key)
# Enhance prompt with style
enhanced_prompt = f"image of {prompt_text} in {style} style, high quality, detailed illustration"
# Generate content
response = client.models.generate_content(
model="models/gemini-2.0-flash-exp",
contents=enhanced_prompt,
config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
)
# Extract image from response
for part in response.candidates[0].content.parts:
if part.inline_data is not None:
image = Image.open(BytesIO(part.inline_data.data))
# Convert to base64 string
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return image, img_str
# If no image was found in the response
logger.error("No image was found in the Gemini API response")
print("Gemini API didn't return an image")
return None, None
except ImportError:
logger.error("Google Gemini libraries not installed")
print("Google Gemini libraries not installed. Install with 'pip install google-genai'")
return None, None
except Exception as e:
logger.error(f"Gemini API error: {str(e)}")
print(f"Error from Gemini image generation: {str(e)}")
return None, None
elif model == "pexels":
# ---------- NEW BRANCH FOR PEXELS -----------
pexels_api_key = os.getenv("PEXELS_API_KEY")
if not pexels_api_key:
logger.error("PEXELS_API_KEY not found in environment variables")
print("Pexels API key is missing. Please set the PEXELS_API_KEY environment variable.")
return None, None
# Call Pexels search endpoint
# e.g. GET https://api.pexels.com/v1/search?query={prompt_text}&per_page=1
headers_pexels = {
"Authorization": pexels_api_key
}
g_api_key = os.getenv("GEMINI")
client = genai.Client(api_key=g_api_key)
new_prompt = client.models.generate_content(model="models/gemini-2.0-flash-lite", contents=f"Generate a keywords or a phrase, (maximum 3 words) to search for an appropriate image on the pexels API based on this prompt: {prompt_text} return only the keywords and nothing else")
pexel_prompt = str(new_prompt.text)
search_url = f"https://api.pexels.com/v1/search?query={pexel_prompt}&per_page=1"
response = requests.get(search_url, headers=headers_pexels)
if response.status_code != 200:
logger.error(f"Pexels API error: {response.status_code}, {response.text} from {pexel_prompt}")
print(f"Error from Pexels API: {response.status_code}")
return None, None
data = response.json()
photos = data.get("photos", [])
if not photos:
logger.error("No photos found for the given prompt on Pexels")
print("No photos found on Pexels for this prompt.")
return None, None
# Take the first photo
photo = photos[0]
# We can pick "src" => "original" or "large2x", etc.
image_url = photo["src"].get("large2x") or photo["src"].get("original")
if not image_url:
logger.error("No suitable image URL found in Pexels photo object")
return None, None
# Download the image
img_resp = requests.get(image_url)
if img_resp.status_code != 200:
logger.error(f"Failed to download Pexels image from {image_url}")
return None, None
image_bytes = img_resp.content
else:
# Default to Hugging Face model
enhanced_prompt = f"{prompt_text} in {style} style, high quality, detailed illustration"
model_id = "black-forest-labs/FLUX.1-dev"
api_url = f"https://api-inference.huggingface.co/models/{model_id}"
payload = {"inputs": enhanced_prompt}
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code != 200:
logger.error(f"Hugging Face API error: {response.status_code}, {response.text}")
print(f"Error from image generation API: {response.status_code}")
return None, None
image_bytes = response.content
# For HF, Pollinations, or Pexels that return image bytes
if model != "gemini":
image = Image.open(io.BytesIO(image_bytes))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return image, img_str
except Exception as e:
print(f"Error generating image: {e}")
logger.error(f"Image generation error: {str(e)}")
# Return a placeholder image in case of failure
return Image.new('RGB', (1024, 1024), color=(200,200,200)), None
def generate_image_with_retry(prompt_text, style, model="hf", max_retries=3):
"""
Attempt to generate an image using generate_image, retrying up to max_retries if needed.
Args:
prompt_text (str): The text prompt for image generation.
style (str or None): The style of the image (ignored for Pollinations Turbo).
model (str): Which model to use ("hf" or "pollinations_turbo").
max_retries (int): Maximum number of retries.
Returns:
tuple: The generated image and its Base64 string.
"""
for attempt in range(max_retries):
try:
if attempt > 0:
time.sleep(2 ** attempt)
return generate_image(prompt_text, style, model=model)
except Exception as e:
logger.error(f"Attempt {attempt+1} failed: {e}")
if attempt == max_retries - 1:
raise
return None, None
# edit image function
def edit_section_image(image_url: str, gemini_prompt: str):
"""
Downloads the existing image from image_url, uses Google Gemini to edit it
according to gemini_prompt, and returns the new edited image (as a PIL.Image).
"""
try:
# 1) Download the original image
resp = requests.get(image_url)
if resp.status_code != 200:
logger.error(f"Failed to download image from {image_url}")
return None
original_image = Image.open(io.BytesIO(resp.content))
# 2) Initialize Gemini client
g_api_key = os.getenv("GEMINI")
if not g_api_key:
logger.error("GEMINI_API_KEY not found in environment variables")
print("Google Gemini API key is missing. Please set the GEMINI_API_KEY environment variable.")
return None
client = genai.Client(api_key=g_api_key)
# 3) Prepare the prompt_with_image: a list with [ prompt_text, PIL.Image ]
prompt_with_image = [
gemini_prompt,
original_image
]
# 4) Call the Gemini model to edit the image
response = client.models.generate_content(
model="models/gemini-2.0-flash-exp",
contents=prompt_with_image,
config=types.GenerateContentConfig(
response_modalities=['Text', 'Image']
)
)
# 5) Extract the edited image from the response
# Typically, the 'response' might have text + image. We want the image part.
edited_image = None
for part in response.candidates[0].content.parts:
if part.inline_data is not None:
edited_image = Image.open(io.BytesIO(part.inline_data.data))
break
if not edited_image:
logger.error("No edited image found in Gemini response")
return None
return edited_image
except Exception as e:
logger.error(f"Error editing section image: {e}")
return None