Rahul2020 commited on
Commit
eeaf3e5
·
verified ·
1 Parent(s): fc0efe5
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms, models
5
+ from PIL import Image
6
+ import requests
7
+
8
+ # Load ImageNet labels
9
+ LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
10
+ labels = requests.get(LABELS_URL).text.strip().split("\n")
11
+
12
+ # Load model (change this to your model path)
13
+ model = models.resnet50(weights=None)
14
+
15
+ # If using your converted FP16 model:
16
+ # state = torch.load("model_cpu.pt", map_location="cpu")
17
+ # def to_fp32(obj):
18
+ # if isinstance(obj, torch.Tensor) and obj.dtype == torch.float16:
19
+ # return obj.float()
20
+ # if isinstance(obj, dict):
21
+ # return {k: to_fp32(v) for k, v in obj.items()}
22
+ # return obj
23
+ # model.load_state_dict(to_fp32(state))
24
+
25
+ # For demo, using pretrained weights
26
+ model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
27
+ model.eval()
28
+
29
+ # Preprocessing
30
+ transform = transforms.Compose([
31
+ transforms.Resize(256),
32
+ transforms.CenterCrop(224),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
35
+ ])
36
+
37
+ def predict(image):
38
+ if image is None:
39
+ return {}
40
+
41
+ img = transform(image).unsqueeze(0)
42
+
43
+ with torch.no_grad():
44
+ outputs = model(img)
45
+ probs = F.softmax(outputs, dim=1)[0]
46
+
47
+ top5_probs, top5_indices = torch.topk(probs, 5)
48
+
49
+ return {labels[idx]: float(prob) for prob, idx in zip(top5_probs, top5_indices)}
50
+
51
+ # Gradio interface
52
+ demo = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Image(type="pil", label="Upload Image"),
55
+ outputs=gr.Label(num_top_classes=5, label="Predictions"),
56
+ title="🖼️ ImageNet 1K Classifier",
57
+ description="Upload an image to classify it into one of 1000 ImageNet categories.",
58
+ examples=[
59
+ ["https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/1200px-Cat_November_2010-1a.jpg"],
60
+ ],
61
+ theme=gr.themes.Soft(),
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch()