it-montesanto commited on
Commit
4c92074
1 Parent(s): 05b6c47

init repo

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. handler.py +115 -0
  3. requirements.txt +8 -0
  4. test_hf.py +56 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea
2
+ .env
3
+ .venv
handler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import base64
4
+ import io
5
+ import requests
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
9
+ import numpy as np
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, path=""):
13
+ # Load model and processor
14
+ self.model_path = path or os.environ.get("MODEL_PATH", "")
15
+ self.image_processor = AutoImageProcessor.from_pretrained(self.model_path)
16
+ self.model = AutoModelForDepthEstimation.from_pretrained(self.model_path)
17
+
18
+ # Move model to GPU if available
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.model = self.model.to(self.device)
21
+
22
+ # Set model to evaluation mode
23
+ self.model.eval()
24
+
25
+ def __call__(self, data):
26
+ """
27
+ Args:
28
+ data: Input data in the format of a dictionary with either:
29
+ - 'url': URL of the image
30
+ - 'file': Base64 encoded image
31
+ - 'image': PIL Image object
32
+ - 'visualization': Boolean flag to return visualization-friendly format (default: False)
33
+ Returns:
34
+ Dictionary containing the depth map and metadata
35
+ """
36
+ # Process input data
37
+ print(data)
38
+ if "url" in data:
39
+ # Download image from URL
40
+ response = requests.get(data["url"], stream=True)
41
+ response.raise_for_status() # Raise an exception for HTTP errors
42
+ image = Image.open(response.raw).convert("RGB")
43
+
44
+ elif "file" in data:
45
+ # Decode base64 image
46
+ image_bytes = base64.b64decode(data["file"])
47
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
48
+
49
+ elif "image" in data:
50
+ # Direct PIL image input
51
+ image = data["image"]
52
+
53
+ else:
54
+ raise ValueError("No valid image input found. Please provide either 'url', 'file' (base64 encoded image), or 'image' (PIL Image object).")
55
+
56
+ # Prepare image for the model
57
+ inputs = self.image_processor(images=image, return_tensors="pt")
58
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
59
+
60
+ # Perform inference
61
+ with torch.no_grad():
62
+ outputs = self.model(**inputs)
63
+ predicted_depth = outputs.predicted_depth
64
+
65
+ # Interpolate to original size
66
+ prediction = torch.nn.functional.interpolate(
67
+ predicted_depth.unsqueeze(1),
68
+ size=image.size[::-1], # (height, width)
69
+ mode="bicubic",
70
+ align_corners=False,
71
+ ).squeeze()
72
+
73
+ # Convert to numpy and normalize for visualization
74
+ depth_map = prediction.cpu().numpy()
75
+
76
+ # Normalize depth map to 0-1 range for better visualization
77
+ depth_min = depth_map.min()
78
+ depth_max = depth_map.max()
79
+ normalized_depth = (depth_map - depth_min) / (depth_max - depth_min)
80
+
81
+ # Check if visualization is requested
82
+ visualization = data.get("visualization", False)
83
+
84
+ if visualization:
85
+ # Convert depth map to a visualization-friendly format (grayscale image)
86
+ # Create a figure and plot the depth map
87
+ plt.figure(figsize=(10, 10))
88
+ plt.imshow(normalized_depth, cmap='plasma')
89
+ plt.axis('off')
90
+
91
+ # Save the figure to a BytesIO object
92
+ buf = io.BytesIO()
93
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
94
+ plt.close()
95
+ buf.seek(0)
96
+
97
+ # Convert to base64 for easy transmission
98
+ img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
99
+
100
+ result = {
101
+ "visualization": img_str,
102
+ "min_depth": float(depth_min),
103
+ "max_depth": float(depth_max),
104
+ "format": "base64_png"
105
+ }
106
+ else:
107
+ # Return raw depth map
108
+ result = {
109
+ "depth": normalized_depth.tolist(),
110
+ "min_depth": float(depth_min),
111
+ "max_depth": float(depth_max),
112
+ "shape": list(normalized_depth.shape)
113
+ }
114
+
115
+ return result
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ requests
6
+ pillow
7
+ transformers
8
+ numpy
test_hf.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import base64
3
+ from PIL import Image
4
+ import io
5
+
6
+ # URL del endpoint proporcionado por Hugging Face
7
+ ENDPOINT_URL = "https://qh7glc3xj9iw4tk2.eu-west-1.aws.endpoints.huggingface.cloud"
8
+
9
+ # Token de API de Hugging Face
10
+ API_TOKEN = "hf_..."
11
+
12
+ headers = {
13
+ # "Authorization": f"Bearer {API_TOKEN}",
14
+ "Content-Type": "application/json"
15
+ }
16
+
17
+ # Cargar y codificar una imagen
18
+ image = Image.open("mine.jpeg")
19
+ buffered = io.BytesIO()
20
+ image.save(buffered, format="JPEG")
21
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
22
+
23
+ # Preparar los datos para la solicitud
24
+ # payload = {
25
+ # "inputs" : {
26
+ #
27
+ # },
28
+ # "url" : "https://images.unsplash.com/photo-1586023492125-27b2c045efd7?fm=jpg&q=60&w=3000&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8Mnx8aW50ZXJpb3IlMjBkZXNpZ258ZW58MHx8MHx8fDA%3D",
29
+ # "visualization": True
30
+ # }
31
+
32
+ payload = {
33
+ "inputs" : {
34
+
35
+ },
36
+ "file" : img_str,
37
+ "visualization": True
38
+ }
39
+
40
+ # Enviar la solicitud
41
+ response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
42
+
43
+ # Procesar la respuesta
44
+ if response.status_code == 200:
45
+ result = response.json()
46
+ if "visualization" in result:
47
+ # Decodificar y guardar la visualizaci贸n
48
+ vis_bytes = base64.b64decode(result["visualization"])
49
+ with open("depth_visualization.png", "wb") as f:
50
+ f.write(vis_bytes)
51
+ print("Visualizaci贸n guardada como 'depth_visualization.png'")
52
+ print(f"Profundidad m铆nima: {result.get('min_depth')}")
53
+ print(f"Profundidad m谩xima: {result.get('max_depth')}")
54
+ else:
55
+ print(f"Error: {response.status_code}")
56
+ print(response.text)