| | import json |
| | import base64 |
| | from io import BytesIO |
| |
|
| | from typing import List, Union |
| | from PIL import Image |
| | from urllib.parse import urlparse |
| | from diffusers.modular_pipelines.modular_pipeline_utils import ConfigSpec |
| | from huggingface_hub import InferenceClient |
| | from diffusers import ModularPipeline, ModularPipelineBlocks |
| | from diffusers.modular_pipelines import InputParam, OutputParam, PipelineState |
| | from diffusers.utils import logger |
| |
|
| |
|
| | FRAME_MULTIPLE = 4 |
| |
|
| |
|
| | def _encode_image_to_base64(img: Image.Image, max_size_mb: float = 3.5) -> str: |
| | """Helper function to encode a PIL Image to base64 data URI. |
| | |
| | Args: |
| | img: PIL Image object |
| | max_size_mb: Maximum size in MB for data URIs |
| | |
| | Returns: |
| | str: Base64 encoded data URI |
| | """ |
| | buffer = BytesIO() |
| | img.save(buffer, format="PNG") |
| | size_mb = len(buffer.getvalue()) / (1024 * 1024) |
| |
|
| | if size_mb <= max_size_mb: |
| | img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | return f"data:image/png;base64,{img_str}" |
| |
|
| | if img.mode not in ("RGB", "L"): |
| | img = img.convert("RGB") |
| |
|
| | if size_mb > max_size_mb * 2: |
| | scale = (max_size_mb / size_mb) ** 0.5 |
| | new_size = (int(img.width * scale), int(img.height * scale)) |
| | img = img.resize(new_size, Image.Resampling.LANCZOS) |
| |
|
| | buffer = BytesIO() |
| | img.save(buffer, format="JPEG", quality=85, optimize=True) |
| | img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | return f"data:image/jpeg;base64,{img_str}" |
| |
|
| |
|
| | def image_to_uri(image: Union[str, Image.Image], max_size_mb: float = 3.5) -> str: |
| | """Convert an image to a URI. |
| | |
| | Args: |
| | image: URL string, local file path string, or PIL Image object |
| | max_size_mb: Maximum size in MB for data URIs (default 3.5MB) |
| | |
| | Returns: |
| | str: URL if input is a URL, data URI otherwise |
| | """ |
| | if isinstance(image, Image.Image): |
| | return _encode_image_to_base64(image, max_size_mb) |
| |
|
| | parsed = urlparse(image) |
| | if parsed.scheme in ("http", "https") and parsed.netloc: |
| | return image |
| |
|
| | with Image.open(image) as img: |
| | return _encode_image_to_base64(img, max_size_mb) |
| |
|
| |
|
| | class ImageToMatrixGameAction(ModularPipelineBlocks): |
| | model_name = "MatrixGameWan" |
| |
|
| | @property |
| | def inputs(self): |
| | return [ |
| | InputParam("image"), |
| | InputParam("num_frames"), |
| | InputParam("prompt"), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self): |
| | return [OutputParam("actions")] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ConfigSpec]: |
| | return [ConfigSpec("model_id", default="Qwen/Qwen2.5-VL-72B-Instruct")] |
| |
|
| | def __call__(self, components: ModularPipeline, state: PipelineState): |
| | client = InferenceClient() |
| | instructions = """ |
| | You will be provided an image and you have to interpret how you would move inside it |
| | if the image was in 3D space. |
| | |
| | Here are the available actions you can take: |
| | |
| | Movement Actions: forward, left, right |
| | Camera Actions: camera_l, camera_r |
| | |
| | You can also combine actions with an _ to create compound actions. e.g. forward_left_camera_l |
| | Each action is rendered for 12 frames, so make sure the number of actions suggested fits into the total number of frames available: {num_frames} |
| | |
| | e.g ["forward", "forward_left", "camera_l"] |
| | |
| | Here are additional instructions for you to follow: |
| | {prompt} |
| | |
| | Only respond with the list of actions you have to take and nothing else. |
| | """ |
| | block_state = self.get_block_state(state) |
| |
|
| | image = block_state.image |
| | prompt = block_state.prompt or "" |
| | num_frames = block_state.num_frames |
| | instructions = instructions.format(prompt=prompt, num_frames=num_frames) |
| |
|
| | try: |
| | user_message = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | { |
| | "type": "image_url", |
| | "image_url": {"url": image_to_uri(image)}, |
| | }, |
| | {"type": "text", "text": instructions}, |
| | ], |
| | } |
| | ] |
| |
|
| | completion = client.chat.completions.create( |
| | model=components.model_id, |
| | messages=user_message, |
| | temperature=0.2, |
| | max_tokens=1000, |
| | ) |
| | content = completion.choices[0].message.content |
| | block_state.actions = json.loads(content) |
| |
|
| | self.set_block_state(state, block_state) |
| |
|
| | return components, state |
| |
|
| | except Exception as e: |
| | logger.warning("Unable to generate actions. Defaulting to random actions") |
| | return components, state |
| |
|
| |
|