import torch import numpy as np import gradio as gr from PIL import Image from torchvision import transforms from model import load_encoder import torch.nn.functional as F DEVICE = torch.device("cpu") # ---------------- LOAD ENCODER ---------------- encoder = load_encoder("encoder_resnet18_simclr.pth") encoder.eval() # ---------------- LOAD LINEAR PROBE ---------------- data = np.load("linear_probe_cifar10.npz", allow_pickle=True) W = torch.tensor(data["W"], dtype=torch.float32) b = torch.tensor(data["b"], dtype=torch.float32) CLASSES = [ "airplane","automobile","bird","cat","deer", "dog","frog","horse","ship","truck" ] print("Classifier Loaded Successfully") # ---------------- TRANSFORM ---------------- transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor() ]) # ---------------- PREDICT ---------------- def predict(image): image = Image.fromarray(image).convert("RGB") x = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): emb = encoder(x).cpu().numpy() # Match training normalization emb = emb / np.linalg.norm(emb, axis=1, keepdims=True) emb = torch.tensor(emb, dtype=torch.float32) logits = emb @ W.T + b probs = F.softmax(logits, dim=1).numpy()[0] return {CLASSES[i]: float(probs[i]) for i in range(10)} # ---------------- UI ---------------- demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy"), outputs=gr.Label(num_top_classes=3), title="CRLF — CIFAR10 SimCLR Demo", description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe." ) demo.queue() # HF friendly demo.launch(server_name="0.0.0.0", server_port=7860)