Spaces:
Sleeping
Sleeping
| import gc | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import torch | |
| from diffusers import ShapEPipeline | |
| from PIL import Image | |
| MODEL_ID = "openai/shap-e" | |
| pipe = None | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| ROOT_DIR = Path(__file__).resolve().parent | |
| DATA_DIR = ROOT_DIR / "data" | |
| ASSETS_DIR = DATA_DIR / "assets" | |
| ASSETS_DIR.mkdir(parents=True, exist_ok=True) | |
| EXAMPLES = [ | |
| "A cute stylized robot with a round head", | |
| "A fantasy treasure chest with gold trim", | |
| "A small dragon figurine, toy-like, colorful", | |
| "A low-poly medieval house", | |
| "A ceramic teapot shaped like an owl", | |
| "A cartoon submarine with tiny windows", | |
| ] | |
| def get_pipeline(): | |
| global pipe | |
| if pipe is None: | |
| kwargs = {"torch_dtype": DTYPE} | |
| if DEVICE == "cuda": | |
| kwargs["variant"] = "fp16" | |
| pipe = ShapEPipeline.from_pretrained(MODEL_ID, **kwargs) | |
| pipe = pipe.to(DEVICE) | |
| return pipe | |
| def make_white_background_transparent(frame: Image.Image, threshold: int = 245) -> Image.Image: | |
| """ | |
| Делает почти-белый фон прозрачным. | |
| Если R, G и B все >= threshold, пиксель считаем фоном. | |
| """ | |
| img = frame.convert("RGBA") | |
| data = img.getdata() | |
| new_data = [] | |
| for r, g, b, a in data: | |
| if r >= threshold and g >= threshold and b >= threshold: | |
| new_data.append((255, 255, 255, 0)) | |
| else: | |
| new_data.append((r, g, b, a)) | |
| img.putdata(new_data) | |
| return img | |
| def crop_to_nontransparent_content(img: Image.Image, padding: int = 8) -> Image.Image: | |
| """ | |
| Обрезает лишние прозрачные поля вокруг объекта. | |
| """ | |
| alpha = img.getchannel("A") | |
| bbox = alpha.getbbox() | |
| if bbox is None: | |
| return img | |
| left, top, right, bottom = bbox | |
| left = max(0, left - padding) | |
| top = max(0, top - padding) | |
| right = min(img.width, right + padding) | |
| bottom = min(img.height, bottom + padding) | |
| return img.crop((left, top, right, bottom)) | |
| def save_frames_to_files(frames, prompt: str) -> List[str]: | |
| asset_id = f"asset_{uuid.uuid4().hex[:8]}" | |
| asset_dir = ASSETS_DIR / asset_id | |
| asset_dir.mkdir(parents=True, exist_ok=True) | |
| frame_paths = [] | |
| for i, frame in enumerate(frames): | |
| img = frame.convert("RGBA") | |
| img = make_white_background_transparent(img, threshold=245) | |
| img = crop_to_nontransparent_content(img, padding=8) | |
| frame_path = asset_dir / f"view_{i:03d}.png" | |
| img.save(frame_path) | |
| frame_paths.append(str(frame_path)) | |
| return frame_paths | |
| def make_asset(prompt: str, frame_paths: List[str]) -> Dict[str, Any]: | |
| return { | |
| "prompt": prompt, | |
| "frame_paths": frame_paths, | |
| "selected_index": 0, | |
| } | |
| def gallery_items_from_assets(saved_assets: List[Dict[str, Any]]) -> List[Tuple[str, str]]: | |
| items = [] | |
| for i, asset in enumerate(saved_assets): | |
| frame_paths = asset.get("frame_paths", []) | |
| if not frame_paths: | |
| continue | |
| idx = int(asset.get("selected_index", 0)) | |
| idx = max(0, min(idx, len(frame_paths) - 1)) | |
| caption = f"{i + 1}. {asset.get('prompt', '')}" | |
| items.append((frame_paths[idx], caption)) | |
| return items | |
| def current_view_from_selected( | |
| saved_assets: List[Dict[str, Any]], | |
| selected_asset_index: Optional[int], | |
| ): | |
| if selected_asset_index is None: | |
| return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) | |
| if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): | |
| return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) | |
| asset = saved_assets[selected_asset_index] | |
| frame_paths = asset.get("frame_paths", []) | |
| if not frame_paths: | |
| return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) | |
| idx = int(asset.get("selected_index", 0)) | |
| idx = max(0, min(idx, len(frame_paths) - 1)) | |
| label = f"Asset {selected_asset_index + 1} · View {idx + 1} / {len(frame_paths)}" | |
| return frame_paths[idx], label, gr.update(interactive=True), gr.update(interactive=True) | |
| def selected_view_path( | |
| saved_assets: List[Dict[str, Any]], | |
| selected_asset_index: Optional[int], | |
| ): | |
| if selected_asset_index is None: | |
| return None | |
| if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): | |
| return None | |
| asset = saved_assets[selected_asset_index] | |
| frame_paths = asset.get("frame_paths", []) | |
| if not frame_paths: | |
| return None | |
| idx = int(asset.get("selected_index", 0)) | |
| idx = max(0, min(idx, len(frame_paths) - 1)) | |
| return frame_paths[idx] | |
| def generate_and_add_asset( | |
| prompt: str, | |
| steps: int, | |
| guidance_scale: float, | |
| frame_size: int, | |
| seed: int, | |
| saved_assets: List[Dict[str, Any]], | |
| ): | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| raise gr.Error("Prompt is empty.") | |
| saved_assets = saved_assets or [] | |
| pipeline = get_pipeline() | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| result = pipeline( | |
| prompt, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(steps), | |
| frame_size=int(frame_size), | |
| generator=generator, | |
| ) | |
| frames = result.images[0] | |
| frame_paths = save_frames_to_files(frames, prompt) | |
| new_asset = make_asset(prompt, frame_paths) | |
| saved_assets = saved_assets + [new_asset] | |
| selected_asset_index = len(saved_assets) - 1 | |
| gallery_items = gallery_items_from_assets(saved_assets) | |
| current_view, view_text, prev_btn, next_btn = current_view_from_selected( | |
| saved_assets, selected_asset_index | |
| ) | |
| gc.collect() | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| return saved_assets, selected_asset_index, gallery_items, current_view, view_text, prev_btn, next_btn | |
| def select_asset( | |
| saved_assets: List[Dict[str, Any]], | |
| evt: gr.SelectData, | |
| ): | |
| if not saved_assets: | |
| return None, [], None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) | |
| if evt is None or evt.index is None: | |
| return None, gallery_items_from_assets(saved_assets), None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) | |
| idx = evt.index | |
| if isinstance(idx, (list, tuple)): | |
| idx = idx[0] | |
| idx = int(idx) | |
| current_view, view_text, prev_btn, next_btn = current_view_from_selected(saved_assets, idx) | |
| gallery_items = gallery_items_from_assets(saved_assets) | |
| return idx, gallery_items, current_view, view_text, prev_btn, next_btn | |
| def prev_view( | |
| saved_assets: List[Dict[str, Any]], | |
| selected_asset_index: Optional[int], | |
| ): | |
| if selected_asset_index is None: | |
| raise gr.Error("Select an asset in the gallery first.") | |
| if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): | |
| raise gr.Error("Select an asset in the gallery first.") | |
| asset = saved_assets[selected_asset_index] | |
| frame_paths = asset.get("frame_paths", []) | |
| if not frame_paths: | |
| raise gr.Error("Selected asset has no frames.") | |
| idx = int(asset.get("selected_index", 0)) | |
| idx = (idx - 1) % len(frame_paths) | |
| asset["selected_index"] = idx | |
| gallery_items = gallery_items_from_assets(saved_assets) | |
| current_view, view_text, prev_btn, next_btn = current_view_from_selected( | |
| saved_assets, selected_asset_index | |
| ) | |
| return saved_assets, gallery_items, current_view, view_text, prev_btn, next_btn | |
| def next_view( | |
| saved_assets: List[Dict[str, Any]], | |
| selected_asset_index: Optional[int], | |
| ): | |
| if selected_asset_index is None: | |
| raise gr.Error("Select an asset in the gallery first.") | |
| if selected_asset_index < 0 or selected_asset_index >= len(saved_assets): | |
| raise gr.Error("Select an asset in the gallery first.") | |
| asset = saved_assets[selected_asset_index] | |
| frame_paths = asset.get("frame_paths", []) | |
| if not frame_paths: | |
| raise gr.Error("Selected asset has no frames.") | |
| idx = int(asset.get("selected_index", 0)) | |
| idx = (idx + 1) % len(frame_paths) | |
| asset["selected_index"] = idx | |
| gallery_items = gallery_items_from_assets(saved_assets) | |
| current_view, view_text, prev_btn, next_btn = current_view_from_selected( | |
| saved_assets, selected_asset_index | |
| ) | |
| return saved_assets, gallery_items, current_view, view_text, prev_btn, next_btn | |
| def clear_saved_assets(): | |
| return [], None, [], None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False) | |
| def set_prompt(value: str): | |
| return value |