UMEVision / app.py
akrao9's picture
Upload 19 files
495d05b verified
import logging
import os
from dataclasses import replace
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(*args, **kwargs):
def decorator(func):
return func
return decorator
import gradio as gr
from PIL import Image
from ume_pipeline.config import PipelineConfig
from ume_pipeline.pipeline import UnifiedMultimodalEditor
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
LOGGER = logging.getLogger(__name__)
MAX_HISTORY = max(2, int(os.getenv("UME_MAX_HISTORY", "4")))
ZERO_GPU_SIZE = os.getenv("UME_ZERO_GPU_SIZE", "large")
ZERO_GPU_DURATION = int(os.getenv("UME_ZERO_GPU_DURATION", os.getenv("UME_ZERO_GPU_EDIT_DURATION", "240")))
# ZeroGPU supports root-level CUDA model placement through CUDA emulation.
# Keep the models global, but avoid torch.compile because ZeroGPU does not support it.
LOGGER.info("Initializing UnifiedMultimodalEditor...")
config = replace(
PipelineConfig.from_env(),
compile_flux_transformer=False,
enable_cpu_offload=False,
)
editor = UnifiedMultimodalEditor(config)
def get_editor():
return editor
def message_text(value) -> str:
if isinstance(value, str):
return value
if isinstance(value, dict):
for key in ("content", "text", "value"):
if key in value:
return message_text(value.get(key))
if isinstance(value, (list, tuple)):
for item in reversed(value):
text = message_text(item)
if text:
return text
if hasattr(value, "content"):
return message_text(getattr(value, "content"))
return "" if value is None else str(value)
def message_role(value) -> str | None:
if isinstance(value, dict):
return value.get("role")
if hasattr(value, "role"):
return getattr(value, "role")
return None
def normalize_chat_history(history) -> list[dict[str, str]]:
normalized = []
for item in history or []:
role = message_role(item)
if role in {"user", "assistant"}:
normalized.append({"role": role, "content": message_text(item)})
elif isinstance(item, (list, tuple)) and len(item) >= 2:
user_text = message_text(item[0])
assistant_text = message_text(item[1])
if user_text:
normalized.append({"role": "user", "content": user_text})
if assistant_text:
normalized.append({"role": "assistant", "content": assistant_text})
return normalized
def is_edit_instruction(text) -> bool:
text_lower = message_text(text).lower().strip()
edit_keywords = ["change", "make", "add", "remove", "turn", "replace"]
return any(text_lower.startswith(kw) for kw in edit_keywords)
def normalize_image(image: Image.Image) -> Image.Image:
return image.convert("RGB").resize((config.width, config.height), Image.Resampling.LANCZOS)
def trim_history(history: list[Image.Image]) -> list[Image.Image]:
if len(history) <= MAX_HISTORY:
return history
return [history[0], *history[-(MAX_HISTORY - 1):]]
@spaces.GPU(duration=ZERO_GPU_DURATION, size=ZERO_GPU_SIZE)
def chat_interface(message, image_state: dict, progress: gr.Progress = None):
message = message_text(message)
if not image_state or not image_state.get('history'):
yield "Please upload an image first.", image_state
return
img_history = list(image_state['history'])
current_image = normalize_image(img_history[-1])
ed = get_editor()
if is_edit_instruction(message):
if progress: progress(0.1, desc="Initializing Editor...")
yield "Processing edit...", image_state
try:
if progress: progress(0.3, desc="Perception & Localization...")
output = ed.run(current_image, message)
if progress: progress(0.9, desc="Finalizing Image...")
new_image = output.image
img_history.append(normalize_image(new_image))
img_history = trim_history(img_history)
image_state['history'] = img_history
if progress: progress(1.0, desc="Done!")
yield f"Edit complete! (Target concept: {output.perception.target_concept})", image_state
except Exception as e:
LOGGER.error("Error during editing", exc_info=True)
gr.Warning(f"Editing failed: {str(e)}")
yield f"Error during editing: {str(e)}", image_state
else:
if progress: progress(0.2, desc="Thinking...")
yield "Thinking...", image_state
try:
response = ed.brain.chat(current_image, message)
if progress: progress(1.0, desc="Done!")
yield response, image_state
except Exception as e:
LOGGER.error("Error during chat", exc_info=True)
gr.Warning(f"Chat failed: {str(e)}")
yield f"Error during chat: {str(e)}", image_state
with gr.Blocks(title="Unified Multimodal Editor") as demo:
gr.Markdown("# Unified Multimodal Editor (UME)")
gr.Markdown("Upload an image, ask for a description, or give instructions to edit it (e.g., 'change the red mug to blue').")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Original Image")
current_image_display = gr.Image(type="pil", label="Current Image", interactive=False)
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Conversation")
msg = gr.Textbox(label="Type a command (e.g., 'describe the image' or 'change the red mug to blue')")
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
revert_btn = gr.Button("Revert Last Change")
revert_all_btn = gr.Button("Revert All Changes")
# The state holds the history of images
image_state = gr.State({"history": []})
def handle_upload(img):
if img is None:
return {"history": []}, None, []
normalized = normalize_image(img)
return {"history": [normalized]}, normalized, []
image_input.upload(
handle_upload,
inputs=[image_input],
outputs=[image_state, current_image_display, chatbot]
)
def user(user_message, history):
history = normalize_chat_history(history)
user_message = message_text(user_message).strip()
if not user_message:
return "", history
return "", history + [{"role": "user", "content": user_message}]
def bot(history, state, progress=gr.Progress()):
history = normalize_chat_history(history)
if not history or history[-1].get("role") != "user":
display_img = state["history"][-1] if state and state.get("history") else None
yield history, state, display_img
return
user_message = history[-1]["content"]
for response, new_state in chat_interface(user_message, state, progress):
history = [*history]
if history[-1].get("role") == "user":
history.append({"role": "assistant", "content": response})
else:
history[-1]["content"] = response
display_img = new_state["history"][-1] if new_state and new_state.get("history") else None
yield history, new_state, display_img
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, image_state], [chatbot, image_state, current_image_display]
)
submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, image_state], [chatbot, image_state, current_image_display]
)
def revert_action(action, history, state):
history = normalize_chat_history(history)
state = state or {"history": []}
img_history = list(state.get("history", []))
if not img_history:
response = "Please upload an image first."
elif action == "revert all changes" and len(img_history) > 1:
img_history = [img_history[0]]
response = "Reverted all changes. Back to the original image."
elif action == "revert changes" and len(img_history) > 1:
img_history.pop()
response = "Reverted the last change."
else:
response = "No changes to revert. This is the original image."
new_state = {"history": img_history}
history = history + [
{"role": "user", "content": action},
{"role": "assistant", "content": response},
]
display_img = img_history[-1] if img_history else None
return history, new_state, display_img
def revert_last_change(history, state):
return revert_action("revert changes", history, state)
def revert_all_changes(history, state):
return revert_action("revert all changes", history, state)
revert_btn.click(revert_last_change, [chatbot, image_state], [chatbot, image_state, current_image_display])
revert_all_btn.click(revert_all_changes, [chatbot, image_state], [chatbot, image_state, current_image_display])
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch(
share=os.getenv("GRADIO_SHARE", "0") == "1",
)