File size: 2,867 Bytes
526a74f
 
e0f5958
 
 
526a74f
 
 
 
 
 
 
 
12484e8
 
526a74f
 
 
 
 
 
e0f5958
 
 
 
 
 
 
 
 
 
 
7ea11ff
 
 
526a74f
 
 
e0f5958
526a74f
e0f5958
 
 
 
526a74f
 
 
 
 
e0f5958
526a74f
e0f5958
526a74f
 
 
 
7ea11ff
 
 
526a74f
7ea11ff
 
526a74f
7ea11ff
 
 
526a74f
7ea11ff
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
import joblib
import torch
import torch.nn as nn
from torchvision.models import resnet18
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download

# Função para carregar o modelo e o LabelEncoder do Hugging Face
@st.cache_resource
def load_model():
    # Baixar o modelo e o LabelEncoder do Hugging Face
    model_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="mlp_classifier.joblib")
    label_encoder_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="label_encoder.joblib")
    
    # Carregar o modelo e o LabelEncoder
    model = joblib.load(model_path)
    label_encoder = joblib.load(label_encoder_path)
    return model, label_encoder

# Função para carregar a CNN (ResNet18) para extração de características
@st.cache_resource
def load_cnn():
    # Carregar a ResNet18 pré-treinada e remover a última camada
    cnn_model = resnet18(pretrained=True)
    cnn_model = nn.Sequential(*list(cnn_model.children())[:-1])  # Remove a última camada
    cnn_model.eval()  # Colocar o modelo em modo de avaliação
    return cnn_model

# Função para pré-processar a imagem e extrair características
def extract_features(image, cnn_model):
    # Converter a imagem para RGB (remover canal alfa se existir)
    image = image.convert("RGB")
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalização para a ResNet18
    ])
    image = transform(image).unsqueeze(0)  # Adicionar dimensão do batch
    with torch.no_grad():  # Desativar cálculo de gradientes
        features = cnn_model(image)  # Extrair características
    return features.flatten().numpy().reshape(1, -1)  # Achatar e converter para numpy array

# Interface da aplicação
st.title("Classificador de Animais")
st.write("Envie uma imagem de um animal (gato, cachorro ou pássaro) para classificação.")

# Carregar o modelo, o LabelEncoder e a CNN
model, label_encoder = load_model()
cnn_model = load_cnn()

# Upload da imagem
uploaded_file = st.file_uploader("Escolha uma imagem...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
    try:
        image = Image.open(uploaded_file)
        st.image(image, caption="Imagem enviada", use_column_width=True)

        # Extrair características da imagem usando a CNN
        image_features = extract_features(image, cnn_model)

        # Fazer a previsão
        prediction = model.predict(image_features)
        predicted_class = label_encoder.inverse_transform(prediction)[0]

        # Exibir o resultado
        st.write(f"Classificação: **{predicted_class}**")
    except Exception as e:
        st.error(f"Erro ao processar a imagem: {e}")