import os import urllib.request import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import gradio as gr import numpy as np from PIL import Image import os import urllib.request import zipfile # --- CONFIG & MODEL DOWNLOAD --- MODEL_PATH = "LookThem_V8_MNIST.pth" ZIP_PATH = "LookThem_V8_MNIST.zip" HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth" if not os.path.exists(MODEL_PATH): print(f"Downloading model weights from Hugging Face...") # Download the file as a zip first urllib.request.urlretrieve(HF_URL, ZIP_PATH) print("Download complete! Checking for zip compression...") try: # Unzip the file with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref: # Look for a .pth file inside the zip file_list = zip_ref.namelist() pth_files = [f for f in file_list if f.endswith('.pth')] if pth_files: # Extract the .pth file and rename it to our expected MODEL_PATH zip_ref.extract(pth_files[0], path=".") if pth_files[0] != MODEL_PATH: os.rename(pth_files[0], MODEL_PATH) print(f"Successfully extracted: {pth_files[0]} -> {MODEL_PATH}") else: # If no .pth inside, maybe the zip itself *is* the model (PyTorch 2.0+ format) print("No .pth file found inside zip. Renaming zip directly to .pth...") os.rename(ZIP_PATH, MODEL_PATH) except zipfile.BadZipFile: # If it wasn't actually a zip file, just rename it print("File is not a zip archive. Proceeding with standard weight loading.") os.rename(ZIP_PATH, MODEL_PATH) # Clean up the temporary zip file if it still exists if os.path.exists(ZIP_PATH): os.remove(ZIP_PATH) # --- DEFINE YOUR MODEL ARCHITECTURE --- class LookThemLayer(nn.Module): def __init__(self, num_tokens, in_features, hidden_dim): super().__init__() self.num_tokens = num_tokens self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1)) self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1)) self._init_weights() def _init_weights(self): for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]: nn.init.xavier_uniform_(w) def forward(self, x): N = self.num_tokens h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1 out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2 h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1 out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2 out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6) compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1)) compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2)) trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1) trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1) interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2 mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1) return (interaksi * mask).sum(dim=2) / (N - 1.0) class LiteResidualBlock(nn.Module): def __init__(self, dim, dropout=0.05): super().__init__() self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim)) self.norm = nn.LayerNorm(dim) def forward(self, x): return self.norm(x + self.block(x)) class LookThemV8MNIST(nn.Module): def __init__(self): super().__init__() self.stream_a = nn.Sequential( nn.Conv2d(1, 4, 3, 2, 1), nn.BatchNorm2d(4), nn.GELU(), nn.Conv2d(4, 8, 3, 2, 1), nn.BatchNorm2d(8), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8))) self.stream_b = nn.Sequential( nn.Conv2d(1, 4, 3, 1, 1), nn.BatchNorm2d(4), nn.GELU(), nn.Conv2d(4, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8))) self.lookthemA = LookThemLayer(64, 8, 32) self.lookthemB = LookThemLayer(64, 8, 32) self.lookthem_comb = LookThemLayer(64, 16, 32) self.comb_norm = nn.LayerNorm(16) self.FFN1 = nn.Conv1d(16, 8, 1) self.lookthem2 = LookThemLayer(64, 8, 32) self.FFN2 = nn.Conv1d(8, 8, 1) self.compressor = nn.Conv1d(8, 4, 1) self.input_proj = nn.Linear(64 * 4, 128) self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128)) self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10)) def forward(self, x): b = x.size(0) fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2)) fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2)) x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2))) x = x.transpose(1, 2) x = self.FFN1(x).transpose(1, 2) res = x x = self.lookthem2(x).transpose(1, 2) x = self.FFN2(x) + res.transpose(1, 2) x = self.compressor(x).flatten(1) x = self.res_blocks(self.input_proj(x)) return self.head(x) # --- LOAD WEIGHTS ON CPU/GPU --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = LookThemV8MNIST() model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True)) model.to(device) model.eval() # --- PREPROCESSING MATCHING TRAINING PIPELINE --- transform_fn = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) def predict_digit(input_image): default_output = {str(i): 0.1 for i in range(10)} if input_image is None: return default_output try: # 1. Handle Gradio Sketchpad dictionary output if isinstance(input_image, dict): img_array = input_image.get("composite", None) if img_array is None: img_array = input_image.get("background", None) else: img_array = input_image if img_array is None: return default_output # 2. Convert to Grayscale safely if isinstance(img_array, np.ndarray) and img_array.ndim == 3: if img_array.shape[-1] == 4: # RGBA (Canvas often uses alpha) # If background is transparent/white, alpha channel might be inverted grayscale = img_array[..., 3] else: # RGB -> Grayscale grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140]) else: grayscale = img_array.copy() # 3. AUTO-INVERT: Ensure white digit on black background # If the average pixel value is bright (> 127), the user drew dark text on light background. if np.mean(grayscale) > 127: grayscale = 255.0 - grayscale # 4. Check if the canvas is empty if np.max(grayscale) < 15: return default_output # Debugging print to check what your model is actually receiving print(f"Processed image shape: {grayscale.shape} | Max Val: {np.max(grayscale)} | Mean Val: {np.mean(grayscale):.2f}") # 5. Convert to PIL, Resize, and Transform img = Image.fromarray(grayscale.astype(np.uint8), mode="L") img = img.resize((28, 28), Image.Resampling.BILINEAR) tensor_img = transform_fn(img).unsqueeze(0).to(device) # 6. Model Inference with torch.no_grad(): outputs = model(tensor_img) probabilities = F.softmax(outputs, dim=1)[0] return {str(i): float(probabilities[i]) for i in range(10)} except Exception as e: print(f"Prediction error: {e}") return default_output # --- GRADIO INTERFACE --- with gr.Blocks() as demo: gr.Markdown( """ # 🧠 LookThem V8 - MNIST Fraction Engine Classifier ### 315K Parameters | **99.53% Validation Accuracy** """ ) with gr.Row(): with gr.Column(): # Standardized setup for canvas sketching in modern Gradio versions input_canvas = gr.Sketchpad( type="numpy", layers=False, canvas_size=(280, 280) ) submit_btn = gr.Button("Classify Digit 🏎️", variant="primary") with gr.Column(): output_label = gr.Label(num_top_classes=3, label="Top Probabilities") submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label) if __name__ == "__main__": demo.launch()