|
|
from __future__ import annotations |
|
|
import os, io, tempfile, mimetypes |
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal |
|
|
|
|
|
import gradio as gr |
|
|
import PIL |
|
|
import time |
|
|
from PIL import Image as PILImage |
|
|
|
|
|
FilePath = str |
|
|
ImageLike = Union["PIL.Image.Image", Any] |
|
|
|
|
|
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"} |
|
|
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"} |
|
|
|
|
|
def get_state(state): |
|
|
return state if isinstance(state, dict) else state.value |
|
|
|
|
|
def get_list( objs): |
|
|
if objs is None: |
|
|
return [] |
|
|
return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs] |
|
|
|
|
|
def record_last_action(st, last_action): |
|
|
st["last_action"] = last_action |
|
|
st["last_time"] = time.time() |
|
|
class AdvancedMediaGallery: |
|
|
def __init__( |
|
|
self, |
|
|
label: str = "Media", |
|
|
*, |
|
|
media_mode: Literal["image", "video"] = "image", |
|
|
height = None, |
|
|
columns: Union[int, Tuple[int, ...]] = 6, |
|
|
show_label: bool = True, |
|
|
initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None, |
|
|
elem_id: Optional[str] = None, |
|
|
elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",), |
|
|
accept_filter: bool = True, |
|
|
single_image_mode: bool = False, |
|
|
): |
|
|
assert media_mode in ("image", "video") |
|
|
self.label = label |
|
|
self.media_mode = media_mode |
|
|
self.height = height |
|
|
self.columns = columns |
|
|
self.show_label = show_label |
|
|
self.elem_id = elem_id |
|
|
self.elem_classes = list(elem_classes) if elem_classes else None |
|
|
self.accept_filter = accept_filter |
|
|
|
|
|
items = self._normalize_initial(initial or [], media_mode) |
|
|
|
|
|
|
|
|
self.container: Optional[gr.Column] = None |
|
|
self.gallery: Optional[gr.Gallery] = None |
|
|
self.upload_btn: Optional[gr.UploadButton] = None |
|
|
self.btn_remove: Optional[gr.Button] = None |
|
|
self.btn_left: Optional[gr.Button] = None |
|
|
self.btn_right: Optional[gr.Button] = None |
|
|
self.btn_clear: Optional[gr.Button] = None |
|
|
|
|
|
|
|
|
self.state: Optional[gr.State] = None |
|
|
self._initial_state: Dict[str, Any] = { |
|
|
"items": items, |
|
|
"selected": (len(items) - 1) if items else 0, |
|
|
"single": bool(single_image_mode), |
|
|
"mode": self.media_mode, |
|
|
"last_action": "", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]: |
|
|
out: List[Any] = [] |
|
|
if mode == "image": |
|
|
for it in items: |
|
|
p = self._ensure_image_item(it) |
|
|
if p is not None: |
|
|
out.append(p) |
|
|
else: |
|
|
for it in items: |
|
|
if isinstance(item, tuple): item = item[0] |
|
|
if isinstance(it, str) and self._is_video_path(it): |
|
|
out.append(os.path.abspath(it)) |
|
|
return out |
|
|
|
|
|
def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]: |
|
|
|
|
|
if isinstance(item, tuple): item = item[0] |
|
|
if isinstance(item, str): |
|
|
return os.path.abspath(item) if self._is_image_path(item) else None |
|
|
if PILImage is None: |
|
|
return None |
|
|
try: |
|
|
if isinstance(item, PILImage.Image): |
|
|
img = item |
|
|
else: |
|
|
import numpy as np |
|
|
if isinstance(item, np.ndarray): |
|
|
img = PILImage.fromarray(item) |
|
|
elif hasattr(item, "read"): |
|
|
data = item.read() |
|
|
img = PILImage.open(io.BytesIO(data)).convert("RGBA") |
|
|
else: |
|
|
return None |
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
|
|
img.save(tmp.name) |
|
|
return tmp.name |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _extract_path(obj: Any) -> Optional[str]: |
|
|
|
|
|
if isinstance(obj, str): |
|
|
return obj |
|
|
try: |
|
|
import pathlib |
|
|
if isinstance(obj, pathlib.Path): |
|
|
return str(obj) |
|
|
except Exception: |
|
|
pass |
|
|
if isinstance(obj, dict): |
|
|
return obj.get("path") or obj.get("name") |
|
|
for attr in ("path", "name"): |
|
|
if hasattr(obj, attr): |
|
|
try: |
|
|
val = getattr(obj, attr) |
|
|
if isinstance(val, str): |
|
|
return val |
|
|
except Exception: |
|
|
pass |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _is_image_path(p: str) -> bool: |
|
|
ext = os.path.splitext(p)[1].lower() |
|
|
if ext in IMAGE_EXTS: |
|
|
return True |
|
|
mt, _ = mimetypes.guess_type(p) |
|
|
return bool(mt and mt.startswith("image/")) |
|
|
|
|
|
@staticmethod |
|
|
def _is_video_path(p: str) -> bool: |
|
|
ext = os.path.splitext(p)[1].lower() |
|
|
if ext in VIDEO_EXTS: |
|
|
return True |
|
|
mt, _ = mimetypes.guess_type(p) |
|
|
return bool(mt and mt.startswith("video/")) |
|
|
|
|
|
def _filter_items_by_mode(self, items: List[Any]) -> List[Any]: |
|
|
|
|
|
out: List[Any] = [] |
|
|
if self.media_mode == "image": |
|
|
for it in items: |
|
|
p = self._extract_path(it) |
|
|
if p is None: |
|
|
|
|
|
out.append(it) |
|
|
elif self._is_image_path(p): |
|
|
out.append(os.path.abspath(p)) |
|
|
else: |
|
|
for it in items: |
|
|
p = self._extract_path(it) |
|
|
if p is not None and self._is_video_path(p): |
|
|
out.append(os.path.abspath(p)) |
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]: |
|
|
|
|
|
seen_paths = set() |
|
|
def key(x: Any) -> Optional[str]: |
|
|
if isinstance(x, str): return os.path.abspath(x) |
|
|
try: |
|
|
import pathlib |
|
|
if isinstance(x, pathlib.Path): |
|
|
return os.path.abspath(str(x)) |
|
|
except Exception: |
|
|
pass |
|
|
if isinstance(x, dict): |
|
|
p = x.get("path") or x.get("name") |
|
|
return os.path.abspath(p) if isinstance(p, str) else None |
|
|
for attr in ("path", "name"): |
|
|
if hasattr(x, attr): |
|
|
try: |
|
|
v = getattr(x, attr) |
|
|
return os.path.abspath(v) if isinstance(v, str) else None |
|
|
except Exception: |
|
|
pass |
|
|
return None |
|
|
|
|
|
out: List[Any] = [] |
|
|
for lst in (cur, add): |
|
|
for it in lst: |
|
|
k = key(it) |
|
|
if k is None or k not in seen_paths: |
|
|
out.append(it) |
|
|
if k is not None: |
|
|
seen_paths.add(k) |
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
def _paths_from_payload(payload: Any) -> List[Any]: |
|
|
|
|
|
if payload is None: |
|
|
return [] |
|
|
if isinstance(payload, (list, tuple, set)): |
|
|
return list(payload) |
|
|
return [payload] |
|
|
|
|
|
|
|
|
|
|
|
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : |
|
|
|
|
|
|
|
|
st = get_state(state) |
|
|
last_time = st.get("last_time", None) |
|
|
if last_time is not None and abs(time.time()- last_time)< 0.5: |
|
|
|
|
|
return gr.update(selected_index=st["selected"]), st |
|
|
|
|
|
idx = None |
|
|
if evt is not None and hasattr(evt, "index"): |
|
|
ix = evt.index |
|
|
if isinstance(ix, int): |
|
|
idx = ix |
|
|
elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int): |
|
|
if isinstance(self.columns, int) and len(ix) >= 2: |
|
|
idx = ix[0] * max(1, int(self.columns)) + ix[1] |
|
|
else: |
|
|
idx = ix[0] |
|
|
n = len(get_list(gallery)) |
|
|
sel = idx if (idx is not None and 0 <= idx < n) else None |
|
|
|
|
|
st["selected"] = sel |
|
|
return gr.update(), st |
|
|
|
|
|
def _on_upload(self, value: List[Any], state: Dict[str, Any]) : |
|
|
|
|
|
|
|
|
items_filtered = list(value or []) |
|
|
st = get_state(state) |
|
|
new_items = self._paths_from_payload(items_filtered) |
|
|
st["items"] = new_items |
|
|
new_sel = len(new_items) - 1 |
|
|
st["selected"] = new_sel |
|
|
record_last_action(st,"add") |
|
|
return gr.update(selected_index=new_sel), st |
|
|
|
|
|
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : |
|
|
|
|
|
|
|
|
items_filtered = list(value or []) |
|
|
st = get_state(state) |
|
|
st["items"] = items_filtered |
|
|
|
|
|
old_sel = st.get("selected", None) |
|
|
if old_sel is None or not (0 <= old_sel < len(items_filtered)): |
|
|
new_sel = (len(items_filtered) - 1) if items_filtered else None |
|
|
else: |
|
|
new_sel = old_sel |
|
|
st["selected"] = new_sel |
|
|
st["last_action"] ="gallery_change" |
|
|
|
|
|
return gr.update(selected_index=new_sel), st |
|
|
|
|
|
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): |
|
|
""" |
|
|
Insert added items right AFTER the currently selected index. |
|
|
Keeps the same ordering as chosen in the file picker, dedupes by path, |
|
|
and re-selects the last inserted item. |
|
|
""" |
|
|
|
|
|
|
|
|
new_items = self._paths_from_payload(files_payload) |
|
|
|
|
|
st = get_state(state) |
|
|
cur: List[Any] = get_list(gallery) |
|
|
sel = st.get("selected", None) |
|
|
if sel is None: |
|
|
sel = (len(cur) -1) if len(cur)>0 else 0 |
|
|
single = bool(st.get("single", False)) |
|
|
|
|
|
|
|
|
if not new_items: |
|
|
return gr.update(value=cur, selected_index=st.get("selected")), st |
|
|
|
|
|
|
|
|
if single: |
|
|
st["items"] = [new_items[-1]] |
|
|
st["selected"] = 0 |
|
|
return gr.update(value=st["items"], selected_index=0), st |
|
|
|
|
|
|
|
|
def key_of(it: Any) -> Optional[str]: |
|
|
|
|
|
if hasattr(self, "_extract_path"): |
|
|
p = self._extract_path(it) |
|
|
else: |
|
|
p = it if isinstance(it, str) else None |
|
|
if p is None and isinstance(it, dict): |
|
|
p = it.get("path") or it.get("name") |
|
|
if p is None and hasattr(it, "path"): |
|
|
try: p = getattr(it, "path") |
|
|
except Exception: p = None |
|
|
if p is None and hasattr(it, "name"): |
|
|
try: p = getattr(it, "name") |
|
|
except Exception: p = None |
|
|
return os.path.abspath(p) if isinstance(p, str) else None |
|
|
|
|
|
|
|
|
seen_new = set() |
|
|
incoming: List[Any] = [] |
|
|
for it in new_items: |
|
|
k = key_of(it) |
|
|
if k is None or k not in seen_new: |
|
|
incoming.append(it) |
|
|
if k is not None: |
|
|
seen_new.add(k) |
|
|
|
|
|
insert_pos = min(sel, len(cur) -1) |
|
|
cur_clean = cur |
|
|
|
|
|
merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:] |
|
|
new_sel = insert_pos + len(incoming) |
|
|
|
|
|
st["items"] = merged |
|
|
st["selected"] = new_sel |
|
|
record_last_action(st,"add") |
|
|
|
|
|
return gr.update(value=merged, selected_index=new_sel), st |
|
|
|
|
|
def _on_remove(self, state: Dict[str, Any], gallery) : |
|
|
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) |
|
|
if sel is None or not (0 <= sel < len(items)): |
|
|
return gr.update(value=items, selected_index=st.get("selected")), st |
|
|
items.pop(sel) |
|
|
if not items: |
|
|
st["items"] = []; st["selected"] = None |
|
|
return gr.update(value=[], selected_index=None), st |
|
|
new_sel = min(sel, len(items) - 1) |
|
|
st["items"] = items; st["selected"] = new_sel |
|
|
record_last_action(st,"remove") |
|
|
|
|
|
return gr.update(value=items, selected_index=new_sel), st |
|
|
|
|
|
def _on_move(self, delta: int, state: Dict[str, Any], gallery) : |
|
|
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) |
|
|
if sel is None or not (0 <= sel < len(items)): |
|
|
return gr.update(value=items, selected_index=sel), st |
|
|
j = sel + delta |
|
|
if j < 0 or j >= len(items): |
|
|
return gr.update(value=items, selected_index=sel), st |
|
|
items[sel], items[j] = items[j], items[sel] |
|
|
st["items"] = items; st["selected"] = j |
|
|
record_last_action(st,"move") |
|
|
|
|
|
return gr.update(value=items, selected_index=j), st |
|
|
|
|
|
def _on_clear(self, state: Dict[str, Any]) : |
|
|
st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode} |
|
|
record_last_action(st,"clear") |
|
|
|
|
|
return gr.update(value=[], selected_index=None), st |
|
|
|
|
|
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : |
|
|
st = get_state(state); st["single"] = bool(to_single) |
|
|
items: List[Any] = list(st["items"]); sel = st.get("selected", None) |
|
|
if st["single"]: |
|
|
keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None) |
|
|
items = [keep] if keep is not None else [] |
|
|
sel = 0 if items else None |
|
|
st["items"] = items; st["selected"] = sel |
|
|
|
|
|
upload_update = gr.update(file_count=("single" if st["single"] else "multiple")) |
|
|
left_update = gr.update(visible=not st["single"]) |
|
|
right_update = gr.update(visible=not st["single"]) |
|
|
clear_update = gr.update(visible=not st["single"]) |
|
|
gallery_update= gr.update(value=items, selected_index=sel) |
|
|
|
|
|
return upload_update, left_update, right_update, clear_update, gallery_update, st |
|
|
|
|
|
|
|
|
|
|
|
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False): |
|
|
if parent is not None: |
|
|
with parent: |
|
|
col = self._build_ui(update_form) |
|
|
else: |
|
|
col = self._build_ui(update_form) |
|
|
if not update_form: |
|
|
self._wire_events() |
|
|
return col |
|
|
|
|
|
def _build_ui(self, update = False) -> gr.Column: |
|
|
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col: |
|
|
self.container = col |
|
|
|
|
|
self.state = gr.State(dict(self._initial_state)) |
|
|
|
|
|
if update: |
|
|
self.gallery = gr.update( |
|
|
value=self._initial_state["items"], |
|
|
selected_index=self._initial_state["selected"], |
|
|
label=self.label, |
|
|
show_label=self.show_label, |
|
|
) |
|
|
else: |
|
|
self.gallery = gr.Gallery( |
|
|
value=self._initial_state["items"], |
|
|
label=self.label, |
|
|
height=self.height, |
|
|
columns=self.columns, |
|
|
show_label=self.show_label, |
|
|
preview= True, |
|
|
|
|
|
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), |
|
|
selected_index=self._initial_state["selected"], |
|
|
) |
|
|
|
|
|
|
|
|
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None |
|
|
with gr.Row(equal_height=True, elem_classes=["amg-controls"]): |
|
|
self.upload_btn = gr.UploadButton( |
|
|
"Set" if self._initial_state["single"] else "Add", |
|
|
file_types=exts, |
|
|
file_count=("single" if self._initial_state["single"] else "multiple"), |
|
|
variant="primary", |
|
|
size="sm", |
|
|
min_width=1, |
|
|
) |
|
|
self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1) |
|
|
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1) |
|
|
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1) |
|
|
self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1) |
|
|
|
|
|
return col |
|
|
|
|
|
def _wire_events(self): |
|
|
|
|
|
self.gallery.select( |
|
|
self._on_select, |
|
|
inputs=[self.state, self.gallery], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
|
|
|
|
|
|
self.gallery.upload( |
|
|
self._on_upload, |
|
|
inputs=[self.gallery, self.state], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
|
|
|
|
|
|
self.gallery.upload( |
|
|
self._on_gallery_change, |
|
|
inputs=[self.gallery, self.state], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
|
|
|
|
|
|
self.upload_btn.upload( |
|
|
self._on_add, |
|
|
inputs=[self.upload_btn, self.state, self.gallery], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
|
|
|
|
|
|
self.btn_remove.click( |
|
|
self._on_remove, |
|
|
inputs=[self.state, self.gallery], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
|
|
|
|
|
|
self.btn_left.click( |
|
|
lambda st, gallery: self._on_move(-1, st, gallery), |
|
|
inputs=[self.state, self.gallery], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
self.btn_right.click( |
|
|
lambda st, gallery: self._on_move(+1, st, gallery), |
|
|
inputs=[self.state, self.gallery], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
|
|
|
|
|
|
self.btn_clear.click( |
|
|
self._on_clear, |
|
|
inputs=[self.state], |
|
|
outputs=[self.gallery, self.state], |
|
|
trigger_mode="always_last", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def set_one_image_mode(self, enabled: bool = True): |
|
|
"""Toggle single-image mode at runtime.""" |
|
|
return ( |
|
|
self._on_toggle_single, |
|
|
[gr.State(enabled), self.state], |
|
|
[self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state], |
|
|
) |
|
|
|
|
|
def get_toggable_elements(self): |
|
|
return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|