0xtimi commited on
Commit
515bcf8
·
1 Parent(s): 1f9fa31

upgrade gradio

Browse files
Files changed (3) hide show
  1. .gitignore +125 -0
  2. app.py +40 -6
  3. requirements.txt +1 -1
.gitignore ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .gradio/
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .nox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+ .dmypy.json
106
+ dmypy.json
107
+
108
+ # Pyre type checker
109
+ .pyre/
110
+
111
+ # IDE
112
+ .idea/
113
+ .vscode/
114
+ *.swp
115
+ *.swo
116
+ *~
117
+
118
+ # OS
119
+ .DS_Store
120
+ .DS_Store?
121
+ ._*
122
+ .Spotlight-V100
123
+ .Trashes
124
+ ehthumbs.db
125
+ Thumbs.db
app.py CHANGED
@@ -2,6 +2,9 @@ from pathlib import Path
2
  import torch
3
  import gradio as gr
4
  from torch import nn
 
 
 
5
 
6
  LABELS = Path("class_names.txt").read_text().splitlines()
7
 
@@ -20,27 +23,58 @@ model = nn.Sequential(
20
  nn.ReLU(),
21
  nn.Linear(256, len(LABELS)),
22
  )
 
23
  state_dict = torch.load("pytorch_model.bin", map_location="cpu")
24
  model.load_state_dict(state_dict, strict=False)
25
  model.eval()
26
 
27
-
28
  def predict(im):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
 
30
  with torch.no_grad():
31
  out = model(x)
 
32
  probabilities = torch.nn.functional.softmax(out[0], dim=0)
33
  values, indices = torch.topk(probabilities, 5)
 
34
  return {LABELS[i]: v.item() for i, v in zip(indices, values)}
35
 
 
36
  interface = gr.Interface(
37
- predict,
38
- inputs="sketchpad",
39
- outputs="label",
40
- theme="huggingface",
 
 
 
41
  title="Sketch Recognition",
42
  description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
43
  article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
44
- live=True,
45
  )
 
46
  interface.launch(share=True)
 
2
  import torch
3
  import gradio as gr
4
  from torch import nn
5
+ import numpy as np
6
+
7
+ print(gr.__version__)
8
 
9
  LABELS = Path("class_names.txt").read_text().splitlines()
10
 
 
23
  nn.ReLU(),
24
  nn.Linear(256, len(LABELS)),
25
  )
26
+
27
  state_dict = torch.load("pytorch_model.bin", map_location="cpu")
28
  model.load_state_dict(state_dict, strict=False)
29
  model.eval()
30
 
 
31
  def predict(im):
32
+ if im is None:
33
+ return {}
34
+
35
+ # 处理输入图像
36
+ # 如果是字典格式(新版Gradio Sketchpad的输出),提取图像
37
+ if isinstance(im, dict):
38
+ im = im['image'] if 'image' in im else im.get('composite', None)
39
+
40
+ # 转换为numpy数组并确保是灰度图
41
+ if isinstance(im, np.ndarray):
42
+ if len(im.shape) == 3:
43
+ # 如果是RGB图像,转换为灰度图
44
+ im = np.mean(im, axis=2)
45
+ else:
46
+ return {}
47
+
48
+ # 确保图像尺寸正确(28x28)
49
+ if im.shape != (28, 28):
50
+ from PIL import Image
51
+ im_pil = Image.fromarray(im.astype('uint8'))
52
+ im_pil = im_pil.resize((28, 28))
53
+ im = np.array(im_pil)
54
+
55
+ # 转换为tensor并进行预测
56
  x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
57
+
58
  with torch.no_grad():
59
  out = model(x)
60
+
61
  probabilities = torch.nn.functional.softmax(out[0], dim=0)
62
  values, indices = torch.topk(probabilities, 5)
63
+
64
  return {LABELS[i]: v.item() for i, v in zip(indices, values)}
65
 
66
+ # 创建Gradio界面
67
  interface = gr.Interface(
68
+ fn=predict,
69
+ inputs=gr.Sketchpad(
70
+ image_mode="L", # 灰度模式
71
+ canvas_size=(280, 280), # 画布大小
72
+ brush=gr.Brush(default_size=10) # 画笔设置
73
+ ),
74
+ outputs=gr.Label(num_top_classes=5), # 显示前5个预测结果
75
  title="Sketch Recognition",
76
  description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
77
  article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
 
78
  )
79
+
80
  interface.launch(share=True)
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  torch
2
- gradio==3.50.0
 
1
  torch
2
+ gradio