Spaces:
Sleeping
Sleeping
| import base64 | |
| import os | |
| import gradio as gr | |
| from google import genai | |
| from google.genai import types | |
| from google.genai.types import HarmBlockThreshold | |
| from PIL import Image | |
| from io import BytesIO | |
| import tempfile | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| def generate_image(input_image, partner_type): | |
| """ | |
| Generate an image with a girlfriend or boyfriend added to the input image. | |
| Args: | |
| input_image: The uploaded image file | |
| partner_type: Either "Girlfriend" or "Boyfriend" | |
| Returns: | |
| The generated image with the partner added | |
| """ | |
| # Check if input image is provided | |
| if input_image is None: | |
| return None, "Please upload an image first." | |
| # Get API key from environment variables | |
| api_key = os.environ.get("GEMINI_API_KEY") | |
| if not api_key: | |
| # For Hugging Face Spaces, try to get from HF_TOKEN | |
| api_key = os.environ.get("HF_TOKEN") | |
| if not api_key: | |
| return None, "GEMINI_API_KEY not found in environment variables." | |
| client = genai.Client(api_key=api_key) | |
| # Save the uploaded image to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file: | |
| input_image_path = temp_file.name | |
| input_image.save(input_image_path) | |
| # Upload the temporary file to Gemini | |
| try: | |
| files = [ | |
| client.files.upload(file=input_image_path), | |
| ] | |
| # Create the prompt based on the partner type selection | |
| prompt = f"add a {partner_type.lower()} beside the person in the picture. Do not change the environment or background." | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=[ | |
| types.Part.from_uri( | |
| file_uri=files[0].uri, | |
| mime_type=files[0].mime_type, | |
| ), | |
| types.Part.from_text(text=prompt), | |
| ], | |
| ), | |
| ] | |
| generate_content_config = types.GenerateContentConfig( | |
| temperature=0.2, | |
| top_p=0.95, | |
| top_k=40, | |
| max_output_tokens=8192, | |
| response_modalities=[ | |
| "image", | |
| "text", | |
| ], | |
| safety_settings=[ | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_HARASSMENT", | |
| threshold=HarmBlockThreshold.BLOCK_NONE , | |
| ), | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_HATE_SPEECH", | |
| threshold=HarmBlockThreshold.BLOCK_NONE, | |
| ), | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| threshold=HarmBlockThreshold.BLOCK_NONE, | |
| ), | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_DANGEROUS_CONTENT", | |
| threshold=HarmBlockThreshold.BLOCK_NONE, | |
| ), | |
| types.SafetySetting( | |
| category="HARM_CATEGORY_CIVIC_INTEGRITY", | |
| threshold=HarmBlockThreshold.BLOCK_NONE, | |
| ), | |
| ], | |
| response_mime_type="text/plain", | |
| ) | |
| try: | |
| response = client.models.generate_content( | |
| model="models/gemini-2.0-flash-exp", | |
| contents=contents, | |
| config=generate_content_config, | |
| ) | |
| output_image = None | |
| output_text = "" | |
| # Check if response has candidates | |
| if response and hasattr(response, 'candidates') and response.candidates: | |
| candidate = response.candidates[0] | |
| if hasattr(candidate, 'content') and candidate.content: | |
| for part in candidate.content.parts: | |
| if part.text is not None: | |
| output_text += part.text + "\n" | |
| elif part.inline_data is not None: | |
| try: | |
| # Add debug information | |
| # output_text += f"Received image data of type: {type(part.inline_data.data)}\n" | |
| # output_text += f"MIME type: {part.inline_data.mime_type}\n" | |
| # Try different approaches to handle the image data | |
| if isinstance(part.inline_data.data, bytes): | |
| # If it's already bytes, use it directly | |
| image_data = part.inline_data.data | |
| else: | |
| # Otherwise, try to decode it from base64 | |
| image_data = base64.b64decode(part.inline_data.data) | |
| # Save to a temporary file first | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| temp_file.write(image_data) | |
| temp_file_path = temp_file.name | |
| # Open the image from the temporary file | |
| output_image = Image.open(temp_file_path) | |
| # Clean up the temporary file | |
| os.unlink(temp_file_path) | |
| except Exception as img_error: | |
| output_text += f"Error processing image: {str(img_error)}\n" | |
| output_text += f"Error type: {type(img_error).__name__}\n" | |
| # Check if there are finish details to display | |
| if hasattr(candidate, 'finish_details') and candidate.finish_details: | |
| output_text += f"\nFinish reason: {candidate.finish_details.finish_reason}\n" | |
| # Check if there are safety ratings to display | |
| if hasattr(candidate, 'safety_ratings') and candidate.safety_ratings: | |
| output_text += "\nSafety ratings:\n" | |
| for rating in candidate.safety_ratings: | |
| output_text += f"- {rating.category}: {rating.probability}\n" | |
| else: | |
| output_text = "The model did not generate a valid response. Please try again with a different image." | |
| if hasattr(response, 'prompt_feedback') and response.prompt_feedback: | |
| output_text += f"\nPrompt feedback: {response.prompt_feedback.block_reason}" | |
| except Exception as api_error: | |
| output_text = f"API Error: {str(api_error)}\n\nDetails: {type(api_error).__name__}" | |
| # Clean up the temporary file in case of error | |
| if os.path.exists(input_image_path): | |
| os.unlink(input_image_path) | |
| return None, output_text | |
| # Clean up the temporary file | |
| os.unlink(input_image_path) | |
| return output_image, output_text | |
| except Exception as e: | |
| # Clean up the temporary file in case of error | |
| if os.path.exists(input_image_path): | |
| os.unlink(input_image_path) | |
| error_details = f"Error: {str(e)}\n\nType: {type(e).__name__}" | |
| print(f"Exception occurred: {error_details}") | |
| return None, error_details | |
| # Create the Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="Single No More") as app: | |
| gr.Markdown("# Single No More") | |
| gr.Markdown("Stop annoying family members pestering you for being single!") | |
| gr.Markdown("Upload your photo and choose to add a girlfriend or boyfriend!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Your Photo", type="pil", image_mode="RGB") | |
| partner_type = gr.Radio( | |
| ["Girlfriend", "Boyfriend"], | |
| label="Choose Partner Type", | |
| value="Girlfriend" | |
| ) | |
| submit_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Result", type="pil") | |
| output_text = gr.Textbox(label="Response", lines=3) | |
| submit_btn.click( | |
| fn=generate_image, | |
| inputs=[input_image, partner_type], | |
| outputs=[output_image, output_text] | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_interface() | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |