NararyaPutra commited on
Commit
6670680
·
1 Parent(s): 73ba6a9

update env2

Browse files
Files changed (2) hide show
  1. requirements.txt +5 -3
  2. src/streamlit_app.py +190 -38
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ requests
src/streamlit_app.py CHANGED
@@ -1,40 +1,192 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import json
4
+ import requests
5
+ import os
6
+ from torchvision import models, transforms
7
+ from PIL import Image
8
+ from urllib.request import urlretrieve
9
 
10
+ # --- ATUR PATH MODEL DAN LABEL (gunakan direktori yang dapat ditulis di Hugging Face Spaces) ---
11
+ BASE_DIR = "/tmp/streamlit_app"
12
+
13
+ # Pastikan STREAMLIT_HOME berada di direktori yang dapat ditulis
14
+ os.environ["STREAMLIT_HOME"] = BASE_DIR
15
+ MODEL_DIR = os.path.join(BASE_DIR, "models")
16
+ LABELS_DIR = os.path.join(BASE_DIR, "labels")
17
+
18
+ os.makedirs(MODEL_DIR, exist_ok=True)
19
+ os.makedirs(LABELS_DIR, exist_ok=True)
20
+
21
+ MODEL_FILENAME = os.getenv("MODEL_FILENAME","mobilenetv2.pth")
22
+ LABELS_FILENAME = os.getenv("LABELS_FILENAME", "labels.json")
23
+
24
+ model_path = os.path.join(MODEL_DIR, MODEL_FILENAME)
25
+ labels_path = os.path.join(LABELS_DIR, LABELS_FILENAME)
26
+
27
+ MODEL_URL = os.getenv("MODEL_URL","https://download.pytorch.org/models/mobilenet_v2-b0353104.pth")
28
+ LABELS_URL = os.getenv("LABELS_URL", "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")
29
+
30
+ # --- KONFIGURASI APLIKASI ---
31
+ st.set_page_config(
32
+ page_title="Klasifikasi Gambar (PyTorch) 📸",
33
+ page_icon="🖼️",
34
+ layout="centered"
35
+ )
36
+
37
+ # --- FUNGSI-FUNGSI ---
38
+ @st.cache_resource
39
+ def load_model():
40
+ """Memuat model MobileNetV2 dari file lokal atau mengunduh jika belum ada."""
41
+ if not os.path.exists(model_path):
42
+ st.info("Mengunduh model MobileNetV2...")
43
+ try:
44
+ urlretrieve(MODEL_URL, model_path)
45
+ st.success("Model berhasil diunduh.")
46
+ except Exception as e:
47
+ st.error(f"Gagal mengunduh model: {str(e)}")
48
+ return None
49
+
50
+ try:
51
+ # Buat model tanpa weight
52
+ model = models.mobilenet_v2(weights=None)
53
+ # Muat state_dict dari file lokal
54
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
55
+ model.load_state_dict(state_dict)
56
+ model.eval()
57
+ return model
58
+ except Exception as e:
59
+ st.error(f"Gagal memuat model: {str(e)}")
60
+ return None
61
+
62
+
63
+
64
+ @st.cache_data
65
+ def load_labels():
66
+ """Memuat label dari file lokal atau mengunduh jika belum ada."""
67
+ if not os.path.exists(labels_path):
68
+ st.info("Mengunduh label ImageNet...")
69
+ try:
70
+ response = requests.get(LABELS_URL)
71
+ response.raise_for_status()
72
+ with open(labels_path, 'w') as f:
73
+ json.dump(response.json(), f)
74
+ st.success("Label berhasil diunduh.")
75
+ except Exception as e:
76
+ st.error(f"Gagal mengunduh label: {str(e)}")
77
+ return None
78
+
79
+ try:
80
+ with open(labels_path, 'r') as f:
81
+ labels = json.load(f)
82
+ return labels
83
+ except Exception as e:
84
+ st.error(f"Gagal memuat label: {str(e)}")
85
+ return None
86
+
87
+
88
+
89
+ def preprocess_image(image):
90
+ """Melakukan pra-pemrosesan gambar agar sesuai dengan input model PyTorch."""
91
+ try:
92
+ # Definisikan transformasi
93
+ preprocess = transforms.Compose([
94
+ transforms.Resize(256),
95
+ transforms.CenterCrop(224),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
98
+ ])
99
+ # Terapkan transformasi dan tambahkan dimensi batch
100
+ img_t = preprocess(image)
101
+ batch_t = torch.unsqueeze(img_t, 0)
102
+ return batch_t
103
+ except Exception as e:
104
+ st.error(f"Gagal memproses gambar: {str(e)}")
105
+ return None
106
+
107
+ def predict(image, model, labels):
108
+ """Melakukan prediksi klasifikasi pada gambar."""
109
+ try:
110
+ st.info("🧠 Model sedang menganalisis gambar...")
111
+
112
+ # Pra-pemrosesan gambar
113
+ batch_t = preprocess_image(image)
114
+ if batch_t is None:
115
+ return None
116
+
117
+ # Lakukan prediksi tanpa menghitung gradien
118
+ with torch.no_grad():
119
+ output = model(batch_t)
120
+
121
+ # Dapatkan probabilitas dengan softmax
122
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
123
+
124
+ # Dapatkan 3 kelas dengan probabilitas tertinggi
125
+ top3_prob, top3_catid = torch.topk(probabilities, 3)
126
+
127
+ # Siapkan hasil
128
+ results = []
129
+ for i in range(top3_prob.size(0)):
130
+ class_name = labels[top3_catid[i]]
131
+ probability = top3_prob[i].item()
132
+ results.append((class_name, probability))
133
+
134
+ return results
135
+ except Exception as e:
136
+ st.error(f"Gagal melakukan prediksi: {str(e)}")
137
+ return None
138
+
139
+ # --- TAMPILAN UTAMA APLIKASI ---
140
+
141
+ st.title("🖼️ Aplikasi Klasifikasi Gambar (PyTorch)")
142
+ st.write(
143
+ "Unggah sebuah gambar, dan AI akan mencoba menebak objek apa yang ada di dalamnya! "
144
+ "Aplikasi ini menggunakan model **MobileNetV2** dari PyTorch."
145
+ )
146
+
147
+ # Muat model dan label
148
+ try:
149
+ model = load_model()
150
+ labels = load_labels()
151
+
152
+ if model is None or labels is None:
153
+ st.error("Aplikasi tidak dapat dijalankan karena gagal memuat model atau label.")
154
+ st.stop()
155
+ except Exception as e:
156
+ st.error(f"Kesalahan saat inisialisasi aplikasi: {str(e)}")
157
+ st.stop()
158
+
159
+ # Komponen untuk unggah file
160
+ uploaded_file = st.file_uploader(
161
+ "Pilih sebuah gambar...",
162
+ type=["jpg", "jpeg", "png"],
163
+ help="Format file yang didukung: JPG, JPEG, PNG"
164
+ )
165
+
166
+ if uploaded_file is not None:
167
+ try:
168
+ # Buka dan tampilkan gambar yang diunggah
169
+ image = Image.open(uploaded_file).convert('RGB')
170
+ st.image(image, caption='Gambar yang Anda Unggah', use_column_width=True)
171
+
172
+ # Tombol untuk memulai klasifikasi
173
+ if st.button('✨ Klasifikasikan Gambar Ini!'):
174
+ with st.spinner('Tunggu sebentar...'):
175
+ # Lakukan prediksi
176
+ predictions = predict(image, model, labels)
177
+
178
+ if predictions is not None:
179
+ st.subheader("✅ Hasil Prediksi Teratas:")
180
+ for i, (label, score) in enumerate(predictions):
181
+ st.write(f"{i+1}. **{label.replace('_', ' ').title()}** - Keyakinan: {score:.2%}")
182
+ else:
183
+ st.error("Prediksi gagal. Silakan coba lagi atau unggah gambar lain.")
184
+ except Exception as e:
185
+ st.error(f"Kesalahan saat memproses gambar yang diunggah: {str(e)}")
186
+ # Tambahan debugging untuk membantu identifikasi
187
+ st.write("Detail error: Periksa koneksi internet atau format gambar.")
188
+
189
+ st.divider()
190
+ st.markdown(
191
+ "Dibuat dengan ❤️ menggunakan [Streamlit](https://streamlit.io), [PyTorch](https://pytorch.org/) & [Hugging Face Spaces](https://huggingface.co/spaces)."
192
+ )