rezzzq commited on
Commit
cb986c8
·
1 Parent(s): 3eaa887

Add SurfaceAI inference endpoint handler

Browse files
Files changed (2) hide show
  1. handler.py +324 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Inference Endpoint handler for SurfaceAI models.
3
+
4
+ This handler loads all 7 SurfaceAI models and performs hierarchical classification:
5
+ 1. Road type classification
6
+ 2. Surface type classification
7
+ 3. Surface quality regression (model selected based on surface type)
8
+
9
+ Deploy by creating an Inference Endpoint pointing to this repo.
10
+ """
11
+
12
+ import base64
13
+ import io
14
+ import logging
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List
17
+
18
+ import torch
19
+ from huggingface_hub import hf_hub_download
20
+ from PIL import Image
21
+ from torchvision import models, transforms
22
+ from torch import nn, Tensor
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Constants from original SurfaceAI
28
+ NORM_MEAN = [0.42834484577178955, 0.4461250305175781, 0.4350937306880951]
29
+ NORM_SD = [0.22991590201854706, 0.23555299639701843, 0.26348039507865906]
30
+ CROP_LOWER_MIDDLE_HALF = "lower_middle_half"
31
+ CROP_LOWER_HALF = "lower_half"
32
+
33
+ # Model configuration
34
+ MODEL_CONFIG = {
35
+ "hf_repo": "SurfaceAI/models-moved",
36
+ "models": {
37
+ "road_type": "v1/road_type_v1.pt",
38
+ "surface_type": "v1/surface_type_v1.pt",
39
+ "surface_quality": {
40
+ "asphalt": "v1/surface_quality_asphalt_v1.pt",
41
+ "concrete": "v1/surface_quality_concrete_v1.pt",
42
+ "paving_stones": "v1/surface_quality_paving_stones_v1.pt",
43
+ "sett": "v1/surface_quality_sett_v1.pt",
44
+ "unpaved": "v1/surface_quality_unpaved_v1.pt",
45
+ }
46
+ },
47
+ "transform_surface": {
48
+ "resize": 256,
49
+ "crop": CROP_LOWER_MIDDLE_HALF,
50
+ "normalize": (NORM_MEAN, NORM_SD),
51
+ },
52
+ "transform_road_type": {
53
+ "resize": 256,
54
+ "crop": CROP_LOWER_HALF,
55
+ "normalize": (NORM_MEAN, NORM_SD),
56
+ },
57
+ }
58
+
59
+ # Quality class mapping
60
+ QUALITY_CLASSES = {
61
+ 1: "excellent",
62
+ 2: "good",
63
+ 3: "intermediate",
64
+ 4: "bad",
65
+ 5: "very_bad",
66
+ }
67
+
68
+
69
+ class CustomEfficientNetV2SLinear(nn.Module):
70
+ """EfficientNetV2-S with linear classifier for classification/regression."""
71
+
72
+ def __init__(self, num_classes, avg_pool=1):
73
+ super().__init__()
74
+ model = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
75
+ in_features = model.classifier[-1].in_features * (avg_pool * avg_pool)
76
+ fc = nn.Linear(in_features, num_classes, bias=True)
77
+ model.classifier[-1] = fc
78
+
79
+ self.features = model.features
80
+ self.avgpool = nn.AdaptiveAvgPool2d(avg_pool)
81
+ self.classifier = model.classifier
82
+ self.is_regression = num_classes == 1
83
+
84
+ def forward(self, x: Tensor) -> Tensor:
85
+ x = self.features(x)
86
+ x = self.avgpool(x)
87
+ x = torch.flatten(x, 1)
88
+ x = self.classifier(x)
89
+ return x
90
+
91
+ def get_class_probabilities(self, x):
92
+ if self.is_regression:
93
+ return x.flatten()
94
+ return nn.functional.softmax(x, dim=1)
95
+
96
+
97
+ class EndpointHandler:
98
+ """HuggingFace Inference Endpoint handler for SurfaceAI."""
99
+
100
+ def __init__(self, path: str = ""):
101
+ """
102
+ Initialize handler and load all models.
103
+
104
+ Args:
105
+ path: Path to model directory (provided by HF Inference Endpoints)
106
+ """
107
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
+ logger.info(f"Using device: {self.device}")
109
+
110
+ self.models = {}
111
+ self.class_mappings = {}
112
+ self._load_all_models()
113
+
114
+ # Pre-build transforms
115
+ self.transform_surface = self._build_transform(MODEL_CONFIG["transform_surface"])
116
+ self.transform_road_type = self._build_transform(MODEL_CONFIG["transform_road_type"])
117
+
118
+ def _download_model(self, filename: str) -> str:
119
+ """Download model from HuggingFace Hub."""
120
+ return hf_hub_download(
121
+ repo_id=MODEL_CONFIG["hf_repo"],
122
+ filename=filename,
123
+ )
124
+
125
+ def _load_model(self, model_path: str) -> tuple:
126
+ """Load a single model and return (model, class_to_idx, is_regression)."""
127
+ state = torch.load(model_path, map_location=self.device, weights_only=False)
128
+
129
+ is_regression = state["is_regression"]
130
+ class_to_idx = state["class_to_idx"]
131
+ num_classes = 1 if is_regression else len(class_to_idx)
132
+
133
+ model = CustomEfficientNetV2SLinear(num_classes=num_classes)
134
+ model.load_state_dict(state["model_state_dict"])
135
+ model.to(self.device)
136
+ model.eval()
137
+
138
+ return model, class_to_idx, is_regression
139
+
140
+ def _load_all_models(self):
141
+ """Load all 7 SurfaceAI models."""
142
+ logger.info("Loading SurfaceAI models...")
143
+
144
+ # Load road type model
145
+ path = self._download_model(MODEL_CONFIG["models"]["road_type"])
146
+ self.models["road_type"], self.class_mappings["road_type"], _ = self._load_model(path)
147
+ logger.info("Loaded road_type model")
148
+
149
+ # Load surface type model
150
+ path = self._download_model(MODEL_CONFIG["models"]["surface_type"])
151
+ self.models["surface_type"], self.class_mappings["surface_type"], _ = self._load_model(path)
152
+ logger.info("Loaded surface_type model")
153
+
154
+ # Load quality models for each surface type
155
+ self.models["quality"] = {}
156
+ self.class_mappings["quality"] = {}
157
+ for surface_type, model_file in MODEL_CONFIG["models"]["surface_quality"].items():
158
+ path = self._download_model(model_file)
159
+ model, class_to_idx, _ = self._load_model(path)
160
+ self.models["quality"][surface_type] = model
161
+ self.class_mappings["quality"][surface_type] = class_to_idx
162
+ logger.info(f"Loaded quality model for {surface_type}")
163
+
164
+ logger.info("All models loaded successfully")
165
+
166
+ @staticmethod
167
+ def _custom_crop(img: Image.Image, crop_style: str) -> Image.Image:
168
+ """Crop image according to style."""
169
+ im_width, im_height = img.size
170
+
171
+ if crop_style == CROP_LOWER_MIDDLE_HALF:
172
+ top = im_height // 2
173
+ left = im_width // 4
174
+ height = im_height // 2
175
+ width = im_width // 2
176
+ elif crop_style == CROP_LOWER_HALF:
177
+ top = im_height // 2
178
+ left = 0
179
+ height = im_height // 2
180
+ width = im_width
181
+ else:
182
+ return img
183
+
184
+ return img.crop((left, top, left + width, top + height))
185
+
186
+ def _build_transform(self, config: dict) -> transforms.Compose:
187
+ """Build torchvision transform from config."""
188
+ transform_list = []
189
+
190
+ if config.get("crop"):
191
+ transform_list.append(
192
+ transforms.Lambda(lambda img: self._custom_crop(img, config["crop"]))
193
+ )
194
+
195
+ if config.get("resize"):
196
+ size = config["resize"]
197
+ if isinstance(size, int):
198
+ size = (size, size)
199
+ transform_list.append(transforms.Resize(size))
200
+
201
+ transform_list.append(transforms.ToTensor())
202
+
203
+ if config.get("normalize"):
204
+ transform_list.append(transforms.Normalize(*config["normalize"]))
205
+
206
+ return transforms.Compose(transform_list)
207
+
208
+ def _predict(self, model, data: torch.Tensor, class_to_idx: dict) -> tuple:
209
+ """Run prediction and convert to class/value."""
210
+ with torch.no_grad():
211
+ outputs = model(data)
212
+ values = model.get_class_probabilities(outputs)
213
+
214
+ idx_to_class = {i: cls for cls, i in class_to_idx.items()}
215
+
216
+ if len(values.shape) < 2:
217
+ # Regression output
218
+ classes = [
219
+ idx_to_class[
220
+ min(max(int(v.round().item()), min(class_to_idx.values())),
221
+ max(class_to_idx.values()))
222
+ ]
223
+ for v in values
224
+ ]
225
+ values_list = values.tolist()
226
+ else:
227
+ # Classification output
228
+ classes = [idx_to_class[idx.item()] for idx in torch.argmax(values, dim=1)]
229
+ values_list = values.tolist()
230
+
231
+ return classes, values_list
232
+
233
+ def _process_image(self, image: Image.Image) -> dict:
234
+ """Process a single image through all models."""
235
+ # Ensure RGB
236
+ if image.mode != "RGB":
237
+ image = image.convert("RGB")
238
+
239
+ # Road type prediction
240
+ road_data = self.transform_road_type(image).unsqueeze(0).to(self.device)
241
+ road_classes, road_values = self._predict(
242
+ self.models["road_type"],
243
+ road_data,
244
+ self.class_mappings["road_type"]
245
+ )
246
+
247
+ # Surface type prediction
248
+ surface_data = self.transform_surface(image).unsqueeze(0).to(self.device)
249
+ surface_classes, surface_values = self._predict(
250
+ self.models["surface_type"],
251
+ surface_data,
252
+ self.class_mappings["surface_type"]
253
+ )
254
+
255
+ # Quality prediction based on detected surface type
256
+ surface_type = surface_classes[0]
257
+ quality_class = None
258
+ quality_value = None
259
+
260
+ if surface_type in self.models["quality"]:
261
+ quality_classes, quality_values = self._predict(
262
+ self.models["quality"][surface_type],
263
+ surface_data,
264
+ self.class_mappings["quality"][surface_type]
265
+ )
266
+ quality_class = quality_classes[0]
267
+ quality_value = quality_values[0]
268
+
269
+ return {
270
+ "road_type": road_classes[0],
271
+ "road_type_confidence": max(road_values[0]) if isinstance(road_values[0], list) else road_values[0],
272
+ "surface_type": surface_type,
273
+ "surface_type_confidence": max(surface_values[0]) if isinstance(surface_values[0], list) else surface_values[0],
274
+ "quality_class": quality_class,
275
+ "quality_value": quality_value,
276
+ }
277
+
278
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
279
+ """
280
+ Process inference request.
281
+
282
+ Args:
283
+ data: Request data containing either:
284
+ - "inputs": base64-encoded image or URL
285
+ - "image": PIL Image (when called directly)
286
+
287
+ Returns:
288
+ List of prediction results
289
+ """
290
+ inputs = data.get("inputs", data.get("image"))
291
+
292
+ if inputs is None:
293
+ return [{"error": "No input provided. Send 'inputs' with base64 image or URL."}]
294
+
295
+ try:
296
+ # Handle different input types
297
+ if isinstance(inputs, str):
298
+ if inputs.startswith("data:image"):
299
+ # Base64 data URL
300
+ inputs = inputs.split(",")[1]
301
+ image_bytes = base64.b64decode(inputs)
302
+ image = Image.open(io.BytesIO(image_bytes))
303
+ elif inputs.startswith("http"):
304
+ # URL - fetch it
305
+ import requests
306
+ response = requests.get(inputs, timeout=10)
307
+ image = Image.open(io.BytesIO(response.content))
308
+ else:
309
+ # Assume raw base64
310
+ image_bytes = base64.b64decode(inputs)
311
+ image = Image.open(io.BytesIO(image_bytes))
312
+ elif isinstance(inputs, Image.Image):
313
+ image = inputs
314
+ elif isinstance(inputs, bytes):
315
+ image = Image.open(io.BytesIO(inputs))
316
+ else:
317
+ return [{"error": f"Unsupported input type: {type(inputs)}"}]
318
+
319
+ result = self._process_image(image)
320
+ return [result]
321
+
322
+ except Exception as e:
323
+ logger.exception("Error processing request")
324
+ return [{"error": str(e)}]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ huggingface_hub>=0.20.0
4
+ Pillow>=9.0.0
5
+ requests>=2.28.0