|
|
""" |
|
|
HuggingFace Inference Endpoint handler for SurfaceAI models. |
|
|
|
|
|
This handler loads all 7 SurfaceAI models and performs hierarchical classification: |
|
|
1. Road type classification |
|
|
2. Surface type classification |
|
|
3. Surface quality regression (model selected based on surface type) |
|
|
|
|
|
Deploy by creating an Inference Endpoint pointing to this repo. |
|
|
""" |
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from PIL import Image |
|
|
from torchvision import models, transforms |
|
|
from torch import nn, Tensor |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
NORM_MEAN = [0.42834484577178955, 0.4461250305175781, 0.4350937306880951] |
|
|
NORM_SD = [0.22991590201854706, 0.23555299639701843, 0.26348039507865906] |
|
|
CROP_LOWER_MIDDLE_HALF = "lower_middle_half" |
|
|
CROP_LOWER_HALF = "lower_half" |
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
"hf_repo": "SurfaceAI/models-moved", |
|
|
"models": { |
|
|
"road_type": "v1/road_type_v1.pt", |
|
|
"surface_type": "v1/surface_type_v1.pt", |
|
|
"surface_quality": { |
|
|
"asphalt": "v1/surface_quality_asphalt_v1.pt", |
|
|
"concrete": "v1/surface_quality_concrete_v1.pt", |
|
|
"paving_stones": "v1/surface_quality_paving_stones_v1.pt", |
|
|
"sett": "v1/surface_quality_sett_v1.pt", |
|
|
"unpaved": "v1/surface_quality_unpaved_v1.pt", |
|
|
} |
|
|
}, |
|
|
"transform_surface": { |
|
|
"resize": 256, |
|
|
"crop": CROP_LOWER_MIDDLE_HALF, |
|
|
"normalize": (NORM_MEAN, NORM_SD), |
|
|
}, |
|
|
"transform_road_type": { |
|
|
"resize": 256, |
|
|
"crop": CROP_LOWER_HALF, |
|
|
"normalize": (NORM_MEAN, NORM_SD), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
QUALITY_CLASSES = { |
|
|
1: "excellent", |
|
|
2: "good", |
|
|
3: "intermediate", |
|
|
4: "bad", |
|
|
5: "very_bad", |
|
|
} |
|
|
|
|
|
|
|
|
class CustomEfficientNetV2SLinear(nn.Module): |
|
|
"""EfficientNetV2-S with linear classifier for classification/regression.""" |
|
|
|
|
|
def __init__(self, num_classes, avg_pool=1): |
|
|
super().__init__() |
|
|
model = models.efficientnet_v2_s(weights="IMAGENET1K_V1") |
|
|
in_features = model.classifier[-1].in_features * (avg_pool * avg_pool) |
|
|
fc = nn.Linear(in_features, num_classes, bias=True) |
|
|
model.classifier[-1] = fc |
|
|
|
|
|
self.features = model.features |
|
|
self.avgpool = nn.AdaptiveAvgPool2d(avg_pool) |
|
|
self.classifier = model.classifier |
|
|
self.is_regression = num_classes == 1 |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.features(x) |
|
|
x = self.avgpool(x) |
|
|
x = torch.flatten(x, 1) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
def get_class_probabilities(self, x): |
|
|
if self.is_regression: |
|
|
return x.flatten() |
|
|
return nn.functional.softmax(x, dim=1) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
"""HuggingFace Inference Endpoint handler for SurfaceAI.""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize handler and load all models. |
|
|
|
|
|
Args: |
|
|
path: Path to model directory (provided by HF Inference Endpoints) |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
self.models = {} |
|
|
self.class_mappings = {} |
|
|
self._load_all_models() |
|
|
|
|
|
|
|
|
self.transform_surface = self._build_transform(MODEL_CONFIG["transform_surface"]) |
|
|
self.transform_road_type = self._build_transform(MODEL_CONFIG["transform_road_type"]) |
|
|
|
|
|
def _download_model(self, filename: str) -> str: |
|
|
"""Download model from HuggingFace Hub.""" |
|
|
return hf_hub_download( |
|
|
repo_id=MODEL_CONFIG["hf_repo"], |
|
|
filename=filename, |
|
|
) |
|
|
|
|
|
def _load_model(self, model_path: str) -> tuple: |
|
|
"""Load a single model and return (model, class_to_idx, is_regression).""" |
|
|
state = torch.load(model_path, map_location=self.device, weights_only=False) |
|
|
|
|
|
is_regression = state["is_regression"] |
|
|
class_to_idx = state["class_to_idx"] |
|
|
num_classes = 1 if is_regression else len(class_to_idx) |
|
|
|
|
|
model = CustomEfficientNetV2SLinear(num_classes=num_classes) |
|
|
model.load_state_dict(state["model_state_dict"]) |
|
|
model.to(self.device) |
|
|
model.eval() |
|
|
|
|
|
return model, class_to_idx, is_regression |
|
|
|
|
|
def _load_all_models(self): |
|
|
"""Load all 7 SurfaceAI models.""" |
|
|
logger.info("Loading SurfaceAI models...") |
|
|
|
|
|
|
|
|
path = self._download_model(MODEL_CONFIG["models"]["road_type"]) |
|
|
self.models["road_type"], self.class_mappings["road_type"], _ = self._load_model(path) |
|
|
logger.info("Loaded road_type model") |
|
|
|
|
|
|
|
|
path = self._download_model(MODEL_CONFIG["models"]["surface_type"]) |
|
|
self.models["surface_type"], self.class_mappings["surface_type"], _ = self._load_model(path) |
|
|
logger.info("Loaded surface_type model") |
|
|
|
|
|
|
|
|
self.models["quality"] = {} |
|
|
self.class_mappings["quality"] = {} |
|
|
for surface_type, model_file in MODEL_CONFIG["models"]["surface_quality"].items(): |
|
|
path = self._download_model(model_file) |
|
|
model, class_to_idx, _ = self._load_model(path) |
|
|
self.models["quality"][surface_type] = model |
|
|
self.class_mappings["quality"][surface_type] = class_to_idx |
|
|
logger.info(f"Loaded quality model for {surface_type}") |
|
|
|
|
|
logger.info("All models loaded successfully") |
|
|
|
|
|
@staticmethod |
|
|
def _custom_crop(img: Image.Image, crop_style: str) -> Image.Image: |
|
|
"""Crop image according to style.""" |
|
|
im_width, im_height = img.size |
|
|
|
|
|
if crop_style == CROP_LOWER_MIDDLE_HALF: |
|
|
top = im_height // 2 |
|
|
left = im_width // 4 |
|
|
height = im_height // 2 |
|
|
width = im_width // 2 |
|
|
elif crop_style == CROP_LOWER_HALF: |
|
|
top = im_height // 2 |
|
|
left = 0 |
|
|
height = im_height // 2 |
|
|
width = im_width |
|
|
else: |
|
|
return img |
|
|
|
|
|
return img.crop((left, top, left + width, top + height)) |
|
|
|
|
|
def _build_transform(self, config: dict) -> transforms.Compose: |
|
|
"""Build torchvision transform from config.""" |
|
|
transform_list = [] |
|
|
|
|
|
if config.get("crop"): |
|
|
transform_list.append( |
|
|
transforms.Lambda(lambda img: self._custom_crop(img, config["crop"])) |
|
|
) |
|
|
|
|
|
if config.get("resize"): |
|
|
size = config["resize"] |
|
|
if isinstance(size, int): |
|
|
size = (size, size) |
|
|
transform_list.append(transforms.Resize(size)) |
|
|
|
|
|
transform_list.append(transforms.ToTensor()) |
|
|
|
|
|
if config.get("normalize"): |
|
|
transform_list.append(transforms.Normalize(*config["normalize"])) |
|
|
|
|
|
return transforms.Compose(transform_list) |
|
|
|
|
|
def _predict(self, model, data: torch.Tensor, class_to_idx: dict) -> tuple: |
|
|
"""Run prediction and convert to class/value.""" |
|
|
with torch.no_grad(): |
|
|
outputs = model(data) |
|
|
values = model.get_class_probabilities(outputs) |
|
|
|
|
|
idx_to_class = {i: cls for cls, i in class_to_idx.items()} |
|
|
|
|
|
if len(values.shape) < 2: |
|
|
|
|
|
classes = [ |
|
|
idx_to_class[ |
|
|
min(max(int(v.round().item()), min(class_to_idx.values())), |
|
|
max(class_to_idx.values())) |
|
|
] |
|
|
for v in values |
|
|
] |
|
|
values_list = values.tolist() |
|
|
else: |
|
|
|
|
|
classes = [idx_to_class[idx.item()] for idx in torch.argmax(values, dim=1)] |
|
|
values_list = values.tolist() |
|
|
|
|
|
return classes, values_list |
|
|
|
|
|
def _process_image(self, image: Image.Image) -> dict: |
|
|
"""Process a single image through all models.""" |
|
|
|
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
road_data = self.transform_road_type(image).unsqueeze(0).to(self.device) |
|
|
road_classes, road_values = self._predict( |
|
|
self.models["road_type"], |
|
|
road_data, |
|
|
self.class_mappings["road_type"] |
|
|
) |
|
|
|
|
|
|
|
|
surface_data = self.transform_surface(image).unsqueeze(0).to(self.device) |
|
|
surface_classes, surface_values = self._predict( |
|
|
self.models["surface_type"], |
|
|
surface_data, |
|
|
self.class_mappings["surface_type"] |
|
|
) |
|
|
|
|
|
|
|
|
surface_type = surface_classes[0] |
|
|
quality_class = None |
|
|
quality_value = None |
|
|
|
|
|
if surface_type in self.models["quality"]: |
|
|
quality_classes, quality_values = self._predict( |
|
|
self.models["quality"][surface_type], |
|
|
surface_data, |
|
|
self.class_mappings["quality"][surface_type] |
|
|
) |
|
|
quality_class = quality_classes[0] |
|
|
quality_value = quality_values[0] |
|
|
|
|
|
return { |
|
|
"road_type": road_classes[0], |
|
|
"road_type_confidence": max(road_values[0]) if isinstance(road_values[0], list) else road_values[0], |
|
|
"surface_type": surface_type, |
|
|
"surface_type_confidence": max(surface_values[0]) if isinstance(surface_values[0], list) else surface_values[0], |
|
|
"quality_class": quality_class, |
|
|
"quality_value": quality_value, |
|
|
} |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process inference request. |
|
|
|
|
|
Args: |
|
|
data: Request data containing either: |
|
|
- "inputs": base64-encoded image or URL |
|
|
- "image": PIL Image (when called directly) |
|
|
|
|
|
Returns: |
|
|
List of prediction results |
|
|
""" |
|
|
inputs = data.get("inputs", data.get("image")) |
|
|
|
|
|
if inputs is None: |
|
|
return [{"error": "No input provided. Send 'inputs' with base64 image or URL."}] |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
if inputs.startswith("data:image"): |
|
|
|
|
|
inputs = inputs.split(",")[1] |
|
|
image_bytes = base64.b64decode(inputs) |
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
elif inputs.startswith("http"): |
|
|
|
|
|
import requests |
|
|
response = requests.get(inputs, timeout=10) |
|
|
image = Image.open(io.BytesIO(response.content)) |
|
|
else: |
|
|
|
|
|
image_bytes = base64.b64decode(inputs) |
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
elif isinstance(inputs, Image.Image): |
|
|
image = inputs |
|
|
elif isinstance(inputs, bytes): |
|
|
image = Image.open(io.BytesIO(inputs)) |
|
|
else: |
|
|
return [{"error": f"Unsupported input type: {type(inputs)}"}] |
|
|
|
|
|
result = self._process_image(image) |
|
|
return [result] |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Error processing request") |
|
|
return [{"error": str(e)}] |
|
|
|