ASomeoneWhoInterestedWithAI's picture
Update app.py
3fec527 verified
import os
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import gradio as gr
import numpy as np
from PIL import Image
import os
import urllib.request
import zipfile
# --- CONFIG & MODEL DOWNLOAD ---
MODEL_PATH = "LookThem_V8_MNIST.pth"
ZIP_PATH = "LookThem_V8_MNIST.zip"
HF_URL = "https://huggingface.co/ASomeoneWhoInterestedWithAI/LookThem_V8-MNIST_Classifier/resolve/main/LookThem_V8_MNIST%20(2).pth"
if not os.path.exists(MODEL_PATH):
print(f"Downloading model weights from Hugging Face...")
# Download the file as a zip first
urllib.request.urlretrieve(HF_URL, ZIP_PATH)
print("Download complete! Checking for zip compression...")
try:
# Unzip the file
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
# Look for a .pth file inside the zip
file_list = zip_ref.namelist()
pth_files = [f for f in file_list if f.endswith('.pth')]
if pth_files:
# Extract the .pth file and rename it to our expected MODEL_PATH
zip_ref.extract(pth_files[0], path=".")
if pth_files[0] != MODEL_PATH:
os.rename(pth_files[0], MODEL_PATH)
print(f"Successfully extracted: {pth_files[0]} -> {MODEL_PATH}")
else:
# If no .pth inside, maybe the zip itself *is* the model (PyTorch 2.0+ format)
print("No .pth file found inside zip. Renaming zip directly to .pth...")
os.rename(ZIP_PATH, MODEL_PATH)
except zipfile.BadZipFile:
# If it wasn't actually a zip file, just rename it
print("File is not a zip archive. Proceeding with standard weight loading.")
os.rename(ZIP_PATH, MODEL_PATH)
# Clean up the temporary zip file if it still exists
if os.path.exists(ZIP_PATH):
os.remove(ZIP_PATH)
# --- DEFINE YOUR MODEL ARCHITECTURE ---
class LookThemLayer(nn.Module):
def __init__(self, num_tokens, in_features, hidden_dim):
super().__init__()
self.num_tokens = num_tokens
self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
self._init_weights()
def _init_weights(self):
for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]:
nn.init.xavier_uniform_(w)
def forward(self, x):
N = self.num_tokens
h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1
out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2
h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1
out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2
out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6)
compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1)
trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1)
interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1)
return (interaksi * mask).sum(dim=2) / (N - 1.0)
class LiteResidualBlock(nn.Module):
def __init__(self, dim, dropout=0.05):
super().__init__()
self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.norm(x + self.block(x))
class LookThemV8MNIST(nn.Module):
def __init__(self):
super().__init__()
self.stream_a = nn.Sequential(
nn.Conv2d(1, 4, 3, 2, 1),
nn.BatchNorm2d(4), nn.GELU(),
nn.Conv2d(4, 8, 3, 2, 1),
nn.BatchNorm2d(8), nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8)))
self.stream_b = nn.Sequential(
nn.Conv2d(1, 4, 3, 1, 1),
nn.BatchNorm2d(4), nn.GELU(),
nn.Conv2d(4, 8, 3, 1, 1),
nn.BatchNorm2d(8), nn.GELU(),
nn.AdaptiveMaxPool2d((8, 8)))
self.lookthemA = LookThemLayer(64, 8, 32)
self.lookthemB = LookThemLayer(64, 8, 32)
self.lookthem_comb = LookThemLayer(64, 16, 32)
self.comb_norm = nn.LayerNorm(16)
self.FFN1 = nn.Conv1d(16, 8, 1)
self.lookthem2 = LookThemLayer(64, 8, 32)
self.FFN2 = nn.Conv1d(8, 8, 1)
self.compressor = nn.Conv1d(8, 4, 1)
self.input_proj = nn.Linear(64 * 4, 128)
self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128))
self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10))
def forward(self, x):
b = x.size(0)
fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2))
fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2))
x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2)))
x = x.transpose(1, 2)
x = self.FFN1(x).transpose(1, 2)
res = x
x = self.lookthem2(x).transpose(1, 2)
x = self.FFN2(x) + res.transpose(1, 2)
x = self.compressor(x).flatten(1)
x = self.res_blocks(self.input_proj(x))
return self.head(x)
# --- LOAD WEIGHTS ON CPU/GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LookThemV8MNIST()
model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True))
model.to(device)
model.eval()
# --- PREPROCESSING MATCHING TRAINING PIPELINE ---
transform_fn = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def predict_digit(input_image):
default_output = {str(i): 0.1 for i in range(10)}
if input_image is None:
return default_output
try:
# 1. Handle Gradio Sketchpad dictionary output
if isinstance(input_image, dict):
img_array = input_image.get("composite", None)
if img_array is None:
img_array = input_image.get("background", None)
else:
img_array = input_image
if img_array is None:
return default_output
# 2. Convert to Grayscale safely
if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
if img_array.shape[-1] == 4: # RGBA (Canvas often uses alpha)
# If background is transparent/white, alpha channel might be inverted
grayscale = img_array[..., 3]
else: # RGB -> Grayscale
grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
else:
grayscale = img_array.copy()
# 3. AUTO-INVERT: Ensure white digit on black background
# If the average pixel value is bright (> 127), the user drew dark text on light background.
if np.mean(grayscale) > 127:
grayscale = 255.0 - grayscale
# 4. Check if the canvas is empty
if np.max(grayscale) < 15:
return default_output
# Debugging print to check what your model is actually receiving
print(f"Processed image shape: {grayscale.shape} | Max Val: {np.max(grayscale)} | Mean Val: {np.mean(grayscale):.2f}")
# 5. Convert to PIL, Resize, and Transform
img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
img = img.resize((28, 28), Image.Resampling.BILINEAR)
tensor_img = transform_fn(img).unsqueeze(0).to(device)
# 6. Model Inference
with torch.no_grad():
outputs = model(tensor_img)
probabilities = F.softmax(outputs, dim=1)[0]
return {str(i): float(probabilities[i]) for i in range(10)}
except Exception as e:
print(f"Prediction error: {e}")
return default_output
# --- GRADIO INTERFACE ---
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🧠 LookThem V8 - MNIST Fraction Engine Classifier
### 315K Parameters | **99.53% Validation Accuracy**
"""
)
with gr.Row():
with gr.Column():
# Standardized setup for canvas sketching in modern Gradio versions
input_canvas = gr.Sketchpad(
type="numpy",
layers=False,
canvas_size=(280, 280)
)
submit_btn = gr.Button("Classify Digit 🏎️", variant="primary")
with gr.Column():
output_label = gr.Label(num_top_classes=3, label="Top Probabilities")
submit_btn.click(fn=predict_digit, inputs=input_canvas, outputs=output_label)
if __name__ == "__main__":
demo.launch()