Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from model import load_encoder | |
| import torch.nn.functional as F | |
| DEVICE = torch.device("cpu") | |
| # ---------------- LOAD ENCODER ---------------- | |
| encoder = load_encoder("encoder_resnet18_simclr.pth") | |
| encoder.eval() | |
| # ---------------- LOAD LINEAR PROBE ---------------- | |
| data = np.load("linear_probe_cifar10.npz", allow_pickle=True) | |
| W = torch.tensor(data["W"], dtype=torch.float32) | |
| b = torch.tensor(data["b"], dtype=torch.float32) | |
| CLASSES = [ | |
| "airplane","automobile","bird","cat","deer", | |
| "dog","frog","horse","ship","truck" | |
| ] | |
| print("Classifier Loaded Successfully") | |
| # ---------------- TRANSFORM ---------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor() | |
| ]) | |
| # ---------------- PREDICT ---------------- | |
| def predict(image): | |
| image = Image.fromarray(image).convert("RGB") | |
| x = transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| emb = encoder(x).cpu().numpy() | |
| # Match training normalization | |
| emb = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
| emb = torch.tensor(emb, dtype=torch.float32) | |
| logits = emb @ W.T + b | |
| probs = F.softmax(logits, dim=1).numpy()[0] | |
| return {CLASSES[i]: float(probs[i]) for i in range(10)} | |
| # ---------------- UI ---------------- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="CRLF — CIFAR10 SimCLR Demo", | |
| description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe." | |
| ) | |
| demo.queue() # HF friendly | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |