File size: 1,724 Bytes
e2e08dc
 
 
 
 
3b572d1
24e5d05
e2e08dc
3b572d1
e2e08dc
da63f24
3b572d1
f178e3e
e2e08dc
da63f24
00829d5
e2e08dc
24e5d05
 
da63f24
e2e08dc
 
 
 
 
24e5d05
 
da63f24
e2e08dc
 
 
 
 
da63f24
e2e08dc
3b572d1
f178e3e
e2e08dc
 
f178e3e
24e5d05
ea839ae
f178e3e
 
 
e2e08dc
24e5d05
 
e2e08dc
3b572d1
e2e08dc
da63f24
e2e08dc
 
3b572d1
e2e08dc
3b572d1
 
e2e08dc
 
ea839ae
f178e3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from model import load_encoder
import torch.nn.functional as F

DEVICE = torch.device("cpu")

# ---------------- LOAD ENCODER ----------------
encoder = load_encoder("encoder_resnet18_simclr.pth")
encoder.eval()

# ---------------- LOAD LINEAR PROBE ----------------
data = np.load("linear_probe_cifar10.npz", allow_pickle=True)

W = torch.tensor(data["W"], dtype=torch.float32)
b = torch.tensor(data["b"], dtype=torch.float32)

CLASSES = [
    "airplane","automobile","bird","cat","deer",
    "dog","frog","horse","ship","truck"
]

print("Classifier Loaded Successfully")

# ---------------- TRANSFORM ----------------
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

# ---------------- PREDICT ----------------
def predict(image):
    image = Image.fromarray(image).convert("RGB")
    x = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        emb = encoder(x).cpu().numpy()

    # Match training normalization
    emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)

    emb = torch.tensor(emb, dtype=torch.float32)

    logits = emb @ W.T + b
    probs = F.softmax(logits, dim=1).numpy()[0]

    return {CLASSES[i]: float(probs[i]) for i in range(10)}

# ---------------- UI ----------------
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Label(num_top_classes=3),
    title="CRLF — CIFAR10 SimCLR Demo",
    description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe."
)

demo.queue()   # HF friendly
demo.launch(server_name="0.0.0.0", server_port=7860)