Jkyutech commited on
Commit
111453b
·
1 Parent(s): 1852565

Add application file

Browse files
Files changed (2) hide show
  1. app.py +22 -19
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,32 +1,35 @@
1
  import requests
2
-
3
  import gradio as gr
4
  import torch
5
  from timm import create_model
6
  from timm.data import resolve_data_config
7
  from timm.data.transforms_factory import create_transform
 
8
 
9
  IMAGENET_1k_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
10
- LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
11
-
12
- model = create_model('resnet50', pretrained=True)
13
 
14
- transform = create_transform(
15
- **resolve_data_config({}, model=model)
16
- )
17
- model.eval()
18
 
19
- def predict_fn(img):
20
- img = img.convert('RGB')
21
- img = transform(img).unsqueeze(0)
22
 
 
 
 
23
  with torch.no_grad():
24
- out = model(img)
25
-
26
- probabilites = torch.nn.functional.softmax(out[0], dim=0)
27
-
28
- values, indices = torch.topk(probabilites, k=5)
29
-
30
- return {LABELS[i]: v.item() for i, v in zip(indices, values)}
 
 
 
 
31
 
32
- gr.Interface(predict_fn, gr.inputs.Image(type='pil'), outputs='label').launch()
 
 
1
  import requests
 
2
  import gradio as gr
3
  import torch
4
  from timm import create_model
5
  from timm.data import resolve_data_config
6
  from timm.data.transforms_factory import create_transform
7
+ from PIL import Image
8
 
9
  IMAGENET_1k_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
10
+ LABELS = requests.get(IMAGENET_1k_URL, timeout=10).text.strip().split("\n")
 
 
11
 
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
13
 
14
+ model = create_model("resnet50", pretrained=True).to(device).eval()
15
+ config = resolve_data_config({}, model=model)
16
+ transform = create_transform(**config)
17
 
18
+ def predict_fn(img: Image.Image):
19
+ img = img.convert("RGB")
20
+ x = transform(img).unsqueeze(0).to(device)
21
  with torch.no_grad():
22
+ out = model(x)[0]
23
+ probs = torch.nn.functional.softmax(out, dim=0)
24
+ values, indices = torch.topk(probs, k=5)
25
+ return {LABELS[i.item()]: v.item() for i, v in zip(indices, values)}
26
+
27
+ demo = gr.Interface(
28
+ fn=predict_fn,
29
+ inputs=gr.Image(type="pil"),
30
+ outputs=gr.Label(num_top_classes=5),
31
+ title="ResNet-50 (timm) · ImageNet Top-5"
32
+ )
33
 
34
+ if __name__ == "__main__":
35
+ demo.launch()
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  timm
2
- gradio==3.50.2
 
1
  timm
2
+ gradio