md896 commited on
Commit
f178e3e
·
verified ·
1 Parent(s): 24e5d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -10,6 +10,7 @@ DEVICE = torch.device("cpu")
10
 
11
  # ---------------- LOAD ENCODER ----------------
12
  encoder = load_encoder("encoder_resnet18_simclr.pth")
 
13
 
14
  # ---------------- LOAD LINEAR PROBE ----------------
15
  data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
@@ -33,12 +34,15 @@ transform = transforms.Compose([
33
  # ---------------- PREDICT ----------------
34
  def predict(image):
35
  image = Image.fromarray(image).convert("RGB")
36
- x = transform(image).unsqueeze(0)
37
 
38
  with torch.no_grad():
39
- emb = encoder(x).numpy()
40
 
41
- emb = torch.tensor(emb)
 
 
 
42
 
43
  logits = emb @ W.T + b
44
  probs = F.softmax(logits, dim=1).numpy()[0]
@@ -54,4 +58,5 @@ demo = gr.Interface(
54
  description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe."
55
  )
56
 
57
- demo.launch()
 
 
10
 
11
  # ---------------- LOAD ENCODER ----------------
12
  encoder = load_encoder("encoder_resnet18_simclr.pth")
13
+ encoder.eval()
14
 
15
  # ---------------- LOAD LINEAR PROBE ----------------
16
  data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
 
34
  # ---------------- PREDICT ----------------
35
  def predict(image):
36
  image = Image.fromarray(image).convert("RGB")
37
+ x = transform(image).unsqueeze(0).to(DEVICE)
38
 
39
  with torch.no_grad():
40
+ emb = encoder(x).cpu().numpy()
41
 
42
+ # IMPORTANT — match training behavior
43
+ emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
44
+
45
+ emb = torch.tensor(emb, dtype=torch.float32)
46
 
47
  logits = emb @ W.T + b
48
  probs = F.softmax(logits, dim=1).numpy()[0]
 
58
  description="Upload an image. Model trained WITHOUT labels using SimCLR. Evaluated using Linear Probe."
59
  )
60
 
61
+ demo.queue(concurrency_count=4)
62
+ demo.launch(server_name="0.0.0.0", server_port=7860)