depth_anything_multi / handler.py
rfigueroa's picture
Update handler.py
a5255df verified
import os
import torch
import base64
import io
import requests
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
# Load model and processor
#self.model_path = path or os.environ.get("MODEL_PATH", "")
self.model_path = "depthanything-v2-small"
print(self.model_path)
self.image_processor = AutoImageProcessor.from_pretrained(self.model_path)
self.model = AutoModelForDepthEstimation.from_pretrained(self.model_path)
# Move model to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
# Set model to evaluation mode
self.model.eval()
def __call__(self, data):
"""
Args:
data: Input data in the format of a dictionary with either:
- 'url': URL of the image
- 'file': Base64 encoded image
- 'image': PIL Image object
- 'visualization': Boolean flag to return visualization-friendly format (default: False)
- 'points': List of points to return depth values for (default: None)[[x1 y1] [x2 y2] ... [xn yn]]
Returns:
Dictionary containing the depth map and metadata
"""
# Process input data
if "url" in data:
# Download image from URL
response = requests.get(data["url"], stream=True)
response.raise_for_status() # Raise an exception for HTTP errors
image = Image.open(response.raw).convert("RGB")
elif "file" in data:
# Decode base64 image
image_bytes = base64.b64decode(data["file"])
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif "image" in data:
# Direct PIL image input
image = data["image"]
else:
raise ValueError(
"No valid image input found. Please provide either 'url', 'file' (base64 encoded image), or 'image' (PIL Image object).")
# Prepare image for the model
inputs = self.image_processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = self.model(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1], # (height, width)
mode="bicubic",
align_corners=False,
).squeeze()
# Convert to numpy and normalize for visualization
depth_map = prediction.cpu().numpy()
# Normalize depth map to 0-1 range for better visualization
depth_min = depth_map.min()
depth_max = depth_map.max()
normalized_depth = (depth_map - depth_min) / (depth_max - depth_min)
# Check if visualization is requested
visualization = data.get("visualization", False)
# Check the pixels to return if no pixel provided will return the [0,0] position
points = data.get("points", [[0, 0]])
# map = np.array(depth_map)
# print(map.shape)
if visualization:
# Convert depth map to a visualization-friendly format (grayscale image)
# Create a figure and plot the depth map
plt.figure(figsize=(10, 10))
plt.imshow(normalized_depth, cmap='plasma')
plt.axis('off')
# Save the figure to a BytesIO object
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
buf.seek(0)
# Convert to base64 for easy transmission
img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
result = {
"visualization": img_str,
"min_depth": float(depth_min),
"max_depth": float(depth_max),
"format": "base64_png"
}
else:
depths = [depth_map[i[1]][i[0]] for i in points]
result = {
"depths": depths
# "depth": normalized_depth.tolist(),
# "depth": compressed_depth_base64,
# "depth_map": depth_map,
# "min_depth": float(depth_min),
# "max_depth": float(depth_max),
# "shape": list(normalized_depth.shape)
}
return result