md896's picture
Update app.py
ea839ae verified
import torch
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from model import load_encoder
import torch.nn.functional as F
DEVICE = torch.device("cpu")
# ---------------- LOAD ENCODER ----------------
encoder = load_encoder("encoder_resnet18_simclr.pth")
encoder.eval()
# ---------------- LOAD LINEAR PROBE ----------------
data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
W = torch.tensor(data["W"], dtype=torch.float32)
b = torch.tensor(data["b"], dtype=torch.float32)
CLASSES = [
"airplane","automobile","bird","cat","deer",
"dog","frog","horse","ship","truck"
]
print("Classifier Loaded Successfully")
# ---------------- TRANSFORM ----------------
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
# ---------------- PREDICT ----------------
def predict(image):
image = Image.fromarray(image).convert("RGB")
x = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
emb = encoder(x).cpu().numpy()
# Match training normalization
emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
emb = torch.tensor(emb, dtype=torch.float32)
logits = emb @ W.T + b
probs = F.softmax(logits, dim=1).numpy()[0]
return {CLASSES[i]: float(probs[i]) for i in range(10)}
# ---------------- UI ----------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=3),
title="CRLF — CIFAR10 SimCLR Demo",
description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe."
)
demo.queue() # HF friendly
demo.launch(server_name="0.0.0.0", server_port=7860)