File size: 9,766 Bytes
a605261
 
 
 
 
 
 
 
 
 
3fec527
 
 
 
a605261
 
3fec527
a605261
 
 
 
3fec527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a605261
51849ff
a5137a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a605261
 
 
 
307755a
a605261
 
 
 
 
 
 
 
 
9cbe616
a605261
307755a
 
a605261
307755a
9cbe616
cb2cf3e
077e064
51849ff
 
 
 
 
 
 
 
 
 
077e064
2ab4f6e
077e064
 
 
 
307755a
 
077e064
2ab4f6e
077e064
 
 
 
 
 
 
307755a
 
077e064
 
51849ff
077e064
307755a
077e064
cb2cf3e
9cbe616
077e064
cb2cf3e
 
2ab4f6e
 
 
307755a
cb2cf3e
307755a
 
a605261
077e064
307755a
ba16773
a605261
 
 
4b78229
a605261
 
 
 
 
51849ff
 
 
 
 
a605261
 
 
 
 
 
 
68860ec
a605261
51849ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import os
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import gradio as gr
import numpy as np
from PIL import Image

import os
import urllib.request
import zipfile

# --- CONFIG & MODEL DOWNLOAD ---
MODEL_PATH = "LookThem_V8_MNIST.pth"
ZIP_PATH = "LookThem_V8_MNIST.zip"
HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"

if not os.path.exists(MODEL_PATH):
    print(f"Downloading model weights from Hugging Face...")
    # Download the file as a zip first
    urllib.request.urlretrieve(HF_URL, ZIP_PATH)
    print("Download complete! Checking for zip compression...")
    
    try:
        # Unzip the file
        with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
            # Look for a .pth file inside the zip
            file_list = zip_ref.namelist()
            pth_files = [f for f in file_list if f.endswith('.pth')]
            
            if pth_files:
                # Extract the .pth file and rename it to our expected MODEL_PATH
                zip_ref.extract(pth_files[0], path=".")
                if pth_files[0] != MODEL_PATH:
                    os.rename(pth_files[0], MODEL_PATH)
                print(f"Successfully extracted: {pth_files[0]} -> {MODEL_PATH}")
            else:
                # If no .pth inside, maybe the zip itself *is* the model (PyTorch 2.0+ format)
                print("No .pth file found inside zip. Renaming zip directly to .pth...")
                os.rename(ZIP_PATH, MODEL_PATH)
                
    except zipfile.BadZipFile:
        # If it wasn't actually a zip file, just rename it
        print("File is not a zip archive. Proceeding with standard weight loading.")
        os.rename(ZIP_PATH, MODEL_PATH)
        
    # Clean up the temporary zip file if it still exists
    if os.path.exists(ZIP_PATH):
        os.remove(ZIP_PATH)

# --- DEFINE YOUR MODEL ARCHITECTURE ---
class LookThemLayer(nn.Module):
    def __init__(self, num_tokens, in_features, hidden_dim):
        super().__init__()
        self.num_tokens = num_tokens
        self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
        self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
        self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
        self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
        self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
        self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
        self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
        self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
        self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
        self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
        self._init_weights()

    def _init_weights(self):
        for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]:
            nn.init.xavier_uniform_(w)

    def forward(self, x):
        N = self.num_tokens
        h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1
        out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2
        h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1
        out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2

        out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6)
        compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
        compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))

        trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1)
        trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1)

        interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
        mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1)
        return (interaksi * mask).sum(dim=2) / (N - 1.0)

class LiteResidualBlock(nn.Module):
    def __init__(self, dim, dropout=0.05):
        super().__init__()
        self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
        self.norm = nn.LayerNorm(dim)
    def forward(self, x):
        return self.norm(x + self.block(x))

