mgbam commited on
Commit
058024c
·
verified ·
1 Parent(s): 62a234a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from torch.utils.data import DataLoader
5
+ from PIL import Image
6
+ import gradio as gr
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Load class names (make sure this file is in the Space)
11
+ with open("cifar10_classes.txt") as f:
12
+ CLASSES = [line.strip() for line in f.readlines()]
13
+
14
+ def build_model(num_classes: int, device: str = "cpu"):
15
+ try:
16
+ weights = models.ResNet18_Weights.DEFAULT
17
+ model = models.resnet18(weights=weights)
18
+ except AttributeError:
19
+ model = models.resnet18(weights="IMAGENET1K_V1")
20
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
21
+ model = model.to(device)
22
+ return model
23
+
24
+ num_classes = len(CLASSES)
25
+ model = build_model(num_classes, device=DEVICE)
26
+
27
+ state_dict = torch.load("ast_cifar10_resnet18.pth", map_location=DEVICE)
28
+ model.load_state_dict(state_dict)
29
+ model.eval()
30
+
31
+ preprocess = transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.ToTensor(),
34
+ ])
35
+
36
+ def predict(image: Image.Image):
37
+ if image is None:
38
+ return {}
39
+ x = preprocess(image).unsqueeze(0).to(DEVICE)
40
+ with torch.no_grad():
41
+ logits = model(x)
42
+ probs = torch.softmax(logits, dim=1)[0]
43
+ return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
44
+
45
+ demo = gr.Interface(
46
+ fn=predict,
47
+ inputs=gr.Image(type="pil", label="Upload CIFAR-like image"),
48
+ outputs=gr.Label(num_top_classes=3, label="Top-3 Predictions"),
49
+ title="AST CIFAR-10 Classifier",
50
+ description="ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.",
51
+ )
52
+
53
+ if __name__ == "__main__":
54
+ demo.launch()