import streamlit as st import joblib import torch import torch.nn as nn from torchvision.models import resnet18 from PIL import Image import torchvision.transforms as transforms from huggingface_hub import hf_hub_download # Função para carregar o modelo e o LabelEncoder do Hugging Face @st.cache_resource def load_model(): # Baixar o modelo e o LabelEncoder do Hugging Face model_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="mlp_classifier.joblib") label_encoder_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="label_encoder.joblib") # Carregar o modelo e o LabelEncoder model = joblib.load(model_path) label_encoder = joblib.load(label_encoder_path) return model, label_encoder # Função para carregar a CNN (ResNet18) para extração de características @st.cache_resource def load_cnn(): # Carregar a ResNet18 pré-treinada e remover a última camada cnn_model = resnet18(pretrained=True) cnn_model = nn.Sequential(*list(cnn_model.children())[:-1]) # Remove a última camada cnn_model.eval() # Colocar o modelo em modo de avaliação return cnn_model # Função para pré-processar a imagem e extrair características def extract_features(image, cnn_model): # Converter a imagem para RGB (remover canal alfa se existir) image = image.convert("RGB") transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalização para a ResNet18 ]) image = transform(image).unsqueeze(0) # Adicionar dimensão do batch with torch.no_grad(): # Desativar cálculo de gradientes features = cnn_model(image) # Extrair características return features.flatten().numpy().reshape(1, -1) # Achatar e converter para numpy array # Interface da aplicação st.title("Classificador de Animais") st.write("Envie uma imagem de um animal (gato, cachorro ou pássaro) para classificação.") # Carregar o modelo, o LabelEncoder e a CNN model, label_encoder = load_model() cnn_model = load_cnn() # Upload da imagem uploaded_file = st.file_uploader("Escolha uma imagem...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: try: image = Image.open(uploaded_file) st.image(image, caption="Imagem enviada", use_column_width=True) # Extrair características da imagem usando a CNN image_features = extract_features(image, cnn_model) # Fazer a previsão prediction = model.predict(image_features) predicted_class = label_encoder.inverse_transform(prediction)[0] # Exibir o resultado st.write(f"Classificação: **{predicted_class}**") except Exception as e: st.error(f"Erro ao processar a imagem: {e}")