Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,8 +7,7 @@ import gradio as gr
|
|
| 7 |
from PIL import Image
|
| 8 |
import math
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
# Configuration (must match your trained model)
|
| 12 |
cfg = {
|
| 13 |
"image_size": 32,
|
| 14 |
"patch_size": 4,
|
|
@@ -41,10 +40,7 @@ classes = [
|
|
| 41 |
|
| 42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
|
| 44 |
-
# ------------------------
|
| 45 |
-
# Model definition
|
| 46 |
# ViT model implementation
|
| 47 |
-
# --- Conv stem (replace PatchEmbed) ---
|
| 48 |
class ConvPatchEmbed(nn.Module):
|
| 49 |
def __init__(self, in_chans=3, embed_dim=384):
|
| 50 |
super().__init__()
|
|
@@ -67,9 +63,9 @@ class ConvPatchEmbed(nn.Module):
|
|
| 67 |
|
| 68 |
def forward(self, x):
|
| 69 |
# x: (B, C, H, W)
|
| 70 |
-
x = self.conv(x)
|
| 71 |
-
x = x.flatten(2)
|
| 72 |
-
x = x.transpose(1, 2)
|
| 73 |
return x
|
| 74 |
|
| 75 |
class MLP(nn.Module):
|
|
@@ -192,13 +188,12 @@ class ViT(nn.Module):
|
|
| 192 |
out = self.head(cls)
|
| 193 |
return out
|
| 194 |
|
| 195 |
-
|
| 196 |
# Load model weights
|
| 197 |
model = ViT(cfg).to(device)
|
| 198 |
-
model.load_state_dict(torch.load("
|
| 199 |
model.eval()
|
| 200 |
|
| 201 |
-
# ------------------------
|
| 202 |
# Image preprocessing
|
| 203 |
transform = transforms.Compose([
|
| 204 |
transforms.Resize((32,32)),
|
|
@@ -215,7 +210,6 @@ def predict(img: Image.Image):
|
|
| 215 |
result = {classes[i]: float(probs[i]) for i in top5.indices}
|
| 216 |
return result
|
| 217 |
|
| 218 |
-
# ------------------------
|
| 219 |
# Gradio interface
|
| 220 |
iface = gr.Interface(
|
| 221 |
fn=predict,
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
import math
|
| 9 |
|
| 10 |
+
# Configuration
|
|
|
|
| 11 |
cfg = {
|
| 12 |
"image_size": 32,
|
| 13 |
"patch_size": 4,
|
|
|
|
| 40 |
|
| 41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
|
|
|
|
|
|
|
| 43 |
# ViT model implementation
|
|
|
|
| 44 |
class ConvPatchEmbed(nn.Module):
|
| 45 |
def __init__(self, in_chans=3, embed_dim=384):
|
| 46 |
super().__init__()
|
|
|
|
| 63 |
|
| 64 |
def forward(self, x):
|
| 65 |
# x: (B, C, H, W)
|
| 66 |
+
x = self.conv(x)
|
| 67 |
+
x = x.flatten(2)
|
| 68 |
+
x = x.transpose(1, 2)
|
| 69 |
return x
|
| 70 |
|
| 71 |
class MLP(nn.Module):
|
|
|
|
| 188 |
out = self.head(cls)
|
| 189 |
return out
|
| 190 |
|
| 191 |
+
|
| 192 |
# Load model weights
|
| 193 |
model = ViT(cfg).to(device)
|
| 194 |
+
model.load_state_dict(torch.load("best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device))
|
| 195 |
model.eval()
|
| 196 |
|
|
|
|
| 197 |
# Image preprocessing
|
| 198 |
transform = transforms.Compose([
|
| 199 |
transforms.Resize((32,32)),
|
|
|
|
| 210 |
result = {classes[i]: float(probs[i]) for i in top5.indices}
|
| 211 |
return result
|
| 212 |
|
|
|
|
| 213 |
# Gradio interface
|
| 214 |
iface = gr.Interface(
|
| 215 |
fn=predict,
|