import logging import os from dataclasses import replace try: import spaces except ImportError: class spaces: @staticmethod def GPU(*args, **kwargs): def decorator(func): return func return decorator import gradio as gr from PIL import Image from ume_pipeline.config import PipelineConfig from ume_pipeline.pipeline import UnifiedMultimodalEditor logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") LOGGER = logging.getLogger(__name__) MAX_HISTORY = max(2, int(os.getenv("UME_MAX_HISTORY", "4"))) ZERO_GPU_SIZE = os.getenv("UME_ZERO_GPU_SIZE", "large") ZERO_GPU_DURATION = int(os.getenv("UME_ZERO_GPU_DURATION", os.getenv("UME_ZERO_GPU_EDIT_DURATION", "240"))) # ZeroGPU supports root-level CUDA model placement through CUDA emulation. # Keep the models global, but avoid torch.compile because ZeroGPU does not support it. LOGGER.info("Initializing UnifiedMultimodalEditor...") config = replace( PipelineConfig.from_env(), compile_flux_transformer=False, enable_cpu_offload=False, ) editor = UnifiedMultimodalEditor(config) def get_editor(): return editor def message_text(value) -> str: if isinstance(value, str): return value if isinstance(value, dict): for key in ("content", "text", "value"): if key in value: return message_text(value.get(key)) if isinstance(value, (list, tuple)): for item in reversed(value): text = message_text(item) if text: return text if hasattr(value, "content"): return message_text(getattr(value, "content")) return "" if value is None else str(value) def message_role(value) -> str | None: if isinstance(value, dict): return value.get("role") if hasattr(value, "role"): return getattr(value, "role") return None def normalize_chat_history(history) -> list[dict[str, str]]: normalized = [] for item in history or []: role = message_role(item) if role in {"user", "assistant"}: normalized.append({"role": role, "content": message_text(item)}) elif isinstance(item, (list, tuple)) and len(item) >= 2: user_text = message_text(item[0]) assistant_text = message_text(item[1]) if user_text: normalized.append({"role": "user", "content": user_text}) if assistant_text: normalized.append({"role": "assistant", "content": assistant_text}) return normalized def is_edit_instruction(text) -> bool: text_lower = message_text(text).lower().strip() edit_keywords = ["change", "make", "add", "remove", "turn", "replace"] return any(text_lower.startswith(kw) for kw in edit_keywords) def normalize_image(image: Image.Image) -> Image.Image: return image.convert("RGB").resize((config.width, config.height), Image.Resampling.LANCZOS) def trim_history(history: list[Image.Image]) -> list[Image.Image]: if len(history) <= MAX_HISTORY: return history return [history[0], *history[-(MAX_HISTORY - 1):]] @spaces.GPU(duration=ZERO_GPU_DURATION, size=ZERO_GPU_SIZE) def chat_interface(message, image_state: dict, progress: gr.Progress = None): message = message_text(message) if not image_state or not image_state.get('history'): yield "Please upload an image first.", image_state return img_history = list(image_state['history']) current_image = normalize_image(img_history[-1]) ed = get_editor() if is_edit_instruction(message): if progress: progress(0.1, desc="Initializing Editor...") yield "Processing edit...", image_state try: if progress: progress(0.3, desc="Perception & Localization...") output = ed.run(current_image, message) if progress: progress(0.9, desc="Finalizing Image...") new_image = output.image img_history.append(normalize_image(new_image)) img_history = trim_history(img_history) image_state['history'] = img_history if progress: progress(1.0, desc="Done!") yield f"Edit complete! (Target concept: {output.perception.target_concept})", image_state except Exception as e: LOGGER.error("Error during editing", exc_info=True) gr.Warning(f"Editing failed: {str(e)}") yield f"Error during editing: {str(e)}", image_state else: if progress: progress(0.2, desc="Thinking...") yield "Thinking...", image_state try: response = ed.brain.chat(current_image, message) if progress: progress(1.0, desc="Done!") yield response, image_state except Exception as e: LOGGER.error("Error during chat", exc_info=True) gr.Warning(f"Chat failed: {str(e)}") yield f"Error during chat: {str(e)}", image_state with gr.Blocks(title="Unified Multimodal Editor") as demo: gr.Markdown("# Unified Multimodal Editor (UME)") gr.Markdown("Upload an image, ask for a description, or give instructions to edit it (e.g., 'change the red mug to blue').") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Original Image") current_image_display = gr.Image(type="pil", label="Current Image", interactive=False) with gr.Column(scale=2): chatbot = gr.Chatbot(label="Conversation") msg = gr.Textbox(label="Type a command (e.g., 'describe the image' or 'change the red mug to blue')") with gr.Row(): submit_btn = gr.Button("Submit", variant="primary") revert_btn = gr.Button("Revert Last Change") revert_all_btn = gr.Button("Revert All Changes") # The state holds the history of images image_state = gr.State({"history": []}) def handle_upload(img): if img is None: return {"history": []}, None, [] normalized = normalize_image(img) return {"history": [normalized]}, normalized, [] image_input.upload( handle_upload, inputs=[image_input], outputs=[image_state, current_image_display, chatbot] ) def user(user_message, history): history = normalize_chat_history(history) user_message = message_text(user_message).strip() if not user_message: return "", history return "", history + [{"role": "user", "content": user_message}] def bot(history, state, progress=gr.Progress()): history = normalize_chat_history(history) if not history or history[-1].get("role") != "user": display_img = state["history"][-1] if state and state.get("history") else None yield history, state, display_img return user_message = history[-1]["content"] for response, new_state in chat_interface(user_message, state, progress): history = [*history] if history[-1].get("role") == "user": history.append({"role": "assistant", "content": response}) else: history[-1]["content"] = response display_img = new_state["history"][-1] if new_state and new_state.get("history") else None yield history, new_state, display_img msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, image_state], [chatbot, image_state, current_image_display] ) submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, image_state], [chatbot, image_state, current_image_display] ) def revert_action(action, history, state): history = normalize_chat_history(history) state = state or {"history": []} img_history = list(state.get("history", [])) if not img_history: response = "Please upload an image first." elif action == "revert all changes" and len(img_history) > 1: img_history = [img_history[0]] response = "Reverted all changes. Back to the original image." elif action == "revert changes" and len(img_history) > 1: img_history.pop() response = "Reverted the last change." else: response = "No changes to revert. This is the original image." new_state = {"history": img_history} history = history + [ {"role": "user", "content": action}, {"role": "assistant", "content": response}, ] display_img = img_history[-1] if img_history else None return history, new_state, display_img def revert_last_change(history, state): return revert_action("revert changes", history, state) def revert_all_changes(history, state): return revert_action("revert all changes", history, state) revert_btn.click(revert_last_change, [chatbot, image_state], [chatbot, image_state, current_image_display]) revert_all_btn.click(revert_all_changes, [chatbot, image_state], [chatbot, image_state, current_image_display]) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch( share=os.getenv("GRADIO_SHARE", "0") == "1", )