Commit
·
63a23a5
1
Parent(s):
7868581
Subindo arquivos7
Browse files- app.py +22 -22
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -12,9 +12,9 @@ import torch.optim as optim
|
|
| 12 |
from torchvision import datasets, transforms, models
|
| 13 |
from torch.utils.data import DataLoader, random_split
|
| 14 |
from PIL import Image
|
| 15 |
-
import joblib #
|
| 16 |
|
| 17 |
-
#
|
| 18 |
model_dict = {
|
| 19 |
'AlexNet': models.alexnet,
|
| 20 |
'ResNet18': models.resnet18,
|
|
@@ -23,7 +23,7 @@ model_dict = {
|
|
| 23 |
'MobileNetV2': models.mobilenet_v2
|
| 24 |
}
|
| 25 |
|
| 26 |
-
#
|
| 27 |
model = None
|
| 28 |
train_loader = None
|
| 29 |
val_loader = None
|
|
@@ -32,38 +32,38 @@ dataset_path = 'dataset'
|
|
| 32 |
class_dirs = []
|
| 33 |
test_dataset_path = 'test_dataset'
|
| 34 |
test_class_dirs = []
|
| 35 |
-
num_classes = 2 #
|
| 36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
|
| 38 |
-
#
|
| 39 |
def setup_classes(num_classes_value):
|
| 40 |
global class_dirs, dataset_path, num_classes
|
| 41 |
|
| 42 |
-
num_classes = int(num_classes_value) #
|
| 43 |
-
#
|
| 44 |
if os.path.exists(dataset_path):
|
| 45 |
shutil.rmtree(dataset_path)
|
| 46 |
os.makedirs(dataset_path)
|
| 47 |
|
| 48 |
-
#
|
| 49 |
class_dirs = [os.path.join(dataset_path, f'class_{i}') for i in range(num_classes)]
|
| 50 |
for class_dir in class_dirs:
|
| 51 |
os.makedirs(class_dir)
|
| 52 |
|
| 53 |
return f"Criados {num_classes} diretórios para classes."
|
| 54 |
|
| 55 |
-
#
|
| 56 |
def upload_images(class_id, images):
|
| 57 |
class_dir = class_dirs[int(class_id)]
|
| 58 |
for image in images:
|
| 59 |
shutil.copy(image, class_dir)
|
| 60 |
return f"Imagens salvas na classe {class_id}."
|
| 61 |
|
| 62 |
-
#
|
| 63 |
def prepare_data(batch_size=32, resize=(224, 224)):
|
| 64 |
global train_loader, val_loader, test_loader, num_classes
|
| 65 |
|
| 66 |
-
#
|
| 67 |
transform = transforms.Compose([
|
| 68 |
transforms.Resize(resize),
|
| 69 |
transforms.ToTensor(),
|
|
@@ -74,7 +74,7 @@ def prepare_data(batch_size=32, resize=(224, 224)):
|
|
| 74 |
if len(dataset.classes) != num_classes:
|
| 75 |
return f"Erro: Número de classes detectadas ({len(dataset.classes)}) não corresponde ao número esperado ({num_classes}). Verifique suas imagens."
|
| 76 |
|
| 77 |
-
#
|
| 78 |
train_size = int(0.7 * len(dataset))
|
| 79 |
val_size = int(0.2 * len(dataset))
|
| 80 |
test_size = len(dataset) - train_size - val_size
|
|
@@ -86,7 +86,7 @@ def prepare_data(batch_size=32, resize=(224, 224)):
|
|
| 86 |
|
| 87 |
return "Preparação dos dados concluída com sucesso."
|
| 88 |
|
| 89 |
-
#
|
| 90 |
def start_training(model_name, epochs, lr):
|
| 91 |
global model, train_loader, val_loader, device
|
| 92 |
|
|
@@ -117,7 +117,7 @@ def start_training(model_name, epochs, lr):
|
|
| 117 |
torch.save(model.state_dict(), 'modelo.pth')
|
| 118 |
return f"Treinamento concluído com sucesso. Modelo salvo."
|
| 119 |
|
| 120 |
-
#
|
| 121 |
def evaluate_model(loader):
|
| 122 |
global model, device, num_classes
|
| 123 |
|
|
@@ -144,7 +144,7 @@ def evaluate_model(loader):
|
|
| 144 |
except Exception as e:
|
| 145 |
return f"Erro durante a avaliação: {str(e)}"
|
| 146 |
|
| 147 |
-
#
|
| 148 |
def show_confusion_matrix(loader):
|
| 149 |
global model, device, num_classes
|
| 150 |
|
|
@@ -174,7 +174,7 @@ def show_confusion_matrix(loader):
|
|
| 174 |
buf.seek(0)
|
| 175 |
return Image.open(buf)
|
| 176 |
|
| 177 |
-
#
|
| 178 |
def predict_images(images):
|
| 179 |
global model, device, num_classes
|
| 180 |
|
|
@@ -202,7 +202,7 @@ def predict_images(images):
|
|
| 202 |
|
| 203 |
return results
|
| 204 |
|
| 205 |
-
#
|
| 206 |
def export_model(format):
|
| 207 |
global model
|
| 208 |
|
|
@@ -225,7 +225,7 @@ def export_model(format):
|
|
| 225 |
|
| 226 |
return f"Modelo exportado com sucesso para {file_path}"
|
| 227 |
|
| 228 |
-
#
|
| 229 |
def setup_test_classes():
|
| 230 |
global test_class_dirs, test_dataset_path
|
| 231 |
|
|
@@ -233,21 +233,21 @@ def setup_test_classes():
|
|
| 233 |
shutil.rmtree(test_dataset_path)
|
| 234 |
os.makedirs(test_dataset_path)
|
| 235 |
|
| 236 |
-
#
|
| 237 |
test_class_dirs = [os.path.join(test_dataset_path, f'class_{i}') for i in range(num_classes)]
|
| 238 |
for class_dir in test_class_dirs:
|
| 239 |
os.makedirs(class_dir)
|
| 240 |
|
| 241 |
return f"Criados {num_classes} diretórios para classes de teste."
|
| 242 |
|
| 243 |
-
#
|
| 244 |
def upload_test_images(class_id, images):
|
| 245 |
class_dir = test_class_dirs[int(class_id)]
|
| 246 |
for image in images:
|
| 247 |
shutil.copy(image, class_dir)
|
| 248 |
return f"Imagens de teste salvas na classe {class_id}."
|
| 249 |
|
| 250 |
-
#
|
| 251 |
def prepare_test_data(batch_size=32, resize=(224, 224)):
|
| 252 |
global test_loader, num_classes
|
| 253 |
|
|
@@ -265,7 +265,7 @@ def prepare_test_data(batch_size=32, resize=(224, 224)):
|
|
| 265 |
|
| 266 |
return "Preparação dos dados de teste concluída com sucesso."
|
| 267 |
|
| 268 |
-
#
|
| 269 |
def main():
|
| 270 |
with gr.Blocks() as demo:
|
| 271 |
gr.Markdown("# Image Classification Training")
|
|
|
|
| 12 |
from torchvision import datasets, transforms, models
|
| 13 |
from torch.utils.data import DataLoader, random_split
|
| 14 |
from PIL import Image
|
| 15 |
+
import joblib # Para salvar como .pkl
|
| 16 |
|
| 17 |
+
# Modelos para seleção
|
| 18 |
model_dict = {
|
| 19 |
'AlexNet': models.alexnet,
|
| 20 |
'ResNet18': models.resnet18,
|
|
|
|
| 23 |
'MobileNetV2': models.mobilenet_v2
|
| 24 |
}
|
| 25 |
|
| 26 |
+
# Variáveis globais
|
| 27 |
model = None
|
| 28 |
train_loader = None
|
| 29 |
val_loader = None
|
|
|
|
| 32 |
class_dirs = []
|
| 33 |
test_dataset_path = 'test_dataset'
|
| 34 |
test_class_dirs = []
|
| 35 |
+
num_classes = 2 # Valor padrão para o número de classes
|
| 36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
|
| 38 |
+
# Função para configurar as classes
|
| 39 |
def setup_classes(num_classes_value):
|
| 40 |
global class_dirs, dataset_path, num_classes
|
| 41 |
|
| 42 |
+
num_classes = int(num_classes_value) # Atualizar a variável global num_classes
|
| 43 |
+
# Limpar diretório antigo se existir
|
| 44 |
if os.path.exists(dataset_path):
|
| 45 |
shutil.rmtree(dataset_path)
|
| 46 |
os.makedirs(dataset_path)
|
| 47 |
|
| 48 |
+
# Criar diretórios para cada classe
|
| 49 |
class_dirs = [os.path.join(dataset_path, f'class_{i}') for i in range(num_classes)]
|
| 50 |
for class_dir in class_dirs:
|
| 51 |
os.makedirs(class_dir)
|
| 52 |
|
| 53 |
return f"Criados {num_classes} diretórios para classes."
|
| 54 |
|
| 55 |
+
# Função para upload de imagens
|
| 56 |
def upload_images(class_id, images):
|
| 57 |
class_dir = class_dirs[int(class_id)]
|
| 58 |
for image in images:
|
| 59 |
shutil.copy(image, class_dir)
|
| 60 |
return f"Imagens salvas na classe {class_id}."
|
| 61 |
|
| 62 |
+
# Função para preparação dos dados
|
| 63 |
def prepare_data(batch_size=32, resize=(224, 224)):
|
| 64 |
global train_loader, val_loader, test_loader, num_classes
|
| 65 |
|
| 66 |
+
# Transformações para os dados de treinamento e validação
|
| 67 |
transform = transforms.Compose([
|
| 68 |
transforms.Resize(resize),
|
| 69 |
transforms.ToTensor(),
|
|
|
|
| 74 |
if len(dataset.classes) != num_classes:
|
| 75 |
return f"Erro: Número de classes detectadas ({len(dataset.classes)}) não corresponde ao número esperado ({num_classes}). Verifique suas imagens."
|
| 76 |
|
| 77 |
+
# Divisão do conjunto de dados em treinamento, validação e teste
|
| 78 |
train_size = int(0.7 * len(dataset))
|
| 79 |
val_size = int(0.2 * len(dataset))
|
| 80 |
test_size = len(dataset) - train_size - val_size
|
|
|
|
| 86 |
|
| 87 |
return "Preparação dos dados concluída com sucesso."
|
| 88 |
|
| 89 |
+
# Função para iniciar o treinamento
|
| 90 |
def start_training(model_name, epochs, lr):
|
| 91 |
global model, train_loader, val_loader, device
|
| 92 |
|
|
|
|
| 117 |
torch.save(model.state_dict(), 'modelo.pth')
|
| 118 |
return f"Treinamento concluído com sucesso. Modelo salvo."
|
| 119 |
|
| 120 |
+
# Função para avaliação do modelo com conjunto de teste
|
| 121 |
def evaluate_model(loader):
|
| 122 |
global model, device, num_classes
|
| 123 |
|
|
|
|
| 144 |
except Exception as e:
|
| 145 |
return f"Erro durante a avaliação: {str(e)}"
|
| 146 |
|
| 147 |
+
# Função para mostrar a matriz de confusão
|
| 148 |
def show_confusion_matrix(loader):
|
| 149 |
global model, device, num_classes
|
| 150 |
|
|
|
|
| 174 |
buf.seek(0)
|
| 175 |
return Image.open(buf)
|
| 176 |
|
| 177 |
+
# Função para predição de imagens desconhecidas
|
| 178 |
def predict_images(images):
|
| 179 |
global model, device, num_classes
|
| 180 |
|
|
|
|
| 202 |
|
| 203 |
return results
|
| 204 |
|
| 205 |
+
# Função para exportar o modelo
|
| 206 |
def export_model(format):
|
| 207 |
global model
|
| 208 |
|
|
|
|
| 225 |
|
| 226 |
return f"Modelo exportado com sucesso para {file_path}"
|
| 227 |
|
| 228 |
+
# Função para configurar diretórios de teste
|
| 229 |
def setup_test_classes():
|
| 230 |
global test_class_dirs, test_dataset_path
|
| 231 |
|
|
|
|
| 233 |
shutil.rmtree(test_dataset_path)
|
| 234 |
os.makedirs(test_dataset_path)
|
| 235 |
|
| 236 |
+
# Criar diretórios para cada classe
|
| 237 |
test_class_dirs = [os.path.join(test_dataset_path, f'class_{i}') for i in range(num_classes)]
|
| 238 |
for class_dir in test_class_dirs:
|
| 239 |
os.makedirs(class_dir)
|
| 240 |
|
| 241 |
return f"Criados {num_classes} diretórios para classes de teste."
|
| 242 |
|
| 243 |
+
# Função para upload de imagens de teste
|
| 244 |
def upload_test_images(class_id, images):
|
| 245 |
class_dir = test_class_dirs[int(class_id)]
|
| 246 |
for image in images:
|
| 247 |
shutil.copy(image, class_dir)
|
| 248 |
return f"Imagens de teste salvas na classe {class_id}."
|
| 249 |
|
| 250 |
+
# Função para preparar dados de teste
|
| 251 |
def prepare_test_data(batch_size=32, resize=(224, 224)):
|
| 252 |
global test_loader, num_classes
|
| 253 |
|
|
|
|
| 265 |
|
| 266 |
return "Preparação dos dados de teste concluída com sucesso."
|
| 267 |
|
| 268 |
+
# Interface Gradio
|
| 269 |
def main():
|
| 270 |
with gr.Blocks() as demo:
|
| 271 |
gr.Markdown("# Image Classification Training")
|
requirements.txt
CHANGED
|
@@ -9,8 +9,10 @@ joblib==1.1.0
|
|
| 9 |
onnx==1.10.2
|
| 10 |
onnx-tf==1.9.0
|
| 11 |
tensorflow==2.9.1
|
|
|
|
| 12 |
tensorflowjs==3.20.0
|
| 13 |
numpy==1.21.6
|
| 14 |
Cython==0.29.24
|
| 15 |
|
| 16 |
|
|
|
|
|
|
| 9 |
onnx==1.10.2
|
| 10 |
onnx-tf==1.9.0
|
| 11 |
tensorflow==2.9.1
|
| 12 |
+
tensorflow-addons
|
| 13 |
tensorflowjs==3.20.0
|
| 14 |
numpy==1.21.6
|
| 15 |
Cython==0.29.24
|
| 16 |
|
| 17 |
|
| 18 |
+
|