Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import open_clip | |
| import matplotlib.pyplot as plt | |
| # Check if CUDA is available | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model_path = "ViT-B-32.pt" | |
| model_name = "ViT-B-32" | |
| # Load model and tokenizer | |
| model, _, preprocess = open_clip.create_model_and_transforms(model_name=model_name, pretrained=model_path) | |
| tokenizer = open_clip.get_tokenizer(model_name) | |
| # Move model to device | |
| model.to(device) | |
| def predict_emotion(image, prompts): | |
| # Preprocess the image | |
| image = preprocess(image).unsqueeze(0).to(device) | |
| # Tokenize the prompts | |
| text = tokenizer(prompts).to(device) | |
| # Perform inference | |
| with torch.no_grad(), torch.cuda.amp.autocast(): | |
| image_features = model.encode_image(image) | |
| text_features = model.encode_text(text) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| return text_probs.cpu().numpy() | |
| def main(): | |
| st.title("Zero-Shot OpenAI CLIP Fine-tuned for Emotion analysis") | |
| # Image upload | |
| uploaded_image = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"]) | |
| if uploaded_image is not None: | |
| # Display uploaded image | |
| image = Image.open(uploaded_image) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| # Prompt inputs | |
| st.write("Enter four prompts:") | |
| prompt1 = st.text_input("Prompt 1:") | |
| prompt2 = st.text_input("Prompt 2:") | |
| prompt3 = st.text_input("Prompt 3:") | |
| prompt4 = st.text_input("Prompt 4:") | |
| prompts = [prompt1, prompt2, prompt3, prompt4] | |
| # Predict emotion on button click | |
| if st.button("Predict"): | |
| with st.spinner("Predicting..."): | |
| probabilities = predict_emotion(image, prompts) | |
| # Print label probs in the specified format | |
| formatted_probs = ["{:.5f}".format(prob) for prob in probabilities[0]] | |
| results = dict(zip(prompts, formatted_probs)) | |
| # Display results | |
| st.write("Emotion Probabilities:") | |
| for prompt, prob in results.items(): | |
| st.write(f"{prompt}: {prob}") | |
| # Plot the probabilities | |
| plt.figure(figsize=(8, 6)) | |
| plt.bar(prompts, probabilities[0], color='skyblue') | |
| plt.title('Emotion Probabilities') | |
| plt.xlabel('Prompt') | |
| plt.ylabel('Probability') | |
| plt.ylim(0, 1) # Set y-axis limits to range [0, 1] | |
| st.pyplot(plt) | |
| if __name__ == "__main__": | |
| main() | |