ASomeoneWhoInterestedWithAI commited on
Commit
e9455f8
·
verified ·
1 Parent(s): 68860ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -123,49 +123,50 @@ transform_fn = transforms.Compose([
123
  transforms.ToTensor(),
124
  transforms.Normalize((0.1307,), (0.3081,))
125
  ])
126
-
127
  def predict_digit(input_image):
128
  if input_image is None:
129
  return "Please draw a number!"
130
-
131
  try:
132
- # 1. Ambil layer coretan pertama langsung seperti kode referensi
133
- if isinstance(input_image, dict) and "layers" in input_image and len(input_image["layers"]) > 0:
134
- img_array = input_image["layers"][0]
135
- elif isinstance(input_image, dict) and "composite" in input_image:
136
- img_array = input_image["composite"]
137
  else:
138
  img_array = input_image
139
 
140
- # 2. Pastikan bentuknya 2D Grayscale
141
- if len(img_array.shape) == 3:
142
- if img_array.shape[-1] == 4: # Jika RGBA, ambil alpha channel (coretannya)
143
- grayscale = img_array[:, :, 3]
144
- else: # Jika RGB, konversi ke grayscale
145
- grayscale = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140])
 
 
 
146
  else:
147
- grayscale = img_array
 
148
 
149
- # 3. Cek jika kanvas kosong
150
- if np.max(grayscale) == 0:
151
  return {str(i): 0.1 for i in range(10)}
152
 
153
- # 4. Konversi ke PIL dan Resize ke 28x28 (Gaya kode referensimu)
154
- img = Image.fromarray(grayscale.astype(np.uint8)).convert('L')
155
- img = img.resize((28, 28), resample=Image.Resampling.BILINEAR)
156
-
157
- # 5. Jalankan normalisasi PyTorch sesuai training LookThem V8 kamu
158
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
159
-
160
- # 6. Prediksi dengan Otak Sniper LookThem
161
  with torch.no_grad():
162
  outputs = model(tensor_img)
163
  probabilities = F.softmax(outputs, dim=1)[0]
164
-
165
  return {str(i): float(probabilities[i]) for i in range(10)}
166
 
167
  except Exception as e:
168
- return {"Error": 0.0}
 
169
 
170
 
171
  # --- GRADIO INTERFACE CONSTRUCTION ---
@@ -179,12 +180,12 @@ with gr.Blocks() as demo:
179
 
180
  with gr.Row():
181
  with gr.Column():
182
- # Menggunakan gr.Paint dengan mode L (Luminance/Grayscale)
183
  input_canvas = gr.Paint(
184
- image_mode="L",
185
  height=280,
186
  width=280,
187
- brush=gr.components.image_editor.Brush(default_color="rgb(255, 255, 255)") # Kuas warna putih
 
188
  )
189
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
190
 
 
123
  transforms.ToTensor(),
124
  transforms.Normalize((0.1307,), (0.3081,))
125
  ])
 
126
  def predict_digit(input_image):
127
  if input_image is None:
128
  return "Please draw a number!"
129
+
130
  try:
131
+ # Versi aman: ambil composite jika ada (numpy array HxW atau HxWxC)
132
+ if isinstance(input_image, dict):
133
+ # Beberapa versi Gradio meletakkan hasil akhir di 'composite'
134
+ img_array = input_image.get("composite", input_image["layers"][0])
 
135
  else:
136
  img_array = input_image
137
 
138
+ # Konversi ke grayscale 2D
139
+ if isinstance(img_array, np.ndarray):
140
+ if img_array.ndim == 3:
141
+ if img_array.shape[-1] == 4: # RGBA -> ambil alpha
142
+ grayscale = img_array[..., 3]
143
+ else: # RGB -> luminance
144
+ grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
145
+ else:
146
+ grayscale = img_array
147
  else:
148
+ # Kalau ternyata PIL.Image
149
+ grayscale = np.array(img_array.convert("L"))
150
 
151
+ # 3. Cek kanvas kosong (kali ini cek nilai maks > 0, bukan == 0)
152
+ if grayscale.max() == 0:
153
  return {str(i): 0.1 for i in range(10)}
154
 
155
+ # 4. Resize dan normalisasi
156
+ img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
157
+ img = img.resize((28, 28), Image.Resampling.BILINEAR)
 
 
158
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
159
+
160
+ # 5. Prediksi
161
  with torch.no_grad():
162
  outputs = model(tensor_img)
163
  probabilities = F.softmax(outputs, dim=1)[0]
164
+
165
  return {str(i): float(probabilities[i]) for i in range(10)}
166
 
167
  except Exception as e:
168
+ # Untuk debug, kembalikan pesan errornya
169
+ return {"error": str(e)}
170
 
171
 
172
  # --- GRADIO INTERFACE CONSTRUCTION ---
 
180
 
181
  with gr.Row():
182
  with gr.Column():
 
183
  input_canvas = gr.Paint(
184
+ image_mode="L",
185
  height=280,
186
  width=280,
187
+ canvas_color="black", # ⬅️ ini yang wajib ditambahkan
188
+ brush=gr.components.image_editor.Brush(default_color="rgb(255, 255, 255)")
189
  )
190
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
191