Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from datasets import load_dataset | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import random | |
| # Load the CLIP model and processor | |
| st.title("Meme Battle AI") | |
| st.write("Stream memes directly and let AI determine the winner!") | |
| def load_model(): | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| return model, processor | |
| model, processor = load_model() | |
| def load_streamed_dataset(): | |
| return load_dataset("Dhruv-goyal/memes_with_captions", split="train", streaming=True) | |
| dataset = load_streamed_dataset() | |
| def fetch_random_memes(): | |
| """Fetch two random memes from the dataset.""" | |
| sample_size = 100 # Number of samples to shuffle | |
| dataset_samples = list(dataset.shuffle(seed=random.randint(0, 1000)).take(sample_size)) | |
| meme1, meme2 = random.sample(dataset_samples, 2) | |
| return meme1, meme2 | |
| def parse_meme(meme): | |
| """Extract the caption and Pillow image from a meme.""" | |
| caption = meme["answers"][0] if meme.get("answers") else "No caption available" | |
| image = meme["image"] # This is already a PIL image object | |
| return caption, image | |
| def score_meme(image, caption): | |
| """Score a meme by evaluating the image-caption compatibility.""" | |
| try: | |
| # Preprocess image and caption | |
| inputs = processor(text=[caption], images=[image], return_tensors="pt", padding=True) | |
| # Get the compatibility score | |
| outputs = model(**inputs) | |
| logits_per_text = outputs.logits_per_text | |
| return logits_per_text.item() | |
| except Exception as e: | |
| st.error(f"Error scoring meme: {e}") | |
| return 0 | |
| if st.button("Start Meme Battle"): | |
| # Fetch random memes | |
| meme1, meme2 = fetch_random_memes() | |
| # Parse captions and images | |
| caption1, image1 = parse_meme(meme1) | |
| caption2, image2 = parse_meme(meme2) | |
| # Score memes | |
| score1 = score_meme(image1, caption1) | |
| score2 = score_meme(image2, caption2) | |
| # Display Meme 1 and Meme 2 side by side | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("#### Meme 1") | |
| st.image(image1, caption=f"Caption: {caption1}") | |
| st.write(f"AI Score: {score1:.2f}") | |
| with col2: | |
| st.write("#### Meme 2") | |
| st.image(image2, caption=f"Caption: {caption2}") | |
| st.write(f"AI Score: {score2:.2f}") | |
| # Determine the winner | |
| if score1 > score2: | |
| st.write("π **Meme 1 Wins!**") | |
| elif score2 > score1: | |
| st.write("π **Meme 2 Wins!**") | |
| else: | |
| st.write("π€ **It's a tie!**") | |