Spaces:
Build error
Build error
sashavor commited on
Commit ·
367e0cd
1
Parent(s): b444b38
final push for now
Browse files
app.py
CHANGED
|
@@ -32,7 +32,8 @@ def model_classify(radio, im):
|
|
| 32 |
outputs = model(**inputs)
|
| 33 |
logits = outputs.logits
|
| 34 |
predicted_class_idx = logits.argmax(-1).item()
|
| 35 |
-
|
|
|
|
| 36 |
else:
|
| 37 |
return None, None, False
|
| 38 |
|
|
@@ -61,15 +62,14 @@ def check_score(pred, truth, current_score, total_score, has_guessed):
|
|
| 61 |
|
| 62 |
|
| 63 |
|
| 64 |
-
def compare_score(userclass,
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return "The AI model got it right this time, try again!"
|
| 73 |
|
| 74 |
with gr.Blocks() as demo:
|
| 75 |
user_score = gr.State(0)
|
|
@@ -80,7 +80,9 @@ with gr.Blocks() as demo:
|
|
| 80 |
has_guessed = gr.State(False)
|
| 81 |
|
| 82 |
gr.Markdown("# ImageNet Quiz")
|
| 83 |
-
gr.Markdown("
|
|
|
|
|
|
|
| 84 |
with gr.Row():
|
| 85 |
|
| 86 |
with gr.Column(min_width= 900):
|
|
@@ -89,14 +91,14 @@ with gr.Blocks() as demo:
|
|
| 89 |
with gr.Column():
|
| 90 |
prediction = gr.Label(label="The AI model predicts:")
|
| 91 |
score = gr.Label(label="Your Score")
|
| 92 |
-
message = gr.
|
| 93 |
|
| 94 |
btn = gr.Button("Next image")
|
| 95 |
|
| 96 |
demo.load(random_image, None, [image, image_label, radio, prediction])
|
| 97 |
radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
|
| 98 |
radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
|
| 99 |
-
radio.change(compare_score, [radio,
|
| 100 |
btn.click(random_image, None, [image, image_label, radio, prediction])
|
| 101 |
btn.click(lambda :False, None, has_guessed)
|
| 102 |
|
|
|
|
| 32 |
outputs = model(**inputs)
|
| 33 |
logits = outputs.logits
|
| 34 |
predicted_class_idx = logits.argmax(-1).item()
|
| 35 |
+
modelclass=model.config.id2label[predicted_class_idx]
|
| 36 |
+
return modelclass.split(',')[0], predicted_class_idx, True
|
| 37 |
else:
|
| 38 |
return None, None, False
|
| 39 |
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
|
| 65 |
+
def compare_score(userclass, truth, has_guessed):
|
| 66 |
+
if userclass is None:
|
| 67 |
+
return"Try guessing a category!"
|
| 68 |
+
else:
|
| 69 |
+
if userclass == classes[int(truth)]:
|
| 70 |
+
return "Great! You guessed it right"
|
| 71 |
+
else:
|
| 72 |
+
return "The right answer was " +str(classes[int(truth)])+ "! Try guessing the next image."
|
|
|
|
| 73 |
|
| 74 |
with gr.Blocks() as demo:
|
| 75 |
user_score = gr.State(0)
|
|
|
|
| 80 |
has_guessed = gr.State(False)
|
| 81 |
|
| 82 |
gr.Markdown("# ImageNet Quiz")
|
| 83 |
+
gr.Markdown("### ImageNet is one of the most popular datasets used for training and evaluating AI models.")
|
| 84 |
+
gr.Markdown("### But many of its categories are hard to guess, even for humans.")
|
| 85 |
+
gr.Markdown("#### Try your hand at guessing the category of each image displayed, from the options provided. Compare your answers to that of a neural network trained on the dataset, and see if you can do better!")
|
| 86 |
with gr.Row():
|
| 87 |
|
| 88 |
with gr.Column(min_width= 900):
|
|
|
|
| 91 |
with gr.Column():
|
| 92 |
prediction = gr.Label(label="The AI model predicts:")
|
| 93 |
score = gr.Label(label="Your Score")
|
| 94 |
+
message = gr.Label(label="Did you guess it right?")
|
| 95 |
|
| 96 |
btn = gr.Button("Next image")
|
| 97 |
|
| 98 |
demo.load(random_image, None, [image, image_label, radio, prediction])
|
| 99 |
radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
|
| 100 |
radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
|
| 101 |
+
radio.change(compare_score, [radio, image_label, has_guessed], message)
|
| 102 |
btn.click(random_image, None, [image, image_label, radio, prediction])
|
| 103 |
btn.click(lambda :False, None, has_guessed)
|
| 104 |
|