Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| # Load CLIP model and processor | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Define a list of target words for the game | |
| words = ["cat", "car", "tree", "house", "dog", "cloud", "flower", "bicycle", "boat", "star", "bird", "fish", "sun"] | |
| # Precompute text embeddings for faster comparisons | |
| text_inputs = processor(text=words, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| text_features = model.get_text_features(**text_inputs) | |
| # Define the function to process drawing and make a prediction | |
| def guess_drawing(drawing): | |
| # Access the image data from the 'background' key | |
| if 'background' in drawing: | |
| image_array = np.array(drawing['background'], dtype=np.uint8) | |
| else: | |
| return "Invalid drawing format. Unable to process." | |
| # Convert to RGB PIL image to ensure compatibility with CLIP | |
| image = Image.fromarray(image_array).convert("RGB") | |
| # Prepare the image for the model | |
| image_inputs = processor(images=image, return_tensors="pt") | |
| # Get image features from the model | |
| with torch.no_grad(): | |
| image_features = model.get_image_features(**image_inputs) | |
| # Calculate cosine similarity with each word | |
| similarity = torch.nn.functional.cosine_similarity(image_features, text_features) | |
| # Debug: Print similarity scores for each word | |
| for word, score in zip(words, similarity.tolist()): | |
| print(f"Similarity score for '{word}': {score}") | |
| best_match = words[similarity.argmax().item()] | |
| # Return the AI's best guess | |
| return f"AI's guess: {best_match}" | |
| # Set up Gradio interface | |
| interface = gr.Interface( | |
| fn=guess_drawing, | |
| inputs=gr.Sketchpad(), | |
| outputs="text", | |
| live=True, | |
| description="Draw cat, car, tree, house, dog, cloud, flower, bicycle, boat, star, bird, fish, sun" | |
| ) | |
| interface.launch() | |