File size: 6,593 Bytes
b44d3ab
16510bb
 
 
 
 
 
 
b44d3ab
16510bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import streamlit as st
import torch
import json
import requests
import os
from torchvision import models, transforms
from PIL import Image
from urllib.request import urlretrieve

# --- ATUR PATH MODEL DAN LABEL (gunakan direktori yang dapat ditulis di Hugging Face Spaces) ---
BASE_DIR = "/tmp/streamlit_app"

# Pastikan STREAMLIT_HOME berada di direktori yang dapat ditulis
os.environ["STREAMLIT_HOME"] = BASE_DIR
MODEL_DIR = os.path.join(BASE_DIR, "models")
LABELS_DIR = os.path.join(BASE_DIR, "labels")

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)

MODEL_FILENAME = os.getenv("MODEL_FILENAME","mobilenetv2.pth")
LABELS_FILENAME = os.getenv("LABELS_FILENAME", "labels.json")

model_path = os.path.join(MODEL_DIR, MODEL_FILENAME)
labels_path = os.path.join(LABELS_DIR, LABELS_FILENAME)

MODEL_URL = os.getenv("MODEL_URL","https://download.pytorch.org/models/mobilenet_v2-b0353104.pth")
LABELS_URL = os.getenv("LABELS_URL", "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")

# --- KONFIGURASI APLIKASI ---
st.set_page_config(
    page_title="Klasifikasi Gambar (PyTorch) 📸",
    page_icon="🖼️",
    layout="centered"
)

# --- FUNGSI-FUNGSI ---
@st.cache_resource
def load_model():
    """Memuat model MobileNetV2 dari file lokal atau mengunduh jika belum ada."""
    if not os.path.exists(model_path):
        # st.info("Mengunduh model MobileNetV2...")
        try:
            urlretrieve(MODEL_URL, model_path)
            # st.success("Model berhasil diunduh.")
        except Exception as e:
            st.error(f"Gagal mengunduh model: {str(e)}")
            return None

    try:
        # Buat model tanpa weight
        model = models.mobilenet_v2(weights=None)
        # Muat state_dict dari file lokal
        state_dict = torch.load(model_path, map_location=torch.device('cpu'))
        model.load_state_dict(state_dict)
        model.eval()
        return model
    except Exception as e:
        st.error(f"Gagal memuat model: {str(e)}")
        return None



@st.cache_data
def load_labels():
    """Memuat label dari file lokal atau mengunduh jika belum ada."""
    if not os.path.exists(labels_path):
        # st.info("Mengunduh label ImageNet...")
        try:
            response = requests.get(LABELS_URL)
            response.raise_for_status()
            with open(labels_path, 'w') as f:
                json.dump(response.json(), f)
            # st.success("Label berhasil diunduh.")
        except Exception as e:
            st.error(f"Gagal mengunduh label: {str(e)}")
            return None

    try:
        with open(labels_path, 'r') as f:
            labels = json.load(f)
        return labels
    except Exception as e:
        st.error(f"Gagal memuat label: {str(e)}")
        return None



def preprocess_image(image):
    """Melakukan pra-pemrosesan gambar agar sesuai dengan input model PyTorch."""
    try:
        # Definisikan transformasi
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        # Terapkan transformasi dan tambahkan dimensi batch
        img_t = preprocess(image)
        batch_t = torch.unsqueeze(img_t, 0)
        return batch_t
    except Exception as e:
        st.error(f"Gagal memproses gambar: {str(e)}")
        return None

def predict(image, model, labels):
    """Melakukan prediksi klasifikasi pada gambar."""
    try:
        st.info("🧠 Model sedang menganalisis gambar...")

        # Pra-pemrosesan gambar
        batch_t = preprocess_image(image)
        if batch_t is None:
            return None

        # Lakukan prediksi tanpa menghitung gradien
        with torch.no_grad():
            output = model(batch_t)

        # Dapatkan probabilitas dengan softmax
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # Dapatkan 3 kelas dengan probabilitas tertinggi
        top3_prob, top3_catid = torch.topk(probabilities, 3)

        # Siapkan hasil
        results = []
        for i in range(top3_prob.size(0)):
            class_name = labels[top3_catid[i]]
            probability = top3_prob[i].item()
            results.append((class_name, probability))

        return results
    except Exception as e:
        st.error(f"Gagal melakukan prediksi: {str(e)}")
        return None

# --- TAMPILAN UTAMA APLIKASI ---

st.title("🖼️ Aplikasi Klasifikasi Gambar (PyTorch)")
st.write(
    "Unggah sebuah gambar, dan AI akan mencoba menebak objek apa yang ada di dalamnya! "
    "Aplikasi ini menggunakan model **MobileNetV2** dari PyTorch."
)

# Muat model dan label
try:
    model = load_model()
    labels = load_labels()

    if model is None or labels is None:
        st.error("Aplikasi tidak dapat dijalankan karena gagal memuat model atau label.")
        st.stop()
except Exception as e:
    st.error(f"Kesalahan saat inisialisasi aplikasi: {str(e)}")
    st.stop()

# Komponen untuk unggah file
uploaded_file = st.file_uploader(
    "Pilih sebuah gambar...",
    type=["jpg", "jpeg", "png"],
    help="Format file yang didukung: JPG, JPEG, PNG"
)

if uploaded_file is not None:
    try:
        # Buka dan tampilkan gambar yang diunggah
        image = Image.open(uploaded_file).convert('RGB')
        st.image(image, caption='Gambar yang Anda Unggah', use_column_width=True)

        # Tombol untuk memulai klasifikasi
        if st.button('✨ Klasifikasikan Gambar Ini!'):
            with st.spinner('Tunggu sebentar...'):
                # Lakukan prediksi
                predictions = predict(image, model, labels)

                if predictions is not None:
                    st.subheader("✅ Hasil Prediksi Teratas:")
                    for i, (label, score) in enumerate(predictions):
                        st.write(f"{i+1}. **{label.replace('_', ' ').title()}** - Keyakinan: {score:.2%}")
                else:
                    st.error("Prediksi gagal. Silakan coba lagi atau unggah gambar lain.")
    except Exception as e:
        st.error(f"Kesalahan saat memproses gambar yang diunggah: {str(e)}")
        # Tambahan debugging untuk membantu identifikasi
        st.write("Detail error: Periksa koneksi internet atau format gambar.")

st.divider()
# st.markdown(
#     "Dibuat dengan ❤️ menggunakan [Streamlit](https://streamlit.io), [PyTorch](https://pytorch.org/) & [Hugging Face Spaces](https://huggingface.co/spaces)."
# )