File size: 4,743 Bytes
4ba40aa
 
 
 
 
 
 
 
 
 
 
 
 
 
a5255df
 
4ba40aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637eff3
 
 
4ba40aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637eff3
4ba40aa
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import torch
import base64
import io
import requests
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
import numpy as np


class EndpointHandler:
    def __init__(self, path=""):
        # Load model and processor
        #self.model_path = path or os.environ.get("MODEL_PATH", "")
        self.model_path = "depthanything-v2-small" 
        print(self.model_path)
        self.image_processor = AutoImageProcessor.from_pretrained(self.model_path)
        self.model = AutoModelForDepthEstimation.from_pretrained(self.model_path)

        # Move model to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        # Set model to evaluation mode
        self.model.eval()

    def __call__(self, data):
        """
        Args:
            data: Input data in the format of a dictionary with either:
                - 'url': URL of the image
                - 'file': Base64 encoded image
                - 'image': PIL Image object
                - 'visualization': Boolean flag to return visualization-friendly format (default: False)
                - 'points': List of points to return depth values for (default: None)[[x1 y1] [x2 y2] ... [xn yn]]
        Returns:
            Dictionary containing the depth map and metadata
        """
        # Process input data

        if "url" in data:
            # Download image from URL
            response = requests.get(data["url"], stream=True)
            response.raise_for_status()  # Raise an exception for HTTP errors
            image = Image.open(response.raw).convert("RGB")

        elif "file" in data:
            # Decode base64 image
            image_bytes = base64.b64decode(data["file"])
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        elif "image" in data:
            # Direct PIL image input
            image = data["image"]

        else:
            raise ValueError(
                "No valid image input found. Please provide either 'url', 'file' (base64 encoded image), or 'image' (PIL Image object).")

        # Prepare image for the model
        inputs = self.image_processor(images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Perform inference
        with torch.no_grad():
            outputs = self.model(**inputs)
            predicted_depth = outputs.predicted_depth

        # Interpolate to original size
        prediction = torch.nn.functional.interpolate(
            predicted_depth.unsqueeze(1),
            size=image.size[::-1],  # (height, width)
            mode="bicubic",
            align_corners=False,
        ).squeeze()

        # Convert to numpy and normalize for visualization
        depth_map = prediction.cpu().numpy()

        # Normalize depth map to 0-1 range for better visualization
        depth_min = depth_map.min()
        depth_max = depth_map.max()
        normalized_depth = (depth_map - depth_min) / (depth_max - depth_min)

        # Check if visualization is requested
        visualization = data.get("visualization", False)

        # Check the pixels to return  if no pixel provided will return the [0,0] position
        points = data.get("points", [[0, 0]])



        # map = np.array(depth_map)
        # print(map.shape)

        if visualization:
            # Convert depth map to a visualization-friendly format (grayscale image)
            # Create a figure and plot the depth map
            plt.figure(figsize=(10, 10))
            plt.imshow(normalized_depth, cmap='plasma')
            plt.axis('off')

            # Save the figure to a BytesIO object
            buf = io.BytesIO()
            plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
            plt.close()
            buf.seek(0)

            # Convert to base64 for easy transmission
            img_str = base64.b64encode(buf.getvalue()).decode('utf-8')

            result = {
                "visualization": img_str,
                "min_depth": float(depth_min),
                "max_depth": float(depth_max),
                "format": "base64_png"
            }
        else:
            depths = [depth_map[i[1]][i[0]] for i in points]
            result = {
                "depths": depths
                # "depth": normalized_depth.tolist(),
                # "depth": compressed_depth_base64,
                # "depth_map": depth_map,
                # "min_depth": float(depth_min),
                # "max_depth": float(depth_max),
                # "shape": list(normalized_depth.shape)
            }

        return result