vGiacomov commited on
Commit
359c4f4
verified
1 Parent(s): a3c37ad

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ from huggingface_hub import hf_hub_download
6
+ from datasets import load_dataset
7
+ import numpy as np
8
+
9
+
10
+ # === Repozytorium z modelem i artefaktami ===
11
+ REPO_ID = "vGiacomov/image-classifier-beans"
12
+ MODEL_FILENAME = "resnet18_beans.pth"
13
+
14
+
15
+ # === Automatyczne pobranie modelu z Model Hub ===
16
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
17
+
18
+
19
+ # === Wczytanie modelu ===
20
+ model = models.resnet18()
21
+ model.fc = torch.nn.Linear(model.fc.in_features, 3)
22
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
23
+ model.eval()
24
+
25
+
26
+ # === Klasy ===
27
+ labels = ["Healthy", "Bean Rust", "Angular Leaf Spot"]
28
+
29
+
30
+ # === Transformacje obrazu ===
31
+ transform = transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize([0.485, 0.456, 0.406],
35
+ [0.229, 0.224, 0.225])
36
+ ])
37
+
38
+
39
+ # === Funkcja predykcji ===
40
+ def classify(image):
41
+ if image is None:
42
+ return {"No image uploaded": 1.0}
43
+ try:
44
+ image = Image.fromarray(image.astype("uint8"), "RGB")
45
+ tensor = transform(image).unsqueeze(0)
46
+ with torch.no_grad():
47
+ outputs = model(tensor)
48
+ probs = torch.nn.functional.softmax(outputs[0], dim=0)
49
+ return {labels[i]: float(probs[i]) for i in range(3)}
50
+ except Exception as e:
51
+ return {f"Error: {str(e)}": 1.0}
52
+
53
+
54
+ # === NOWE: Pobierz przyk艂adowe obrazy z datasetu beans ===
55
+ def get_example_images():
56
+ """Pobiera przyk艂adowe obrazy z ka偶dej klasy datasetu beans"""
57
+ try:
58
+ dataset = load_dataset("beans", split="train")
59
+ examples = []
60
+
61
+ # Pobierz po jednym przyk艂adzie z ka偶dej klasy (0, 1, 2)
62
+ for label_id in range(3):
63
+ # Znajd藕 pierwszy obraz dla danej klasy
64
+ for item in dataset:
65
+ if item["labels"] == label_id:
66
+ # Konwertuj PIL Image na numpy array (format wymagany przez Gradio)
67
+ img_array = np.array(item["image"])
68
+ examples.append(img_array)
69
+ break
70
+
71
+ return examples
72
+ except Exception as e:
73
+ print(f"Nie uda艂o si臋 za艂adowa膰 przyk艂ad贸w: {e}")
74
+ return []
75
+
76
+
77
+ # === Pobierz przyk艂ady ===
78
+ example_images = get_example_images()
79
+
80
+
81
+ # === Interfejs Gradio ===
82
+ gr.Interface(
83
+ fn=classify,
84
+ inputs=gr.Image(type="numpy", sources=["upload"], label="Upload an image"),
85
+ outputs=gr.Label(num_top_classes=3),
86
+ title="Bean Disease Classifier",
87
+ description="Upload an image of a bean leaf to detect disease.",
88
+ examples=example_images if example_images else None,
89
+ cache_examples=False # Unikaj cachowania na CPU Basic
90
+ ).launch(debug=True)