Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| import pickle | |
| from resnest.torch import resnest50 | |
| # Carregar nomes das classes originais | |
| with open('class_names.pkl', 'rb') as f: | |
| class_names_en = pickle.load(f) | |
| # Imprimir as classes originais para debug | |
| print("Classes originais encontradas:", class_names_en) | |
| # Dicionário de tradução mais completo (incluindo variações) | |
| class_names_pt = { | |
| 'apple': 'maçã', | |
| 'Apple': 'maçã', | |
| 'Apple 10': 'maçã', # adicionando variações | |
| 'banana': 'banana', | |
| 'Banana': 'banana', | |
| 'cherry': 'cereja', | |
| 'Cherry': 'cereja', | |
| 'chico': 'sapoti', | |
| 'grape': 'uva', | |
| 'Grape': 'uva', | |
| 'kiwi': 'kiwi', | |
| 'Kiwi': 'kiwi', | |
| 'mango': 'manga', | |
| 'Mango': 'manga', | |
| 'orange': 'laranja', | |
| 'Orange': 'laranja', | |
| 'pear': 'pera', | |
| 'Pear': 'pera', | |
| 'tomato': 'tomate', | |
| 'Tomato': 'tomate' | |
| } | |
| # Criar lista de nomes em português, usando o nome original se não houver tradução | |
| class_names = [] | |
| for en in class_names_en: | |
| # Remover números e espaços extras para normalizar | |
| base_name = ''.join([i for i in en if not i.isdigit()]).strip() | |
| translated = class_names_pt.get(base_name, class_names_pt.get(en, en)) | |
| class_names.append(translated) | |
| print("Classes traduzidas:", class_names) | |
| # Restante do código permanece igual... | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = resnest50(pretrained=None) | |
| model.fc = nn.Sequential( | |
| nn.Dropout(0.2), | |
| nn.Linear(model.fc.in_features, len(class_names)) | |
| ) | |
| # Carregar os pesos do modelo | |
| model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True)) | |
| model = model.to(device) | |
| model.eval() | |
| # Definir o mesmo pré-processamento usado no treinamento | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((100, 100)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict_image(img): | |
| img = img.convert('RGB') | |
| # Aplicar pré-processamento | |
| input_tensor = preprocess(img) | |
| # Adicionar dimensão de batch e mover para o dispositivo | |
| input_batch = input_tensor.unsqueeze(0).to(device) | |
| # Fazer previsão | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| # Calcular probabilidades | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| # Obter as 3 melhores previsões | |
| top3_probs, top3_indices = torch.topk(probabilities, 3) | |
| results = { | |
| class_names[i]: float(p) | |
| for p, i in zip(top3_probs, top3_indices) | |
| } | |
| # Obter a melhor previsão | |
| best_class = class_names[top3_indices[0]] | |
| best_conf = float(top3_probs[0]) * 100 | |
| # Salvar resultados | |
| with open('/tmp/prediction_results.txt', 'a') as f: | |
| f.write(f"Imagem: {img}\n" | |
| f"Previsão: {best_class}\n" | |
| f"Confiança: {best_conf:.2f}%\n" | |
| f"Top 3: {results}\n" | |
| f"------------------------\n") | |
| return best_class, f"{best_conf:.2f}%", results | |
| # Criar interface Gradio | |
| def create_interface(): | |
| examples = [ | |
| "r0_0_100.jpg", | |
| "r0_18_100.jpg" | |
| ] | |
| with gr.Blocks(title="Sistema de Classificação de Frutas", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🍎 Sistema de Reconhecimento de Frutas") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Envie uma imagem") | |
| gr.Examples(examples=examples, inputs=image_input) | |
| submit_btn = gr.Button("Classificar", variant="primary") | |
| with gr.Column(): | |
| best_pred = gr.Textbox(label="Resultado da Previsão") | |
| confidence = gr.Textbox(label="Nível de Confiança") | |
| full_results = gr.Label(label="Top 3", num_top_classes=3) | |
| # Evento de clique do botão 'Classificar' | |
| submit_btn.click( | |
| fn=predict_image, | |
| inputs=image_input, | |
| outputs=[best_pred, confidence, full_results] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch(share=False) |