serviceadvisor / car_advisor /vision_model.py
viswanani's picture
Upload 22 files
1c7bc31 verified
from typing import List, Dict
from PIL import Image
import torch, os
import torch.nn as nn
import torchvision.transforms as T
from .config import DEFAULT_LABELS
from .utils import softmax
class SimpleVisionModel(nn.Module):
"""
Wrapper around a lightweight classifier. For training, use training/train_vision.py.
At inference, if checkpoint absent or downloads fail, we return rule-based scores.
"""
def __init__(self, num_classes: int):
super().__init__()
try:
import timm
self.net = timm.create_model("mobilenetv3_small_100", pretrained=True, num_classes=num_classes)
except Exception:
self.net = nn.Sequential(
nn.AdaptiveAvgPool2d((8,8)),
nn.Flatten(),
nn.Linear(8*8*3, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x):
return self.net(x)
class VisionInference:
def __init__(self, labels: List[str] = None, ckpt_path: str = "checkpoints/vision/best.pt"):
self.labels = labels or DEFAULT_LABELS
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SimpleVisionModel(num_classes=len(self.labels)).to(self.device)
self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()])
self.ready = False
if os.path.exists(ckpt_path):
try:
state = torch.load(ckpt_path, map_location=self.device)
self.model.load_state_dict(state["model"] if "model" in state else state)
self.ready = True
except Exception:
self.ready = False
@torch.no_grad()
def predict(self, image: Image.Image) -> Dict[str, float]:
if image is None:
return {l: 0.0 for l in self.labels}
try:
x = self.transform(image.convert("RGB")).unsqueeze(0).to(self.device)
logits = self.model(x)[0].detach().cpu().tolist()
probs = softmax(logits)
return {lbl: float(p) for lbl, p in zip(self.labels, probs)}
except Exception:
import numpy as np
img = image.convert("RGB").resize((64,64))
arr = np.array(img).astype("float32")/255.0
gray = arr.mean(axis=2)
contrast = float(gray.std())
red_mean = float(arr[:,:,0].mean())
green_mean = float(arr[:,:,1].mean())
blue_mean = float(arr[:,:,2].mean())
scores = {l: 0.01 for l in self.labels}
if contrast > 0.22:
scores["scratch_dent"] += 0.2
scores["paint_damage"] += 0.15
scores["bumper_damage"] += 0.1
if blue_mean < 0.35 and green_mean < 0.35:
scores["rust"] += 0.2
if red_mean > 0.55:
scores["engine_leak"] += 0.15
s = sum(scores.values())
return {k: v/s for k,v in scores.items()}