| import logging |
| import os |
| from dataclasses import replace |
|
|
| try: |
| import spaces |
| except ImportError: |
| class spaces: |
| @staticmethod |
| def GPU(*args, **kwargs): |
| def decorator(func): |
| return func |
| return decorator |
|
|
| import gradio as gr |
| from PIL import Image |
|
|
| from ume_pipeline.config import PipelineConfig |
| from ume_pipeline.pipeline import UnifiedMultimodalEditor |
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") |
| LOGGER = logging.getLogger(__name__) |
|
|
| MAX_HISTORY = max(2, int(os.getenv("UME_MAX_HISTORY", "4"))) |
| ZERO_GPU_SIZE = os.getenv("UME_ZERO_GPU_SIZE", "large") |
| ZERO_GPU_DURATION = int(os.getenv("UME_ZERO_GPU_DURATION", os.getenv("UME_ZERO_GPU_EDIT_DURATION", "240"))) |
|
|
| |
| |
| LOGGER.info("Initializing UnifiedMultimodalEditor...") |
| config = replace( |
| PipelineConfig.from_env(), |
| compile_flux_transformer=False, |
| enable_cpu_offload=False, |
| ) |
| editor = UnifiedMultimodalEditor(config) |
|
|
|
|
| def get_editor(): |
| return editor |
|
|
|
|
| def message_text(value) -> str: |
| if isinstance(value, str): |
| return value |
| if isinstance(value, dict): |
| for key in ("content", "text", "value"): |
| if key in value: |
| return message_text(value.get(key)) |
| if isinstance(value, (list, tuple)): |
| for item in reversed(value): |
| text = message_text(item) |
| if text: |
| return text |
| if hasattr(value, "content"): |
| return message_text(getattr(value, "content")) |
| return "" if value is None else str(value) |
|
|
|
|
| def message_role(value) -> str | None: |
| if isinstance(value, dict): |
| return value.get("role") |
| if hasattr(value, "role"): |
| return getattr(value, "role") |
| return None |
|
|
|
|
| def normalize_chat_history(history) -> list[dict[str, str]]: |
| normalized = [] |
| for item in history or []: |
| role = message_role(item) |
| if role in {"user", "assistant"}: |
| normalized.append({"role": role, "content": message_text(item)}) |
| elif isinstance(item, (list, tuple)) and len(item) >= 2: |
| user_text = message_text(item[0]) |
| assistant_text = message_text(item[1]) |
| if user_text: |
| normalized.append({"role": "user", "content": user_text}) |
| if assistant_text: |
| normalized.append({"role": "assistant", "content": assistant_text}) |
| return normalized |
|
|
|
|
| def is_edit_instruction(text) -> bool: |
| text_lower = message_text(text).lower().strip() |
| edit_keywords = ["change", "make", "add", "remove", "turn", "replace"] |
| return any(text_lower.startswith(kw) for kw in edit_keywords) |
|
|
|
|
| def normalize_image(image: Image.Image) -> Image.Image: |
| return image.convert("RGB").resize((config.width, config.height), Image.Resampling.LANCZOS) |
|
|
|
|
| def trim_history(history: list[Image.Image]) -> list[Image.Image]: |
| if len(history) <= MAX_HISTORY: |
| return history |
| return [history[0], *history[-(MAX_HISTORY - 1):]] |
|
|
|
|
| @spaces.GPU(duration=ZERO_GPU_DURATION, size=ZERO_GPU_SIZE) |
| def chat_interface(message, image_state: dict, progress: gr.Progress = None): |
| message = message_text(message) |
| if not image_state or not image_state.get('history'): |
| yield "Please upload an image first.", image_state |
| return |
|
|
| img_history = list(image_state['history']) |
| current_image = normalize_image(img_history[-1]) |
|
|
| ed = get_editor() |
| if is_edit_instruction(message): |
| if progress: progress(0.1, desc="Initializing Editor...") |
| yield "Processing edit...", image_state |
| try: |
| if progress: progress(0.3, desc="Perception & Localization...") |
| output = ed.run(current_image, message) |
|
|
| if progress: progress(0.9, desc="Finalizing Image...") |
| new_image = output.image |
| img_history.append(normalize_image(new_image)) |
| img_history = trim_history(img_history) |
| image_state['history'] = img_history |
| if progress: progress(1.0, desc="Done!") |
| yield f"Edit complete! (Target concept: {output.perception.target_concept})", image_state |
| except Exception as e: |
| LOGGER.error("Error during editing", exc_info=True) |
| gr.Warning(f"Editing failed: {str(e)}") |
| yield f"Error during editing: {str(e)}", image_state |
| else: |
| if progress: progress(0.2, desc="Thinking...") |
| yield "Thinking...", image_state |
| try: |
| response = ed.brain.chat(current_image, message) |
| if progress: progress(1.0, desc="Done!") |
| yield response, image_state |
| except Exception as e: |
| LOGGER.error("Error during chat", exc_info=True) |
| gr.Warning(f"Chat failed: {str(e)}") |
| yield f"Error during chat: {str(e)}", image_state |
|
|
|
|
| with gr.Blocks(title="Unified Multimodal Editor") as demo: |
| gr.Markdown("# Unified Multimodal Editor (UME)") |
| gr.Markdown("Upload an image, ask for a description, or give instructions to edit it (e.g., 'change the red mug to blue').") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="Upload Original Image") |
| current_image_display = gr.Image(type="pil", label="Current Image", interactive=False) |
| |
| with gr.Column(scale=2): |
| chatbot = gr.Chatbot(label="Conversation") |
| msg = gr.Textbox(label="Type a command (e.g., 'describe the image' or 'change the red mug to blue')") |
| |
| with gr.Row(): |
| submit_btn = gr.Button("Submit", variant="primary") |
| revert_btn = gr.Button("Revert Last Change") |
| revert_all_btn = gr.Button("Revert All Changes") |
|
|
| |
| image_state = gr.State({"history": []}) |
|
|
| def handle_upload(img): |
| if img is None: |
| return {"history": []}, None, [] |
| normalized = normalize_image(img) |
| return {"history": [normalized]}, normalized, [] |
|
|
| image_input.upload( |
| handle_upload, |
| inputs=[image_input], |
| outputs=[image_state, current_image_display, chatbot] |
| ) |
|
|
| def user(user_message, history): |
| history = normalize_chat_history(history) |
| user_message = message_text(user_message).strip() |
| if not user_message: |
| return "", history |
| return "", history + [{"role": "user", "content": user_message}] |
|
|
| def bot(history, state, progress=gr.Progress()): |
| history = normalize_chat_history(history) |
| if not history or history[-1].get("role") != "user": |
| display_img = state["history"][-1] if state and state.get("history") else None |
| yield history, state, display_img |
| return |
|
|
| user_message = history[-1]["content"] |
|
|
| for response, new_state in chat_interface(user_message, state, progress): |
| history = [*history] |
| if history[-1].get("role") == "user": |
| history.append({"role": "assistant", "content": response}) |
| else: |
| history[-1]["content"] = response |
| display_img = new_state["history"][-1] if new_state and new_state.get("history") else None |
| yield history, new_state, display_img |
|
|
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot, [chatbot, image_state], [chatbot, image_state, current_image_display] |
| ) |
| submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot, [chatbot, image_state], [chatbot, image_state, current_image_display] |
| ) |
|
|
| def revert_action(action, history, state): |
| history = normalize_chat_history(history) |
| state = state or {"history": []} |
| img_history = list(state.get("history", [])) |
|
|
| if not img_history: |
| response = "Please upload an image first." |
| elif action == "revert all changes" and len(img_history) > 1: |
| img_history = [img_history[0]] |
| response = "Reverted all changes. Back to the original image." |
| elif action == "revert changes" and len(img_history) > 1: |
| img_history.pop() |
| response = "Reverted the last change." |
| else: |
| response = "No changes to revert. This is the original image." |
|
|
| new_state = {"history": img_history} |
| history = history + [ |
| {"role": "user", "content": action}, |
| {"role": "assistant", "content": response}, |
| ] |
| display_img = img_history[-1] if img_history else None |
| return history, new_state, display_img |
|
|
| def revert_last_change(history, state): |
| return revert_action("revert changes", history, state) |
|
|
| def revert_all_changes(history, state): |
| return revert_action("revert all changes", history, state) |
|
|
| revert_btn.click(revert_last_change, [chatbot, image_state], [chatbot, image_state, current_image_display]) |
| revert_all_btn.click(revert_all_changes, [chatbot, image_state], [chatbot, image_state, current_image_display]) |
|
|
| if __name__ == "__main__": |
| demo.queue(default_concurrency_limit=1).launch( |
| share=os.getenv("GRADIO_SHARE", "0") == "1", |
| ) |
|
|