File size: 9,766 Bytes
a605261 3fec527 a605261 3fec527 a605261 3fec527 a605261 51849ff a5137a6 a605261 307755a a605261 9cbe616 a605261 307755a a605261 307755a 9cbe616 cb2cf3e 077e064 51849ff 077e064 2ab4f6e 077e064 307755a 077e064 2ab4f6e 077e064 307755a 077e064 51849ff 077e064 307755a 077e064 cb2cf3e 9cbe616 077e064 cb2cf3e 2ab4f6e 307755a cb2cf3e 307755a a605261 077e064 307755a ba16773 a605261 4b78229 a605261 51849ff a605261 68860ec a605261 51849ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | import os
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import gradio as gr
import numpy as np
from PIL import Image
import os
import urllib.request
import zipfile
# --- CONFIG & MODEL DOWNLOAD ---
MODEL_PATH = "LookThem_V8_MNIST.pth"
ZIP_PATH = "LookThem_V8_MNIST.zip"
HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"
if not os.path.exists(MODEL_PATH):
print(f"Downloading model weights from Hugging Face...")
# Download the file as a zip first
urllib.request.urlretrieve(HF_URL, ZIP_PATH)
print("Download complete! Checking for zip compression...")
try:
# Unzip the file
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
# Look for a .pth file inside the zip
file_list = zip_ref.namelist()
pth_files = [f for f in file_list if f.endswith('.pth')]
if pth_files:
# Extract the .pth file and rename it to our expected MODEL_PATH
zip_ref.extract(pth_files[0], path=".")
if pth_files[0] != MODEL_PATH:
os.rename(pth_files[0], MODEL_PATH)
print(f"Successfully extracted: {pth_files[0]} -> {MODEL_PATH}")
else:
# If no .pth inside, maybe the zip itself *is* the model (PyTorch 2.0+ format)
print("No .pth file found inside zip. Renaming zip directly to .pth...")
os.rename(ZIP_PATH, MODEL_PATH)
except zipfile.BadZipFile:
# If it wasn't actually a zip file, just rename it
print("File is not a zip archive. Proceeding with standard weight loading.")
os.rename(ZIP_PATH, MODEL_PATH)
# Clean up the temporary zip file if it still exists
if os.path.exists(ZIP_PATH):
os.remove(ZIP_PATH)
# --- DEFINE YOUR MODEL ARCHITECTURE ---
class LookThemLayer(nn.Module):
def __init__(self, num_tokens, in_features, hidden_dim):
super().__init__()
self.num_tokens = num_tokens
self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
self._init_weights()
def _init_weights(self):
for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]:
nn.init.xavier_uniform_(w)
def forward(self, x):
N = self.num_tokens
h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1
out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2
h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1
out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2
out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6)
compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1)
trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1)
interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1)
return (interaksi * mask).sum(dim=2) / (N - 1.0)
class LiteResidualBlock(nn.Module):
def __init__(self, dim, dropout=0.05):
super().__init__()
self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.norm(x + self.block(x))
class LookThemV8MNIST(nn.Module):
def __init__(self):
super().__init__()
self.stream_a = nn.Sequential(
nn.Conv2d(1, 4, 3, 2, 1),
nn.BatchNorm2d(4), nn.GELU(),
nn.Conv2d(4, 8, 3, 2, 1),
nn.BatchNorm2d(8), nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8)))
self.stream_b = nn.Sequential(
nn.Conv2d(1, 4, 3, 1, 1),
nn.BatchNorm2d(4), nn.GELU(),
nn.Conv2d(4, 8, 3, 1, 1),
nn.BatchNorm2d(8), nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8)))
self.lookthemA = LookThemLayer(64, 8, 32)
self.lookthemB = LookThemLayer(64, 8, 32)
self.lookthem_comb = LookThemLayer(64, 16, 32)
self.comb_norm = nn.LayerNorm(16)
self.FFN1 = nn.Conv1d(16, 8, 1)
self.lookthem2 = LookThemLayer(64, 8, 32)
self.FFN2 = nn.Conv1d(8, 8, 1)
self.compressor = nn.Conv1d(8, 4, 1)
self.input_proj = nn.Linear(64 * 4, 128)
self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128))
self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10))
def forward(self, x):
b = x.size(0)
fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2))
fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2))
x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2)))
x = x.transpose(1, 2)
x = self.FFN1(x).transpose(1, 2)
res = x
x = self.lookthem2(x).transpose(1, 2)
x = self.FFN2(x) + res.transpose(1, 2)
x = self.compressor(x).flatten(1)
x = self.res_blocks(self.input_proj(x))
return self.head(x)
# --- LOAD WEIGHTS ON CPU/GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LookThemV8MNIST()
model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True))
model.to(device)
model.eval()
# --- PREPROCESSING MATCHING TRAINING PIPELINE ---
transform_fn = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def predict_digit(input_image):
default_output = {str(i): 0.1 for i in range(10)}
if input_image is None:
return default_output
try:
# 1. Handle Gradio Sketchpad dictionary output
if isinstance(input_image, dict):
img_array = input_image.get("composite", None)
if img_array is None:
img_array = input_image.get("background", None)
else:
img_array = input_image
if img_array is None:
return default_output
# 2. Convert to Grayscale safely
if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
if img_array.shape[-1] == 4: # RGBA (Canvas often uses alpha)
# If background is transparent/white, alpha channel might be inverted
grayscale = img_array[..., 3]
else: # RGB -> Grayscale
grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
else:
grayscale = img_array.copy()
# 3. AUTO-INVERT: Ensure white digit on black background
# If the average pixel value is bright (> 127), the user drew dark text on light background.
if np.mean(grayscale) > 127:
grayscale = 255.0 - grayscale
# 4. Check if the canvas is empty
if np.max(grayscale) < 15:
return default_output
# Debugging print to check what your model is actually receiving
print(f"Processed image shape: {grayscale.shape} | Max Val: {np.max(grayscale)} | Mean Val: {np.mean(grayscale):.2f}")
# 5. Convert to PIL, Resize, and Transform
img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
img = img.resize((28, 28), Image.Resampling.BILINEAR)
tensor_img = transform_fn(img).unsqueeze(0).to(device)
# 6. Model Inference
with torch.no_grad():
outputs = model(tensor_img)
probabilities = F.softmax(outputs, dim=1)[0]
return {str(i): float(probabilities[i]) for i in range(10)}
except Exception as e:
print(f"Prediction error: {e}")
return default_output
# --- GRADIO INTERFACE ---
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🧠 LookThem V8 - MNIST Fraction Engine Classifier
### 315K Parameters | **99.53% Validation Accuracy**
"""
)
with gr.Row():
with gr.Column():
# Standardized setup for canvas sketching in modern Gradio versions
input_canvas = gr.Sketchpad(
type="numpy",
layers=False,
canvas_size=(280, 280)
)
submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
with gr.Column():
output_label = gr.Label(num_top_classes=3, label="Top Probabilities")
submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)
if __name__ == "__main__":
demo.launch()
|