ma4389 commited on
Commit
cc65e2c
Β·
verified Β·
1 Parent(s): 2c95f0d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +76 -0
  2. best_model.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ from PIL import Image
6
+
7
+ # ===============================
8
+ # πŸ“Œ Classes (same order as ImageFolder during training)
9
+ # Replace with trainers.classes if available
10
+ # ===============================
11
+ class_names = [
12
+ "Acura", "Alfa Romeo", "Aston Martin", "Audi", "Bentley",
13
+ "BMW", "Bugatti", "Buick", "Cadillac", "Chevrolet",
14
+ "Chrysler", "Citroen", "Dodge", "Ferrari", "Fiat",
15
+ "Ford", "Genesis", "GMC", "Honda", "Hyundai",
16
+ "Infiniti", "Jaguar", "Jeep", "Kia", "Lamborghini",
17
+ "Land Rover", "Lexus", "Lincoln", "Maserati", "Mazda",
18
+ "McLaren", "Mercedes", "Mini", "Mitsubishi", "Nissan",
19
+ "Pagani", "Peugeot", "Porsche", "Ram", "Renault",
20
+ "Rolls Royce", "Saab", "Subaru", "Suzuki", "Tesla",
21
+ "Toyota", "Volkswagen", "Volvo", "Others1", "Others2"
22
+ ]
23
+
24
+ # ===============================
25
+ # πŸ“Œ Inference Transform (NO augmentation)
26
+ # ===============================
27
+ transform = transforms.Compose([
28
+ transforms.Lambda(lambda x: x.convert("RGB")),
29
+ transforms.Resize((224,224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.5]*3, [0.5]*3)
32
+ ])
33
+
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # ===============================
37
+ # πŸ“Œ Load Model
38
+ # ===============================
39
+ def load_model(weights_path="best_model.pth"):
40
+ base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
41
+ in_features = base_model.fc.in_features
42
+ base_model.fc = nn.Sequential(
43
+ nn.Linear(in_features, 512),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.5),
46
+ nn.Linear(512, len(class_names))
47
+ )
48
+ base_model.load_state_dict(torch.load(weights_path, map_location=device))
49
+ return base_model.to(device).eval()
50
+
51
+ model = load_model()
52
+
53
+ # ===============================
54
+ # πŸ“Œ Prediction Function
55
+ # ===============================
56
+ def predict(image):
57
+ img = transform(image).unsqueeze(0).to(device)
58
+ with torch.no_grad():
59
+ outputs = model(img)
60
+ probs = torch.softmax(outputs, dim=1)[0]
61
+ top5_probs, top5_idx = torch.topk(probs, 5)
62
+ return {class_names[idx]: float(prob) for idx, prob in zip(top5_idx, top5_probs)}
63
+
64
+ # ===============================
65
+ # πŸ“Œ Gradio Interface
66
+ # ===============================
67
+ demo = gr.Interface(
68
+ fn=predict,
69
+ inputs=gr.Image(type="pil"),
70
+ outputs=gr.Label(num_top_classes=5),
71
+ title="πŸš— Car Brand Classifier",
72
+ description="Upload a car image and the model predicts the brand (Top-5)."
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c15273cfeec798e0051940a97c1c18cf51a14a995e8aa155a9d98eeade1cc043
3
+ size 98650771
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow