| import os |
| import urllib.request |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
|
|
| import os |
| import urllib.request |
| import zipfile |
|
|
| |
| MODEL_PATH = "LookThem_V8_MNIST.pth" |
| ZIP_PATH = "LookThem_V8_MNIST.zip" |
| HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth" |
|
|
| if not os.path.exists(MODEL_PATH): |
| print(f"Downloading model weights from Hugging Face...") |
| |
| urllib.request.urlretrieve(HF_URL, ZIP_PATH) |
| print("Download complete! Checking for zip compression...") |
| |
| try: |
| |
| with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref: |
| |
| file_list = zip_ref.namelist() |
| pth_files = [f for f in file_list if f.endswith('.pth')] |
| |
| if pth_files: |
| |
| zip_ref.extract(pth_files[0], path=".") |
| if pth_files[0] != MODEL_PATH: |
| os.rename(pth_files[0], MODEL_PATH) |
| print(f"Successfully extracted: {pth_files[0]} -> {MODEL_PATH}") |
| else: |
| |
| print("No .pth file found inside zip. Renaming zip directly to .pth...") |
| os.rename(ZIP_PATH, MODEL_PATH) |
| |
| except zipfile.BadZipFile: |
| |
| print("File is not a zip archive. Proceeding with standard weight loading.") |
| os.rename(ZIP_PATH, MODEL_PATH) |
| |
| |
| if os.path.exists(ZIP_PATH): |
| os.remove(ZIP_PATH) |
|
|
| |
| class LookThemLayer(nn.Module): |
| def __init__(self, num_tokens, in_features, hidden_dim): |
| super().__init__() |
| self.num_tokens = num_tokens |
| self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) |
| self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) |
| self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) |
| self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) |
| self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) |
| self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) |
| self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) |
| self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) |
| self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1)) |
| self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1)) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]: |
| nn.init.xavier_uniform_(w) |
|
|
| def forward(self, x): |
| N = self.num_tokens |
| h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1 |
| out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2 |
| h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1 |
| out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2 |
|
|
| out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6) |
| compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1)) |
| compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2)) |
|
|
| trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1) |
| trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1) |
|
|
| interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2 |
| mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1) |
| return (interaksi * mask).sum(dim=2) / (N - 1.0) |
|
|
| class LiteResidualBlock(nn.Module): |
| def __init__(self, dim, dropout=0.05): |
| super().__init__() |
| self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim)) |
| self.norm = nn.LayerNorm(dim) |
| def forward(self, x): |
| return self.norm(x + self.block(x)) |
|
|
| class LookThemV8MNIST(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.stream_a = nn.Sequential( |
| nn.Conv2d(1, 4, 3, 2, 1), |
| nn.BatchNorm2d(4), nn.GELU(), |
| nn.Conv2d(4, 8, 3, 2, 1), |
| nn.BatchNorm2d(8), nn.GELU(), |
| nn.AdaptiveMaxPool2d((8, 8))) |
| self.stream_b = nn.Sequential( |
| nn.Conv2d(1, 4, 3, 1, 1), |
| nn.BatchNorm2d(4), nn.GELU(), |
| nn.Conv2d(4, 8, 3, 1, 1), |
| nn.BatchNorm2d(8), nn.GELU(), |
| nn.AdaptiveMaxPool2d((8, 8))) |
|
|
| self.lookthemA = LookThemLayer(64, 8, 32) |
| self.lookthemB = LookThemLayer(64, 8, 32) |
| self.lookthem_comb = LookThemLayer(64, 16, 32) |
| self.comb_norm = nn.LayerNorm(16) |
|
|
| self.FFN1 = nn.Conv1d(16, 8, 1) |
| self.lookthem2 = LookThemLayer(64, 8, 32) |
| self.FFN2 = nn.Conv1d(8, 8, 1) |
|
|
| self.compressor = nn.Conv1d(8, 4, 1) |
| self.input_proj = nn.Linear(64 * 4, 128) |
| self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128)) |
| self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10)) |
|
|
| def forward(self, x): |
| b = x.size(0) |
| fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2)) |
| fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2)) |
| x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2))) |
| x = x.transpose(1, 2) |
| x = self.FFN1(x).transpose(1, 2) |
| res = x |
| x = self.lookthem2(x).transpose(1, 2) |
| x = self.FFN2(x) + res.transpose(1, 2) |
| x = self.compressor(x).flatten(1) |
| x = self.res_blocks(self.input_proj(x)) |
| return self.head(x) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = LookThemV8MNIST() |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True)) |
| model.to(device) |
| model.eval() |
|
|
| |
| transform_fn = transforms.Compose([ |
| transforms.Resize((28, 28)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.1307,), (0.3081,)) |
| ]) |
|
|
| def predict_digit(input_image): |
| default_output = {str(i): 0.1 for i in range(10)} |
| |
| if input_image is None: |
| return default_output |
| |
| try: |
| |
| if isinstance(input_image, dict): |
| img_array = input_image.get("composite", None) |
| if img_array is None: |
| img_array = input_image.get("background", None) |
| else: |
| img_array = input_image |
| |
| if img_array is None: |
| return default_output |
|
|
| |
| if isinstance(img_array, np.ndarray) and img_array.ndim == 3: |
| if img_array.shape[-1] == 4: |
| |
| grayscale = img_array[..., 3] |
| else: |
| grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140]) |
| else: |
| grayscale = img_array.copy() |
| |
| |
| |
| if np.mean(grayscale) > 127: |
| grayscale = 255.0 - grayscale |
|
|
| |
| if np.max(grayscale) < 15: |
| return default_output |
| |
| |
| print(f"Processed image shape: {grayscale.shape} | Max Val: {np.max(grayscale)} | Mean Val: {np.mean(grayscale):.2f}") |
| |
| |
| img = Image.fromarray(grayscale.astype(np.uint8), mode="L") |
| img = img.resize((28, 28), Image.Resampling.BILINEAR) |
| tensor_img = transform_fn(img).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(tensor_img) |
| probabilities = F.softmax(outputs, dim=1)[0] |
| |
| return {str(i): float(probabilities[i]) for i in range(10)} |
| |
| except Exception as e: |
| print(f"Prediction error: {e}") |
| return default_output |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """ |
| # 🧠 LookThem V8 - MNIST Fraction Engine Classifier |
| ### 315K Parameters | **99.53% Validation Accuracy** |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| |
| input_canvas = gr.Sketchpad( |
| type="numpy", |
| layers=False, |
| canvas_size=(280, 280) |
| ) |
| submit_btn = gr.Button("Classify Digit 🏎️", variant="primary") |
| |
| with gr.Column(): |
| output_label = gr.Label(num_top_classes=3, label="Top Probabilities") |
|
|
| submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|