import os import torch import base64 import io import requests import matplotlib.pyplot as plt from PIL import Image from transformers import AutoImageProcessor, AutoModelForDepthEstimation import numpy as np class EndpointHandler: def __init__(self, path=""): # Load model and processor self.model_path = path or os.environ.get("MODEL_PATH", "") print(self.model_path) self.image_processor = AutoImageProcessor.from_pretrained(self.model_path) self.model = AutoModelForDepthEstimation.from_pretrained(self.model_path) # Move model to GPU if available self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) # Set model to evaluation mode self.model.eval() def __call__(self, data): """ Args: data: Input data in the format of a dictionary with either: - 'url': URL of the image - 'file': Base64 encoded image - 'image': PIL Image object - 'visualization': Boolean flag to return visualization-friendly format (default: False) - 'x': Int pixel position on axis x - 'y': Int pixel position on axis y Returns: Dictionary containing the depth map and metadata """ # Process input data if "url" in data: # Download image from URL response = requests.get(data["url"], stream=True) response.raise_for_status() # Raise an exception for HTTP errors image = Image.open(response.raw).convert("RGB") elif "file" in data: # Decode base64 image image_bytes = base64.b64decode(data["file"]) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") elif "image" in data: # Direct PIL image input image = data["image"] else: raise ValueError("No valid image input found. Please provide either 'url', 'file' (base64 encoded image), or 'image' (PIL Image object).") # Prepare image for the model inputs = self.image_processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Perform inference with torch.no_grad(): outputs = self.model(**inputs) predicted_depth = outputs.predicted_depth # Interpolate to original size prediction = torch.nn.functional.interpolate( predicted_depth.unsqueeze(1), size=image.size[::-1], # (height, width) mode="bicubic", align_corners=False, ).squeeze() # Convert to numpy and normalize for visualization depth_map = prediction.cpu().numpy() # Normalize depth map to 0-1 range for better visualization depth_min = depth_map.min() depth_max = depth_map.max() normalized_depth = (depth_map - depth_min) / (depth_max - depth_min) # Check if visualization is requested visualization = data.get("visualization", False) # Check the pixels to return if no pixel provided will return the [0,0] position x= data.get('x',0) y= data.get('y',0) map = np.array(depth_map) print(map.shape) if visualization: # Convert depth map to a visualization-friendly format (grayscale image) # Create a figure and plot the depth map plt.figure(figsize=(10, 10)) plt.imshow(normalized_depth, cmap='plasma') plt.axis('off') # Save the figure to a BytesIO object buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close() buf.seek(0) # Convert to base64 for easy transmission img_str = base64.b64encode(buf.getvalue()).decode('utf-8') result = { "visualization": img_str, "min_depth": float(depth_min), "max_depth": float(depth_max), "format": "base64_png" } else: result = { "deph": depth_map[y][x] # "depth": normalized_depth.tolist(), # "depth": compressed_depth_base64, # "depth_map": depth_map, # "min_depth": float(depth_min), # "max_depth": float(depth_max), # "shape": list(normalized_depth.shape) } return result