arturevs commited on
Commit
e0f5958
·
1 Parent(s): 12484e8
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import streamlit as st
2
  import joblib
 
 
 
3
  from PIL import Image
4
  import torchvision.transforms as transforms
5
  from huggingface_hub import hf_hub_download
@@ -16,22 +19,34 @@ def load_model():
16
  label_encoder = joblib.load(label_encoder_path)
17
  return model, label_encoder
18
 
19
- # Função para pré-processar a imagem
20
- def preprocess_image(image):
 
 
 
 
 
 
 
 
 
21
  transform = transforms.Compose([
22
  transforms.Resize((256, 256)),
23
  transforms.ToTensor(),
24
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
25
  ])
26
- image = transform(image).flatten().numpy().reshape(1, -1)
27
- return image
 
 
28
 
29
  # Interface da aplicação
30
  st.title("Classificador de Animais")
31
  st.write("Envie uma imagem de um animal (gato, cachorro ou pássaro) para classificação.")
32
 
33
- # Carregar o modelo e o LabelEncoder
34
  model, label_encoder = load_model()
 
35
 
36
  # Upload da imagem
37
  uploaded_file = st.file_uploader("Escolha uma imagem...", type=["jpg", "jpeg", "png"])
@@ -39,11 +54,11 @@ if uploaded_file is not None:
39
  image = Image.open(uploaded_file)
40
  st.image(image, caption="Imagem enviada", use_column_width=True)
41
 
42
- # Pré-processar a imagem
43
- image_processed = preprocess_image(image)
44
 
45
  # Fazer a previsão
46
- prediction = model.predict(image_processed)
47
  predicted_class = label_encoder.inverse_transform(prediction)[0]
48
 
49
  # Exibir o resultado
 
1
  import streamlit as st
2
  import joblib
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision.models import resnet18
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
  from huggingface_hub import hf_hub_download
 
19
  label_encoder = joblib.load(label_encoder_path)
20
  return model, label_encoder
21
 
22
+ # Função para carregar a CNN (ResNet18) para extração de características
23
+ @st.cache_resource
24
+ def load_cnn():
25
+ # Carregar a ResNet18 pré-treinada e remover a última camada
26
+ cnn_model = resnet18(pretrained=True)
27
+ cnn_model = nn.Sequential(*list(cnn_model.children())[:-1]) # Remove a última camada
28
+ cnn_model.eval() # Colocar o modelo em modo de avaliação
29
+ return cnn_model
30
+
31
+ # Função para pré-processar a imagem e extrair características
32
+ def extract_features(image, cnn_model):
33
  transform = transforms.Compose([
34
  transforms.Resize((256, 256)),
35
  transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalização para a ResNet18
37
  ])
38
+ image = transform(image).unsqueeze(0) # Adicionar dimensão do batch
39
+ with torch.no_grad(): # Desativar cálculo de gradientes
40
+ features = cnn_model(image) # Extrair características
41
+ return features.flatten().numpy().reshape(1, -1) # Achatar e converter para numpy array
42
 
43
  # Interface da aplicação
44
  st.title("Classificador de Animais")
45
  st.write("Envie uma imagem de um animal (gato, cachorro ou pássaro) para classificação.")
46
 
47
+ # Carregar o modelo, o LabelEncoder e a CNN
48
  model, label_encoder = load_model()
49
+ cnn_model = load_cnn()
50
 
51
  # Upload da imagem
52
  uploaded_file = st.file_uploader("Escolha uma imagem...", type=["jpg", "jpeg", "png"])
 
54
  image = Image.open(uploaded_file)
55
  st.image(image, caption="Imagem enviada", use_column_width=True)
56
 
57
+ # Extrair características da imagem usando a CNN
58
+ image_features = extract_features(image, cnn_model)
59
 
60
  # Fazer a previsão
61
+ prediction = model.predict(image_features)
62
  predicted_class = label_encoder.inverse_transform(prediction)[0]
63
 
64
  # Exibir o resultado