md896 commited on
Commit
da63f24
·
verified ·
1 Parent(s): 00829d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -8,21 +8,48 @@ from sklearn.linear_model import LogisticRegression
8
 
9
  DEVICE = torch.device("cpu")
10
 
 
11
  encoder = load_encoder("encoder_resnet18_simclr.pth")
12
 
 
13
  data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
14
- clf = data["clf"].item()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  CLASSES = [
17
  "airplane","automobile","bird","cat","deer",
18
  "dog","frog","horse","ship","truck"
19
  ]
20
 
 
21
  transform = transforms.Compose([
22
  transforms.Resize((32, 32)),
23
  transforms.ToTensor()
24
  ])
25
 
 
26
  def predict(image):
27
  image = Image.fromarray(image).convert("RGB")
28
  x = transform(image).unsqueeze(0).to(DEVICE)
@@ -30,11 +57,12 @@ def predict(image):
30
  with torch.no_grad():
31
  emb = encoder(x).cpu().numpy()
32
 
33
- pred = clf.predict(emb)[0]
34
  probs = clf.predict_proba(emb)[0]
 
35
 
36
  return {CLASSES[i]: float(probs[i]) for i in range(10)}
37
 
 
38
  demo = gr.Interface(
39
  fn=predict,
40
  inputs=gr.Image(type="numpy"),
 
8
 
9
  DEVICE = torch.device("cpu")
10
 
11
+ # ---------------- LOAD ENCODER ----------------
12
  encoder = load_encoder("encoder_resnet18_simclr.pth")
13
 
14
+ # ---------------- LOAD LINEAR PROBE ----------------
15
  data = np.load("linear_probe_cifar10.npz", allow_pickle=True)
16
+ keys = list(data.files)
17
 
18
+ print("Loaded NPZ keys:", keys)
19
+
20
+ clf = None
21
+
22
+ # Case 1: saved full sklearn model
23
+ for k in keys:
24
+ if isinstance(data[k].item(), LogisticRegression):
25
+ clf = data[k].item()
26
+
27
+ # Case 2: saved coef / intercept
28
+ if clf is None and "coef" in keys and "intercept" in keys:
29
+ clf = LogisticRegression()
30
+ clf.coef_ = data["coef"]
31
+ clf.intercept_ = data["intercept"]
32
+ clf.classes_ = np.arange(10)
33
+
34
+ # If still no classifier → FAIL gracefully
35
+ if clf is None:
36
+ raise ValueError(f"❌ Could not find classifier inside NPZ. Found keys: {keys}")
37
+
38
+ print("Classifier Loaded Successfully")
39
+
40
+ # ---------------- CLASSES ----------------
41
  CLASSES = [
42
  "airplane","automobile","bird","cat","deer",
43
  "dog","frog","horse","ship","truck"
44
  ]
45
 
46
+ # ---------------- TRANSFORM ----------------
47
  transform = transforms.Compose([
48
  transforms.Resize((32, 32)),
49
  transforms.ToTensor()
50
  ])
51
 
52
+ # ---------------- PREDICT ----------------
53
  def predict(image):
54
  image = Image.fromarray(image).convert("RGB")
55
  x = transform(image).unsqueeze(0).to(DEVICE)
 
57
  with torch.no_grad():
58
  emb = encoder(x).cpu().numpy()
59
 
 
60
  probs = clf.predict_proba(emb)[0]
61
+ pred = np.argmax(probs)
62
 
63
  return {CLASSES[i]: float(probs[i]) for i in range(10)}
64
 
65
+ # ---------------- UI ----------------
66
  demo = gr.Interface(
67
  fn=predict,
68
  inputs=gr.Image(type="numpy"),