Spaces:
Sleeping
Sleeping
| # ----------------------- | |
| # Image Generation | |
| # ----------------------- | |
| import os | |
| import re | |
| import time | |
| import tempfile | |
| import requests | |
| import json | |
| from google import genai | |
| from google.genai import types | |
| import io | |
| import base64 | |
| import numpy as np | |
| import cv2 | |
| import logging | |
| import uuid | |
| import subprocess | |
| from pathlib import Path | |
| import urllib.parse | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import matplotlib.pyplot as plt | |
| import base64 | |
| import os | |
| import uuid | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| import dataframe_image as dfi | |
| import uuid | |
| from PIL import ImageFont, ImageDraw, Image | |
| import seaborn as sns | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| token = os.getenv('HF_API') | |
| headers = {"Authorization": f"Bearer {token}"} | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| def is_valid_png(file_path): | |
| """Check if the PNG file at `file_path` is valid.""" | |
| try: | |
| with open(file_path, "rb") as f: | |
| # Read the first 8 bytes to check the PNG signature | |
| header = f.read(8) | |
| if header != b'\x89PNG\r\n\x1a\n': | |
| return False | |
| # Attempt to open and verify the entire image | |
| with Image.open(file_path) as img: | |
| img.verify() # Verify the file integrity | |
| return True | |
| except Exception as e: | |
| print(f"Invalid PNG file at {file_path}: {e}") | |
| return False | |
| def standardize_and_validate_image(file_path): | |
| """Validate, standardize, and overwrite the image at `file_path`.""" | |
| try: | |
| # Verify basic integrity | |
| with Image.open(file_path) as img: | |
| img.verify() | |
| # Reopen and convert to RGB | |
| with Image.open(file_path) as img: | |
| img = img.convert("RGB") # Remove alpha channel if present | |
| # Save to a temporary BytesIO buffer first | |
| buffer = io.BytesIO() | |
| img.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| # Write the buffer to the file | |
| with open(file_path, "wb") as f: | |
| f.write(buffer.getvalue()) | |
| return True | |
| except Exception as e: | |
| print(f"Failed to standardize/validate {file_path}: {e}") | |
| return False | |
| def generate_image(prompt_text, style, model="hf"): | |
| """ | |
| Generate an image from a text prompt using one of the following: | |
| - Hugging Face's FLUX.1-dev | |
| - Pollinations Turbo | |
| - Google's Gemini | |
| - Pexels API (for a real photo instead of AI) | |
| Args: | |
| prompt_text (str): The text prompt for image generation or search. | |
| style (str or None): The style of the image (used for HF and Gemini models, ignored for Pexels). | |
| model (str): Which model to use | |
| ("hf", "pollinations_turbo", "gemini", or "pexels"). | |
| Returns: | |
| tuple: (PIL.Image, base64_string) or (None, None) on error. | |
| """ | |
| try: | |
| if model == "pollinations_turbo": | |
| # URL-encode the prompt and add the query parameter to specify the model as "turbo" | |
| prompt_encoded = urllib.parse.quote(prompt_text) | |
| api_url = f"https://image.pollinations.ai/prompt/{prompt_encoded}?model=turbo" | |
| response = requests.get(api_url) | |
| if response.status_code != 200: | |
| logger.error(f"Pollinations API error: {response.status_code}, {response.text}") | |
| print(f"Error from image generation API: {response.status_code}") | |
| return None, None | |
| image_bytes = response.content | |
| elif model == "gemini": | |
| # For Google's Gemini model | |
| try: | |
| g_api_key = os.getenv("GEMINI") | |
| if not g_api_key: | |
| logger.error("GEMINI_API_KEY not found in environment variables") | |
| print("Google Gemini API key is missing. Please set the GEMINI_API_KEY environment variable.") | |
| return None, None | |
| # Initialize Gemini client | |
| client = genai.Client(api_key=g_api_key) | |
| # Enhance prompt with style | |
| enhanced_prompt = f"image of {prompt_text} in {style} style, high quality, detailed illustration" | |
| # Generate content | |
| response = client.models.generate_content( | |
| model="models/gemini-2.0-flash-exp", | |
| contents=enhanced_prompt, | |
| config=types.GenerateContentConfig(response_modalities=['Text', 'Image']) | |
| ) | |
| # Extract image from response | |
| for part in response.candidates[0].content.parts: | |
| if part.inline_data is not None: | |
| image = Image.open(BytesIO(part.inline_data.data)) | |
| # Convert to base64 string | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return image, img_str | |
| # If no image was found in the response | |
| logger.error("No image was found in the Gemini API response") | |
| print("Gemini API didn't return an image") | |
| return None, None | |
| except ImportError: | |
| logger.error("Google Gemini libraries not installed") | |
| print("Google Gemini libraries not installed. Install with 'pip install google-genai'") | |
| return None, None | |
| except Exception as e: | |
| logger.error(f"Gemini API error: {str(e)}") | |
| print(f"Error from Gemini image generation: {str(e)}") | |
| return None, None | |
| elif model == "pexels": | |
| # ---------- NEW BRANCH FOR PEXELS ----------- | |
| pexels_api_key = os.getenv("PEXELS_API_KEY") | |
| if not pexels_api_key: | |
| logger.error("PEXELS_API_KEY not found in environment variables") | |
| print("Pexels API key is missing. Please set the PEXELS_API_KEY environment variable.") | |
| return None, None | |
| # Call Pexels search endpoint | |
| # e.g. GET https://api.pexels.com/v1/search?query={prompt_text}&per_page=1 | |
| headers_pexels = { | |
| "Authorization": pexels_api_key | |
| } | |
| g_api_key = os.getenv("GEMINI") | |
| client = genai.Client(api_key=g_api_key) | |
| new_prompt = client.models.generate_content(model="models/gemini-2.0-flash-lite", contents=f"Generate a keywords or a phrase, (maximum 3 words) to search for an appropriate image on the pexels API based on this prompt: {prompt_text} return only the keywords and nothing else") | |
| pexel_prompt = str(new_prompt.text) | |
| search_url = f"https://api.pexels.com/v1/search?query={pexel_prompt}&per_page=1" | |
| response = requests.get(search_url, headers=headers_pexels) | |
| if response.status_code != 200: | |
| logger.error(f"Pexels API error: {response.status_code}, {response.text} from {pexel_prompt}") | |
| print(f"Error from Pexels API: {response.status_code}") | |
| return None, None | |
| data = response.json() | |
| photos = data.get("photos", []) | |
| if not photos: | |
| logger.error("No photos found for the given prompt on Pexels") | |
| print("No photos found on Pexels for this prompt.") | |
| return None, None | |
| # Take the first photo | |
| photo = photos[0] | |
| # We can pick "src" => "original" or "large2x", etc. | |
| image_url = photo["src"].get("large2x") or photo["src"].get("original") | |
| if not image_url: | |
| logger.error("No suitable image URL found in Pexels photo object") | |
| return None, None | |
| # Download the image | |
| img_resp = requests.get(image_url) | |
| if img_resp.status_code != 200: | |
| logger.error(f"Failed to download Pexels image from {image_url}") | |
| return None, None | |
| image_bytes = img_resp.content | |
| else: | |
| # Default to Hugging Face model | |
| enhanced_prompt = f"{prompt_text} in {style} style, high quality, detailed illustration" | |
| model_id = "black-forest-labs/FLUX.1-dev" | |
| api_url = f"https://api-inference.huggingface.co/models/{model_id}" | |
| payload = {"inputs": enhanced_prompt} | |
| response = requests.post(api_url, headers=headers, json=payload) | |
| if response.status_code != 200: | |
| logger.error(f"Hugging Face API error: {response.status_code}, {response.text}") | |
| print(f"Error from image generation API: {response.status_code}") | |
| return None, None | |
| image_bytes = response.content | |
| # For HF, Pollinations, or Pexels that return image bytes | |
| if model != "gemini": | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return image, img_str | |
| except Exception as e: | |
| print(f"Error generating image: {e}") | |
| logger.error(f"Image generation error: {str(e)}") | |
| # Return a placeholder image in case of failure | |
| return Image.new('RGB', (1024, 1024), color=(200,200,200)), None | |
| def generate_image_with_retry(prompt_text, style, model="hf", max_retries=3): | |
| """ | |
| Attempt to generate an image using generate_image, retrying up to max_retries if needed. | |
| Args: | |
| prompt_text (str): The text prompt for image generation. | |
| style (str or None): The style of the image (ignored for Pollinations Turbo). | |
| model (str): Which model to use ("hf" or "pollinations_turbo"). | |
| max_retries (int): Maximum number of retries. | |
| Returns: | |
| tuple: The generated image and its Base64 string. | |
| """ | |
| for attempt in range(max_retries): | |
| try: | |
| if attempt > 0: | |
| time.sleep(2 ** attempt) | |
| return generate_image(prompt_text, style, model=model) | |
| except Exception as e: | |
| logger.error(f"Attempt {attempt+1} failed: {e}") | |
| if attempt == max_retries - 1: | |
| raise | |
| return None, None | |
| # edit image function | |
| def edit_section_image(image_url: str, gemini_prompt: str): | |
| """ | |
| Downloads the existing image from image_url, uses Google Gemini to edit it | |
| according to gemini_prompt, and returns the new edited image (as a PIL.Image). | |
| """ | |
| try: | |
| # 1) Download the original image | |
| resp = requests.get(image_url) | |
| if resp.status_code != 200: | |
| logger.error(f"Failed to download image from {image_url}") | |
| return None | |
| original_image = Image.open(io.BytesIO(resp.content)) | |
| # 2) Initialize Gemini client | |
| g_api_key = os.getenv("GEMINI") | |
| if not g_api_key: | |
| logger.error("GEMINI_API_KEY not found in environment variables") | |
| print("Google Gemini API key is missing. Please set the GEMINI_API_KEY environment variable.") | |
| return None | |
| client = genai.Client(api_key=g_api_key) | |
| # 3) Prepare the prompt_with_image: a list with [ prompt_text, PIL.Image ] | |
| prompt_with_image = [ | |
| gemini_prompt, | |
| original_image | |
| ] | |
| # 4) Call the Gemini model to edit the image | |
| response = client.models.generate_content( | |
| model="models/gemini-2.0-flash-exp", | |
| contents=prompt_with_image, | |
| config=types.GenerateContentConfig( | |
| response_modalities=['Text', 'Image'] | |
| ) | |
| ) | |
| # 5) Extract the edited image from the response | |
| # Typically, the 'response' might have text + image. We want the image part. | |
| edited_image = None | |
| for part in response.candidates[0].content.parts: | |
| if part.inline_data is not None: | |
| edited_image = Image.open(io.BytesIO(part.inline_data.data)) | |
| break | |
| if not edited_image: | |
| logger.error("No edited image found in Gemini response") | |
| return None | |
| return edited_image | |
| except Exception as e: | |
| logger.error(f"Error editing section image: {e}") | |
| return None |