Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import joblib | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import resnet18 | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from huggingface_hub import hf_hub_download | |
| # Função para carregar o modelo e o LabelEncoder do Hugging Face | |
| def load_model(): | |
| # Baixar o modelo e o LabelEncoder do Hugging Face | |
| model_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="mlp_classifier.joblib") | |
| label_encoder_path = hf_hub_download(repo_id="arturevs/90AnimalClassifier", filename="label_encoder.joblib") | |
| # Carregar o modelo e o LabelEncoder | |
| model = joblib.load(model_path) | |
| label_encoder = joblib.load(label_encoder_path) | |
| return model, label_encoder | |
| # Função para carregar a CNN (ResNet18) para extração de características | |
| def load_cnn(): | |
| # Carregar a ResNet18 pré-treinada e remover a última camada | |
| cnn_model = resnet18(pretrained=True) | |
| cnn_model = nn.Sequential(*list(cnn_model.children())[:-1]) # Remove a última camada | |
| cnn_model.eval() # Colocar o modelo em modo de avaliação | |
| return cnn_model | |
| # Função para pré-processar a imagem e extrair características | |
| def extract_features(image, cnn_model): | |
| # Converter a imagem para RGB (remover canal alfa se existir) | |
| image = image.convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalização para a ResNet18 | |
| ]) | |
| image = transform(image).unsqueeze(0) # Adicionar dimensão do batch | |
| with torch.no_grad(): # Desativar cálculo de gradientes | |
| features = cnn_model(image) # Extrair características | |
| return features.flatten().numpy().reshape(1, -1) # Achatar e converter para numpy array | |
| # Interface da aplicação | |
| st.title("Classificador de Animais") | |
| st.write("Envie uma imagem de um animal (gato, cachorro ou pássaro) para classificação.") | |
| # Carregar o modelo, o LabelEncoder e a CNN | |
| model, label_encoder = load_model() | |
| cnn_model = load_cnn() | |
| # Upload da imagem | |
| uploaded_file = st.file_uploader("Escolha uma imagem...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| try: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Imagem enviada", use_column_width=True) | |
| # Extrair características da imagem usando a CNN | |
| image_features = extract_features(image, cnn_model) | |
| # Fazer a previsão | |
| prediction = model.predict(image_features) | |
| predicted_class = label_encoder.inverse_transform(prediction)[0] | |
| # Exibir o resultado | |
| st.write(f"Classificação: **{predicted_class}**") | |
| except Exception as e: | |
| st.error(f"Erro ao processar a imagem: {e}") |