saneshashank commited on
Commit
45b5eaa
·
verified ·
1 Parent(s): 8bda9c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ import json
6
+ from torchvision import transforms
7
+ from model import resnet50
8
+
9
+ # Load class labels from local file
10
+ with open("imagenet_classes.json", "r") as f:
11
+ class_labels = json.load(f)
12
+
13
+ # Load model
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model = resnet50(num_classes=1000, drop_path_rate=0.0, use_blurpool=True)
16
+ model.load_state_dict(torch.load("best_resnet50_imagenet_1k.pt", map_location=device))
17
+ model.to(device)
18
+ model.eval()
19
+
20
+ # Image preprocessing
21
+ transform = transforms.Compose([
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ def predict(image):
29
+ # Preprocess
30
+ img_tensor = transform(image).unsqueeze(0).to(device)
31
+
32
+ # Predict
33
+ with torch.no_grad():
34
+ outputs = model(img_tensor)
35
+ probabilities = F.softmax(outputs, dim=1)[0]
36
+
37
+ # Get top 5 predictions
38
+ top5_prob, top5_idx = torch.topk(probabilities, 5)
39
+
40
+ # Format results
41
+ results = {class_labels[idx]: float(prob) for idx, prob in zip(top5_idx, top5_prob)}
42
+ return results
43
+
44
+ # Create Gradio interface
45
+ demo = gr.Interface(
46
+ fn=predict,
47
+ inputs=gr.Image(type="pil"),
48
+ outputs=gr.Label(num_top_classes=5),
49
+ title="ImageNet ResNet50 Classifier (71% Accuracy)",
50
+ description="ResNet50 trained on ImageNet with improved stem, BlurPool, and progressive resizing. Achieved 71% top-1 accuracy under $30 budget.",
51
+ examples=[]
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ demo.launch()