ASomeoneWhoInterestedWithAI commited on
Commit
51849ff
·
verified ·
1 Parent(s): ceb8453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -17,7 +17,7 @@ if not os.path.exists(MODEL_PATH):
17
  urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
  print("Download complete!")
19
 
20
- # --- DEFINE YOUR MODEL ARCHITECTURE (TETAP SAMA) ---
21
  class LookThemLayer(nn.Module):
22
  def __init__(self, num_tokens, in_features, hidden_dim):
23
  super().__init__()
@@ -107,7 +107,6 @@ class LookThemV8MNIST(nn.Module):
107
  x = self.compressor(x).flatten(1)
108
  x = self.res_blocks(self.input_proj(x))
109
  return self.head(x)
110
- # ... (Salin definisi kelas LookThemLayer, LiteResidualBlock, dan LookThemV8MNIST Anda di sini) ...
111
 
112
  # --- LOAD WEIGHTS ON CPU/GPU ---
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -124,33 +123,43 @@ transform_fn = transforms.Compose([
124
  ])
125
 
126
  def predict_digit(input_image):
127
- # Default output jika kanvas kosong
128
  default_output = {str(i): 0.1 for i in range(10)}
129
 
130
  if input_image is None:
131
  return default_output
132
 
133
  try:
134
- # gr.Image(source="canvas") mengembalikan numpy array secara langsung
135
- img_array = input_image
136
-
137
- # Konversi ke grayscale jika perlu (hasil kanvas biasanya sudah grayscale)
 
 
 
 
 
 
 
 
138
  if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
139
- # Ambil channel pertama jika multichannel, atau konversi ke luminance
140
- if img_array.shape[-1] == 4: # RGBA -> alpha
141
  grayscale = img_array[..., 3]
142
  else: # RGB -> luminance
143
  grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
144
  else:
145
  grayscale = img_array
146
 
147
- # Cek apakah kanvas kosong (semua piksel bernilai 0 atau mendekati)
148
  if np.max(grayscale) < 5:
149
  return default_output
150
 
151
- # Resize & normalisasi
 
 
 
 
152
  img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
153
- img = img.resize((28, 28), Image.Resampling.BILINEAR)
154
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
155
 
156
  with torch.no_grad():
@@ -174,14 +183,11 @@ with gr.Blocks() as demo:
174
 
175
  with gr.Row():
176
  with gr.Column():
177
- # Gunakan gr.Image dengan source="canvas"
178
- input_canvas = gr.Image(
179
- image_mode="L",
180
- height=280,
181
- width=280,
182
- sources="canvas", # Mengaktifkan mode kanvas untuk menggambar
183
- invert_colors=True, # Membalik warna: latar hitam, coretan putih
184
- brush=gr.Brush(default_color="rgb(0,0,0)", color_mode="fixed") # Kuas hitam (akan dibalik jadi putih)
185
  )
186
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
187
 
@@ -191,4 +197,4 @@ with gr.Blocks() as demo:
191
  submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)
192
 
193
  if __name__ == "__main__":
194
- demo.launch()
 
17
  urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
  print("Download complete!")
19
 
20
+ # --- DEFINE YOUR MODEL ARCHITECTURE ---
21
  class LookThemLayer(nn.Module):
22
  def __init__(self, num_tokens, in_features, hidden_dim):
23
  super().__init__()
 
107
  x = self.compressor(x).flatten(1)
108
  x = self.res_blocks(self.input_proj(x))
109
  return self.head(x)
 
110
 
111
  # --- LOAD WEIGHTS ON CPU/GPU ---
112
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
123
  ])
124
 
125
  def predict_digit(input_image):
 
126
  default_output = {str(i): 0.1 for i in range(10)}
127
 
128
  if input_image is None:
129
  return default_output
130
 
131
  try:
132
+ # Extract the background or composite layer from the Gradio Sketchpad dictionary
133
+ if isinstance(input_image, dict):
134
+ img_array = input_image.get("composite", None)
135
+ if img_array is None:
136
+ img_array = input_image.get("background", None)
137
+ else:
138
+ img_array = input_image
139
+
140
+ if img_array is None:
141
+ return default_output
142
+
143
+ # Extract channels safely
144
  if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
145
+ if img_array.shape[-1] == 4: # RGBA -> alpha channel
 
146
  grayscale = img_array[..., 3]
147
  else: # RGB -> luminance
148
  grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
149
  else:
150
  grayscale = img_array
151
 
152
+ # Check if canvas is essentially empty
153
  if np.max(grayscale) < 5:
154
  return default_output
155
 
156
+ # Ensure the background is black and the text is white (standard MNIST setup)
157
+ # If your brush was black and canvas was white, invert it here:
158
+ # grayscale = 255 - grayscale
159
+
160
+ # Resize & normalize
161
  img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
162
+ img = img.resize((28, 28), Image.Image.Resampling.BILINEAR if hasattr(Image, 'Image') else Image.BILINEAR)
163
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
164
 
165
  with torch.no_grad():
 
183
 
184
  with gr.Row():
185
  with gr.Column():
186
+ # Standardized setup for canvas sketching in modern Gradio versions
187
+ input_canvas = gr.Sketchpad(
188
+ type="numpy",
189
+ layers=False,
190
+ canvas_size=(280, 280)
 
 
 
191
  )
192
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
193
 
 
197
  submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)
198
 
199
  if __name__ == "__main__":
200
+ demo.launch()