File size: 3,107 Bytes
9fca7c3
d12f2bb
15652b4
 
 
279af2f
15652b4
 
 
efbbf0f
b8c2f1b
15652b4
 
 
efbbf0f
 
 
 
 
 
 
15652b4
 
 
 
 
 
 
 
 
efbbf0f
 
 
 
19d53bf
15652b4
 
efbbf0f
15652b4
d12f2bb
9fca7c3
15652b4
efbbf0f
 
 
 
 
 
 
 
15652b4
 
 
d12f2bb
 
9fca7c3
15652b4
efbbf0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427cb66
d12f2bb
efbbf0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from google import genai
from PIL import Image
from io import BytesIO
import base64

# Configure the Gemini API client with the hardcoded API key
GOOGLE_API_KEY = "AIzaSyDL5Rilo7ptJpUOZdY6wy8PJYUcVcnDADs"
client = genai.Client(api_key=GOOGLE_API_KEY)
GEMINI_MODEL_NAME = 'gemini-2.5-flash-image-preview'

def process_image(image, prompt):
    try:
        # Prepare the content for the Gemini API
        contents = []
        if image:
            # Convert Gradio image (PIL Image) to base64
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
            contents.append({
                "parts": [
                    {"text": prompt},
                    {
                        "inline_data": {
                            "mime_type": "image/png",
                            "data": img_base64
                        }
                    }
                ]
            })
        else:
            # Text-to-image generation
            contents.append({"parts": [{"text": prompt}]})

        # Call the Gemini API
        response = client.models.generate_content(
            model=GEMINI_MODEL_NAME,
            contents=contents
        )

        # Process the response
        for candidate in response.candidates:
            for part in candidate.content.parts:
                if hasattr(part, 'inline_data') and part.inline_data:
                    # Decode the generated image
                    img_data = base64.b64decode(part.inline_data.data)
                    return Image.open(BytesIO(img_data))
                elif part.text:
                    return f"Text response: {part.text}"
        
        return "No image or text returned by the model."
    
    except Exception as e:
        return f"Error: {str(e)}"

# Create the Gradio interface
css = '''
.grid-container img {object-fit: contain}
.grid-container {display: grid; grid-template-columns: 1fr}
'''

with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
    gr.HTML('''
    <img src='https://huggingface.co/spaces/multimodalart/nano-banana/resolve/main/nano_banana_pros_light.png' style='margin: 0 auto; max-width: 500px' />
    <h3 style='text-align:center'>Nano Banana: Gemini 2.5 Flash Image Preview</h3>
    ''')

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Image (Optional)", file_types=["image"])
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="e.g., 'Add a nano-banana to the image in a fancy restaurant setting' or 'Generate a cat eating a nano-banana'"
            )
            generate_button = gr.Button("Generate", variant="primary")

        with gr.Column(scale=1):
            output_image = gr.Image(label="Generated Image", type="pil")

    # Event handler
    generate_button.click(
        fn=process_image,
        inputs=[image_input, prompt_input],
        outputs=[output_image]
    )

if __name__ == "__main__":
    demo.launch()