Spaces:
Sleeping
Sleeping
| import open_clip | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from datasets import load_dataset | |
| import random | |
| from datasets import load_from_disk | |
| dataset = load_from_disk("./train") | |
| from collections import OrderedDict | |
| FRUITS30_CLASSES = OrderedDict( | |
| { | |
| "0" : "acerolas", | |
| "1" : "apples", | |
| "2" : "apricots", | |
| "3" : "avocados", | |
| "4" : "bananas", | |
| "5" : "blackberries", | |
| "6" : "blueberries", | |
| "7" : "cantaloupes", | |
| "8" : "cherries", | |
| "9" : "coconuts", | |
| "10" : "figs", | |
| "11" : "grapefruits", | |
| "12" : "grapes", | |
| "13" : "guava", | |
| "14" : "kiwifruit", | |
| "15" : "lemons", | |
| "16" : "limes", | |
| "17" : "mangos", | |
| "18" : "olives", | |
| "19" : "oranges", | |
| "20" : "passionfruit", | |
| "21" : "peaches", | |
| "22" : "pears", | |
| "23" : "pineapples", | |
| "24" : "plums", | |
| "25" : "pomegranates", | |
| "26" : "raspberries", | |
| "27" : "strawberries", | |
| "28" : "tomatoes", | |
| "29" : "watermelons" | |
| } | |
| ) | |
| labels = list(FRUITS30_CLASSES.values()) | |
| model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') | |
| model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active | |
| tokenizer = open_clip.get_tokenizer('ViT-B-32') | |
| def create_interface(): | |
| # Store current correct labels in a mutable container | |
| current_correct_labels = [] | |
| def get_image(): | |
| indices = random.sample(range(len(dataset)), 1) | |
| selected_images = [dataset[i]['image'] for i in indices] | |
| return selected_images[0] | |
| def on_submit(img1,label1): | |
| image = preprocess(img1).unsqueeze(0) | |
| text = tokenizer(labels+[label1,"not a fruit"]) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image) | |
| text_features = model.encode_text(text) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| labels1 = labels+[label1,"not a fruit"] | |
| correct_label = labels1[(text_probs.argmax().item())] | |
| return correct_label | |
| with gr.Blocks() as demo: | |
| # Create components | |
| with gr.Row(): | |
| img1 = gr.Image(type="pil", label="Fruit",height = 300,width = 300) | |
| label1 = gr.Textbox(label="Name this fruit") | |
| submit_btn = gr.Button("Submit") | |
| refresh_btn = gr.Button("Refresh") | |
| result = gr.Textbox(label="Answer") | |
| # Update images, labels, and correct labels on refresh button click | |
| refresh_btn.click( | |
| fn=get_image, | |
| outputs=[img1] | |
| ) | |
| # Evaluate user input on submit button click | |
| submit_btn.click( | |
| fn=on_submit, | |
| inputs=[img1,label1], | |
| outputs=result | |
| ) | |
| demo.launch(debug = True) | |
| # Run the game | |
| create_interface() |