class LookThemV8MNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.stream_a = nn.Sequential(
                nn.Conv2d(1, 4, 3, 2, 1),
                nn.BatchNorm2d(4), nn.GELU(),
                nn.Conv2d(4, 8, 3, 2, 1),
                nn.BatchNorm2d(8), nn.GELU(),
                nn.AdaptiveMaxPool2d((8, 8)))
        self.stream_b = nn.Sequential(
                nn.Conv2d(1, 4, 3, 1, 1),
                nn.BatchNorm2d(4), nn.GELU(),
                nn.Conv2d(4, 8, 3, 1, 1),
                nn.BatchNorm2d(8), nn.GELU(),
                nn.AdaptiveMaxPool2d((8, 8)))

        self.lookthemA = LookThemLayer(64, 8, 32)
        self.lookthemB = LookThemLayer(64, 8, 32)
        self.lookthem_comb = LookThemLayer(64, 16, 32)
        self.comb_norm = nn.LayerNorm(16)

        self.FFN1 = nn.Conv1d(16, 8, 1)
        self.lookthem2 = LookThemLayer(64, 8, 32)
        self.FFN2 = nn.Conv1d(8, 8, 1)

        self.compressor = nn.Conv1d(8, 4, 1)
        self.input_proj = nn.Linear(64 * 4, 128)
        self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128))
        self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10))

    def forward(self, x):
        b = x.size(0)
        fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2))
        fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2))
        x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2)))
        x = x.transpose(1, 2)
        x = self.FFN1(x).transpose(1, 2)
        res = x
        x = self.lookthem2(x).transpose(1, 2)
        x = self.FFN2(x) + res.transpose(1, 2)
        x = self.compressor(x).flatten(1)
        x = self.res_blocks(self.input_proj(x))
        return self.head(x)

# --- LOAD WEIGHTS ON CPU/GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LookThemV8MNIST()
model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True))
model.to(device)
model.eval()

# --- PREPROCESSING MATCHING TRAINING PIPELINE ---
transform_fn = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

def predict_digit(input_image):
    default_output = {str(i): 0.1 for i in range(10)}
    
    if input_image is None:
        return default_output
    
    try:
        # 1. Handle Gradio Sketchpad dictionary output
        if isinstance(input_image, dict):
            img_array = input_image.get("composite", None)
            if img_array is None:
                img_array = input_image.get("background", None)
        else:
            img_array = input_image
            
        if img_array is None:
            return default_output

        # 2. Convert to Grayscale safely
        if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
            if img_array.shape[-1] == 4:      # RGBA (Canvas often uses alpha)
                # If background is transparent/white, alpha channel might be inverted
                grayscale = img_array[..., 3] 
            else:                             # RGB -> Grayscale
                grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
        else:
            grayscale = img_array.copy()
            
        # 3. AUTO-INVERT: Ensure white digit on black background
        # If the average pixel value is bright (> 127), the user drew dark text on light background.
        if np.mean(grayscale) > 127:
            grayscale = 255.0 - grayscale

        # 4. Check if the canvas is empty
        if np.max(grayscale) < 15: 
            return default_output
        
        # Debugging print to check what your model is actually receiving
        print(f"Processed image shape: {grayscale.shape} | Max Val: {np.max(grayscale)} | Mean Val: {np.mean(grayscale):.2f}")
        
        # 5. Convert to PIL, Resize, and Transform
        img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
        img = img.resize((28, 28), Image.Resampling.BILINEAR)
        tensor_img = transform_fn(img).unsqueeze(0).to(device)
        
        # 6. Model Inference
        with torch.no_grad():
            outputs = model(tensor_img)
            probabilities = F.softmax(outputs, dim=1)[0]
            
        return {str(i): float(probabilities[i]) for i in range(10)}
    
    except Exception as e:
        print(f"Prediction error: {e}")
        return default_output


# --- GRADIO INTERFACE ---
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # 🧠 LookThem V8 - MNIST Fraction Engine Classifier
        ### 315K Parameters | **99.53% Validation Accuracy**
        """
    )
    
    with gr.Row():
        with gr.Column():
            # Standardized setup for canvas sketching in modern Gradio versions
            input_canvas = gr.Sketchpad(
                type="numpy",
                layers=False,
                canvas_size=(280, 280)
            )
            submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
            
        with gr.Column():
            output_label = gr.Label(num_top_classes=3, label="Top Probabilities")

    submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)

if __name__ == "__main__":
    demo.launch()