mitbersh commited on
Commit
a4c4c95
·
verified ·
1 Parent(s): 198726a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -1
README.md CHANGED
@@ -4,4 +4,76 @@ datasets:
4
  base_model:
5
  - microsoft/resnet-18
6
  pipeline_tag: image-classification
7
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  base_model:
5
  - microsoft/resnet-18
6
  pipeline_tag: image-classification
7
+ ---
8
+ # AutoInspect — Car View Classifier (ResNet18)
9
+
10
+ Модель для определения ракурса автомобиля на изображении.
11
+ Часть проекта **AutoInspect** (pipeline: *view classification → car parts segmentation → damage segmentation*).
12
+
13
+ ## Task
14
+
15
+ Multi-class классификация ракурса автомобиля (9 классов).
16
+
17
+ ## Labels
18
+
19
+ ```python
20
+ CLASS_NAMES = [
21
+ "back",
22
+ "back-left",
23
+ "back-right",
24
+ "front",
25
+ "front-left",
26
+ "front-right",
27
+ "left",
28
+ "other",
29
+ "right",
30
+ ]
31
+ ```
32
+
33
+ ## How to use
34
+
35
+ ### Load model
36
+
37
+ ```python
38
+ import torch
39
+ import torch.nn as nn
40
+ from torchvision import models, transforms
41
+ from PIL import Image
42
+
43
+ CLASS_NAMES = [
44
+ "back", "back-left", "back-right",
45
+ "front", "front-left", "front-right",
46
+ "left", "other", "right"
47
+ ]
48
+
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ model = models.resnet18(weights=None)
52
+ model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
53
+
54
+ state_dict = torch.load("best_car_view_model.pth", map_location=device)
55
+ model.load_state_dict(state_dict)
56
+
57
+ model.eval()
58
+ model.to(device)
59
+ ```
60
+
61
+ ### Predict
62
+
63
+ ```python
64
+ preprocess = transforms.Compose([
65
+ transforms.Resize((224, 224)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize([0.485, 0.456, 0.406],
68
+ [0.229, 0.224, 0.225])
69
+ ])
70
+
71
+ img = Image.open("car.jpg").convert("RGB")
72
+ x = preprocess(img).unsqueeze(0).to(device)
73
+
74
+ with torch.no_grad():
75
+ logits = model(x)
76
+ pred = torch.argmax(logits, dim=1).item()
77
+
78
+ print("Prediction:", CLASS_NAMES[pred])
79
+ ```