godsofheaven commited on
Commit
793b943
·
verified ·
1 Parent(s): 05794ce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Any, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torchvision.models as tvm
7
+ from torchvision.transforms import functional as F
8
+ from torchvision import transforms as T
9
+ from PIL import Image
10
+ import gradio as gr
11
+
12
+
13
+ CHECKPOINT_PATH = os.environ.get("CKPT_PATH", "runs/best.pth")
14
+
15
+
16
+ def get_device() -> torch.device:
17
+ if torch.cuda.is_available():
18
+ return torch.device("cuda")
19
+ return torch.device("cpu")
20
+
21
+
22
+ def build_model(num_classes: int = 1000) -> nn.Module:
23
+ model = tvm.resnet50(weights=None)
24
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
25
+ return model
26
+
27
+
28
+ def get_preprocess_and_labels():
29
+ # Use torchvision's ImageNet-1k metadata for categories and canonical transforms
30
+ try:
31
+ weights = tvm.ResNet50_Weights.IMAGENET1K_V2
32
+ except Exception:
33
+ # Fallback if weights enum not available
34
+ weights = None
35
+ if weights is not None:
36
+ preprocess = weights.transforms()
37
+ labels = weights.meta.get("categories", [str(i) for i in range(1000)])
38
+ else:
39
+ preprocess = T.Compose(
40
+ [
41
+ T.Resize(256, interpolation=T.InterpolationMode.BILINEAR),
42
+ T.CenterCrop(224),
43
+ T.ToTensor(),
44
+ T.Normalize(
45
+ mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225],
47
+ ),
48
+ ]
49
+ )
50
+ labels = [str(i) for i in range(1000)]
51
+ return preprocess, labels
52
+
53
+
54
+ def load_checkpoint_into_model(model: nn.Module, checkpoint_path: str) -> None:
55
+ if not os.path.exists(checkpoint_path):
56
+ raise FileNotFoundError(
57
+ f"Checkpoint not found at '{checkpoint_path}'. "
58
+ f"Place your file at runs/exp1/best.pth or set CKPT_PATH env var."
59
+ )
60
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
61
+ # Support either a full training checkpoint dict or a raw state_dict
62
+ state_dict = checkpoint.get("model", checkpoint)
63
+ model.load_state_dict(state_dict, strict=False)
64
+ model.eval()
65
+
66
+
67
+ device = get_device()
68
+ model = build_model(num_classes=1000).to(device)
69
+ preprocess, imagenet_labels = get_preprocess_and_labels()
70
+ load_checkpoint_into_model(model, CHECKPOINT_PATH)
71
+
72
+
73
+ def predict_images(
74
+ images: Union[Image.Image, List[Image.Image]],
75
+ top_k: int = 5,
76
+ ) -> List[List[Dict[str, Any]]]:
77
+ if images is None:
78
+ return []
79
+ if not isinstance(images, list):
80
+ images = [images]
81
+
82
+ results: List[List[Dict[str, Any]]] = []
83
+ with torch.no_grad():
84
+ for image in images:
85
+ if not isinstance(image, Image.Image):
86
+ # Some gradio versions may return dicts; handle defensively
87
+ image = Image.fromarray(image)
88
+ tensor = preprocess(image).unsqueeze(0).to(device)
89
+ logits = model(tensor)
90
+ probs = torch.softmax(logits, dim=1)[0]
91
+ topk = torch.topk(probs, k=top_k)
92
+ sample_result: List[Dict[str, Any]] = []
93
+ for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
94
+ label = imagenet_labels[idx] if 0 <= idx < len(imagenet_labels) else str(idx)
95
+ sample_result.append({"label": label, "probability": float(score)})
96
+ results.append(sample_result)
97
+ return results
98
+
99
+
100
+ with gr.Blocks(title="ResNet-50 ImageNet-1k Classifier") as demo:
101
+ gr.Markdown(
102
+ """
103
+ **ResNet-50 ImageNet-1k Classifier**
104
+
105
+ - Upload one or more images and get top-5 predictions.
106
+ - Model weights loaded from `runs/exp1/best.pth`.
107
+ """
108
+ )
109
+ with gr.Row():
110
+ with gr.Column():
111
+ input_images = gr.Image(
112
+ label="Upload images",
113
+ type="pil",
114
+ sources=["upload", "clipboard"],
115
+ )
116
+ gr.Examples(
117
+ examples=[
118
+ "input-examples/goldfish.png",
119
+ "input-examples/tiger-shark.png",
120
+ "input-examples/toilet-tissue.png",
121
+ ],
122
+ inputs=input_images,
123
+ label="Example images",
124
+ )
125
+ topk = gr.Slider(1, 10, value=5, step=1, label="Top-K")
126
+ run_btn = gr.Button("Predict")
127
+ with gr.Column():
128
+ output = gr.JSON(label="Predictions (per-image top-K)")
129
+
130
+ run_btn.click(fn=predict_images, inputs=[input_images, topk], outputs=output)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
135
+
136
+