ASomeoneWhoInterestedWithAI commited on
Commit
2ab4f6e
·
verified ·
1 Parent(s): 307755a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -120
app.py CHANGED
@@ -17,96 +17,8 @@ if not os.path.exists(MODEL_PATH):
17
  urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
  print("Download complete!")
19
 
20
- # --- DEFINE YOUR MODEL ARCHITECTURE (sama seperti sebelumnya) ---
21
- class LookThemLayer(nn.Module):
22
- def __init__(self, num_tokens, in_features, hidden_dim):
23
- super().__init__()
24
- self.num_tokens = num_tokens
25
- self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
26
- self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
27
- self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
28
- self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
29
- self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
30
- self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
31
- self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
32
- self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
33
- self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
34
- self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
35
- self._init_weights()
36
-
37
- def _init_weights(self):
38
- for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]:
39
- nn.init.xavier_uniform_(w)
40
-
41
- def forward(self, x):
42
- N = self.num_tokens
43
- h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1
44
- out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2
45
- h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1
46
- out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2
47
-
48
- out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6)
49
- compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
50
- compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
51
-
52
- trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1)
53
- trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1)
54
-
55
- interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
56
- mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1)
57
- return (interaksi * mask).sum(dim=2) / (N - 1.0)
58
-
59
- class LiteResidualBlock(nn.Module):
60
- def __init__(self, dim, dropout=0.05):
61
- super().__init__()
62
- self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
63
- self.norm = nn.LayerNorm(dim)
64
- def forward(self, x):
65
- return self.norm(x + self.block(x))
66
-
67
- class LookThemV8MNIST(nn.Module):
68
- def __init__(self):
69
- super().__init__()
70
- self.stream_a = nn.Sequential(
71
- nn.Conv2d(1, 4, 3, 2, 1),
72
- nn.BatchNorm2d(4), nn.GELU(),
73
- nn.Conv2d(4, 8, 3, 2, 1),
74
- nn.BatchNorm2d(8), nn.GELU(),
75
- nn.AdaptiveMaxPool2d((8, 8)))
76
- self.stream_b = nn.Sequential(
77
- nn.Conv2d(1, 4, 3, 1, 1),
78
- nn.BatchNorm2d(4), nn.GELU(),
79
- nn.Conv2d(4, 8, 3, 1, 1),
80
- nn.BatchNorm2d(8), nn.GELU(),
81
- nn.AdaptiveMaxPool2d((8, 8)))
82
-
83
- self.lookthemA = LookThemLayer(64, 8, 32)
84
- self.lookthemB = LookThemLayer(64, 8, 32)
85
- self.lookthem_comb = LookThemLayer(64, 16, 32)
86
- self.comb_norm = nn.LayerNorm(16)
87
-
88
- self.FFN1 = nn.Conv1d(16, 8, 1)
89
- self.lookthem2 = LookThemLayer(64, 8, 32)
90
- self.FFN2 = nn.Conv1d(8, 8, 1)
91
-
92
- self.compressor = nn.Conv1d(8, 4, 1)
93
- self.input_proj = nn.Linear(64 * 4, 128)
94
- self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128))
95
- self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10))
96
-
97
- def forward(self, x):
98
- b = x.size(0)
99
- fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2))
100
- fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2))
101
- x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2)))
102
- x = x.transpose(1, 2)
103
- x = self.FFN1(x).transpose(1, 2)
104
- res = x
105
- x = self.lookthem2(x).transpose(1, 2)
106
- x = self.FFN2(x) + res.transpose(1, 2)
107
- x = self.compressor(x).flatten(1)
108
- x = self.res_blocks(self.input_proj(x))
109
- return self.head(x)
110
 
111
  # --- LOAD WEIGHTS ON CPU/GPU ---
112
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -123,40 +35,28 @@ transform_fn = transforms.Compose([
123
  ])
124
 
125
  def predict_digit(input_image):
126
- # Selalu kembalikan dictionary 10 digit untuk gr.Label
127
  default_output = {str(i): 0.1 for i in range(10)}
128
 
129
  if input_image is None:
130
  return default_output
131
 
132
  try:
