Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from datasets import load_dataset | |
| import numpy as np | |
| # === Repozytorium z modelem i artefaktami === | |
| REPO_ID = "vGiacomov/image-classifier-beans" | |
| MODEL_FILENAME = "resnet18_beans.pth" | |
| # === Automatyczne pobranie modelu z Model Hub === | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) | |
| # === Wczytanie modelu === | |
| model = models.resnet18() | |
| model.fc = torch.nn.Linear(model.fc.in_features, 3) | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| # === Klasy === | |
| labels = ["Healthy", "Bean Rust", "Angular Leaf Spot"] | |
| # === Transformacje obrazu === | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| # === Funkcja predykcji === | |
| def classify(image): | |
| if image is None: | |
| return {"No image uploaded": 1.0} | |
| try: | |
| image = Image.fromarray(image.astype("uint8"), "RGB") | |
| tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(tensor) | |
| probs = torch.nn.functional.softmax(outputs[0], dim=0) | |
| return {labels[i]: float(probs[i]) for i in range(3)} | |
| except Exception as e: | |
| return {f"Error: {str(e)}": 1.0} | |
| # === NOWE: Pobierz przykładowe obrazy z datasetu beans === | |
| def get_example_images(): | |
| """Pobiera przykładowe obrazy z każdej klasy datasetu beans""" | |
| try: | |
| dataset = load_dataset("beans", split="train") | |
| examples = [] | |
| # Pobierz po jednym przykładzie z każdej klasy (0, 1, 2) | |
| for label_id in range(3): | |
| # Znajdź pierwszy obraz dla danej klasy | |
| for item in dataset: | |
| if item["labels"] == label_id: | |
| # Konwertuj PIL Image na numpy array (format wymagany przez Gradio) | |
| img_array = np.array(item["image"]) | |
| examples.append(img_array) | |
| break | |
| return examples | |
| except Exception as e: | |
| print(f"Nie udało się załadować przykładów: {e}") | |
| return [] | |
| # === Pobierz przykłady === | |
| example_images = get_example_images() | |
| # === Interfejs Gradio === | |
| gr.Interface( | |
| fn=classify, | |
| inputs=gr.Image(type="numpy", sources=["upload"], label="Upload an image"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Bean Disease Classifier", | |
| description="Upload an image of a bean leaf to detect disease.", | |
| examples=example_images if example_images else None, | |
| cache_examples=False # Unikaj cachowania na CPU Basic | |
| ).launch(debug=True) | |