ericxlima commited on
Commit
ccb94c4
·
1 Parent(s): a012208

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEVELOP_MODE = False
2
+ USER_MODE = not DEVELOP_MODE
3
+ AZURE_SEARCH_KEY = ""
4
+
5
+ import os
6
+ from pathlib import Path
7
+ import gradio as gr
8
+ from fastai.vision.all import *
9
+
10
+
11
+ if DEVELOP_MODE:
12
+ import fastbook
13
+ from fastbook import *
14
+ from fastai.vision.widgets import *
15
+ from fastai.vision.all import *
16
+ fastbook.setup_book()
17
+
18
+ import uuid
19
+ import requests
20
+ import imghdr
21
+ from PIL import Image
22
+ import numpy as np
23
+
24
+
25
+ attn_slicing_enabled = True
26
+
27
+
28
+ def download_unique_image(url, folder_path):
29
+ try:
30
+ response = requests.get(url, timeout=10)
31
+ content_type = response.headers.get('Content-Type')
32
+ if content_type.startswith('image'):
33
+ image_type = imghdr.what(None, response.content)
34
+ if image_type == 'jpeg':
35
+ extension = 'jpg'
36
+ else:
37
+ extension = image_type
38
+ filename = str(uuid.uuid4()) + '.' + extension
39
+ filepath = os.path.join(folder_path, filename)
40
+ with open(filepath, 'wb') as f:
41
+ f.write(response.content)
42
+ except:
43
+ pass
44
+
45
+
46
+ def remove_corrupted_images(folder_path):
47
+ count = 0
48
+ for file_name in os.listdir(folder_path):
49
+ file_path = os.path.join(folder_path, file_name)
50
+ try:
51
+ with Image.open(file_path) as img:
52
+ pass
53
+ except Exception as err:
54
+ os.remove(file_path)
55
+ count += 1
56
+
57
+
58
+ def normalize_dog_name(dog_name):
59
+ return dog_name.replace(' ', '_').lower()
60
+
61
+
62
+ def download_images_():
63
+ dogs = {
64
+ 'Zwergspitz Dog': [],
65
+ 'Bouledogue Français Dog': [],
66
+ 'Shih Tzu Dog': [],
67
+ 'Rottweiler Dog': [],
68
+ 'Pug Dog': [],
69
+ 'Golden Retriever Dog': [],
70
+ 'Deutscher Schäferhund Dog': [],
71
+ 'Yorkshire Terrier Dog': [],
72
+ 'Border Collie Dog': [],
73
+ 'Dachshund Dog': [],
74
+ 'Poodle Dog': [],
75
+ 'Labrador Retriever Dog': [],
76
+ 'Pinscher Dog': [],
77
+ 'Golden Retriever': [],
78
+ }
79
+ DOGS_NAMES = tuple(dogs.keys())
80
+ if DEVELOP_MODE:
81
+ if not PATH.exists():
82
+ PATH.mkdir()
83
+ for dog_name in DOGS_NAMES:
84
+ urls = search_images_bing(
85
+ AZURE_KEY, dog_name).attrgot('contentUrl')
86
+ dogs[dog_name] = urls
87
+
88
+ dest = os.path.join(PATH, normalize_dog_name(dog_name))
89
+ if not os.path.exists(dest):
90
+ os.mkdir(dest)
91
+ download_images(dest, urls=urls)
92
+ remove_corrupted_images(dest)
93
+ return [dog.replace('Dog', '') for dog in DOGS_NAMES]
94
+
95
+
96
+ def train_model():
97
+ dogs_datablock = DataBlock(
98
+ blocks=(ImageBlock, CategoryBlock),
99
+ get_items=get_image_files,
100
+ splitter=RandomSplitter(valid_pct=0.2, seed=42),
101
+ get_y=parent_label,
102
+ item_tfms=[Resize(128, ResizeMethod.Squish),
103
+ Resize(128, ResizeMethod.Pad, pad_mode='zeros'),
104
+ RandomResizedCrop(128, min_scale=0.3),
105
+ ]
106
+ )
107
+ dogs_dataloaders = dogs_datablock.dataloaders(PATH)
108
+ # dogs_dataloaders = dogs_dataloaders.new(
109
+ # item_tfms=Resize(128, ResizeMethod.Squish))
110
+ learn_ = vision_learner(dogs_dataloaders, resnet18, metrics=error_rate)
111
+ learn_.fine_tune(4)
112
+ learn_.export('dogs.pkl')
113
+ return learn_
114
+
115
+
116
+ def classify_image(image):
117
+ global learing
118
+ pred, pred_idx, probs = learing.predict(image)
119
+ return f"Prediction: {pred.replace('_', '').replace('dog', '').title()};\n Probability: {probs[pred_idx]:.04f}"
120
+
121
+
122
+ def get_model_():
123
+ path = Path()
124
+ model = None
125
+
126
+ if any(file.endswith('.pkl') for file in os.listdir(path)):
127
+ model_ = load_learner('dogs.pkl')
128
+ else:
129
+ model_ = train_model()
130
+ return model_
131
+
132
+
133
+ AZURE_KEY = os.environ.get(
134
+ 'AZURE_SEARCH_KEY',
135
+ AZURE_SEARCH_KEY,
136
+ )
137
+ PATH = Path('dogs')
138
+
139
+ dogs = download_images_()
140
+ learing = get_model_()
141
+
142
+
143
+ # Gradio
144
+ iface = gr.Interface(
145
+ classify_image,
146
+ inputs="image",
147
+ outputs="text",
148
+ title="Classificação de Imagens",
149
+ description="Insira uma imagem para ser classificada"
150
+ )
151
+
152
+
153
+ def set_mem_optimizations(pipe):
154
+ if attn_slicing_enabled:
155
+ pipe.enable_attention_slicing()
156
+ else:
157
+ pipe.disable_attention_slicing()
158
+
159
+
160
+ def list_breeds():
161
+ global dogs
162
+ html = "<div class='row'>"
163
+ html += "<div class='column'>"
164
+ html += "<h2>List of breed dogs trained:</h2>"
165
+ html += "<ol>" + "".join([f"<li>{breed}</li>" for breed in dogs]) + "</ol>"
166
+ html += "</div>"
167
+ html += "<div class='column'>"
168
+ html += "<h2>Author:</h2>"
169
+ html += "<a href='https://github.com/ericxlima'><img src='https://avatars.githubusercontent.com/u/58092119?v=4' alt='profile image' style='width:40%' /></a>"
170
+ html += "<h2><a href='https://github.com/ericxlima'>Eric de Lima</a></h2>"
171
+ html += "</div>"
172
+ html += "</div>"
173
+ return html
174
+
175
+
176
+ image = gr.Image(shape=(224, 224))
177
+ label = gr.Label(num_top_classes=3)
178
+ breeds_list = list_breeds()
179
+
180
+ demo = gr.Interface(
181
+ fn=classify_image,
182
+ inputs=image,
183
+ outputs=label,
184
+ title="🐶 Dog Breed Classifier",
185
+ interpretation="default",
186
+ description="Upload an image of a dog and the model will predict its breed.",
187
+ article=breeds_list,
188
+ css=".row { display: flex; } .column { flex: 50%; }",
189
+ )
190
+
191
+ demo.launch(share=True, debug=True)