depth-anything-3 / simple_app.py
harshilawign's picture
Fix API method name: use inference() instead of infer_image()
16d14b6
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simplified Depth Anything 3 App.
Input: ZIP file of images
Output: ZIP file of raw depth images (.npy format)
"""
import os
import shutil
import time
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple
import gradio as gr
import numpy as np
import torch
from PIL import Image
from depth_anything_3.api import DepthAnything3
class SimpleDepthApp:
"""Simple depth prediction app - zip in, zip out."""
def __init__(self, model_dir: str = None):
"""Initialize the app."""
self.model_dir = model_dir or os.environ.get(
"DA3_MODEL_DIR", "depth-anything/DA3NESTED-GIANT-LARGE"
)
self.model = None
self.workspace_dir = "workspace/simple_app"
os.makedirs(self.workspace_dir, exist_ok=True)
def load_model(self):
"""Load the depth prediction model."""
if self.model is None:
print(f"Loading model from {self.model_dir}...")
self.model = DepthAnything3.from_pretrained(self.model_dir)
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(device)
self.model.eval()
print(f"Model loaded on {device}")
return self.model
def process_zip(self, zip_file: str) -> Tuple[Optional[str], dict]:
"""
Process a zip file of images and return a zip file of depth maps with metrics.
Args:
zip_file: Path to uploaded zip file
Returns:
Tuple of (output zip path, metrics dict)
"""
if zip_file is None:
return None, {}
# Initialize metrics
metrics = {
'total_images': 0,
'total_inference_time': 0.0,
'per_image_times': [],
'avg_time_per_image': 0.0,
'file_handling_time': 0.0
}
# Create unique session directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
session_dir = os.path.join(self.workspace_dir, f"session_{timestamp}")
input_dir = os.path.join(session_dir, "input")
output_dir = os.path.join(session_dir, "output")
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
try:
# Extract uploaded zip (file handling time)
file_start = time.time()
print(f"Extracting images from {zip_file}...")
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(input_dir)
# Find all image files (skip macOS metadata)
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"}
image_files = []
for root, _, files in os.walk(input_dir):
# Skip __MACOSX directories
if "__MACOSX" in root:
continue
for file in files:
# Skip hidden files and macOS metadata
if file.startswith("._") or file.startswith(".DS_Store"):
continue
if Path(file).suffix.lower() in image_extensions:
image_files.append(os.path.join(root, file))
if not image_files:
print("No images found in zip file")
return None, metrics
image_files = sorted(image_files)
metrics['total_images'] = len(image_files)
print(f"Found {len(image_files)} images")
# Load model
model = self.load_model()
file_handling_time = time.time() - file_start
# Process each image (INFERENCE TIME ONLY)
print("Processing images...")
inference_times = []
for i, img_path in enumerate(image_files):
print(f"Processing {i+1}/{len(image_files)}: {Path(img_path).name}")
try:
# Load image (not counted in inference time)
image = Image.open(img_path).convert("RGB")
image_np = np.array(image)
# Measure ONLY inference time
inference_start = time.time()
with torch.no_grad():
# API expects a list of images, returns Prediction object
prediction = model.inference([image_np])
depth = prediction.depth[0] # Get first (and only) depth map
inference_time = time.time() - inference_start
inference_times.append(inference_time)
print(f" Inference time: {inference_time:.3f}s")
# Save raw depth as .npy file (not counted in inference time)
output_name = Path(img_path).stem + "_depth.npy"
output_path = os.path.join(output_dir, output_name)
np.save(output_path, depth)
except Exception as e:
print(f" ⚠️ Skipping {Path(img_path).name}: {str(e)}")
continue
# Calculate metrics
metrics['per_image_times'] = inference_times
metrics['total_inference_time'] = sum(inference_times)
metrics['avg_time_per_image'] = metrics['total_inference_time'] / len(inference_times) if inference_times else 0
metrics['images_processed'] = len(inference_times) # Actually processed (may be less than total if some failed)
# Create output zip (file handling time)
zip_start = time.time()
output_zip = os.path.join(session_dir, "depth_output.zip")
print(f"Creating output zip: {output_zip}")
with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, _, files in os.walk(output_dir):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, output_dir)
zipf.write(file_path, arcname)
file_handling_time += time.time() - zip_start
metrics['file_handling_time'] = file_handling_time
print(f"Done! Output saved to {output_zip}")
print(f"\n📊 Metrics:")
print(f" Total images found: {metrics['total_images']}")
print(f" Images successfully processed: {metrics['images_processed']}")
print(f" Total inference time: {metrics['total_inference_time']:.2f}s")
print(f" Average per image: {metrics['avg_time_per_image']:.3f}s")
print(f" File handling time: {metrics['file_handling_time']:.2f}s")
return output_zip, metrics
except Exception as e:
print(f"Error processing zip file: {e}")
import traceback
traceback.print_exc()
return None, metrics
finally:
# Cleanup input directory to save space
if os.path.exists(input_dir):
shutil.rmtree(input_dir)
def create_interface(self):
"""Create the Gradio interface."""
with gr.Blocks(title="Depth Anything 3 - Simple") as demo:
gr.Markdown("# Depth Anything 3 - Simplified")
gr.Markdown(
"Upload a ZIP file containing images. "
"Get back a ZIP file with raw depth maps (.npy format)."
)
with gr.Row():
with gr.Column():
input_zip = gr.File(
label="Upload ZIP file with images",
file_types=[".zip"],
type="filepath"
)
process_btn = gr.Button("Process Images", variant="primary")
with gr.Column():
output_zip = gr.File(
label="Download depth maps (ZIP)",
type="filepath"
)
status = gr.Markdown("Ready to process images.")
# Performance Metrics Display
with gr.Row():
with gr.Column():
metrics_display = gr.Markdown("", label="Performance Metrics")
def process_with_status(zip_file):
if zip_file is None:
return None, "❌ Please upload a ZIP file first.", ""
status_msg = "⏳ Processing images... This may take a few minutes."
yield None, status_msg, ""
result, metrics = self.process_zip(zip_file)
if result is None:
final_status = "❌ Error processing images. Check console for details."
metrics_msg = ""
else:
# Check if all images were processed successfully
success_rate = metrics['images_processed'] / metrics['total_images'] if metrics['total_images'] > 0 else 0
if metrics['images_processed'] < metrics['total_images']:
skipped = metrics['total_images'] - metrics['images_processed']
final_status = f"✅ Done! Processed {metrics['images_processed']} images. ({skipped} skipped due to errors)"
else:
final_status = f"✅ Done! All {metrics['images_processed']} images processed successfully."
# Format metrics for display
throughput = 1/metrics['avg_time_per_image'] if metrics['avg_time_per_image'] > 0 else 0
metrics_msg = f"""
## 📊 Performance Metrics
| Metric | Value |
|--------|-------|
| **Total Images Found** | {metrics['total_images']} |
| **Successfully Processed** | {metrics['images_processed']} ({success_rate*100:.0f}%) |
| **Total Inference Time** | {metrics['total_inference_time']:.2f}s |
| **Average Time per Image** | {metrics['avg_time_per_image']:.3f}s |
| **Throughput** | {throughput:.2f} images/second |
| **File Handling Time** | {metrics['file_handling_time']:.2f}s |
**Note:** Inference time excludes ZIP extraction, image loading, and file saving - shows only pure model performance.
"""
yield result, final_status, metrics_msg
process_btn.click(
fn=process_with_status,
inputs=[input_zip],
outputs=[output_zip, status, metrics_display]
)
gr.Markdown("---")
gr.Markdown(
"**Note:** Depth maps are saved as NumPy arrays (.npy files). "
"Load them with `numpy.load('filename.npy')` in Python."
)
return demo
def launch(self, **kwargs):
"""Launch the application."""
demo = self.create_interface()
demo.queue().launch(**kwargs)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Simple Depth Anything 3 App")
parser.add_argument(
"--model-dir",
default="depth-anything/DA3NESTED-GIANT-LARGE",
help="Path to model directory"
)
parser.add_argument(
"--host",
default="127.0.0.1",
help="Host address"
)
parser.add_argument(
"--port",
type=int,
default=7860,
help="Port number"
)
parser.add_argument(
"--share",
action="store_true",
help="Create public link"
)
args = parser.parse_args()
print("🚀 Starting Simple Depth Anything 3 App...")
print(f"📦 Model: {args.model_dir}")
print(f"🌐 Server: {args.host}:{args.port}")
app = SimpleDepthApp(model_dir=args.model_dir)
app.launch(
server_name=args.host,
server_port=args.port,
share=args.share
)