Spaces:
Runtime error
Runtime error
Commit
·
0c49c14
1
Parent(s):
10432c9
Update app.py
Browse files
app.py
CHANGED
|
@@ -55,4 +55,21 @@ def training():
|
|
| 55 |
print("Test accuracy:", score[1])
|
| 56 |
|
| 57 |
push_to_hub_keras(model, "active-learning/mnist_classifier")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
print("Test accuracy:", score[1])
|
| 56 |
|
| 57 |
push_to_hub_keras(model, "active-learning/mnist_classifier")
|
| 58 |
+
|
| 59 |
+
def find_samples_to_label():
|
| 60 |
+
loaded_model = from_pretrained_keras("active-learning/mnist_classifier")
|
| 61 |
+
loaded_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
|
| 62 |
|
| 63 |
+
unlabeled_data = load_dataset("active-learning/unlabeled_samples")["train"]
|
| 64 |
+
processed_data = unlabeled_data.map(to_numpy, batched=True)
|
| 65 |
+
processed_data = processed_data["pixel_values"]
|
| 66 |
+
processed_data = tf.expand_dims(processed_data, -1)
|
| 67 |
+
|
| 68 |
+
# Get all predictions
|
| 69 |
+
# And then get the 5 samples with the lowest prediction score
|
| 70 |
+
preds = loaded_model.predict(unlabeled_data)
|
| 71 |
+
top_pred_confs = 1 - np.max(preds, axis=1)
|
| 72 |
+
idx_to_label = np.argpartition(top_pred_confs, -5)[-5:]
|
| 73 |
+
|
| 74 |
+
to_label_data = unlabeled_data.select(idx_to_label)
|
| 75 |
+
to_label_data.push_to_hub("active-learning/to_label_samples")
|