import gradio as gr import torch from torchvision import models, transforms from PIL import Image from huggingface_hub import hf_hub_download from datasets import load_dataset import numpy as np # === Repozytorium z modelem i artefaktami === REPO_ID = "vGiacomov/image-classifier-beans" MODEL_FILENAME = "resnet18_beans.pth" # === Automatyczne pobranie modelu z Model Hub === model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) # === Wczytanie modelu === model = models.resnet18() model.fc = torch.nn.Linear(model.fc.in_features, 3) model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() # === Klasy === labels = ["Healthy", "Bean Rust", "Angular Leaf Spot"] # === Transformacje obrazu === transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # === Funkcja predykcji === def classify(image): if image is None: return {"No image uploaded": 1.0} try: image = Image.fromarray(image.astype("uint8"), "RGB") tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(tensor) probs = torch.nn.functional.softmax(outputs[0], dim=0) return {labels[i]: float(probs[i]) for i in range(3)} except Exception as e: return {f"Error: {str(e)}": 1.0} # === NOWE: Pobierz przykładowe obrazy z datasetu beans === def get_example_images(): """Pobiera przykładowe obrazy z każdej klasy datasetu beans""" try: dataset = load_dataset("beans", split="train") examples = [] # Pobierz po jednym przykładzie z każdej klasy (0, 1, 2) for label_id in range(3): # Znajdź pierwszy obraz dla danej klasy for item in dataset: if item["labels"] == label_id: # Konwertuj PIL Image na numpy array (format wymagany przez Gradio) img_array = np.array(item["image"]) examples.append(img_array) break return examples except Exception as e: print(f"Nie udało się załadować przykładów: {e}") return [] # === Pobierz przykłady === example_images = get_example_images() # === Interfejs Gradio === gr.Interface( fn=classify, inputs=gr.Image(type="numpy", sources=["upload"], label="Upload an image"), outputs=gr.Label(num_top_classes=3), title="Bean Disease Classifier", description="Upload an image of a bean leaf to detect disease.", examples=example_images if example_images else None, cache_examples=False # Unikaj cachowania na CPU Basic ).launch(debug=True)