Spaces:
Sleeping
Sleeping
File size: 1,724 Bytes
e2e08dc 3b572d1 24e5d05 e2e08dc 3b572d1 e2e08dc da63f24 3b572d1 f178e3e e2e08dc da63f24 00829d5 e2e08dc 24e5d05 da63f24 e2e08dc 24e5d05 da63f24 e2e08dc da63f24 e2e08dc 3b572d1 f178e3e e2e08dc f178e3e 24e5d05 ea839ae f178e3e e2e08dc 24e5d05 e2e08dc 3b572d1 e2e08dc da63f24 e2e08dc 3b572d1 e2e08dc 3b572d1 e2e08dc ea839ae f178e3e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import torch
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from model import load_encoder
import torch.nn.functional as F
DEVICE = torch.device("cpu")
# ---------------- LOAD ENCODER ----------------
encoder = load_encoder("encoder_resnet18_simclr.pth")
encoder.eval()
# ---------------- LOAD LINEAR PROBE ----------------
data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
W = torch.tensor(data["W"], dtype=torch.float32)
b = torch.tensor(data["b"], dtype=torch.float32)
CLASSES = [
"airplane","automobile","bird","cat","deer",
"dog","frog","horse","ship","truck"
]
print("Classifier Loaded Successfully")
# ---------------- TRANSFORM ----------------
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
# ---------------- PREDICT ----------------
def predict(image):
image = Image.fromarray(image).convert("RGB")
x = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
emb = encoder(x).cpu().numpy()
# Match training normalization
emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
emb = torch.tensor(emb, dtype=torch.float32)
logits = emb @ W.T + b
probs = F.softmax(logits, dim=1).numpy()[0]
return {CLASSES[i]: float(probs[i]) for i in range(10)}
# ---------------- UI ----------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=3),
title="CRLF — CIFAR10 SimCLR Demo",
description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe."
)
demo.queue() # HF friendly
demo.launch(server_name="0.0.0.0", server_port=7860)
|