ASomeoneWhoInterestedWithAI commited on
Commit
9cbe616
·
verified ·
1 Parent(s): e9455f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -126
app.py CHANGED
@@ -18,96 +18,8 @@ if not os.path.exists(MODEL_PATH):
18
  print("Download complete!")
19
 
20
  # --- DEFINE YOUR MODEL ARCHITECTURE ---
21
-
22
- class LookThemLayer(nn.Module):
23
- def __init__(self, num_tokens, in_features, hidden_dim):
24
- super().__init__()
25
- self.num_tokens = num_tokens
26
- self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
27
- self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
28
- self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
29
- self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
30
- self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
31
- self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
32
- self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
33
- self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
34
- self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
35
- self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
36
- self._init_weights()
37
-
38
- def _init_weights(self):
39
- for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]:
40
- nn.init.xavier_uniform_(w)
41
-
42
- def forward(self, x):
43
- N = self.num_tokens
44
- h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1
45
- out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2
46
- h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1
47
- out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2
48
-
49
- out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6)
50
- compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
51
- compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
52
-
53
- trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1)
54
- trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1)
55
-
56
- interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
57
- mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1)
58
- return (interaksi * mask).sum(dim=2) / (N - 1.0)
59
-
60
- class LiteResidualBlock(nn.Module):
61
- def __init__(self, dim, dropout=0.05):
62
- super().__init__()
63
- self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
64
- self.norm = nn.LayerNorm(dim)
65
- def forward(self, x):
66
- return self.norm(x + self.block(x))
67
-
68
- class LookThemV8MNIST(nn.Module):
69
- def __init__(self):
70
- super().__init__()
71
- self.stream_a = nn.Sequential(
72
- nn.Conv2d(1, 4, 3, 2, 1),
73
- nn.BatchNorm2d(4), nn.GELU(),
74
- nn.Conv2d(4, 8, 3, 2, 1),
75
- nn.BatchNorm2d(8), nn.GELU(),
76
- nn.AdaptiveMaxPool2d((8, 8)))
77
- self.stream_b = nn.Sequential(
78
- nn.Conv2d(1, 4, 3, 1, 1),
79
- nn.BatchNorm2d(4), nn.GELU(),
80
- nn.Conv2d(4, 8, 3, 1, 1),
81
- nn.BatchNorm2d(8), nn.GELU(),
82
- nn.AdaptiveMaxPool2d((8, 8)))
83
-
84
- self.lookthemA = LookThemLayer(64, 8, 32)
85
- self.lookthemB = LookThemLayer(64, 8, 32)
86
- self.lookthem_comb = LookThemLayer(64, 16, 32)
87
- self.comb_norm = nn.LayerNorm(16)
88
-
89
- self.FFN1 = nn.Conv1d(16, 8, 1)
90
- self.lookthem2 = LookThemLayer(64, 8, 32)
91
- self.FFN2 = nn.Conv1d(8, 8, 1)
92
-
93
- self.compressor = nn.Conv1d(8, 4, 1)
94
- self.input_proj = nn.Linear(64 * 4, 128)
95
- self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128))
96
- self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10))
97
-
98
- def forward(self, x):
99
- b = x.size(0)
100
- fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2))
101
- fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2))
102
- x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2)))
103
- x = x.transpose(1, 2)
104
- x = self.FFN1(x).transpose(1, 2)
105
- res = x
106
- x = self.lookthem2(x).transpose(1, 2)
107
- x = self.FFN2(x) + res.transpose(1, 2)
108
- x = self.compressor(x).flatten(1)
109
- x = self.res_blocks(self.input_proj(x))
110
- return self.head(x)
111
 
112
  # --- LOAD WEIGHTS ON CPU/GPU ---
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -117,57 +29,39 @@ model.to(device)
117
  model.eval()
118
 
119
  # --- PREPROCESSING MATCHING TRAINING PIPELINE ---
120
- # Using the exact MNIST normalization values from your training code
121
  transform_fn = transforms.Compose([
122
  transforms.Resize((28, 28)),
123
  transforms.ToTensor(),
124
  transforms.Normalize((0.1307,), (0.3081,))
125
  ])
 
126
  def predict_digit(input_image):
127
  if input_image is None:
128
  return "Please draw a number!"
129
-
130
  try:
131
- # Versi aman: ambil composite jika ada (numpy array HxW atau HxWxC)
132
- if isinstance(input_image, dict):
133
- # Beberapa versi Gradio meletakkan hasil akhir di 'composite'
134
- img_array = input_image.get("composite", input_image["layers"][0])
135
- else:
136
- img_array = input_image
137
 
