Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Utility functions for Depth Anything 3 Gradio app. | |
| This module contains helper functions for data processing, visualization, | |
| and file operations. | |
| """ | |
| import gc | |
| import json | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| def create_depth_visualization(depth: np.ndarray) -> Optional[np.ndarray]: | |
| """ | |
| Create a colored depth visualization. | |
| Args: | |
| depth: Depth array | |
| Returns: | |
| Colored depth visualization or None | |
| """ | |
| if depth is None: | |
| return None | |
| # Normalize depth to 0-1 range | |
| depth_min = depth[depth > 0].min() if (depth > 0).any() else 0 | |
| depth_max = depth.max() | |
| if depth_max <= depth_min: | |
| return None | |
| # Normalize depth | |
| depth_norm = (depth - depth_min) / (depth_max - depth_min) | |
| depth_norm = np.clip(depth_norm, 0, 1) | |
| # Apply colormap (using matplotlib's viridis colormap) | |
| import matplotlib.cm as cm | |
| # Convert to colored image | |
| depth_colored = cm.viridis(depth_norm)[:, :, :3] # Remove alpha channel | |
| depth_colored = (depth_colored * 255).astype(np.uint8) | |
| return depth_colored | |
| def save_to_gallery_func( | |
| target_dir: str, processed_data: Dict[int, Dict[str, Any]], gallery_name: Optional[str] = None | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Save the current reconstruction results to the gallery directory. | |
| Args: | |
| target_dir: Source directory containing reconstruction results | |
| processed_data: Processed data dictionary | |
| gallery_name: Name for the gallery folder | |
| Returns: | |
| Tuple of (success, message) | |
| """ | |
| try: | |
| # Get gallery directory from environment variable or use default | |
| gallery_dir = os.environ.get( | |
| "DA3_GALLERY_DIR", | |
| "workspace/gallery", | |
| ) | |
| if not os.path.exists(gallery_dir): | |
| os.makedirs(gallery_dir) | |
| # Use provided name or create a unique name | |
| if gallery_name is None or gallery_name.strip() == "": | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| gallery_name = f"reconstruction_{timestamp}" | |
| gallery_path = os.path.join(gallery_dir, gallery_name) | |
| # Check if directory already exists | |
| if os.path.exists(gallery_path): | |
| return False, f"Save failed: folder '{gallery_name}' already exists" | |
| # Create the gallery directory | |
| os.makedirs(gallery_path, exist_ok=True) | |
| # Copy GLB file | |
| glb_source = os.path.join(target_dir, "scene.glb") | |
| glb_dest = os.path.join(gallery_path, "scene.glb") | |
| if os.path.exists(glb_source): | |
| shutil.copy2(glb_source, glb_dest) | |
| # Copy depth visualization images | |
| depth_vis_dir = os.path.join(target_dir, "depth_vis") | |
| if os.path.exists(depth_vis_dir): | |
| gallery_depth_vis = os.path.join(gallery_path, "depth_vis") | |
| shutil.copytree(depth_vis_dir, gallery_depth_vis) | |
| # Copy original images | |
| images_source = os.path.join(target_dir, "images") | |
| if os.path.exists(images_source): | |
| gallery_images = os.path.join(gallery_path, "images") | |
| shutil.copytree(images_source, gallery_images) | |
| scene_preview_source = os.path.join(target_dir, "scene.jpg") | |
| scene_preview_dest = os.path.join(gallery_path, "scene.jpg") | |
| shutil.copy2(scene_preview_source, scene_preview_dest) | |
| # Save metadata | |
| metadata = { | |
| "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), | |
| "num_images": len(processed_data) if processed_data else 0, | |
| "gallery_name": gallery_name, | |
| } | |
| with open(os.path.join(gallery_path, "metadata.json"), "w") as f: | |
| json.dump(metadata, f, indent=2) | |
| print(f"Saved reconstruction to gallery: {gallery_path}") | |
| return True, f"Save successful: saved to {gallery_path}" | |
| except Exception as e: | |
| print(f"Error saving to gallery: {e}") | |
| return False, f"Save failed: {str(e)}" | |
| def get_scene_info(examples_dir: str) -> List[Dict[str, Any]]: | |
| """ | |
| Get information about scenes in the examples directory. | |
| Args: | |
| examples_dir: Path to examples directory | |
| Returns: | |
| List of scene information dictionaries | |
| """ | |
| import glob | |
| scenes = [] | |
| if not os.path.exists(examples_dir): | |
| return scenes | |
| for scene_folder in sorted(os.listdir(examples_dir)): | |
| scene_path = os.path.join(examples_dir, scene_folder) | |
| if os.path.isdir(scene_path): | |
| # Find all image files in the scene folder | |
| image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] | |
| image_files = [] | |
| for ext in image_extensions: | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext))) | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) | |
| if image_files: | |
| # Sort images and get the first one for thumbnail | |
| image_files = sorted(image_files) | |
| first_image = image_files[0] | |
| num_images = len(image_files) | |
| scenes.append( | |
| { | |
| "name": scene_folder, | |
| "path": scene_path, | |
| "thumbnail": first_image, | |
| "num_images": num_images, | |
| "image_files": image_files, | |
| } | |
| ) | |
| return scenes | |
| def cleanup_memory() -> None: | |
| """Clean up GPU memory and garbage collect.""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_logo_base64() -> Optional[str]: | |
| """ | |
| Convert WAI logo to base64 for embedding in HTML. | |
| Returns: | |
| Base64 encoded logo string or None | |
| """ | |
| import base64 | |
| logo_path = "examples/WAI-Logo/wai_logo.png" | |
| try: | |
| with open(logo_path, "rb") as img_file: | |
| img_data = img_file.read() | |
| base64_str = base64.b64encode(img_data).decode() | |
| return f"data:image/png;base64,{base64_str}" | |
| except FileNotFoundError: | |
| return None | |