File size: 2,862 Bytes
359c4f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np


# === Repozytorium z modelem i artefaktami ===
REPO_ID = "vGiacomov/image-classifier-beans"
MODEL_FILENAME = "resnet18_beans.pth"


# === Automatyczne pobranie modelu z Model Hub ===
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)


# === Wczytanie modelu ===
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()


# === Klasy ===
labels = ["Healthy", "Bean Rust", "Angular Leaf Spot"]


# === Transformacje obrazu ===
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


# === Funkcja predykcji ===
def classify(image):
    if image is None:
        return {"No image uploaded": 1.0}
    try:
        image = Image.fromarray(image.astype("uint8"), "RGB")
        tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            outputs = model(tensor)
            probs = torch.nn.functional.softmax(outputs[0], dim=0)
        return {labels[i]: float(probs[i]) for i in range(3)}
    except Exception as e:
        return {f"Error: {str(e)}": 1.0}


# === NOWE: Pobierz przykładowe obrazy z datasetu beans ===
def get_example_images():
    """Pobiera przykładowe obrazy z każdej klasy datasetu beans"""
    try:
        dataset = load_dataset("beans", split="train")
        examples = []
        
        # Pobierz po jednym przykładzie z każdej klasy (0, 1, 2)
        for label_id in range(3):
            # Znajdź pierwszy obraz dla danej klasy
            for item in dataset:
                if item["labels"] == label_id:
                    # Konwertuj PIL Image na numpy array (format wymagany przez Gradio)
                    img_array = np.array(item["image"])
                    examples.append(img_array)
                    break
        
        return examples
    except Exception as e:
        print(f"Nie udało się załadować przykładów: {e}")
        return []


# === Pobierz przykłady ===
example_images = get_example_images()


# === Interfejs Gradio ===
gr.Interface(
    fn=classify,
    inputs=gr.Image(type="numpy", sources=["upload"], label="Upload an image"),
    outputs=gr.Label(num_top_classes=3),
    title="Bean Disease Classifier",
    description="Upload an image of a bean leaf to detect disease.",
    examples=example_images if example_images else None,
    cache_examples=False  # Unikaj cachowania na CPU Basic
).launch(debug=True)