Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,7 @@ DEVICE = torch.device("cpu")
|
|
| 10 |
|
| 11 |
# ---------------- LOAD ENCODER ----------------
|
| 12 |
encoder = load_encoder("encoder_resnet18_simclr.pth")
|
|
|
|
| 13 |
|
| 14 |
# ---------------- LOAD LINEAR PROBE ----------------
|
| 15 |
data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
|
|
@@ -33,12 +34,15 @@ transform = transforms.Compose([
|
|
| 33 |
# ---------------- PREDICT ----------------
|
| 34 |
def predict(image):
|
| 35 |
image = Image.fromarray(image).convert("RGB")
|
| 36 |
-
x = transform(image).unsqueeze(0)
|
| 37 |
|
| 38 |
with torch.no_grad():
|
| 39 |
-
emb = encoder(x).numpy()
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
logits = emb @ W.T + b
|
| 44 |
probs = F.softmax(logits, dim=1).numpy()[0]
|
|
@@ -54,4 +58,5 @@ demo = gr.Interface(
|
|
| 54 |
description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe."
|
| 55 |
)
|
| 56 |
|
| 57 |
-
demo.
|
|
|
|
|
|
| 10 |
|
| 11 |
# ---------------- LOAD ENCODER ----------------
|
| 12 |
encoder = load_encoder("encoder_resnet18_simclr.pth")
|
| 13 |
+
encoder.eval()
|
| 14 |
|
| 15 |
# ---------------- LOAD LINEAR PROBE ----------------
|
| 16 |
data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
|
|
|
|
| 34 |
# ---------------- PREDICT ----------------
|
| 35 |
def predict(image):
|
| 36 |
image = Image.fromarray(image).convert("RGB")
|
| 37 |
+
x = transform(image).unsqueeze(0).to(DEVICE)
|
| 38 |
|
| 39 |
with torch.no_grad():
|
| 40 |
+
emb = encoder(x).cpu().numpy()
|
| 41 |
|
| 42 |
+
# IMPORTANT — match training behavior
|
| 43 |
+
emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
|
| 44 |
+
|
| 45 |
+
emb = torch.tensor(emb, dtype=torch.float32)
|
| 46 |
|
| 47 |
logits = emb @ W.T + b
|
| 48 |
probs = F.softmax(logits, dim=1).numpy()[0]
|
|
|
|
| 58 |
description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe."
|
| 59 |
)
|
| 60 |
|
| 61 |
+
demo.queue(concurrency_count=4)
|
| 62 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|