File size: 2,299 Bytes
9ae5cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from typing import List
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
OutputParam,
)
from PIL import Image
import google.generativeai as genai
import os
class NanoBanana(ModularPipelineBlocks):
def __init__(self, model_id="gemini-2.5-flash-image-preview"):
super().__init__()
api_key = os.getenv("GEMINI_API_KEY")
if api_key is None:
raise ValueError("Must provide an API key for Gemini through the `GEMINI_API_KEY` env variable.")
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model_name=model_id)
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Image.Image,
required=False,
description="Image to use"
),
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
)
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"output_image",
type_hint=str,
description="Output image",
),
OutputParam(
"old_image",
type_hint=str,
description="Old image (if) provided by the user",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_image = block_state.image
prompt = block_state.state.prompt
contents = [prompt]
if old_image is not None:
contents.expand(old_image)
output = self.model.generate_content(contents=contents)
block_state.output_image = output
if old_image is not None:
block_state.old_image = old_image
else:
block_state.old_image = None
self.set_block_state(state, block_state)
return components, state |