vGiacomov's picture
Upload app.py with huggingface_hub
359c4f4 verified
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
# === Repozytorium z modelem i artefaktami ===
REPO_ID = "vGiacomov/image-classifier-beans"
MODEL_FILENAME = "resnet18_beans.pth"
# === Automatyczne pobranie modelu z Model Hub ===
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
# === Wczytanie modelu ===
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 3)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# === Klasy ===
labels = ["Healthy", "Bean Rust", "Angular Leaf Spot"]
# === Transformacje obrazu ===
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# === Funkcja predykcji ===
def classify(image):
if image is None:
return {"No image uploaded": 1.0}
try:
image = Image.fromarray(image.astype("uint8"), "RGB")
tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(tensor)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
return {labels[i]: float(probs[i]) for i in range(3)}
except Exception as e:
return {f"Error: {str(e)}": 1.0}
# === NOWE: Pobierz przykładowe obrazy z datasetu beans ===
def get_example_images():
"""Pobiera przykładowe obrazy z każdej klasy datasetu beans"""
try:
dataset = load_dataset("beans", split="train")
examples = []
# Pobierz po jednym przykładzie z każdej klasy (0, 1, 2)
for label_id in range(3):
# Znajdź pierwszy obraz dla danej klasy
for item in dataset:
if item["labels"] == label_id:
# Konwertuj PIL Image na numpy array (format wymagany przez Gradio)
img_array = np.array(item["image"])
examples.append(img_array)
break
return examples
except Exception as e:
print(f"Nie udało się załadować przykładów: {e}")
return []
# === Pobierz przykłady ===
example_images = get_example_images()
# === Interfejs Gradio ===
gr.Interface(
fn=classify,
inputs=gr.Image(type="numpy", sources=["upload"], label="Upload an image"),
outputs=gr.Label(num_top_classes=3),
title="Bean Disease Classifier",
description="Upload an image of a bean leaf to detect disease.",
examples=example_images if example_images else None,
cache_examples=False # Unikaj cachowania na CPU Basic
).launch(debug=True)