Aumkeshchy2003 commited on
Commit
1396142
·
verified ·
1 Parent(s): c14f0d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -7,8 +7,7 @@ import gradio as gr
7
  from PIL import Image
8
  import math
9
 
10
- # ------------------------
11
- # Configuration (must match your trained model)
12
  cfg = {
13
  "image_size": 32,
14
  "patch_size": 4,
@@ -41,10 +40,7 @@ classes = [
41
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
- # ------------------------
45
- # Model definition
46
  # ViT model implementation
47
- # --- Conv stem (replace PatchEmbed) ---
48
  class ConvPatchEmbed(nn.Module):
49
  def __init__(self, in_chans=3, embed_dim=384):
50
  super().__init__()
@@ -67,9 +63,9 @@ class ConvPatchEmbed(nn.Module):
67
 
68
  def forward(self, x):
69
  # x: (B, C, H, W)
70
- x = self.conv(x) # (B, E, H/4, W/4) -> H/4=8 for 32x32
71
- x = x.flatten(2) # (B, E, N)
72
- x = x.transpose(1, 2) # (B, N, E)
73
  return x
74
 
75
  class MLP(nn.Module):
@@ -192,13 +188,12 @@ class ViT(nn.Module):
192
  out = self.head(cls)
193
  return out
194
 
195
- # ------------------------
196
  # Load model weights
197
  model = ViT(cfg).to(device)
198
- model.load_state_dict(torch.load("best_vit_cifar100_small.pt", map_location=device))
199
  model.eval()
200
 
201
- # ------------------------
202
  # Image preprocessing
203
  transform = transforms.Compose([
204
  transforms.Resize((32,32)),
@@ -215,7 +210,6 @@ def predict(img: Image.Image):
215
  result = {classes[i]: float(probs[i]) for i in top5.indices}
216
  return result
217
 
218
- # ------------------------
219
  # Gradio interface
220
  iface = gr.Interface(
221
  fn=predict,
 
7
  from PIL import Image
8
  import math
9
 
10
+ # Configuration
 
11
  cfg = {
12
  "image_size": 32,
13
  "patch_size": 4,
 
40
 
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
 
 
43
  # ViT model implementation
 
44
  class ConvPatchEmbed(nn.Module):
45
  def __init__(self, in_chans=3, embed_dim=384):
46
  super().__init__()
 
63
 
64
  def forward(self, x):
65
  # x: (B, C, H, W)
66
+ x = self.conv(x)
67
+ x = x.flatten(2)
68
+ x = x.transpose(1, 2)
69
  return x
70
 
71
  class MLP(nn.Module):
 
188
  out = self.head(cls)
189
  return out
190
 
191
+
192
  # Load model weights
193
  model = ViT(cfg).to(device)
194
+ model.load_state_dict(torch.load("best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device))
195
  model.eval()
196
 
 
197
  # Image preprocessing
198
  transform = transforms.Compose([
199
  transforms.Resize((32,32)),
 
210
  result = {classes[i]: float(probs[i]) for i in top5.indices}
211
  return result
212
 
 
213
  # Gradio interface
214
  iface = gr.Interface(
215
  fn=predict,