import gc import uuid from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import gradio as gr import torch from diffusers import ShapEPipeline from PIL import Image MODEL_ID = "openai/shap-e" pipe = None DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 ROOT_DIR = Path(__file__).resolve().parent DATA_DIR = ROOT_DIR / "data" ASSETS_DIR = DATA_DIR / "assets" ASSETS_DIR.mkdir(parents=True, exist_ok=True) EXAMPLES = [ "A cute stylized robot with a round head", "A fantasy treasure chest with gold trim", "A small dragon figurine, toy-like, colorful", "A low-poly medieval house", "A ceramic teapot shaped like an owl", "A cartoon submarine with tiny windows", ] def get_pipeline(): global pipe if pipe is None: kwargs = {"torch_dtype": DTYPE} if DEVICE == "cuda": kwargs["variant"] = "fp16" pipe = ShapEPipeline.from_pretrained(MODEL_ID, **kwargs) pipe = pipe.to(DEVICE) return pipe def make_white_background_transparent(frame: Image.Image, threshold: int = 245) -> Image.Image: """ Делает почти-белый фон прозрачным. Если R, G и B все >= threshold, пиксель считаем фоном. """ img = frame.convert("RGBA") data = img.getdata() new_data = [] for r, g, b, a in data: if r >= threshold and g >= threshold and b >= threshold: new_data.append((255, 255, 255, 0)) else: new_data.append((r, g, b, a)) img.putdata(new_data) return img def crop_to_nontransparent_content(img: Image.Image, padding: int = 8) -> Image.Image: """ Обрезает лишние прозрачные поля вокруг объекта. """ alpha = img.getchannel("A") bbox = alpha.getbbox() if bbox is None: return img left, top, right, bottom = bbox left = max(0, left - padding) top = max(0, top - padding) right = min(img.width, right + padding) bottom = min(img.height, bottom + padding) return img.crop((left, top, right, bottom)) def save_frames_to_files(frames, prompt: str) -> List[str]: asset_id = f"asset_{uuid.uuid4().hex[:8]}" asset_dir = ASSETS_DIR / asset_id asset_dir.mkdir(parents=True, exist_ok=True) frame_paths = [] for i, frame in enumerate(frames): img = frame.convert("RGBA") img = make_white_background_transparent(img, threshold=245) img = crop_to_nontransparent_content(img, padding=8) frame_path = asset_dir / f"view_{i:03d}.png" img.save(frame_path) frame_paths.append(str(frame_path)) return frame_paths def make_asset(prompt: str, frame_paths: List[str]) -> Dict[str, Any]: return { "prompt": prompt, "frame_paths": frame_paths, "selected_index": 0, } def gallery_items_from_assets(saved_assets: List[Dict[str, Any]]) -> List[Tuple[str, str]]: items = [] for i, asset in enumerate(saved_assets): frame_paths = asset.get("frame_paths", []) if not frame_paths: continue idx = int(asset.get("selected_index", 0)) idx = max(0, min(idx, len(frame_paths) - 1)) caption = f"{i + 1}. {asset.get('prompt', '')}" items.append((frame_paths[idx], caption)) return items def current_view_from_selected( saved_assets: List[Dict[str, Any]], selected_asset_index: Optional[int], ): if selected_asset_index is None: return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) asset = saved_assets[selected_asset_index] frame_paths = asset.get("frame_paths", []) if not frame_paths: return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) idx = int(asset.get("selected_index", 0)) idx = max(0, min(idx, len(frame_paths) - 1)) label = f"Asset {selected_asset_index + 1} · View {idx + 1} / {len(frame_paths)}" return frame_paths[idx], label, gr.update(interactive=True), gr.update(interactive=True) def selected_view_path( saved_assets: List[Dict[str, Any]], selected_asset_index: Optional[int], ): if selected_asset_index is None: return None if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): return None asset = saved_assets[selected_asset_index] frame_paths = asset.get("frame_paths", []) if not frame_paths: return None idx = int(asset.get("selected_index", 0)) idx = max(0, min(idx, len(frame_paths) - 1)) return frame_paths[idx] def generate_and_add_asset( prompt: str, steps: int, guidance_scale: float, frame_size: int, seed: int, saved_assets: List[Dict[str, Any]], ): prompt = (prompt or "").strip() if not prompt: raise gr.Error("Prompt is empty.") saved_assets = saved_assets or [] pipeline = get_pipeline() generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) result = pipeline( prompt, guidance_scale=float(guidance_scale), num_inference_steps=int(steps), frame_size=int(frame_size), generator=generator, ) frames = result.images[0] frame_paths = save_frames_to_files(frames, prompt) new_asset = make_asset(prompt, frame_paths) saved_assets = saved_assets + [new_asset] selected_asset_index = len(saved_assets) - 1 gallery_items = gallery_items_from_assets(saved_assets) current_view, view_text, prev_btn, next_btn = current_view_from_selected( saved_assets, selected_asset_index ) gc.collect() if DEVICE == "cuda": torch.cuda.empty_cache() return saved_assets, selected_asset_index, gallery_items, current_view, view_text, prev_btn, next_btn def select_asset( saved_assets: List[Dict[str, Any]], evt: gr.SelectData, ): if not saved_assets: return None, [], None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) if evt is None or evt.index is None: return None, gallery_items_from_assets(saved_assets), None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) idx = evt.index if isinstance(idx, (list, tuple)): idx = idx[0] idx = int(idx) current_view, view_text, prev_btn, next_btn = current_view_from_selected(saved_assets, idx) gallery_items = gallery_items_from_assets(saved_assets) return idx, gallery_items, current_view, view_text, prev_btn, next_btn def prev_view( saved_assets: List[Dict[str, Any]], selected_asset_index: Optional[int], ): if selected_asset_index is None: raise gr.Error("Select an asset in the gallery first.") if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): raise gr.Error("Select an asset in the gallery first.") asset = saved_assets[selected_asset_index] frame_paths = asset.get("frame_paths", []) if not frame_paths: raise gr.Error("Selected asset has no frames.") idx = int(asset.get("selected_index", 0)) idx = (idx - 1) % len(frame_paths) asset["selected_index"] = idx gallery_items = gallery_items_from_assets(saved_assets) current_view, view_text, prev_btn, next_btn = current_view_from_selected( saved_assets, selected_asset_index ) return saved_assets, gallery_items, current_view, view_text, prev_btn, next_btn def next_view( saved_assets: List[Dict[str, Any]], selected_asset_index: Optional[int], ): if selected_asset_index is None: raise gr.Error("Select an asset in the gallery first.") if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): raise gr.Error("Select an asset in the gallery first.") asset = saved_assets[selected_asset_index] frame_paths = asset.get("frame_paths", []) if not frame_paths: raise gr.Error("Selected asset has no frames.") idx = int(asset.get("selected_index", 0)) idx = (idx + 1) % len(frame_paths) asset["selected_index"] = idx gallery_items = gallery_items_from_assets(saved_assets) current_view, view_text, prev_btn, next_btn = current_view_from_selected( saved_assets, selected_asset_index ) return saved_assets, gallery_items, current_view, view_text, prev_btn, next_btn def clear_saved_assets(): return [], None, [], None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) def set_prompt(value: str): return value