133
- # Tangani berbagai format input (dict dari Paint, array dari Sketchpad, dll.)
134
- if isinstance(input_image, dict):
135
- # gr.Paint versi lama -> ambil composite atau layer pertama
136
- img_array = input_image.get("composite")
137
- if img_array is None and "layers" in input_image:
138
- layers = input_image["layers"]
139
- img_array = layers[0] if layers else None
140
- if img_array is None:
141
- return default_output
142
- else:
143
- img_array = input_image
144
-
145
- # Konversi ke numpy array jika belum
146
- if not isinstance(img_array, np.ndarray):
147
- img_array = np.array(img_array)
148
 
149
- # Jika gambar berwarna, ambil channel yang tepat
150
- if img_array.ndim == 3:
151
- if img_array.shape[-1] == 4: # RGBA alpha
 
152
  grayscale = img_array[..., 3]
153
- else: # RGB luminance
154
  grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
155
  else:
156
  grayscale = img_array
157
-
158
- # Cek kanvas kosong
159
- if grayscale.max() == 0:
160
  return default_output
161
 
162
  # Resize & normalisasi
@@ -166,12 +66,11 @@ def predict_digit(input_image):
166
 
167
  with torch.no_grad():
168
  outputs = model(tensor_img)
169
- probs = F.softmax(outputs, dim=1)[0]
170
-
171
- return {str(i): float(probs[i]) for i in range(10)}
172
 
173
  except Exception as e:
174
- # Kembalikan uniform jika terjadi error tak terduga
175
  print(f"Prediction error: {e}")
176
  return default_output
177
 
@@ -186,12 +85,14 @@ with gr.Blocks() as demo:
186
 
187
  with gr.Row():
188
  with gr.Column():
189
- # GANTI: gunakan Sketchpad agar latar hitam + pena putih
190
- input_canvas = gr.Sketchpad(
191
  image_mode="L",
192
  height=280,
193
  width=280,
194
- brush=gr.Brush(default_color="rgb(255,255,255)", color_mode="fixed")
 
 
195
  )
196
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
197
 
 
17
  urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
  print("Download complete!")
19
 
20
+ # --- DEFINE YOUR MODEL ARCHITECTURE (TETAP SAMA) ---
21
+ # ... (Salin definisi kelas LookThemLayer, LiteResidualBlock, dan LookThemV8MNIST Anda di sini) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # --- LOAD WEIGHTS ON CPU/GPU ---
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
35
  ])
36
 
37
  def predict_digit(input_image):
38
+ # Default output jika kanvas kosong
39
  default_output = {str(i): 0.1 for i in range(10)}
40
 
41
  if input_image is None:
42
  return default_output
43
 
44
  try:
45
+ # gr.Image(source="canvas") mengembalikan numpy array secara langsung
46
+ img_array = input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Konversi ke grayscale jika perlu (hasil kanvas biasanya sudah grayscale)
49
+ if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
50
+ # Ambil channel pertama jika multichannel, atau konversi ke luminance
51
+ if img_array.shape[-1] == 4: # RGBA -> alpha
52
  grayscale = img_array[..., 3]
53
+ else: # RGB -> luminance
54
  grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
55
  else:
56
  grayscale = img_array
57
+
58
+ # Cek apakah kanvas kosong (semua piksel bernilai 0 atau mendekati)
59
+ if np.max(grayscale) < 5:
60
  return default_output
61
 
62
  # Resize & normalisasi
 
66
 
67
  with torch.no_grad():
68
  outputs = model(tensor_img)
69
+ probabilities = F.softmax(outputs, dim=1)[0]
70
+
71
+ return {str(i): float(probabilities[i]) for i in range(10)}
72
 
73
  except Exception as e:
 
74
  print(f"Prediction error: {e}")
75
  return default_output
76
 
 
85
 
86
  with gr.Row():
87
  with gr.Column():
88
+ # Gunakan gr.Image dengan source="canvas"
89
+ input_canvas = gr.Image(
90
  image_mode="L",
91
  height=280,
92
  width=280,
93
+ source="canvas", # Mengaktifkan mode kanvas untuk menggambar
94
+ invert_colors=True, # Membalik warna: latar hitam, coretan putih
95
+ brush=gr.Brush(default_color="rgb(0,0,0)", color_mode="fixed") # Kuas hitam (akan dibalik jadi putih)
96
  )
97
  submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
98