138
- # Konversi ke grayscale 2D
139
- if isinstance(img_array, np.ndarray):
140
- if img_array.ndim == 3:
141
- if img_array.shape[-1] == 4: # RGBA -> ambil alpha
142
- grayscale = img_array[..., 3]
143
- else: # RGB -> luminance
144
- grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
145
- else:
146
- grayscale = img_array
147
- else:
148
- # Kalau ternyata PIL.Image
149
- grayscale = np.array(img_array.convert("L"))
150
-
151
- # 3. Cek kanvas kosong (kali ini cek nilai maks > 0, bukan == 0)
152
- if grayscale.max() == 0:
153
  return {str(i): 0.1 for i in range(10)}
154
 
155
- # 4. Resize dan normalisasi
156
- img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
157
  img = img.resize((28, 28), Image.Resampling.BILINEAR)
 
 
158
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
159
-
160
- # 5. Prediksi
161
  with torch.no_grad():
162
  outputs = model(tensor_img)
163
  probabilities = F.softmax(outputs, dim=1)[0]
164
-
165
  return {str(i): float(probabilities[i]) for i in range(10)}
166
 
167
  except Exception as e:
168
- # Untuk debug, kembalikan pesan errornya
169
- return {"error": str(e)}
170
-
171
 
172
  # --- GRADIO INTERFACE CONSTRUCTION ---
173
  with gr.Blocks() as demo:
@@ -180,12 +74,15 @@ with gr.Blocks() as demo:
180
 
181
  with gr.Row():
182
  with gr.Column():
183
- input_canvas = gr.Paint(
 
184
  image_mode="L",
185
  height=280,
186
  width=280,
187
- canvas_color="black", # ⬅️ ini yang wajib ditambahkan
188
- brush=gr.components.image_editor.Brush(default_color="rgb(255, 255, 255)")
 
 
189
  )
190
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
191
 
@@ -194,6 +91,5 @@ with gr.Blocks() as demo:
194
 
195
  submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)
196
 
197
-
198
  if __name__ == "__main__":
199
- demo.launch(theme=gr.themes.Soft())
 
18
  print("Download complete!")
19
 
20
  # --- DEFINE YOUR MODEL ARCHITECTURE ---
21
+ # (Bagian kelas LookThemLayer, LiteResidualBlock, dan LookThemV8MNIST tetap sama)
22
+ # ... (Salin definisi model Anda di sini) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # --- LOAD WEIGHTS ON CPU/GPU ---
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
29
  model.eval()
30
 
31
  # --- PREPROCESSING MATCHING TRAINING PIPELINE ---
 
32
  transform_fn = transforms.Compose([
33
  transforms.Resize((28, 28)),
34
  transforms.ToTensor(),
35
  transforms.Normalize((0.1307,), (0.3081,))
36
  ])
37
+
38
  def predict_digit(input_image):
39
  if input_image is None:
40
  return "Please draw a number!"
41
+
42
  try:
43
+ # gr.Sketchpad mengembalikan numpy array secara langsung
44
+ img_array = input_image
 
 
 
 
45
 
46
+ # Cek apakah kanvas kosong (semua piksel bernilai 0)
47
+ if np.max(img_array) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return {str(i): 0.1 for i in range(10)}
49
 
50
+ # Konversi ke PIL Image dan resize
51
+ img = Image.fromarray(img_array.astype(np.uint8), mode="L")
52
  img = img.resize((28, 28), Image.Resampling.BILINEAR)
53
+
54
+ # Transformasi dan prediksi
55
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
56
+
 
57
  with torch.no_grad():
58
  outputs = model(tensor_img)
59
  probabilities = F.softmax(outputs, dim=1)[0]
60
+
61
  return {str(i): float(probabilities[i]) for i in range(10)}
62
 
63
  except Exception as e:
64
+ return {"Error": str(e)}
 
 
65
 
66
  # --- GRADIO INTERFACE CONSTRUCTION ---
67
  with gr.Blocks() as demo:
 
74
 
75
  with gr.Row():
76
  with gr.Column():
77
+ # Gunakan gr.Sketchpad
78
+ input_canvas = gr.Sketchpad(
79
  image_mode="L",
80
  height=280,
81
  width=280,
82
+ brush=gr.Brush(
83
+ default_color="rgb(255, 255, 255)", # Kuas putih
84
+ color_mode="fixed"
85
+ )
86
  )
87
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
88
 
 
91
 
92
  submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)
93
 
 
94
  if __name__ == "__main__":
95
+ demo.launch()