arturevs's picture
update
5653677
import streamlit as st
import joblib
import torch
import torch.nn as nn
from torchvision.models import resnet18
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
# Função para carregar o modelo e o LabelEncoder do Hugging Face
@st.cache_resource
def load_model():
# Baixar o modelo e o LabelEncoder do Hugging Face
model_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="mlp_classifier.joblib")
label_encoder_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="label_encoder.joblib")
# Carregar o modelo e o LabelEncoder
model = joblib.load(model_path)
label_encoder = joblib.load(label_encoder_path)
return model, label_encoder
# Função para carregar a CNN (ResNet18) para extração de características
@st.cache_resource
def load_cnn():
# Carregar a ResNet18 pré-treinada e remover a última camada
cnn_model = resnet18(pretrained=True)
cnn_model = nn.Sequential(*list(cnn_model.children())[:-1]) # Remove a última camada
cnn_model.eval() # Colocar o modelo em modo de avaliação
return cnn_model
# Função para pré-processar a imagem e extrair características
def extract_features(image, cnn_model):
# Converter a imagem para RGB (remover canal alfa se existir)
image = image.convert("RGB")
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalização para a ResNet18
])
image = transform(image).unsqueeze(0) # Adicionar dimensão do batch
with torch.no_grad(): # Desativar cálculo de gradientes
features = cnn_model(image) # Extrair características
return features.flatten().numpy().reshape(1, -1) # Achatar e converter para numpy array
# Interface da aplicação
st.title("Classificador de Animais")
st.write("Envie uma imagem de um animal (gato, cachorro ou pássaro) para classificação.")
# Carregar o modelo, o LabelEncoder e a CNN
model, label_encoder = load_model()
cnn_model = load_cnn()
# Upload da imagem
uploaded_file = st.file_uploader("Escolha uma imagem...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
image = Image.open(uploaded_file)
st.image(image, caption="Imagem enviada", use_column_width=True)
# Extrair características da imagem usando a CNN
image_features = extract_features(image, cnn_model)
# Fazer a previsão
prediction = model.predict(image_features)
predicted_class = label_encoder.inverse_transform(prediction)[0]
# Exibir o resultado
st.write(f"Classificação: **{predicted_class}**")
except Exception as e:
st.error(f"Erro ao processar a imagem: {e}")