ASomeoneWhoInterestedWithAI commited on
Commit
a605261
·
verified ·
1 Parent(s): 89c8935

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as transforms
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ # --- CONFIG & MODEL DOWNLOAD ---
12
+ MODEL_PATH = "LookThem_V8_MNIST.pth"
13
+ HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"
14
+
15
+ if not os.path.exists(MODEL_PATH):
16
+ print(f"Downloading model weights from Hugging Face...")
17
+ urllib.request.urlretrieve(HF_URL, MODEL_PATH)
18
+ print("Download complete!")
19
+
20
+ # --- DEFINE YOUR MODEL ARCHITECTURE ---
21
+ class LookThemLayer(nn.Module):
22
+ def __init__(self, num_tokens, in_features, hidden_dim):
23
+ super().__init__()
24
+ self.num_tokens = num_tokens
25
+ self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
26
+ self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
27
+ self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
28
+ self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
29
+ self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
30
+ self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
31
+ self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
32
+ self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
33
+ self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
34
+ self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
35
+
36
+ def forward(self, x):
37
+ N = self.num_tokens
38
+ h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1
39
+ out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2
40
+ h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1
41
+ out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2
42
+
43
+ out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6)
44
+ compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
45
+ compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
46
+
47
+ trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1)
48
+ trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1)
49
+
50
+ interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
51
+ mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1)
52
+ return (interaksi * mask).sum(dim=2) / (N - 1.0)
53
+
54
+ class LiteResidualBlock(nn.Module):
55
+ def __init__(self, dim, dropout=0.05):
56
+ super().__init__()
57
+ self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
58
+ self.norm = nn.LayerNorm(dim)
59
+ def forward(self, x):
60
+ return self.norm(x + self.block(x))
61
+
62
+ class LookThemV8MNIST(nn.Module):
63
+ def __init__(self):
64
+ super().__init__()
65
+ self.stream_a = nn.Sequential(
66
+ nn.Conv2d(1, 4, 3, 2, 1),
67
+ nn.BatchNorm2d(4), nn.GELU(),
68
+ nn.Conv2d(4, 8, 3, 2, 1),
69
+ nn.BatchNorm2d(8), nn.GELU(),
70
+ nn.AdaptiveMaxPool2d((8, 8)))
71
+ self.stream_b = nn.Sequential(
72
+ nn.Conv2d(1, 4, 3, 1, 1),
73
+ nn.BatchNorm2d(4), nn.GELU(),
74
+ nn.Conv2d(4, 8, 3, 1, 1),
75
+ nn.BatchNorm2d(8), nn.GELU(),
76
+ nn.AdaptiveMaxPool2d((8, 8)))
77
+
78
+ self.lookthemA = LookThemLayer(64, 8, 32)
79
+ self.lookthemB = LookThemLayer(64, 8, 32)
80
+ self.lookthem_comb = LookThemLayer(64, 16, 32)
81
+ self.comb_norm = nn.LayerNorm(16)
82
+
83
+ self.FFN1 = nn.Conv1d(16, 8, 1)
84
+ self.lookthem2 = LookThemLayer(64, 8, 32)
85
+ self.FFN2 = nn.Conv1d(8, 8, 1)
86
+
87
+ self.compressor = nn.Conv1d(8, 4, 1)
88
+ self.input_proj = nn.Linear(64 * 4, 128)
89
+ self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128))
90
+ self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10))
91
+
92
+ def forward(self, x):
93
+ b = x.size(0)
94
+ fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2))
95
+ fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2))
96
+ x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2)))
97
+ x = x.transpose(1, 2)
98
+ x = self.FFN1(x).transpose(1, 2)
99
+ res = x
100
+ x = self.lookthem2(x).transpose(1, 2)
101
+ x = x.transpose(1, 2)
102
+ x = self.FFN2(x) + res.transpose(1, 2)
103
+ x = self.compressor(x).flatten(1)
104
+ x = self.res_blocks(self.input_proj(x))
105
+ return self.head(x)
106
+
107
+ # --- LOAD WEIGHTS ON CPU/GPU ---
108
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
109
+ model = LookThemV8MNIST()
110
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
111
+ model.to(device)
112
+ model.eval()
113
+
114
+ # --- PREPROCESSING MATCHING TRAINING PIPELINE ---
115
+ # Using the exact MNIST normalization values from your training code
116
+ transform_fn = transforms.Compose([
117
+ transforms.Resize((28, 28)),
118
+ transforms.ToTensor(),
119
+ transforms.Normalize((0.1307,), (0.3081,))
120
+ ])
121
+
122
+ def predict_digit(input_image):
123
+ if input_image is None:
124
+ return "Please draw a number!"
125
+
126
+ # Process image background depending on Gradio Sketchpad structure (composite dictionary)
127
+ if isinstance(input_image, dict) and "composite" in input_image:
128
+ img = input_image["composite"]
129
+ else:
130
+ img = input_image
131
+
132
+ # Convert to grayscale
133
+ img = Image.fromarray(img.astype('uint8')).convert('L')
134
+
135
+ # Apply matching transformations
136
+ tensor_img = transform_fn(img).unsqueeze(0).to(device)
137
+
138
+ with torch.no_grad():
139
+ outputs = model(tensor_img)
140
+ probabilities = F.softmax(outputs, dim=1)[0]
141
+
142
+ # Format top class probabilities for Gradio output
143
+ return {str(i): float(probabilities[i]) for i in range(10)}
144
+
145
+ # --- GRADIO INTERFACE CONSTRUCTION ---
146
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
147
+ gr.Markdown(
148
+ """
149
+ # 🧠 LookThem V8 - MNIST Fraction Engine Classifier
150
+ ### Built by a 13-year-old developer | 315K Parameters | **99.53% Validation Accuracy**
151
+
152
+ Draw a single digit (0-9) in the sketchpad below to see how the fractional token gating engine analyzes structural patterns!
153
+ """
154
+ )
155
+
156
+ with gr.Row():
157
+ with gr.Column():
158
+ # Create a 280x280 canvas for white drawing on black canvas background
159
+ input_canvas = gr.Sketchpad(
160
+ crop_size=(280, 280),
161
+ type="numpy",
162
+ label="Draw Digit Here",
163
+ layers=False,
164
+ canvas_size=(280, 280)
165
+ )
166
+ submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
167
+ clear_btn = gr.Button("Clear Canvas")
168
+
169
+ with gr.Column():
170
+ output_label = gr.Label(num_top_classes=3, label="Top Probabilities")
171
+
172
+ # Hook up action events
173
+ submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)
174
+ clear_btn.click(fn=lambda: None, outputs=input_canvas)
175
+
176
+ if __name__ == "__main__":
177
+ demo.launch()