rezzzq Claude commited on
Commit
3138c42
·
0 Parent(s):

Add multi-model inference handler with RDD and SurfaceAI support

Browse files

- business/finishing: classification models for road assessment
- rdd: road damage detection with bounding boxes (YOLO12s RDD2022)
- surfaceai: surface type, road type, and quality classification
- Automatically selects quality model based on detected surface type

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
handler.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Inference Endpoint Custom Handler
3
+ Handles inference for multiple YOLO models:
4
+ - business/finishing: classification models
5
+ - rdd: road damage detection (object detection with bounding boxes)
6
+ - surfaceai: surface type, road type, and quality classification
7
+ """
8
+
9
+ import base64
10
+ import io
11
+ from typing import Any, Dict, List
12
+ from PIL import Image
13
+ from ultralytics import YOLO
14
+
15
+
16
+ class EndpointHandler:
17
+ # SurfaceAI class mappings
18
+ SURFACE_TYPES = ["asphalt", "concrete", "paving_stones", "sett", "unpaved"]
19
+ QUALITY_LEVELS = ["excellent", "good", "intermediate", "bad", "very_bad"]
20
+
21
+ def __init__(self, path: str = ""):
22
+ """
23
+ Initialize the handler by loading all models.
24
+
25
+ Args:
26
+ path: Path to the model directory (provided by HF)
27
+ """
28
+ # Classification models
29
+ self.models = {
30
+ "business": YOLO(f"{path}/models/business_best.pt"),
31
+ "finishing": YOLO(f"{path}/models/finishing_best.pt")
32
+ }
33
+
34
+ # Road Damage Detection model
35
+ self.rdd_model = YOLO(f"{path}/models/rdd/yolo12s_RDD2022_best.pt")
36
+
37
+ # SurfaceAI models
38
+ self.surfaceai_models = {
39
+ "surface_type": YOLO(f"{path}/models/surfaceai/surface_type_v1.pt"),
40
+ "road_type": YOLO(f"{path}/models/surfaceai/road_type_v1.pt"),
41
+ "quality": {
42
+ "asphalt": YOLO(f"{path}/models/surfaceai/quality/surface_quality_asphalt_v1.pt"),
43
+ "concrete": YOLO(f"{path}/models/surfaceai/quality/surface_quality_concrete_v1.pt"),
44
+ "paving_stones": YOLO(f"{path}/models/surfaceai/quality/surface_quality_paving_stones_v1.pt"),
45
+ "sett": YOLO(f"{path}/models/surfaceai/quality/surface_quality_sett_v1.pt"),
46
+ "unpaved": YOLO(f"{path}/models/surfaceai/quality/surface_quality_unpaved_v1.pt"),
47
+ }
48
+ }
49
+
50
+ def _decode_image(self, image_input: Any) -> Image.Image:
51
+ """
52
+ Decode image from various input formats.
53
+
54
+ Args:
55
+ image_input: Base64 string, URL, or raw bytes
56
+
57
+ Returns:
58
+ PIL Image object
59
+ """
60
+ if isinstance(image_input, str):
61
+ if image_input.startswith(("http://", "https://")):
62
+ import requests
63
+ response = requests.get(image_input, timeout=30)
64
+ response.raise_for_status()
65
+ return Image.open(io.BytesIO(response.content))
66
+ else:
67
+ # Handle base64 with or without data URI prefix
68
+ if "base64," in image_input:
69
+ image_input = image_input.split("base64,")[1]
70
+ image_data = base64.b64decode(image_input)
71
+ return Image.open(io.BytesIO(image_data))
72
+ elif isinstance(image_input, bytes):
73
+ return Image.open(io.BytesIO(image_input))
74
+ else:
75
+ raise ValueError(f"Unsupported image input type: {type(image_input)}")
76
+
77
+ def _run_classification(self, model: YOLO, image: Image.Image) -> Dict[str, Any]:
78
+ """Run classification inference and return formatted results."""
79
+ prediction = model.predict(image, verbose=False)[0]
80
+ probs = prediction.probs
81
+ top_class_id = int(probs.top1)
82
+ top_class_name = prediction.names[top_class_id]
83
+ top_confidence = float(probs.top1conf)
84
+
85
+ all_probs = {
86
+ prediction.names[i]: float(probs.data[i])
87
+ for i in range(len(probs.data))
88
+ }
89
+
90
+ return {
91
+ "class": top_class_name,
92
+ "class_id": top_class_id,
93
+ "confidence": round(top_confidence, 4),
94
+ "all_probs": {k: round(v, 4) for k, v in all_probs.items()}
95
+ }
96
+
97
+ def _run_rdd(self, image: Image.Image, conf_threshold: float = 0.25) -> Dict[str, Any]:
98
+ """
99
+ Run Road Damage Detection and return detections with bounding boxes.
100
+
101
+ Returns:
102
+ {
103
+ "detections": [
104
+ {
105
+ "class": "D00",
106
+ "class_id": 0,
107
+ "confidence": 0.85,
108
+ "bbox": [x1, y1, x2, y2]
109
+ },
110
+ ...
111
+ ],
112
+ "count": 2
113
+ }
114
+ """
115
+ prediction = self.rdd_model.predict(image, verbose=False, conf=conf_threshold)[0]
116
+ detections = []
117
+
118
+ if prediction.boxes is not None and len(prediction.boxes) > 0:
119
+ for box in prediction.boxes:
120
+ class_id = int(box.cls[0])
121
+ class_name = prediction.names[class_id]
122
+ confidence = float(box.conf[0])
123
+ bbox = box.xyxy[0].tolist() # [x1, y1, x2, y2]
124
+
125
+ detections.append({
126
+ "class": class_name,
127
+ "class_id": class_id,
128
+ "confidence": round(confidence, 4),
129
+ "bbox": [round(coord, 2) for coord in bbox]
130
+ })
131
+
132
+ return {
133
+ "detections": detections,
134
+ "count": len(detections)
135
+ }
136
+
137
+ def _run_surfaceai(self, image: Image.Image) -> Dict[str, Any]:
138
+ """
139
+ Run SurfaceAI models for surface type, road type, and quality assessment.
140
+
141
+ Returns:
142
+ {
143
+ "surface_type": {
144
+ "class": "asphalt",
145
+ "confidence": 0.92,
146
+ "all_probs": {...}
147
+ },
148
+ "road_type": {
149
+ "class": "primary",
150
+ "confidence": 0.88,
151
+ "all_probs": {...}
152
+ },
153
+ "surface_quality": {
154
+ "class": "good",
155
+ "confidence": 0.75,
156
+ "all_probs": {...},
157
+ "model_used": "asphalt"
158
+ }
159
+ }
160
+ """
161
+ results = {}
162
+
163
+ # Get surface type
164
+ surface_result = self._run_classification(
165
+ self.surfaceai_models["surface_type"], image
166
+ )
167
+ results["surface_type"] = surface_result
168
+
169
+ # Get road type
170
+ road_result = self._run_classification(
171
+ self.surfaceai_models["road_type"], image
172
+ )
173
+ results["road_type"] = road_result
174
+
175
+ # Get surface quality based on detected surface type
176
+ detected_surface = surface_result["class"].lower()
177
+ if detected_surface in self.surfaceai_models["quality"]:
178
+ quality_model = self.surfaceai_models["quality"][detected_surface]
179
+ quality_result = self._run_classification(quality_model, image)
180
+ quality_result["model_used"] = detected_surface
181
+ results["surface_quality"] = quality_result
182
+ else:
183
+ # Fallback to asphalt quality model if surface type not recognized
184
+ quality_model = self.surfaceai_models["quality"]["asphalt"]
185
+ quality_result = self._run_classification(quality_model, image)
186
+ quality_result["model_used"] = "asphalt"
187
+ quality_result["note"] = f"Surface type '{detected_surface}' not recognized, using asphalt model"
188
+ results["surface_quality"] = quality_result
189
+
190
+ return results
191
+
192
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
193
+ """
194
+ Process inference request.
195
+
196
+ Expected input format:
197
+ {
198
+ "inputs": "<base64_string or URL>",
199
+ "parameters": {
200
+ "model": "business" | "finishing" | "both" | "rdd" | "surfaceai"
201
+ "conf_threshold": 0.25 # optional, for RDD only
202
+ }
203
+ }
204
+
205
+ Returns for business/finishing/both:
206
+ [
207
+ {
208
+ "business": {"class": "...", "class_id": 0, "confidence": 0.95, "all_probs": {...}},
209
+ "finishing": {"class": "...", "class_id": 0, "confidence": 0.92, "all_probs": {...}}
210
+ }
211
+ ]
212
+
213
+ Returns for rdd:
214
+ [
215
+ {
216
+ "detections": [
217
+ {"class": "D00", "class_id": 0, "confidence": 0.85, "bbox": [x1, y1, x2, y2]},
218
+ ...
219
+ ],
220
+ "count": 2
221
+ }
222
+ ]
223
+
224
+ Returns for surfaceai:
225
+ [
226
+ {
227
+ "surface_type": {"class": "asphalt", "confidence": 0.92, "all_probs": {...}},
228
+ "road_type": {"class": "primary", "confidence": 0.88, "all_probs": {...}},
229
+ "surface_quality": {"class": "good", "confidence": 0.75, "all_probs": {...}, "model_used": "asphalt"}
230
+ }
231
+ ]
232
+ """
233
+ # Get image input
234
+ image_input = data.get("inputs")
235
+ if not image_input:
236
+ return [{"error": "Missing required field: inputs"}]
237
+
238
+ # Get parameters
239
+ parameters = data.get("parameters", {})
240
+ model_choice = parameters.get("model", "both")
241
+
242
+ try:
243
+ # Decode image
244
+ image = self._decode_image(image_input)
245
+
246
+ # Handle RDD model
247
+ if model_choice == "rdd":
248
+ conf_threshold = parameters.get("conf_threshold", 0.25)
249
+ return [self._run_rdd(image, conf_threshold)]
250
+
251
+ # Handle SurfaceAI models
252
+ if model_choice == "surfaceai":
253
+ return [self._run_surfaceai(image)]
254
+
255
+ # Handle classification models (business/finishing/both)
256
+ if model_choice == "both":
257
+ models_to_run = ["business", "finishing"]
258
+ elif model_choice in self.models:
259
+ models_to_run = [model_choice]
260
+ else:
261
+ return [{"error": f"Invalid model choice: {model_choice}. Use 'business', 'finishing', 'both', 'rdd', or 'surfaceai'"}]
262
+
263
+ # Run classification inference
264
+ results = {}
265
+ for model_name in models_to_run:
266
+ model = self.models[model_name]
267
+ results[model_name] = self._run_classification(model, image)
268
+
269
+ return [results]
270
+
271
+ except Exception as e:
272
+ return [{"error": str(e)}]
models/business_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ca532122b323ba29e3c280cff785fd64a62fc6fd7016f5c2239a9e690aa0abd
3
+ size 3197634
models/finishing_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08c62c0a9a44899427df2b33f34f84b76d0a72f760e9790733a6384bf5a46bf2
3
+ size 3201346
models/rdd/yolo12s_RDD2022_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf91e3199d81ee761f4a3a012da752939c9d72e9b1eaa242e508db3163270a0
3
+ size 18966618
models/surfaceai/quality/surface_quality_asphalt_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3bb1d9b7884fff0dbbf736e792e7dbcf149fb7c300818601bc96f042f24ec50
3
+ size 81618348
models/surfaceai/quality/surface_quality_concrete_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baf27927570d268cc0bb13115b7281c05aa8892dc56d559f113a11cb8ce91402
3
+ size 81619134
models/surfaceai/quality/surface_quality_paving_stones_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd8f7510c2cf8b4a6be0edbb66c1cd5909a8a52a06a3f6ee1b422d965039f084
3
+ size 81623064
models/surfaceai/quality/surface_quality_sett_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7075f1e8c0d8388d7f7d49763807701f3a124a4d2e2d6f83362c514fa59ca58d
3
+ size 81615990
models/surfaceai/quality/surface_quality_unpaved_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29c0b73838129bdb244d8b377746d8099e2f8dee4011c13ab1e0a5a3fc9a045e
3
+ size 81618348
models/surfaceai/road_type_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01b7fb3e5619a8afc1d6b8bc4e382eb21e21a0b5f5a3d380eea0d67bf818a8e5
3
+ size 81643376
models/surfaceai/surface_type_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c87e8d1649b7aec19fc6927da0315f34c865cbc6913d9e0c35c430c69e7c1de
3
+ size 81630118
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ultralytics>=8.3.0
2
+ pillow>=10.0.0
3
+ requests>=2.31.0