Update app.py
Browse files
app.py
CHANGED
|
@@ -123,49 +123,50 @@ transform_fn = transforms.Compose([
|
|
| 123 |
transforms.ToTensor(),
|
| 124 |
transforms.Normalize((0.1307,), (0.3081,))
|
| 125 |
])
|
| 126 |
-
|
| 127 |
def predict_digit(input_image):
|
| 128 |
if input_image is None:
|
| 129 |
return "Please draw a number!"
|
| 130 |
-
|
| 131 |
try:
|
| 132 |
-
#
|
| 133 |
-
if isinstance(input_image, dict)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
img_array = input_image["composite"]
|
| 137 |
else:
|
| 138 |
img_array = input_image
|
| 139 |
|
| 140 |
-
#
|
| 141 |
-
if
|
| 142 |
-
if img_array.
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
| 146 |
else:
|
| 147 |
-
|
|
|
|
| 148 |
|
| 149 |
-
# 3. Cek
|
| 150 |
-
if
|
| 151 |
return {str(i): 0.1 for i in range(10)}
|
| 152 |
|
| 153 |
-
# 4.
|
| 154 |
-
img = Image.fromarray(grayscale.astype(np.uint8)
|
| 155 |
-
img = img.resize((28, 28),
|
| 156 |
-
|
| 157 |
-
# 5. Jalankan normalisasi PyTorch sesuai training LookThem V8 kamu
|
| 158 |
tensor_img = transform_fn(img).unsqueeze(0).to(device)
|
| 159 |
-
|
| 160 |
-
#
|
| 161 |
with torch.no_grad():
|
| 162 |
outputs = model(tensor_img)
|
| 163 |
probabilities = F.softmax(outputs, dim=1)[0]
|
| 164 |
-
|
| 165 |
return {str(i): float(probabilities[i]) for i in range(10)}
|
| 166 |
|
| 167 |
except Exception as e:
|
| 168 |
-
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
# --- GRADIO INTERFACE CONSTRUCTION ---
|
|
@@ -179,12 +180,12 @@ with gr.Blocks() as demo:
|
|
| 179 |
|
| 180 |
with gr.Row():
|
| 181 |
with gr.Column():
|
| 182 |
-
# Menggunakan gr.Paint dengan mode L (Luminance/Grayscale)
|
| 183 |
input_canvas = gr.Paint(
|
| 184 |
-
image_mode="L",
|
| 185 |
height=280,
|
| 186 |
width=280,
|
| 187 |
-
|
|
|
|
| 188 |
)
|
| 189 |
submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
|
| 190 |
|
|
|
|
| 123 |
transforms.ToTensor(),
|
| 124 |
transforms.Normalize((0.1307,), (0.3081,))
|
| 125 |
])
|
|
|
|
| 126 |
def predict_digit(input_image):
|
| 127 |
if input_image is None:
|
| 128 |
return "Please draw a number!"
|
| 129 |
+
|
| 130 |
try:
|
| 131 |
+
# Versi aman: ambil composite jika ada (numpy array HxW atau HxWxC)
|
| 132 |
+
if isinstance(input_image, dict):
|
| 133 |
+
# Beberapa versi Gradio meletakkan hasil akhir di 'composite'
|
| 134 |
+
img_array = input_image.get("composite", input_image["layers"][0])
|
|
|
|
| 135 |
else:
|
| 136 |
img_array = input_image
|
| 137 |
|
| 138 |
+
# Konversi ke grayscale 2D
|
| 139 |
+
if isinstance(img_array, np.ndarray):
|
| 140 |
+
if img_array.ndim == 3:
|
| 141 |
+
if img_array.shape[-1] == 4: # RGBA -> ambil alpha
|
| 142 |
+
grayscale = img_array[..., 3]
|
| 143 |
+
else: # RGB -> luminance
|
| 144 |
+
grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
|
| 145 |
+
else:
|
| 146 |
+
grayscale = img_array
|
| 147 |
else:
|
| 148 |
+
# Kalau ternyata PIL.Image
|
| 149 |
+
grayscale = np.array(img_array.convert("L"))
|
| 150 |
|
| 151 |
+
# 3. Cek kanvas kosong (kali ini cek nilai maks > 0, bukan == 0)
|
| 152 |
+
if grayscale.max() == 0:
|
| 153 |
return {str(i): 0.1 for i in range(10)}
|
| 154 |
|
| 155 |
+
# 4. Resize dan normalisasi
|
| 156 |
+
img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
|
| 157 |
+
img = img.resize((28, 28), Image.Resampling.BILINEAR)
|
|
|
|
|
|
|
| 158 |
tensor_img = transform_fn(img).unsqueeze(0).to(device)
|
| 159 |
+
|
| 160 |
+
# 5. Prediksi
|
| 161 |
with torch.no_grad():
|
| 162 |
outputs = model(tensor_img)
|
| 163 |
probabilities = F.softmax(outputs, dim=1)[0]
|
| 164 |
+
|
| 165 |
return {str(i): float(probabilities[i]) for i in range(10)}
|
| 166 |
|
| 167 |
except Exception as e:
|
| 168 |
+
# Untuk debug, kembalikan pesan errornya
|
| 169 |
+
return {"error": str(e)}
|
| 170 |
|
| 171 |
|
| 172 |
# --- GRADIO INTERFACE CONSTRUCTION ---
|
|
|
|
| 180 |
|
| 181 |
with gr.Row():
|
| 182 |
with gr.Column():
|
|
|
|
| 183 |
input_canvas = gr.Paint(
|
| 184 |
+
image_mode="L",
|
| 185 |
height=280,
|
| 186 |
width=280,
|
| 187 |
+
canvas_color="black", # ⬅️ ini yang wajib ditambahkan
|
| 188 |
+
brush=gr.components.image_editor.Brush(default_color="rgb(255, 255, 255)")
|
| 189 |
)
|
| 190 |
submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
|
| 191 |
|