File size: 4,619 Bytes
4c92074 5d978a0 05c9116 4c92074 e7571b2 4c92074 a2ddb3e 4c92074 a2ddb3e 4c92074 a2ddb3e 4c92074 e20a152 05c9116 a2ddb3e 4c92074 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import os
import torch
import base64
import io
import requests
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
# Load model and processor
self.model_path = path or os.environ.get("MODEL_PATH", "")
print(self.model_path)
self.image_processor = AutoImageProcessor.from_pretrained(self.model_path)
self.model = AutoModelForDepthEstimation.from_pretrained(self.model_path)
# Move model to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
# Set model to evaluation mode
self.model.eval()
def __call__(self, data):
"""
Args:
data: Input data in the format of a dictionary with either:
- 'url': URL of the image
- 'file': Base64 encoded image
- 'image': PIL Image object
- 'visualization': Boolean flag to return visualization-friendly format (default: False)
- 'x': Int pixel position on axis x
- 'y': Int pixel position on axis y
Returns:
Dictionary containing the depth map and metadata
"""
# Process input data
if "url" in data:
# Download image from URL
response = requests.get(data["url"], stream=True)
response.raise_for_status() # Raise an exception for HTTP errors
image = Image.open(response.raw).convert("RGB")
elif "file" in data:
# Decode base64 image
image_bytes = base64.b64decode(data["file"])
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif "image" in data:
# Direct PIL image input
image = data["image"]
else:
raise ValueError("No valid image input found. Please provide either 'url', 'file' (base64 encoded image), or 'image' (PIL Image object).")
# Prepare image for the model
inputs = self.image_processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = self.model(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1], # (height, width)
mode="bicubic",
align_corners=False,
).squeeze()
# Convert to numpy and normalize for visualization
depth_map = prediction.cpu().numpy()
# Normalize depth map to 0-1 range for better visualization
depth_min = depth_map.min()
depth_max = depth_map.max()
normalized_depth = (depth_map - depth_min) / (depth_max - depth_min)
# Check if visualization is requested
visualization = data.get("visualization", False)
# Check the pixels to return if no pixel provided will return the [0,0] position
x= data.get('x',0)
y= data.get('y',0)
map = np.array(depth_map)
print(map.shape)
if visualization:
# Convert depth map to a visualization-friendly format (grayscale image)
# Create a figure and plot the depth map
plt.figure(figsize=(10, 10))
plt.imshow(normalized_depth, cmap='plasma')
plt.axis('off')
# Save the figure to a BytesIO object
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
buf.seek(0)
# Convert to base64 for easy transmission
img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
result = {
"visualization": img_str,
"min_depth": float(depth_min),
"max_depth": float(depth_max),
"format": "base64_png"
}
else:
result = {
"deph": depth_map[y][x]
# "depth": normalized_depth.tolist(),
# "depth": compressed_depth_base64,
# "depth_map": depth_map,
# "min_depth": float(depth_min),
# "max_depth": float(depth_max),
# "shape": list(normalized_depth.shape)
}
return result
|