ma4389 commited on
Commit
f726feb
·
verified ·
1 Parent(s): 11cfb1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -32
app.py CHANGED
@@ -1,64 +1,89 @@
1
  import torch
2
  import torch.nn as nn
 
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # -----------------------------
8
- # 1. Load your model
9
- # -----------------------------
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- model = torch.load("best_model.pth", map_location=device) # update filename if needed
12
- model.eval()
13
-
14
- # -----------------------------
15
- # 2. Class mapping (from training)
16
- # -----------------------------
17
  class_to_idx = {
18
- 'Acura': 0, 'Alfa Romeo': 1, 'Aston Martin': 2, 'Audi': 3, 'BMW': 4, 'Bentley': 5, 'Bugatti': 6,
19
- 'Buick': 7, 'Cadillac': 8, 'Chevrolet': 9, 'Chrysler': 10, 'Citroen': 11, 'Daewoo': 12,
20
- 'Dodge': 13, 'Ferrari': 14, 'Fiat': 15, 'Ford': 16, 'GMC': 17, 'Genesis': 18, 'Honda': 19,
21
- 'Hudson': 20, 'Hyundai': 21, 'Infiniti': 22, 'Jaguar': 23, 'Jeep': 24, 'Kia': 25, 'Land Rover': 26,
22
- 'Lexus': 27, 'Lincoln': 28, 'MG': 29, 'Maserati': 30, 'Mazda': 31, 'Mercedes-Benz': 32, 'Mini': 33,
23
- 'Mitsubishi': 34, 'Nissan': 35, 'Oldsmobile': 36, 'Peugeot': 37, 'Pontiac': 38, 'Porsche': 39,
24
- 'Ram Trucks': 40, 'Renault': 41, 'Saab': 42, 'Studebaker': 43, 'Subaru': 44, 'Suzuki': 45,
25
- 'Tesla': 46, 'Toyota': 47, 'Volkswagen': 48, 'Volvo': 49
 
 
 
26
  }
27
  idx_to_class = {v: k for k, v in class_to_idx.items()}
28
 
29
- # -----------------------------
30
- # 3. Transform (inference version: no randomness)
31
- # -----------------------------
32
  transform = transforms.Compose([
33
  transforms.Lambda(lambda x: x.convert("RGB")),
34
  transforms.Resize((224,224)),
 
 
 
 
 
35
  transforms.ToTensor(),
36
  transforms.Normalize([0.5]*3, [0.5]*3)
37
  ])
38
 
39
- # -----------------------------
40
- # 4. Prediction function
41
- # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def predict(img):
43
  img = transform(img).unsqueeze(0).to(device)
44
 
45
  with torch.no_grad():
46
- outputs = model(img)
47
  probs = torch.softmax(outputs, dim=1)[0]
48
 
49
- top5_probs, top5_idx = torch.topk(probs, 5)
50
- results = {idx_to_class[idx.item()]: float(top5_probs[i]) for i, idx in enumerate(top5_idx)}
51
- return results
52
 
53
- # -----------------------------
54
- # 5. Gradio UI
55
- # -----------------------------
56
  demo = gr.Interface(
57
  fn=predict,
58
  inputs=gr.Image(type="pil"),
59
  outputs=gr.Label(num_top_classes=5),
60
  title="Car Brand Classifier",
61
- description="Upload a car image to classify its brand"
62
  )
63
 
64
  if __name__ == "__main__":
 
1
  import torch
2
  import torch.nn as nn
3
+ import torchvision.models as models
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ # --------------------
9
+ # Class Mapping
10
+ # --------------------
 
 
 
 
 
 
 
11
  class_to_idx = {
12
+ 'Acura': 0, 'Alfa Romeo': 1, 'Aston Martin': 2, 'Audi': 3, 'BMW': 4,
13
+ 'Bentley': 5, 'Bugatti': 6, 'Buick': 7, 'Cadillac': 8, 'Chevrolet': 9,
14
+ 'Chrysler': 10, 'Citroen': 11, 'Daewoo': 12, 'Dodge': 13, 'Ferrari': 14,
15
+ 'Fiat': 15, 'Ford': 16, 'GMC': 17, 'Genesis': 18, 'Honda': 19,
16
+ 'Hudson': 20, 'Hyundai': 21, 'Infiniti': 22, 'Jaguar': 23, 'Jeep': 24,
17
+ 'Kia': 25, 'Land Rover': 26, 'Lexus': 27, 'Lincoln': 28, 'MG': 29,
18
+ 'Maserati': 30, 'Mazda': 31, 'Mercedes-Benz': 32, 'Mini': 33,
19
+ 'Mitsubishi': 34, 'Nissan': 35, 'Oldsmobile': 36, 'Peugeot': 37,
20
+ 'Pontiac': 38, 'Porsche': 39, 'Ram Trucks': 40, 'Renault': 41,
21
+ 'Saab': 42, 'Studebaker': 43, 'Subaru': 44, 'Suzuki': 45, 'Tesla': 46,
22
+ 'Toyota': 47, 'Volkswagen': 48, 'Volvo': 49
23
  }
24
  idx_to_class = {v: k for k, v in class_to_idx.items()}
25
 
26
+ # --------------------
27
+ # Image Transform
28
+ # --------------------
29
  transform = transforms.Compose([
30
  transforms.Lambda(lambda x: x.convert("RGB")),
31
  transforms.Resize((224,224)),
32
+ transforms.RandomHorizontalFlip(p=0.5),
33
+ transforms.RandomVerticalFlip(p=0.2),
34
+ transforms.RandomRotation(20),
35
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
36
+ transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
37
  transforms.ToTensor(),
38
  transforms.Normalize([0.5]*3, [0.5]*3)
39
  ])
40
 
41
+ # --------------------
42
+ # Load Model
43
+ # --------------------
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ # Load pretrained ResNet50 correctly
47
+ base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
48
+
49
+ # Replace final fully connected layer (your head)
50
+ in_features = base_model.fc.in_features # 2048 for ResNet50
51
+ base_model.fc = nn.Sequential(
52
+ nn.Linear(in_features, 512),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.5),
55
+ nn.Linear(512, 50) # 50 classes
56
+ )
57
+
58
+ # Load state dict
59
+ state_dict = torch.load("best_model.pth", map_location=device)
60
+ base_model.load_state_dict(state_dict)
61
+ base_model = base_model.to(device)
62
+ base_model.eval()
63
+
64
+ # --------------------
65
+ # Prediction Function
66
+ # --------------------
67
  def predict(img):
68
  img = transform(img).unsqueeze(0).to(device)
69
 
70
  with torch.no_grad():
71
+ outputs = base_model(img)
72
  probs = torch.softmax(outputs, dim=1)[0]
73
 
74
+ top5_prob, top5_idx = torch.topk(probs, 5)
75
+ result = {idx_to_class[idx.item()]: float(top5_prob[i]) for i, idx in enumerate(top5_idx)}
76
+ return result
77
 
78
+ # --------------------
79
+ # Gradio UI
80
+ # --------------------
81
  demo = gr.Interface(
82
  fn=predict,
83
  inputs=gr.Image(type="pil"),
84
  outputs=gr.Label(num_top_classes=5),
85
  title="Car Brand Classifier",
86
+ description="Upload a car image to predict its brand."
87
  )
88
 
89
  if __name__ == "__main__":