|
|
|
|
|
from typing import List |
|
|
from diffusers.modular_pipelines import ( |
|
|
PipelineState, |
|
|
ModularPipelineBlocks, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
from PIL import Image |
|
|
import google.generativeai as genai |
|
|
import os |
|
|
|
|
|
|
|
|
class NanoBanana(ModularPipelineBlocks): |
|
|
def __init__(self, model_id="gemini-2.5-flash-image-preview"): |
|
|
super().__init__() |
|
|
api_key = os.getenv("GEMINI_API_KEY") |
|
|
if api_key is None: |
|
|
raise ValueError("Must provide an API key for Gemini through the `GEMINI_API_KEY` env variable.") |
|
|
genai.configure(api_key=api_key) |
|
|
self.model = genai.GenerativeModel(model_name=model_id) |
|
|
|
|
|
@property |
|
|
def expected_components(self): |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"image", |
|
|
type_hint=Image.Image, |
|
|
required=False, |
|
|
description="Image to use" |
|
|
), |
|
|
InputParam( |
|
|
"prompt", |
|
|
type_hint=str, |
|
|
required=True, |
|
|
description="Prompt to use", |
|
|
) |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_inputs(self) -> List[InputParam]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"output_image", |
|
|
type_hint=str, |
|
|
description="Output image", |
|
|
), |
|
|
OutputParam( |
|
|
"old_image", |
|
|
type_hint=str, |
|
|
description="Old image (if) provided by the user", |
|
|
) |
|
|
] |
|
|
|
|
|
|
|
|
def __call__(self, components, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
old_image = block_state.image |
|
|
prompt = block_state.state.prompt |
|
|
contents = [prompt] |
|
|
if old_image is not None: |
|
|
contents.expand(old_image) |
|
|
|
|
|
output = self.model.generate_content(contents=contents) |
|
|
block_state.output_image = output |
|
|
|
|
|
if old_image is not None: |
|
|
block_state.old_image = old_image |
|
|
else: |
|
|
block_state.old_image = None |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |