Spaces:
Sleeping
Sleeping
| from typing import List, Dict | |
| from PIL import Image | |
| import torch, os | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| from .config import DEFAULT_LABELS | |
| from .utils import softmax | |
| class SimpleVisionModel(nn.Module): | |
| """ | |
| Wrapper around a lightweight classifier. For training, use training/train_vision.py. | |
| At inference, if checkpoint absent or downloads fail, we return rule-based scores. | |
| """ | |
| def __init__(self, num_classes: int): | |
| super().__init__() | |
| try: | |
| import timm | |
| self.net = timm.create_model("mobilenetv3_small_100", pretrained=True, num_classes=num_classes) | |
| except Exception: | |
| self.net = nn.Sequential( | |
| nn.AdaptiveAvgPool2d((8,8)), | |
| nn.Flatten(), | |
| nn.Linear(8*8*3, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class VisionInference: | |
| def __init__(self, labels: List[str] = None, ckpt_path: str = "checkpoints/vision/best.pt"): | |
| self.labels = labels or DEFAULT_LABELS | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = SimpleVisionModel(num_classes=len(self.labels)).to(self.device) | |
| self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()]) | |
| self.ready = False | |
| if os.path.exists(ckpt_path): | |
| try: | |
| state = torch.load(ckpt_path, map_location=self.device) | |
| self.model.load_state_dict(state["model"] if "model" in state else state) | |
| self.ready = True | |
| except Exception: | |
| self.ready = False | |
| def predict(self, image: Image.Image) -> Dict[str, float]: | |
| if image is None: | |
| return {l: 0.0 for l in self.labels} | |
| try: | |
| x = self.transform(image.convert("RGB")).unsqueeze(0).to(self.device) | |
| logits = self.model(x)[0].detach().cpu().tolist() | |
| probs = softmax(logits) | |
| return {lbl: float(p) for lbl, p in zip(self.labels, probs)} | |
| except Exception: | |
| import numpy as np | |
| img = image.convert("RGB").resize((64,64)) | |
| arr = np.array(img).astype("float32")/255.0 | |
| gray = arr.mean(axis=2) | |
| contrast = float(gray.std()) | |
| red_mean = float(arr[:,:,0].mean()) | |
| green_mean = float(arr[:,:,1].mean()) | |
| blue_mean = float(arr[:,:,2].mean()) | |
| scores = {l: 0.01 for l in self.labels} | |
| if contrast > 0.22: | |
| scores["scratch_dent"] += 0.2 | |
| scores["paint_damage"] += 0.15 | |
| scores["bumper_damage"] += 0.1 | |
| if blue_mean < 0.35 and green_mean < 0.35: | |
| scores["rust"] += 0.2 | |
| if red_mean > 0.55: | |
| scores["engine_leak"] += 0.15 | |
| s = sum(scores.values()) | |
| return {k: v/s for k,v in scores.items()} | |