ASomeoneWhoInterestedWithAI commited on
Commit
307755a
·
verified ·
1 Parent(s): c65e33f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -27
app.py CHANGED
@@ -17,8 +17,7 @@ if not os.path.exists(MODEL_PATH):
17
  urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
  print("Download complete!")
19
 
20
- # --- DEFINE YOUR MODEL ARCHITECTURE ---
21
- # (Bagian kelas LookThemLayer, LiteResidualBlock, dan LookThemV8MNIST tetap sama)
22
  class LookThemLayer(nn.Module):
23
  def __init__(self, num_tokens, in_features, hidden_dim):
24
  super().__init__()
@@ -108,12 +107,11 @@ class LookThemV8MNIST(nn.Module):
108
  x = self.compressor(x).flatten(1)
109
  x = self.res_blocks(self.input_proj(x))
110
  return self.head(x)
111
- # ... (Salin definisi model Anda di sini) ...
112
 
113
  # --- LOAD WEIGHTS ON CPU/GPU ---
114
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
  model = LookThemV8MNIST()
116
- model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
117
  model.to(device)
118
  model.eval()
119
 
@@ -125,34 +123,59 @@ transform_fn = transforms.Compose([
125
  ])
126
 
127
  def predict_digit(input_image):
 
 
 
128
  if input_image is None:
129
- return "Please draw a number!"
130
 
131
  try:
132
- # gr.Sketchpad mengembalikan numpy array secara langsung
133
- img_array = input_image
134
-
135
- # Cek apakah kanvas kosong (semua piksel bernilai 0)
136
- if np.max(img_array) == 0:
137
- return {str(i): 0.1 for i in range(10)}
138
-
139
- # Konversi ke PIL Image dan resize
140
- img = Image.fromarray(img_array.astype(np.uint8), mode="L")
141
- img = img.resize((28, 28), Image.Resampling.BILINEAR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- # Transformasi dan prediksi
 
 
 
 
 
 
144
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
145
 
146
  with torch.no_grad():
147
  outputs = model(tensor_img)
148
- probabilities = F.softmax(outputs, dim=1)[0]
149
-
150
- return {str(i): float(probabilities[i]) for i in range(10)}
151
-
152
  except Exception as e:
153
- return {"Error": str(e)}
 
 
154
 
155
- # --- GRADIO INTERFACE CONSTRUCTION ---
156
  with gr.Blocks() as demo:
157
  gr.Markdown(
158
  """
@@ -163,15 +186,12 @@ with gr.Blocks() as demo:
163
 
164
  with gr.Row():
165
  with gr.Column():
166
- # Gunakan gr.Sketchpad
167
  input_canvas = gr.Sketchpad(
168
  image_mode="L",
169
  height=280,
170
  width=280,
171
- brush=gr.Brush(
172
- default_color="rgb(255, 255, 255)", # Kuas putih
173
- color_mode="fixed"
174
- )
175
  )
176
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
177
 
 
17
  urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
  print("Download complete!")
19
 
20
+ # --- DEFINE YOUR MODEL ARCHITECTURE (sama seperti sebelumnya) ---
 
21
  class LookThemLayer(nn.Module):
22
  def __init__(self, num_tokens, in_features, hidden_dim):
23
  super().__init__()
 
107
  x = self.compressor(x).flatten(1)
108
  x = self.res_blocks(self.input_proj(x))
109
  return self.head(x)
 
110
 
111
  # --- LOAD WEIGHTS ON CPU/GPU ---
112
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113
  model = LookThemV8MNIST()
114
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True))
115
  model.to(device)
116
  model.eval()
117
 
 
123
  ])
124
 
125
  def predict_digit(input_image):
126
+ # Selalu kembalikan dictionary 10 digit untuk gr.Label
127
+ default_output = {str(i): 0.1 for i in range(10)}
128
+
129
  if input_image is None:
130
+ return default_output
131
 
132
  try:
133
+ # Tangani berbagai format input (dict dari Paint, array dari Sketchpad, dll.)
134
+ if isinstance(input_image, dict):
135
+ # gr.Paint versi lama -> ambil composite atau layer pertama
136
+ img_array = input_image.get("composite")
137
+ if img_array is None and "layers" in input_image:
138
+ layers = input_image["layers"]
139
+ img_array = layers[0] if layers else None
140
+ if img_array is None:
141
+ return default_output
142
+ else:
143
+ img_array = input_image
144
+
145
+ # Konversi ke numpy array jika belum
146
+ if not isinstance(img_array, np.ndarray):
147
+ img_array = np.array(img_array)
148
+
149
+ # Jika gambar berwarna, ambil channel yang tepat
150
+ if img_array.ndim == 3:
151
+ if img_array.shape[-1] == 4: # RGBA → alpha
152
+ grayscale = img_array[..., 3]
153
+ else: # RGB → luminance
154
+ grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
155
+ else:
156
+ grayscale = img_array
157
 
158
+ # Cek kanvas kosong
159
+ if grayscale.max() == 0:
160
+ return default_output
161
+
162
+ # Resize & normalisasi
163
+ img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
164
+ img = img.resize((28, 28), Image.Resampling.BILINEAR)
165
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
166
 
167
  with torch.no_grad():
168
  outputs = model(tensor_img)
169
+ probs = F.softmax(outputs, dim=1)[0]
170
+
171
+ return {str(i): float(probs[i]) for i in range(10)}
172
+
173
  except Exception as e:
174
+ # Kembalikan uniform jika terjadi error tak terduga
175
+ print(f"Prediction error: {e}")
176
+ return default_output
177
 
178
+ # --- GRADIO INTERFACE ---
179
  with gr.Blocks() as demo:
180
  gr.Markdown(
181
  """
 
186
 
187
  with gr.Row():
188
  with gr.Column():
189
+ # GANTI: gunakan Sketchpad agar latar hitam + pena putih
190
  input_canvas = gr.Sketchpad(
191
  image_mode="L",
192
  height=280,
193
  width=280,
194
+ brush=gr.Brush(default_color="rgb(255,255,255)", color_mode="fixed")
 
 
 
195
  )
196
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
197