nano-banana-modular / nano_banana.py
sayakpaul's picture
sayakpaul HF Staff
Create nano_banana.py
9ae5cf2 verified
raw
history blame
2.3 kB
from typing import List
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
OutputParam,
)
from PIL import Image
import google.generativeai as genai
import os
class NanoBanana(ModularPipelineBlocks):
def __init__(self, model_id="gemini-2.5-flash-image-preview"):
super().__init__()
api_key = os.getenv("GEMINI_API_KEY")
if api_key is None:
raise ValueError("Must provide an API key for Gemini through the `GEMINI_API_KEY` env variable.")
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name=model_id)
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Image.Image,
required=False,
description="Image to use"
),
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
)
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"output_image",
type_hint=str,
description="Output image",
),
OutputParam(
"old_image",
type_hint=str,
description="Old image (if) provided by the user",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_image = block_state.image
prompt = block_state.state.prompt
contents = [prompt]
if old_image is not None:
contents.expand(old_image)
output = self.model.generate_content(contents=contents)
block_state.output_image = output
if old_image is not None:
block_state.old_image = old_image
else:
block_state.old_image = None
self.set_block_state(state, block_state)
return components, state